Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ajet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"WorkflowOutput",
"AjetTuner",
"AgentJetJob",
"bp",
"bp"
]

__version__ = "0.1.0"
2 changes: 1 addition & 1 deletion ajet/utils/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def verify_python_env(args, exp_config):
time.sleep(5)
raise ImportError(cause + " " + solution)
elif args.backbone == "verl":
if not any([v in verl.__version__ for v in ["0.5.0.post", "0.7.0.post"]]): # you must install via `pip install -e .[verl]` to get every dependency right
if not any([v in verl.__version__ for v in ["0.5.0.post", "0.5.0.dev", "0.7.0.post"]]): # you must install via `pip install -e .[verl]` to get every dependency right
cause = "Python environment does not match current backbone 'verl'."
solution = "Please `cd /path/to/project/AgentJet` and run `(uv) pip install -e .[verl]` to install the correct environment."
print_dict(
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ dev = [
"mypy>=1.7.0",
"pytest>=8.0.0",
"pytest-json-ctrf",
"langchain>=1.2.3",
]

reward = [
Expand Down Expand Up @@ -112,4 +113,4 @@ known_third_party = ["wandb"]


[project.urls]
"Homepage" = "https://github.com/modelscope/AgentJet"
"Homepage" = "https://github.com/modelscope/AgentJet"
199 changes: 199 additions & 0 deletions tutorial/example_learn2ask/learn2ask_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@

import re
import time
import asyncio
import threading

from agentscope.message import Msg

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Msg class is imported from agentscope.message but is not used within this file. It's good practice to remove unused imports to keep the code clean and maintainable.

from loguru import logger

from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask
from ajet.utils.robust_dashscope import RobustDashScopeChatModel

system_prompt = """# Task
You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English.

## Guidelines
- Each response must contain exactly one clear and concise medical question with 2 to 3 answer choices.
- Do not repeat any previous question.
- Your response must be a single sentence.
- If enough information has been gathered to make a medication suggestion, output only: <stop />
"""

reward_prompt = """# Task
You are an evaluation assistant. The user will provide a dialogue history between a doctor and a patient. You must analyze the dialogue and evaluate the doctor's last message.

# Grading Policy
## Format Score
- 1.0: The doctor's last message contains exactly **one question**.
- 0.5: The doctor's last message contains **two questions**.
- 0.0: The doctor's last message contains **three or more questions**.

## Content Score
Reference Information contains the information that the doctor has not known.

- 1.0: The question(s) **directly ask about** item in the Reference Information.
- 0.1: The question(s) are a general type of question that could be asked for any symptoms.
- 0.0: The question(s) are **irrelevant** to all items in the Reference Information.

### You should

- ONLY if the doctor asks a question that helps to collect information and diagnose the patient, it is a good question.
- A ambiguous question should get 0.
- For example, the doctor asks "How long have you been feeling this way?", but "this way" is not clear in the previous messages.
- For example, the doctor asks "Do you feel bad?". This is a meaningless question that does not provide any useful information.

# Reference Information

{}

# Output Format
<think>Explain your reasoning for the format and content scores clearly and concisely.</think>
<format_score>Insert only the format score as a float (e.g., 1.0, 0.5, 0.0)</format_score>
<content_score>Insert only the content score as a float (e.g., 1.0, 0.5, 0.0)</content_score>

> ✅ Important:
> - Output **exactly** the three tags shown above.
> - Do **not** include any additional text, explanation, or formatting outside the tags.
> - Scores must be based **only** on the doctor's **last message** and the provided Reference Information.
> - Ensure clarity and precision in your evaluation reasoning within the `<think>` tag.
"""


llm = RobustDashScopeChatModel("qwen-plus", stream=False)


async def llm_reward(init_messages: list[dict], response: str, truth_info: str):
def format_messages(messages: list[dict]) -> str:
result_str = ""
for msg in messages:
if msg["role"] == "user":
result_str += f"patient: {msg['content']}\n"
if msg["role"] == "assistant":
result_str += f"doctor: {msg['content']}\n"
return result_str

def parse_tag_string(text: str):
pattern = r"<(\w+)>(.*?)</\1>"
matches = re.findall(pattern, text)
result = {}
for tag, value in matches:
result[tag] = value
return result

history = format_messages([] + init_messages + [{"role": "assistant", "content": response}])
messages = [
{"role": "system", "content": reward_prompt.format(truth_info)},
{"role": "user", "content": history},
]

try_count, max_retries = 0, 5
while try_count <= max_retries:
try:

async def get_content():
from agentscope.model import ChatResponse

response = await llm(messages)

if isinstance(response, ChatResponse):
res = "".join([x["text"] for x in response.content if "text" in x])
else:
res = ""
async for chunk in response:
res += "".join([x["text"] for x in chunk.content if "text" in x])
return res

content = await get_content()
score_dict = parse_tag_string(content)
return score_dict
except Exception as e:
if try_count > max_retries:
logger.warning("retried too many times, abort task.")
return None
else:
logger.warning(f"error: {e}, response:{response}, retrying...")
time.sleep(2**try_count)


async def reward_fn(init_messages: list[dict], response: str, truth_action: str, truth_info: str):
"""
content_score: R_a, the reward for response quality
action_score: R_s, the reward for decision correctness
format_score: P, the reward for response format
"""

action_response = "stop" if "<stop />" in response else "continue"
if truth_action == action_response:
action_score = 1.0
if truth_action == "continue":
score_dict = await llm_reward(init_messages, response, truth_info)
if score_dict is not None:
format_score = float(score_dict.get("format_score", 0.0))
content_score = float(score_dict.get("content_score", 0.0))
else:
format_score, content_score = 0.0, 0.0
else:
content_score = 1.0
format_score = 1.0 if response == "<stop />" else 0.0
else:
action_score, format_score, content_score = 0.0, 0.0, 0.0

# treat as self.train_mode == "Ra+Rs", the default setting
final_reward = action_score * (1 + 2 * content_score) + format_score

return final_reward


_reward_semaphore = threading.Semaphore(16)

async def reward_fn_with_semaphore(*args, **kwargs):

get_sem_ok = False
while not get_sem_ok:
get_sem_ok = _reward_semaphore.acquire(blocking=False)
if not get_sem_ok:
await asyncio.sleep(1)

try:
fn_result = await reward_fn(*args, **kwargs)
finally:
_reward_semaphore.release()

return fn_result
Comment on lines +148 to +163

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation uses threading.Semaphore with a polling loop (while not get_sem_ok...) in an asyncio context. This is inefficient as it involves busy-waiting. A more idiomatic and efficient approach for asynchronous code is to use asyncio.Semaphore, which integrates with the async with statement for cleaner and non-blocking synchronization.

Suggested change
_reward_semaphore = threading.Semaphore(16)
async def reward_fn_with_semaphore(*args, **kwargs):
get_sem_ok = False
while not get_sem_ok:
get_sem_ok = _reward_semaphore.acquire(blocking=False)
if not get_sem_ok:
await asyncio.sleep(1)
try:
fn_result = await reward_fn(*args, **kwargs)
finally:
_reward_semaphore.release()
return fn_result
_reward_semaphore = asyncio.Semaphore(16)
async def reward_fn_with_semaphore(*args, **kwargs):
async with _reward_semaphore:
return await reward_fn(*args, **kwargs)



class ExampleLearn2Ask(Workflow):
name: str = "math_agent_workflow"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The workflow name math_agent_workflow seems to be a copy-paste from another example. To improve clarity and avoid potential confusion, it should be renamed to something more descriptive of its actual function, like learn2ask_langchain_workflow.

Suggested change
name: str = "math_agent_workflow"
name: str = "learn2ask_langchain_workflow"


async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput:
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent

messages = workflow_task.task.init_messages
assert isinstance(messages, list)
truth_action = workflow_task.task.metadata["decision_truth"] or "continue"
truth_info = workflow_task.task.metadata["info_truth"]

llm_info=tuner.as_oai_baseurl_apikey()

llm=ChatOpenAI(
base_url=llm_info.base_url,
api_key=lambda:llm_info.api_key,
)

agent=create_agent(
model=llm,
system_prompt=system_prompt,
)

msg=[
{"role": x["role"], "content": x["content"]} for x in messages
]
result = agent.invoke({
"messages": msg, # type: ignore
})

response = result["messages"][-1].content
reward = await reward_fn_with_semaphore(msg, response, truth_action, truth_info)
return WorkflowOutput(reward=reward)
57 changes: 57 additions & 0 deletions tutorial/example_math_agent/math_agent_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from loguru import logger
from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat import ChatCompletionMessageToolCall
from textwrap import dedent

import json
import asyncio
import requests
from langchain.agents import create_agent
Comment on lines +1 to +10

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file contains several unused imports, including loguru, openai.types, json, asyncio, and requests. Removing these will improve code readability and maintainability. It's also a good practice to group and order imports according to PEP 8.

Suggested change
from loguru import logger
from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat import ChatCompletionMessageToolCall
from textwrap import dedent
import json
import asyncio
import requests
from langchain.agents import create_agent
from textwrap import dedent
from langchain.agents import create_agent
from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask



# ------------------------------------------------------
# Simple version - no tool call
# ------------------------------------------------------


class ExampleMathLearn(Workflow):

name: str = "math_agent_workflow"
system_prompt: str = dedent("""
You are an agent specialized in solving math problems.
Please solve the math problem given to you.
You can write and execute Python code to perform calculation or verify your answer.
You should return your final answer within \\boxed{{}}.
""")

async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: # type: ignore
# tuner to api key
url_and_apikey = tuner.as_oai_baseurl_apikey()
base_url = url_and_apikey.base_url
api_key = url_and_apikey.api_key

from langchain_openai import ChatOpenAI
llm=ChatOpenAI(
base_url=base_url,
api_key=lambda:api_key,
)
agent=create_agent(
model=llm,
system_prompt=self.system_prompt,
)

# take out query
query = workflow_task.task.main_query

response = agent.invoke({
"messages": [
{
"role": "user",
"content": query
}
],
})

final_answer = response['messages'][-1].content
return WorkflowOutput(reward=None, metadata={"final_answer": final_answer})
Loading