Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions src/agentlab/agents/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,24 +443,36 @@ def __init__(self, visible: bool = True) -> None:


class GoalInstructions(PromptElement):
def __init__(self, goal, visible: bool = True, extra_instructions=None) -> None:
def __init__(self, goal_object, visible: bool = True, extra_instructions=None) -> None:
super().__init__(visible)
self._prompt = f"""\
self._prompt = [
dict(
type="text",
text=f"""\
# Instructions
Review the current state of the page and all other information to find the best
possible next action to accomplish your goal. Your answer will be interpreted
and executed by a program, make sure to follow the formatting instructions.

## Goal:
{goal}
"""
""",
)
]

self._prompt += goal_object

if extra_instructions:
self._prompt += f"""
self._prompt += [
dict(
type="text",
text=f"""

## Extra instructions:

{extra_instructions}
"""
""",
)
]


class ChatInstructions(PromptElement):
Expand Down
2 changes: 2 additions & 0 deletions src/agentlab/agents/generic_agent/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def get_action(self, obs):
main_prompt = MainPrompt(
action_set=self.action_set,
obs_history=self.obs_history,
goal_object=obs["goal_object"],
actions=self.actions,
memories=self.memories,
thoughts=self.thoughts,
Expand Down Expand Up @@ -268,3 +269,4 @@ def get_action_post_hoc(agent: GenericAgent, obs: dict, ans_dict: dict):
output += f"\n<action>\n{action}\n</action>"

return system_prompt, instruction_prompt, output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

return system_prompt, instruction_prompt, output
7 changes: 4 additions & 3 deletions src/agentlab/agents/generic_agent/generic_agent_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self,
action_set: AbstractActionSet,
obs_history: list[dict],
goal_object: list[dict],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no duplicate please (even if not elegant)

actions: list[str],
memories: list[str],
thoughts: list[str],
Expand All @@ -71,7 +72,7 @@ def __init__(
"Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`."
)
self.instructions = dp.GoalInstructions(
obs_history[-1]["goal"], extra_instructions=flags.extra_instructions
goal_object, extra_instructions=flags.extra_instructions
)

self.obs = dp.Observation(obs_history[-1], self.flags.obs)
Expand All @@ -93,9 +94,9 @@ def time_for_caution():

@property
def _prompt(self) -> HumanMessage:
prompt = HumanMessage(
prompt = HumanMessage(self.instructions.prompt)
prompt.add_text(
f"""\
{self.instructions.prompt}\
{self.obs.prompt}\
{self.history.prompt}\
{self.action_prompt.prompt}\
Expand Down
33 changes: 30 additions & 3 deletions src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image):
class BaseMessage(dict):
def __init__(self, role: str, content: Union[str, list[dict]]):
self["role"] = role
self["content"] = content
self["content"] = deepcopy(content)

def __str__(self) -> str:
if isinstance(self["content"], str):
Expand Down Expand Up @@ -365,10 +365,30 @@ def to_markdown(self):
# add texts between ticks and images
if elem["type"] == "text":
res.append(f"\n```\n{elem['text']}\n```\n")
elif elem["type"] == "image":
res.append(f"![image]({elem['url']})")
elif elem["type"] == "image_url":
img_str = (
elem["image_url"]
if isinstance(elem["image_url"], str)
else elem["image_url"]["url"]
)
res.append(f"![image]({img_str})")
return "\n".join(res)

def merge(self):
"""Merges content elements of type 'text' if they are adjacent."""
if isinstance(self["content"], str):
return
new_content = []
for elem in self["content"]:
if elem["type"] == "text":
if new_content and new_content[-1]["type"] == "text":
new_content[-1]["text"] += "\n" + elem["text"]
else:
new_content.append(elem)
else:
new_content.append(elem)
self["content"] = new_content


class SystemMessage(BaseMessage):
def __init__(self, content: Union[str, list[dict]]):
Expand Down Expand Up @@ -397,13 +417,19 @@ def __init__(self, messages: Union[list[BaseMessage], BaseMessage] = None):
def last_message(self):
return self.messages[-1]

def merge(self):
for m in self.messages:
m.merge()

def __str__(self) -> str:
return "\n".join(str(m) for m in self.messages)

def to_string(self):
self.merge()
return str(self)

def to_openai(self):
self.merge()
return self.messages

def add_message(
Expand Down Expand Up @@ -444,6 +470,7 @@ def __getitem__(self, key):
return self.messages[key]

def to_markdown(self):
self.merge()
return "\n".join([f"Message {i}\n{m.to_markdown()}\n" for i, m in enumerate(self.messages)])


Expand Down
12 changes: 12 additions & 0 deletions tests/agents/test_generic_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OBS_HISTORY = [
{
"goal": "do this and that",
"goal_object": [{"type": "text", "text": "do this and that"}],
"chat_messages": [{"role": "user", "message": "do this and that"}],
"pruned_html": html_template.format(1),
"axtree_txt": "[1] Click me",
Expand All @@ -32,6 +33,7 @@
},
{
"goal": "do this and that",
"goal_object": [{"type": "text", "text": "do this and that"}],
"chat_messages": [{"role": "user", "message": "do this and that"}],
"pruned_html": html_template.format(2),
"axtree_txt": "[1] Click me",
Expand All @@ -40,13 +42,15 @@
},
{
"goal": "do this and that",
"goal_object": [{"type": "text", "text": "do this and that"}],
"chat_messages": [{"role": "user", "message": "do this and that"}],
"pruned_html": html_template.format(3),
"axtree_txt": "[1] Click me",
"focused_element_bid": "45-256",
"last_action_error": "Hey, there is an error now",
},
]
GOAL_OBJECT = [{"type": "text", "text": "do this and that"}]
ACTIONS = ["click('41')", "click('42')"]
MEMORIES = ["memory A", "memory B"]
THOUGHTS = ["thought A", "thought B"]
Expand Down Expand Up @@ -164,6 +168,7 @@ def test_shrinking_observation():
prompt_maker = MainPrompt(
action_set=dp.HighLevelActionSet(),
obs_history=OBS_HISTORY,
goal_object=GOAL_OBJECT,
actions=ACTIONS,
memories=MEMORIES,
thoughts=THOUGHTS,
Expand Down Expand Up @@ -208,6 +213,7 @@ def test_main_prompt_elements_gone_one_at_a_time(flag_name: str, expected_prompt
MainPrompt(
action_set=flags.action.action_set.make_action_set(),
obs_history=OBS_HISTORY,
goal_object=GOAL_OBJECT,
actions=ACTIONS,
memories=memories,
thoughts=THOUGHTS,
Expand All @@ -230,6 +236,7 @@ def test_main_prompt_elements_present():
MainPrompt(
action_set=dp.HighLevelActionSet(),
obs_history=OBS_HISTORY,
goal_object=GOAL_OBJECT,
actions=ACTIONS,
memories=MEMORIES,
thoughts=THOUGHTS,
Expand All @@ -250,3 +257,8 @@ def test_main_prompt_elements_present():
test_main_prompt_elements_present()
for flag, expected_prompts in FLAG_EXPECTED_PROMPT:
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
31 changes: 30 additions & 1 deletion tests/llm/test_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,38 @@ def hello_world():
assert llm_utils.extract_code_blocks(text) == expected_output


def test_message_merge_only_text():
content = [
{"type": "text", "text": "Hello, world!"},
{"type": "text", "text": "This is a test."},
]
message = llm_utils.BaseMessage(role="system", content=content)
message.merge()
assert len(message["content"]) == 1
assert message["content"][0]["text"] == "Hello, world!\nThis is a test."


def test_message_merge_text_image():
content = [
{"type": "text", "text": "Hello, world!"},
{"type": "text", "text": "This is a test."},
{"type": "image_url", "image_url": "this is a base64 image"},
{"type": "text", "text": "This is another test."},
{"type": "text", "text": "Goodbye, world!"},
]
message = llm_utils.BaseMessage(role="system", content=content)
message.merge()
assert len(message["content"]) == 3
assert message["content"][0]["text"] == "Hello, world!\nThis is a test."
assert message["content"][1]["image_url"] == "this is a base64 image"
assert message["content"][2]["text"] == "This is another test.\nGoodbye, world!"


if __name__ == "__main__":
# test_retry_parallel()
# test_rate_limit_max_wait_time()
# test_successful_parse_before_max_retries()
# test_unsuccessful_parse_before_max_retries()
test_extract_code_blocks()
# test_extract_code_blocks()
# test_message_merge_only_text()
test_message_merge_text_image()