Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions python/restate/context_managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""
contextvar utility for async context managers.
"""

import contextvars
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncContextManager,
AsyncGenerator,
Callable,
Generic,
ParamSpec,
TypeVar,
)

P = ParamSpec("P")
T = TypeVar("T")


class contextvar(Generic[P, T]):
"""
A type-safe decorator for asynccontextmanager functions that captures the yielded value in a ContextVar.
This is useful when integrating with frameworks that only support None yielded values from context managers.

Example usage:
```python
@contextvar
@asynccontextmanager
async def my_resource() -> AsyncIterator[str]:
yield "hi"

async def usage_example():
async with my_resource():
print(my_resource.value) # prints "hi"
```


"""

def __init__(self, func: Callable[P, AsyncContextManager[T]]):
self.func = func
self._value_var: contextvars.ContextVar[T | None] = contextvars.ContextVar("value")

@property
def value(self) -> T:
"""Return the value yielded by the wrapped context manager."""
val = self._value_var.get()
if val is None:
raise LookupError("Context manager value accessed outside of context manager scope (has not been entered yet)")
return val

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> AsyncContextManager[None]:
@asynccontextmanager
async def wrapper() -> AsyncGenerator[None, Any]:
async with self.func(*args, **kwargs) as value:
token = self._value_var.set(value)
try:
yield # we make it yield None, as the value is accessible via .value()
finally:
self._value_var.reset(token)

return wrapper()
3 changes: 2 additions & 1 deletion python/restate/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
"""This module contains internal extensions apis"""

from .server_context import current_context
from .context_managers import contextvar

__all__ = ["current_context"]
__all__ = ["current_context", "contextvar"]
5 changes: 4 additions & 1 deletion python/restate/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataclasses import dataclass
from datetime import timedelta
from inspect import Signature
from typing import Any, Callable, Awaitable, Dict, Generic, Literal, Optional, TypeVar
from typing import Any, AsyncContextManager, Callable, Awaitable, Dict, Generic, List, Literal, Optional, TypeVar

from restate.retry_policy import InvocationRetryPolicy

Expand Down Expand Up @@ -150,6 +150,7 @@ class Handler(Generic[I, O]):
enable_lazy_state: Optional[bool] = None
ingress_private: Optional[bool] = None
invocation_retry_policy: Optional[InvocationRetryPolicy] = None
context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None


# disable too many arguments warning
Expand All @@ -172,6 +173,7 @@ def make_handler(
enable_lazy_state: Optional[bool] = None,
ingress_private: Optional[bool] = None,
invocation_retry_policy: Optional[InvocationRetryPolicy] = None,
context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None,
) -> Handler[I, O]:
"""
Factory function to create a handler.
Expand Down Expand Up @@ -225,6 +227,7 @@ def make_handler(
enable_lazy_state=enable_lazy_state,
ingress_private=ingress_private,
invocation_retry_policy=invocation_retry_policy,
context_managers=context_managers,
)

vars(wrapped)[RESTATE_UNIQUE_HANDLER_SYMBOL] = handler
Expand Down
9 changes: 9 additions & 0 deletions python/restate/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
enable_lazy_state: typing.Optional[bool] = None,
ingress_private: typing.Optional[bool] = None,
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
):
self.service_tag = ServiceTag("object", name, description, metadata)
self.handlers = {}
Expand All @@ -97,6 +98,7 @@ def __init__(
self.enable_lazy_state = enable_lazy_state
self.ingress_private = ingress_private
self.invocation_retry_policy = invocation_retry_policy
self.context_managers = context_managers

@property
def name(self):
Expand All @@ -122,6 +124,7 @@ def handler(
enable_lazy_state: typing.Optional[bool] = None,
ingress_private: typing.Optional[bool] = None,
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
) -> typing.Callable[[T], T]:
"""
Decorator for defining a handler function.
Expand Down Expand Up @@ -184,6 +187,11 @@ def wrapped(*args, **kwargs):
return fn(*args, **kwargs)

signature = inspect.signature(fn, eval_str=True)
combined_context_managers = (
(self.context_managers or []) + (context_managers or [])
if self.context_managers or context_managers
else None
)
handler = make_handler(
self.service_tag,
handler_io,
Expand All @@ -201,6 +209,7 @@ def wrapped(*args, **kwargs):
enable_lazy_state,
ingress_private,
invocation_retry_policy,
combined_context_managers,
)
self.handlers[handler.name] = handler
return wrapped
Expand Down
6 changes: 5 additions & 1 deletion python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""This module contains the restate context implementation based on the server"""

import asyncio
from contextlib import AsyncExitStack
import contextvars
import copy
from random import Random
Expand Down Expand Up @@ -342,7 +343,10 @@ async def enter(self):
token = _restate_context_var.set(self)
try:
in_buffer = self.invocation.input_buffer
out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer)
async with AsyncExitStack() as stack:
for manager in self.handler.context_managers or []:
await stack.enter_async_context(manager())
out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer)
restate_context_is_replaying.set(False)
self.vm.sys_write_output_success(bytes(out_buffer))
self.vm.sys_end()
Expand Down
12 changes: 12 additions & 0 deletions python/restate/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
idempotency_retention: typing.Optional[timedelta] = None,
ingress_private: typing.Optional[bool] = None,
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
) -> None:
self.service_tag = ServiceTag("service", name, description, metadata)
self.handlers: typing.Dict[str, Handler] = {}
Expand All @@ -90,6 +91,7 @@ def __init__(
self.idempotency_retention = idempotency_retention
self.ingress_private = ingress_private
self.invocation_retry_policy = invocation_retry_policy
self.context_managers = context_managers

@property
def name(self):
Expand All @@ -112,6 +114,7 @@ def handler(
idempotency_retention: typing.Optional[timedelta] = None,
ingress_private: typing.Optional[bool] = None,
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
) -> typing.Callable[[T], T]:
"""
Decorator for defining a handler function.
Expand Down Expand Up @@ -170,6 +173,14 @@ def wrapped(*args, **kwargs):
return fn(*args, **kwargs)

signature = inspect.signature(fn, eval_str=True)

# combine context managers or leave None if both are None
combined_context_managers = (
(self.context_managers or []) + (context_managers or [])
if self.context_managers or context_managers
else None
)

handler = make_handler(
self.service_tag,
handler_io,
Expand All @@ -187,6 +198,7 @@ def wrapped(*args, **kwargs):
None,
ingress_private,
invocation_retry_policy,
combined_context_managers,
)
self.handlers[handler.name] = handler
return wrapped
Expand Down
13 changes: 13 additions & 0 deletions python/restate/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
enable_lazy_state: typing.Optional[bool] = None,
ingress_private: typing.Optional[bool] = None,
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
):
self.service_tag = ServiceTag("workflow", name, description, metadata)
self.handlers = {}
Expand All @@ -102,6 +103,7 @@ def __init__(
self.enable_lazy_state = enable_lazy_state
self.ingress_private = ingress_private
self.invocation_retry_policy = invocation_retry_policy
self.context_managers = context_managers

@property
def name(self):
Expand All @@ -125,6 +127,7 @@ def main(
enable_lazy_state: typing.Optional[bool] = None,
ingress_private: typing.Optional[bool] = None,
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
) -> typing.Callable[[T], T]:
"""
Mark this handler as a workflow entry point.
Expand Down Expand Up @@ -182,6 +185,7 @@ def main(
enable_lazy_state=enable_lazy_state,
ingress_private=ingress_private,
invocation_retry_policy=invocation_retry_policy,
context_managers=context_managers,
)

def handler(
Expand All @@ -199,6 +203,7 @@ def handler(
enable_lazy_state: typing.Optional[bool] = None,
ingress_private: typing.Optional[bool] = None,
invocation_retry_policy: typing.Optional[InvocationRetryPolicy] = None,
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
) -> typing.Callable[[T], T]:
"""
Decorator for defining a handler function.
Expand Down Expand Up @@ -256,6 +261,7 @@ def handler(
enable_lazy_state,
ingress_private,
invocation_retry_policy,
context_managers,
)

# pylint: disable=R0914
Expand All @@ -276,6 +282,7 @@ def _add_handler(
enable_lazy_state: typing.Optional[bool] = None,
ingress_private: typing.Optional[bool] = None,
invocation_retry_policy: typing.Optional["InvocationRetryPolicy"] = None,
context_managers: typing.Optional[typing.List[typing.Callable[[], typing.AsyncContextManager[None]]]] = None,
) -> typing.Callable[[T], T]:
"""
Decorator for defining a handler function.
Expand Down Expand Up @@ -342,6 +349,11 @@ def wrapped(*args, **kwargs):

signature = inspect.signature(fn, eval_str=True)
description = inspect.getdoc(fn)
combined_context_managers = (
(self.context_managers or []) + (context_managers or [])
if self.context_managers or context_managers
else None
)
handler = make_handler(
service_tag=self.service_tag,
handler_io=handler_io,
Expand All @@ -359,6 +371,7 @@ def wrapped(*args, **kwargs):
enable_lazy_state=enable_lazy_state,
ingress_private=ingress_private,
invocation_retry_policy=invocation_retry_policy,
context_managers=combined_context_managers,
)
self.handlers[handler.name] = handler
return wrapped
Expand Down
33 changes: 25 additions & 8 deletions tests/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#

from contextlib import asynccontextmanager
from restate.extensions import contextvar

import restate
from restate import (
Context,
Service,
HarnessEnvironment,
)

from restate import Context, Service, HarnessEnvironment, extensions
import pytest

# ----- Asyncio fixtures
Expand All @@ -35,9 +35,7 @@ def anyio_backend():


def magic_function():
from restate.extensions import current_context

ctx = current_context()
ctx = extensions.current_context()
assert ctx is not None
return ctx.request().id

Expand All @@ -48,6 +46,20 @@ async def greet(ctx: Context, name: str) -> str:
return f"Hello {id}!"


# -- context manager


@contextvar
@asynccontextmanager
async def my_resource_manager():
yield "hello"


@greeter.handler(context_managers=[my_resource_manager])
async def greet_with_cm(ctx: Context, name: str) -> str:
return my_resource_manager.value


@pytest.fixture(scope="session")
async def restate_test_harness():
async with restate.create_test_harness(
Expand All @@ -62,3 +74,8 @@ async def restate_test_harness():
async def test_greeter(restate_test_harness: HarnessEnvironment):
greeting = await restate_test_harness.client.service_call(greet, arg="bob")
assert greeting.startswith("Hello ")


async def test_greeter_with_cm(restate_test_harness: HarnessEnvironment):
greeting = await restate_test_harness.client.service_call(greet_with_cm, arg="bob")
assert greeting == "hello"