diff --git a/Cargo.toml b/Cargo.toml index 6a2f234..1dbeb84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,13 +5,26 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = ["torch-backend"] +torch-backend = ["burn-tch"] +wgpu-backend = ["burn-wgpu"] + +[dependencies.burn-tch] +version = "0.8.0" +optional = true + +[dependencies.burn-wgpu] +version = "0.8.0" +optional = true + [dependencies] burn = "0.8.0" -burn-tch = "0.8.0" serde = {version = "1.0.171", features = ["std", "derive"]} npy = "0.4.0" num-traits = "0.2.15" rust_tokenizers = "8.1.0" regex = "1.9.1" image = "0.24.6" -bincode = {version = "2.0.0-alpha.0", features = ["std"]} \ No newline at end of file +bincode = {version = "2.0.0-alpha.0", features = ["std"]} +cfg-if = "0.1" \ No newline at end of file diff --git a/README.md b/README.md index 491bc22..efc9fbe 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Start by downloading the SDv1-4.bin model provided on HuggingFace. wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin ``` -Next, set the appropriate CUDA version. +Next, set the appropriate CUDA version. It may be possible to run the model using wgpu without the need for torch in the future using `cargo run --features wgpu-backend...` but currently wgpu doesn't support buffer sizes large enough for Stable Diffusion. ```bash export TORCH_CUDA_VERSION=cu113 diff --git a/src/bin/convert/main.rs b/src/bin/convert/main.rs index 223aa89..d2f6b53 100644 --- a/src/bin/convert/main.rs +++ b/src/bin/convert/main.rs @@ -14,7 +14,13 @@ use burn::{ }, }; -use burn_tch::{TchBackend, TchDevice}; +cfg_if::cfg_if! { + if #[cfg(feature = "torch-backend")] { + use burn_tch::{TchBackend, TchDevice}; + } else if #[cfg(feature = "wgpu-backend")] { + use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi}; + } +} use burn::record::{self, Recorder, FullPrecisionSettings}; use stablediffusion::binrecorderfast::{BinFileRecorderBuffered}; @@ -38,8 +44,15 @@ fn save_model_file(model: StableDiffusion, name: &str) -> Result< } fn main() { - type Backend = TchBackend; - let device = TchDevice::Cpu; + cfg_if::cfg_if! { + if #[cfg(feature = "torch-backend")] { + type Backend = TchBackend; + let device = TchDevice::Cpu; + } else if #[cfg(feature = "wgpu-backend")] { + type Backend = WgpuBackend; + let device = WgpuDevice::CPU; + } + } let args: Vec = env::args().collect(); if args.len() != 3 { diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index 8efeea0..245e92f 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -9,7 +9,14 @@ use burn::{ Tensor, }, }; -use burn_tch::{TchBackend, TchDevice}; + +cfg_if::cfg_if! { + if #[cfg(feature = "torch-backend")] { + use burn_tch::{TchBackend, TchDevice}; + } else if #[cfg(feature = "wgpu-backend")] { + use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi}; + } +} use std::env; use std::io; @@ -25,9 +32,15 @@ fn load_stable_diffusion_model_file(filename: &str) -> Result; - //let device = TchDevice::Cpu; - let device = TchDevice::Cuda(0); + cfg_if::cfg_if! { + if #[cfg(feature = "torch-backend")] { + type Backend = TchBackend; + let device = TchDevice::Cuda(0); + } else if #[cfg(feature = "wgpu-backend")] { + type Backend = WgpuBackend; + let device = WgpuDevice::BestAvailable; + } + } let args: Vec = std::env::args().collect(); if args.len() != 7 {