Replace helper functions with native burn functions

This commit is contained in:
Gadersd
2023-09-07 12:23:18 -04:00
parent a62795347f
commit f4c58c1790
20 changed files with 1091 additions and 950 deletions

View File

@@ -1,14 +1,11 @@
use std::error::Error;
use burn::tensor::ElementConversion;
use std::error::Error;
use burn::{
config::Config,
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
tensor::{backend::Backend, Tensor},
};
use super::*;
@@ -28,7 +25,10 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
Ok(mlp)
}
pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
pub fn load_multi_head_self_attention<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
let n_head = load_usize::<B>("n_head", path, device)?;
let query = load_linear(&format!("{}/{}", path, "query"), device)?;
let key = load_linear(&format!("{}/{}", path, "key"), device)?;
@@ -46,7 +46,10 @@ pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device
Ok(mhsa)
}
pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
pub fn load_residual_decoder_attention_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
@@ -64,15 +67,17 @@ pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B:
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
let position_embedding =
load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
let n_layer = load_usize::<B>("n_layer", path, device)?;
let mut blocks = (0..n_layer)
.into_iter()
.map(|i| {
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
}).collect::<Result<Vec<_>, _>>()?;
})
.collect::<Result<Vec<_>, _>>()?;
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;
let clip = CLIP {
@@ -81,6 +86,6 @@ pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>,
blocks: blocks,
layer_norm: layer_norm,
};
Ok(clip)
}

View File

@@ -1,35 +1,33 @@
pub mod load;
use burn::{
config::Config,
config::Config,
module::{Module, Param},
nn,
tensor::{
activation::{sigmoid, softmax},
backend::Backend,
activation::{softmax, sigmoid},
module::embedding,
Tensor,
Distribution,
Int,
module::embedding,
Distribution, Int, Tensor,
},
};
use crate::model::attention::{qkv_attention, attn_decoder_mask};
use crate::model::attention::{attn_decoder_mask, qkv_attention};
#[derive(Config)]
pub struct CLIPConfig {
n_vocab: usize,
n_state: usize,
n_head: usize,
n_ctx: usize,
n_layer: usize,
n_vocab: usize,
n_state: usize,
n_head: usize,
n_ctx: usize,
n_layer: usize,
}
impl CLIPConfig {
pub fn init<B: Backend>(&self) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
let position_embedding = Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
let position_embedding =
Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
let blocks = (0..self.n_layer)
.into_iter()
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
@@ -37,33 +35,35 @@ impl CLIPConfig {
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
CLIP {
token_embedding,
position_embedding,
blocks,
layer_norm,
token_embedding,
position_embedding,
blocks,
layer_norm,
}
}
}
#[derive(Module, Debug)]
pub struct CLIP<B: Backend> {
token_embedding: nn::Embedding<B>,
position_embedding: Param<Tensor<B, 2>>,
blocks: Vec<ResidualDecoderAttentionBlock<B>>,
layer_norm: nn::LayerNorm<B>,
token_embedding: nn::Embedding<B>,
position_embedding: Param<Tensor<B, 2>>,
blocks: Vec<ResidualDecoderAttentionBlock<B>>,
layer_norm: nn::LayerNorm<B>,
}
impl<B: Backend> CLIP<B> {
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims();
let mask = attn_decoder_mask(seq_len, &x.device());
let embedded = self.token_embedding.forward(x)
+ self.position_embedding.val().slice([0..seq_len]).unsqueeze();
let embedded = self.token_embedding.forward(x)
+ self
.position_embedding
.val()
.slice([0..seq_len])
.unsqueeze();
let mut x = embedded;
for block in &self.blocks {
x = block.forward(x, mask.clone());
@@ -73,37 +73,35 @@ impl<B: Backend> CLIP<B> {
}
}
#[derive(Config)]
pub struct ResidualDecoderAttentionBlockConfig {
n_state: usize,
n_head: usize,
n_state: usize,
n_head: usize,
}
impl ResidualDecoderAttentionBlockConfig {
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
ResidualDecoderAttentionBlock {
attn,
attn_ln,
mlp,
mlp_ln,
attn,
attn_ln,
mlp,
mlp_ln,
}
}
}
#[derive(Module, Debug)]
pub struct ResidualDecoderAttentionBlock<B: Backend> {
attn: MultiHeadSelfAttention<B>,
attn_ln: nn::LayerNorm<B>,
mlp: MLP<B>,
mlp_ln: nn::LayerNorm<B>,
attn: MultiHeadSelfAttention<B>,
attn_ln: nn::LayerNorm<B>,
mlp: MLP<B>,
mlp_ln: nn::LayerNorm<B>,
}
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
@@ -117,12 +115,17 @@ impl<B: Backend> ResidualDecoderAttentionBlock<B> {
#[derive(Config)]
pub struct MultiHeadSelfAttentionConfig {
n_state: usize,
n_head: usize,
n_head: usize,
}
impl MultiHeadSelfAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
self.n_state,
self.n_head
);
let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
@@ -130,23 +133,23 @@ impl MultiHeadSelfAttentionConfig {
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
MultiHeadSelfAttention {
n_head,
query,
key,
value,
out
MultiHeadSelfAttention {
n_head,
query,
key,
value,
out,
}
}
}
#[derive(Module, Debug)]
pub struct MultiHeadSelfAttention<B: Backend> {
n_head: usize,
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
out: nn::Linear<B>,
n_head: usize,
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
out: nn::Linear<B>,
}
impl<B: Backend> MultiHeadSelfAttention<B> {
@@ -161,17 +164,10 @@ impl<B: Backend> MultiHeadSelfAttention<B> {
}
}
#[derive(Config, Debug)]
pub struct MLPConfig {
input_size: usize,
hidden_size: usize,
input_size: usize,
hidden_size: usize,
}
impl MLPConfig {
@@ -180,19 +176,15 @@ impl MLPConfig {
let gelu = QuickGELU::new();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
MLP {
fc1,
gelu,
fc2,
}
MLP { fc1, gelu, fc2 }
}
}
#[derive(Module, Debug)]
pub struct MLP<B: Backend> {
fc1: nn::Linear<B>,
gelu: QuickGELU,
fc2: nn::Linear<B>,
fc1: nn::Linear<B>,
gelu: QuickGELU,
fc2: nn::Linear<B>,
}
impl<B: Backend> MLP<B> {
@@ -217,4 +209,3 @@ impl QuickGELU {
x.clone() * sigmoid(x * 1.702)
}
}