Add first successful sampling implementation

This commit is contained in:
Gadersd
2023-08-04 17:01:44 -04:00
parent 3bf5b8c477
commit 77f30aefa7
9 changed files with 42 additions and 34 deletions

View File

@@ -35,7 +35,7 @@ pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B,
return o;
}
pub fn attn_decoder_mask<B: Backend>(seq_length: usize) -> Tensor<B, 2> {
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
for i in 0..(seq_length - 1) {
@@ -43,5 +43,5 @@ pub fn attn_decoder_mask<B: Backend>(seq_length: usize) -> Tensor<B, 2> {
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask;
return mask.to_device(device);
}