Skip to content

Commit ca23fce

Browse files
Refactors AbsMaxQuantizer to accept axis in __call__ (#21931)
1 parent 1a9893f commit ca23fce

File tree

8 files changed

+73
-31
lines changed

8 files changed

+73
-31
lines changed

keras/src/layers/core/dense.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def quantized_build(self, kernel_shape, mode, config=None):
413413
def _int8_build(self, kernel_shape, config=None):
414414
self.inputs_quantizer = (
415415
QuantizationConfig.activation_quantizer_or_default(
416-
config, quantizers.AbsMaxQuantizer(axis=-1)
416+
config, quantizers.AbsMaxQuantizer()
417417
)
418418
)
419419

@@ -526,7 +526,7 @@ def _int4_build(self, kernel_shape, config=None):
526526
# Per-channel int8 quantizer for the last axis (features).
527527
self.inputs_quantizer = (
528528
QuantizationConfig.activation_quantizer_or_default(
529-
config, quantizers.AbsMaxQuantizer(axis=-1)
529+
config, quantizers.AbsMaxQuantizer()
530530
)
531531
)
532532
input_dim, output_dim = kernel_shape
@@ -618,7 +618,7 @@ def grad_fn(*args, upstream=None):
618618

619619
output_scale = kernel_scale
620620
if self.inputs_quantizer:
621-
inputs, inputs_scale = self.inputs_quantizer(inputs)
621+
inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
622622
output_scale = ops.multiply(output_scale, inputs_scale)
623623

624624
x = ops.matmul(inputs, kernel)
@@ -674,7 +674,7 @@ def grad_fn(*args, upstream=None):
674674
output_scale = kernel_scale
675675

676676
if self.inputs_quantizer:
677-
inputs, inputs_scale = self.inputs_quantizer(inputs)
677+
inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
678678
output_scale = ops.multiply(output_scale, inputs_scale)
679679

680680
x = ops.matmul(inputs, unpacked_kernel)

keras/src/layers/core/dense_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
class DenseTest(testing.TestCase):
2626
@parameterized.named_parameters(
27-
("int8", "int8", {"axis": 0}, {"axis": -1}),
27+
("int8", "int8", {"axis": 0}, {}),
2828
(
2929
"int4",
3030
"int4",
@@ -62,7 +62,6 @@ def test_dense_quantize_config(
6262
if activation_quantizer_args is not None:
6363
# Verify inputs_quantizer is set correctly
6464
self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer)
65-
self.assertEqual(layer.inputs_quantizer.axis, (-1,))
6665
else:
6766
# Verify inputs_quantizer is None
6867
self.assertIsNone(layer.inputs_quantizer)

keras/src/layers/core/einsum_dense.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -479,13 +479,12 @@ def _int8_build(self, kernel_shape, config=None):
479479
self.inputs_quantizer = (
480480
QuantizationConfig.activation_quantizer_or_default(
481481
config,
482-
quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes),
482+
quantizers.AbsMaxQuantizer(),
483483
)
484484
)
485485
# If the config provided a default AbsMaxQuantizer, we need to
486486
# override the axis to match the equation's reduction axes.
487-
if isinstance(self.inputs_quantizer, quantizers.AbsMaxQuantizer):
488-
self.inputs_quantizer.axis = tuple(self._input_reduced_axes)
487+
self.quantization_axis = tuple(self._input_reduced_axes)
489488
self._kernel = self.add_weight(
490489
name="kernel",
491490
shape=kernel_shape,
@@ -632,13 +631,12 @@ def _int4_build(self, kernel_shape, config=None):
632631
self.inputs_quantizer = (
633632
QuantizationConfig.activation_quantizer_or_default(
634633
config,
635-
quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes),
634+
quantizers.AbsMaxQuantizer(),
636635
)
637636
)
638637
# If the config provided a default AbsMaxQuantizer, we need to
639638
# override the axis to match the equation's reduction axes.
640-
if isinstance(self.inputs_quantizer, quantizers.AbsMaxQuantizer):
641-
self.inputs_quantizer.axis = tuple(self._input_reduced_axes)
639+
self.quantization_axis = tuple(self._input_reduced_axes)
642640

643641
# Choose the axis to perform int4 packing - use the first reduced axis
644642
# for the kernel (analogous to the input dimension of a Dense layer).
@@ -761,7 +759,9 @@ def grad_fn(*args, upstream=None):
761759
return (inputs_grad, None, None)
762760

763761
if self.inputs_quantizer:
764-
inputs, inputs_scale = self.inputs_quantizer(inputs)
762+
inputs, inputs_scale = self.inputs_quantizer(
763+
inputs, axis=self.quantization_axis
764+
)
765765
# Align `inputs_scale` axes with the output
766766
# for correct broadcasting
767767
inputs_scale = self._adjust_scale_for_quant(
@@ -858,7 +858,9 @@ def grad_fn(*args, upstream=None):
858858

859859
# Quantize inputs per `self.inputs_quantizer`.
860860
if self.inputs_quantizer:
861-
inputs_q, inputs_scale = self.inputs_quantizer(inputs)
861+
inputs_q, inputs_scale = self.inputs_quantizer(
862+
inputs, axis=self.quantization_axis
863+
)
862864
# Align `inputs_scale` axes with the output
863865
# for correct broadcasting
864866
inputs_scale = self._adjust_scale_for_quant(

keras/src/layers/core/einsum_dense_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def test_einsum_dense_quantize(
7171
if activation_quantizer_args is not None:
7272
# Verify inputs_quantizer is set correctly
7373
self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer)
74-
self.assertEqual(layer.inputs_quantizer.axis, (1,))
7574
else:
7675
# Verify inputs_quantizer is None
7776
self.assertIsNone(layer.inputs_quantizer)

keras/src/layers/core/reversible_embedding_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,5 +258,4 @@ def test_reversible_embedding_int8_custom_quantizer(self):
258258
)
259259
quantizer = new_layer.quantization_config.weight_quantizer
260260
self.assertIsInstance(quantizer, AbsMaxQuantizer)
261-
self.assertEqual(quantizer.axis, (-1,))
262261
self.assertAllEqual(quantizer.value_range, weight_range)

keras/src/quantizers/quantization_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, weight_quantizer=None, activation_quantizer="default"):
7676
from keras.src.quantizers.quantizers import AbsMaxQuantizer
7777

7878
if activation_quantizer == "default":
79-
activation_quantizer = AbsMaxQuantizer(axis=-1)
79+
activation_quantizer = AbsMaxQuantizer()
8080
super().__init__(weight_quantizer, activation_quantizer)
8181
if self.weight_quantizer is not None:
8282
if self.weight_quantizer.output_dtype != "int8":
@@ -105,7 +105,7 @@ def __init__(self, weight_quantizer=None, activation_quantizer="default"):
105105
from keras.src.quantizers.quantizers import AbsMaxQuantizer
106106

107107
if activation_quantizer == "default":
108-
activation_quantizer = AbsMaxQuantizer(axis=-1)
108+
activation_quantizer = AbsMaxQuantizer()
109109
super().__init__(weight_quantizer, activation_quantizer)
110110
if self.weight_quantizer is not None:
111111
if self.weight_quantizer.value_range != (-8, 7):

keras/src/quantizers/quantizers.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,23 @@ def abs_max_quantize(
7373
epsilon=backend.epsilon(),
7474
to_numpy=False,
7575
):
76+
"""
77+
Quantizes the input tensor using the absolute maximum quantization scheme.
78+
79+
Args:
80+
inputs: Input tensor to quantize.
81+
axis: Axis along which to compute the quantization range.
82+
value_range: Tuple of the minimum and maximum values of the quantization
83+
range.
84+
dtype: Data type of the quantized output.
85+
epsilon: Small value to avoid division by zero.
86+
to_numpy: Whether to perform the quantization in numpy. This performs
87+
the computation on the host CPU and can be useful for saving memory
88+
on the device. If False, the computation is performed on the device.
89+
90+
Returns:
91+
A tuple of the quantized tensor and the scale.
92+
"""
7693
if to_numpy:
7794
# Save memory on the device using numpy
7895
original_dtype = backend.standardize_dtype(inputs.dtype)
@@ -105,15 +122,18 @@ def abs_max_quantize(
105122
class AbsMaxQuantizer(Quantizer):
106123
def __init__(
107124
self,
108-
axis,
125+
axis=None, # Deprecated, provide axis in __call__ instead.
109126
value_range=(-127, 127),
110127
epsilon=backend.epsilon(),
111128
output_dtype="int8",
112129
):
113130
Quantizer.__init__(self, output_dtype=output_dtype)
114-
if isinstance(axis, int):
115-
axis = (axis,)
116-
self.axis = tuple(axis)
131+
if axis is not None:
132+
if isinstance(axis, int):
133+
axis = (axis,)
134+
self.axis = tuple(axis)
135+
else:
136+
self.axis = None
117137
self.value_range = value_range
118138
self.epsilon = epsilon
119139
if output_dtype == "int8":
@@ -124,10 +144,31 @@ def __init__(
124144
f"value_range={value_range}"
125145
)
126146

127-
def __call__(self, x, to_numpy=False):
147+
def __call__(self, x, axis=None, to_numpy=False):
148+
"""
149+
Quantizes the input tensor.
150+
151+
Args:
152+
x: Input tensor to quantize.
153+
axis: Axis along which to compute the quantization range. If None,
154+
uses the axis specified in the constructor. If None and no axis
155+
was specified in the constructor, defaults to -1.
156+
to_numpy: Whether to perform the quantization in numpy. This
157+
performs the computation on the host CPU and can be useful for
158+
saving memory on the device. If False, the computation is
159+
performed on the device.
160+
161+
Returns:
162+
A tuple of the quantized tensor and the scale.
163+
"""
164+
if axis is None:
165+
axis = self.axis
166+
if axis is None:
167+
# Default to -1 if no axis is specified
168+
axis = -1
128169
quantized_x, scale = abs_max_quantize(
129170
x,
130-
self.axis,
171+
axis,
131172
self.value_range,
132173
self.output_dtype,
133174
self.epsilon,
@@ -136,12 +177,14 @@ def __call__(self, x, to_numpy=False):
136177
return quantized_x, scale
137178

138179
def get_config(self):
139-
return {
140-
"axis": self.axis,
180+
config = {
141181
"value_range": self.value_range,
142182
"epsilon": self.epsilon,
143183
"output_dtype": self.output_dtype,
144184
}
185+
if self.axis is not None:
186+
config["axis"] = self.axis
187+
return config
145188

146189

147190
def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):

keras/src/quantizers/quantizers_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
class QuantizersTest(testing.TestCase):
1919
def test_get_method(self):
20-
quantizer = quantizers.get("abs_max_quantizer", axis=-1)
20+
quantizer = quantizers.get("abs_max_quantizer")
2121
self.assertTrue(quantizer, quantizers.AbsMaxQuantizer)
2222

2323
quantizer = quantizers.get(None)
@@ -28,10 +28,10 @@ def test_get_method(self):
2828

2929
def test_abs_max_quantizer(self):
3030
values = random.uniform([3, 4, 5], minval=-1, maxval=1, dtype="float32")
31-
quantizer = quantizers.AbsMaxQuantizer(axis=-1)
31+
quantizer = quantizers.AbsMaxQuantizer()
3232

3333
# Test quantizing
34-
quantized_values, scale = quantizer(values)
34+
quantized_values, scale = quantizer(values, axis=-1)
3535
self.assertDType(quantized_values, "int8")
3636
self.assertDType(scale, "float32")
3737
self.assertEqual(tuple(quantized_values.shape), (3, 4, 5))
@@ -53,11 +53,11 @@ def test_abs_max_quantizer(self):
5353
values = random.uniform(
5454
[3, 4, 5], minval=-1, maxval=1, dtype="bfloat16"
5555
)
56-
quantized_values, scale = quantizer(values)
56+
quantized_values, scale = quantizer(values, axis=-1)
5757
self.assertDType(quantized_values, "int8")
5858
self.assertDType(scale, "bfloat16")
5959
values = random.uniform([3, 4, 5], minval=-1, maxval=1, dtype="float16")
60-
quantized_values, scale = quantizer(values)
60+
quantized_values, scale = quantizer(values, axis=-1)
6161
self.assertDType(quantized_values, "int8")
6262
self.assertDType(scale, "float16")
6363

0 commit comments

Comments
 (0)