166 lines
6.6 KiB
Python
166 lines
6.6 KiB
Python
from typing import List
|
|
import logging
|
|
import websockets
|
|
import urllib.request
|
|
import json
|
|
import uuid
|
|
import http.cookiejar
|
|
import httpx
|
|
|
|
class DowngradeInfoFilter(logging.Filter):
|
|
def filter(self, record):
|
|
if record.levelno == logging.INFO:
|
|
record.levelno = logging.DEBUG
|
|
record.levelname = 'DEBUG'
|
|
return True
|
|
|
|
# Configure httpx logger to downgrade INFO to DEBUG
|
|
httpx_logger = logging.getLogger("httpx")
|
|
httpx_logger.addFilter(DowngradeInfoFilter())
|
|
|
|
# Function to handle redirects and store cookies
|
|
async def open_websocket_connection(comfyui_url):
|
|
client_id = str(uuid.uuid4())
|
|
cookie_jar = http.cookiejar.CookieJar() # Initialize a cookie jar
|
|
opener = urllib.request.build_opener(urllib.request.HTTPCookieProcessor(cookie_jar))
|
|
urllib.request.install_opener(opener) # Install the opener to handle cookies globally
|
|
|
|
try:
|
|
ws = await websockets.connect(f"ws://{comfyui_url}/ws?clientId={client_id}")
|
|
return ws, client_id
|
|
except websockets.InvalidStatusCode as e:
|
|
if e.status_code in (301, 302, 307, 308): # Check for redirect status codes
|
|
print(f"Redirect detected: {e.status_code}")
|
|
location = e.headers.get("Location")
|
|
if location:
|
|
print(f"Following redirect to: {location}")
|
|
|
|
# Make a request to the redirect URL to store cookies
|
|
try:
|
|
urllib.request.urlopen(location)
|
|
except Exception as redirect_request_error:
|
|
print(f"Error following redirect: {redirect_request_error}")
|
|
raise
|
|
|
|
print(f"Retrying websocket connection to original URL: {comfyui_url}")
|
|
return await open_websocket_connection(comfyui_url) # Retry with original URL and stored cookies
|
|
else:
|
|
print("Redirect location not found.")
|
|
raise
|
|
else:
|
|
print(f"Failed to open websocket connection: {e}")
|
|
raise
|
|
except Exception as e:
|
|
print(f"Failed to open websocket connection: {e}")
|
|
raise
|
|
|
|
|
|
def queue_prompt(comfyui_url, prompt, client_id):
|
|
p = {"prompt": prompt, "client_id": client_id}
|
|
headers = {'Content-Type': 'application/json'}
|
|
data = json.dumps(p).encode('utf-8')
|
|
req = urllib.request.Request(f"http://{comfyui_url}/prompt", data=data, headers=headers)
|
|
try:
|
|
response = urllib.request.urlopen(req)
|
|
response_data = json.loads(response.read())
|
|
return response_data
|
|
except urllib.error.HTTPError as e:
|
|
error_body = e.read().decode('utf-8')
|
|
print(f"Failed to queue prompt. HTTPError: {e.code} - {e.reason}. Response body: {error_body}")
|
|
raise
|
|
except Exception as e:
|
|
print(f"Failed to queue prompt. Unexpected error: {e}")
|
|
raise
|
|
|
|
async def track_progress(prompt, ws, prompt_id):
|
|
node_ids = list(prompt.keys())
|
|
finished_nodes = []
|
|
|
|
while True:
|
|
try:
|
|
out = await ws.recv()
|
|
if isinstance(out, str):
|
|
message = json.loads(out)
|
|
if message['type'] == 'progress':
|
|
data = message['data']
|
|
current_step = data['value']
|
|
if message['type'] == 'execution_cached':
|
|
data = message['data']
|
|
for itm in data['nodes']:
|
|
if itm not in finished_nodes:
|
|
finished_nodes.append(itm)
|
|
if message['type'] == 'executing':
|
|
data = message['data']
|
|
if data['node'] not in finished_nodes:
|
|
finished_nodes.append(data['node'])
|
|
|
|
if data['node'] is None and data['prompt_id'] == prompt_id:
|
|
break # Execution is done
|
|
else:
|
|
continue
|
|
|
|
except (websockets.exceptions.ConnectionClosedError, websockets.exceptions.ConnectionClosedOK, websockets.WebSocketException) as e: # Catch correct exception
|
|
print(f"Websocket connection closed: {e}")
|
|
break
|
|
return
|
|
|
|
async def get_history(prompt_id, comfyui_url):
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
response = await client.get(f"http://{comfyui_url}/history/{prompt_id}")
|
|
response.raise_for_status()
|
|
|
|
comfyui_status: dict = response.json()[prompt_id]["status"]
|
|
if comfyui_status["status_str"] == "error":
|
|
for message in comfyui_status["messages"]:
|
|
if message[0] == "execution_error":
|
|
print(f"ComfyUI threw an exception: {message[1]["exception_message"]}")
|
|
raise
|
|
|
|
return response.json()[prompt_id]
|
|
except httpx.HTTPError as e:
|
|
print(f"Failed to get image. HTTPError: {e}")
|
|
raise
|
|
|
|
async def get_image(filename, subfolder, folder_type, comfyui_url) -> bytes:
|
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
|
url_values = urllib.parse.urlencode(data)
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
response = await client.get(f"http://{comfyui_url}/view?{url_values}")
|
|
response.raise_for_status()
|
|
return response.content
|
|
except httpx.HTTPError as e:
|
|
print(f"Failed to get image. HTTPError: {e}")
|
|
raise
|
|
|
|
async def get_images(prompt_id, server_address, allow_preview = False) -> List[dict]:
|
|
output_images = []
|
|
|
|
history = await get_history(prompt_id, server_address)
|
|
for node_id in history['outputs']:
|
|
node_output = history['outputs'][node_id]
|
|
output_data = {}
|
|
if 'images' in node_output:
|
|
for image in node_output['images']:
|
|
if allow_preview and image['type'] == 'temp':
|
|
preview_data = await get_image(image['filename'], image['subfolder'], image['type'], server_address)
|
|
output_data['image_data'] = preview_data
|
|
if image['type'] == 'output':
|
|
image_data = await get_image(image['filename'], image['subfolder'], image['type'], server_address)
|
|
output_data['image_data'] = image_data
|
|
output_data['file_name'] = image['filename']
|
|
output_data['type'] = image['type']
|
|
output_images.append(output_data)
|
|
|
|
return output_images
|
|
|
|
async def execute_comfyui_prompt(comfyui_url, prompt):
|
|
ws, client_id = await open_websocket_connection(comfyui_url)
|
|
queued_prompt = queue_prompt(comfyui_url, prompt, client_id)
|
|
prompt_id = queued_prompt['prompt_id']
|
|
await track_progress(prompt, ws, prompt_id)
|
|
await ws.close()
|
|
images = await get_images(prompt_id, comfyui_url)
|
|
return images
|