comfykiosk/comfykiosk/comfy_networking.py
2024-12-25 09:30:53 -05:00

166 lines
6.6 KiB
Python

from typing import List
import logging
import websockets
import urllib.request
import json
import uuid
import http.cookiejar
import httpx
class DowngradeInfoFilter(logging.Filter):
def filter(self, record):
if record.levelno == logging.INFO:
record.levelno = logging.DEBUG
record.levelname = 'DEBUG'
return True
# Configure httpx logger to downgrade INFO to DEBUG
httpx_logger = logging.getLogger("httpx")
httpx_logger.addFilter(DowngradeInfoFilter())
# Function to handle redirects and store cookies
async def open_websocket_connection(comfyui_url):
client_id = str(uuid.uuid4())
cookie_jar = http.cookiejar.CookieJar() # Initialize a cookie jar
opener = urllib.request.build_opener(urllib.request.HTTPCookieProcessor(cookie_jar))
urllib.request.install_opener(opener) # Install the opener to handle cookies globally
try:
ws = await websockets.connect(f"ws://{comfyui_url}/ws?clientId={client_id}")
return ws, client_id
except websockets.InvalidStatusCode as e:
if e.status_code in (301, 302, 307, 308): # Check for redirect status codes
print(f"Redirect detected: {e.status_code}")
location = e.headers.get("Location")
if location:
print(f"Following redirect to: {location}")
# Make a request to the redirect URL to store cookies
try:
urllib.request.urlopen(location)
except Exception as redirect_request_error:
print(f"Error following redirect: {redirect_request_error}")
raise
print(f"Retrying websocket connection to original URL: {comfyui_url}")
return await open_websocket_connection(comfyui_url) # Retry with original URL and stored cookies
else:
print("Redirect location not found.")
raise
else:
print(f"Failed to open websocket connection: {e}")
raise
except Exception as e:
print(f"Failed to open websocket connection: {e}")
raise
def queue_prompt(comfyui_url, prompt, client_id):
p = {"prompt": prompt, "client_id": client_id}
headers = {'Content-Type': 'application/json'}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request(f"http://{comfyui_url}/prompt", data=data, headers=headers)
try:
response = urllib.request.urlopen(req)
response_data = json.loads(response.read())
return response_data
except urllib.error.HTTPError as e:
error_body = e.read().decode('utf-8')
print(f"Failed to queue prompt. HTTPError: {e.code} - {e.reason}. Response body: {error_body}")
raise
except Exception as e:
print(f"Failed to queue prompt. Unexpected error: {e}")
raise
async def track_progress(prompt, ws, prompt_id):
node_ids = list(prompt.keys())
finished_nodes = []
while True:
try:
out = await ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'progress':
data = message['data']
current_step = data['value']
if message['type'] == 'execution_cached':
data = message['data']
for itm in data['nodes']:
if itm not in finished_nodes:
finished_nodes.append(itm)
if message['type'] == 'executing':
data = message['data']
if data['node'] not in finished_nodes:
finished_nodes.append(data['node'])
if data['node'] is None and data['prompt_id'] == prompt_id:
break # Execution is done
else:
continue
except (websockets.exceptions.ConnectionClosedError, websockets.exceptions.ConnectionClosedOK, websockets.WebSocketException) as e: # Catch correct exception
print(f"Websocket connection closed: {e}")
break
return
async def get_history(prompt_id, comfyui_url):
async with httpx.AsyncClient() as client:
try:
response = await client.get(f"http://{comfyui_url}/history/{prompt_id}")
response.raise_for_status()
comfyui_status: dict = response.json()[prompt_id]["status"]
if comfyui_status["status_str"] == "error":
for message in comfyui_status["messages"]:
if message[0] == "execution_error":
print(f"ComfyUI threw an exception: {message[1]["exception_message"]}")
raise
return response.json()[prompt_id]
except httpx.HTTPError as e:
print(f"Failed to get image. HTTPError: {e}")
raise
async def get_image(filename, subfolder, folder_type, comfyui_url) -> bytes:
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
async with httpx.AsyncClient() as client:
try:
response = await client.get(f"http://{comfyui_url}/view?{url_values}")
response.raise_for_status()
return response.content
except httpx.HTTPError as e:
print(f"Failed to get image. HTTPError: {e}")
raise
async def get_images(prompt_id, server_address, allow_preview = False) -> List[dict]:
output_images = []
history = await get_history(prompt_id, server_address)
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
output_data = {}
if 'images' in node_output:
for image in node_output['images']:
if allow_preview and image['type'] == 'temp':
preview_data = await get_image(image['filename'], image['subfolder'], image['type'], server_address)
output_data['image_data'] = preview_data
if image['type'] == 'output':
image_data = await get_image(image['filename'], image['subfolder'], image['type'], server_address)
output_data['image_data'] = image_data
output_data['file_name'] = image['filename']
output_data['type'] = image['type']
output_images.append(output_data)
return output_images
async def execute_comfyui_prompt(comfyui_url, prompt):
ws, client_id = await open_websocket_connection(comfyui_url)
queued_prompt = queue_prompt(comfyui_url, prompt, client_id)
prompt_id = queued_prompt['prompt_id']
await track_progress(prompt, ws, prompt_id)
await ws.close()
images = await get_images(prompt_id, comfyui_url)
return images