Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/doc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ jobs:
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
uses: actions/deploy-pages@v4
29 changes: 8 additions & 21 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,17 @@ repos:
- id: check-merge-conflict
- id: detect-private-key

- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
language_version: python3.10
args: [--line-length=100]

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]

- repo: https://github.com/pycqa/flake8
rev: 6.1.0
- repo: https://github.com/myint/autoflake
rev: v2.2.0
hooks:
- id: flake8
additional_dependencies: [flake8-docstrings]
args: [
"--max-line-length=100",
"--max-complexity=20",
"--select=C,E,F,W,B,B950",
"--ignore=E203,E266,E501,W503",
]
- id: autoflake
args: [
--in-place,
--remove-all-unused-imports,
--ignore-init-module-imports
]
Comment on lines +16 to +24

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Replacing black, isort, and flake8 with only autoflake is a significant step back for code quality. autoflake only handles unused imports, while the previous tools enforced code style and caught a wide range of potential bugs.

I strongly recommend using a more comprehensive tool. ruff is a modern, high-performance tool that can replace flake8, isort, autoflake, and even black's formatting. Here is a suggested configuration using ruff for both linting and formatting.

  - repo: https://github.com/astral-sh/ruff-pre-commit
    rev: v0.2.1
    hooks:
      - id: ruff
        args: [--fix, --exit-non-zero-on-fix]
      - id: ruff-format


- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.0
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,4 @@ If you use AgentJet in your research, please cite:
<div align="center">

[⭐ Star Us](https://github.com/modelscope/AgentJet) · [Report Bug](https://github.com/modelscope/AgentJet/issues) · [Request Feature](https://github.com/modelscope/AgentJet/issues)
</div>
</div>
5 changes: 2 additions & 3 deletions ajet/backbone/main_vllm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import atexit
import os
import sys
from types import SimpleNamespace
Expand Down Expand Up @@ -83,7 +82,7 @@ def submit_chat_completions(self, messages, sampling_params, request_id, tools=[
"content": message["content"],
"tool_calls": message.get("tool_calls", None),
"tokens": [
TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore
TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore
],
}
)
Expand Down Expand Up @@ -131,7 +130,7 @@ async def submit_chat_completions_async(self, messages, sampling_params, request
"content": message["content"],
"tool_calls": message.get("tool_calls", None),
"tokens": [
TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore
TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore
],
}
)
Expand Down
11 changes: 4 additions & 7 deletions ajet/backbone/trainer_trinity.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
import os
from typing import Dict, List, Literal, Optional, cast

import asyncio
import datasets
import openai
import swanlab

from loguru import logger
from transformers import AutoTokenizer
from typing import Dict, List, Literal, Optional, cast
from trinity.buffer.reader import READER
from trinity.buffer.reader.file_reader import TaskFileReader, _HFBatchReader
from trinity.buffer.schema import FORMATTER
Expand All @@ -19,9 +19,7 @@
from trinity.utils.monitor import MONITOR, Monitor

from ajet.backbone.warm_up import warm_up_process
from ajet.context_tracker.multiagent_tracking import (
MultiAgentContextTracker,
)
from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker
from ajet.schema.trajectory import Sample
from ajet.task_reader import dict_to_ajet_task
from ajet.task_rollout.native_parallel_worker import DynamicRolloutManager
Expand Down Expand Up @@ -65,7 +63,6 @@ def __init__(
)

def convert_task(self, task: TrinityTask):
from ajet.schema.task import Task
assert isinstance(task.raw_task, dict)
return dict_to_ajet_task(task.raw_task)

Expand Down
2 changes: 1 addition & 1 deletion ajet/backbone/warm_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ def warm_up_process(config):
experiment_name = config.ajet.experiment_name
init_parallel_rollout_logger(experiment_name)
warm_up_task_judge_when_needed(config)
clean_up_tmp_ajet_dir(config)
clean_up_tmp_ajet_dir(config)
2 changes: 1 addition & 1 deletion ajet/context_tracker/base_tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Tuple, Union
from typing import List, Union, Tuple, Dict, Optional, Any
from typing import List, Union, Tuple, Dict, Optional
Comment on lines 1 to +2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These two import statements from typing are partially redundant. They can be merged into a single, sorted import statement for better readability and maintainability.

Suggested change
from typing import List, Tuple, Union
from typing import List, Union, Tuple, Dict, Optional, Any
from typing import List, Union, Tuple, Dict, Optional
from typing import Dict, List, Optional, Tuple, Union

from ajet.schema.task import WorkflowTask

from ajet.schema.extended_msg import (
Expand Down
1 change: 0 additions & 1 deletion ajet/context_tracker/timeline_merging/timeline_merging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List

from beast_logger import print_listofdict

from ajet.context_tracker.basic_tracker import ExtendedMessage

Expand Down
4 changes: 1 addition & 3 deletions ajet/schema/convertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from agentscope.model import ChatResponse as AgentScopeChatResponse
from openai.types.completion_usage import CompletionUsage
from typing import Any, Callable, Dict, List, Literal, Type, Union
from typing import List, Type
from agentscope.message import TextBlock, ToolUseBlock
from agentscope._utils._common import _json_loads_with_repair
from pydantic import BaseModel
from agentscope.model import ChatResponse


def convert_llm_proxy_response_to_oai_response(llm_proxy_response):
Expand Down Expand Up @@ -106,4 +105,3 @@ def convert_llm_proxy_response_to_agentscope_response(
)

return parsed_response

2 changes: 1 addition & 1 deletion ajet/schema/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
class TokenAndProb(BaseModel):
token_id: int
logprob: float
decoded_string: str
decoded_string: str
2 changes: 1 addition & 1 deletion ajet/task_reader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,4 @@ def dict_to_ajet_task(task_dict: dict) -> Task:
task_id=task_dict.get("task_id", ""),
env_type=task_dict.get("env_type", ""),
metadata=task_dict.get("metadata", {}),
)
)
3 changes: 1 addition & 2 deletions ajet/task_rollout/async_llm_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import json
import time
import uuid
from typing import Any, Callable, Dict, List, Literal, Type, Union
from typing import Any, Callable, Dict, List, Literal, Union



from loguru import logger
from omegaconf import DictConfig
from pydantic import BaseModel
from transformers.tokenization_utils import PreTrainedTokenizer
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.outputs import RequestOutput as VerlVllmRequestOutput

Expand Down
2 changes: 0 additions & 2 deletions ajet/task_runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from threading import Lock
from typing import Any, Callable, Union, Type
from multiprocessing import Process, Queue
from unittest import result

from ajet.context_tracker.basic_tracker import BaseContextTracker
from ajet.schema.task import WorkflowOutput, WorkflowTask
Expand Down Expand Up @@ -117,4 +116,3 @@ def run_user_workflow(

else:
raise ValueError(f"Unsupported wrapper type: {self.wrapper_type}")

3 changes: 1 addition & 2 deletions ajet/task_runner/general_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from venv import logger

from ajet import AjetTuner
from ajet import Workflow, WorkflowOutput
from ajet import WorkflowOutput
from ajet.context_tracker.multiagent_tracking import (
MultiAgentContextTracker,
)
Expand Down
2 changes: 1 addition & 1 deletion ajet/tuner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Literal, Callable, Union, Type
from typing import TYPE_CHECKING, Callable, Union, Type

from ajet.context_tracker.multiagent_tracking import (
MultiAgentContextTracker,
Expand Down
1 change: 0 additions & 1 deletion ajet/tuner_lib/weight_tuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from ajet.tuner_lib.weight_tuner.as_agentscope_model import AgentScopeModelTuner
from ajet.tuner_lib.weight_tuner.as_oai_sdk_model import OpenaiClientModelTuner

7 changes: 2 additions & 5 deletions ajet/tuner_lib/weight_tuner/as_agentscope_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Any, Literal, Type
from typing import Any, Literal, Type

from agentscope._utils._common import _create_tool_from_base_model
from agentscope.model import ChatModelBase, ChatResponse, DashScopeChatModel
from agentscope.model import ChatResponse, DashScopeChatModel
from loguru import logger
from pydantic import BaseModel

Expand All @@ -10,9 +10,6 @@
)
from ajet.task_rollout.async_llm_bridge import AgentScopeLlmProxyWithTracker

if TYPE_CHECKING:
from ajet import Workflow


class AgentScopeModelTuner(DashScopeChatModel):
"""
Expand Down
13 changes: 2 additions & 11 deletions ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
import os
import asyncio
from typing import TYPE_CHECKING, Any, List, Callable, Literal, Type, Union
from loguru import logger
from typing import Any
from pydantic import BaseModel, Field
from ajet.context_tracker.multiagent_tracking import (
MultiAgentContextTracker,
)
from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker
from ajet.utils.magic_mock import SpecialMagicMock
from openai.types.chat.chat_completion import ChatCompletion
from openai.resources.chat.chat import Chat, AsyncChat
from openai.resources.chat.chat import AsyncChat
from openai.resources.completions import AsyncCompletions
from openai import OpenAI, AsyncOpenAI
from ajet.utils.networking import find_free_port
from .experimental.as_oai_model_client import generate_auth_token

if TYPE_CHECKING:
from ajet import Workflow

class MockAsyncCompletions(AsyncCompletions):
async def create(self, *args, **kwargs) -> Any: # type: ignore
Expand Down
15 changes: 3 additions & 12 deletions ajet/tuner_lib/weight_tuner/as_oai_sdk_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import asyncio
from typing import TYPE_CHECKING, Any, List, Callable, Literal, Type, Union
from loguru import logger
from pydantic import BaseModel
from typing import Any, List, Callable
from ajet.context_tracker.multiagent_tracking import (
MultiAgentContextTracker,
)
from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker
from ajet.utils.magic_mock import SpecialMagicMock
from openai.types.chat.chat_completion import ChatCompletion
from openai.resources.chat.chat import Chat, AsyncChat
from openai.resources.chat.chat import AsyncChat
from openai.resources.completions import AsyncCompletions
from openai import OpenAI, AsyncOpenAI

if TYPE_CHECKING:
from ajet import Workflow
from openai import AsyncOpenAI


class MockAsyncCompletions(AsyncCompletions):
Expand Down Expand Up @@ -80,5 +73,3 @@ async def create(
)
assert isinstance(response_gen, ChatCompletion)
return response_gen


Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from loguru import logger
from pydantic import BaseModel
from fastapi import FastAPI, Header, HTTPException, Request, Body
from fastapi import FastAPI, Header, HTTPException, Request
from contextlib import asynccontextmanager
from multiprocessing import Process
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -239,5 +239,3 @@ def start_interchange_server(config) -> int:

# return port
return port


2 changes: 1 addition & 1 deletion ajet/utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ def _patched_del(self) -> None:
AsyncHttpxClientWrapper.__del__ = _patched_del
print("Applied httpx aclose patch.")
except ImportError:
pass
pass
2 changes: 1 addition & 1 deletion ajet/utils/lowlevel_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def debug_task_init(self, coro, loop=None, name=None, context=None):
asyncio.create_task = debug_create_task
asyncio.AbstractEventLoop.create_task = debug_loop_create_task

patch_task_creation()
patch_task_creation()
2 changes: 1 addition & 1 deletion ajet/utils/metric_helper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ def update_metrics(context_tracker_arr, metrics:dict):
metrics.update(tool_metrics)
if reward_metrics:
metrics.update(reward_metrics)
return
return
Loading
Loading