-
Notifications
You must be signed in to change notification settings - Fork 107
Create a simple pure visual agent. #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1b1c4dc
c25dc84
ba8c91e
a5a8ef4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT | ||
|
|
||
| from .visual_agent import VisualAgentArgs | ||
| from .visual_agent_prompts import PromptFlags | ||
| import agentlab.agents.dynamic_prompting as dp | ||
| import bgym | ||
|
|
||
| # the other flags are ignored for this agent. | ||
| DEFAULT_OBS_FLAGS = dp.ObsFlags( | ||
| use_tabs=True, # will be overridden by the benchmark when set_benchmark is called after initalizing the agent | ||
| use_error_logs=True, | ||
| use_past_error_logs=False, | ||
| use_screenshot=True, | ||
| use_som=False, | ||
| openai_vision_detail="auto", | ||
| ) | ||
|
|
||
| DEFAULT_ACTION_FLAGS = dp.ActionFlags( | ||
| action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]), | ||
| long_description=True, | ||
| individual_examples=False, | ||
| ) | ||
|
|
||
|
|
||
| DEFAULT_PROMPT_FLAGS = PromptFlags( | ||
| obs=DEFAULT_OBS_FLAGS, | ||
| action=DEFAULT_ACTION_FLAGS, | ||
| use_thinking=True, | ||
| use_concrete_example=False, | ||
| use_abstract_example=True, | ||
| enable_chat=False, | ||
| extra_instructions=None, | ||
| ) | ||
|
|
||
| VISUAL_AGENT_4o = VisualAgentArgs( | ||
| chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], | ||
| flags=DEFAULT_PROMPT_FLAGS, | ||
| ) | ||
|
|
||
|
|
||
| VISUAL_AGENT_CLAUDE_3_5 = VisualAgentArgs( | ||
| chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"], | ||
| flags=DEFAULT_PROMPT_FLAGS, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| """ | ||
| GenericAgent implementation for AgentLab | ||
|
|
||
| This module defines a `GenericAgent` class and its associated arguments for use in the AgentLab framework. \ | ||
| The `GenericAgent` class is designed to interact with a chat-based model to determine actions based on \ | ||
| observations. It includes methods for preprocessing observations, generating actions, and managing internal \ | ||
| state such as plans, memories, and thoughts. The `GenericAgentArgs` class provides configuration options for \ | ||
| the agent, including model arguments and flags for various behaviors. | ||
| """ | ||
|
|
||
| from dataclasses import asdict, dataclass | ||
|
|
||
| import bgym | ||
| from browsergym.experiments.agent import Agent, AgentInfo | ||
|
|
||
| from agentlab.agents import dynamic_prompting as dp | ||
| from agentlab.agents.agent_args import AgentArgs | ||
| from agentlab.llm.chat_api import BaseModelArgs | ||
| from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry | ||
| from agentlab.llm.tracking import cost_tracker_decorator | ||
|
|
||
| from .visual_agent_prompts import PromptFlags, MainPrompt | ||
|
|
||
|
|
||
| @dataclass | ||
| class VisualAgentArgs(AgentArgs): | ||
| chat_model_args: BaseModelArgs = None | ||
| flags: PromptFlags = None | ||
| max_retry: int = 4 | ||
|
|
||
| def __post_init__(self): | ||
| try: # some attributes might be missing temporarily due to args.CrossProd for hyperparameter generation | ||
| self.agent_name = f"VisualAgent-{self.chat_model_args.model_name}".replace("/", "_") | ||
| except AttributeError: | ||
| pass | ||
|
|
||
| def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): | ||
| """Override Some flags based on the benchmark.""" | ||
| self.flags.obs.use_tabs = benchmark.is_multi_tab | ||
|
|
||
| def set_reproducibility_mode(self): | ||
| self.chat_model_args.temperature = 0 | ||
|
|
||
| def prepare(self): | ||
| return self.chat_model_args.prepare_server() | ||
|
|
||
| def close(self): | ||
| return self.chat_model_args.close_server() | ||
|
|
||
| def make_agent(self): | ||
| return VisualAgent( | ||
| chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry | ||
| ) | ||
|
|
||
|
|
||
| class VisualAgent(Agent): | ||
|
|
||
| def __init__( | ||
| self, | ||
| chat_model_args: BaseModelArgs, | ||
| flags: PromptFlags, | ||
| max_retry: int = 4, | ||
| ): | ||
|
|
||
| self.chat_llm = chat_model_args.make_model() | ||
| self.chat_model_args = chat_model_args | ||
| self.max_retry = max_retry | ||
|
|
||
| self.flags = flags | ||
| self.action_set = self.flags.action.action_set.make_action_set() | ||
| self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) | ||
|
|
||
| self.reset(seed=None) | ||
|
|
||
| def obs_preprocessor(self, obs: dict) -> dict: | ||
| return self._obs_preprocessor(obs) | ||
|
|
||
| @cost_tracker_decorator | ||
| def get_action(self, obs): | ||
|
|
||
| main_prompt = MainPrompt( | ||
| action_set=self.action_set, | ||
| obs=obs, | ||
| actions=self.actions, | ||
| thoughts=self.thoughts, | ||
| flags=self.flags, | ||
| ) | ||
|
|
||
| system_prompt = SystemMessage(dp.SystemPrompt().prompt) | ||
| try: | ||
| # TODO, we would need to further shrink the prompt if the retry | ||
| # cause it to be too long | ||
|
|
||
| chat_messages = Discussion([system_prompt, main_prompt.prompt]) | ||
| ans_dict = retry( | ||
| self.chat_llm, | ||
| chat_messages, | ||
| n_retry=self.max_retry, | ||
| parser=main_prompt._parse_answer, | ||
| ) | ||
| ans_dict["busted_retry"] = 0 | ||
| # inferring the number of retries, TODO: make this less hacky | ||
| ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 | ||
|
Comment on lines
+102
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Complex retry calculation with magic numbers
Tell me moreWhat is the issue?Complex calculation with magic numbers (3, 2) used to infer retry count. Why this mattersThe formula's intent is unclear and forces readers to reverse engineer the logic. Suggested change ∙ Feature PreviewExtract the calculation into a well-named method like Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| except ParseError: | ||
| ans_dict = dict( | ||
| action=None, | ||
| n_retry=self.max_retry + 1, | ||
| busted_retry=1, | ||
| ) | ||
|
|
||
| stats = self.chat_llm.get_stats() | ||
| stats["n_retry"] = ans_dict["n_retry"] | ||
| stats["busted_retry"] = ans_dict["busted_retry"] | ||
|
|
||
| self.actions.append(ans_dict["action"]) | ||
| self.thoughts.append(ans_dict.get("think", None)) | ||
|
|
||
| agent_info = AgentInfo( | ||
| think=ans_dict.get("think", None), | ||
| chat_messages=chat_messages, | ||
| stats=stats, | ||
| extra_info={"chat_model_args": asdict(self.chat_model_args)}, | ||
| ) | ||
| return ans_dict["action"], agent_info | ||
|
|
||
| def reset(self, seed=None): | ||
| self.seed = seed | ||
| self.thoughts = [] | ||
| self.actions = [] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| """ | ||
| Prompt builder for GenericAgent | ||
|
|
||
| It is based on the dynamic_prompting module from the agentlab package. | ||
| """ | ||
|
|
||
| import logging | ||
| from dataclasses import dataclass | ||
| import bgym | ||
|
|
||
| from browsergym.core.action.base import AbstractActionSet | ||
|
|
||
| from agentlab.agents import dynamic_prompting as dp | ||
| from agentlab.llm.llm_utils import BaseMessage, HumanMessage, image_to_jpg_base64_url | ||
|
|
||
|
|
||
| @dataclass | ||
| class PromptFlags(dp.Flags): | ||
| """ | ||
| A class to represent various flags used to control features in an application. | ||
| """ | ||
|
|
||
| obs: dp.ObsFlags = None | ||
| action: dp.ActionFlags = None | ||
| use_thinking: bool = True | ||
| use_concrete_example: bool = False | ||
| use_abstract_example: bool = True | ||
| enable_chat: bool = False | ||
| extra_instructions: str | None = None | ||
|
|
||
|
|
||
| class SystemPrompt(dp.PromptElement): | ||
| _prompt = """\ | ||
| You are an agent trying to solve a web task based on the content of the page and | ||
| user instructions. You can interact with the page and explore, and send messages to the user. Each time you | ||
| submit an action it will be sent to the browser and you will receive a new page.""" | ||
|
|
||
|
|
||
| def make_instructions(obs: dict, from_chat: bool, extra_instructions: str | None): | ||
| """Convenient wrapper to extract instructions from either goal or chat""" | ||
| if from_chat: | ||
| instructions = dp.ChatInstructions( | ||
| obs["chat_messages"], extra_instructions=extra_instructions | ||
| ) | ||
| else: | ||
| if sum([msg["role"] == "user" for msg in obs.get("chat_messages", [])]) > 1: | ||
| logging.warning( | ||
| "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." | ||
| ) | ||
| instructions = dp.GoalInstructions( | ||
| obs["goal_object"], extra_instructions=extra_instructions | ||
| ) | ||
| return instructions | ||
|
|
||
|
|
||
| class History(dp.PromptElement): | ||
| """ | ||
| Format the actions and thoughts of previous steps.""" | ||
|
|
||
| def __init__(self, actions, thoughts) -> None: | ||
| super().__init__() | ||
| prompt_elements = [] | ||
| for i, (action, thought) in enumerate(zip(actions, thoughts)): | ||
| prompt_elements.append( | ||
| f""" | ||
| ## Step {i} | ||
| ### Thoughts: | ||
| {thought} | ||
| ### Action: | ||
| {action} | ||
| """ | ||
| ) | ||
| self._prompt = "\n".join(prompt_elements) + "\n" | ||
|
|
||
|
|
||
| class Observation(dp.PromptElement): | ||
| """Observation of the current step. | ||
|
|
||
| Contains the html, the accessibility tree and the error logs. | ||
| """ | ||
|
|
||
| def __init__(self, obs, flags: dp.ObsFlags) -> None: | ||
| super().__init__() | ||
| self.flags = flags | ||
| self.obs = obs | ||
|
|
||
| # for a multi-tab browser, we need to show the current tab | ||
| self.tabs = dp.Tabs( | ||
| obs, | ||
| visible=lambda: flags.use_tabs, | ||
| prefix="## ", | ||
| ) | ||
|
|
||
| # if an error is present, we need to show it | ||
| self.error = dp.Error( | ||
| obs["last_action_error"], | ||
| visible=lambda: flags.use_error_logs and obs["last_action_error"], | ||
| prefix="## ", | ||
| ) | ||
|
|
||
| @property | ||
| def _prompt(self) -> str: | ||
| return f""" | ||
| # Observation of current step: | ||
| {self.tabs.prompt}{self.error.prompt} | ||
|
|
||
| """ | ||
|
|
||
| def add_screenshot(self, prompt: BaseMessage) -> BaseMessage: | ||
| if self.flags.use_screenshot: | ||
| if self.flags.use_som: | ||
| screenshot = self.obs["screenshot_som"] | ||
| prompt.add_text( | ||
| "\n## Screenshot:\nHere is a screenshot of the page, it is annotated with bounding boxes and corresponding bids:" | ||
| ) | ||
| else: | ||
| screenshot = self.obs["screenshot"] | ||
| prompt.add_text("\n## Screenshot:\nHere is a screenshot of the page:") | ||
| img_url = image_to_jpg_base64_url(screenshot) | ||
| prompt.add_image(img_url, detail=self.flags.openai_vision_detail) | ||
| return prompt | ||
|
|
||
|
|
||
| class MainPrompt(dp.PromptElement): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing class docstring in MainPrompt
Tell me moreWhat is the issue?The MainPrompt class lacks a class-level docstring explaining its purpose and usage Why this mattersThe class appears to be a core component for building prompts, but its role and relationship to other components is unclear Suggested change ∙ Feature Previewclass MainPrompt(dp.PromptElement):
"""Builds complete prompts for the agent including instructions, observations, history, and examples.
Combines various prompt components like history, observations, and action prompts into a
structured format for the LLM to process.
"""Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
| def __init__( | ||
| self, | ||
| action_set: AbstractActionSet, | ||
| obs: dict, | ||
| actions: list[str], | ||
| thoughts: list[str], | ||
| flags: PromptFlags, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.flags = flags | ||
| self.history = History(actions, thoughts) | ||
| self.instructions = make_instructions(obs, flags.enable_chat, flags.extra_instructions) | ||
| self.obs = Observation(obs, self.flags.obs) | ||
|
|
||
| self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action) | ||
| self.think = dp.Think(visible=lambda: flags.use_thinking) | ||
|
|
||
| @property | ||
| def _prompt(self) -> HumanMessage: | ||
| prompt = HumanMessage(self.instructions.prompt) | ||
| prompt.add_text( | ||
| f"""\ | ||
| {self.obs.prompt}\ | ||
| {self.history.prompt}\ | ||
| {self.action_prompt.prompt}\ | ||
| {self.think.prompt}\ | ||
| """ | ||
| ) | ||
|
|
||
| if self.flags.use_abstract_example: | ||
| prompt.add_text( | ||
| f""" | ||
| # Abstract Example | ||
|
|
||
| Here is an abstract version of the answer with description of the content of | ||
| each tag. Make sure you follow this structure, but replace the content with your | ||
| answer: | ||
| {self.think.abstract_ex}\ | ||
| {self.action_prompt.abstract_ex}\ | ||
| """ | ||
| ) | ||
|
|
||
| if self.flags.use_concrete_example: | ||
| prompt.add_text( | ||
| f""" | ||
| # Concrete Example | ||
|
|
||
| Here is a concrete example of how to format your answer. | ||
| Make sure to follow the template with proper tags: | ||
| {self.think.concrete_ex}\ | ||
| {self.action_prompt.concrete_ex}\ | ||
| """ | ||
| ) | ||
| return self.obs.add_screenshot(prompt) | ||
|
|
||
| def _parse_answer(self, text_answer): | ||
| ans_dict = {} | ||
| ans_dict.update(self.think.parse_answer(text_answer)) | ||
| ans_dict.update(self.action_prompt.parse_answer(text_answer)) | ||
| return ans_dict | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-actionable TODO comment
Tell me more
What is the issue?
TODO comment is not actionable and lacks context about implementation details.
Why this matters
Future developers won't understand what specifically needs to be implemented or why.
Suggested change ∙ Feature Preview
TODO: Implement dynamic prompt shrinking when retries cause token limits to be exceeded.
Provide feedback to improve future suggestions
💬 Looking for more details? Reply to this comment to chat with Korbit.