Files
burn-stablediffusion-vibecode/src/bin/convert/main.rs

60 lines
1.6 KiB
Rust

use std::env;
use std::process;
use std::error::Error;
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion};
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
};
use burn_tch::{TchBackend, TchDevice};
use burn::record::{self, Recorder, FullPrecisionSettings};
use stablediffusion::binrecorder::{BinFileRecorderBuffered};
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
println!("Loading dump...");
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?;
println!("Saving model...");
save_model_file(model, model_name)?;
Ok(())
}
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
BinFileRecorderBuffered::<FullPrecisionSettings>::new()
.record(
model.into_record(),
name.into(),
)
}
fn main() {
type Backend = TchBackend<f32>;
let device = TchDevice::Cpu;
let args: Vec<String> = env::args().collect();
if args.len() != 3 {
eprintln!("Usage: {} <dump_path> <model_name>", args[0]);
process::exit(1);
}
let dump_path = &args[1];
let model_name = &args[2];
if let Err(e) = convert_dump_to_model::<Backend>(dump_path, model_name, &device) {
eprintln!("Failed to convert dump to model: {:?}", e);
process::exit(1);
}
println!("Successfully converted {} to {}", dump_path, model_name);
}