Replace helper functions with native burn functions

This commit is contained in:
Gadersd
2023-09-07 12:23:18 -04:00
parent a62795347f
commit f4c58c1790
20 changed files with 1091 additions and 950 deletions

View File

@@ -1,59 +1,59 @@
pub mod load;
use burn::{
config::Config,
config::Config,
module::{Module, Param},
nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}},
nn::{
self,
conv::{Conv2d, Conv2dConfig, Conv2dRecord},
PaddingConfig2d,
},
tensor::{
activation::{sigmoid, softmax},
backend::Backend,
activation::{softmax, sigmoid},
module::embedding,
Tensor,
Distribution,
Int,
module::embedding,
Distribution, Int, Tensor,
},
};
use crate::helper::div_roundup;
use super::silu::*;
use super::groupnorm::*;
use super::attention::qkv_attention;
use super::groupnorm::*;
use super::silu::*;
use std::iter;
#[derive(Config)]
pub struct AutoencoderConfig {}
impl AutoencoderConfig {
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
let encoder =
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
let decoder =
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
Autoencoder {
encoder,
decoder,
quant_conv,
post_quant_conv,
encoder,
decoder,
quant_conv,
post_quant_conv,
}
}
}
#[derive(Module, Debug)]
pub struct Autoencoder<B: Backend> {
encoder: Encoder<B>,
decoder: Decoder<B>,
quant_conv: Conv2d<B>,
post_quant_conv: Conv2d<B>,
encoder: Encoder<B>,
decoder: Decoder<B>,
quant_conv: Conv2d<B>,
post_quant_conv: Conv2d<B>,
}
impl<B: Backend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent( self.encode_image(x) )
self.decode_latent(self.encode_image(x))
}
pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
@@ -72,48 +72,60 @@ impl<B: Backend> Autoencoder<B> {
#[derive(Config)]
pub struct EncoderConfig {
channels: Vec<(usize, usize)>,
n_group: usize,
n_channels_out: usize,
channels: Vec<(usize, usize)>,
n_group: usize,
n_channels_out: usize,
}
impl EncoderConfig {
fn init<B: Backend>(&self) -> Encoder<B> {
let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty.");
let n_expanded_channels_initial = self
.channels
.first()
.map(|f| f.1)
.expect("Channels must not be empty.");
let n_expanded_channels_final = self.channels.first().unwrap().0;
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
}).collect();
let blocks = self
.channels
.iter()
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
})
.collect();
let mid = MidConfig::new(n_expanded_channels_final).init();
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
Encoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
}
}
}
#[derive(Module, Debug)]
pub struct Encoder<B: Backend> {
conv_in: Conv2d<B>,
mid: Mid<B>,
blocks: Vec<EncoderBlock<B>>,
norm_out: GroupNorm<B>,
silu: SILU,
conv_out: Conv2d<B>,
conv_in: Conv2d<B>,
mid: Mid<B>,
blocks: Vec<EncoderBlock<B>>,
norm_out: GroupNorm<B>,
silu: SILU,
conv_out: Conv2d<B>,
}
impl<B: Backend> Encoder<B> {
@@ -126,55 +138,66 @@ impl<B: Backend> Encoder<B> {
}
let x = self.mid.forward(x);
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
self.conv_out
.forward(self.silu.forward(self.norm_out.forward(x)))
}
}
#[derive(Config)]
pub struct DecoderConfig {
channels: Vec<(usize, usize)>,
n_group: usize,
channels: Vec<(usize, usize)>,
n_group: usize,
}
impl DecoderConfig {
fn init<B: Backend>(&self) -> Decoder<B> {
let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty.");
let n_expanded_channels = self
.channels
.first()
.map(|f| f.0)
.expect("Channels must not be empty.");
let n_condensed_channels = self.channels.last().unwrap().1;
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let mid = MidConfig::new(n_expanded_channels).init();
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
}).collect();
let blocks = self
.channels
.iter()
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
})
.collect();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
Decoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
}
}
}
#[derive(Module, Debug)]
pub struct Decoder<B: Backend> {
conv_in: Conv2d<B>,
mid: Mid<B>,
blocks: Vec<DecoderBlock<B>>,
norm_out: GroupNorm<B>,
silu: SILU,
conv_out: Conv2d<B>,
conv_in: Conv2d<B>,
mid: Mid<B>,
blocks: Vec<DecoderBlock<B>>,
norm_out: GroupNorm<B>,
silu: SILU,
conv_out: Conv2d<B>,
}
impl<B: Backend> Decoder<B> {
@@ -187,15 +210,16 @@ impl<B: Backend> Decoder<B> {
x = block.forward(x);
}
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
self.conv_out
.forward(self.silu.forward(self.norm_out.forward(x)))
}
}
#[derive(Config)]
pub struct EncoderBlockConfig {
n_channels_in: usize,
n_channels_out: usize,
downsample: bool,
n_channels_in: usize,
n_channels_out: usize,
downsample: bool,
}
impl EncoderBlockConfig {
@@ -204,24 +228,28 @@ impl EncoderBlockConfig {
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let downsampler = if self.downsample {
let padding = Padding::new(0, 1, 0, 1);
Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() )
Some(
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
.with_stride(2)
.init(),
)
} else {
None
};
EncoderBlock {
res1,
res2,
downsampler,
res1,
res2,
downsampler,
}
}
}
#[derive(Module, Debug)]
pub struct EncoderBlock<B: Backend> {
res1: ResnetBlock<B>,
res2: ResnetBlock<B>,
downsampler: Option<PaddedConv2d<B>>,
res1: ResnetBlock<B>,
res2: ResnetBlock<B>,
downsampler: Option<PaddedConv2d<B>>,
}
impl<B: Backend> EncoderBlock<B> {
@@ -238,9 +266,9 @@ impl<B: Backend> EncoderBlock<B> {
#[derive(Config)]
pub struct DecoderBlockConfig {
n_channels_in: usize,
n_channels_out: usize,
upsample: bool,
n_channels_in: usize,
n_channels_out: usize,
upsample: bool,
}
impl DecoderBlockConfig {
@@ -249,26 +277,30 @@ impl DecoderBlockConfig {
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let upsampler = if self.upsample {
Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() )
Some(
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(),
)
} else {
None
};
DecoderBlock {
res1,
res2,
res3,
upsampler,
res1,
res2,
res3,
upsampler,
}
}
}
#[derive(Module, Debug)]
pub struct DecoderBlock<B: Backend> {
res1: ResnetBlock<B>,
res2: ResnetBlock<B>,
res3: ResnetBlock<B>,
upsampler: Option<Conv2d<B>>,
res1: ResnetBlock<B>,
res2: ResnetBlock<B>,
res3: ResnetBlock<B>,
upsampler: Option<Conv2d<B>>,
}
impl<B: Backend> DecoderBlock<B> {
@@ -280,10 +312,10 @@ impl<B: Backend> DecoderBlock<B> {
if let Some(d) = self.upsampler.as_ref() {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2)
.repeat(5, 2)
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2)
.repeat(5, 2)
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
d.forward(x)
} else {
x
@@ -291,14 +323,13 @@ impl<B: Backend> DecoderBlock<B> {
}
}
#[derive(Config)]
pub struct PaddedConv2dConfig {
channels: [usize; 2],
kernel_size: usize,
channels: [usize; 2],
kernel_size: usize,
#[config(default = 1)]
stride: usize,
padding: Padding,
stride: usize,
padding: Padding,
}
impl PaddedConv2dConfig {
@@ -328,57 +359,68 @@ impl PaddedConv2dConfig {
let padding = self.padding;
PaddedConv2d {
conv,
kernel_size,
stride,
padding,
padding_actual,
conv,
kernel_size,
stride,
padding,
padding_actual,
}
}
}
fn div_roundup(x: usize, y: usize) -> usize {
(x + y - 1) / y
}
#[derive(Module, Debug)]
pub struct PaddedConv2d<B: Backend> {
conv: Conv2d<B>,
kernel_size: usize,
stride: usize,
padding: Padding,
padding_actual: [usize; 2],
conv: Conv2d<B>,
kernel_size: usize,
stride: usize,
padding: Padding,
padding_actual: [usize; 2],
}
impl<B: Backend> PaddedConv2d<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
println!("{} {} {:?} {:?}", self.kernel_size, self.stride, self.padding, self.padding_actual);
println!(
"{} {} {:?} {:?}",
self.kernel_size, self.stride, self.padding, self.padding_actual
);
let [n_batch, n_channel, height, width] = x.dims();
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1;
let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1;
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
- self.kernel_size)
/ self.stride
+ 1;
let desired_width = (self.padding.pad_left + self.padding.pad_right + width
- self.kernel_size)
/ self.stride
+ 1;
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
self.conv
.forward(x)
.slice([
0..n_batch,
0..n_channel,
skip_vert..(skip_vert + desired_height),
skip_hor..(skip_hor + desired_width)
])
self.conv.forward(x).slice([
0..n_batch,
0..n_channel,
skip_vert..(skip_vert + desired_height),
skip_hor..(skip_hor + desired_width),
])
}
}
#[derive(Config, Module, Copy, Debug)]
pub struct Padding {
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_bottom: usize,
}
#[derive(Config)]
pub struct MidConfig {
n_channel: usize,
n_channel: usize,
}
impl MidConfig {
@@ -388,18 +430,18 @@ impl MidConfig {
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
Mid {
block_1,
attn,
block_2,
block_1,
attn,
block_2,
}
}
}
#[derive(Module, Debug)]
pub struct Mid<B: Backend> {
block_1: ResnetBlock<B>,
attn: ConvSelfAttentionBlock<B>,
block_2: ResnetBlock<B>,
block_1: ResnetBlock<B>,
attn: ConvSelfAttentionBlock<B>,
block_2: ResnetBlock<B>,
}
impl<B: Backend> Mid<B> {
@@ -411,21 +453,24 @@ impl<B: Backend> Mid<B> {
}
}
#[derive(Config)]
pub struct ResnetBlockConfig {
in_channels: usize,
out_channels: usize,
in_channels: usize,
out_channels: usize,
}
impl ResnetBlockConfig {
fn init<B: Backend>(&self) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let nin_shortcut = if self.in_channels != self.out_channels {
Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() )
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init())
} else {
None
};
@@ -434,34 +479,37 @@ impl ResnetBlockConfig {
let silu2 = SILU::new();
ResnetBlock {
norm1,
silu1,
conv1,
norm2,
silu2,
conv2,
nin_shortcut,
norm1,
silu1,
conv1,
norm2,
silu2,
conv2,
nin_shortcut,
}
}
}
#[derive(Module, Debug)]
pub struct ResnetBlock<B: Backend> {
norm1: GroupNorm<B>,
silu1: SILU,
conv1: Conv2d<B>,
norm2: GroupNorm<B>,
silu2: SILU,
conv2: Conv2d<B>,
nin_shortcut: Option<Conv2d<B>>,
norm1: GroupNorm<B>,
silu1: SILU,
conv1: Conv2d<B>,
norm2: GroupNorm<B>,
silu2: SILU,
conv2: Conv2d<B>,
nin_shortcut: Option<Conv2d<B>>,
}
impl<B: Backend> ResnetBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) );
let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) );
let h = self
.conv1
.forward(self.silu1.forward(self.norm1.forward(x.clone())));
let h = self
.conv2
.forward(self.silu2.forward(self.norm2.forward(h)));
if let Some(ns) = self.nin_shortcut.as_ref() {
ns.forward(x) + h
} else {
@@ -472,7 +520,7 @@ impl<B: Backend> ResnetBlock<B> {
#[derive(Config)]
pub struct ConvSelfAttentionBlockConfig {
n_channel: usize,
n_channel: usize,
}
impl ConvSelfAttentionBlockConfig {
@@ -484,22 +532,22 @@ impl ConvSelfAttentionBlockConfig {
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
ConvSelfAttentionBlock {
norm,
q,
k,
v,
proj_out,
norm,
q,
k,
v,
proj_out,
}
}
}
#[derive(Module, Debug)]
pub struct ConvSelfAttentionBlock<B: Backend> {
norm: GroupNorm<B>,
q: Conv2d<B>,
k: Conv2d<B>,
v: Conv2d<B>,
proj_out: Conv2d<B>,
norm: GroupNorm<B>,
q: Conv2d<B>,
k: Conv2d<B>,
v: Conv2d<B>,
proj_out: Conv2d<B>,
}
impl<B: Backend> ConvSelfAttentionBlock<B> {
@@ -508,9 +556,21 @@ impl<B: Backend> ConvSelfAttentionBlock<B> {
let h = self.norm.forward(x.clone());
let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
let q = self
.q
.forward(h.clone())
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let k = self
.k
.forward(h.clone())
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let v = self
.v
.forward(h)
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let wv = qkv_attention(q, k, v, None, 1)
.swap_dims(1, 2)