Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 86 additions & 30 deletions src/wavespeed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import sys
import uuid
from typing import Optional

from ._config_module import install_config_module
Expand Down Expand Up @@ -83,22 +84,41 @@ def _detect_serverless_env() -> Optional[str]:
The serverless environment type ("runpod", "waverless") or None
if not running in a known serverless environment.
"""
# Check for RunPod environment
if os.environ.get("RUNPOD_POD_ID"):
return "runpod"

# Check for native Waverless environment
if os.environ.get("WAVERLESS_POD_ID"):
if os.environ.get("WAVERLESS_ENDPOINT_ID"):
return "waverless"

# Check for RunPod environment
if os.environ.get("RUNPOD_ENDPOINT_ID"):
return "runpod"

return None


def _resolve_url(url_template: Optional[str], pod_id: str) -> Optional[str]:
"""Replace pod ID placeholder in URL template.
def _generate_pod_id(endpoint_id: Optional[str], raw_pod_id: Optional[str]) -> str:
"""Generate or resolve pod_id.

Note: Only $RUNPOD_POD_ID is replaced here. The $ID placeholder is
replaced later at runtime with the actual job ID in http._handle_result.
Priority: raw_pod_id > DEVICE_ID > auto-generate

Args:
endpoint_id: The endpoint identifier.
raw_pod_id: The raw pod_id from environment variable.

Returns:
The resolved pod_id.
"""
if raw_pod_id:
return raw_pod_id
device_id = os.environ.get("DEVICE_ID")
if device_id:
return device_id
prefix = endpoint_id or "worker"
return f"{prefix}-{uuid.uuid4().hex}"


def _resolve_runpod_url(url_template: Optional[str], pod_id: str) -> Optional[str]:
"""Replace pod ID placeholder in RunPod URL template.

Args:
url_template: URL template with $RUNPOD_POD_ID placeholder.
Expand All @@ -112,26 +132,55 @@ def _resolve_url(url_template: Optional[str], pod_id: str) -> Optional[str]:
return url_template.replace("$RUNPOD_POD_ID", pod_id)


def _resolve_waverless_url(url_template: Optional[str], pod_id: str) -> Optional[str]:
"""Replace pod ID placeholder in Waverless URL template.

Args:
url_template: URL template with $WAVERLESS_POD_ID placeholder.
pod_id: The worker/pod ID to substitute.

Returns:
URL with $WAVERLESS_POD_ID placeholder replaced, or None if template is None.
"""
if not url_template:
return None
return url_template.replace("$WAVERLESS_POD_ID", pod_id)


def _load_runpod_serverless_config() -> None:
"""Load RunPod environment variables into serverless config."""
# Endpoint identification (load first for pod_id generation)
serverless.endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID")
serverless.project_id = os.environ.get("RUNPOD_PROJECT_ID")

# Worker identification
serverless.pod_id = os.environ.get("RUNPOD_POD_ID") or ""
raw_pod_id = os.environ.get("RUNPOD_POD_ID")
serverless.pod_id = _generate_pod_id(serverless.endpoint_id, raw_pod_id)
serverless.pod_hostname = os.environ.get("RUNPOD_POD_HOSTNAME", serverless.pod_id)

# API endpoint templates
serverless.webhook_get_job = os.environ.get("RUNPOD_WEBHOOK_GET_JOB")
serverless.webhook_post_output = os.environ.get("RUNPOD_WEBHOOK_POST_OUTPUT")
serverless.webhook_post_stream = os.environ.get("RUNPOD_WEBHOOK_POST_STREAM")
serverless.webhook_ping = os.environ.get("RUNPOD_WEBHOOK_PING")

# Resolved API endpoints (with pod_id substituted)
serverless.job_get_url = _resolve_url(serverless.webhook_get_job, serverless.pod_id)
serverless.job_done_url = _resolve_url(
# Resolved API endpoints (with $RUNPOD_POD_ID substituted)
job_get_url = _resolve_runpod_url(serverless.webhook_get_job, serverless.pod_id)
# job_get_url also needs $ID replaced with worker ID (like runpod-python)
if job_get_url:
job_get_url = job_get_url.replace("$ID", serverless.pod_id)
serverless.job_get_url = job_get_url

# job_done_url keeps $ID for runtime replacement with job_id
serverless.job_done_url = _resolve_runpod_url(
serverless.webhook_post_output, serverless.pod_id
)
serverless.job_stream_url = _resolve_url(
serverless.job_stream_url = _resolve_runpod_url(
serverless.webhook_post_stream, serverless.pod_id
)
serverless.ping_url = _resolve_url(serverless.webhook_ping, serverless.pod_id)
serverless.ping_url = _resolve_runpod_url(
serverless.webhook_ping, serverless.pod_id
)

# Authentication
serverless.api_key = os.environ.get("RUNPOD_AI_API_KEY")
Expand All @@ -142,11 +191,6 @@ def _load_runpod_serverless_config() -> None:
log_level = os.environ.get("RUNPOD_DEBUG_LEVEL")
serverless.log_level = log_level or "INFO"

# Endpoint identification
serverless.endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID")
serverless.project_id = os.environ.get("RUNPOD_PROJECT_ID")
serverless.pod_hostname = os.environ.get("RUNPOD_POD_HOSTNAME")

# Timing and concurrency
ping_interval = os.environ.get("RUNPOD_PING_INTERVAL")
if ping_interval:
Expand All @@ -163,36 +207,48 @@ def _load_runpod_serverless_config() -> None:

def _load_waverless_serverless_config() -> None:
"""Load Waverless environment variables into serverless config."""
# Endpoint identification (load first for pod_id generation)
serverless.endpoint_id = os.environ.get("WAVERLESS_ENDPOINT_ID")
# Endpoint identification (endpoint_id already set above)
serverless.project_id = os.environ.get("WAVERLESS_PROJECT_ID")

# Worker identification
serverless.pod_id = os.environ.get("WAVERLESS_POD_ID") or ""
raw_pod_id = os.environ.get("WAVERLESS_POD_ID")
serverless.pod_id = _generate_pod_id(serverless.endpoint_id, raw_pod_id)
serverless.pod_hostname = os.environ.get(
"WAVERLESS_POD_HOSTNAME", serverless.pod_id
)

# API endpoint templates
serverless.webhook_get_job = os.environ.get("WAVERLESS_WEBHOOK_GET_JOB")
serverless.webhook_post_output = os.environ.get("WAVERLESS_WEBHOOK_POST_OUTPUT")
serverless.webhook_post_stream = os.environ.get("WAVERLESS_WEBHOOK_POST_STREAM")
serverless.webhook_ping = os.environ.get("WAVERLESS_WEBHOOK_PING")

# Resolved API endpoints (with pod_id substituted)
serverless.job_get_url = _resolve_url(serverless.webhook_get_job, serverless.pod_id)
serverless.job_done_url = _resolve_url(
# Resolved API endpoints (with $WAVERLESS_POD_ID substituted)
job_get_url = _resolve_waverless_url(serverless.webhook_get_job, serverless.pod_id)
# job_get_url also needs $ID replaced with worker ID (like runpod)
if job_get_url:
job_get_url = job_get_url.replace("$ID", serverless.pod_id)
serverless.job_get_url = job_get_url

# job_done_url keeps $ID for runtime replacement with job_id
serverless.job_done_url = _resolve_waverless_url(
serverless.webhook_post_output, serverless.pod_id
)
serverless.job_stream_url = _resolve_url(
serverless.job_stream_url = _resolve_waverless_url(
serverless.webhook_post_stream, serverless.pod_id
)
serverless.ping_url = _resolve_url(serverless.webhook_ping, serverless.pod_id)
serverless.ping_url = _resolve_waverless_url(
serverless.webhook_ping, serverless.pod_id
)

# Authentication
serverless.api_key = os.environ.get("WAVERLESS_API_KEY")

# Logging
serverless.log_level = os.environ.get("WAVERLESS_LOG_LEVEL", "INFO")

# Endpoint identification
serverless.endpoint_id = os.environ.get("WAVERLESS_ENDPOINT_ID")
serverless.project_id = os.environ.get("WAVERLESS_PROJECT_ID")
serverless.pod_hostname = os.environ.get("WAVERLESS_POD_HOSTNAME")

# Timing and concurrency
ping_interval = os.environ.get("WAVERLESS_PING_INTERVAL")
if ping_interval:
Expand Down
9 changes: 9 additions & 0 deletions src/wavespeed/serverless/modules/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ def __init__(self, config: Dict[str, Any]):
tags=["Status"],
)

# Health check endpoint
router.add_api_route(
"/health",
lambda: {"status": "ok"},
methods=["GET"],
summary="Health check",
tags=["Status"],
)

self.app.include_router(router)

def start(
Expand Down
53 changes: 36 additions & 17 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,70 @@

import unittest

from wavespeed.config import _resolve_url, serverless
from wavespeed.config import _resolve_runpod_url, _resolve_waverless_url, serverless


class TestResolveUrl(unittest.TestCase):
"""Tests for the _resolve_url function."""
class TestResolveRunpodUrl(unittest.TestCase):
"""Tests for the _resolve_runpod_url function."""

def test_replaces_runpod_pod_id(self):
"""Test that $RUNPOD_POD_ID is replaced with pod_id."""
template = "https://api.runpod.ai/v2/endpoint/job-done/$RUNPOD_POD_ID"
result = _resolve_url(template, "my-pod-123")
result = _resolve_runpod_url(template, "my-pod-123")
self.assertEqual(
result, "https://api.runpod.ai/v2/endpoint/job-done/my-pod-123"
)

def test_preserves_id_placeholder(self):
"""Test that $ID is NOT replaced - it's for job ID at runtime."""
template = "https://api.runpod.ai/v2/endpoint/job-done/$RUNPOD_POD_ID/$ID"
result = _resolve_url(template, "my-pod-123")
# $ID should remain as placeholder for job ID replacement later
result = _resolve_runpod_url(template, "my-pod-123")
self.assertEqual(
result, "https://api.runpod.ai/v2/endpoint/job-done/my-pod-123/$ID"
)

def test_handles_none_template(self):
"""Test that None template returns None."""
result = _resolve_url(None, "my-pod-123")
result = _resolve_runpod_url(None, "my-pod-123")
self.assertIsNone(result)

def test_handles_empty_template(self):
"""Test that empty template returns None (falsy check)."""
result = _resolve_url("", "my-pod-123")
def test_no_placeholders(self):
"""Test URL without any placeholders."""
template = "https://api.example.com/endpoint"
result = _resolve_runpod_url(template, "my-pod-123")
self.assertEqual(result, "https://api.example.com/endpoint")


class TestResolveWaverlessUrl(unittest.TestCase):
"""Tests for the _resolve_waverless_url function."""

def test_replaces_waverless_pod_id_placeholder(self):
"""Test that $WAVERLESS_POD_ID is replaced with pod_id."""
template = "https://api.wavespeed.ai/v2/test/job-take/$WAVERLESS_POD_ID"
result = _resolve_waverless_url(template, "my-pod-123")
self.assertEqual(
result, "https://api.wavespeed.ai/v2/test/job-take/my-pod-123"
)

def test_preserves_id_placeholder(self):
"""Test that $ID is NOT replaced - it's for job/worker ID at runtime."""
template = "https://api.wavespeed.ai/v2/test/job-done/$WAVERLESS_POD_ID/$ID"
result = _resolve_waverless_url(template, "my-pod-123")
self.assertEqual(
result, "https://api.wavespeed.ai/v2/test/job-done/my-pod-123/$ID"
)

def test_handles_none_template(self):
"""Test that None template returns None."""
result = _resolve_waverless_url(None, "my-pod-123")
self.assertIsNone(result)

def test_no_placeholders(self):
"""Test URL without any placeholders."""
template = "https://api.example.com/endpoint"
result = _resolve_url(template, "my-pod-123")
result = _resolve_waverless_url(template, "my-pod-123")
self.assertEqual(result, "https://api.example.com/endpoint")

def test_multiple_pod_id_placeholders(self):
"""Test multiple $RUNPOD_POD_ID placeholders are all replaced."""
template = "https://api.runpod.ai/$RUNPOD_POD_ID/test/$RUNPOD_POD_ID"
result = _resolve_url(template, "pod-456")
self.assertEqual(result, "https://api.runpod.ai/pod-456/test/pod-456")


class TestServerlessConfig(unittest.TestCase):
"""Tests for serverless config loading."""
Expand Down
Loading