Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ serde = ["dacite", "pydantic", "msgspec"]
client = ["httpx[http2]"]
adk = ["google-adk>=1.20.0"]
openai = ["openai-agents>=0.6.1"]
pydantic_ai = ["pydantic-ai-slim>=1.35.0"]
pydantic_ai = ["pydantic-ai-slim>=1.68.0"]

[build-system]
requires = ["maturin>=1.6,<2.0"]
Expand Down
22 changes: 22 additions & 0 deletions python/restate/ext/pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
import typing

from restate import ObjectContext, Context
from restate.server_context import current_context

from ._agent import RestateAgent
from ._model import RestateModelWrapper
from ._serde import PydanticTypeAdapter
from ._toolset import RestateContextRunToolSet

def restate_object_context() -> ObjectContext:
"""Get the current Restate ObjectContext."""
ctx = current_context()
if ctx is None:
raise RuntimeError("No Restate context found.")
return typing.cast(ObjectContext, ctx)


def restate_context() -> Context:
"""Get the current Restate Context."""
ctx = current_context()
if ctx is None:
raise RuntimeError("No Restate context found.")
return ctx


__all__ = [
"RestateModelWrapper",
"RestateAgent",
"PydanticTypeAdapter",
"RestateContextRunToolSet",
"restate_object_context",
"restate_context",
]
72 changes: 38 additions & 34 deletions python/restate/ext/pydantic/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from pydantic_ai import models
from pydantic_ai._run_context import AgentDepsT
from pydantic_ai.agent.abstract import AbstractAgent, EventStreamHandler, RunOutputDataT, Instructions
from pydantic_ai.agent.abstract import AbstractAgent, AgentMetadata, EventStreamHandler, RunOutputDataT, Instructions
from pydantic_ai.agent.wrapper import WrapperAgent
from pydantic_ai.builtin_tools import AbstractBuiltinTool
from pydantic_ai.exceptions import UserError
Expand All @@ -32,56 +32,53 @@ class RestateAgent(WrapperAgent[AgentDepsT, OutputDataT]):
"""An agent that integrates with Restate framework for building resilient applications.

This agent wraps an existing agent with Restate context capabilities, providing
automatic retries and durable execution for all operations. By default, tool calls
are automatically wrapped with Restate's execution model.
automatic retries and durable execution for model calls and MCP tool calls.

The Restate context is available within your tools via `restate_context()`,
giving you features like RPC calls, timers, and multi-step operations.

Example:
...

weather = restate.Service('weather')
from restate.ext.pydantic import restate_context

weather_agent = Agent(...)

@weather_agent.tool
async def get_weather(ctx: RunContext, city: str) -> dict:
return await restate_context().run_typed(...)


agent = RestateAgent(weather_agent)

@weather.handler()
async def get_weather(ctx: restate.Context, city: str):
agent_service = restate.Service('agent')

@agent_service.handler()
async def run(ctx: restate.Context, city: str):
result = await agent.run(f'What is the weather in {city}?')
return result.output
...

For advanced scenarios, you can disable automatic tool wrapping by setting
`disable_auto_wrapping_tools=True`. This allows direct usage of Restate context
within your tools for features like RPC calls, timers, and multi-step operations.

When automatic wrapping is disabled, function tools will NOT be automatically executed
within Restate's `ctx.run()` context, giving you full control over how the
Restate context is used within your tool implementations.
But model calls, and MCP tool calls will still be automatically wrapped.
For simple tools that don't need direct Restate context access, you can enable
automatic wrapping by setting `auto_wrap_tools=True`. This will automatically
execute function tools within Restate's `ctx.run()` context.

Example:
...

@dataclass
WeatherDeps:
...
restate_context: Context
from restate.ext.pydantic import restate_context

weather_agent = Agent(..., deps_type=WeatherDeps, ...)
weather_agent = RestateAgent(weather_agent, auto_wrap_tools=True)

@weather_agent.tool
async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -> LatLng:
restate_context = ctx.deps.restate_context
lat = await restate_context.run(...) # <---- note the direct usage of the restate context
lng = await restate_context.run(...)
return LatLng(lat, lng)

async def get_weather(ctx: RunContext, city: str) -> dict:
return await fetch_weather(...)

agent = RestateAgent(weather_agent)

weather = restate.Service('weather')
agent_service = restate.Service('agent')

@weather.handler()
@agent_service.handler()
async def get_weather(ctx: restate.Context, city: str):
result = await agent.run(f'What is the weather in {city}?', deps=WeatherDeps(restate_context=ctx, ...))
result = await agent.run(f'What is the weather in {city}?')
return result.output
...

Expand All @@ -92,7 +89,7 @@ def __init__(
wrapped: AbstractAgent[AgentDepsT, OutputDataT],
*,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
disable_auto_wrapping_tools: bool = False,
auto_wrap_tools: bool = False,
run_options: RunOptions | None = None,
):
super().__init__(wrapped)
Expand All @@ -102,7 +99,7 @@ def __init__(
)

self._event_stream_handler = event_stream_handler
self._disable_auto_wrapping_tools = disable_auto_wrapping_tools
self._auto_wrap_tools = auto_wrap_tools

if run_options is None:
run_options = RunOptions(max_attempts=3)
Expand All @@ -111,7 +108,7 @@ def __init__(

def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
"""Set the Restate context for the toolset, wrapping tools if needed."""
if isinstance(toolset, FunctionToolset) and not disable_auto_wrapping_tools:
if isinstance(toolset, FunctionToolset) and auto_wrap_tools:
return RestateContextRunToolSet(toolset, run_options)
try:
from pydantic_ai.mcp import MCPServer
Expand Down Expand Up @@ -144,7 +141,7 @@ def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None:
handler = self._event_stream_handler or super().event_stream_handler
if handler is None:
return None
if self._disable_auto_wrapping_tools:
if not self._auto_wrap_tools:
return handler
return self.wrapped_event_stream_handler

Expand Down Expand Up @@ -184,6 +181,7 @@ async def run(
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: RunUsage | None = None,
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
Expand All @@ -204,6 +202,7 @@ async def run(
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: RunUsage | None = None,
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
Expand All @@ -223,6 +222,7 @@ async def run(
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: RunUsage | None = None,
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
Expand Down Expand Up @@ -281,6 +281,7 @@ async def main():
model_settings=model_settings,
usage_limits=usage_limits,
usage=usage,
metadata=metadata,
infer_name=infer_name,
toolsets=toolsets,
builtin_tools=builtin_tools,
Expand All @@ -301,6 +302,7 @@ def run_stream(
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: RunUsage | None = None,
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
Expand All @@ -321,6 +323,7 @@ def run_stream(
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: RunUsage | None = None,
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
Expand All @@ -341,6 +344,7 @@ async def run_stream(
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: RunUsage | None = None,
metadata: AgentMetadata[AgentDepsT] | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None,
Expand Down
Loading
Loading