Add custom backend to enable flash attention
This commit is contained in:
@@ -16,9 +16,9 @@ use burn::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::attention::qkv_attention;
|
||||
use super::groupnorm::*;
|
||||
use super::silu::*;
|
||||
use crate::backend::Backend as MyBackend;
|
||||
|
||||
use std::iter;
|
||||
|
||||
@@ -51,7 +51,7 @@ pub struct Autoencoder<B: Backend> {
|
||||
post_quant_conv: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Autoencoder<B> {
|
||||
impl<B: MyBackend> Autoencoder<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.decode_latent(self.encode_image(x))
|
||||
}
|
||||
@@ -128,7 +128,7 @@ pub struct Encoder<B: Backend> {
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Encoder<B> {
|
||||
impl<B: MyBackend> Encoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
|
||||
@@ -200,7 +200,7 @@ pub struct Decoder<B: Backend> {
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Decoder<B> {
|
||||
impl<B: MyBackend> Decoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
let x = self.mid.forward(x);
|
||||
@@ -383,10 +383,6 @@ pub struct PaddedConv2d<B: Backend> {
|
||||
|
||||
impl<B: Backend> PaddedConv2d<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
println!(
|
||||
"{} {} {:?} {:?}",
|
||||
self.kernel_size, self.stride, self.padding, self.padding_actual
|
||||
);
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
|
||||
@@ -444,7 +440,7 @@ pub struct Mid<B: Backend> {
|
||||
block_2: ResnetBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Mid<B> {
|
||||
impl<B: MyBackend> Mid<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.block_1.forward(x);
|
||||
let x = self.attn.forward(x);
|
||||
@@ -550,7 +546,7 @@ pub struct ConvSelfAttentionBlock<B: Backend> {
|
||||
proj_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
impl<B: MyBackend> ConvSelfAttentionBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
@@ -572,9 +568,15 @@ impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
let wv = qkv_attention(q, k, v, None, 1)
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
q.into_primitive(),
|
||||
k.into_primitive(),
|
||||
v.into_primitive(),
|
||||
None,
|
||||
1,
|
||||
))
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
|
||||
let projected = self.proj_out.forward(wv);
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ use burn::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::model::attention::{attn_decoder_mask, qkv_attention};
|
||||
use crate::backend::Backend as MyBackend;
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct CLIPConfig {
|
||||
@@ -51,11 +51,11 @@ pub struct CLIP<B: Backend> {
|
||||
layer_norm: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> CLIP<B> {
|
||||
impl<B: MyBackend> CLIP<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
let [n_batch, seq_len] = x.dims();
|
||||
|
||||
let mask = attn_decoder_mask(seq_len, &x.device());
|
||||
let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
|
||||
|
||||
let embedded = self.token_embedding.forward(x)
|
||||
+ self
|
||||
@@ -104,7 +104,7 @@ pub struct ResidualDecoderAttentionBlock<B: Backend> {
|
||||
mlp_ln: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
||||
impl<B: MyBackend> ResidualDecoderAttentionBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
|
||||
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
|
||||
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
|
||||
@@ -152,13 +152,19 @@ pub struct MultiHeadSelfAttention<B: Backend> {
|
||||
out: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiHeadSelfAttention<B> {
|
||||
impl<B: MyBackend> MultiHeadSelfAttention<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
|
||||
let q = self.query.forward(x.clone());
|
||||
let k = self.key.forward(x.clone());
|
||||
let v = self.value.forward(x);
|
||||
|
||||
let wv = qkv_attention(q, k, v, mask, self.n_head);
|
||||
let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
q.into_primitive(),
|
||||
k.into_primitive(),
|
||||
v.into_primitive(),
|
||||
mask.map(|m| m.into_primitive()),
|
||||
self.n_head,
|
||||
));
|
||||
|
||||
return self.out.forward(wv);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ use burn::{
|
||||
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
use crate::backend::Backend as MyBackend;
|
||||
|
||||
use super::autoencoder::{Autoencoder, AutoencoderConfig};
|
||||
use super::clip::{CLIPConfig, CLIP};
|
||||
use super::unet::{UNet, UNetConfig};
|
||||
@@ -44,7 +46,7 @@ pub struct StableDiffusion<B: Backend> {
|
||||
clip: CLIP<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> StableDiffusion<B> {
|
||||
impl<B: MyBackend> StableDiffusion<B> {
|
||||
pub fn sample_image(
|
||||
&self,
|
||||
context: Tensor<B, 3>,
|
||||
|
||||
Reference in New Issue
Block a user