updated to the newest version of burn
This commit is contained in:
16
Cargo.toml
16
Cargo.toml
@@ -8,17 +8,13 @@ edition = "2021"
|
|||||||
[features]
|
[features]
|
||||||
wgpu-backend = ["burn-wgpu"]
|
wgpu-backend = ["burn-wgpu"]
|
||||||
|
|
||||||
[dependencies.burn-wgpu]
|
|
||||||
package = "burn-wgpu"
|
|
||||||
git = "https://github.com/burn-rs/burn.git"
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn = "0.14.0"
|
burn = "0.20.1"
|
||||||
burn-ndarray = "0.14.0"
|
burn-ndarray = "0.20.1"
|
||||||
burn-tch = "0.14.0"
|
burn-tch = "0.20.1"
|
||||||
burn-autodiff = "0.14.0"
|
burn-autodiff = "0.20.1"
|
||||||
tch = "0.15.0"
|
burn-wgpu = { version = "0.20.1", optional = true }
|
||||||
|
tch = "0.22.0"
|
||||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||||
npy = "0.4.0"
|
npy = "0.4.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use burn::{
|
|||||||
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "wgpu-backend")] {
|
if #[cfg(feature = "wgpu-backend")] {
|
||||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
use burn_wgpu::{Wgpu, WgpuDevice};
|
||||||
} else {
|
} else {
|
||||||
use burn_tch::{LibTorch, LibTorchDevice};
|
use burn_tch::{LibTorch, LibTorchDevice};
|
||||||
}
|
}
|
||||||
@@ -58,7 +58,7 @@ fn main() {
|
|||||||
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "wgpu-backend")] {
|
if #[cfg(feature = "wgpu-backend")] {
|
||||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
type Backend = Wgpu;
|
||||||
let device = WgpuDevice::BestAvailable;
|
let device = WgpuDevice::BestAvailable;
|
||||||
} else {
|
} else {
|
||||||
type Backend = LibTorch<f32>;
|
type Backend = LibTorch<f32>;
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ use crate::backend::{qkv_attention, attn_decoder_mask};
|
|||||||
|
|
||||||
use std::iter;
|
use std::iter;
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct AutoencoderConfig {}
|
pub struct AutoencoderConfig {}
|
||||||
|
|
||||||
impl AutoencoderConfig {
|
impl AutoencoderConfig {
|
||||||
@@ -71,7 +71,7 @@ impl<B: Backend> Autoencoder<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct EncoderConfig {
|
pub struct EncoderConfig {
|
||||||
channels: Vec<(usize, usize)>,
|
channels: Vec<(usize, usize)>,
|
||||||
n_group: usize,
|
n_group: usize,
|
||||||
@@ -144,7 +144,7 @@ impl<B: Backend> Encoder<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct DecoderConfig {
|
pub struct DecoderConfig {
|
||||||
channels: Vec<(usize, usize)>,
|
channels: Vec<(usize, usize)>,
|
||||||
n_group: usize,
|
n_group: usize,
|
||||||
@@ -216,7 +216,7 @@ impl<B: Backend> Decoder<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct EncoderBlockConfig {
|
pub struct EncoderBlockConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_out: usize,
|
n_channels_out: usize,
|
||||||
@@ -265,7 +265,7 @@ impl<B: Backend> EncoderBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct DecoderBlockConfig {
|
pub struct DecoderBlockConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_out: usize,
|
n_channels_out: usize,
|
||||||
@@ -323,7 +323,7 @@ impl<B: Backend> DecoderBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct PaddedConv2dConfig {
|
pub struct PaddedConv2dConfig {
|
||||||
channels: [usize; 2],
|
channels: [usize; 2],
|
||||||
kernel_size: usize,
|
kernel_size: usize,
|
||||||
@@ -427,7 +427,7 @@ pub struct Padding {
|
|||||||
pad_bottom: usize,
|
pad_bottom: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct MidConfig {
|
pub struct MidConfig {
|
||||||
n_channel: usize,
|
n_channel: usize,
|
||||||
}
|
}
|
||||||
@@ -462,7 +462,7 @@ impl<B: Backend> Mid<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResnetBlockConfig {
|
pub struct ResnetBlockConfig {
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
out_channels: usize,
|
out_channels: usize,
|
||||||
@@ -527,7 +527,7 @@ impl<B: Backend> ResnetBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ConvSelfAttentionBlockConfig {
|
pub struct ConvSelfAttentionBlockConfig {
|
||||||
n_channel: usize,
|
n_channel: usize,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ use burn::{
|
|||||||
//use crate::backend::Backend as MyBackend;
|
//use crate::backend::Backend as MyBackend;
|
||||||
use crate::backend::{qkv_attention, attn_decoder_mask};
|
use crate::backend::{qkv_attention, attn_decoder_mask};
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct CLIPConfig {
|
pub struct CLIPConfig {
|
||||||
n_vocab: usize,
|
n_vocab: usize,
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
@@ -75,7 +75,7 @@ impl<B: Backend> CLIP<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResidualDecoderAttentionBlockConfig {
|
pub struct ResidualDecoderAttentionBlockConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
@@ -114,7 +114,7 @@ impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct MultiHeadSelfAttentionConfig {
|
pub struct MultiHeadSelfAttentionConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use burn::{
|
|||||||
tensor::{backend::Backend, Tensor},
|
tensor::{backend::Backend, Tensor},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct GroupNormConfig {
|
pub struct GroupNormConfig {
|
||||||
n_group: usize,
|
n_group: usize,
|
||||||
n_channel: usize,
|
n_channel: usize,
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use burn::{
|
|||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
nn::{self, conv},
|
nn::{self, conv},
|
||||||
tensor::{backend::Backend, Data, Tensor},
|
tensor::{backend::Backend, Tensor},
|
||||||
};
|
};
|
||||||
|
|
||||||
use burn::tensor::ElementConversion;
|
use burn::tensor::ElementConversion;
|
||||||
@@ -98,7 +98,7 @@ pub fn load_layer_norm<B: Backend>(
|
|||||||
|
|
||||||
let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
|
let mut layer_norm = nn::LayerNormConfig::new(n_state).with_epsilon(eps).init(device);
|
||||||
layer_norm.gamma = Param::from_tensor(weight);
|
layer_norm.gamma = Param::from_tensor(weight);
|
||||||
layer_norm.beta = Param::from_tensor(bias);
|
layer_norm.beta = Some(Param::from_tensor(bias));
|
||||||
|
|
||||||
Ok(layer_norm)
|
Ok(layer_norm)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ pub mod load;
|
|||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::{Module, Param},
|
module::{Module, Param},
|
||||||
tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor},
|
tensor::{backend::Backend, BasicOps, Distribution, Float, Int, Tensor},
|
||||||
tensor::cast::ToElement,
|
tensor::cast::ToElement,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -16,7 +16,7 @@ use super::clip::{CLIPConfig, CLIP};
|
|||||||
use super::unet::{UNet, UNetConfig};
|
use super::unet::{UNet, UNetConfig};
|
||||||
use crate::tokenizer::SimpleTokenizer;
|
use crate::tokenizer::SimpleTokenizer;
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct StableDiffusionConfig {}
|
pub struct StableDiffusionConfig {}
|
||||||
|
|
||||||
impl StableDiffusionConfig {
|
impl StableDiffusionConfig {
|
||||||
@@ -192,7 +192,7 @@ impl<B: Backend> StableDiffusion<B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
|
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
|
||||||
self.context(tokenizer, "").squeeze(0)
|
self.context(tokenizer, "").squeeze::<2>()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
|
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ fn timestep_embedding<B: Backend>(
|
|||||||
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
|
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct UNetConfig {}
|
pub struct UNetConfig {}
|
||||||
|
|
||||||
impl UNetConfig {
|
impl UNetConfig {
|
||||||
@@ -196,7 +196,7 @@ trait UNetBlock<B: Backend> {
|
|||||||
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
|
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResTransformerConfig {
|
pub struct ResTransformerConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
@@ -235,7 +235,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformer<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResUpSampleConfig {
|
pub struct ResUpSampleConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
@@ -270,7 +270,7 @@ impl<B: Backend> UNetBlock<B> for ResUpSample<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResTransformerUpsampleConfig {
|
pub struct ResTransformerUpsampleConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
@@ -316,7 +316,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformerUpsample<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResTransformerResConfig {
|
pub struct ResTransformerResConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
@@ -367,7 +367,7 @@ impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct UpsampleConfig {
|
pub struct UpsampleConfig {
|
||||||
n_channels: usize,
|
n_channels: usize,
|
||||||
}
|
}
|
||||||
@@ -404,7 +404,7 @@ impl<B: Backend> UNetBlock<B> for Upsample<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct DownsampleConfig {
|
pub struct DownsampleConfig {
|
||||||
n_channels: usize,
|
n_channels: usize,
|
||||||
}
|
}
|
||||||
@@ -426,7 +426,7 @@ impl<B: Backend> UNetBlock<B> for Conv2d<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct SpatialTransformerConfig {
|
pub struct SpatialTransformerConfig {
|
||||||
n_channels: usize,
|
n_channels: usize,
|
||||||
n_context_state: usize,
|
n_context_state: usize,
|
||||||
@@ -480,7 +480,7 @@ impl<B: Backend> SpatialTransformer<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct TransformerBlockConfig {
|
pub struct TransformerBlockConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
n_context_state: usize,
|
n_context_state: usize,
|
||||||
@@ -526,7 +526,7 @@ impl<B: Backend> TransformerBlock<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct MLPConfig {
|
pub struct MLPConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
mult: usize,
|
mult: usize,
|
||||||
@@ -554,7 +554,7 @@ impl<B: Backend> MLP<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct GEGLUConfig {
|
pub struct GEGLUConfig {
|
||||||
n_state_in: usize,
|
n_state_in: usize,
|
||||||
n_state_out: usize,
|
n_state_out: usize,
|
||||||
@@ -591,7 +591,7 @@ impl<B: Backend> GEGLU<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct MultiHeadAttentionConfig {
|
pub struct MultiHeadAttentionConfig {
|
||||||
n_state: usize,
|
n_state: usize,
|
||||||
n_context_state: usize,
|
n_context_state: usize,
|
||||||
@@ -652,7 +652,7 @@ impl<B: Backend> MultiHeadAttention<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config, Debug)]
|
||||||
pub struct ResBlockConfig {
|
pub struct ResBlockConfig {
|
||||||
n_channels_in: usize,
|
n_channels_in: usize,
|
||||||
n_channels_embed: usize,
|
n_channels_embed: usize,
|
||||||
|
|||||||
Reference in New Issue
Block a user