113 lines
5.2 KiB
Python
113 lines
5.2 KiB
Python
import asyncio
|
|
import logging
|
|
import warnings
|
|
from typing import Dict, List, Tuple # Import List and Tuple
|
|
from weakref import WeakKeyDictionary
|
|
|
|
from comfykiosk import generate_seed
|
|
from comfykiosk.image_sources import ImageSource
|
|
from comfykiosk.workflow import Workflow # Import Workflow
|
|
|
|
|
|
class PreparedGenPool(ImageSource):
|
|
def __init__(self, bucket_max: int = 10, batch_size: int = 5, registered_workflows: Dict=None,
|
|
max_retries: int = 10, initial_delay: float = 1.0, max_delay: float = 60.0, **kwargs):
|
|
super().__init__()
|
|
self.generator = kwargs.get('generator')
|
|
self.saver = kwargs.get('saver')
|
|
self.bucket_max = bucket_max
|
|
self.replenish_batch_size = batch_size
|
|
self.image_queues = WeakKeyDictionary()
|
|
self.replenish_task = None
|
|
self.registered_workflows = registered_workflows or {}
|
|
|
|
# Retry configuration
|
|
self.max_retries = max_retries
|
|
self.initial_delay = initial_delay
|
|
self.max_delay = max_delay
|
|
|
|
if self.generator is None:
|
|
raise ValueError("The 'generator' argument is required for PreparedGenPool.")
|
|
|
|
async def replenish(self):
|
|
# Check for no assigned variable, not an empty list
|
|
if self.registered_workflows is None:
|
|
warnings.warn("Assign `registered_workflows` to the `PreparedGenPool` instance, or pass it to a ComfyKiosk instance.", UserWarning)
|
|
|
|
logging.info("Replenishing image queue...")
|
|
|
|
while True: # Keep at it until there's nothing left.
|
|
# Find the workflow with the smallest queue
|
|
items: List[Tuple[int, Workflow]] = [(len(self.image_queues.get(workflow, [])), workflow) for workflow in self.registered_workflows.values()]
|
|
if not items:
|
|
break # No workflows registered
|
|
|
|
min_queue_size, target_workflow = min(items, key=lambda item: item[0], default=(float('inf'), None)) # Compare by queue size
|
|
if target_workflow is None:
|
|
break
|
|
|
|
if min_queue_size >= self.bucket_max:
|
|
break # Target queue is full
|
|
|
|
# Generate images for the target workflow with exponential backoff
|
|
# (The delay resets with the next batch)
|
|
delay = self.initial_delay
|
|
|
|
# Done in batches because switching between workflows likely means loading new models into memory
|
|
# This can more than triple the time it takes to generate an image. Batching images amortizes that cost.
|
|
# Note that only one request is made to the backend at any time, so we must wait for any queue at the
|
|
# ComfyUI server to be clear before our request is serviced, for each image.
|
|
for _ in range(self.replenish_batch_size):
|
|
retry_count = 0
|
|
while retry_count < self.max_retries:
|
|
try:
|
|
seed = generate_seed()
|
|
image_data, media_type = await self.generator.generate_image(seed=seed, workflow=target_workflow)
|
|
await self.saver.save_image(seed, image_data, workflow=target_workflow)
|
|
self.image_queues.setdefault(target_workflow, []).append(seed)
|
|
break # Success - continue to next image
|
|
except Exception as e:
|
|
retry_count += 1
|
|
if retry_count >= self.max_retries:
|
|
logging.error(f"Failed to generate image after {self.max_retries} attempts: {e}")
|
|
break
|
|
|
|
logging.warning(f"Error generating image (attempt {retry_count}/{self.max_retries}): {e}")
|
|
await asyncio.sleep(min(delay, self.max_delay))
|
|
delay *= 2 # Exponential backoff
|
|
|
|
# Log buffer status
|
|
if not self.image_queues:
|
|
logging.info("Buffer Status: All queues are empty")
|
|
else:
|
|
status = ["Buffer Status:"]
|
|
for workflow, queue in self.image_queues.items():
|
|
status.append(f" {workflow.handle}: {len(queue)} images")
|
|
logging.info("\n".join(status))
|
|
|
|
def start_replenish(self):
|
|
if self.replenish_task is None or self.replenish_task.done(): # Start only if not already running
|
|
self.replenish_task = asyncio.create_task(self.replenish())
|
|
self.replenish_task.add_done_callback(self._replenish_finished)
|
|
|
|
async def get_image(self, seed=None, workflow: Workflow=None):
|
|
if workflow is None:
|
|
raise ValueError("The 'workflow' argument is required for PreparedGenPool.get_image().")
|
|
|
|
self.start_replenish()
|
|
|
|
image_queue = self.image_queues.get(workflow)
|
|
if image_queue:
|
|
seed = image_queue.pop(0)
|
|
return await self.saver.get_image(seed, workflow=workflow)
|
|
else:
|
|
raise asyncio.QueueEmpty()
|
|
|
|
def _replenish_finished(self, task):
|
|
self.replenish_task = None # Reset the task when finished
|
|
|
|
async def on_app_startup(self):
|
|
"""Start the replenishment loop when the FastAPI app starts"""
|
|
self.start_replenish()
|
|
|