Use wgpu by default and ndarray for convert
This commit is contained in:
@@ -59,6 +59,11 @@ impl<B: Backend> StableDiffusion<B> {
|
||||
let [n_batch, _, _] = context.dims();
|
||||
|
||||
let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||
self.latent_to_image(latent)
|
||||
}
|
||||
|
||||
pub fn latent_to_image(&self, latent: Tensor<B, 4>) -> Vec<Vec<u8>> {
|
||||
let [n_batch, _, _, _] = latent.dims();
|
||||
let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215));
|
||||
|
||||
let n_channel = 3;
|
||||
@@ -157,7 +162,7 @@ impl<B: Backend> StableDiffusion<B> {
|
||||
}
|
||||
|
||||
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
|
||||
let device = &self.devices()[0];
|
||||
let device = &self.clip.devices()[0];
|
||||
let text = format!("<|startoftext|>{}<|endoftext|>", text);
|
||||
let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user