Add ability to load dump or burn model in sample binary
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::*};
|
||||
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}};
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
@@ -30,30 +30,39 @@ fn main() {
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() != 6 {
|
||||
eprintln!("Usage: {} <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
|
||||
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
|
||||
process::exit(1);
|
||||
}
|
||||
|
||||
let model_name = &args[1];
|
||||
let unconditional_guidance_scale: f64 = args[2].parse().unwrap_or_else(|_| {
|
||||
let model_type = &args[1];
|
||||
let model_name = &args[2];
|
||||
let unconditional_guidance_scale: f64 = args[3].parse().unwrap_or_else(|_| {
|
||||
eprintln!("Error: Invalid unconditional guidance scale.");
|
||||
process::exit(1);
|
||||
});
|
||||
let n_steps: usize = args[3].parse().unwrap_or_else(|_| {
|
||||
let n_steps: usize = args[4].parse().unwrap_or_else(|_| {
|
||||
eprintln!("Error: Invalid number of diffusion steps.");
|
||||
process::exit(1);
|
||||
});
|
||||
let prompt = &args[4];
|
||||
let output_image_name = &args[5];
|
||||
|
||||
let prompt = &args[5];
|
||||
let output_image_name = &args[6];
|
||||
|
||||
println!("Loading tokenizer...");
|
||||
let tokenizer = SimpleTokenizer::new().unwrap();
|
||||
println!("Loading model...");
|
||||
let sd: StableDiffusion<Backend> = load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
|
||||
let sd: StableDiffusion<Backend> = if model_type == "burn" {
|
||||
load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
|
||||
eprintln!("Error loading model: {}", err);
|
||||
process::exit(1);
|
||||
});
|
||||
})
|
||||
} else {
|
||||
load_stable_diffusion(model_name, &device).unwrap_or_else(|err| {
|
||||
eprintln!("Error loading model dump: {}", err);
|
||||
process::exit(1);
|
||||
})
|
||||
};
|
||||
|
||||
|
||||
let sd = sd.to_device(&device);
|
||||
|
||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||
|
||||
Reference in New Issue
Block a user