Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,17 @@ private static TF_Output toNativeOutputs(List<Operand<?>> outputs) {
new TF_Output(Pointer.malloc((long) outputs.size() * Pointer.sizeof(TF_Output.class)));

for (int i = 0; i < outputs.size(); ++i) {
var output = outputs.get(i).asOutput();
Operand<?> operand = outputs.get(i);
var nativeOutput = nativeOutputs.getPointer(i);

// Convention: null Operand => NoGradient
if (operand == null) {
nativeOutput.oper((TF_Operation) null);
nativeOutput.index(0);
continue;
}

var output = operand.asOutput();
nativeOutput.oper(((GraphOperation) output.op()).getUnsafeNativeHandle());
nativeOutput.index(output.index());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.tensorflow.internal.c_api.TF_Library;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.op.CustomGradient;
import org.tensorflow.op.GradientDispatch;
import org.tensorflow.op.RawCustomGradient;
import org.tensorflow.op.RawOpInputs;
import org.tensorflow.op.annotation.OpInputsMetadata;
Expand Down Expand Up @@ -207,7 +208,10 @@ public static synchronized boolean registerCustomGradient(
if (hasGradient(opType)) {
return false;
}
TFJ_GradFuncAdapter g = RawCustomGradient.adapter(gradient);

GradientDispatch.putRaw(opType, gradient);
TFJ_GradFuncAdapter g = GradientDispatch.adapter();

if (!TFJ_RegisterCustomGradient(opType, g)) {
return false;
}
Expand Down Expand Up @@ -255,7 +259,10 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
if (hasGradient(opType)) {
return false;
}
TFJ_GradFuncAdapter g = CustomGradient.adapter(gradient, inputClass);

GradientDispatch.putTyped(opType, gradient, inputClass);
TFJ_GradFuncAdapter g = GradientDispatch.adapter();

if (!TFJ_RegisterCustomGradient(opType, g)) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
package org.tensorflow.op;

import java.util.List;
import org.bytedeco.javacpp.PointerPointer;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.TensorFlow;
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
import org.tensorflow.internal.c_api.TFJ_GraphId;
import org.tensorflow.internal.c_api.TFJ_Scope;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Output;

/**
* A custom gradient for ops of type {@link T}. Should be registered using {@link
Expand Down Expand Up @@ -57,6 +62,31 @@ public interface CustomGradient<T extends RawOpInputs> {
*/
static <T extends RawOpInputs<?>> TFJ_GradFuncAdapter adapter(
CustomGradient<T> gradient, Class<T> opClass) {
return new TypedGradientAdapter<T>(gradient, opClass);

final TypedGradientAdapter<T> impl = new TypedGradientAdapter<T>(gradient, opClass);

// IMPORTANT:
// Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function
// pointer thunk for the native side. Some call paths may pass NULL if we return a deeper
// subclass.
return new TFJ_GradFuncAdapter() {
@Override
public int call(
TFJ_GraphId nativeGraphId,
TFJ_Scope nativeScope,
TF_Operation nativeOperation,
TF_Output nativeGradInputs,
int nativeGradInputsLength,
PointerPointer nativeGradOutputsPtr) {

return impl.call(
nativeGraphId,
nativeScope,
nativeOperation,
nativeGradInputs,
nativeGradInputsLength,
nativeGradOutputsPtr);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/* Copyright 2026 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow.op;

import java.lang.reflect.Constructor;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.tensorflow.AbstractGradientAdapter;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.internal.c_api.TFJ_Scope;

/**
* Dispatching adapter for Java-side custom gradient registration.
*
* <p>This class mirrors the behavior of TensorFlow Python's {@code tf.RegisterGradient} mechanism
* by providing a centralized dispatch layer for custom gradients in the Java API.
*
* <p>Gradients may be registered in one of two forms for a given op type:
*
* <ul>
* <li>A raw gradient ({@link RawCustomGradient}) operating directly on {@link GraphOperation} and
* {@link Output} objects.
* <li>A typed gradient ({@link CustomGradient}) operating on generated {@link RawOpInputs}
* subclasses.
* </ul>
*
* <p>For any given op type, exactly one gradient definition is permitted: either raw or typed.
* Duplicate registrations, or attempts to mix raw and typed gradients for the same op type, are
* rejected to prevent ambiguous dispatch behavior.
*
* <p>At runtime, {@link #apply(Graph, TFJ_Scope, GraphOperation, List)} determines the operation
* type and dispatches to the corresponding registered gradient implementation.
*/
final class DispatchingGradientAdapter extends AbstractGradientAdapter {

private final ConcurrentMap<String, RawCustomGradient> raw = new ConcurrentHashMap<>();
private final ConcurrentMap<String, TypedEntry<?>> typed = new ConcurrentHashMap<>();

private static String dupMsg(String opType, String existingKind, String newKind) {
return "A "
+ existingKind
+ " gradient is already registered for op type '"
+ opType
+ "'. Raw and typed registrations are mutually exclusive; cannot register "
+ newKind
+ ".";
}

static final class TypedEntry<T extends RawOpInputs<?>> {
final CustomGradient<T> grad;
final Class<T> inputClass;
final Constructor<T> ctor;

TypedEntry(CustomGradient<T> grad, Class<T> inputClass) {
this.grad = grad;
this.inputClass = inputClass;
try {
this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class);
} catch (NoSuchMethodException e) {
throw new IllegalArgumentException(
"Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).",
e);
}
}
}

void putRaw(String opType, RawCustomGradient g) {
if (typed.containsKey(opType)) {
throw new IllegalStateException(dupMsg(opType, "typed", "raw"));
}
RawCustomGradient prev = raw.putIfAbsent(opType, g);
if (prev != null) {
throw new IllegalStateException(
"A raw gradient is already registered for op type '" + opType + "'.");
}
}

<T extends RawOpInputs<?>> void putTyped(
String opType, CustomGradient<T> g, Class<T> inputClass) {
if (raw.containsKey(opType)) {
throw new IllegalStateException(dupMsg(opType, "raw", "typed"));
}
TypedEntry<?> prev = typed.putIfAbsent(opType, new TypedEntry<>(g, inputClass));
if (prev != null) {
throw new IllegalStateException(
"A typed gradient is already registered for op type '" + opType + "'.");
}
}

@Override
protected List<Operand<?>> apply(
Graph graph, TFJ_Scope scope, GraphOperation operation, List<Output<?>> gradInputs) {

final String opType = operation.type();

RawCustomGradient rg = raw.get(opType);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic prefers raw gradients over typed ones, but there isn't anything documented about why it prefers them or if it makes sense to add both raw and typed gradients for the same op. It would be good to clarify this, and if it doesn't make sense to have both kinds of gradients the adapter should reject them in the puts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add javadoc to the top of this class noting the overall purpose of it (to provide Java side dispatching for gradients mirroring TF-Python), that it only accepts either raw or typed gradients for a given op, and that it rejects duplicate assignments.

if (rg != null) {
// NativeScope & Ops constructors are package-private => must be in org.tensorflow.op
Scope nativeScope =
new NativeScope(scope, graph, operation.name()).withSubScope(operation.name());
return rg.call(new Ops(nativeScope), operation, gradInputs);
}

@SuppressWarnings("rawtypes")
TypedEntry te = typed.get(opType);
if (te != null) {
return applyTyped(graph, scope, operation, gradInputs, te);
}

throw new IllegalStateException("No Java custom gradient registered for op type: " + opType);
}

private <T extends RawOpInputs<?>> List<Operand<?>> applyTyped(
Graph graph,
TFJ_Scope scope,
GraphOperation operation,
List<Output<?>> gradInputs,
TypedEntry<T> te) {
try {
T inputs = te.ctor.newInstance(operation);
Scope nativeScope =
new NativeScope(scope, graph, operation.name()).withSubScope(operation.name());
return te.grad.call(new Ops(nativeScope), inputs, gradInputs);
} catch (ReflectiveOperationException e) {
throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright 2026 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow.op;

import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;

/** Public bridge to a single native gradient adapter. */
public final class GradientDispatch {

// package-private adapter that can access NativeScope/Ops constructors
static final DispatchingGradientAdapter ADAPTER = new DispatchingGradientAdapter();

private GradientDispatch() {}

public static TFJ_GradFuncAdapter adapter() {
return ADAPTER;
}

public static void putRaw(String opType, RawCustomGradient gradient) {
ADAPTER.putRaw(opType, gradient);
}

public static <T extends RawOpInputs<?>> void putTyped(
String opType, CustomGradient<T> gradient, Class<T> inputClass) {
ADAPTER.putTyped(opType, gradient, inputClass);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
package org.tensorflow.op;

import java.util.List;
import org.bytedeco.javacpp.PointerPointer;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.TensorFlow;
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
import org.tensorflow.internal.c_api.TFJ_GraphId;
import org.tensorflow.internal.c_api.TFJ_Scope;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Output;

/**
* A custom gradient for an op of unspecified type. Should be registered using {@link
Expand Down Expand Up @@ -54,6 +59,30 @@ public interface RawCustomGradient {
* TensorFlow#registerCustomGradient(String, RawCustomGradient)}.
*/
static TFJ_GradFuncAdapter adapter(RawCustomGradient gradient) {
return new RawGradientAdapter(gradient);
final RawGradientAdapter impl = new RawGradientAdapter(gradient);

// IMPORTANT:
// Return a *direct* TFJ_GradFuncAdapter subclass, so JavaCPP reliably materializes a function
// pointer thunk for the native side. Some call paths may pass NULL if we return a deeper
// subclass.
return new TFJ_GradFuncAdapter() {
@Override
public int call(
TFJ_GraphId nativeGraphId,
TFJ_Scope nativeScope,
TF_Operation nativeOperation,
TF_Output nativeGradInputs,
int nativeGradInputsLength,
PointerPointer nativeGradOutputsPtr) {

return impl.call(
nativeGraphId,
nativeScope,
nativeOperation,
nativeGradInputs,
nativeGradInputsLength,
nativeGradOutputsPtr);
}
};
}
}
Loading