Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 2 additions & 81 deletions src/agentlab/agents/generic_agent/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@

from copy import deepcopy
from dataclasses import asdict, dataclass
from functools import partial
from warnings import warn

import bgym
from browsergym.experiments.agent import Agent, AgentInfo

from agentlab.agents import dynamic_prompting as dp
from agentlab.agents.agent_args import AgentArgs
from agentlab.llm.chat_api import BaseModelArgs, make_system_message, make_user_message
from agentlab.llm.chat_api import BaseModelArgs
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
from agentlab.llm.tracking import cost_tracker_decorator

from .generic_agent_prompt import GenericPromptFlags, MainPrompt
from functools import partial


@dataclass
Expand Down Expand Up @@ -200,82 +200,3 @@ def _get_maxes(self):
else 20 # dangerous to change the default value here?
)
return max_prompt_tokens, max_trunc_itr


from functools import partial


def get_action_post_hoc(agent: GenericAgent, obs: dict, ans_dict: dict):
"""
Get the action post-hoc for the agent.

This function is used to get the action after the agent has already been run.
Its goal is to recreate the prompt and the output of the agent a posteriori.
The purpose is to build datasets for training the agents.

Args:
agent (GenericAgent): The agent for which the action is being determined.
obs (dict): The observation dictionary to append to the agent's history.
ans_dict (dict): The answer dictionary containing the plan, step, memory, think, and action.

Returns:
Tuple[str, str]: The complete prompt used for the agent and the reconstructed output based on the answer dictionary.
"""
system_prompt = dp.SystemPrompt().prompt

agent.obs_history.append(obs)

main_prompt = MainPrompt(
action_set=agent.action_set,
obs_history=agent.obs_history,
actions=agent.actions,
memories=agent.memories,
thoughts=agent.thoughts,
previous_plan=agent.plan,
step=agent.plan_step,
flags=agent.flags,
)

max_prompt_tokens, max_trunc_itr = agent._get_maxes()

fit_function = partial(
dp.fit_tokens,
max_prompt_tokens=max_prompt_tokens,
model_name=agent.chat_model_args.model_name,
max_iterations=max_trunc_itr,
)

instruction_prompt = fit_function(shrinkable=main_prompt)

if isinstance(instruction_prompt, list):
# NOTE: this is when we have images
instruction_prompt = instruction_prompt[0]["text"]

# TODO: make sure the bid is in the prompt

output = ""

# TODO: validate this
agent.plan = ans_dict.get("plan", agent.plan)
if agent.plan != "No plan yet":
output += f"\n<plan>\n{agent.plan}\n</plan>\n"

# TODO: is plan_step something that the agent's outputs?
agent.plan_step = ans_dict.get("step", agent.plan_step)

memory = ans_dict.get("memory", None)
agent.memories.append(memory)
if memory is not None:
output += f"\n<memory>\n{memory}\n</memory>\n"

thought = ans_dict.get("think", None)
agent.thoughts.append(thought)
if thought is not None:
output += f"\n<think>\n{thought}\n</think>\n"

action = ans_dict["action"]
agent.actions.append(action)
if action is not None:
output += f"\n<action>\n{action}\n</action>"

return system_prompt, instruction_prompt, output