Compare commits

..

10 Commits

Author SHA1 Message Date
6cfd6db5a5 updated to the newest version of burn 2026-03-03 15:12:59 +01:00
Hermes
893fb0950d Update to burn v0.14.0 and switch to .mpk model file 2024-10-05 14:19:49 -04:00
Gadersd
9e4d7bd310 Support is hopeless 2023-09-17 12:43:26 -04:00
Gadersd
01b1aea897 Add custom backend to enable flash attention 2023-09-07 12:54:27 -04:00
Gadersd
f4c58c1790 Replace helper functions with native burn functions 2023-09-07 12:23:18 -04:00
Gadersd
a62795347f Use torch by default 2023-08-24 17:12:53 -04:00
Gadersd
1830756917 Use wgpu by default and ndarray for convert 2023-08-08 15:32:21 -04:00
Gadersd
b87273c2be Update README.md 2023-08-07 10:59:57 -04:00
Gadersd
31c24a82ef Update Cargo.toml
Remove unnecessary dependency
2023-08-06 19:55:00 -04:00
Gadersd
c24d37df00 Switch to burn 0.9.0 to gain fast model io without the need for a custom recorder 2023-08-06 19:44:13 -04:00
26 changed files with 1542 additions and 1309 deletions

View File

@@ -6,25 +6,19 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features] [features]
default = ["torch-backend"]
torch-backend = ["burn-tch"]
wgpu-backend = ["burn-wgpu"] wgpu-backend = ["burn-wgpu"]
[dependencies.burn-tch]
version = "0.8.0"
optional = true
[dependencies.burn-wgpu]
version = "0.8.0"
optional = true
[dependencies] [dependencies]
burn = "0.8.0" burn = "0.20.1"
burn-ndarray = "0.20.1"
burn-tch = "0.20.1"
burn-autodiff = "0.20.1"
burn-wgpu = { version = "0.20.1", optional = true }
tch = "0.22.0"
serde = {version = "1.0.171", features = ["std", "derive"]} serde = {version = "1.0.171", features = ["std", "derive"]}
npy = "0.4.0" npy = "0.4.0"
num-traits = "0.2.15" num-traits = "0.2.15"
rust_tokenizers = "8.1.0" rust_tokenizers = "8.1.0"
regex = "1.9.1" regex = "1.9.1"
image = "0.24.6" image = "0.24.6"
bincode = {version = "2.0.0-alpha.0", features = ["std"]}
cfg-if = "0.1" cfg-if = "0.1"

View File

@@ -2,36 +2,35 @@
Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn. This repository is licensed under the MIT Licence. Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn. This repository is licensed under the MIT Licence.
## Support The Project
Stable-Diffusion-Burn is a passion project that is open and free to all. I want to empower everyone with reliable AI that can be run by ourselves on our own hardware to ensure that great AI is not limited to the hands of the few. If you support this vision consider supporting me so that I can continue on this journey and produce more projects such as Stable Diffusion XL in Rust.
You can show your support by buying a shirt at https://www.bonfire.com/machine-learning/. The shirt image was, of course, generated by my Rust powered Stable Diffusion! I'd love to release more projects and any support will help make that happen!
Any contribution would be greatly appreciated. Thanks!
## How To Use ## How To Use
### Step 0: Install libtorch v2.4.1
### Step 1: Download the Model and Set Environment Variables ### Step 1: Download the Model and Set Environment Variables
Start by downloading the SDv1-4.bin model provided on HuggingFace. Start by downloading the SDv1-4 model provided on HuggingFace.
```bash ```bash
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/SDv1-4.mpk
``` ```
Next, set the appropriate CUDA version. It may be possible to run the model using wgpu without the need for torch in the future using `cargo run --features wgpu-backend...` but currently wgpu doesn't support buffer sizes large enough for Stable Diffusion.
```bash
export TORCH_CUDA_VERSION=cu113
```
### Step 2: Run the Sample Binary ### Step 2: Run the Sample Binary
Invoke the sample binary provided in the rust code, as shown below: Invoke the sample binary provided in the rust code. By default, torch is used. The WGPU backend is unstable for SD but may work well in the future as burn-wpu is optimized.
```bash ```bash
# torch (at least 6 GB VRAM, possibly less)
# Arguments: <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [cuda, mps, cpu]
# Cuda
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img cuda
# Mps(Mac)
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img mps
# wgpu (UNSTABLE)
# Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image> # Arguments: <model_type(burn or dump)> <model> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image>
cargo run --release --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img cargo run --release --features wgpu-backend --bin sample burn SDv1-4 7.5 20 "An ancient mossy stone." img
``` ```
This command will generate an image according to the provided prompt, which will be saved as 'img0.png'. This command will generate an image according to the provided prompt, which will be saved as 'img0.png'.
@@ -40,7 +39,7 @@ This command will generate an image according to the provided prompt, which will
### Optional: Extract and Convert a Fine-Tuned Model ### Optional: Extract and Convert a Fine-Tuned Model
If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. If users are interested in using a fine-tuned version of stable diffusion, the Python scripts provided in this project can be used to transform a weight dump into a Burn model file. This does not work on Windows.
```bash ```bash
# Step into the Python directory # Step into the Python directory
@@ -49,6 +48,9 @@ cd python
# Download the model, this is just the base v1.4 model as an example # Download the model, this is just the base v1.4 model as an example
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
# Install tinygrad
pip install -r requirements.txt
# Extract the weights # Extract the weights
CPU=1 python3 dump.py sd-v1-4.ckpt CPU=1 python3 dump.py sd-v1-4.ckpt

BIN
img0.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 671 KiB

After

Width:  |  Height:  |  Size: 677 KiB

View File

@@ -13,10 +13,11 @@ from collections import namedtuple
from tqdm import tqdm from tqdm import tqdm
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad import dtypes
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file #from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict from tinygrad.nn.state import torch_load, load_state_dict
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code # TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code

1
python/requirements.txt Normal file
View File

@@ -0,0 +1 @@
tinygrad==0.9.2

139
src/backend.rs Normal file
View File

@@ -0,0 +1,139 @@
use burn::tensor::{activation::softmax, Tensor};
use burn::prelude::Backend;
/*pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
pub trait Backend: burn::tensor::backend::Backend {
fn qkv_attention(
q: FloatTensor<Self, 3>,
k: FloatTensor<Self, 3>,
v: FloatTensor<Self, 3>,
mask: Option<FloatTensor<Self, 2>>,
n_head: usize,
) -> FloatTensor<Self, 3> {
qkv_attention(
Tensor::<Self, 3>::from_primitive(q),
Tensor::from_primitive(k),
Tensor::from_primitive(v),
mask.map(|m| Tensor::from_primitive(m)),
n_head,
)
.into_primitive()
}
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor<Self, 2> {
attn_decoder_mask::<Self>(seq_length, device).into_primitive()
}
}
use burn::tensor::Float;
use burn_tch::{self, TchElement, TchTensor};
use tch;
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
fn qkv_attention(
q: FloatTensor<Self, 3>,
k: FloatTensor<Self, 3>,
v: FloatTensor<Self, 3>,
mask: Option<FloatTensor<Self, 2>>,
n_head: usize,
) -> FloatTensor<Self, 2> {
let q = Tensor::from_primitive(q);
let k = Tensor::from_primitive(k);
let v = Tensor::from_primitive(v);
let [n_batch, q_ctx, n_state] = q.dims();
let [_, k_ctx, _] = k.dims();
let n_hstate = n_state / n_head;
let rearrange = |t: Tensor<Self, 3>| {
let [_, n_ctx, _] = t.dims();
t.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2)
};
let q = rearrange(q).into_primitive();
let k = rearrange(k).into_primitive();
let v = rearrange(v).into_primitive();
// for some reason torch crashes when mask is None
let mask = mask.unwrap_or_else(|| {
Tensor::<Self, 2, Float>::zeros([q_ctx, k_ctx], &Self::device(&v))
.into_primitive()
});
Tensor::<Self, 4>::from_primitive(TchTensor::new(
tch::Tensor::scaled_dot_product_attention(
&q.tensor,
&k.tensor,
&v.tensor,
Some(mask.tensor),
0.0,
false,
None,
),
))
.swap_dims(1, 2)
.flatten(2, 3)
.into_primitive()
}
}
use burn_autodiff;
impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {}*/
use std::f32::NEG_INFINITY;
pub fn qkv_attention<B: Backend>(
q: Tensor<B, 3>,
k: Tensor<B, 3>,
v: Tensor<B, 3>,
mask: Option<Tensor<B, 2>>,
n_head: usize,
) -> Tensor<B, 3> {
let [n_batch, n_qctx, n_state] = q.dims();
let [_, n_ctx, _] = k.dims();
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
let n_hstate = n_state / n_head;
let q = q
.reshape([n_batch, n_qctx, n_head, n_hstate])
.swap_dims(1, 2)
* scale;
let k = k
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2)
.transpose()
* scale;
let v = v
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2);
let qk = q.matmul(k);
// apply mask
let qk = if let Some(mask) = mask {
qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
} else {
qk
};
// normalize value weightings
let w = softmax(qk, 3);
let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
return o;
}
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask;
}

View File

@@ -1,33 +1,27 @@
use std::env; use std::env;
use std::process;
use std::error::Error; use std::error::Error;
use std::process;
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion}; use stablediffusion::model::stablediffusion::{load::load_stable_diffusion, StableDiffusion};
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
cfg_if::cfg_if! { use burn_ndarray::{NdArray, NdArrayDevice};
if #[cfg(feature = "torch-backend")] {
use burn_tch::{TchBackend, TchDevice};
} else if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
}
}
use burn::record::{self, Recorder, FullPrecisionSettings}; use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> { fn convert_dump_to_model<B: Backend>(
dump_path: &str,
model_name: &str,
device: &B::Device,
) -> Result<(), Box<dyn Error>> {
println!("Loading dump..."); println!("Loading dump...");
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?; let model: StableDiffusion<B> = load_stable_diffusion(dump_path, device)?;
println!("Saving model..."); println!("Saving model...");
save_model_file(model, model_name)?; save_model_file(model, model_name)?;
@@ -35,24 +29,16 @@ fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device:
Ok(()) Ok(())
} }
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> { fn save_model_file<B: Backend>(
BinFileRecorderBuffered::<FullPrecisionSettings>::new() model: StableDiffusion<B>,
.record( name: &str,
model.into_record(), ) -> Result<(), record::RecorderError> {
name.into(), NamedMpkFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
)
} }
fn main() { fn main() {
cfg_if::cfg_if! { type Backend = NdArray<f32>;
if #[cfg(feature = "torch-backend")] { let device = NdArrayDevice::Cpu;
type Backend = TchBackend<f32>;
let device = TchDevice::Cpu;
} else if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::CPU;
}
}
let args: Vec<String> = env::args().collect(); let args: Vec<String> = env::args().collect();
if args.len() != 3 { if args.len() != 3 {

View File

@@ -1,20 +1,20 @@
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}}; use stablediffusion::{
model::stablediffusion::{load::load_stable_diffusion, *},
tokenizer::SimpleTokenizer,
};
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] { if #[cfg(feature = "wgpu-backend")] {
use burn_tch::{TchBackend, TchDevice}; use burn_wgpu::{Wgpu, WgpuDevice};
} else if #[cfg(feature = "wgpu-backend")] { } else {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi}; use burn_tch::{LibTorch, LibTorchDevice};
} }
} }
@@ -22,29 +22,21 @@ use std::env;
use std::io; use std::io;
use std::process; use std::process;
use burn::record::{self, Recorder, FullPrecisionSettings}; use burn::record::{self, NamedMpkFileRecorder, FullPrecisionSettings, Recorder};
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> { fn load_stable_diffusion_model_file<B: Backend>(
BinFileRecorderBuffered::<FullPrecisionSettings>::new() filename: &str,
.load(filename.into()) device: &B::Device,
.map(|record| StableDiffusionConfig::new().init().load_record(record)) ) -> Result<StableDiffusion<B>, record::RecorderError> {
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into(), device)
.map(|record| StableDiffusionConfig::new().init(device).load_record(record))
} }
fn main() { fn main() {
cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] {
type Backend = TchBackend<f32>;
let device = TchDevice::Cuda(0);
} else if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
}
}
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
if args.len() != 7 { if args.len() != 7 && args.len() != 8 {
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]); eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name> [device(cuda, mps, cpu)]", args[0]);
process::exit(1); process::exit(1);
} }
@@ -61,11 +53,40 @@ fn main() {
let prompt = &args[5]; let prompt = &args[5];
let output_image_name = &args[6]; let output_image_name = &args[6];
// Optional device parameter
let device_arg = if args.len() == 8 { Some(&args[7]) } else { None };
cfg_if::cfg_if! {
if #[cfg(feature = "wgpu-backend")] {
type Backend = Wgpu;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = LibTorch<f32>;
let device = if let Some(dev_str) = device_arg {
match dev_str.to_lowercase().as_str() {
"cpu" => LibTorchDevice::Cpu,
"mps" => LibTorchDevice::Mps,
s if s.starts_with("cuda") => {
let idx = s[4..].parse().unwrap_or(0);
LibTorchDevice::Cuda(idx)
}
_ => {
eprintln!("Unknown device: {}", dev_str);
process::exit(1);
}
}
} else {
LibTorchDevice::Cuda(0)
};
}
}
println!("Loading tokenizer..."); println!("Loading tokenizer...");
let tokenizer = SimpleTokenizer::new().unwrap(); let tokenizer = SimpleTokenizer::new().unwrap();
println!("Loading model..."); println!("Loading model...");
let sd: StableDiffusion<Backend> = if model_type == "burn" { let sd: StableDiffusion<Backend> = if model_type == "burn" {
load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| { load_stable_diffusion_model_file(model_name, &device).unwrap_or_else(|err| {
eprintln!("Error loading model: {}", err); eprintln!("Error loading model: {}", err);
process::exit(1); process::exit(1);
}) })
@@ -76,20 +97,23 @@ fn main() {
}) })
}; };
let sd = sd.to_device(&device);
let unconditional_context = sd.unconditional_context(&tokenizer); let unconditional_context = sd.unconditional_context(&tokenizer);
let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples
println!("Sampling image..."); println!("Sampling image...");
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps); let images = sd.sample_image(
context,
unconditional_context,
unconditional_guidance_scale,
n_steps,
);
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| { save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
eprintln!("Error saving image: {}", err); eprintln!("Error saving image: {}", err);
process::exit(1); process::exit(1);
}); });
} }
use image::{self, ImageResult, ColorType::Rgb8}; use image::{self, ColorType::Rgb8, ImageResult};
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> { fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
for (index, img_data) in images.iter().enumerate() { for (index, img_data) in images.iter().enumerate() {
@@ -104,12 +128,15 @@ fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -
fn save_test_image() -> ImageResult<()> { fn save_test_image() -> ImageResult<()> {
let width = 256; let width = 256;
let height = 256; let height = 256;
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| { let raw: Vec<_> = (0..width * height)
let row = i / width; .into_iter()
let red = (255.0 * row as f64 / height as f64) as u8; .flat_map(|i| {
let row = i / width;
let red = (255.0 * row as f64 / height as f64) as u8;
[red, 0, 0] [red, 0, 0]
}).collect(); })
.collect();
image::save_buffer("red.png", &raw[..], width, height, Rgb8) image::save_buffer("red.png", &raw[..], width, height, Rgb8)
} }

View File

@@ -1,86 +0,0 @@
use bincode;
use burn::record::{PrecisionSettings, Recorder, RecorderError, FileRecorder};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::PathBuf;
use std::marker::PhantomData;
use serde::{de::DeserializeOwned, Serialize};
//use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
fn bin_config() -> bincode::config::Configuration {
bincode::config::standard()
}
macro_rules! str2reader {
($file:expr) => {{
$file.set_extension(<Self as FileRecorder>::file_extension());
let path = $file.as_path();
File::open(path).map_err(|err| match err.kind() {
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
_ => RecorderError::Unknown(err.to_string()),
}).map(|file| BufReader::new(file)) // wrap File in BufReader
}};
}
macro_rules! str2writer {
($file:expr) => {{
$file.set_extension(<Self as FileRecorder>::file_extension());
let path = $file.as_path();
if path.exists() {
//log::info!("File exists, replacing");
std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;
}
File::create(path).map_err(|err| match err.kind() {
std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),
_ => RecorderError::Unknown(err.to_string()),
}).map(|file| BufWriter::new(file)) // wrap File in BufWriter
}};
}
#[derive(Debug, Default, Clone)]
pub struct BinFileRecorderBuffered<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
impl<S: PrecisionSettings> BinFileRecorderBuffered<S> {
pub fn new() -> Self {
BinFileRecorderBuffered {
_settings: PhantomData,
}
}
}
impl<S: PrecisionSettings> FileRecorder for BinFileRecorderBuffered<S> {
fn file_extension() -> &'static str {
"bin"
}
}
impl<S: PrecisionSettings> Recorder for BinFileRecorderBuffered<S> {
type Settings = S;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = PathBuf;
fn save_item<I: Serialize>(
&self,
item: I,
mut file: Self::RecordArgs,
) -> Result<(), RecorderError> {
let config = bin_config();
let mut writer = str2writer!(file)?;
bincode::serde::encode_into_std_write(&item, &mut writer, config)
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(())
}
fn load_item<I: DeserializeOwned>(&self, mut file: Self::LoadArgs) -> Result<I, RecorderError> {
let mut reader = str2reader!(file)?;
let state =
bincode::serde::decode_from_std_read(&mut reader, bin_config())
.map_err(|err| RecorderError::Unknown(err.to_string()))?;
Ok(state)
}
}

View File

@@ -1,87 +0,0 @@
use burn::{
tensor::{
backend::Backend,
activation::relu,
Tensor,
Int,
Bool,
Float,
TensorKind,
BasicOps,
Numeric,
Element,
},
};
use num_traits::ToPrimitive;
pub fn tensor_max_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, max: f64) -> Tensor<B, D> {
relu(x.sub_scalar(max)).add_scalar(max)
}
pub fn tensor_min_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, min: f64) -> Tensor<B, D> {
-tensor_max_scalar(-x, -min)
}
pub fn tensor_max<B: Backend, const D: usize>(x: Tensor<B, D>, max: Tensor<B, D>) -> Tensor<B, D> {
relu(x - max.clone()) + max
}
pub fn tensor_min<B: Backend, const D: usize>(x: Tensor<B, D>, min: Tensor<B, D>) -> Tensor<B, D> {
-tensor_max(-x, -min)
}
pub fn tensor_log10<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
let ln10 = (10.0f64).ln();
x.log() / ln10
}
pub fn tensor_max_element<B: Backend, const D: usize>(x: Tensor<B, D>) -> f64 {
let flat: Tensor<B, 1> = x.flatten(0, D - 1);
let max_index = flat.clone().argmax(0);
flat.select(0, max_index).into_scalar().to_f64().unwrap()
}
pub fn all_zeros<B: Backend, const D: usize>(x: Tensor<B, D>) -> bool {
x.powf(2.0).sum().into_scalar().to_f64().unwrap() == 0.0
}
pub fn max_dim<B: Backend>(x: Tensor<B, 2>, dim: usize) -> Tensor<B, 2> {
let indices = x.clone().argmax(dim).flatten(0, 1);
x.select(dim, indices)
}
pub fn _10pow<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
let log10 = (10.0f64).ln();
(x * log10).exp()
}
pub fn to_float<B: Backend, const D: usize>(x: Tensor<B, D, Int>) -> Tensor<B, D, Float> {
let device = x.device();
Tensor::from_data(
x
.into_data()
.convert()
).to_device(&device)
}
pub fn to_float_bool<B: Backend, const D: usize>(x: Tensor<B, D, Bool>) -> Tensor<B, D, Float> {
let device = x.device();
Tensor::from_data(
x
.into_data()
.convert()
).to_device(&device)
}
pub fn reverse<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B> + Numeric<B>>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K> where <K as BasicOps<B>>::Elem: Element {
let len = x.dims()[dim];
let indices = -Tensor::arange_device(0..len, &x.device()) + (len - 1) as i64;
x.select(dim, indices)
}
pub fn div_roundup(x: usize, y: usize) -> usize {
(x + y - 1) / y
}

View File

@@ -1,4 +1,3 @@
pub mod backend;
pub mod model; pub mod model;
pub mod tokenizer; pub mod tokenizer;
pub mod helper;
pub mod binrecorderfast;

View File

@@ -1,23 +1,32 @@
use burn::{ use burn::tensor::{activation::softmax, backend::Backend, Tensor};
tensor::{
backend::Backend,
activation::softmax,
Tensor,
},
};
use std::f32::NEG_INFINITY; use std::f32::NEG_INFINITY;
pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B, 3>, mask: Option<Tensor<B, 2>>, n_head: usize) -> Tensor<B, 3> { pub fn qkv_attention<B: Backend>(
q: Tensor<B, 3>,
k: Tensor<B, 3>,
v: Tensor<B, 3>,
mask: Option<Tensor<B, 2>>,
n_head: usize,
) -> Tensor<B, 3> {
let [n_batch, n_qctx, n_state] = q.dims(); let [n_batch, n_qctx, n_state] = q.dims();
let [_, n_ctx, _] = k.dims(); let [_, n_ctx, _] = k.dims();
let scale = (n_state as f64 / n_head as f64).powf(-0.25); let scale = (n_state as f64 / n_head as f64).powf(-0.25);
let n_hstate = n_state / n_head; let n_hstate = n_state / n_head;
let q = q.reshape([n_batch, n_qctx, n_head, n_hstate]).swap_dims(1, 2) * scale; let q = q
let k = k.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2).transpose() * scale; .reshape([n_batch, n_qctx, n_head, n_hstate])
let v = v.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2); .swap_dims(1, 2)
* scale;
let k = k
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2)
.transpose()
* scale;
let v = v
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2);
let qk = q.matmul(k); let qk = q.matmul(k);
@@ -36,12 +45,12 @@ pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B,
} }
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> { pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]); let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) { for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY); let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
} }
return mask.to_device(device); return mask;
} }

View File

@@ -7,26 +7,35 @@ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
use super::*; use super::*;
use crate::model::groupnorm::load::load_group_norm; use crate::model::groupnorm::load::load_group_norm;
fn load_conv_self_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> { fn load_conv_self_attention_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?; let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?;
let q = load_conv2d(&format!("{}/{}", path, "q"), device)?; let q = load_conv2d(&format!("{}/{}", path, "q"), device)?;
let k = load_conv2d(&format!("{}/{}", path, "k"), device)?; let k = load_conv2d(&format!("{}/{}", path, "k"), device)?;
let v = load_conv2d(&format!("{}/{}", path, "v"), device)?; let v = load_conv2d(&format!("{}/{}", path, "v"), device)?;
let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?; let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?;
Ok(ConvSelfAttentionBlock { norm, q, k, v, proj_out }) Ok(ConvSelfAttentionBlock {
norm,
q,
k,
v,
proj_out,
})
} }
fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResnetBlock<B>, Box<dyn Error>> { fn load_resnet_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResnetBlock<B>, Box<dyn Error>> {
let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?; let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?;
let silu1 = SILU {}; let silu1 = SILU {};
let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?; let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?;
@@ -35,7 +44,15 @@ fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<Resne
let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?; let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?;
let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok(); let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok();
Ok(ResnetBlock { norm1, silu1, conv1, norm2, silu2, conv2, nin_shortcut }) Ok(ResnetBlock {
norm1,
silu1,
conv1,
norm2,
silu2,
conv2,
nin_shortcut,
})
} }
fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> { fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> {
@@ -43,11 +60,18 @@ fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dy
let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?; let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?;
let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?; let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?;
Ok(Mid { block_1, attn, block_2 }) Ok(Mid {
block_1,
attn,
block_2,
})
} }
fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<PaddedConv2d<B>, Box<dyn Error>> { fn load_padded_conv2d<B: Backend>(
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?; path: &str,
device: &B::Device,
) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
let mut conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
let channels = load_tensor::<B, 1>("channels", path, device)?; let channels = load_tensor::<B, 1>("channels", path, device)?;
let channels = tensor_to_array_2(channels); let channels = tensor_to_array_2(channels);
@@ -57,35 +81,55 @@ fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<Padd
let padding = load_tensor::<B, 1>("padding", path, device)?; let padding = load_tensor::<B, 1>("padding", path, device)?;
let padding: [usize; 4] = tensor_to_array(padding); let padding: [usize; 4] = tensor_to_array(padding);
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]); let padding = PaddingCfg::new(padding[0], padding[1], padding[2], padding[3]);
let mut record = conv.into_record(); //let mut record = conv.into_record();
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding).with_stride(stride).init(); let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding)
let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]); .with_stride(stride)
.init(device);
let padding_actual =
PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual); conv.padding = burn::module::Ignored(padding_actual);
padded_conv.conv = padded_conv.conv.load_record(record); padded_conv.conv = conv;
//record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
//padded_conv.conv = padded_conv.conv.load_record(record);
Ok(padded_conv) Ok(padded_conv)
} }
fn load_decoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<DecoderBlock<B>, Box<dyn Error>> { fn load_decoder_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<DecoderBlock<B>, Box<dyn Error>> {
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?; let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?; let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?; let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?;
let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok(); let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok();
Ok(DecoderBlock { res1, res2, res3, upsampler }) Ok(DecoderBlock {
res1,
res2,
res3,
upsampler,
})
} }
fn load_encoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<EncoderBlock<B>, Box<dyn Error>> { fn load_encoder_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<EncoderBlock<B>, Box<dyn Error>> {
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?; let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?; let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok(); let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok();
Ok(EncoderBlock { res1, res2, downsampler }) Ok(EncoderBlock {
res1,
res2,
downsampler,
})
} }
fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> { fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> {
@@ -95,15 +139,21 @@ fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>
let n_block = load_usize::<B>("n_block", path, device)?; let n_block = load_usize::<B>("n_block", path, device)?;
let mut blocks = (0..n_block) let mut blocks = (0..n_block)
.into_iter() .into_iter()
.map(|i| { .map(|i| load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device) .collect::<Result<Vec<_>, _>>()?;
}).collect::<Result<Vec<_>, _>>()?;
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?; let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
let silu = SILU {}; let silu = SILU {};
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?; let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
Ok(Decoder { conv_in, mid, blocks, norm_out, silu, conv_out }) Ok(Decoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
})
} }
fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> { fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> {
@@ -113,22 +163,36 @@ fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>
let n_block = load_usize::<B>("n_block", path, device)?; let n_block = load_usize::<B>("n_block", path, device)?;
let mut blocks = (0..n_block) let mut blocks = (0..n_block)
.into_iter() .into_iter()
.map(|i| { .map(|i| load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device) .collect::<Result<Vec<_>, _>>()?;
}).collect::<Result<Vec<_>, _>>()?;
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?; let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
let silu = SILU {}; let silu = SILU {};
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?; let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
Ok(Encoder { conv_in, mid, blocks, norm_out, silu, conv_out }) Ok(Encoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
})
} }
pub fn load_autoencoder<B: Backend>(path: &str, device: &B::Device) -> Result<Autoencoder<B>, Box<dyn Error>> { pub fn load_autoencoder<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<Autoencoder<B>, Box<dyn Error>> {
let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?; let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?;
let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?; let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?;
let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?; let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?;
let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?; let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?;
Ok(Autoencoder { encoder, decoder, quant_conv, post_quant_conv }) Ok(Autoencoder {
encoder,
decoder,
quant_conv,
post_quant_conv,
})
} }

View File

@@ -3,35 +3,37 @@ pub mod load;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}}, nn::{
self,
conv::{Conv2d, Conv2dConfig, Conv2dRecord},
PaddingConfig2d,
},
tensor::{ tensor::{
activation::{sigmoid, softmax},
backend::Backend, backend::Backend,
activation::{softmax, sigmoid},
module::embedding, module::embedding,
Tensor, Distribution, Int, Tensor,
Distribution,
Int,
}, },
}; };
use crate::helper::div_roundup;
use super::silu::*;
use super::groupnorm::*; use super::groupnorm::*;
use super::attention::qkv_attention; use super::silu::*;
//use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
use std::iter; use std::iter;
#[derive(Config, Debug)]
#[derive(Config)]
pub struct AutoencoderConfig {} pub struct AutoencoderConfig {}
impl AutoencoderConfig { impl AutoencoderConfig {
pub fn init<B: Backend>(&self) -> Autoencoder<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(); let encoder =
let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(); EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(); let decoder =
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(); DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
Autoencoder { Autoencoder {
encoder, encoder,
@@ -42,7 +44,6 @@ impl AutoencoderConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct Autoencoder<B: Backend> { pub struct Autoencoder<B: Backend> {
encoder: Encoder<B>, encoder: Encoder<B>,
@@ -53,7 +54,7 @@ pub struct Autoencoder<B: Backend> {
impl<B: Backend> Autoencoder<B> { impl<B: Backend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent( self.encode_image(x) ) self.decode_latent(self.encode_image(x))
} }
pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
@@ -70,7 +71,7 @@ impl<B: Backend> Autoencoder<B> {
} }
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct EncoderConfig { pub struct EncoderConfig {
channels: Vec<(usize, usize)>, channels: Vec<(usize, usize)>,
n_group: usize, n_group: usize,
@@ -78,21 +79,34 @@ pub struct EncoderConfig {
} }
impl EncoderConfig { impl EncoderConfig {
fn init<B: Backend>(&self) -> Encoder<B> { fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty."); let n_expanded_channels_initial = self
.channels
.first()
.map(|f| f.1)
.expect("Channels must not be empty.");
let n_expanded_channels_final = self.channels.first().unwrap().0; let n_expanded_channels_final = self.channels.first().unwrap().0;
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| { let blocks = self
let downsample = i != self.channels.len() - 1; .channels
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init() .iter()
}).collect(); .enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
})
.collect();
let mid = MidConfig::new(n_expanded_channels_final).init(); let mid = MidConfig::new(n_expanded_channels_final).init(device);
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(); let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
let silu = SILU::new(); let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
Encoder { Encoder {
conv_in, conv_in,
@@ -105,7 +119,6 @@ impl EncoderConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct Encoder<B: Backend> { pub struct Encoder<B: Backend> {
conv_in: Conv2d<B>, conv_in: Conv2d<B>,
@@ -126,34 +139,46 @@ impl<B: Backend> Encoder<B> {
} }
let x = self.mid.forward(x); let x = self.mid.forward(x);
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) ) self.conv_out
.forward(self.silu.forward(self.norm_out.forward(x)))
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct DecoderConfig { pub struct DecoderConfig {
channels: Vec<(usize, usize)>, channels: Vec<(usize, usize)>,
n_group: usize, n_group: usize,
} }
impl DecoderConfig { impl DecoderConfig {
fn init<B: Backend>(&self) -> Decoder<B> { fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty."); let n_expanded_channels = self
.channels
.first()
.map(|f| f.0)
.expect("Channels must not be empty.");
let n_condensed_channels = self.channels.last().unwrap().1; let n_condensed_channels = self.channels.last().unwrap().1;
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
let mid = MidConfig::new(n_expanded_channels).init(); .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let mid = MidConfig::new(n_expanded_channels).init(device);
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| { let blocks = self
let upsample = i != self.channels.len() - 1; .channels
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init() .iter()
}).collect(); .enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
})
.collect();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(); let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
let silu = SILU::new(); let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
Decoder { Decoder {
conv_in, conv_in,
@@ -166,7 +191,6 @@ impl DecoderConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct Decoder<B: Backend> { pub struct Decoder<B: Backend> {
conv_in: Conv2d<B>, conv_in: Conv2d<B>,
@@ -187,11 +211,12 @@ impl<B: Backend> Decoder<B> {
x = block.forward(x); x = block.forward(x);
} }
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) ) self.conv_out
.forward(self.silu.forward(self.norm_out.forward(x)))
} }
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct EncoderBlockConfig { pub struct EncoderBlockConfig {
n_channels_in: usize, n_channels_in: usize,
n_channels_out: usize, n_channels_out: usize,
@@ -199,12 +224,16 @@ pub struct EncoderBlockConfig {
} }
impl EncoderBlockConfig { impl EncoderBlockConfig {
fn init<B: Backend>(&self) -> EncoderBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(); let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let downsampler = if self.downsample { let downsampler = if self.downsample {
let padding = Padding::new(0, 1, 0, 1); let padding = PaddingCfg::new(0, 1, 0, 1);
Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() ) Some(
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
.with_stride(2)
.init(device),
)
} else { } else {
None None
}; };
@@ -236,7 +265,7 @@ impl<B: Backend> EncoderBlock<B> {
} }
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct DecoderBlockConfig { pub struct DecoderBlockConfig {
n_channels_in: usize, n_channels_in: usize,
n_channels_out: usize, n_channels_out: usize,
@@ -244,12 +273,16 @@ pub struct DecoderBlockConfig {
} }
impl DecoderBlockConfig { impl DecoderBlockConfig {
fn init<B: Backend>(&self) -> DecoderBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(); let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let upsampler = if self.upsample { let upsampler = if self.upsample {
Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() ) Some(
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device),
)
} else { } else {
None None
}; };
@@ -280,10 +313,9 @@ impl<B: Backend> DecoderBlock<B> {
if let Some(d) = self.upsampler.as_ref() { if let Some(d) = self.upsampler.as_ref() {
let [n_batch, n_channel, height, width] = x.dims(); let [n_batch, n_channel, height, width] = x.dims();
let x = x let x = x
.reshape([n_batch, n_channel, height, 1, width, 1]) .reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2) .repeat(&[1, 1, 1, 2, 1, 2])
.repeat(5, 2) .reshape([n_batch, n_channel, 2 * height, 2 * width]);
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
d.forward(x) d.forward(x)
} else { } else {
x x
@@ -291,18 +323,17 @@ impl<B: Backend> DecoderBlock<B> {
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct PaddedConv2dConfig { pub struct PaddedConv2dConfig {
channels: [usize; 2], channels: [usize; 2],
kernel_size: usize, kernel_size: usize,
#[config(default = 1)] #[config(default = 1)]
stride: usize, stride: usize,
padding: Padding, padding: PaddingCfg,
} }
impl PaddedConv2dConfig { impl PaddedConv2dConfig {
fn init<B: Backend>(&self) -> PaddedConv2d<B> { fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
let calc_padding = |p_left, p_right| { let calc_padding = |p_left, p_right| {
let n = if p_left >= p_right { let n = if p_left >= p_right {
0 0
@@ -320,12 +351,17 @@ impl PaddedConv2dConfig {
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size]) let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
.with_stride([self.stride, self.stride]) .with_stride([self.stride, self.stride])
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal)) .with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
.init(); .init(device);
let kernel_size = self.kernel_size; let kernel_size = self.kernel_size;
let stride = self.stride; let stride = self.stride;
let padding = self.padding; let padding = Padding {
pad_left: self.padding.pad_left,
pad_right: self.padding.pad_right,
pad_top: self.padding.pad_top,
pad_bottom: self.padding.pad_bottom,
};
PaddedConv2d { PaddedConv2d {
conv, conv,
@@ -337,6 +373,10 @@ impl PaddedConv2dConfig {
} }
} }
fn div_roundup(x: usize, y: usize) -> usize {
(x + y - 1) / y
}
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct PaddedConv2d<B: Backend> { pub struct PaddedConv2d<B: Backend> {
conv: Conv2d<B>, conv: Conv2d<B>,
@@ -348,27 +388,38 @@ pub struct PaddedConv2d<B: Backend> {
impl<B: Backend> PaddedConv2d<B> { impl<B: Backend> PaddedConv2d<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
println!("{} {} {:?} {:?}", self.kernel_size, self.stride, self.padding, self.padding_actual);
let [n_batch, n_channel, height, width] = x.dims(); let [n_batch, n_channel, height, width] = x.dims();
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1; let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1; - self.kernel_size)
/ self.stride
+ 1;
let desired_width = (self.padding.pad_left + self.padding.pad_right + width
- self.kernel_size)
/ self.stride
+ 1;
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride; let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride; let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
self.conv self.conv.forward(x).slice([
.forward(x) 0..n_batch,
.slice([ 0..n_channel,
0..n_batch, skip_vert..(skip_vert + desired_height),
0..n_channel, skip_hor..(skip_hor + desired_width),
skip_vert..(skip_vert + desired_height), ])
skip_hor..(skip_hor + desired_width)
])
} }
} }
#[derive(Config, Module, Copy, Debug)] #[derive(Config, Debug)]
pub struct PaddingCfg {
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_bottom: usize,
}
#[derive(Module, Clone, Debug)]
pub struct Padding { pub struct Padding {
pad_left: usize, pad_left: usize,
pad_right: usize, pad_right: usize,
@@ -376,16 +427,16 @@ pub struct Padding {
pad_bottom: usize, pad_bottom: usize,
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct MidConfig { pub struct MidConfig {
n_channel: usize, n_channel: usize,
} }
impl MidConfig { impl MidConfig {
fn init<B: Backend>(&self) -> Mid<B> { fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(); let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(); let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(); let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
Mid { Mid {
block_1, block_1,
@@ -411,21 +462,24 @@ impl<B: Backend> Mid<B> {
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct ResnetBlockConfig { pub struct ResnetBlockConfig {
in_channels: usize, in_channels: usize,
out_channels: usize, out_channels: usize,
} }
impl ResnetBlockConfig { impl ResnetBlockConfig {
fn init<B: Backend>(&self) -> ResnetBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init(); let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
let norm2 = GroupNormConfig::new(32, self.out_channels).init(); .with_padding(PaddingConfig2d::Explicit(1, 1))
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); .init(device);
let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let nin_shortcut = if self.in_channels != self.out_channels { let nin_shortcut = if self.in_channels != self.out_channels {
Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() ) Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
} else { } else {
None None
}; };
@@ -458,9 +512,12 @@ pub struct ResnetBlock<B: Backend> {
impl<B: Backend> ResnetBlock<B> { impl<B: Backend> ResnetBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) ); let h = self
let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) ); .conv1
.forward(self.silu1.forward(self.norm1.forward(x.clone())));
let h = self
.conv2
.forward(self.silu2.forward(self.norm2.forward(h)));
if let Some(ns) = self.nin_shortcut.as_ref() { if let Some(ns) = self.nin_shortcut.as_ref() {
ns.forward(x) + h ns.forward(x) + h
@@ -470,18 +527,18 @@ impl<B: Backend> ResnetBlock<B> {
} }
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct ConvSelfAttentionBlockConfig { pub struct ConvSelfAttentionBlockConfig {
n_channel: usize, n_channel: usize,
} }
impl ConvSelfAttentionBlockConfig { impl ConvSelfAttentionBlockConfig {
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init(); let norm = GroupNormConfig::new(32, self.n_channel).init(device);
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
ConvSelfAttentionBlock { ConvSelfAttentionBlock {
norm, norm,
@@ -508,13 +565,41 @@ impl<B: Backend> ConvSelfAttentionBlock<B> {
let h = self.norm.forward(x.clone()); let h = self.norm.forward(x.clone());
let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); let q = self
let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); .q
let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); .forward(h.clone())
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let k = self
.k
.forward(h.clone())
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let v = self
.v
.forward(h)
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let wv = qkv_attention(q, k, v, None, 1) /*let wv = Tensor::from_primitive(B::qkv_attention(
.swap_dims(1, 2) q.into_primitive(),
.reshape([n_batch, n_channel, height, width]); k.into_primitive(),
v.into_primitive(),
None,
1,
))
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);*/
let wv = qkv_attention(
q,
k,
v,
None,
1,
)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
let projected = self.proj_out.forward(wv); let projected = self.proj_out.forward(wv);

View File

@@ -1,14 +1,11 @@
use std::error::Error;
use burn::tensor::ElementConversion; use burn::tensor::ElementConversion;
use std::error::Error;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
use super::*; use super::*;
@@ -28,7 +25,10 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
Ok(mlp) Ok(mlp)
} }
pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> { pub fn load_multi_head_self_attention<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
let n_head = load_usize::<B>("n_head", path, device)?; let n_head = load_usize::<B>("n_head", path, device)?;
let query = load_linear(&format!("{}/{}", path, "query"), device)?; let query = load_linear(&format!("{}/{}", path, "query"), device)?;
let key = load_linear(&format!("{}/{}", path, "key"), device)?; let key = load_linear(&format!("{}/{}", path, "key"), device)?;
@@ -46,7 +46,10 @@ pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device
Ok(mhsa) Ok(mhsa)
} }
pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> { pub fn load_residual_decoder_attention_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?; let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?; let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?; let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
@@ -64,14 +67,16 @@ pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B:
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> { pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?; let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into(); let position_embedding =
Param::from_tensor(load_tensor("weight", &format!("{}/position_embedding", path), device)?);
let n_layer = load_usize::<B>("n_layer", path, device)?; let n_layer = load_usize::<B>("n_layer", path, device)?;
let mut blocks = (0..n_layer) let mut blocks = (0..n_layer)
.into_iter() .into_iter()
.map(|i| { .map(|i| {
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device) load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
}).collect::<Result<Vec<_>, _>>()?; })
.collect::<Result<Vec<_>, _>>()?;
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?; let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;

View File

@@ -5,19 +5,17 @@ use burn::{
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{
activation::{sigmoid, softmax},
backend::Backend, backend::Backend,
activation::{softmax, sigmoid},
module::embedding, module::embedding,
Tensor, Distribution, Int, Tensor,
Distribution,
Int,
}, },
}; };
use crate::model::attention::{qkv_attention, attn_decoder_mask}; //use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
#[derive(Config, Debug)]
#[derive(Config)]
pub struct CLIPConfig { pub struct CLIPConfig {
n_vocab: usize, n_vocab: usize,
n_state: usize, n_state: usize,
@@ -27,14 +25,15 @@ pub struct CLIPConfig {
} }
impl CLIPConfig { impl CLIPConfig {
pub fn init<B: Backend>(&self) -> CLIP<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(); let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device);
let position_embedding = Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into(); let position_embedding =
Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device));
let blocks = (0..self.n_layer) let blocks = (0..self.n_layer)
.into_iter() .into_iter()
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init()) .map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device))
.collect(); .collect();
let layer_norm = nn::LayerNormConfig::new(self.n_state).init(); let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device);
CLIP { CLIP {
token_embedding, token_embedding,
@@ -45,8 +44,6 @@ impl CLIPConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct CLIP<B: Backend> { pub struct CLIP<B: Backend> {
token_embedding: nn::Embedding<B>, token_embedding: nn::Embedding<B>,
@@ -59,10 +56,15 @@ impl<B: Backend> CLIP<B> {
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> { pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims(); let [n_batch, seq_len] = x.dims();
//let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
let mask = attn_decoder_mask(seq_len, &x.device()); let mask = attn_decoder_mask(seq_len, &x.device());
let embedded = self.token_embedding.forward(x) let embedded = self.token_embedding.forward(x)
+ self.position_embedding.val().slice([0..seq_len]).unsqueeze(); + self
.position_embedding
.val()
.slice([0..seq_len])
.unsqueeze();
let mut x = embedded; let mut x = embedded;
for block in &self.blocks { for block in &self.blocks {
@@ -73,21 +75,19 @@ impl<B: Backend> CLIP<B> {
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct ResidualDecoderAttentionBlockConfig { pub struct ResidualDecoderAttentionBlockConfig {
n_state: usize, n_state: usize,
n_head: usize, n_head: usize,
} }
impl ResidualDecoderAttentionBlockConfig { impl ResidualDecoderAttentionBlockConfig {
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(); let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
let attn_ln = nn::LayerNormConfig::new(self.n_state).init(); let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(); let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device);
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(); let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
ResidualDecoderAttentionBlock { ResidualDecoderAttentionBlock {
attn, attn,
@@ -114,28 +114,33 @@ impl<B: Backend> ResidualDecoderAttentionBlock<B> {
} }
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct MultiHeadSelfAttentionConfig { pub struct MultiHeadSelfAttentionConfig {
n_state: usize, n_state: usize,
n_head: usize, n_head: usize,
} }
impl MultiHeadSelfAttentionConfig { impl MultiHeadSelfAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> { fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head); assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
self.n_state,
self.n_head
);
let n_head = self.n_head; let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(); let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let key = nn::LinearConfig::new(self.n_state, self.n_state).init(); let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let value = nn::LinearConfig::new(self.n_state, self.n_state).init(); let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(); let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadSelfAttention { MultiHeadSelfAttention {
n_head, n_head,
query, query,
key, key,
value, value,
out out,
} }
} }
} }
@@ -155,19 +160,26 @@ impl<B: Backend> MultiHeadSelfAttention<B> {
let k = self.key.forward(x.clone()); let k = self.key.forward(x.clone());
let v = self.value.forward(x); let v = self.value.forward(x);
let wv = qkv_attention(q, k, v, mask, self.n_head); /*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
mask.map(|m| m.into_primitive()),
self.n_head,
));*/
let wv = qkv_attention(
q,
k,
v,
mask,
self.n_head,
);
return self.out.forward(wv); return self.out.forward(wv);
} }
} }
#[derive(Config, Debug)] #[derive(Config, Debug)]
pub struct MLPConfig { pub struct MLPConfig {
input_size: usize, input_size: usize,
@@ -175,16 +187,12 @@ pub struct MLPConfig {
} }
impl MLPConfig { impl MLPConfig {
fn init<B: Backend>(&self) -> MLP<B> { fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(); let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device);
let gelu = QuickGELU::new(); let gelu = QuickGELU::new();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(); let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device);
MLP { MLP { fc1, gelu, fc2 }
fc1,
gelu,
fc2,
}
} }
} }
@@ -217,4 +225,3 @@ impl QuickGELU {
x.clone() * sigmoid(x * 1.702) x.clone() * sigmoid(x * 1.702)
} }
} }

View File

@@ -7,27 +7,31 @@ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
pub fn load_group_norm<B: Backend>(path: &str, device: &B::Device) -> Result<GroupNorm<B>, Box<dyn Error>> { pub fn load_group_norm<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<GroupNorm<B>, Box<dyn Error>> {
let n_group = load_usize::<B>("n_group", path, device)?.into(); let n_group = load_usize::<B>("n_group", path, device)?.into();
let n_channel = load_usize::<B>("n_channel", path, device)?.into(); let n_channel = load_usize::<B>("n_channel", path, device)?.into();
let eps = load_f32::<B>("eps", path, device)?.into(); let eps = load_f32::<B>("eps", path, device)?.into();
let gamma = load_tensor::<B, 1>("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into(); let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
let beta = load_tensor::<B, 1>("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into(); .ok()
.unwrap_or_else(|| Tensor::ones([n_channel], device))
);
let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
.ok()
.unwrap_or_else(|| Tensor::zeros([n_channel], device))
);
Ok( Ok(GroupNorm {
GroupNorm { n_group,
n_group, n_channel,
n_channel, gamma,
gamma, beta,
beta, eps,
eps, })
}
)
} }

View File

@@ -3,13 +3,10 @@ pub mod load;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
#[derive(Config)] #[derive(Config, Debug)]
pub struct GroupNormConfig { pub struct GroupNormConfig {
n_group: usize, n_group: usize,
n_channel: usize, n_channel: usize,
@@ -18,13 +15,18 @@ pub struct GroupNormConfig {
} }
impl GroupNormConfig { impl GroupNormConfig {
pub fn init<B: Backend>(&self) -> GroupNorm<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group); assert!(
self.n_channel % self.n_group == 0,
"The number of channels {} must be divisible by the number of groups {}",
self.n_channel,
self.n_group
);
let n_per_group = self.n_channel / self.n_group; let n_per_group = self.n_channel / self.n_group;
let gamma = Tensor::ones([self.n_channel]).into(); let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
let beta = Tensor::zeros([self.n_channel]).into(); let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
let eps = self.eps; let eps = self.eps;
@@ -56,10 +58,17 @@ impl<B: Backend> GroupNorm<B> {
let mut affine_shape = [1; D]; let mut affine_shape = [1; D];
affine_shape[1] = self.n_channel; affine_shape[1] = self.n_channel;
layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps ) layernorm(
.reshape(shape) x.reshape([
.mul(self.gamma.val().reshape(affine_shape)) n_batch,
.add(self.beta.val().reshape(affine_shape)) self.n_group,
num_elements / (n_batch * self.n_group),
]),
self.eps,
)
.reshape(shape)
.mul(self.gamma.val().reshape(affine_shape))
.add(self.beta.val().reshape(affine_shape))
} }
} }
@@ -68,5 +77,6 @@ pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tenso
//x.sub(mean).div(var.sqrt().add_scalar(eps)) //x.sub(mean).div(var.sqrt().add_scalar(eps))
let u = x.clone() - x.mean_dim(D - 1); let u = x.clone() - x.mean_dim(D - 1);
u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() ) u.clone()
.div((u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt())
} }

View File

@@ -1,36 +1,41 @@
use std::error::Error;
use std::io::Read;
use npy::{self, NpyData}; use npy::{self, NpyData};
use num_traits::cast::ToPrimitive; use num_traits::cast::ToPrimitive;
use burn::tensor::cast::ToElement;
use burn::prelude::TensorData;
use std::error::Error;
use std::io::Read;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn::{self, conv}, nn::{self, conv},
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
Data,
},
}; };
use burn::tensor::ElementConversion; use burn::tensor::ElementConversion;
pub fn numpy_to_tensor<B: Backend, const D: usize>(numpy_data: NpyData<f32>, device: &B::Device) -> Tensor<B, D> { pub fn numpy_to_tensor<B: Backend, const D: usize>(
numpy_data: NpyData<f32>,
device: &B::Device,
) -> Tensor<B, D> {
let mut v = numpy_data.to_vec(); let mut v = numpy_data.to_vec();
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect(); let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect(); let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
Tensor::from_data_device(Data::new(data, shape.into()), device) //Tensor::from_data_device(Data::new(data, shape.into()), device)
Tensor::from_data(TensorData::new(data, shape), device)
} }
pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &B::Device) -> Result<Tensor<B, D>, Box<dyn Error>> { pub fn load_tensor<B: Backend, const D: usize>(
name: &str,
path: &str,
device: &B::Device,
) -> Result<Tensor<B, D>, Box<dyn Error>> {
let tensor_path = format!("{}/{}.npy", path, name); let tensor_path = format!("{}/{}.npy", path, name);
let mut buf = vec![]; let mut buf = vec![];
std::fs::File::open(&tensor_path)? std::fs::File::open(&tensor_path)?.read_to_end(&mut buf)?;
.read_to_end(&mut buf)?;
let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?; let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?;
@@ -41,71 +46,79 @@ pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &
Ok(tensor) Ok(tensor)
} }
pub fn load_f32<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<f32, Box<dyn Error>> { pub fn load_f32<B: Backend>(
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap()) name: &str,
path: &str,
device: &B::Device,
) -> Result<f32, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32())
} }
pub fn load_usize<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<usize, Box<dyn Error>> { pub fn load_usize<B: Backend>(
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap()) name: &str,
path: &str,
device: &B::Device,
) -> Result<usize, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize())
} }
pub fn load_linear<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Linear<B>, Box<dyn Error>> { pub fn load_linear<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<nn::Linear<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?; let weight = load_tensor::<B, 2>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok(); let bias = load_tensor::<B, 1>("bias", path, device).ok();
let record = nn::LinearRecord { Ok(nn::Linear {
weight: weight.into(), weight: Param::from_tensor(weight),
bias: bias.map(|t| t.into()), bias: bias.map(|t| Param::from_tensor(t)),
}; })
let linear: nn::Linear<B> = nn::LinearConfig::new(3, 3).init_with(record);
Ok(linear)
} }
pub fn load_embedding<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Embedding<B>, Box<dyn Error>> { pub fn load_embedding<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<nn::Embedding<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?; let weight = load_tensor::<B, 2>("weight", path, device)?;
let [n_vocab, n_state] = weight.dims();
let record = nn::EmbeddingRecord { Ok(nn::Embedding {
weight: weight.into(), weight: Param::from_tensor(weight),
}; })
let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record);
Ok(embedding)
} }
pub fn load_layer_norm<B: Backend>(path: &str, device: &B::Device) -> Result<nn::LayerNorm<B>, Box<dyn Error>> { pub fn load_layer_norm<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<nn::LayerNorm<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 1>("weight", path, device)?; let weight = load_tensor::<B, 1>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device)?; let bias = load_tensor::<B, 1>("bias", path, device)?;
let eps = load_f32::<B>("eps", path, device)? as f64; let eps = load_f32::<B>("eps", path, device)? as f64;
let [n_state] = weight.dims(); let [n_state] = weight.dims();
let record = nn::LayerNormRecord { let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
gamma: weight.into(), layer_norm.gamma = Param::from_tensor(weight);
beta: bias.into(), layer_norm.beta = Some(Param::from_tensor(bias));
epsilon: <f64 as Module<B>>::into_record(eps),
};
let layer_norm: nn::LayerNorm<B> = nn::LayerNormConfig::new(n_state).init_with(record);
Ok(layer_norm) Ok(layer_norm)
} }
/*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> { /*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 1>("weight", path, device)?; let weight = load_tensor::<B, 1>("weight", path, device)?;
let eps = load_f32::<B>("eps", path, device)?.into(); let eps = load_f32::<B>("eps", path, device)?.into();
let rmsnorm = RMSNorm { let rmsnorm = RMSNorm {
weight: weight.into(), weight: Param::from_tensor(weight),
eps: eps eps: eps
}; };
Ok(rmsnorm) Ok(rmsnorm)
}*/ }*/
pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::Conv2d<B>, Box<dyn Error>> { pub fn load_conv2d<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<conv::Conv2d<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 4>("weight", path, device)?; let weight = load_tensor::<B, 4>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok(); let bias = load_tensor::<B, 1>("bias", path, device).ok();
let has_bias = bias.is_some(); let has_bias = bias.is_some();
@@ -127,40 +140,38 @@ pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::C
let padding = tensor_to_array_2(padding); let padding = tensor_to_array_2(padding);
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]); let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
let mut conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
.with_stride(stride)
.with_dilation(dilation)
.with_groups(n_group)
.with_padding(padding.clone())
.with_bias(has_bias)
.init(device);
let record = conv::Conv2dRecord { conv2d.weight = Param::from_tensor(weight);
weight: weight.into(), conv2d.bias = bias.map(|t| Param::from_tensor(t));
bias: bias.map(|t| t.into()), conv2d.stride = stride;
stride: <[usize; 2] as Module<B>>::into_record(stride), conv2d.kernel_size = kernel_size;
kernel_size: <[usize; 2] as Module<B>>::into_record(kernel_size), conv2d.dilation = dilation;
dilation: <[usize; 2] as Module<B>>::into_record(dilation), conv2d.groups = n_group;
groups: <usize as Module<B>>::into_record(n_group), conv2d.padding = burn::module::Ignored(padding);
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
};
let conv2d: conv::Conv2d<B> = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
.with_stride(stride)
.with_dilation(dilation)
.with_groups(n_group)
.with_padding(padding)
.with_bias(has_bias)
.init_with(record);
Ok(conv2d) Ok(conv2d)
} }
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] { pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
let vec = x.into_data().value; let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == 2, "Tensor length must be 2."); assert!(vec.len() == 2, "Tensor length must be 2.");
[vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()] [vec[0].to_usize(), vec[1].to_usize()]
} }
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] { pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
let vec = x.into_data().value; let vec: Vec<<B as Backend>::FloatElem> = x.into_data().to_vec().unwrap();
assert!(vec.len() == N, "Tensor length must be {}.", N); assert!(vec.len() == N, "Tensor length must be {}.", N);
let mut arr = [0; N]; let mut arr = [0; N];
for (a, t) in arr.iter_mut().zip(vec) { for (a, t) in arr.iter_mut().zip(vec) {
*a = t.to_usize().unwrap(); *a = t.to_usize();
} }
arr arr

View File

@@ -1,11 +1,11 @@
pub mod stablediffusion; pub mod stablediffusion;
pub mod autoencoder; pub mod autoencoder;
pub mod unet;
pub mod clip; pub mod clip;
pub mod unet;
pub mod silu;
pub mod groupnorm;
pub mod attention; pub mod attention;
pub mod groupnorm;
pub mod silu;
pub mod load; pub mod load;

View File

@@ -1,13 +1,8 @@
use burn::{ use burn::{
module::Module, module::Module,
tensor::{ tensor::{activation::sigmoid, backend::Backend, Tensor},
backend::Backend,
activation::sigmoid,
Tensor,
},
}; };
#[derive(Module, Clone, Debug)] #[derive(Module, Clone, Debug)]
pub struct SILU {} pub struct SILU {}

View File

@@ -1,22 +1,24 @@
use std::error::Error;
use burn::tensor::ElementConversion; use burn::tensor::ElementConversion;
use std::error::Error;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
use super::*; use super::*;
use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::load_unet, clip::load::load_clip}; use crate::model::{
autoencoder::load::load_autoencoder, clip::load::load_clip, load::*, unet::load::load_unet,
};
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> { pub fn load_stable_diffusion<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
let n_steps = load_usize::<B>("n_steps", path, device)?; let n_steps = load_usize::<B>("n_steps", path, device)?;
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into(); let alpha_cumulative_products = Param::from_tensor(load_tensor::<B, 1>("alphas_cumprod", path, device)?);
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?; let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?; let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?; let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
@@ -29,4 +31,3 @@ pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Resu
clip, clip,
}) })
} }

View File

@@ -3,37 +3,30 @@ pub mod load;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
tensor::{ tensor::{backend::Backend, BasicOps, Distribution, Float, Int, Tensor},
backend::Backend, tensor::cast::ToElement,
Tensor,
Int,
Float,
BasicOps,
Data,
Distribution,
},
}; };
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
//use crate::backend::Backend as MyBackend;
use super::autoencoder::{Autoencoder, AutoencoderConfig}; use super::autoencoder::{Autoencoder, AutoencoderConfig};
use super::clip::{CLIPConfig, CLIP};
use super::unet::{UNet, UNetConfig}; use super::unet::{UNet, UNetConfig};
use super::clip::{CLIP, CLIPConfig};
use crate::tokenizer::SimpleTokenizer; use crate::tokenizer::SimpleTokenizer;
#[derive(Config)] #[derive(Config, Debug)]
pub struct StableDiffusionConfig { pub struct StableDiffusionConfig {}
}
impl StableDiffusionConfig { impl StableDiffusionConfig {
pub fn init<B: Backend>(&self) -> StableDiffusion<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> StableDiffusion<B> {
let n_steps = 1000; let n_steps = 1000;
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps).into(); let alpha_cumulative_products = Param::from_tensor(offset_cosine_schedule_cumprod::<B>(n_steps as i64, device));
let autoencoder = AutoencoderConfig::new().init(); let autoencoder = AutoencoderConfig::new().init(device);
let diffusion = UNetConfig::new().init(); let diffusion = UNetConfig::new().init(device);
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(); let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(device);
StableDiffusion { StableDiffusion {
n_steps, n_steps,
@@ -55,10 +48,26 @@ pub struct StableDiffusion<B: Backend> {
} }
impl<B: Backend> StableDiffusion<B> { impl<B: Backend> StableDiffusion<B> {
pub fn sample_image(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Vec<Vec<u8>> { pub fn sample_image(
&self,
context: Tensor<B, 3>,
unconditional_context: Tensor<B, 2>,
unconditional_guidance_scale: f64,
n_steps: usize,
) -> Vec<Vec<u8>> {
let [n_batch, _, _] = context.dims(); let [n_batch, _, _] = context.dims();
let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps); let latent = self.sample_latent(
context,
unconditional_context,
unconditional_guidance_scale,
n_steps,
);
self.latent_to_image(latent)
}
pub fn latent_to_image(&self, latent: Tensor<B, 4>) -> Vec<Vec<u8>> {
let [n_batch, _, _, _] = latent.dims();
let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215)); let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215));
let n_channel = 3; let n_channel = 3;
@@ -74,19 +83,29 @@ impl<B: Backend> StableDiffusion<B> {
.swap_dims(2, 3) .swap_dims(2, 3)
.mul_scalar(255.0); .mul_scalar(255.0);
let flattened: Vec<_> = image. let flattened: Vec<B::FloatElem> = image.into_data().to_vec().unwrap();
into_data().
value;
(0..n_batch).into_iter().map(|b| { (0..n_batch)
let start = b * num_elements_per_image; .into_iter()
let end = start + num_elements_per_image; .map(|b| {
let start = b * num_elements_per_image;
let end = start + num_elements_per_image;
flattened[start..end].into_iter().map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap()).collect() flattened[start..end]
}).collect() .into_iter()
.map(|v| v.to_f64().min(255.0).max(0.0) as u8)
.collect()
})
.collect()
} }
pub fn sample_latent(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor<B, 4> { pub fn sample_latent(
&self,
context: Tensor<B, 3>,
unconditional_context: Tensor<B, 2>,
unconditional_guidance_scale: f64,
n_steps: usize,
) -> Tensor<B, 4> {
let device = context.device(); let device = context.device();
let step_size = self.n_steps / n_steps; let step_size = self.n_steps / n_steps;
@@ -94,7 +113,7 @@ impl<B: Backend> StableDiffusion<B> {
let [n_batches, _, _] = context.dims(); let [n_batches, _, _] = context.dims();
let gen_noise = || { let gen_noise = || {
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0)).to_device(&device) Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0), &device)
}; };
let sigma = 0.0; // Use deterministic diffusion let sigma = 0.0; // Use deterministic diffusion
@@ -102,18 +121,34 @@ impl<B: Backend> StableDiffusion<B> {
let mut latent = gen_noise(); let mut latent = gen_noise();
for t in (0..self.n_steps).rev().step_by(step_size) { for t in (0..self.n_steps).rev().step_by(step_size) {
let current_alpha: f64 = self.alpha_cumulative_products.val().slice([t..t + 1]).into_scalar().to_f64().unwrap(); let current_alpha: f64 = self
.alpha_cumulative_products
.val()
.slice([t..t + 1])
.into_scalar()
.to_f64();
let prev_alpha: f64 = if t >= step_size { let prev_alpha: f64 = if t >= step_size {
let i = t - step_size; let i = t - step_size;
self.alpha_cumulative_products.val().slice([i..i + 1]).into_scalar().to_f64().unwrap() self.alpha_cumulative_products
.val()
.slice([i..i + 1])
.into_scalar()
.to_f64()
} else { } else {
1.0 1.0
}; };
let sqrt_noise = (1.0 - current_alpha).sqrt(); let sqrt_noise = (1.0 - current_alpha).sqrt();
let timestep = Tensor::from_ints([t as i32]).to_device(&device); let timestep = Tensor::from_ints([t as i32], &device);
let pred_noise = self.forward_diffuser(latent.clone(), timestep, context.clone(), unconditional_context.clone(), unconditional_guidance_scale); let pred_noise = self.forward_diffuser(
latent.clone(),
timestep,
context.clone(),
unconditional_context.clone(),
unconditional_guidance_scale,
);
let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt(); let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt();
let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt(); let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt();
@@ -124,21 +159,24 @@ impl<B: Backend> StableDiffusion<B> {
latent latent
} }
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> { fn forward_diffuser(
&self,
latent: Tensor<B, 4>,
timestep: Tensor<B, 1, Int>,
context: Tensor<B, 3>,
unconditional_context: Tensor<B, 2>,
unconditional_guidance_scale: f64,
) -> Tensor<B, 4> {
let [n_batch, _, _, _] = latent.dims(); let [n_batch, _, _, _] = latent.dims();
//let latent = latent.repeat(0, 2); //let latent = latent.repeat(0, 2);
let unconditional_latent = self.diffusion.forward( let unconditional_latent = self.diffusion.forward(
latent.clone(), latent.clone(),
timestep.clone(), timestep.clone(),
unconditional_context.unsqueeze().repeat(0, n_batch) unconditional_context.unsqueeze().repeat(&[0, n_batch]),
); );
let conditional_latent = self.diffusion.forward( let conditional_latent = self.diffusion.forward(latent, timestep, context);
latent,
timestep,
context
);
/*let latent = self.diffusion.forward( /*let latent = self.diffusion.forward(
latent.repeat(0, 2), latent.repeat(0, 2),
@@ -149,43 +187,51 @@ impl<B: Backend> StableDiffusion<B> {
let unconditional_latent = latent.clone().slice([0..n_batch]); let unconditional_latent = latent.clone().slice([0..n_batch]);
let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/ let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/
unconditional_latent.clone() + (conditional_latent - unconditional_latent) * unconditional_guidance_scale unconditional_latent.clone()
+ (conditional_latent - unconditional_latent) * unconditional_guidance_scale
} }
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> { pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
self.context(tokenizer, "").squeeze(0) self.context(tokenizer, "").squeeze::<2>()
} }
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> { pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
let device = &self.devices()[0]; let device = &self.clip.devices()[0];
let text = format!("<|startoftext|>{}<|endoftext|>", text); let text = format!("<|startoftext|>{}<|endoftext|>", text);
let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect(); let tokenized: Vec<_> = tokenizer
.encode(&text)
.into_iter()
.map(|v| v as i32)
.collect();
self.clip.forward(Tensor::from_ints(&tokenized[..]).to_device(device).unsqueeze()) self.clip.forward(
Tensor::<B, 1, Int>::from_ints(&tokenized[..], device)
.unsqueeze(),
)
} }
} }
use crate::helper::to_float;
use std::f64::consts::PI; use std::f64::consts::PI;
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> { fn cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
to_float(Tensor::arange(1..n_steps + 1)) Tensor::arange(1..n_steps + 1, device)
.float()
.mul_scalar(PI * 0.5 / n_steps as f64) .mul_scalar(PI * 0.5 / n_steps as f64)
.cos() .cos()
} }
fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> { fn offset_cosine_schedule<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
let min_signal_rate: f64 = 0.02; let min_signal_rate: f64 = 0.02;
let max_signal_rate: f64 = 0.95; let max_signal_rate: f64 = 0.95;
let start_angle = max_signal_rate.acos(); let start_angle = max_signal_rate.acos();
let end_angle = min_signal_rate.acos(); let end_angle = min_signal_rate.acos();
let times = Tensor::arange(1..n_steps + 1); let times = Tensor::arange(1..n_steps + 1, device).float();
let diffusion_angles = to_float(times) * ( (end_angle - start_angle) / n_steps as f64) + start_angle; let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle;
diffusion_angles.cos() diffusion_angles.cos()
} }
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: usize) -> Tensor<B, 1> { fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: i64, device: &B::Device) -> Tensor<B, 1> {
offset_cosine_schedule::<B>(n_steps).powf(2.0) offset_cosine_schedule::<B>(n_steps, device).powf_scalar(2.0)
} }

View File

@@ -7,16 +7,16 @@ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
use super::*; use super::*;
use crate::model::groupnorm::load::load_group_norm; use crate::model::groupnorm::load::load_group_norm;
pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResBlock<B>, Box<dyn Error>> { pub fn load_res_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResBlock<B>, Box<dyn Error>> {
let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?; let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?;
let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?; let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?;
let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?; let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?;
@@ -39,7 +39,10 @@ pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResB
Ok(res_block) Ok(res_block)
} }
pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadAttention<B>, Box<dyn Error>> { pub fn load_multi_head_attention<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<MultiHeadAttention<B>, Box<dyn Error>> {
let n_head = load_usize::<B>("n_head", path, device)?; let n_head = load_usize::<B>("n_head", path, device)?;
let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?; let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?;
let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?; let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?;
@@ -57,19 +60,17 @@ pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) ->
Ok(multi_head_attention) Ok(multi_head_attention)
} }
pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> { pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> {
let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?; let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?;
let geglue = GEGLU { let geglue = GEGLU {
proj: proj, proj: proj,
gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct gelu: Gelu::new(), // Assuming Gelu::new() initializes a new Gelu struct
}; };
Ok(geglue) Ok(geglue)
} }
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> { pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?; let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?;
let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?; let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?;
@@ -82,8 +83,10 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
Ok(mlp) Ok(mlp)
} }
pub fn load_transformer_block<B: Backend>(
pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Result<TransformerBlock<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<TransformerBlock<B>, Box<dyn Error>> {
let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?; let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?;
let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?; let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?;
let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?; let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?;
@@ -103,8 +106,10 @@ pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Res
Ok(transformer_block) Ok(transformer_block)
} }
pub fn load_spatial_transformer<B: Backend>(
pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<SpatialTransformer<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<SpatialTransformer<B>, Box<dyn Error>> {
let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?; let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?;
let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?; let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?;
let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?; let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?;
@@ -120,24 +125,31 @@ pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> R
Ok(spatial_transformer) Ok(spatial_transformer)
} }
pub fn load_upsample<B: Backend>(
pub fn load_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<Upsample<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<Upsample<B>, Box<dyn Error>> {
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?; let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
let upsample = Upsample { let upsample = Upsample { conv: conv };
conv: conv,
};
Ok(upsample) Ok(upsample)
} }
pub fn load_downsample<B: Backend>(path: &str, device: &B::Device) -> Result<Downsample<B>, Box<dyn Error>> { pub fn load_downsample<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<Downsample<B>, Box<dyn Error>> {
load_conv2d(path, device) load_conv2d(path, device)
} }
pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerRes<B>, Box<dyn Error>> { pub fn load_res_transformer_res<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResTransformerRes<B>, Box<dyn Error>> {
let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?; let transformer =
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?; let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?;
let res_transformer_res = ResTransformerRes { let res_transformer_res = ResTransformerRes {
@@ -149,9 +161,13 @@ pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> R
Ok(res_transformer_res) Ok(res_transformer_res)
} }
pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> { pub fn load_res_transformer_upsample<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?; let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?; let transformer =
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?; let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
let res_transformer_upsample = ResTransformerUpsample { let res_transformer_upsample = ResTransformerUpsample {
@@ -163,8 +179,10 @@ pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device)
Ok(res_transformer_upsample) Ok(res_transformer_upsample)
} }
pub fn load_res_upsample<B: Backend>(
pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResUpSample<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<ResUpSample<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?; let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?; let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
@@ -176,10 +194,13 @@ pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<R
Ok(res_upsample) Ok(res_upsample)
} }
pub fn load_res_transformer<B: Backend>(
pub fn load_res_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformer<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<ResTransformer<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?; let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?; let transformer =
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let res_transformer = ResTransformer { let res_transformer = ResTransformer {
res: res, res: res,
@@ -189,8 +210,10 @@ pub fn load_res_transformer<B: Backend>(path: &str, device: &B::Device) -> Resul
Ok(res_transformer) Ok(res_transformer)
} }
pub fn load_unet_input_blocks<B: Backend>(
pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetInputBlocks<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<UNetInputBlocks<B>, Box<dyn Error>> {
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?; let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?; let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?; let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
@@ -222,7 +245,10 @@ pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Res
Ok(unet_input_blocks) Ok(unet_input_blocks)
} }
pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> { pub fn load_unet_output_blocks<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> {
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?; let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?; let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?; let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?;
@@ -252,14 +278,16 @@ pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Re
}) })
} }
pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> { pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> {
let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?; let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?;
let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?; let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?;
let input_blocks = load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?; let input_blocks =
let middle_block = load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?; load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?;
let output_blocks = load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?; let middle_block =
load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?;
let output_blocks =
load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?;
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?; let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?; let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;

View File

@@ -3,76 +3,80 @@ pub mod load;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn::{self, PaddingConfig2d, GELU, conv::{Conv2d, Conv2dConfig}}, nn::{
tensor::{ self,
backend::Backend, conv::{Conv2d, Conv2dConfig},
activation::softmax, PaddingConfig2d, Gelu,
module::embedding,
Tensor,
Distribution,
Int,
}, },
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
}; };
use super::silu::*;
use super::groupnorm::*; use super::groupnorm::*;
use crate::helper::to_float; use super::silu::*;
use super::attention::qkv_attention; use super::attention::qkv_attention;
fn timestep_embedding<B: Backend>(
fn timestep_embedding<B: Backend>(timesteps: Tensor<B, 1, Int>, dim: usize, max_period: usize) -> Tensor<B, 2> { timesteps: Tensor<B, 1, Int>,
dim: usize,
max_period: usize,
) -> Tensor<B, 2> {
let half = dim / 2; let half = dim / 2;
let freqs = ( to_float(Tensor::arange_device(0..half, &timesteps.device())) * (-(max_period as f64).ln() / half as f64 ) ).exp(); let freqs = (Tensor::arange(0..half as i64, &timesteps.device()).float()
let args = to_float(timesteps) * freqs; * (-(max_period as f64).ln() / half as f64))
.exp();
let args = timesteps.float() * freqs;
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze() Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct UNetConfig {} pub struct UNetConfig {}
impl UNetConfig { impl UNetConfig {
pub fn init<B: Backend>(&self) -> UNet<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> UNet<B> {
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(); let lin1_time_embed = nn::LinearConfig::new(320, 1280).init(device);
let silu_time_embed = SILU::new(); let silu_time_embed = SILU::new();
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(); let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(device);
let input_blocks = UNetInputBlocks { let input_blocks = UNetInputBlocks {
conv: Conv2dConfig::new([4, 320], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(), conv: Conv2dConfig::new([4, 320], [3, 3])
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), .with_padding(PaddingConfig2d::Explicit(1, 1))
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), .init(device),
d1: DownsampleConfig::new(320).init(), rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(), rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(device),
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(), d1: DownsampleConfig::new(320).init(device),
d2: DownsampleConfig::new(640).init(), rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(device),
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(), rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(device),
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(), d2: DownsampleConfig::new(640).init(device),
d3: DownsampleConfig::new(1280).init(), rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(device),
r1: ResBlockConfig::new(1280, 1280, 1280).init(), rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(device),
r2: ResBlockConfig::new(1280, 1280, 1280).init(), d3: DownsampleConfig::new(1280).init(device),
r1: ResBlockConfig::new(1280, 1280, 1280).init(device),
r2: ResBlockConfig::new(1280, 1280, 1280).init(device),
}; };
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(); let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(device);
let output_blocks = UNetOutputBlocks { let output_blocks = UNetOutputBlocks {
r1: ResBlockConfig::new(2560, 1280, 1280).init(), r1: ResBlockConfig::new(2560, 1280, 1280).init(device),
r2: ResBlockConfig::new(2560, 1280, 1280).init(), r2: ResBlockConfig::new(2560, 1280, 1280).init(device),
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(), ru: ResUpSampleConfig::new(2560, 1280, 1280).init(device),
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(), rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(), rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(device),
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(), rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(device),
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(), rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(device),
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(), rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(device),
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(), rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(device),
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(), rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(device),
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(), rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(), rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(device),
}; };
let norm_out = GroupNormConfig::new(32, 320).init(); let norm_out = GroupNormConfig::new(32, 320).init(device);
let silu_out = SILU::new(); let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([320, 4], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_out = Conv2dConfig::new([320, 4], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
UNet { UNet {
lin1_time_embed, lin1_time_embed,
@@ -102,7 +106,12 @@ pub struct UNet<B: Backend> {
} }
impl<B: Backend> UNet<B> { impl<B: Backend> UNet<B> {
pub fn forward(&self, x: Tensor<B, 4>, timesteps: Tensor<B, 1, Int>, context: Tensor<B, 3>) -> Tensor<B, 4> { pub fn forward(
&self,
x: Tensor<B, 4>,
timesteps: Tensor<B, 1, Int>,
context: Tensor<B, 3>,
) -> Tensor<B, 4> {
let t_emb = timestep_embedding(timesteps, 320, 10000); let t_emb = timestep_embedding(timesteps, 320, 10000);
let emb = self.lin1_time_embed.forward(t_emb); let emb = self.lin1_time_embed.forward(t_emb);
let emb = self.silu_time_embed.forward(emb); let emb = self.silu_time_embed.forward(emb);
@@ -133,8 +142,6 @@ impl<B: Backend> UNet<B> {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct UNetInputBlocks<B: Backend> { pub struct UNetInputBlocks<B: Backend> {
conv: Conv2d<B>, conv: Conv2d<B>,
@@ -154,18 +161,8 @@ pub struct UNetInputBlocks<B: Backend> {
impl<B: Backend> UNetInputBlocks<B> { impl<B: Backend> UNetInputBlocks<B> {
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] { fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
[ [
&self.conv, &self.conv, &self.rt1, &self.rt2, &self.d1, &self.rt3, &self.rt4, &self.d2, &self.rt5,
&self.rt1, &self.rt6, &self.d3, &self.r1, &self.r2,
&self.rt2,
&self.d1,
&self.rt3,
&self.rt4,
&self.d2,
&self.rt5,
&self.rt6,
&self.d3,
&self.r1,
&self.r2,
] ]
} }
} }
@@ -189,31 +186,17 @@ pub struct UNetOutputBlocks<B: Backend> {
impl<B: Backend> UNetOutputBlocks<B> { impl<B: Backend> UNetOutputBlocks<B> {
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] { fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
[ [
&self.r1, &self.r1, &self.r2, &self.ru, &self.rt1, &self.rt2, &self.rtu1, &self.rt3, &self.rt4,
&self.r2, &self.rtu2, &self.rt5, &self.rt6, &self.rt7,
&self.ru,
&self.rt1,
&self.rt2,
&self.rtu1,
&self.rt3,
&self.rt4,
&self.rtu2,
&self.rt5,
&self.rt6,
&self.rt7,
] ]
} }
} }
trait UNetBlock<B: Backend> { trait UNetBlock<B: Backend> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>; fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct ResTransformerConfig { pub struct ResTransformerConfig {
n_channels_in: usize, n_channels_in: usize,
n_channels_embed: usize, n_channels_embed: usize,
@@ -223,14 +206,18 @@ pub struct ResTransformerConfig {
} }
impl ResTransformerConfig { impl ResTransformerConfig {
fn init<B: Backend>(&self) -> ResTransformer<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResTransformer<B> {
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); let res = ResBlockConfig::new(
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init(); self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init(device);
ResTransformer { ResTransformer { res, transformer }
res,
transformer,
}
} }
} }
@@ -248,7 +235,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformer<B> {
} }
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct ResUpSampleConfig { pub struct ResUpSampleConfig {
n_channels_in: usize, n_channels_in: usize,
n_channels_embed: usize, n_channels_embed: usize,
@@ -256,14 +243,16 @@ pub struct ResUpSampleConfig {
} }
impl ResUpSampleConfig { impl ResUpSampleConfig {
fn init<B: Backend>(&self) -> ResUpSample<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResUpSample<B> {
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); let res = ResBlockConfig::new(
let upsample = UpsampleConfig::new(self.n_channels_out).init(); self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResUpSample { ResUpSample { res, upsample }
res,
upsample,
}
} }
} }
@@ -281,7 +270,7 @@ impl<B: Backend> UNetBlock<B> for ResUpSample<B> {
} }
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct ResTransformerUpsampleConfig { pub struct ResTransformerUpsampleConfig {
n_channels_in: usize, n_channels_in: usize,
n_channels_embed: usize, n_channels_embed: usize,
@@ -291,10 +280,17 @@ pub struct ResTransformerUpsampleConfig {
} }
impl ResTransformerUpsampleConfig { impl ResTransformerUpsampleConfig {
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerUpsample<B> {
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); let res = ResBlockConfig::new(
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init(); self.n_channels_in,
let upsample = UpsampleConfig::new(self.n_channels_out).init(); self.n_channels_embed,
self.n_channels_out,
)
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init(device);
let upsample = UpsampleConfig::new(self.n_channels_out).init(device);
ResTransformerUpsample { ResTransformerUpsample {
res, res,
@@ -320,7 +316,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformerUpsample<B> {
} }
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct ResTransformerResConfig { pub struct ResTransformerResConfig {
n_channels_in: usize, n_channels_in: usize,
n_channels_embed: usize, n_channels_embed: usize,
@@ -330,10 +326,22 @@ pub struct ResTransformerResConfig {
} }
impl ResTransformerResConfig { impl ResTransformerResConfig {
fn init<B: Backend>(&self) -> ResTransformerRes<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResTransformerRes<B> {
let res1 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); let res1 = ResBlockConfig::new(
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init(); self.n_channels_in,
let res2 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); self.n_channels_embed,
self.n_channels_out,
)
.init(device);
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init(device);
let res2 = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init(device);
ResTransformerRes { ResTransformerRes {
res1, res1,
@@ -359,22 +367,18 @@ impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct UpsampleConfig { pub struct UpsampleConfig {
n_channels: usize, n_channels: usize,
} }
impl UpsampleConfig { impl UpsampleConfig {
fn init<B: Backend>(&self) -> Upsample<B> { fn init<B: Backend>(&self, device: &B::Device) -> Upsample<B> {
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3]) let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init(device);
Upsample { Upsample { conv }
conv,
}
} }
} }
@@ -387,10 +391,9 @@ impl<B: Backend> Upsample<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims(); let [n_batch, n_channel, height, width] = x.dims();
let x = x let x = x
.reshape([n_batch, n_channel, height, 1, width, 1]) .reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2) .repeat(&[1, 1, 1, 2, 1, 2])
.repeat(5, 2) .reshape([n_batch, n_channel, 2 * height, 2 * width]);
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
self.conv.forward(x) self.conv.forward(x)
} }
} }
@@ -401,17 +404,17 @@ impl<B: Backend> UNetBlock<B> for Upsample<B> {
} }
} }
#[derive(Config)] #[derive(Config, Debug)]
pub struct DownsampleConfig { pub struct DownsampleConfig {
n_channels: usize, n_channels: usize,
} }
impl DownsampleConfig { impl DownsampleConfig {
fn init<B: Backend>(&self) -> Conv2d<B> { fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3]) Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_stride([2, 2]) .with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init() .init(device)
} }
} }
@@ -423,10 +426,7 @@ impl<B: Backend> UNetBlock<B> for Conv2d<B> {
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct SpatialTransformerConfig { pub struct SpatialTransformerConfig {
n_channels: usize, n_channels: usize,
n_context_state: usize, n_context_state: usize,
@@ -434,11 +434,12 @@ pub struct SpatialTransformerConfig {
} }
impl SpatialTransformerConfig { impl SpatialTransformerConfig {
fn init<B: Backend>(&self) -> SpatialTransformer<B> { fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
let norm = GroupNormConfig::new(32, self.n_channels).init(); let norm = GroupNormConfig::new(32, self.n_channels).init(device);
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
let transformer = TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(); let transformer =
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(device);
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(device);
SpatialTransformer { SpatialTransformer {
norm, norm,
@@ -465,9 +466,13 @@ impl<B: Backend> SpatialTransformer<B> {
let x = self.norm.forward(x); let x = self.norm.forward(x);
let x = self.proj_in.forward(x); let x = self.proj_in.forward(x);
let x = x.reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); let x = x
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let x = self.transformer.forward(x, context) let x = self
.transformer
.forward(x, context)
.swap_dims(1, 2) .swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]); .reshape([n_batch, n_channel, height, width]);
@@ -475,14 +480,7 @@ impl<B: Backend> SpatialTransformer<B> {
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct TransformerBlockConfig { pub struct TransformerBlockConfig {
n_state: usize, n_state: usize,
n_context_state: usize, n_context_state: usize,
@@ -490,13 +488,14 @@ pub struct TransformerBlockConfig {
} }
impl TransformerBlockConfig { impl TransformerBlockConfig {
fn init<B: Backend>(&self) -> TransformerBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
let norm1 = nn::LayerNormConfig::new(self.n_state).init(); let norm1 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(); let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(device);
let norm2 = nn::LayerNormConfig::new(self.n_state).init(); let norm2 = nn::LayerNormConfig::new(self.n_state).init(device);
let attn2 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(); let attn2 =
let norm3 = nn::LayerNormConfig::new(self.n_state).init(); MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(device);
let mlp = MLPConfig::new(self.n_state, 4).init(); let norm3 = nn::LayerNormConfig::new(self.n_state).init(device);
let mlp = MLPConfig::new(self.n_state, 4).init(device);
TransformerBlock { TransformerBlock {
norm1, norm1,
@@ -521,29 +520,25 @@ pub struct TransformerBlock<B: Backend> {
impl<B: Backend> TransformerBlock<B> { impl<B: Backend> TransformerBlock<B> {
fn forward(&self, x: Tensor<B, 3>, context: Tensor<B, 3>) -> Tensor<B, 3> { fn forward(&self, x: Tensor<B, 3>, context: Tensor<B, 3>) -> Tensor<B, 3> {
let x = x.clone() + self.attn1.forward( self.norm1.forward(x), None); let x = x.clone() + self.attn1.forward(self.norm1.forward(x), None);
let x = x.clone() + self.attn2.forward( self.norm2.forward(x), Some(context)); let x = x.clone() + self.attn2.forward(self.norm2.forward(x), Some(context));
x.clone() + self.mlp.forward( self.norm3.forward(x) ) x.clone() + self.mlp.forward(self.norm3.forward(x))
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct MLPConfig { pub struct MLPConfig {
n_state: usize, n_state: usize,
mult: usize, mult: usize,
} }
impl MLPConfig { impl MLPConfig {
pub fn init<B: Backend>(&self) -> MLP<B> { pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
let n_state_hidden = self.n_state * self.mult; let n_state_hidden = self.n_state * self.mult;
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(); let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(device);
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(); let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(device);
MLP { MLP { geglu, lin }
geglu,
lin,
}
} }
} }
@@ -555,33 +550,29 @@ pub struct MLP<B: Backend> {
impl<B: Backend> MLP<B> { impl<B: Backend> MLP<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> { pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
self.lin.forward( self.geglu.forward(x) ) self.lin.forward(self.geglu.forward(x))
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct GEGLUConfig { pub struct GEGLUConfig {
n_state_in: usize, n_state_in: usize,
n_state_out: usize, n_state_out: usize,
} }
impl GEGLUConfig { impl GEGLUConfig {
fn init<B: Backend>(&self) -> GEGLU<B> { fn init<B: Backend>(&self, device: &B::Device) -> GEGLU<B> {
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(); let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(device);
let gelu = GELU::new(); let gelu = Gelu::new();
GEGLU { GEGLU { proj, gelu }
proj,
gelu,
}
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct GEGLU<B: Backend> { pub struct GEGLU<B: Backend> {
proj: nn::Linear<B>, proj: nn::Linear<B>,
gelu: GELU, gelu: Gelu,
} }
impl<B: Backend> GEGLU<B> { impl<B: Backend> GEGLU<B> {
@@ -591,18 +582,16 @@ impl<B: Backend> GEGLU<B> {
let n_state_out = n_state / 2; let n_state_out = n_state / 2;
let x = projected.clone().slice([0..n_batch, 0..n_ctx, 0..n_state_out]); let x = projected
.clone()
.slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]); let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]);
x * self.gelu.forward(gate) x * self.gelu.forward(gate)
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct MultiHeadAttentionConfig { pub struct MultiHeadAttentionConfig {
n_state: usize, n_state: usize,
n_context_state: usize, n_context_state: usize,
@@ -610,21 +599,32 @@ pub struct MultiHeadAttentionConfig {
} }
impl MultiHeadAttentionConfig { impl MultiHeadAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadAttention<B> { fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head); assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
self.n_state,
self.n_head
);
let n_head = self.n_head; let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).with_bias(false).init(); let query = nn::LinearConfig::new(self.n_state, self.n_state)
let key = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init(); .with_bias(false)
let value = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init(); .init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(); let key = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init(device);
let value = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init(device);
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
MultiHeadAttention { MultiHeadAttention {
n_head, n_head,
query, query,
key, key,
value, value,
out out,
} }
} }
} }
@@ -652,44 +652,32 @@ impl<B: Backend> MultiHeadAttention<B> {
} }
} }
#[derive(Config, Debug)]
#[derive(Config)]
pub struct ResBlockConfig { pub struct ResBlockConfig {
n_channels_in: usize, n_channels_in: usize,
n_channels_embed: usize, n_channels_embed: usize,
n_channels_out: usize, n_channels_out: usize,
} }
impl ResBlockConfig { impl ResBlockConfig {
fn init<B: Backend>(&self) -> ResBlock<B> { fn init<B: Backend>(&self, device: &B::Device) -> ResBlock<B> {
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(); let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(device);
let silu_in = SILU::new(); let silu_in = SILU::new();
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let silu_embed = SILU::new(); let silu_embed = SILU::new();
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(); let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(device);
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(); let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(device);
let silu_out = SILU::new(); let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(device);
let skip_connection = if self.n_channels_in != self.n_channels_out { let skip_connection = if self.n_channels_in != self.n_channels_out {
Some( Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init() ) Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init(device))
} else { } else {
None None
}; };
@@ -708,7 +696,6 @@ impl ResBlockConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct ResBlock<B: Backend> { pub struct ResBlock<B: Backend> {
norm_in: GroupNorm<B>, norm_in: GroupNorm<B>,
@@ -751,5 +738,3 @@ impl<B: Backend> UNetBlock<B> for ResBlock<B> {
self.forward(x, emb) self.forward(x, emb)
} }
} }

View File

@@ -1,13 +1,14 @@
use std::collections::HashMap;
use regex::Regex; use regex::Regex;
use std::collections::HashMap;
use std::fs::File; use std::fs::File;
use std::io::{self, BufRead}; use std::io::{self, BufRead};
fn bytes_to_unicode() -> Vec<(u8, char)> { fn bytes_to_unicode() -> Vec<(u8, char)> {
let mut bs: Vec<u8> = ('!' as u8 ..= '~' as u8).into_iter() let mut bs: Vec<u8> = ('!' as u8..='~' as u8)
.chain( ('¡' as u8..='¬' as u8).into_iter() ) .into_iter()
.chain( ('®' as u8..='ÿ' as u8).into_iter() ) .chain(('¡' as u8..='¬' as u8).into_iter())
.chain(('®' as u8..='ÿ' as u8).into_iter())
.collect(); .collect();
let mut cs: Vec<_> = bs.iter().cloned().map(char::from).collect(); let mut cs: Vec<_> = bs.iter().cloned().map(char::from).collect();
@@ -16,25 +17,21 @@ fn bytes_to_unicode() -> Vec<(u8, char)> {
for b in 0u8..=255u8 { for b in 0u8..=255u8 {
if !bs.contains(&b) { if !bs.contains(&b) {
bs.push(b); bs.push(b);
cs.push( char::from_u32(256 + n).unwrap() ); cs.push(char::from_u32(256 + n).unwrap());
n += 1; n += 1;
} }
} }
bs.into_iter() bs.into_iter()
.zip( .zip(cs.into_iter().map(|c| c.into()))
cs.into_iter() .collect()
.map(|c| c.into())
).collect()
} }
fn get_pairs(word: &[String]) -> Vec<(String, String)> { fn get_pairs(word: &[String]) -> Vec<(String, String)> {
let prev = word.into_iter().cloned(); let prev = word.into_iter().cloned();
let next = prev.clone().skip(1); let next = prev.clone().skip(1);
prev prev.zip(next).collect()
.zip(next)
.collect()
} }
fn whitespace_clean(text: &str) -> String { fn whitespace_clean(text: &str) -> String {
@@ -59,9 +56,12 @@ fn load_merges(path: &str) -> io::Result<Vec<(String, String)>> {
Ok(merges) Ok(merges)
} }
fn construct_vocab(chars: impl Iterator<Item=char> + Clone, merges: &[(String, String)]) -> Vec<String> { fn construct_vocab(
chars: impl Iterator<Item = char> + Clone,
merges: &[(String, String)],
) -> Vec<String> {
let iter = chars.map(String::from); let iter = chars.map(String::from);
let mut vocab: Vec<_> = iter.clone().chain( iter.map(|c| c + "</w>") ).collect(); let mut vocab: Vec<_> = iter.clone().chain(iter.map(|c| c + "</w>")).collect();
for merge in merges { for merge in merges {
vocab.push(format!("{}{}", merge.0, merge.1)); vocab.push(format!("{}{}", merge.0, merge.1));
@@ -87,10 +87,10 @@ impl SimpleTokenizer {
let byte_unicode_values = bytes_to_unicode(); let byte_unicode_values = bytes_to_unicode();
let byte_encoder: HashMap<_, _> = byte_unicode_values.iter().cloned().collect(); let byte_encoder: HashMap<_, _> = byte_unicode_values.iter().cloned().collect();
let byte_decoder = byte_encoder.iter().map(|(k,v)| (*v,*k)).collect(); let byte_decoder = byte_encoder.iter().map(|(k, v)| (*v, *k)).collect();
let merges = load_merges("bpe_simple_vocab_16e6.txt")?; let merges = load_merges("bpe_simple_vocab_16e6.txt")?;
let merges = merges[1..49152-256-2+1].to_vec(); let merges = merges[1..49152 - 256 - 2 + 1].to_vec();
let vocab = construct_vocab(byte_unicode_values.into_iter().map(|(_, u)| u), &merges[..]); let vocab = construct_vocab(byte_unicode_values.into_iter().map(|(_, u)| u), &merges[..]);
@@ -104,7 +104,7 @@ impl SimpleTokenizer {
let pat = Regex::new(r"(?i)<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|\p{L}+|\p{N}|[^\s\p{L}\p{N}]+").unwrap(); let pat = Regex::new(r"(?i)<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|\p{L}+|\p{N}|[^\s\p{L}\p{N}]+").unwrap();
Ok( SimpleTokenizer { Ok(SimpleTokenizer {
byte_encoder: byte_encoder, byte_encoder: byte_encoder,
byte_decoder: byte_decoder, byte_decoder: byte_decoder,
encoder: encoder, encoder: encoder,
@@ -112,7 +112,7 @@ impl SimpleTokenizer {
bpe_ranks: bpe_ranks, bpe_ranks: bpe_ranks,
cache: cache, cache: cache,
pat: pat, pat: pat,
} ) })
} }
pub fn bpe(&self, token: &str) -> String { pub fn bpe(&self, token: &str) -> String {
@@ -129,7 +129,8 @@ impl SimpleTokenizer {
} }
loop { loop {
let bigram = pairs.iter() let bigram = pairs
.iter()
.filter(|pair| self.bpe_ranks.contains_key(pair)) .filter(|pair| self.bpe_ranks.contains_key(pair))
.min_by_key(|&pair| self.bpe_ranks[pair]); .min_by_key(|&pair| self.bpe_ranks[pair]);
@@ -141,7 +142,7 @@ impl SimpleTokenizer {
let mut new_word = Vec::new(); let mut new_word = Vec::new();
let mut i = 0; let mut i = 0;
while i < word.len() { while i < word.len() {
if let Some( (j, _) ) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) { if let Some((j, _)) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) {
new_word.extend(word[i..j].iter().cloned()); new_word.extend(word[i..j].iter().cloned());
i = j; i = j;
} else { } else {
@@ -178,8 +179,16 @@ impl SimpleTokenizer {
for m in self.pat.find_iter(&cleaned_text) { for m in self.pat.find_iter(&cleaned_text) {
let token = m.as_str(); let token = m.as_str();
let token: String = token.as_bytes().into_iter().map(|b| self.byte_encoder[b]).collect(); let token: String = token
bpe_tokens.extend(self.bpe(&token).split(' ').map(|bpe_token| self.encoder[bpe_token])) .as_bytes()
.into_iter()
.map(|b| self.byte_encoder[b])
.collect();
bpe_tokens.extend(
self.bpe(&token)
.split(' ')
.map(|bpe_token| self.encoder[bpe_token]),
)
} }
return bpe_tokens; return bpe_tokens;
@@ -187,9 +196,7 @@ impl SimpleTokenizer {
pub fn decode(&self, tokens: &[u32]) -> String { pub fn decode(&self, tokens: &[u32]) -> String {
let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect(); let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect();
let decoded_bytes: Vec<u8> = text.chars() let decoded_bytes: Vec<u8> = text.chars().map(|c| self.byte_decoder[&c]).collect();
.map(|c| self.byte_decoder[&c])
.collect();
String::from_utf8_lossy(&decoded_bytes[..]).replace("</w>", " ") String::from_utf8_lossy(&decoded_bytes[..]).replace("</w>", " ")
} }