Replace helper functions with native burn functions
This commit is contained in:
@@ -1,59 +1,59 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}},
|
||||
nn::{
|
||||
self,
|
||||
conv::{Conv2d, Conv2dConfig, Conv2dRecord},
|
||||
PaddingConfig2d,
|
||||
},
|
||||
tensor::{
|
||||
activation::{sigmoid, softmax},
|
||||
backend::Backend,
|
||||
activation::{softmax, sigmoid},
|
||||
module::embedding,
|
||||
Tensor,
|
||||
Distribution,
|
||||
Int,
|
||||
module::embedding,
|
||||
Distribution, Int, Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::helper::div_roundup;
|
||||
|
||||
use super::silu::*;
|
||||
use super::groupnorm::*;
|
||||
use super::attention::qkv_attention;
|
||||
use super::groupnorm::*;
|
||||
use super::silu::*;
|
||||
|
||||
use std::iter;
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct AutoencoderConfig {}
|
||||
|
||||
impl AutoencoderConfig {
|
||||
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
|
||||
let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
|
||||
let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
|
||||
let encoder =
|
||||
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
|
||||
let decoder =
|
||||
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
|
||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
|
||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
|
||||
|
||||
Autoencoder {
|
||||
encoder,
|
||||
decoder,
|
||||
quant_conv,
|
||||
post_quant_conv,
|
||||
encoder,
|
||||
decoder,
|
||||
quant_conv,
|
||||
post_quant_conv,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Autoencoder<B: Backend> {
|
||||
encoder: Encoder<B>,
|
||||
decoder: Decoder<B>,
|
||||
quant_conv: Conv2d<B>,
|
||||
post_quant_conv: Conv2d<B>,
|
||||
encoder: Encoder<B>,
|
||||
decoder: Decoder<B>,
|
||||
quant_conv: Conv2d<B>,
|
||||
post_quant_conv: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Autoencoder<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.decode_latent( self.encode_image(x) )
|
||||
self.decode_latent(self.encode_image(x))
|
||||
}
|
||||
|
||||
pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
@@ -72,48 +72,60 @@ impl<B: Backend> Autoencoder<B> {
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct EncoderConfig {
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
n_channels_out: usize,
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
n_channels_out: usize,
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
fn init<B: Backend>(&self) -> Encoder<B> {
|
||||
let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty.");
|
||||
let n_expanded_channels_initial = self
|
||||
.channels
|
||||
.first()
|
||||
.map(|f| f.1)
|
||||
.expect("Channels must not be empty.");
|
||||
let n_expanded_channels_final = self.channels.first().unwrap().0;
|
||||
|
||||
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
|
||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let downsample = i != self.channels.len() - 1;
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
|
||||
}).collect();
|
||||
let blocks = self
|
||||
.channels
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let downsample = i != self.channels.len() - 1;
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mid = MidConfig::new(n_expanded_channels_final).init();
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
|
||||
Encoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Encoder<B: Backend> {
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<EncoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<EncoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Encoder<B> {
|
||||
@@ -126,55 +138,66 @@ impl<B: Backend> Encoder<B> {
|
||||
}
|
||||
|
||||
let x = self.mid.forward(x);
|
||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
||||
self.conv_out
|
||||
.forward(self.silu.forward(self.norm_out.forward(x)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct DecoderConfig {
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
}
|
||||
|
||||
impl DecoderConfig {
|
||||
fn init<B: Backend>(&self) -> Decoder<B> {
|
||||
let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty.");
|
||||
let n_expanded_channels = self
|
||||
.channels
|
||||
.first()
|
||||
.map(|f| f.0)
|
||||
.expect("Channels must not be empty.");
|
||||
let n_condensed_channels = self.channels.last().unwrap().1;
|
||||
|
||||
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
let mid = MidConfig::new(n_expanded_channels).init();
|
||||
|
||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let upsample = i != self.channels.len() - 1;
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
|
||||
}).collect();
|
||||
let blocks = self
|
||||
.channels
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let upsample = i != self.channels.len() - 1;
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
|
||||
Decoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Decoder<B: Backend> {
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<DecoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<DecoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Decoder<B> {
|
||||
@@ -187,15 +210,16 @@ impl<B: Backend> Decoder<B> {
|
||||
x = block.forward(x);
|
||||
}
|
||||
|
||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
||||
self.conv_out
|
||||
.forward(self.silu.forward(self.norm_out.forward(x)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct EncoderBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
downsample: bool,
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
downsample: bool,
|
||||
}
|
||||
|
||||
impl EncoderBlockConfig {
|
||||
@@ -204,24 +228,28 @@ impl EncoderBlockConfig {
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let downsampler = if self.downsample {
|
||||
let padding = Padding::new(0, 1, 0, 1);
|
||||
Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() )
|
||||
Some(
|
||||
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
|
||||
.with_stride(2)
|
||||
.init(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
EncoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
downsampler,
|
||||
res1,
|
||||
res2,
|
||||
downsampler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct EncoderBlock<B: Backend> {
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
downsampler: Option<PaddedConv2d<B>>,
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
downsampler: Option<PaddedConv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> EncoderBlock<B> {
|
||||
@@ -238,9 +266,9 @@ impl<B: Backend> EncoderBlock<B> {
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct DecoderBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
upsample: bool,
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
upsample: bool,
|
||||
}
|
||||
|
||||
impl DecoderBlockConfig {
|
||||
@@ -249,26 +277,30 @@ impl DecoderBlockConfig {
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let upsampler = if self.upsample {
|
||||
Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() )
|
||||
Some(
|
||||
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
DecoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
upsampler,
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
upsampler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DecoderBlock<B: Backend> {
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
res3: ResnetBlock<B>,
|
||||
upsampler: Option<Conv2d<B>>,
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
res3: ResnetBlock<B>,
|
||||
upsampler: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> DecoderBlock<B> {
|
||||
@@ -280,10 +312,10 @@ impl<B: Backend> DecoderBlock<B> {
|
||||
if let Some(d) = self.upsampler.as_ref() {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
d.forward(x)
|
||||
} else {
|
||||
x
|
||||
@@ -291,14 +323,13 @@ impl<B: Backend> DecoderBlock<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct PaddedConv2dConfig {
|
||||
channels: [usize; 2],
|
||||
kernel_size: usize,
|
||||
channels: [usize; 2],
|
||||
kernel_size: usize,
|
||||
#[config(default = 1)]
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
}
|
||||
|
||||
impl PaddedConv2dConfig {
|
||||
@@ -328,57 +359,68 @@ impl PaddedConv2dConfig {
|
||||
let padding = self.padding;
|
||||
|
||||
PaddedConv2d {
|
||||
conv,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
padding_actual,
|
||||
conv,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
padding_actual,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn div_roundup(x: usize, y: usize) -> usize {
|
||||
(x + y - 1) / y
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PaddedConv2d<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
padding_actual: [usize; 2],
|
||||
conv: Conv2d<B>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
padding_actual: [usize; 2],
|
||||
}
|
||||
|
||||
impl<B: Backend> PaddedConv2d<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
println!("{} {} {:?} {:?}", self.kernel_size, self.stride, self.padding, self.padding_actual);
|
||||
println!(
|
||||
"{} {} {:?} {:?}",
|
||||
self.kernel_size, self.stride, self.padding, self.padding_actual
|
||||
);
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1;
|
||||
let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1;
|
||||
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
|
||||
- self.kernel_size)
|
||||
/ self.stride
|
||||
+ 1;
|
||||
let desired_width = (self.padding.pad_left + self.padding.pad_right + width
|
||||
- self.kernel_size)
|
||||
/ self.stride
|
||||
+ 1;
|
||||
|
||||
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
|
||||
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
|
||||
|
||||
self.conv
|
||||
.forward(x)
|
||||
.slice([
|
||||
0..n_batch,
|
||||
0..n_channel,
|
||||
skip_vert..(skip_vert + desired_height),
|
||||
skip_hor..(skip_hor + desired_width)
|
||||
])
|
||||
self.conv.forward(x).slice([
|
||||
0..n_batch,
|
||||
0..n_channel,
|
||||
skip_vert..(skip_vert + desired_height),
|
||||
skip_hor..(skip_hor + desired_width),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config, Module, Copy, Debug)]
|
||||
pub struct Padding {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
pad_top: usize,
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
pad_top: usize,
|
||||
pad_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct MidConfig {
|
||||
n_channel: usize,
|
||||
n_channel: usize,
|
||||
}
|
||||
|
||||
impl MidConfig {
|
||||
@@ -388,18 +430,18 @@ impl MidConfig {
|
||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
|
||||
Mid {
|
||||
block_1,
|
||||
attn,
|
||||
block_2,
|
||||
block_1,
|
||||
attn,
|
||||
block_2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Mid<B: Backend> {
|
||||
block_1: ResnetBlock<B>,
|
||||
attn: ConvSelfAttentionBlock<B>,
|
||||
block_2: ResnetBlock<B>,
|
||||
block_1: ResnetBlock<B>,
|
||||
attn: ConvSelfAttentionBlock<B>,
|
||||
block_2: ResnetBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Mid<B> {
|
||||
@@ -411,21 +453,24 @@ impl<B: Backend> Mid<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResnetBlockConfig {
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
}
|
||||
|
||||
impl ResnetBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ResnetBlock<B> {
|
||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
|
||||
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
|
||||
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
let nin_shortcut = if self.in_channels != self.out_channels {
|
||||
Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() )
|
||||
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -434,34 +479,37 @@ impl ResnetBlockConfig {
|
||||
let silu2 = SILU::new();
|
||||
|
||||
ResnetBlock {
|
||||
norm1,
|
||||
silu1,
|
||||
conv1,
|
||||
norm2,
|
||||
silu2,
|
||||
conv2,
|
||||
nin_shortcut,
|
||||
norm1,
|
||||
silu1,
|
||||
conv1,
|
||||
norm2,
|
||||
silu2,
|
||||
conv2,
|
||||
nin_shortcut,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResnetBlock<B: Backend> {
|
||||
norm1: GroupNorm<B>,
|
||||
silu1: SILU,
|
||||
conv1: Conv2d<B>,
|
||||
norm2: GroupNorm<B>,
|
||||
silu2: SILU,
|
||||
conv2: Conv2d<B>,
|
||||
nin_shortcut: Option<Conv2d<B>>,
|
||||
norm1: GroupNorm<B>,
|
||||
silu1: SILU,
|
||||
conv1: Conv2d<B>,
|
||||
norm2: GroupNorm<B>,
|
||||
silu2: SILU,
|
||||
conv2: Conv2d<B>,
|
||||
nin_shortcut: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResnetBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) );
|
||||
let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) );
|
||||
let h = self
|
||||
.conv1
|
||||
.forward(self.silu1.forward(self.norm1.forward(x.clone())));
|
||||
let h = self
|
||||
.conv2
|
||||
.forward(self.silu2.forward(self.norm2.forward(h)));
|
||||
|
||||
|
||||
if let Some(ns) = self.nin_shortcut.as_ref() {
|
||||
ns.forward(x) + h
|
||||
} else {
|
||||
@@ -472,7 +520,7 @@ impl<B: Backend> ResnetBlock<B> {
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ConvSelfAttentionBlockConfig {
|
||||
n_channel: usize,
|
||||
n_channel: usize,
|
||||
}
|
||||
|
||||
impl ConvSelfAttentionBlockConfig {
|
||||
@@ -484,22 +532,22 @@ impl ConvSelfAttentionBlockConfig {
|
||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
|
||||
ConvSelfAttentionBlock {
|
||||
norm,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
proj_out,
|
||||
norm,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
proj_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvSelfAttentionBlock<B: Backend> {
|
||||
norm: GroupNorm<B>,
|
||||
q: Conv2d<B>,
|
||||
k: Conv2d<B>,
|
||||
v: Conv2d<B>,
|
||||
proj_out: Conv2d<B>,
|
||||
norm: GroupNorm<B>,
|
||||
q: Conv2d<B>,
|
||||
k: Conv2d<B>,
|
||||
v: Conv2d<B>,
|
||||
proj_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
@@ -508,9 +556,21 @@ impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
|
||||
let h = self.norm.forward(x.clone());
|
||||
|
||||
let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let q = self
|
||||
.q
|
||||
.forward(h.clone())
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
let k = self
|
||||
.k
|
||||
.forward(h.clone())
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
let v = self
|
||||
.v
|
||||
.forward(h)
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
let wv = qkv_attention(q, k, v, None, 1)
|
||||
.swap_dims(1, 2)
|
||||
|
||||
Reference in New Issue
Block a user