Replace helper functions with native burn functions

This commit is contained in:
Gadersd
2023-09-07 12:23:18 -04:00
parent a62795347f
commit f4c58c1790
20 changed files with 1091 additions and 950 deletions

View File

@@ -1,23 +1,32 @@
use burn::{
tensor::{
backend::Backend,
activation::softmax,
Tensor,
},
};
use burn::tensor::{activation::softmax, backend::Backend, Tensor};
use std::f32::NEG_INFINITY;
pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B, 3>, mask: Option<Tensor<B, 2>>, n_head: usize) -> Tensor<B, 3> {
pub fn qkv_attention<B: Backend>(
q: Tensor<B, 3>,
k: Tensor<B, 3>,
v: Tensor<B, 3>,
mask: Option<Tensor<B, 2>>,
n_head: usize,
) -> Tensor<B, 3> {
let [n_batch, n_qctx, n_state] = q.dims();
let [_, n_ctx, _] = k.dims();
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
let n_hstate = n_state / n_head;
let q = q.reshape([n_batch, n_qctx, n_head, n_hstate]).swap_dims(1, 2) * scale;
let k = k.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2).transpose() * scale;
let v = v.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2);
let q = q
.reshape([n_batch, n_qctx, n_head, n_hstate])
.swap_dims(1, 2)
* scale;
let k = k
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2)
.transpose()
* scale;
let v = v
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2);
let qk = q.matmul(k);
@@ -44,4 +53,4 @@ pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> T
}
return mask.to_device(device);
}
}