Replace helper functions with native burn functions
This commit is contained in:
@@ -1,20 +1,22 @@
|
||||
use std::error::Error;
|
||||
use burn::tensor::ElementConversion;
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::load_unet, clip::load::load_clip};
|
||||
use crate::model::{
|
||||
autoencoder::load::load_autoencoder, clip::load::load_clip, load::*, unet::load::load_unet,
|
||||
};
|
||||
|
||||
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
||||
pub fn load_stable_diffusion<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
||||
let n_steps = load_usize::<B>("n_steps", path, device)?;
|
||||
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
|
||||
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
|
||||
@@ -22,11 +24,10 @@ pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Resu
|
||||
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
|
||||
|
||||
Ok(StableDiffusion {
|
||||
n_steps,
|
||||
alpha_cumulative_products,
|
||||
autoencoder,
|
||||
diffusion,
|
||||
clip,
|
||||
n_steps,
|
||||
alpha_cumulative_products,
|
||||
autoencoder,
|
||||
diffusion,
|
||||
clip,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user