@@ -73,6 +73,23 @@ def abs_max_quantize(
7373 epsilon = backend .epsilon (),
7474 to_numpy = False ,
7575):
76+ """
77+ Quantizes the input tensor using the absolute maximum quantization scheme.
78+
79+ Args:
80+ inputs: Input tensor to quantize.
81+ axis: Axis along which to compute the quantization range.
82+ value_range: Tuple of the minimum and maximum values of the quantization
83+ range.
84+ dtype: Data type of the quantized output.
85+ epsilon: Small value to avoid division by zero.
86+ to_numpy: Whether to perform the quantization in numpy. This performs
87+ the computation on the host CPU and can be useful for saving memory
88+ on the device. If False, the computation is performed on the device.
89+
90+ Returns:
91+ A tuple of the quantized tensor and the scale.
92+ """
7693 if to_numpy :
7794 # Save memory on the device using numpy
7895 original_dtype = backend .standardize_dtype (inputs .dtype )
@@ -105,15 +122,18 @@ def abs_max_quantize(
105122class AbsMaxQuantizer (Quantizer ):
106123 def __init__ (
107124 self ,
108- axis ,
125+ axis = None , # Deprecated, provide axis in __call__ instead.
109126 value_range = (- 127 , 127 ),
110127 epsilon = backend .epsilon (),
111128 output_dtype = "int8" ,
112129 ):
113130 Quantizer .__init__ (self , output_dtype = output_dtype )
114- if isinstance (axis , int ):
115- axis = (axis ,)
116- self .axis = tuple (axis )
131+ if axis is not None :
132+ if isinstance (axis , int ):
133+ axis = (axis ,)
134+ self .axis = tuple (axis )
135+ else :
136+ self .axis = None
117137 self .value_range = value_range
118138 self .epsilon = epsilon
119139 if output_dtype == "int8" :
@@ -124,10 +144,31 @@ def __init__(
124144 f"value_range={ value_range } "
125145 )
126146
127- def __call__ (self , x , to_numpy = False ):
147+ def __call__ (self , x , axis = None , to_numpy = False ):
148+ """
149+ Quantizes the input tensor.
150+
151+ Args:
152+ x: Input tensor to quantize.
153+ axis: Axis along which to compute the quantization range. If None,
154+ uses the axis specified in the constructor. If None and no axis
155+ was specified in the constructor, defaults to -1.
156+ to_numpy: Whether to perform the quantization in numpy. This
157+ performs the computation on the host CPU and can be useful for
158+ saving memory on the device. If False, the computation is
159+ performed on the device.
160+
161+ Returns:
162+ A tuple of the quantized tensor and the scale.
163+ """
164+ if axis is None :
165+ axis = self .axis
166+ if axis is None :
167+ # Default to -1 if no axis is specified
168+ axis = - 1
128169 quantized_x , scale = abs_max_quantize (
129170 x ,
130- self . axis ,
171+ axis ,
131172 self .value_range ,
132173 self .output_dtype ,
133174 self .epsilon ,
@@ -136,12 +177,14 @@ def __call__(self, x, to_numpy=False):
136177 return quantized_x , scale
137178
138179 def get_config (self ):
139- return {
140- "axis" : self .axis ,
180+ config = {
141181 "value_range" : self .value_range ,
142182 "epsilon" : self .epsilon ,
143183 "output_dtype" : self .output_dtype ,
144184 }
185+ if self .axis is not None :
186+ config ["axis" ] = self .axis
187+ return config
145188
146189
147190def adjust_and_nudge (min_range , max_range , num_bits , narrow_range ):
0 commit comments