use super::GroupNorm; use crate::model::load::*; use std::error::Error; use burn::{ config::Config, module::{Module, Param}, nn, tensor::{ backend::Backend, Tensor, }, }; use super::*; use crate::model::groupnorm::load::load_group_norm; pub fn load_res_block(path: &str, device: &B::Device) -> Result, Box> { let norm_in = load_group_norm::(&format!("{}/{}", path, "norm_in"), device)?; let conv_in = load_conv2d::(&format!("{}/{}", path, "conv_in"), device)?; let lin_embed = load_linear::(&format!("{}/{}", path, "lin_embed"), device)?; let norm_out = load_group_norm::(&format!("{}/{}", path, "norm_out"), device)?; let conv_out = load_conv2d::(&format!("{}/{}", path, "conv_out"), device)?; let skip_connection = load_conv2d::(&format!("{}/{}", path, "skip_connection"), device).ok(); let res_block = ResBlock { norm_in: norm_in, silu_in: SILU::new(), conv_in: conv_in, silu_embed: SILU::new(), lin_embed: lin_embed, norm_out: norm_out, silu_out: SILU::new(), conv_out: conv_out, skip_connection: skip_connection, }; Ok(res_block) } pub fn load_multi_head_attention(path: &str, device: &B::Device) -> Result, Box> { let n_head = load_usize::("n_head", path, device)?; let query = load_linear::(&format!("{}/{}", path, "query"), device)?; let key = load_linear::(&format!("{}/{}", path, "key"), device)?; let value = load_linear::(&format!("{}/{}", path, "value"), device)?; let out = load_linear::(&format!("{}/{}", path, "out"), device)?; let multi_head_attention = MultiHeadAttention { n_head: n_head, query: query, key: key, value: value, out: out, }; Ok(multi_head_attention) } pub fn load_geglu(path: &str, device: &B::Device) -> Result, Box> { let proj = load_linear::(&format!("{}/{}", path, "proj"), device)?; let geglue = GEGLU { proj: proj, gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct }; Ok(geglue) } pub fn load_mlp(path: &str, device: &B::Device) -> Result, Box> { let geglu = load_geglu::(&format!("{}/{}", path, "geglu"), device)?; let lin = load_linear::(&format!("{}/{}", path, "lin"), device)?; let mlp = MLP { geglu: geglu, lin: lin, }; Ok(mlp) } pub fn load_transformer_block(path: &str, device: &B::Device) -> Result, Box> { let norm1 = load_layer_norm::(&format!("{}/{}", path, "norm1"), device)?; let attn1 = load_multi_head_attention::(&format!("{}/{}", path, "attn1"), device)?; let norm2 = load_layer_norm::(&format!("{}/{}", path, "norm2"), device)?; let attn2 = load_multi_head_attention::(&format!("{}/{}", path, "attn2"), device)?; let norm3 = load_layer_norm::(&format!("{}/{}", path, "norm3"), device)?; let mlp = load_mlp::(&format!("{}/{}", path, "mlp"), device)?; let transformer_block = TransformerBlock { norm1: norm1, attn1: attn1, norm2: norm2, attn2: attn2, norm3: norm3, mlp: mlp, }; Ok(transformer_block) } pub fn load_spatial_transformer(path: &str, device: &B::Device) -> Result, Box> { let norm = load_group_norm::(&format!("{}/{}", path, "norm"), device)?; let proj_in = load_conv2d::(&format!("{}/{}", path, "proj_in"), device)?; let transformer = load_transformer_block::(&format!("{}/{}", path, "transformer"), device)?; let proj_out = load_conv2d::(&format!("{}/{}", path, "proj_out"), device)?; let spatial_transformer = SpatialTransformer { norm: norm, proj_in: proj_in, transformer: transformer, proj_out: proj_out, }; Ok(spatial_transformer) } pub fn load_upsample(path: &str, device: &B::Device) -> Result, Box> { let conv = load_conv2d::(&format!("{}/{}", path, "conv"), device)?; let upsample = Upsample { conv: conv, }; Ok(upsample) } pub fn load_downsample(path: &str, device: &B::Device) -> Result, Box> { load_conv2d(path, device) } pub fn load_res_transformer_res(path: &str, device: &B::Device) -> Result, Box> { let res1 = load_res_block::(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function let transformer = load_spatial_transformer::(&format!("{}/{}", path, "transformer"), device)?; let res2 = load_res_block::(&format!("{}/{}", path, "res2"), device)?; let res_transformer_res = ResTransformerRes { res1: res1, transformer: transformer, res2: res2, }; Ok(res_transformer_res) } pub fn load_res_transformer_upsample(path: &str, device: &B::Device) -> Result, Box> { let res = load_res_block::(&format!("{}/{}", path, "res"), device)?; let transformer = load_spatial_transformer::(&format!("{}/{}", path, "transformer"), device)?; let upsample = load_upsample::(&format!("{}/{}", path, "upsample"), device)?; let res_transformer_upsample = ResTransformerUpsample { res: res, transformer: transformer, upsample: upsample, }; Ok(res_transformer_upsample) } pub fn load_res_upsample(path: &str, device: &B::Device) -> Result, Box> { let res = load_res_block::(&format!("{}/{}", path, "res"), device)?; let upsample = load_upsample::(&format!("{}/{}", path, "upsample"), device)?; let res_upsample = ResUpSample { res: res, upsample: upsample, }; Ok(res_upsample) } pub fn load_res_transformer(path: &str, device: &B::Device) -> Result, Box> { let res = load_res_block::(&format!("{}/{}", path, "res"), device)?; let transformer = load_spatial_transformer::(&format!("{}/{}", path, "transformer"), device)?; let res_transformer = ResTransformer { res: res, transformer: transformer, }; Ok(res_transformer) } pub fn load_unet_input_blocks(path: &str, device: &B::Device) -> Result, Box> { let conv = load_conv2d::(&format!("{}/{}", path, "conv"), device)?; let rt1 = load_res_transformer::(&format!("{}/{}", path, "rt1"), device)?; let rt2 = load_res_transformer::(&format!("{}/{}", path, "rt2"), device)?; let d1 = load_downsample::(&format!("{}/{}", path, "d1"), device)?; let rt3 = load_res_transformer::(&format!("{}/{}", path, "rt3"), device)?; let rt4 = load_res_transformer::(&format!("{}/{}", path, "rt4"), device)?; let d2 = load_downsample::(&format!("{}/{}", path, "d2"), device)?; let rt5 = load_res_transformer::(&format!("{}/{}", path, "rt5"), device)?; let rt6 = load_res_transformer::(&format!("{}/{}", path, "rt6"), device)?; let d3 = load_downsample::(&format!("{}/{}", path, "d3"), device)?; let r1 = load_res_block::(&format!("{}/{}", path, "r1"), device)?; let r2 = load_res_block::(&format!("{}/{}", path, "r2"), device)?; let unet_input_blocks = UNetInputBlocks { conv: conv, rt1: rt1, rt2: rt2, d1: d1, rt3: rt3, rt4: rt4, d2: d2, rt5: rt5, rt6: rt6, d3: d3, r1: r1, r2: r2, }; Ok(unet_input_blocks) } pub fn load_unet_output_blocks(path: &str, device: &B::Device) -> Result, Box> { let r1 = load_res_block::(&format!("{}/{}", path, "r1"), device)?; let r2 = load_res_block::(&format!("{}/{}", path, "r2"), device)?; let ru = load_res_upsample::(&format!("{}/{}", path, "ru"), device)?; let rt1 = load_res_transformer::(&format!("{}/{}", path, "rt1"), device)?; let rt2 = load_res_transformer::(&format!("{}/{}", path, "rt2"), device)?; let rtu1 = load_res_transformer_upsample::(&format!("{}/{}", path, "rtu1"), device)?; let rt3 = load_res_transformer::(&format!("{}/{}", path, "rt3"), device)?; let rt4 = load_res_transformer::(&format!("{}/{}", path, "rt4"), device)?; let rtu2 = load_res_transformer_upsample::(&format!("{}/{}", path, "rtu2"), device)?; let rt5 = load_res_transformer::(&format!("{}/{}", path, "rt5"), device)?; let rt6 = load_res_transformer::(&format!("{}/{}", path, "rt6"), device)?; let rt7 = load_res_transformer::(&format!("{}/{}", path, "rt7"), device)?; Ok(UNetOutputBlocks { r1, r2, ru, rt1, rt2, rtu1, rt3, rt4, rtu2, rt5, rt6, rt7, }) } pub fn load_unet(path: &str, device: &B::Device) -> Result, Box> { let lin1_time_embed = load_linear::(&format!("{}/{}", path, "lin1_time_embed"), device)?; let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct let lin2_time_embed = load_linear::(&format!("{}/{}", path, "lin2_time_embed"), device)?; let input_blocks = load_unet_input_blocks::(&format!("{}/{}", path, "input_blocks"), device)?; let middle_block = load_res_transformer_res::(&format!("{}/{}", path, "middle_block"), device)?; let output_blocks = load_unet_output_blocks::(&format!("{}/{}", path, "output_blocks"), device)?; let norm_out = load_group_norm::(&format!("{}/{}", path, "norm_out"), device)?; let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct let conv_out = load_conv2d::(&format!("{}/{}", path, "conv_out"), device)?; Ok(UNet { lin1_time_embed, silu_time_embed, lin2_time_embed, input_blocks, middle_block, output_blocks, norm_out, silu_out, conv_out, }) }