You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
53 lines
1.8 KiB
53 lines
1.8 KiB
#!/usr/bin/env python3
|
|
"""Convert a diffusers-format model to a single safetensors checkpoint."""
|
|
|
|
import argparse
|
|
import torch
|
|
from pathlib import Path
|
|
|
|
|
|
def convert(model_path: str, output_path: str, half: bool = True):
|
|
from diffusers import StableDiffusionPipeline
|
|
from safetensors.torch import save_file
|
|
|
|
dtype = torch.float16 if half else torch.float32
|
|
print(f"Loading diffusers model from {model_path}...")
|
|
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
|
|
|
state_dict = {}
|
|
|
|
print("Converting UNet...")
|
|
for k, v in pipe.unet.state_dict().items():
|
|
state_dict[f"model.diffusion_model.{k}"] = v
|
|
|
|
print("Converting text encoder...")
|
|
for k, v in pipe.text_encoder.state_dict().items():
|
|
state_dict[f"cond_stage_model.transformer.{k}"] = v
|
|
|
|
print("Converting VAE...")
|
|
for k, v in pipe.vae.state_dict().items():
|
|
state_dict[f"first_stage_model.{k}"] = v
|
|
|
|
print(f"Saving to {output_path}...")
|
|
save_file(state_dict, output_path)
|
|
print("Done!")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Convert diffusers model to safetensors")
|
|
parser.add_argument("model_path", help="Path to diffusers model directory")
|
|
parser.add_argument("output_path", nargs="?", help="Output safetensors file path (default: model name in SD models dir)")
|
|
parser.add_argument("--full", action="store_true", help="Use float32 instead of float16")
|
|
args = parser.parse_args()
|
|
|
|
model_path = Path(args.model_path)
|
|
if args.output_path:
|
|
output_path = args.output_path
|
|
else:
|
|
output_path = f"/var/opt/stable-diffusion-webui/data/models/Stable-diffusion/{model_path.name}.safetensors"
|
|
|
|
convert(str(model_path), output_path, half=not args.full)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|