You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
278 lines
9.3 KiB
278 lines
9.3 KiB
import threading
|
|
import datetime
|
|
import random
|
|
import base64
|
|
import io
|
|
import json
|
|
import os
|
|
from diffusers import (
|
|
StableDiffusionPipeline,
|
|
StableDiffusionXLPipeline,
|
|
DPMSolverMultistepScheduler,
|
|
)
|
|
import torch
|
|
|
|
from models import GenerationOptions, ImageParams, ImageMetadata, ImageResult
|
|
|
|
|
|
# --- Model Loaders ---
|
|
# To add a new model type, create a loader function and register it in MODEL_LOADERS
|
|
|
|
def load_sd15(model_path, device, is_single_file):
|
|
"""Load Stable Diffusion 1.5 model."""
|
|
if is_single_file:
|
|
pipe = StableDiffusionPipeline.from_single_file(
|
|
model_path,
|
|
torch_dtype=torch.float16,
|
|
)
|
|
pipe.safety_checker = None
|
|
pipe.requires_safety_checker = False
|
|
else:
|
|
pipe = StableDiffusionPipeline.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.float16,
|
|
safety_checker=None,
|
|
)
|
|
return pipe
|
|
|
|
|
|
def load_sdxl(model_path, device, is_single_file):
|
|
"""Load Stable Diffusion XL model."""
|
|
if is_single_file:
|
|
pipe = StableDiffusionXLPipeline.from_single_file(
|
|
model_path,
|
|
torch_dtype=torch.float16,
|
|
)
|
|
else:
|
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.float16,
|
|
)
|
|
return pipe
|
|
|
|
|
|
MODEL_LOADERS = {
|
|
"sd15": load_sd15,
|
|
"sdxl": load_sdxl,
|
|
}
|
|
|
|
|
|
# --- Pipeline Manager ---
|
|
|
|
class SDPipeline:
|
|
_instance = None
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._initialized = False
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if self._initialized:
|
|
return
|
|
self._initialized = True
|
|
self._generation_lock = threading.Lock()
|
|
self.device = "cuda"
|
|
self.pipe = None
|
|
self.model_path = os.environ.get("SD_MODEL_PATH", "./models/realistic-vision-v51")
|
|
self.model_type = os.environ.get("SD_MODEL_TYPE", "sd15")
|
|
self.low_vram = os.environ.get("SD_LOW_VRAM", "").lower() in ("1", "true", "yes")
|
|
self.lora_stack = self._parse_lora_stack(os.environ.get("SD_LORA_STACK", ""))
|
|
self.quality_keywords = "hyper detail, Canon50, cinematic lighting, realistic, f/1.4, ISO 200, 1/160s, 8K, RAW, unedited"
|
|
|
|
def _parse_lora_stack(self, lora_env: str) -> list[tuple[str, float]]:
|
|
"""Parse SD_LORA_STACK env var into list of (path, weight) tuples.
|
|
|
|
Format: path/to/lora.safetensors:0.8,path/to/other.safetensors:0.5
|
|
"""
|
|
if not lora_env.strip():
|
|
return []
|
|
|
|
result = []
|
|
for entry in lora_env.split(","):
|
|
entry = entry.strip()
|
|
if not entry:
|
|
continue
|
|
if ":" in entry:
|
|
path, weight_str = entry.rsplit(":", 1)
|
|
weight = float(weight_str)
|
|
else:
|
|
path = entry
|
|
weight = 1.0
|
|
result.append((path, weight))
|
|
return result
|
|
|
|
def load(self):
|
|
"""Load the model into GPU memory."""
|
|
if self.pipe is not None:
|
|
return
|
|
|
|
if not os.path.exists(self.model_path):
|
|
raise FileNotFoundError(f"Model not found: {self.model_path}")
|
|
|
|
if self.model_type not in MODEL_LOADERS:
|
|
available = ", ".join(MODEL_LOADERS.keys())
|
|
raise ValueError(f"Unknown model type '{self.model_type}'. Available: {available}")
|
|
|
|
print(f"Loading model ({self.model_type}) from {self.model_path}...")
|
|
|
|
is_single_file = self.model_path.endswith((".safetensors", ".ckpt"))
|
|
loader = MODEL_LOADERS[self.model_type]
|
|
self.pipe = loader(self.model_path, self.device, is_single_file)
|
|
|
|
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
|
self.pipe.scheduler.config,
|
|
use_karras_sigmas=True,
|
|
)
|
|
|
|
if self.low_vram:
|
|
self.pipe.enable_sequential_cpu_offload()
|
|
self.pipe.vae.enable_slicing()
|
|
self.pipe.vae.enable_tiling()
|
|
print("Low VRAM mode: enabled sequential CPU offload and VAE slicing/tiling")
|
|
else:
|
|
self.pipe = self.pipe.to(self.device)
|
|
self.pipe.enable_attention_slicing()
|
|
|
|
self._load_loras()
|
|
|
|
print("Model loaded successfully!")
|
|
|
|
def _load_loras(self):
|
|
"""Load LoRA weights from SD_LORA_STACK configuration."""
|
|
if not self.lora_stack:
|
|
return
|
|
|
|
adapter_names = []
|
|
adapter_weights = []
|
|
|
|
for i, (path, weight) in enumerate(self.lora_stack):
|
|
if not os.path.exists(path):
|
|
print(f"Warning: LoRA not found, skipping: {path}")
|
|
continue
|
|
|
|
adapter_name = f"lora_{i}"
|
|
print(f"Loading LoRA: {path} (weight={weight})")
|
|
self.pipe.load_lora_weights(path, adapter_name=adapter_name)
|
|
adapter_names.append(adapter_name)
|
|
adapter_weights.append(weight)
|
|
|
|
if adapter_names:
|
|
self.pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
|
print(f"Loaded {len(adapter_names)} LoRA(s)")
|
|
|
|
def generate_stream(self, options: GenerationOptions):
|
|
"""Generate images and yield results one by one."""
|
|
if self.pipe is None:
|
|
self.load()
|
|
|
|
seed = options.seed if options.seed is not None else self._random_seed()
|
|
|
|
with self._generation_lock:
|
|
for i in range(options.count):
|
|
params = self._compute_params(options, seed, i)
|
|
full_prompt = f"{options.prompt}, {self.quality_keywords}" if options.add_quality_keywords else options.prompt
|
|
|
|
image = self._generate_image(full_prompt, options.negative_prompt, params, options.width, options.height)
|
|
result = self._save_and_encode(image, options, params, full_prompt, i)
|
|
yield result
|
|
|
|
def _compute_params(self, options: GenerationOptions, seed: int, index: int) -> ImageParams:
|
|
"""Compute generation parameters for a single image."""
|
|
current_seed = seed + index if options.increment_seed else seed
|
|
|
|
if options.vary_guidance and options.count > 1:
|
|
t = index / (options.count - 1)
|
|
current_guidance = options.guidance_low + t * (options.guidance_high - options.guidance_low)
|
|
else:
|
|
current_guidance = options.guidance_scale
|
|
|
|
if options.vary_steps and options.count > 1:
|
|
t = index / (options.count - 1)
|
|
current_steps = int(options.steps_low + t * (options.steps_high - options.steps_low))
|
|
else:
|
|
current_steps = options.steps
|
|
|
|
return ImageParams(
|
|
seed=current_seed,
|
|
steps=current_steps,
|
|
guidance_scale=current_guidance,
|
|
)
|
|
|
|
def _generate_image(self, prompt: str, negative_prompt: str, params: ImageParams, width: int | None, height: int | None):
|
|
"""Run the diffusion pipeline to generate a single image."""
|
|
if self.low_vram:
|
|
torch.cuda.empty_cache()
|
|
|
|
gen_device = "cpu" if self.low_vram else self.device
|
|
generator = torch.Generator(device=gen_device)
|
|
generator.manual_seed(params.seed)
|
|
|
|
kwargs = {
|
|
"prompt": prompt,
|
|
"num_inference_steps": params.steps,
|
|
"guidance_scale": params.guidance_scale,
|
|
"generator": generator,
|
|
}
|
|
if negative_prompt:
|
|
kwargs["negative_prompt"] = negative_prompt
|
|
if width:
|
|
kwargs["width"] = width
|
|
if height:
|
|
kwargs["height"] = height
|
|
|
|
with torch.no_grad():
|
|
result = self.pipe(**kwargs)
|
|
return result.images[0]
|
|
|
|
def _save_and_encode(self, image, options: GenerationOptions, params: ImageParams, full_prompt: str, index: int) -> ImageResult:
|
|
"""Save image to disk and encode as base64."""
|
|
dt = datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S")
|
|
base_file = f"out/{dt}_{params.seed}"
|
|
|
|
image.save(f"{base_file}.jpg")
|
|
|
|
width = options.width or image.width
|
|
height = options.height or image.height
|
|
|
|
metadata = ImageMetadata(
|
|
prompt=options.prompt,
|
|
negative_prompt=options.negative_prompt,
|
|
seed=params.seed,
|
|
steps=params.steps,
|
|
guidance_scale=params.guidance_scale,
|
|
width=width,
|
|
height=height,
|
|
add_quality_keywords=options.add_quality_keywords,
|
|
full_prompt=full_prompt,
|
|
)
|
|
|
|
with open(f"{base_file}.json", "w") as f:
|
|
json.dump(metadata.to_dict(), f, indent=2)
|
|
|
|
buffer = io.BytesIO()
|
|
image.save(buffer, format="JPEG")
|
|
b64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
|
|
return ImageResult(
|
|
index=index + 1,
|
|
total=options.count,
|
|
filename=f"{dt}_{params.seed}.jpg",
|
|
url=f"/out/{dt}_{params.seed}.jpg",
|
|
base64=f"data:image/jpeg;base64,{b64_image}",
|
|
metadata=metadata,
|
|
)
|
|
|
|
def _random_seed(self, length=9):
|
|
"""Generate a random seed with the specified number of digits."""
|
|
random.seed()
|
|
min_val = 10 ** (length - 1)
|
|
max_val = 10 ** length - 1
|
|
return random.randint(min_val, max_val)
|
|
|
|
|
|
pipeline = SDPipeline()
|
|
|