Replace helper functions with native burn functions

This commit is contained in:
Gadersd
2023-09-07 12:23:18 -04:00
parent a62795347f
commit f4c58c1790
20 changed files with 1091 additions and 950 deletions

View File

@@ -1,20 +1,22 @@
use std::error::Error;
use burn::tensor::ElementConversion;
use std::error::Error;
use burn::{
config::Config,
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
tensor::{backend::Backend, Tensor},
};
use super::*;
use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::load_unet, clip::load::load_clip};
use crate::model::{
autoencoder::load::load_autoencoder, clip::load::load_clip, load::*, unet::load::load_unet,
};
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> {
pub fn load_stable_diffusion<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
let n_steps = load_usize::<B>("n_steps", path, device)?;
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
@@ -22,11 +24,10 @@ pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Resu
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
Ok(StableDiffusion {
n_steps,
alpha_cumulative_products,
autoencoder,
diffusion,
clip,
n_steps,
alpha_cumulative_products,
autoencoder,
diffusion,
clip,
})
}