Update to burn v0.14.0 and switch to .mpk model file

This commit is contained in:
Hermes
2024-10-05 14:19:49 -04:00
parent 9e4d7bd310
commit 893fb0950d
19 changed files with 366 additions and 311 deletions

View File

@@ -1,13 +1,16 @@
use burn::tensor::{activation::softmax, Tensor};
use burn::prelude::Backend;
/*pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
pub trait Backend: burn::tensor::backend::Backend {
fn qkv_attention(
q: Self::TensorPrimitive<3>,
k: Self::TensorPrimitive<3>,
v: Self::TensorPrimitive<3>,
mask: Option<Self::TensorPrimitive<2>>,
q: FloatTensor<Self, 3>,
k: FloatTensor<Self, 3>,
v: FloatTensor<Self, 3>,
mask: Option<FloatTensor<Self, 2>>,
n_head: usize,
) -> Self::TensorPrimitive<3> {
) -> FloatTensor<Self, 3> {
qkv_attention(
Tensor::<Self, 3>::from_primitive(q),
Tensor::from_primitive(k),
@@ -18,24 +21,23 @@ pub trait Backend: burn::tensor::backend::Backend {
.into_primitive()
}
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> Self::TensorPrimitive<2> {
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor<Self, 2> {
attn_decoder_mask::<Self>(seq_length, device).into_primitive()
}
}
use burn::tensor::ops::TensorOps;
use burn::tensor::Float;
use burn_tch::{self, TchElement, TchTensor};
use tch;
impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
fn qkv_attention(
q: Self::TensorPrimitive<3>,
k: Self::TensorPrimitive<3>,
v: Self::TensorPrimitive<3>,
mask: Option<Self::TensorPrimitive<2>>,
q: FloatTensor<Self, 3>,
k: FloatTensor<Self, 3>,
v: FloatTensor<Self, 3>,
mask: Option<FloatTensor<Self, 2>>,
n_head: usize,
) -> Self::TensorPrimitive<3> {
) -> FloatTensor<Self, 2> {
let q = Tensor::from_primitive(q);
let k = Tensor::from_primitive(k);
let v = Tensor::from_primitive(v);
@@ -56,7 +58,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
// for some reason torch crashes when mask is None
let mask = mask.unwrap_or_else(|| {
Tensor::<Self, 2, Float>::zeros_device([q_ctx, k_ctx], &Self::device(&v))
Tensor::<Self, 2, Float>::zeros([q_ctx, k_ctx], &Self::device(&v))
.into_primitive()
});
@@ -68,6 +70,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
Some(mask.tensor),
0.0,
false,
None,
),
))
.swap_dims(1, 2)
@@ -78,11 +81,11 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
use burn_autodiff;
impl<B: Backend> Backend for burn_autodiff::ADBackendDecorator<B> {}
impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {}*/
use std::f32::NEG_INFINITY;
fn qkv_attention<B: Backend>(
pub fn qkv_attention<B: Backend>(
q: Tensor<B, 3>,
k: Tensor<B, 3>,
v: Tensor<B, 3>,
@@ -124,13 +127,13 @@ fn qkv_attention<B: Backend>(
return o;
}
fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask.to_device(device);
return mask;
}

View File

@@ -11,9 +11,9 @@ use burn::{
tensor::{backend::Backend, Tensor},
};
use burn_ndarray::{NdArrayBackend, NdArrayDevice};
use burn_ndarray::{NdArray, NdArrayDevice};
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
fn convert_dump_to_model<B: Backend>(
dump_path: &str,
@@ -33,11 +33,11 @@ fn save_model_file<B: Backend>(
model: StableDiffusion<B>,
name: &str,
) -> Result<(), record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
NamedMpkFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
}
fn main() {
type Backend = NdArrayBackend<f32>;
type Backend = NdArray<f32>;
let device = NdArrayDevice::Cpu;
let args: Vec<String> = env::args().collect();

View File

@@ -14,7 +14,7 @@ cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
} else {
use burn_tch::{TchBackend, TchDevice};
use burn_tch::{LibTorch, LibTorchDevice};
}
}
@@ -22,30 +22,21 @@ use std::env;
use std::io;
use std::process;
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
fn load_stable_diffusion_model_file<B: Backend>(
filename: &str,
device: &B::Device,
) -> Result<StableDiffusion<B>, record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into())
.map(|record| StableDiffusionConfig::new().init().load_record(record))
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into(), device)
.map(|record| StableDiffusionConfig::new().init(device).load_record(record))
}
fn main() {
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = TchBackend<f32>;
let device = TchDevice::Cuda(0);
}
}
let args: Vec<String> = std::env::args().collect();
if args.len() != 7 {
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
if args.len() != 7 && args.len() != 8 {
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [device(cuda, mps, cpu)]", args[0]);
process::exit(1);
}
@@ -62,11 +53,40 @@ fn main() {
let prompt = &args[5];
let output_image_name = &args[6];
// Optional device parameter
let device_arg = if args.len() == 8 { Some(&args[7]) } else { None };
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = LibTorch<f32>;
let device = if let Some(dev_str) = device_arg {
match dev_str.to_lowercase().as_str() {
"cpu" => LibTorchDevice::Cpu,
"mps" => LibTorchDevice::Mps,
s if s.starts_with("cuda") => {
let idx = s[4..].parse().unwrap_or(0);
LibTorchDevice::Cuda(idx)
}
_ => {
eprintln!("Unknown device: {}", dev_str);
process::exit(1);
}
}
} else {
LibTorchDevice::Cuda(0)
};
}
}
println!("Loading tokenizer...");
let tokenizer = SimpleTokenizer::new().unwrap();
println!("Loading model...");
let sd: StableDiffusion<Backend> = if model_type == "burn" {
load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
load_stable_diffusion_model_file(model_name, &device).unwrap_or_else(|err| {
eprintln!("Error loading model: {}", err);
process::exit(1);
})
@@ -77,8 +97,6 @@ fn main() {
})
};
let sd = sd.to_device(&device);
let unconditional_context = sd.unconditional_context(&tokenizer);
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples

View File

@@ -45,12 +45,12 @@ pub fn qkv_attention<B: Backend>(
}
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask.to_device(device);
return mask;
}

View File

@@ -71,7 +71,7 @@ fn load_padded_conv2d<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
let mut conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
let channels = load_tensor::<B, 1>("channels", path, device)?;
let channels = tensor_to_array_2(channels);
@@ -81,18 +81,21 @@ fn load_padded_conv2d<B: Backend>(
let padding = load_tensor::<B, 1>("padding", path, device)?;
let padding: [usize; 4] = tensor_to_array(padding);
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]);
let padding = PaddingCfg::new(padding[0], padding[1], padding[2], padding[3]);
let mut record = conv.into_record();
//let mut record = conv.into_record();
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding)
.with_stride(stride)
.init();
.init(device);
let padding_actual =
PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
padded_conv.conv = padded_conv.conv.load_record(record);
conv.padding = burn::module::Ignored(padding_actual);
padded_conv.conv = conv;
//record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
//padded_conv.conv = padded_conv.conv.load_record(record);
Ok(padded_conv)
}

View File

@@ -18,7 +18,8 @@ use burn::{
use super::groupnorm::*;
use super::silu::*;
use crate::backend::Backend as MyBackend;
//use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
use std::iter;
@@ -26,13 +27,13 @@ use std::iter;
pub struct AutoencoderConfig {}
impl AutoencoderConfig {
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
let encoder =
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
let decoder =
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
Autoencoder {
encoder,
@@ -51,7 +52,7 @@ pub struct Autoencoder<B: Backend> {
post_quant_conv: Conv2d<B>,
}
impl<B: MyBackend> Autoencoder<B> {
impl<B: Backend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent(self.encode_image(x))
}
@@ -78,7 +79,7 @@ pub struct EncoderConfig {
}
impl EncoderConfig {
fn init<B: Backend>(&self) -> Encoder<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
let n_expanded_channels_initial = self
.channels
.first()
@@ -88,7 +89,7 @@ impl EncoderConfig {
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let blocks = self
.channels
@@ -96,16 +97,16 @@ impl EncoderConfig {
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
})
.collect();
let mid = MidConfig::new(n_expanded_channels_final).init();
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
let mid = MidConfig::new(n_expanded_channels_final).init(device);
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
Encoder {
conv_in,
@@ -128,7 +129,7 @@ pub struct Encoder<B: Backend> {
conv_out: Conv2d<B>,
}
impl<B: MyBackend> Encoder<B> {
impl<B: Backend> Encoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
@@ -150,7 +151,7 @@ pub struct DecoderConfig {
}
impl DecoderConfig {
fn init<B: Backend>(&self) -> Decoder<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
let n_expanded_channels = self
.channels
.first()
@@ -160,8 +161,8 @@ impl DecoderConfig {
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let mid = MidConfig::new(n_expanded_channels).init();
.init(device);
let mid = MidConfig::new(n_expanded_channels).init(device);
let blocks = self
.channels
@@ -169,15 +170,15 @@ impl DecoderConfig {
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
})
.collect();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
Decoder {
conv_in,
@@ -200,7 +201,7 @@ pub struct Decoder<B: Backend> {
conv_out: Conv2d<B>,
}
impl<B: MyBackend> Decoder<B> {
impl<B: Backend> Decoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
let x = self.mid.forward(x);
@@ -223,15 +224,15 @@ pub struct EncoderBlockConfig {
}
impl EncoderBlockConfig {
fn init<B: Backend>(&self) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let downsampler = if self.downsample {
let padding = Padding::new(0, 1, 0, 1);
let padding = PaddingCfg::new(0, 1, 0, 1);
Some(
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
.with_stride(2)
.init(),
.init(device),
)
} else {
None
@@ -272,15 +273,15 @@ pub struct DecoderBlockConfig {
}
impl DecoderBlockConfig {
fn init<B: Backend>(&self) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let upsampler = if self.upsample {
Some(
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(),
.init(device),
)
} else {
None
@@ -313,8 +314,7 @@ impl<B: Backend> DecoderBlock<B> {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2)
.repeat(5, 2)
.repeat(&[1, 1, 1, 2, 1, 2])
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
d.forward(x)
} else {
@@ -329,11 +329,11 @@ pub struct PaddedConv2dConfig {
kernel_size: usize,
#[config(default = 1)]
stride: usize,
padding: Padding,
padding: PaddingCfg,
}
impl PaddedConv2dConfig {
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
let calc_padding = |p_left, p_right| {
let n = if p_left >= p_right {
0
@@ -351,12 +351,17 @@ impl PaddedConv2dConfig {
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
.with_stride([self.stride, self.stride])
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
.init();
.init(device);
let kernel_size = self.kernel_size;
let stride = self.stride;
let padding = self.padding;
let padding = Padding {
pad_left: self.padding.pad_left,
pad_right: self.padding.pad_right,
pad_top: self.padding.pad_top,
pad_bottom: self.padding.pad_bottom,
};
PaddedConv2d {
conv,
@@ -406,7 +411,15 @@ impl<B: Backend> PaddedConv2d<B> {
}
}
#[derive(Config, Module, Copy, Debug)]
#[derive(Config, Debug)]
pub struct PaddingCfg {
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_bottom: usize,
}
#[derive(Module, Clone, Debug)]
pub struct Padding {
pad_left: usize,
pad_right: usize,
@@ -420,10 +433,10 @@ pub struct MidConfig {
}
impl MidConfig {
fn init<B: Backend>(&self) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
Mid {
block_1,
@@ -440,7 +453,7 @@ pub struct Mid<B: Backend> {
block_2: ResnetBlock<B>,
}
impl<B: MyBackend> Mid<B> {
impl<B: Backend> Mid<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.block_1.forward(x);
let x = self.attn.forward(x);
@@ -456,17 +469,17 @@ pub struct ResnetBlockConfig {
}
impl ResnetBlockConfig {
fn init<B: Backend>(&self) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
.init(device);
let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let nin_shortcut = if self.in_channels != self.out_channels {
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init())
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
} else {
None
};
@@ -520,12 +533,12 @@ pub struct ConvSelfAttentionBlockConfig {
}
impl ConvSelfAttentionBlockConfig {
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init();
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init(device);
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
ConvSelfAttentionBlock {
norm,
@@ -546,7 +559,7 @@ pub struct ConvSelfAttentionBlock<B: Backend> {
proj_out: Conv2d<B>,
}
impl<B: MyBackend> ConvSelfAttentionBlock<B> {
impl<B: Backend> ConvSelfAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
@@ -568,7 +581,7 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let wv = Tensor::from_primitive(B::qkv_attention(
/*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
@@ -576,6 +589,16 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
1,
))
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);*/
let wv = qkv_attention(
q,
k,
v,
None,
1,
)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
let projected = self.proj_out.forward(wv);

View File

@@ -68,7 +68,7 @@ pub fn load_residual_decoder_attention_block<B: Backend>(
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
let position_embedding =
load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
Param::from_tensor(load_tensor("weight", &format!("{}/position_embedding", path), device)?);
let n_layer = load_usize::<B>("n_layer", path, device)?;
let mut blocks = (0..n_layer)

View File

@@ -12,7 +12,8 @@ use burn::{
},
};
use crate::backend::Backend as MyBackend;
//use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
#[derive(Config)]
pub struct CLIPConfig {
@@ -24,15 +25,15 @@ pub struct CLIPConfig {
}
impl CLIPConfig {
pub fn init<B: Backend>(&self) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
pub fn init<B: Backend>(&self, device: &B::Device) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device);
let position_embedding =
Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device));
let blocks = (0..self.n_layer)
.into_iter()
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device))
.collect();
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device);
CLIP {
token_embedding,
@@ -51,11 +52,12 @@ pub struct CLIP<B: Backend> {
layer_norm: nn::LayerNorm<B>,
}
impl<B: MyBackend> CLIP<B> {
impl<B: Backend> CLIP<B> {
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims();
let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
//let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
let mask = attn_decoder_mask(seq_len, &x.device());
let embedded = self.token_embedding.forward(x)
+ self
@@ -80,12 +82,12 @@ pub struct ResidualDecoderAttentionBlockConfig {
}
impl ResidualDecoderAttentionBlockConfig {
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device);
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
ResidualDecoderAttentionBlock {
attn,
@@ -104,7 +106,7 @@ pub struct ResidualDecoderAttentionBlock<B: Backend> {
mlp_ln: nn::LayerNorm<B>,
}
impl<B: MyBackend> ResidualDecoderAttentionBlock<B> {
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
@@ -119,7 +121,7 @@ pub struct MultiHeadSelfAttentionConfig {
}
impl MultiHeadSelfAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
@@ -128,10 +130,10 @@ impl MultiHeadSelfAttentionConfig {
);
let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
let key = nn::LinearConfig::new(self.n_state, self.n_state).init();
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadSelfAttention {
n_head,
@@ -152,19 +154,27 @@ pub struct MultiHeadSelfAttention<B: Backend> {
out: nn::Linear<B>,
}
impl<B: MyBackend> MultiHeadSelfAttention<B> {
impl<B: Backend> MultiHeadSelfAttention<B> {
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
let q = self.query.forward(x.clone());
let k = self.key.forward(x.clone());
let v = self.value.forward(x);
let wv = Tensor::from_primitive(B::qkv_attention(
/*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
mask.map(|m| m.into_primitive()),
self.n_head,
));
));*/
let wv = qkv_attention(
q,
k,
v,
mask,
self.n_head,
);
return self.out.forward(wv);
}
@@ -177,10 +187,10 @@ pub struct MLPConfig {
}
impl MLPConfig {
fn init<B: Backend>(&self) -> MLP<B> {
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init();
fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device);
let gelu = QuickGELU::new();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device);
MLP { fc1, gelu, fc2 }
}

View File

@@ -18,14 +18,14 @@ pub fn load_group_norm<B: Backend>(
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
let eps = load_f32::<B>("eps", path, device)?.into();
let gamma = load_tensor::<B, 1>("weight", path, device)
let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
.ok()
.unwrap_or_else(|| Tensor::ones_device([n_channel], device))
.into();
let beta = load_tensor::<B, 1>("bias", path, device)
.unwrap_or_else(|| Tensor::ones([n_channel], device))
);
let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
.ok()
.unwrap_or_else(|| Tensor::zeros_device([n_channel], device))
.into();
.unwrap_or_else(|| Tensor::zeros([n_channel], device))
);
Ok(GroupNorm {
n_group,

View File

@@ -15,7 +15,7 @@ pub struct GroupNormConfig {
}
impl GroupNormConfig {
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
assert!(
self.n_channel % self.n_group == 0,
"The number of channels {} must be divisible by the number of groups {}",
@@ -25,8 +25,8 @@ impl GroupNormConfig {
let n_per_group = self.n_channel / self.n_group;
let gamma = Tensor::ones([self.n_channel]).into();
let beta = Tensor::zeros([self.n_channel]).into();
let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
let eps = self.eps;

View File

@@ -1,5 +1,7 @@
use npy::{self, NpyData};
use num_traits::cast::ToPrimitive;
use burn::tensor::cast::ToElement;
use burn::prelude::TensorData;
use std::error::Error;
use std::io::Read;
@@ -21,7 +23,8 @@ pub fn numpy_to_tensor<B: Backend, const D: usize>(
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
Tensor::from_data_device(Data::new(data, shape.into()), device)
//Tensor::from_data_device(Data::new(data, shape.into()), device)
Tensor::from_data(TensorData::new(data, shape), device)
}
pub fn load_tensor<B: Backend, const D: usize>(
@@ -48,7 +51,7 @@ pub fn load_f32<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<f32, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap())
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32())
}
pub fn load_usize<B: Backend>(
@@ -56,7 +59,7 @@ pub fn load_usize<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<usize, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap())
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize())
}
pub fn load_linear<B: Backend>(
@@ -66,13 +69,10 @@ pub fn load_linear<B: Backend>(
let weight = load_tensor::<B, 2>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok();
let record = nn::LinearRecord {
weight: weight.into(),
bias: bias.map(|t| t.into()),
};
let linear: nn::Linear<B> = nn::LinearConfig::new(3, 3).init_with(record);
Ok(linear)
Ok(nn::Linear {
weight: Param::from_tensor(weight),
bias: bias.map(|t| Param::from_tensor(t)),
})
}
pub fn load_embedding<B: Backend>(
@@ -80,14 +80,10 @@ pub fn load_embedding<B: Backend>(
device: &B::Device,
) -> Result<nn::Embedding<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?;
let [n_vocab, n_state] = weight.dims();
let record = nn::EmbeddingRecord {
weight: weight.into(),
};
let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record);
Ok(embedding)
Ok(nn::Embedding {
weight: Param::from_tensor(weight),
})
}
pub fn load_layer_norm<B: Backend>(
@@ -100,13 +96,9 @@ pub fn load_layer_norm<B: Backend>(
let [n_state] = weight.dims();
let record = nn::LayerNormRecord {
gamma: weight.into(),
beta: bias.into(),
epsilon: <f64 as Module<B>>::into_record(eps),
};
let layer_norm: nn::LayerNorm<B> = nn::LayerNormConfig::new(n_state).init_with(record);
let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
layer_norm.gamma = Param::from_tensor(weight);
layer_norm.beta = Param::from_tensor(bias);
Ok(layer_norm)
}
@@ -116,7 +108,7 @@ pub fn load_layer_norm<B: Backend>(
let eps = load_f32::<B>("eps", path, device)?.into();
let rmsnorm = RMSNorm {
weight: weight.into(),
weight: Param::from_tensor(weight),
eps: eps
};
@@ -148,40 +140,38 @@ pub fn load_conv2d<B: Backend>(
let padding = tensor_to_array_2(padding);
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
let record = conv::Conv2dRecord {
weight: weight.into(),
bias: bias.map(|t| t.into()),
stride: <[usize; 2] as Module<B>>::into_record(stride),
kernel_size: <[usize; 2] as Module<B>>::into_record(kernel_size),
dilation: <[usize; 2] as Module<B>>::into_record(dilation),
groups: <usize as Module<B>>::into_record(n_group),
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
};
let conv2d: conv::Conv2d<B> =
conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
let mut conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
.with_stride(stride)
.with_dilation(dilation)
.with_groups(n_group)
.with_padding(padding)
.with_padding(padding.clone())
.with_bias(has_bias)
.init_with(record);
.init(device);
conv2d.weight = Param::from_tensor(weight);
conv2d.bias = bias.map(|t| Param::from_tensor(t));
conv2d.stride = stride;
conv2d.kernel_size = kernel_size;
conv2d.dilation = dilation;
conv2d.groups = n_group;
conv2d.padding = burn::module::Ignored(padding);
Ok(conv2d)
}
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
let vec = x.into_data().value;
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == 2, "Tensor length must be 2.");
[vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()]
[vec[0].to_usize(), vec[1].to_usize()]
}
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
let vec = x.into_data().value;
let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == N, "Tensor length must be {}.", N);
let mut arr = [0; N];
for (a, t) in arr.iter_mut().zip(vec) {
*a = t.to_usize().unwrap();
*a = t.to_usize();
}
arr

View File

@@ -18,7 +18,7 @@ pub fn load_stable_diffusion<B: Backend>(
device: &B::Device,
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
let n_steps = load_usize::<B>("n_steps", path, device)?;
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
let alpha_cumulative_products = Param::from_tensor(load_tensor::<B, 1>("alphas_cumprod", path, device)?);
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;

View File

@@ -4,11 +4,12 @@ use burn::{
config::Config,
module::{Module, Param},
tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor},
tensor::cast::ToElement,
};
use num_traits::ToPrimitive;
use crate::backend::Backend as MyBackend;
//use crate::backend::Backend as MyBackend;
use super::autoencoder::{Autoencoder, AutoencoderConfig};
use super::clip::{CLIPConfig, CLIP};
@@ -19,13 +20,13 @@ use crate::tokenizer::SimpleTokenizer;
pub struct StableDiffusionConfig {}
impl StableDiffusionConfig {
pub fn init<B: Backend>(&self) -> StableDiffusion<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> StableDiffusion<B> {
let n_steps = 1000;
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps).into();
let alpha_cumulative_products = Param::from_tensor(offset_cosine_schedule_cumprod::<B>(n_steps as i64, device));
let autoencoder = AutoencoderConfig::new().init();
let diffusion = UNetConfig::new().init();
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init();
let autoencoder = AutoencoderConfig::new().init(device);
let diffusion = UNetConfig::new().init(device);
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(device);
StableDiffusion {
n_steps,
@@ -46,7 +47,7 @@ pub struct StableDiffusion<B: Backend> {
clip: CLIP<B>,
}
impl<B: MyBackend> StableDiffusion<B> {
impl<B: Backend> StableDiffusion<B> {
pub fn sample_image(
&self,
context: Tensor<B, 3>,
@@ -82,7 +83,7 @@ impl<B: MyBackend> StableDiffusion<B> {
.swap_dims(2, 3)
.mul_scalar(255.0);
let flattened: Vec<_> = image.into_data().value;
let flattened: Vec<B::FloatElem> = image.into_data().to_vec().unwrap();
(0..n_batch)
.into_iter()
@@ -92,7 +93,7 @@ impl<B: MyBackend> StableDiffusion<B> {
flattened[start..end]
.into_iter()
.map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap())
.map(|v| v.to_f64().min(255.0).max(0.0) as u8)
.collect()
})
.collect()
@@ -112,8 +113,7 @@ impl<B: MyBackend> StableDiffusion<B> {
let [n_batches, _, _] = context.dims();
let gen_noise = || {
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0))
.to_device(&device)
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0), &device)
};
let sigma = 0.0; // Use deterministic diffusion
@@ -126,8 +126,8 @@ impl<B: MyBackend> StableDiffusion<B> {
.val()
.slice([t..t + 1])
.into_scalar()
.to_f64()
.unwrap();
.to_f64();
let prev_alpha: f64 = if t >= step_size {
let i = t - step_size;
self.alpha_cumulative_products
@@ -135,14 +135,13 @@ impl<B: MyBackend> StableDiffusion<B> {
.slice([i..i + 1])
.into_scalar()
.to_f64()
.unwrap()
} else {
1.0
};
let sqrt_noise = (1.0 - current_alpha).sqrt();
let timestep = Tensor::from_ints([t as i32]).to_device(&device);
let timestep = Tensor::from_ints([t as i32], &device);
let pred_noise = self.forward_diffuser(
latent.clone(),
timestep,
@@ -174,7 +173,7 @@ impl<B: MyBackend> StableDiffusion<B> {
let unconditional_latent = self.diffusion.forward(
latent.clone(),
timestep.clone(),
unconditional_context.unsqueeze().repeat(0, n_batch),
unconditional_context.unsqueeze().repeat(&[0, n_batch]),
);
let conditional_latent = self.diffusion.forward(latent, timestep, context);
@@ -206,8 +205,7 @@ impl<B: MyBackend> StableDiffusion<B> {
.collect();
self.clip.forward(
Tensor::from_ints(&tokenized[..])
.to_device(device)
Tensor::<B, 1, Int>::from_ints(&tokenized[..], device)
.unsqueeze(),
)
}
@@ -215,25 +213,25 @@ impl<B: MyBackend> StableDiffusion<B> {
use std::f64::consts::PI;
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
Tensor::arange(1..n_steps + 1)
fn cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
Tensor::arange(1..n_steps + 1, device)
.float()
.mul_scalar(PI * 0.5 / n_steps as f64)
.cos()
}
fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
fn offset_cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
let min_signal_rate: f64 = 0.02;
let max_signal_rate: f64 = 0.95;
let start_angle = max_signal_rate.acos();
let end_angle = min_signal_rate.acos();
let times = Tensor::arange(1..n_steps + 1).float();
let times = Tensor::arange(1..n_steps + 1, device).float();
let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle;
diffusion_angles.cos()
}
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
offset_cosine_schedule::<B>(n_steps).powf(2.0)
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
offset_cosine_schedule::<B>(n_steps, device).powf_scalar(2.0)
}

View File

@@ -65,7 +65,7 @@ pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>
let geglue = GEGLU {
proj: proj,
gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct
gelu: Gelu::new(), // Assuming Gelu::new() initializes a new Gelu struct
};
Ok(geglue)

View File

@@ -6,7 +6,7 @@ use burn::{
nn::{
self,
conv::{Conv2d, Conv2dConfig},
PaddingConfig2d, GELU,
PaddingConfig2d, Gelu,
},
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
};
@@ -22,7 +22,7 @@ fn timestep_embedding<B: Backend>(
max_period: usize,
) -> Tensor<B, 2> {
let half = dim / 2;
let freqs = (Tensor::arange_device(0..half, &timesteps.device()).float()
let freqs = (Tensor::arange(0..half as i64, &timesteps.device()).float()
* (-(max_period as f64).ln() / half as f64))
.exp();
let args = timesteps.float() * freqs;
@@ -33,50 +33,50 @@ fn timestep_embedding<B: Backend>(
pub struct UNetConfig {}
impl UNetConfig {
pub fn init<B: Backend>(&self) -> UNet<B> {
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init();
pub fn init<B: Backend>(&self, device: &B::Device) -> UNet<B> {
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(device);
let silu_time_embed = SILU::new();
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init();
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(device);
let input_blocks = UNetInputBlocks {
conv: Conv2dConfig::new([4, 320], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(),
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
d1: DownsampleConfig::new(320).init(),
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(),
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(),
d2: DownsampleConfig::new(640).init(),
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(),
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(),
d3: DownsampleConfig::new(1280).init(),
r1: ResBlockConfig::new(1280, 1280, 1280).init(),
r2: ResBlockConfig::new(1280, 1280, 1280).init(),
.init(device),
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
d1: DownsampleConfig::new(320).init(device),
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(device),
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(device),
d2: DownsampleConfig::new(640).init(device),
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(device),
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(device),
d3: DownsampleConfig::new(1280).init(device),
r1: ResBlockConfig::new(1280, 1280, 1280).init(device),
r2: ResBlockConfig::new(1280, 1280, 1280).init(device),
};
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init();
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(device);
let output_blocks = UNetOutputBlocks {
r1: ResBlockConfig::new(2560, 1280, 1280).init(),
r2: ResBlockConfig::new(2560, 1280, 1280).init(),
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(),
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(),
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(),
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(),
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(),
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(),
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
r1: ResBlockConfig::new(2560, 1280, 1280).init(device),
r2: ResBlockConfig::new(2560, 1280, 1280).init(device),
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(device),
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(device),
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(device),
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(device),
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(device),
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(device),
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
};
let norm_out = GroupNormConfig::new(32, 320).init();
let norm_out = GroupNormConfig::new(32, 320).init(device);
let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([320, 4], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
UNet {
lin1_time_embed,
@@ -206,16 +206,16 @@ pub struct ResTransformerConfig {
}
impl ResTransformerConfig {
fn init<B: Backend>(&self) -> ResTransformer<B> {
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformer<B> {
let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init();
.init(device);
ResTransformer { res, transformer }
}
@@ -243,14 +243,14 @@ pub struct ResUpSampleConfig {
}
impl ResUpSampleConfig {
fn init<B: Backend>(&self) -> ResUpSample<B> {
fn init<B: Backend>(&self, device: &B::Device) -> ResUpSample<B> {
let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
let upsample = UpsampleConfig::new(self.n_channels_out).init();
.init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResUpSample { res, upsample }
}
@@ -280,17 +280,17 @@ pub struct ResTransformerUpsampleConfig {
}
impl ResTransformerUpsampleConfig {
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> {
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerUpsample<B> {
let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init();
let upsample = UpsampleConfig::new(self.n_channels_out).init();
.init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResTransformerUpsample {
res,
@@ -326,22 +326,22 @@ pub struct ResTransformerResConfig {
}
impl ResTransformerResConfig {
fn init<B: Backend>(&self) -> ResTransformerRes<B> {
fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerRes<B> {
let res1 = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init();
.init(device);
let res2 = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
.init(device);
ResTransformerRes {
res1,
@@ -373,10 +373,10 @@ pub struct UpsampleConfig {
}
impl UpsampleConfig {
fn init<B: Backend>(&self) -> Upsample<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Upsample<B> {
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
Upsample { conv }
}
@@ -392,8 +392,7 @@ impl<B: Backend> Upsample<B> {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2)
.repeat(5, 2)
.repeat(&[1, 1, 1, 2, 1, 2])
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
self.conv.forward(x)
}
@@ -411,11 +410,11 @@ pub struct DownsampleConfig {
}
impl DownsampleConfig {
fn init<B: Backend>(&self) -> Conv2d<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init()
.init(device)
}
}
@@ -435,12 +434,12 @@ pub struct SpatialTransformerConfig {
}
impl SpatialTransformerConfig {
fn init<B: Backend>(&self) -> SpatialTransformer<B> {
let norm = GroupNormConfig::new(32, self.n_channels).init();
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
let norm = GroupNormConfig::new(32, self.n_channels).init(device);
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
let transformer =
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init();
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(device);
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
SpatialTransformer {
norm,
@@ -489,14 +488,14 @@ pub struct TransformerBlockConfig {
}
impl TransformerBlockConfig {
fn init<B: Backend>(&self) -> TransformerBlock<B> {
let norm1 = nn::LayerNormConfig::new(self.n_state).init();
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init();
let norm2 = nn::LayerNormConfig::new(self.n_state).init();
fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
let norm1 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(device);
let norm2 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn2 =
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
let norm3 = nn::LayerNormConfig::new(self.n_state).init();
let mlp = MLPConfig::new(self.n_state, 4).init();
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(device);
let norm3 = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4).init(device);
TransformerBlock {
norm1,
@@ -534,10 +533,10 @@ pub struct MLPConfig {
}
impl MLPConfig {
pub fn init<B: Backend>(&self) -> MLP<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let n_state_hidden = self.n_state * self.mult;
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init();
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init();
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(device);
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(device);
MLP { geglu, lin }
}
@@ -562,9 +561,9 @@ pub struct GEGLUConfig {
}
impl GEGLUConfig {
fn init<B: Backend>(&self) -> GEGLU<B> {
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init();
let gelu = GELU::new();
fn init<B: Backend>(&self, device: &B::Device) -> GEGLU<B> {
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(device);
let gelu = Gelu::new();
GEGLU { proj, gelu }
}
@@ -573,7 +572,7 @@ impl GEGLUConfig {
#[derive(Module, Debug)]
pub struct GEGLU<B: Backend> {
proj: nn::Linear<B>,
gelu: GELU,
gelu: Gelu,
}
impl<B: Backend> GEGLU<B> {
@@ -600,7 +599,7 @@ pub struct MultiHeadAttentionConfig {
}
impl MultiHeadAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
@@ -611,14 +610,14 @@ impl MultiHeadAttentionConfig {
let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state)
.with_bias(false)
.init();
.init(device);
let key = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init();
.init(device);
let value = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init();
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
.init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadAttention {
n_head,
@@ -661,24 +660,24 @@ pub struct ResBlockConfig {
}
impl ResBlockConfig {
fn init<B: Backend>(&self) -> ResBlock<B> {
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init();
fn init<B: Backend>(&self, device: &B::Device) -> ResBlock<B> {
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(device);
let silu_in = SILU::new();
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let silu_embed = SILU::new();
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init();
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(device);
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init();
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(device);
let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let skip_connection = if self.n_channels_in != self.n_channels_out {
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init())
Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init(device))
} else {
None
};