Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/agentlab/experiments/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@

def run_exp(exp_arg: ExpArgs, *dependencies, avg_step_timeout=60):
"""Run exp_args.run() with a timeout and handle dependencies."""
episode_timeout = _episode_timeout(exp_arg, avg_step_timeout=avg_step_timeout)
with timeout_manager(seconds=episode_timeout):
return exp_arg.run()
# episode_timeout = _episode_timeout(exp_arg, avg_step_timeout=avg_step_timeout)
# logger.warning(f"Running {exp_arg.exp_id} with timeout of {episode_timeout} seconds.")
# with timeout_manager(seconds=episode_timeout):
# this timeout method is not robust enough. using ray.cancel instead
return exp_arg.run()


def _episode_timeout(exp_arg: ExpArgs, avg_step_timeout=60):
Expand Down Expand Up @@ -62,13 +64,12 @@ def timeout_manager(seconds: int = None):

def alarm_handler(signum, frame):

logger.warning(
f"Operation timed out after {seconds}s, sending SIGINT and raising TimeoutError."
)
logger.warning(f"Operation timed out after {seconds}s, raising TimeoutError.")
# send sigint
os.kill(os.getpid(), signal.SIGINT)
# os.kill(os.getpid(), signal.SIGINT) # this doesn't seem to do much I don't know why

# Still raise TimeoutError for immediate handling
# This works, but it doesn't seem enough to kill the job
raise TimeoutError(f"Operation timed out after {seconds} seconds")

previous_handler = signal.signal(signal.SIGALRM, alarm_handler)
Expand Down
66 changes: 57 additions & 9 deletions src/agentlab/experiments/graph_execution_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

# # Disable Ray log deduplication
# os.environ["RAY_DEDUP_LOGS"] = "0"

import time
import ray
import bgym
from agentlab.experiments.exp_utils import run_exp
from agentlab.experiments.exp_utils import run_exp, _episode_timeout
from ray.util import state
import logging

logger = logging.getLogger(__name__)

run_exp = ray.remote(run_exp)

Expand All @@ -15,25 +18,70 @@ def execute_task_graph(exp_args_list: list[bgym.ExpArgs], avg_step_timeout=60):
"""Execute a task graph in parallel while respecting dependencies using Ray."""

exp_args_map = {exp_args.exp_id: exp_args for exp_args in exp_args_list}
tasks = {}
task_map = {}

def get_task(exp_arg: bgym.ExpArgs):
if exp_arg.exp_id not in tasks:
if exp_arg.exp_id not in task_map:
# Get all dependency tasks first
dependency_tasks = [get_task(exp_args_map[dep_key]) for dep_key in exp_arg.depends_on]

# Create new task that depends on the dependency results
tasks[exp_arg.exp_id] = run_exp.remote(
task_map[exp_arg.exp_id] = run_exp.remote(
exp_arg, *dependency_tasks, avg_step_timeout=avg_step_timeout
)
return tasks[exp_arg.exp_id]
return task_map[exp_arg.exp_id]

# Build task graph
for exp_arg in exp_args_list:
get_task(exp_arg)

# Execute all tasks and gather results
max_timeout = max([_episode_timeout(exp_args, avg_step_timeout) for exp_args in exp_args_list])
return poll_for_timeout(task_map, max_timeout, poll_interval=max_timeout * 0.1)


def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_interval: float = 1.0):
"""Cancel tasks that exceeds the timeout

I tried various different methods for killing a job that hangs. so far it's
the only one that seems to work reliably (hopefully)
"""
task_list = list(tasks.values())
task_ids = list(tasks.keys())
results = ray.get(list(tasks.values()))

return {task_id: result for task_id, result in zip(task_ids, results)}
logger.warning(f"Any task exceeding {timeout} seconds will be cancelled.")

while True:
ready, not_ready = ray.wait(task_list, num_returns=len(task_list), timeout=poll_interval)
for task in not_ready:
elapsed_time = get_elapsed_time(task)
# print(f"Task {task.task_id().hex()} elapsed time: {elapsed_time}")
if elapsed_time is not None and elapsed_time > timeout:
msg = f"Task {task.task_id().hex()} hase been running for {elapsed_time}s, more than the timeout: {timeout}s."
if elapsed_time < timeout + 60:
logger.warning(msg + " Cancelling task.")
ray.cancel(task, force=False, recursive=False)
else:
logger.warning(msg + " Force killing.")
ray.cancel(task, force=True, recursive=False)
if len(ready) == len(task_list):
results = []
for task in ready:
try:
result = ray.get(task)
except Exception as e:
result = e
results.append(result)

return {task_id: result for task_id, result in zip(task_ids, results)}


def get_elapsed_time(task_ref: ray.ObjectRef):
task_id = task_ref.task_id().hex()
task_info = state.get_task(task_id, address="auto")
if task_info and task_info.start_time_ms is not None:
start_time_s = task_info.start_time_ms / 1000.0 # Convert ms to s
current_time_s = time.time()
elapsed_time = current_time_s - start_time_s
return elapsed_time
else:
return None # Task has not started yet
12 changes: 6 additions & 6 deletions tests/experiments/test_launch_exp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import tempfile
from pathlib import Path

Expand Down Expand Up @@ -63,19 +64,18 @@ def _test_launch_system(backend="ray", cause_timeout=False):
if row.stack_trace is not None:
print(row.stack_trace)
if cause_timeout:
assert row.err_msg is not None
assert "Timeout" in row.err_msg
assert row.cum_reward == 0
# assert row.err_msg is not None
assert math.isnan(row.cum_reward) or row.cum_reward == 0
else:
assert row.err_msg is None
assert row.cum_reward == 1.0

study_summary = inspect_results.summarize_study(results_df)
assert len(study_summary) == 1
assert study_summary.std_err.iloc[0] == 0
assert study_summary.n_completed.iloc[0] == "3/3"

if not cause_timeout:
assert study_summary.n_completed.iloc[0] == "3/3"
assert study_summary.avg_reward.iloc[0] == 1.0


Expand All @@ -91,7 +91,7 @@ def test_launch_system_ray():
_test_launch_system(backend="ray")


def _test_timeout_ray():
def test_timeout_ray():
_test_launch_system(backend="ray", cause_timeout=True)


Expand Down Expand Up @@ -120,7 +120,7 @@ def test_4o_mini_on_miniwob_tiny_test():


if __name__ == "__main__":
_test_timeout_ray()
test_timeout_ray()
# test_4o_mini_on_miniwob_tiny_test()
# test_launch_system_ray()
# test_launch_system_sequntial()