Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/agentlab/experiments/study.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import ProcessPoolExecutor
import gzip
import logging
import os
Expand Down Expand Up @@ -498,6 +499,8 @@ def _init_worker(server_queue: Queue):
A queue of object implementing BaseServer to initialize (or anything with a init
method).
"""
print("initializing server instance with on process", os.getpid())
print(f"using queue {server_queue}")
server_instance = server_queue.get() # type: "WebArenaInstanceVars"
logger.warning(f"Initializing server instance {server_instance} from process {os.getpid()}")
server_instance.init()
Expand All @@ -510,6 +513,42 @@ def _run_study(study: Study, n_jobs, parallel_backend, strict_reproducibility, n

@dataclass
class ParallelStudies(SequentialStudies):
parallel_servers: list[BaseServer] | int = None

def _run(
self,
n_jobs=1,
parallel_backend="ray",
strict_reproducibility=False,
n_relaunch=3,
):
parallel_servers = self.parallel_servers
if isinstance(parallel_servers, int):
parallel_servers = [BaseServer() for _ in range(parallel_servers)]

server_queue = Manager().Queue()
for server in parallel_servers:
server_queue.put(server)

with ProcessPoolExecutor(
max_workers=len(parallel_servers), initializer=_init_worker, initargs=(server_queue,)
) as executor:
# Create list of arguments for each study
study_args = [
(study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch)
for study in self.studies
]

# Submit all tasks and wait for completion
futures = [executor.submit(_run_study, *args) for args in study_args]

# Wait for all futures to complete and raise any exceptions
for future in futures:
future.result()


@dataclass
class ParallelStudies_alt(SequentialStudies):

parallel_servers: list[BaseServer] | int = None

Expand Down
11 changes: 10 additions & 1 deletion tests/experiments/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from agentlab.llm.chat_api import CheatMiniWoBLLMArgs
from agentlab.experiments.study import ParallelStudies, make_study, Study
from agentlab.experiments.multi_server import WebArenaInstanceVars
import logging


logging.getLogger().setLevel(logging.INFO)


def _make_agent_args_list():
Expand All @@ -28,13 +32,18 @@ def manual_test_launch_parallel_study_webarena():
server_instance_2 = server_instance_1.clone()
server_instance_2.base_url = "http://webarena-slow.eastus.cloudapp.azure.com"
parallel_servers = [server_instance_1, server_instance_2]
# parallel_servers = [server_instance_2]

for server in parallel_servers:
print(server)

study = make_study(
agent_args_list, benchmark="webarena_tiny", parallel_servers=parallel_servers
agent_args_list,
benchmark="webarena_tiny",
parallel_servers=parallel_servers,
ignore_dependencies=True,
)
study.override_max_steps(2)
assert isinstance(study, ParallelStudies)

study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1)
Expand Down
Loading