Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/agentlab/experiments/study.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from dataclasses import dataclass
from datetime import datetime
import gzip
import logging
from pathlib import Path
import pickle
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

from bgym import ExpArgs, EnvArgs, Benchmark
import bgym
from bgym import Benchmark, EnvArgs, ExpArgs

from agentlab.agents.agent_args import AgentArgs
from agentlab.analyze import inspect_results
from agentlab.experiments import args
from agentlab.experiments.launch_exp import run_experiments, find_incomplete
from agentlab.experiments.exp_utils import RESULTS_DIR
from agentlab.experiments import reproducibility_util as repro
from agentlab.experiments.exp_utils import RESULTS_DIR
from agentlab.experiments.launch_exp import find_incomplete, run_experiments


@dataclass
Expand All @@ -25,7 +25,7 @@ class Study:
Attributes:
benchmark: Benchmark | str
The benchmark to evaluate the agents on. If a string is provided, it will be
converted to the corresponding benchmark using bgym.BENCHMARKS.
converted to the corresponding benchmark using bgym.DEFAULT_BENCHMARKS.

agent_args: list[AgentArgs]
The list of agents to evaluate.
Expand Down Expand Up @@ -54,7 +54,7 @@ class Study:
def __post_init__(self):
self.uuid = str(datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
if isinstance(self.benchmark, str):
self.benchmark = bgym.BENCHMARKS[self.benchmark]()
self.benchmark = bgym.DEFAULT_BENCHMARKS[self.benchmark]()
if isinstance(self.dir, str):
self.dir = Path(self.dir)
self.make_exp_args_list()
Expand Down
14 changes: 8 additions & 6 deletions tests/experiments/test_reproducibility_util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from pathlib import Path
import json
import tempfile
import time
from pathlib import Path

import bgym
import pytest

from agentlab.agents.generic_agent import AGENT_4o_MINI
from agentlab.analyze import inspect_results
from agentlab.experiments import reproducibility_util
from agentlab.agents.generic_agent import AGENT_4o_MINI
import pytest
import json
import bgym


@pytest.mark.parametrize(
Expand All @@ -15,7 +17,7 @@
)
def test_get_reproducibility_info(benchmark_name):

benchmark = bgym.BENCHMARKS[benchmark_name]()
benchmark = bgym.DEFAULT_BENCHMARKS[benchmark_name]()

info = reproducibility_util.get_reproducibility_info(
"test_agent", benchmark, "test_id", ignore_changes=True
Expand Down