Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 33 additions & 30 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ public abstract class BaseAgent {
*/
private BaseAgent parentAgent;

private final List<? extends BaseAgent> subAgents;
private final ImmutableList<? extends BaseAgent> subAgents;

private final Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback;
private final Optional<List<? extends AfterAgentCallback>> afterAgentCallback;
private final ImmutableList<? extends BeforeAgentCallback> beforeAgentCallback;
private final ImmutableList<? extends AfterAgentCallback> afterAgentCallback;

/**
* Creates a new BaseAgent.
Expand All @@ -82,9 +82,13 @@ public BaseAgent(
this.name = name;
this.description = description;
this.parentAgent = null;
this.subAgents = subAgents != null ? subAgents : ImmutableList.of();
this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback);
this.afterAgentCallback = Optional.ofNullable(afterAgentCallback);
this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents);
this.beforeAgentCallback =
beforeAgentCallback == null
? ImmutableList.of()
: ImmutableList.copyOf(beforeAgentCallback);
this.afterAgentCallback =
afterAgentCallback == null ? ImmutableList.of() : ImmutableList.copyOf(afterAgentCallback);

// Establish parent relationships for all sub-agents if needed.
for (BaseAgent subAgent : this.subAgents) {
Expand Down Expand Up @@ -144,38 +148,38 @@ public BaseAgent rootAgent() {
/**
* Finds an agent (this or descendant) by name.
*
* @return the agent or descendant with the given name, or {@code null} if not found.
* @return an {@link Optional} containing the agent or descendant with the given name, or {@link
* Optional#empty()} if not found.
*/
public BaseAgent findAgent(String name) {
public Optional<BaseAgent> findAgent(String name) {
if (this.name().equals(name)) {
return this;
return Optional.of(this);
}
return findSubAgent(name);
}

/** Recursively search sub agent by name. */
public @Nullable BaseAgent findSubAgent(String name) {
for (BaseAgent subAgent : subAgents) {
if (subAgent.name().equals(name)) {
return subAgent;
}
BaseAgent result = subAgent.findSubAgent(name);
if (result != null) {
return result;
}
}
return null;
/**
* Recursively search sub agent by name.
*
* @return an {@link Optional} containing the sub agent with the given name, or {@link
* Optional#empty()} if not found.
*/
public Optional<BaseAgent> findSubAgent(String name) {
return subAgents.stream()
.map(subAgent -> subAgent.findAgent(name))
.flatMap(Optional::stream)
.findFirst();
}

public List<? extends BaseAgent> subAgents() {
return subAgents;
}

public Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback() {
public ImmutableList<? extends BeforeAgentCallback> beforeAgentCallback() {
return beforeAgentCallback;
}

public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
public ImmutableList<? extends AfterAgentCallback> afterAgentCallback() {
return afterAgentCallback;
}

Expand All @@ -184,17 +188,17 @@ public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
*
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
return beforeAgentCallback.orElse(ImmutableList.of());
public ImmutableList<? extends BeforeAgentCallback> canonicalBeforeAgentCallbacks() {
return beforeAgentCallback;
}

/**
* The resolved afterAgentCallback field as a list.
*
* <p>This method is only for use by Agent Development Kit.
*/
public List<? extends AfterAgentCallback> canonicalAfterAgentCallbacks() {
return afterAgentCallback.orElse(ImmutableList.of());
public ImmutableList<? extends AfterAgentCallback> canonicalAfterAgentCallbacks() {
return afterAgentCallback;
}

/**
Expand Down Expand Up @@ -239,8 +243,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
() ->
callCallback(
beforeCallbacksToFunctions(
invocationContext.pluginManager(),
beforeAgentCallback.orElse(ImmutableList.of())),
invocationContext.pluginManager(), beforeAgentCallback),
invocationContext)
.flatMapPublisher(
beforeEventOpt -> {
Expand All @@ -257,7 +260,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
callCallback(
afterCallbacksToFunctions(
invocationContext.pluginManager(),
afterAgentCallback.orElse(ImmutableList.of())),
afterAgentCallback),
invocationContext)
.flatMapPublisher(Flowable::fromOptional));

Expand Down
25 changes: 16 additions & 9 deletions core/src/main/java/com/google/adk/agents/BaseAgentConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.adk.agents;

import com.google.common.collect.ImmutableList;
import java.util.List;

/**
Expand All @@ -27,11 +28,11 @@ public class BaseAgentConfig {
private String name;
private String description = "";
private String agentClass;
private List<AgentRefConfig> subAgents;
private ImmutableList<AgentRefConfig> subAgents = ImmutableList.of();

// Callback configuration (names resolved via ComponentRegistry)
private List<CallbackRef> beforeAgentCallbacks;
private List<CallbackRef> afterAgentCallbacks;
private ImmutableList<CallbackRef> beforeAgentCallbacks = ImmutableList.of();
private ImmutableList<CallbackRef> afterAgentCallbacks = ImmutableList.of();

/** Reference to a callback stored in the ComponentRegistry. */
public static class CallbackRef {
Expand Down Expand Up @@ -131,27 +132,33 @@ public String agentClass() {
return agentClass;
}

public List<AgentRefConfig> subAgents() {
public ImmutableList<AgentRefConfig> subAgents() {
return subAgents;
}

public void setSubAgents(List<AgentRefConfig> subAgents) {
this.subAgents = subAgents;
this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents);
}

public List<CallbackRef> beforeAgentCallbacks() {
public ImmutableList<CallbackRef> beforeAgentCallbacks() {
return beforeAgentCallbacks;
}

public void setBeforeAgentCallbacks(List<CallbackRef> beforeAgentCallbacks) {
this.beforeAgentCallbacks = beforeAgentCallbacks;
this.beforeAgentCallbacks =
beforeAgentCallbacks == null
? ImmutableList.of()
: ImmutableList.copyOf(beforeAgentCallbacks);
}

public List<CallbackRef> afterAgentCallbacks() {
public ImmutableList<CallbackRef> afterAgentCallbacks() {
return afterAgentCallbacks;
}

public void setAfterAgentCallbacks(List<CallbackRef> afterAgentCallbacks) {
this.afterAgentCallbacks = afterAgentCallbacks;
this.afterAgentCallbacks =
afterAgentCallbacks == null
? ImmutableList.of()
: ImmutableList.copyOf(afterAgentCallbacks);
}
}
92 changes: 45 additions & 47 deletions core/src/main/java/com/google/adk/agents/CallbackUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.reactivex.rxjava3.core.Maybe;
import java.util.List;
import org.jspecify.annotations.Nullable;
import java.util.function.Function;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -37,65 +38,62 @@ public final class CallbackUtil {
/**
* Normalizes before-agent callbacks.
*
* @param beforeAgentCallback Callback list (sync or async).
* @return normalized async callbacks, or null if input is null.
* @param beforeAgentCallbacks Callback list (sync or async).
* @return normalized async callbacks, or empty list if input is null.
*/
@CanIgnoreReturnValue
public static @Nullable ImmutableList<BeforeAgentCallback> getBeforeAgentCallbacks(
List<BeforeAgentCallbackBase> beforeAgentCallback) {
if (beforeAgentCallback == null) {
return null;
} else if (beforeAgentCallback.isEmpty()) {
return ImmutableList.of();
} else {
ImmutableList.Builder<BeforeAgentCallback> builder = ImmutableList.builder();
for (BeforeAgentCallbackBase callback : beforeAgentCallback) {
if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) {
builder.add(beforeAgentCallbackInstance);
} else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) {
builder.add(
(callbackContext) ->
Maybe.fromOptional(beforeAgentCallbackSyncInstance.call(callbackContext)));
} else {
logger.warn(
"Invalid beforeAgentCallback callback type: {}. Ignoring this callback.",
callback.getClass().getName());
}
}
return builder.build();
}
public static ImmutableList<BeforeAgentCallback> getBeforeAgentCallbacks(
List<BeforeAgentCallbackBase> beforeAgentCallbacks) {
return getCallbacks(
beforeAgentCallbacks,
BeforeAgentCallback.class,
BeforeAgentCallbackSync.class,
sync -> (callbackContext -> Maybe.fromOptional(sync.call(callbackContext))),
"beforeAgentCallbacks");
}

/**
* Normalizes after-agent callbacks.
*
* @param afterAgentCallback Callback list (sync or async).
* @return normalized async callbacks, or null if input is null.
* @return normalized async callbacks, or empty list if input is null.
*/
@CanIgnoreReturnValue
public static @Nullable ImmutableList<AfterAgentCallback> getAfterAgentCallbacks(
public static ImmutableList<AfterAgentCallback> getAfterAgentCallbacks(
List<AfterAgentCallbackBase> afterAgentCallback) {
if (afterAgentCallback == null) {
return null;
} else if (afterAgentCallback.isEmpty()) {
return getCallbacks(
afterAgentCallback,
AfterAgentCallback.class,
AfterAgentCallbackSync.class,
sync -> (callbackContext -> Maybe.fromOptional(sync.call(callbackContext))),
"afterAgentCallback");
}

private static <B, A extends B, S extends B> ImmutableList<A> getCallbacks(
List<B> callbacks,
Class<A> asyncClass,
Class<S> syncClass,
Function<S, A> converter,
String callbackTypeForLogging) {
if (callbacks == null) {
return ImmutableList.of();
} else {
ImmutableList.Builder<AfterAgentCallback> builder = ImmutableList.builder();
for (AfterAgentCallbackBase callback : afterAgentCallback) {
if (callback instanceof AfterAgentCallback afterAgentCallbackInstance) {
builder.add(afterAgentCallbackInstance);
} else if (callback instanceof AfterAgentCallbackSync afterAgentCallbackSyncInstance) {
builder.add(
(callbackContext) ->
Maybe.fromOptional(afterAgentCallbackSyncInstance.call(callbackContext)));
} else {
logger.warn(
"Invalid afterAgentCallback callback type: {}. Ignoring this callback.",
callback.getClass().getName());
}
}
return builder.build();
}
return callbacks.stream()
.flatMap(
callback -> {
if (asyncClass.isInstance(callback)) {
return Stream.of(asyncClass.cast(callback));
} else if (syncClass.isInstance(callback)) {
return Stream.of(converter.apply(syncClass.cast(callback)));
} else {
logger.warn(
"Invalid {} callback type: {}. Ignoring this callback.",
callbackTypeForLogging,
callback.getClass().getName());
return Stream.empty();
}
})
.collect(ImmutableList.toImmutableList());
}

private CallbackUtil() {}
Expand Down
5 changes: 2 additions & 3 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,8 @@ public Optional<String> outputKey() {
return outputKey;
}

@Nullable
public BaseCodeExecutor codeExecutor() {
return codeExecutor.orElse(null);
public Optional<BaseCodeExecutor> codeExecutor() {
return codeExecutor;
}

public Model resolvedModel() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,15 +388,15 @@ private Flowable<Event> runOneStep(InvocationContext context) {
String agentToTransfer = event.actions().transferToAgent().get();
logger.debug("Transferring to agent: {}", agentToTransfer);
BaseAgent rootAgent = context.agent().rootAgent();
BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer);
if (nextAgent == null) {
Optional<BaseAgent> nextAgent = rootAgent.findAgent(agentToTransfer);
if (nextAgent.isEmpty()) {
String errorMsg = "Agent not found for transfer: " + agentToTransfer;
logger.error(errorMsg);
return postProcessedEvents.concatWith(
Flowable.error(new IllegalStateException(errorMsg)));
}
return postProcessedEvents.concatWith(
Flowable.defer(() -> nextAgent.runAsync(context)));
Flowable.defer(() -> nextAgent.get().runAsync(context)));
}
return postProcessedEvents;
});
Expand Down Expand Up @@ -574,14 +574,14 @@ public void onError(Throwable e) {
Flowable<Event> events = Flowable.just(event);
if (event.actions().transferToAgent().isPresent()) {
BaseAgent rootAgent = invocationContext.agent().rootAgent();
BaseAgent nextAgent =
Optional<BaseAgent> nextAgent =
rootAgent.findAgent(event.actions().transferToAgent().get());
if (nextAgent == null) {
if (nextAgent.isEmpty()) {
throw new IllegalStateException(
"Agent not found: " + event.actions().transferToAgent().get());
}
Flowable<Event> nextAgentEvents =
nextAgent.runLive(invocationContext);
nextAgent.get().runLive(invocationContext);
events = Flowable.concat(events, nextAgentEvents);
}
return events;
Expand Down
Loading