diff --git a/README.md b/README.md index f524cb2..250d4d3 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,7 @@ Open http://localhost:5000 in your browser. │ Flask Server (app.py) │ │ - GET / → Serve UI │ │ - POST /generate → Stream generated images via SSE │ +│ - POST /stop → Signal generation to stop │ │ - GET /out/ → Serve saved images │ └─────────────────────────┬───────────────────────────────────┘ │ @@ -225,6 +226,7 @@ Open http://localhost:5000 in your browser. │ - Singleton pattern keeps model in GPU memory │ │ - Thread-safe generation with locking │ │ - Yields images one-by-one for streaming │ +│ - Stop flag checked between images for cancellation │ └─────────────────────────┬───────────────────────────────────┘ │ ┌─────────────────────────▼───────────────────────────────────┐ @@ -295,6 +297,14 @@ When vary modes are enabled, the corresponding slider hides and low/high range i - No waiting for entire batch to finish - Each image card shows: seed, steps, guidance scale, prompt, link to saved file +### Stop Generation + +- **Stop button** appears during batch generation +- Signals the pipeline to stop after the current image completes +- Already-generated images are preserved +- The generation mutex is released, allowing new generations immediately +- Useful when you notice a mistake in your prompt mid-batch + ### Settings Management - **Export**: Download current settings as JSON file @@ -346,3 +356,20 @@ data: {"index":1,"total":1,"filename":"...","seed":12345,"steps":20,"guidance_sc data: {"done":true} ``` + +### POST /stop + +Signals the generation loop to stop after the current image. The frontend also aborts the SSE connection. + +Request: Empty body + +Response (JSON): +```json +{"success": true} +``` + +**Implementation notes:** +- Sets a `_stop_requested` flag on the pipeline singleton +- The generation loop checks this flag before and after each image +- The flag is cleared when generation starts or when stop is processed +- Thread-safe: the flag is checked while holding the generation lock diff --git a/app.py b/app.py index ccfba37..e7e3e9b 100644 --- a/app.py +++ b/app.py @@ -66,6 +66,12 @@ def generate(): return Response(generate_events(), mimetype='text/event-stream') +@app.route("/stop", methods=["POST"]) +def stop(): + pipeline.stop() + return jsonify({"success": True}) + + @app.route("/out/") def serve_image(filename): return send_from_directory("out", filename) diff --git a/sd_pipeline.py b/sd_pipeline.py index 4b12f40..968aed7 100644 --- a/sd_pipeline.py +++ b/sd_pipeline.py @@ -76,6 +76,7 @@ class SDPipeline: return self._initialized = True self._generation_lock = threading.Lock() + self._stop_requested = False self.device = "cuda" self.pipe = None self.model_path = os.environ.get("SD_MODEL_PATH", "./models/realistic-vision-v51") @@ -165,19 +166,33 @@ class SDPipeline: self.pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) print(f"Loaded {len(adapter_names)} LoRA(s)") + def stop(self): + """Signal the generation loop to stop.""" + self._stop_requested = True + def generate_stream(self, options: GenerationOptions): """Generate images and yield results one by one.""" if self.pipe is None: self.load() seed = options.seed if options.seed is not None else self._random_seed() + self._stop_requested = False with self._generation_lock: for i in range(options.count): + if self._stop_requested: + self._stop_requested = False + return + params = self._compute_params(options, seed, i) full_prompt = f"{options.prompt}, {self.quality_keywords}" if options.add_quality_keywords else options.prompt image = self._generate_image(full_prompt, options.negative_prompt, params, options.width, options.height) + + if self._stop_requested: + self._stop_requested = False + return + result = self._save_and_encode(image, options, params, full_prompt, i) yield result diff --git a/static/style.css b/static/style.css index 70c5b85..05fa15b 100644 --- a/static/style.css +++ b/static/style.css @@ -46,7 +46,7 @@ label { font-weight: 500; } -textarea, input[type="number"] { +textarea, input[type="number"], select { width: 100%; padding: 12px; border: 1px solid #0f3460; @@ -168,6 +168,36 @@ button[type="submit"]:disabled { cursor: not-allowed; } +.button-row { + display: flex; + gap: 10px; +} + +.button-row button[type="submit"] { + flex: 1; +} + +#stop-btn { + padding: 15px 25px; + background: #7b2d26; + border: none; + border-radius: 6px; + color: #fff; + font-size: 16px; + font-weight: 600; + cursor: pointer; + transition: background 0.2s; +} + +#stop-btn:hover { + background: #a33d33; +} + +#stop-btn:disabled { + background: #555; + cursor: not-allowed; +} + .settings-buttons { display: flex; gap: 10px; diff --git a/templates/index.html b/templates/index.html index cb11a36..106a69b 100644 --- a/templates/index.html +++ b/templates/index.html @@ -110,7 +110,10 @@ - +
+ + +
@@ -151,6 +154,7 @@ const form = document.getElementById('generate-form'); const generateBtn = document.getElementById('generate-btn'); + const stopBtn = document.getElementById('stop-btn'); const statusDiv = document.getElementById('status'); const statusText = document.getElementById('status-text'); const progressContainer = document.getElementById('progress-container'); @@ -167,6 +171,8 @@ let timePerImage = null; let progressInterval = null; let imageStartTime = null; + let isGenerating = false; + let abortController = null; const incrementSeedCheckbox = document.getElementById('increment-seed'); const varyGuidanceCheckbox = document.getElementById('vary-guidance'); @@ -428,6 +434,19 @@ } } + stopBtn.addEventListener('click', async () => { + stopBtn.disabled = true; + stopBtn.textContent = 'Stopping...'; + try { + await fetch('/stop', { method: 'POST' }); + if (abortController) { + abortController.abort(); + } + } catch (e) { + console.error('Failed to stop:', e); + } + }); + form.addEventListener('submit', async (e) => { e.preventDefault(); @@ -437,8 +456,13 @@ return; } + isGenerating = true; + abortController = new AbortController(); generateBtn.disabled = true; generateBtn.textContent = 'Generating...'; + stopBtn.style.display = 'inline-block'; + stopBtn.disabled = false; + stopBtn.textContent = 'Stop'; results.innerHTML = ''; showProgress(true); @@ -485,7 +509,8 @@ headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(data) + body: JSON.stringify(data), + signal: abortController.signal }); const reader = response.body.getReader(); @@ -542,12 +567,19 @@ } } } catch (error) { - setStatus('Error: ' + error.message, 'error'); + if (error.name === 'AbortError') { + setStatus(imageCount > 0 ? `Stopped after ${imageCount} image(s)` : 'Generation stopped', 'success'); + } else { + setStatus('Error: ' + error.message, 'error'); + } stopProgressTimer(); showProgress(false); } finally { + isGenerating = false; + abortController = null; generateBtn.disabled = false; generateBtn.textContent = 'Generate'; + stopBtn.style.display = 'none'; } });