Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 178 additions & 8 deletions src/a2a/contrib/tasks/vertex_task_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,34 @@
import base64
import json

from dataclasses import dataclass
from typing import Any

from a2a.types import (
Artifact,
DataPart,
FilePart,
FileWithBytes,
FileWithUri,
Message,
Part,
Role,
Task,
TaskState,
TaskStatus,
TextPart,
)


_ORIGINAL_METADATA_KEY = 'originalMetadata'
_EXTENSIONS_KEY = 'extensions'
_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds'
_PART_METADATA_KEY = 'partMetadata'
_PART_TYPES_KEY = 'partTypes'
_METADATA_VERSION_KEY = '__vertex_compat_v'
_METADATA_VERSION_NUMBER = 1.0


_TO_SDK_TASK_STATE = {
vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown,
vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted,
Expand Down Expand Up @@ -52,6 +66,62 @@ def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState:
)


def to_stored_metadata(
original_metadata: dict[str, Any] | None,
extensions: list[str] | None,
reference_task_ids: list[str] | None,
parts: list[Part],
) -> dict[str, Any]:
"""Packs original metadata, extensions, and part types/metadata into a storage dictionary."""
metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER}
if original_metadata:
metadata[_ORIGINAL_METADATA_KEY] = original_metadata
if extensions:
metadata[_EXTENSIONS_KEY] = extensions
if reference_task_ids:
metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids

part_types = []
part_metadata = []
for part in parts:
part_types.append('data' if isinstance(part.root, DataPart) else '')
part_metadata.append(part.root.metadata)

metadata[_PART_TYPES_KEY] = part_types
metadata[_PART_METADATA_KEY] = part_metadata

return metadata


@dataclass
class _UnpackedMetadata:
original_metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
reference_task_ids: list[str] | None = None
part_metadata: list[dict[str, Any] | None] | None = None
part_types: list[str] | None = None


def to_sdk_metadata(
stored_metadata: dict[str, Any] | None,
) -> _UnpackedMetadata:
"""Unpacks metadata, extensions, and part types/metadata from a storage dictionary."""
if not stored_metadata:
return _UnpackedMetadata()

version = stored_metadata.get(_METADATA_VERSION_KEY)
if version is None:
return _UnpackedMetadata(original_metadata=stored_metadata)

return _UnpackedMetadata(
original_metadata=stored_metadata.get(_ORIGINAL_METADATA_KEY),
extensions=stored_metadata.get(_EXTENSIONS_KEY),
reference_task_ids=stored_metadata.get(_REFERENCE_TASK_IDS_KEY),
part_metadata=stored_metadata.get(_PART_METADATA_KEY),
part_types=stored_metadata.get(_PART_TYPES_KEY),
)


def to_stored_part(part: Part) -> genai_types.Part:
"""Converts a SDK Part to a proto Part."""
if isinstance(part.root, TextPart):
Expand Down Expand Up @@ -82,29 +152,42 @@ def to_stored_part(part: Part) -> genai_types.Part:
raise ValueError(f'Unsupported part type: {type(part.root)}')


def to_sdk_part(stored_part: genai_types.Part) -> Part:
def to_sdk_part(
stored_part: genai_types.Part,
part_metadata: dict[str, Any] | None = None,
part_type: str = '',
) -> Part:
"""Converts a proto Part to a SDK Part."""
if stored_part.text:
return Part(root=TextPart(text=stored_part.text))
return Part(
root=TextPart(text=stored_part.text, metadata=part_metadata)
)
if stored_part.inline_data:
mime_type = stored_part.inline_data.mime_type
if part_type == 'data' and mime_type == 'application/json':
data_dict = json.loads(stored_part.inline_data.data or b'{}')
return Part(root=DataPart(data=data_dict, metadata=part_metadata))

encoded_bytes = base64.b64encode(
stored_part.inline_data.data or b''
).decode('utf-8')
return Part(
root=FilePart(
file=FileWithBytes(
mime_type=stored_part.inline_data.mime_type,
mime_type=mime_type,
bytes=encoded_bytes,
)
),
metadata=part_metadata,
)
)
if stored_part.file_data:
return Part(
root=FilePart(
file=FileWithUri(
mime_type=stored_part.file_data.mime_type,
uri=stored_part.file_data.file_uri,
)
uri=stored_part.file_data.file_uri or '',
),
metadata=part_metadata,
)
)

Expand All @@ -115,15 +198,93 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact:
"""Converts a SDK Artifact to a proto TaskArtifact."""
return vertexai_types.TaskArtifact(
artifact_id=artifact.artifact_id,
display_name=artifact.name,
description=artifact.description,
parts=[to_stored_part(part) for part in artifact.parts],
metadata=to_stored_metadata(
original_metadata=artifact.metadata,
extensions=artifact.extensions,
reference_task_ids=None,
parts=artifact.parts,
),
)


def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
"""Converts a proto TaskArtifact to a SDK Artifact."""
unpacked_meta = to_sdk_metadata(stored_artifact.metadata)
part_metadata_list = unpacked_meta.part_metadata or []
part_types = unpacked_meta.part_types or []

parts = []
for i, part in enumerate(stored_artifact.parts or []):
meta: dict[str, Any] | None = None
if i < len(part_metadata_list):
meta = part_metadata_list[i]
ptype = ''
if i < len(part_types):
ptype = part_types[i]
parts.append(to_sdk_part(part, part_metadata=meta, part_type=ptype))

return Artifact(
artifact_id=stored_artifact.artifact_id,
parts=[to_sdk_part(part) for part in stored_artifact.parts],
name=stored_artifact.display_name,
description=stored_artifact.description,
extensions=unpacked_meta.extensions,
metadata=unpacked_meta.original_metadata,
parts=parts,
)


def to_stored_message(
message: Message | None,
) -> vertexai_types.TaskMessage | None:
"""Converts a SDK Message to a proto Message."""
if not message:
return None
role = message.role.value if message.role else ''
return vertexai_types.TaskMessage(
message_id=message.message_id,
role=role,
parts=[to_stored_part(part) for part in message.parts],
metadata=to_stored_metadata(
original_metadata=message.metadata,
extensions=message.extensions,
reference_task_ids=message.reference_task_ids,
parts=message.parts,
),
)


def to_sdk_message(
stored_msg: vertexai_types.TaskMessage | None,
) -> Message | None:
"""Converts a proto Message to a SDK Message."""
if not stored_msg:
return None
unpacked_meta = to_sdk_metadata(stored_msg.metadata)
part_metadata_list = unpacked_meta.part_metadata or []
part_types = unpacked_meta.part_types or []

parts = []
for i, part in enumerate(stored_msg.parts or []):
part_metadata: dict[str, Any] | None = None
if i < len(part_metadata_list):
part_metadata = part_metadata_list[i]
part_type = ''
if i < len(part_types):
part_type = part_types[i]
parts.append(
to_sdk_part(part, part_metadata=part_metadata, part_type=part_type)
)

return Message(
message_id=stored_msg.message_id,
role=Role(stored_msg.role),
extensions=unpacked_meta.extensions,
reference_task_ids=unpacked_meta.reference_task_ids,
metadata=unpacked_meta.original_metadata,
parts=parts,
)


Expand All @@ -133,6 +294,11 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask:
context_id=task.context_id,
metadata=task.metadata,
state=to_stored_task_state(task.status.state),
status_details=vertexai_types.TaskStatusDetails(
task_message=to_stored_message(task.status.message)
)
if task.status.message
else None,
output=vertexai_types.TaskOutput(
artifacts=[
to_stored_artifact(artifact)
Expand All @@ -144,10 +310,14 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask:

def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task:
"""Converts a proto A2aTask to a SDK Task."""
msg: Message | None = None
if a2a_task.status_details and a2a_task.status_details.task_message:
msg = to_sdk_message(a2a_task.status_details.task_message)

return Task(
id=a2a_task.name.split('/')[-1],
context_id=a2a_task.context_id,
status=TaskStatus(state=to_sdk_task_state(a2a_task.state)),
status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg),
metadata=a2a_task.metadata or {},
artifacts=[
to_sdk_artifact(artifact)
Expand Down
Loading
Loading