comfykiosk/comfykiosk/image_sources/pregenerate.py
2024-12-25 09:30:53 -05:00

113 lines
5.2 KiB
Python

import asyncio
import logging
import warnings
from typing import Dict, List, Tuple # Import List and Tuple
from weakref import WeakKeyDictionary
from comfykiosk import generate_seed
from comfykiosk.image_sources import ImageSource
from comfykiosk.workflow import Workflow # Import Workflow
class PreparedGenPool(ImageSource):
def __init__(self, bucket_max: int = 10, batch_size: int = 5, registered_workflows: Dict=None,
max_retries: int = 10, initial_delay: float = 1.0, max_delay: float = 60.0, **kwargs):
super().__init__()
self.generator = kwargs.get('generator')
self.saver = kwargs.get('saver')
self.bucket_max = bucket_max
self.replenish_batch_size = batch_size
self.image_queues = WeakKeyDictionary()
self.replenish_task = None
self.registered_workflows = registered_workflows or {}
# Retry configuration
self.max_retries = max_retries
self.initial_delay = initial_delay
self.max_delay = max_delay
if self.generator is None:
raise ValueError("The 'generator' argument is required for PreparedGenPool.")
async def replenish(self):
# Check for no assigned variable, not an empty list
if self.registered_workflows is None:
warnings.warn("Assign `registered_workflows` to the `PreparedGenPool` instance, or pass it to a ComfyKiosk instance.", UserWarning)
logging.info("Replenishing image queue...")
while True: # Keep at it until there's nothing left.
# Find the workflow with the smallest queue
items: List[Tuple[int, Workflow]] = [(len(self.image_queues.get(workflow, [])), workflow) for workflow in self.registered_workflows.values()]
if not items:
break # No workflows registered
min_queue_size, target_workflow = min(items, key=lambda item: item[0], default=(float('inf'), None)) # Compare by queue size
if target_workflow is None:
break
if min_queue_size >= self.bucket_max:
break # Target queue is full
# Generate images for the target workflow with exponential backoff
# (The delay resets with the next batch)
delay = self.initial_delay
# Done in batches because switching between workflows likely means loading new models into memory
# This can more than triple the time it takes to generate an image. Batching images amortizes that cost.
# Note that only one request is made to the backend at any time, so we must wait for any queue at the
# ComfyUI server to be clear before our request is serviced, for each image.
for _ in range(self.replenish_batch_size):
retry_count = 0
while retry_count < self.max_retries:
try:
seed = generate_seed()
image_data, media_type = await self.generator.generate_image(seed=seed, workflow=target_workflow)
await self.saver.save_image(seed, image_data, workflow=target_workflow)
self.image_queues.setdefault(target_workflow, []).append(seed)
break # Success - continue to next image
except Exception as e:
retry_count += 1
if retry_count >= self.max_retries:
logging.error(f"Failed to generate image after {self.max_retries} attempts: {e}")
break
logging.warning(f"Error generating image (attempt {retry_count}/{self.max_retries}): {e}")
await asyncio.sleep(min(delay, self.max_delay))
delay *= 2 # Exponential backoff
# Log buffer status
if not self.image_queues:
logging.info("Buffer Status: All queues are empty")
else:
status = ["Buffer Status:"]
for workflow, queue in self.image_queues.items():
status.append(f" {workflow.handle}: {len(queue)} images")
logging.info("\n".join(status))
def start_replenish(self):
if self.replenish_task is None or self.replenish_task.done(): # Start only if not already running
self.replenish_task = asyncio.create_task(self.replenish())
self.replenish_task.add_done_callback(self._replenish_finished)
async def get_image(self, seed=None, workflow: Workflow=None):
if workflow is None:
raise ValueError("The 'workflow' argument is required for PreparedGenPool.get_image().")
self.start_replenish()
image_queue = self.image_queues.get(workflow)
if image_queue:
seed = image_queue.pop(0)
return await self.saver.get_image(seed, workflow=workflow)
else:
raise asyncio.QueueEmpty()
def _replenish_finished(self, task):
self.replenish_task = None # Reset the task when finished
async def on_app_startup(self):
"""Start the replenishment loop when the FastAPI app starts"""
self.start_replenish()