Skip to content

Commit 1a9893f

Browse files
Adds Serialization Support for QuantizationConfig based quantized models (#21928)
* Adds serialization support for QuantizationConfig * fix imports * remove redundant config resolution logic * remove redundant config resolution logic * move deserialization to from_config + restructure tests * from_config should call super
1 parent 86bfab4 commit 1a9893f

14 files changed

+415
-66
lines changed

keras/src/layers/core/dense.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from keras.src.layers.input_spec import InputSpec
1313
from keras.src.layers.layer import Layer
1414
from keras.src.quantizers.quantization_config import QuantizationConfig
15-
from keras.src.quantizers.quantization_config import validate_and_resolve_config
1615
from keras.src.quantizers.quantizers import dequantize_with_sz_map
16+
from keras.src.saving import serialization_lib
1717

1818

1919
@keras_export("keras.layers.Dense")
@@ -94,6 +94,7 @@ def __init__(
9494
bias_constraint=None,
9595
lora_rank=None,
9696
lora_alpha=None,
97+
quantization_config=None,
9798
**kwargs,
9899
):
99100
if not isinstance(units, int) or units <= 0:
@@ -115,13 +116,18 @@ def __init__(
115116
self.lora_rank = lora_rank
116117
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
117118
self.lora_enabled = False
119+
self.quantization_config = quantization_config
118120
self.input_spec = InputSpec(min_ndim=2)
119121
self.supports_masking = True
120122

121123
def build(self, input_shape):
122124
kernel_shape = (input_shape[-1], self.units)
123125
if self.quantization_mode:
124-
self.quantized_build(kernel_shape, mode=self.quantization_mode)
126+
self.quantized_build(
127+
kernel_shape,
128+
mode=self.quantization_mode,
129+
config=self.quantization_config,
130+
)
125131
if self.quantization_mode not in ("int8", "int4", "gptq"):
126132
# If the layer is quantized to int8 or int4, `self._kernel` will be
127133
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
@@ -330,12 +336,25 @@ def get_config(self):
330336
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
331337
"kernel_constraint": constraints.serialize(self.kernel_constraint),
332338
"bias_constraint": constraints.serialize(self.bias_constraint),
339+
"quantization_config": serialization_lib.serialize_keras_object(
340+
self.quantization_config
341+
),
333342
}
334343
if self.lora_rank:
335344
config["lora_rank"] = self.lora_rank
336345
config["lora_alpha"] = self.lora_alpha
337346
return {**base_config, **config}
338347

348+
@classmethod
349+
def from_config(cls, config):
350+
config = config.copy()
351+
config["quantization_config"] = (
352+
serialization_lib.deserialize_keras_object(
353+
config.get("quantization_config", None)
354+
)
355+
)
356+
return super().from_config(config)
357+
339358
@property
340359
def variable_serialization_spec(self):
341360
"""Returns a dict mapping quantization modes to variable names in order.
@@ -777,27 +796,26 @@ def quantize(self, mode=None, type_check=True, config=None):
777796
if type_check and (type(self) is not Dense):
778797
raise self._not_implemented_error(self.quantize)
779798

780-
config = validate_and_resolve_config(mode, config)
781-
mode = config.mode
799+
self.quantization_config = config
782800

783801
kernel_shape = self._kernel.shape
784802
if mode == "int8":
785803
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
786-
config, quantizers.AbsMaxQuantizer(axis=0)
804+
self.quantization_config, quantizers.AbsMaxQuantizer(axis=0)
787805
)
788806
kernel_value, kernel_scale = weight_quantizer(
789807
self._kernel, to_numpy=True
790808
)
791809
kernel_scale = ops.squeeze(kernel_scale, axis=0)
792810
del self._kernel
793811
# Build variables for int8 mode
794-
self.quantized_build(kernel_shape, mode, config)
812+
self.quantized_build(kernel_shape, mode, self.quantization_config)
795813
self._kernel.assign(kernel_value)
796814
self.kernel_scale.assign(kernel_scale)
797815
elif mode == "int4":
798816
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
799817
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
800-
config,
818+
self.quantization_config,
801819
quantizers.AbsMaxQuantizer(
802820
axis=0, value_range=(-8, 7), output_dtype="int8"
803821
),
@@ -811,12 +829,12 @@ def quantize(self, mode=None, type_check=True, config=None):
811829
del self._kernel
812830
# Build variables using the original kernel shape; _int4_build will
813831
# compute the packed shape internally.
814-
self.quantized_build(kernel_shape, mode, config)
832+
self.quantized_build(kernel_shape, mode, self.quantization_config)
815833
# Assign packed values.
816834
self._kernel.assign(packed_kernel_value)
817835
self.kernel_scale.assign(kernel_scale)
818836
elif mode == "gptq":
819-
self.quantized_build(kernel_shape, mode, config)
837+
self.quantized_build(kernel_shape, mode, self.quantization_config)
820838
elif mode == "float8":
821839
self.quantized_build(kernel_shape, mode)
822840
else:
@@ -828,7 +846,7 @@ def quantize(self, mode=None, type_check=True, config=None):
828846

829847
policy_name = mode
830848
if mode == "gptq":
831-
policy_name = config.dtype_policy_string()
849+
policy_name = self.quantization_config.dtype_policy_string()
832850
policy = dtype_policies.get(
833851
f"{policy_name}_from_{self.dtype_policy.name}"
834852
)

keras/src/layers/core/dense_test.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,3 +1025,123 @@ def test_gptq_kernel_packing(self):
10251025

10261026
quantized_kernel_params = ops.prod(layer.quantized_kernel.shape)
10271027
self.assertEqual(quantized_kernel_params, original_kernel_params // 2)
1028+
1029+
def _check_quantizer_config(
1030+
self, quantizer, valid_class, axis, value_range
1031+
):
1032+
self.assertIsInstance(quantizer, valid_class)
1033+
self.assertEqual(quantizer.axis, axis)
1034+
1035+
# Normalize value_range to list
1036+
if value_range is not None:
1037+
self.assertAllEqual(quantizer.value_range, value_range)
1038+
1039+
def test_dense_int8_custom_quantizer(self):
1040+
"""
1041+
Test custom quantizer serialization for dense layer.
1042+
"""
1043+
# Setup
1044+
weight_range = (-127, 127)
1045+
act_range = (-5, 5)
1046+
config = Int8QuantizationConfig(
1047+
weight_quantizer=AbsMaxQuantizer(axis=0, value_range=weight_range),
1048+
activation_quantizer=AbsMaxQuantizer(
1049+
axis=-1, value_range=act_range
1050+
),
1051+
)
1052+
1053+
# Build & Quantize
1054+
layer = layers.Dense(10)
1055+
layer.build((None, 5))
1056+
layer.quantize("int8", config=config)
1057+
1058+
# Serialize & Deserialize
1059+
serialized = layer.get_config()
1060+
new_layer = layers.Dense.from_config(serialized)
1061+
1062+
# Verify
1063+
self.assertIsInstance(
1064+
new_layer.quantization_config, Int8QuantizationConfig
1065+
)
1066+
self._check_quantizer_config(
1067+
new_layer.quantization_config.weight_quantizer,
1068+
AbsMaxQuantizer,
1069+
axis=(0,),
1070+
value_range=weight_range,
1071+
)
1072+
self._check_quantizer_config(
1073+
new_layer.quantization_config.activation_quantizer,
1074+
AbsMaxQuantizer,
1075+
axis=(-1,),
1076+
value_range=act_range,
1077+
)
1078+
1079+
def test_dense_int8_weight_only_quantizer(self):
1080+
"""
1081+
Test custom quantizer serialization for dense layer with
1082+
weight-only quantization.
1083+
"""
1084+
# Setup
1085+
config = Int8QuantizationConfig(
1086+
weight_quantizer=AbsMaxQuantizer(axis=0),
1087+
activation_quantizer=None,
1088+
)
1089+
1090+
# Build & Quantize
1091+
layer = layers.Dense(10)
1092+
layer.build((None, 5))
1093+
layer.quantize("int8", config=config)
1094+
1095+
# Serialize & Deserialize
1096+
serialized = layer.get_config()
1097+
new_layer = layers.Dense.from_config(serialized)
1098+
1099+
# Verify
1100+
self.assertIsInstance(
1101+
new_layer.quantization_config, Int8QuantizationConfig
1102+
)
1103+
self.assertIsInstance(
1104+
new_layer.quantization_config.weight_quantizer, AbsMaxQuantizer
1105+
)
1106+
self.assertIsNone(new_layer.quantization_config.activation_quantizer)
1107+
1108+
def test_dense_int4_custom_quantizer(self):
1109+
"""
1110+
Test custom quantizer serialization for dense layer with
1111+
int4 quantization.
1112+
"""
1113+
# Setup
1114+
weight_range = (-8, 7)
1115+
act_range = (-2, 2)
1116+
config = Int4QuantizationConfig(
1117+
weight_quantizer=AbsMaxQuantizer(axis=0, value_range=weight_range),
1118+
activation_quantizer=AbsMaxQuantizer(
1119+
axis=-1, value_range=act_range
1120+
),
1121+
)
1122+
1123+
# Build & Quantize
1124+
layer = layers.Dense(10)
1125+
layer.build((None, 5))
1126+
layer.quantize("int4", config=config)
1127+
1128+
# Serialize & Deserialize
1129+
serialized = layer.get_config()
1130+
new_layer = layers.Dense.from_config(serialized)
1131+
1132+
# Verify
1133+
self.assertIsInstance(
1134+
new_layer.quantization_config, Int4QuantizationConfig
1135+
)
1136+
self._check_quantizer_config(
1137+
new_layer.quantization_config.weight_quantizer,
1138+
AbsMaxQuantizer,
1139+
axis=(0,),
1140+
value_range=weight_range,
1141+
)
1142+
self._check_quantizer_config(
1143+
new_layer.quantization_config.activation_quantizer,
1144+
AbsMaxQuantizer,
1145+
axis=(-1,),
1146+
value_range=act_range,
1147+
)

keras/src/layers/core/einsum_dense.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from keras.src.layers.layer import Layer
1919
from keras.src.quantizers.quantization_config import QuantizationConfig
2020
from keras.src.quantizers.quantizers import dequantize_with_sz_map
21+
from keras.src.saving import serialization_lib
2122

2223

2324
@keras_export("keras.layers.EinsumDense")
@@ -136,6 +137,7 @@ def __init__(
136137
lora_rank=None,
137138
lora_alpha=None,
138139
gptq_unpacked_column_size=None,
140+
quantization_config=None,
139141
**kwargs,
140142
):
141143
super().__init__(**kwargs)
@@ -156,6 +158,7 @@ def __init__(
156158
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
157159
self.lora_enabled = False
158160
self.gptq_unpacked_column_size = gptq_unpacked_column_size
161+
self.quantization_config = quantization_config
159162

160163
def build(self, input_shape):
161164
shape_data = _analyze_einsum_string(
@@ -171,6 +174,7 @@ def build(self, input_shape):
171174
self.quantized_build(
172175
kernel_shape,
173176
mode=self.quantization_mode,
177+
config=self.quantization_config,
174178
)
175179
# Skip creating a duplicate kernel variable when the layer is already
176180
# quantized to int8 or int4, because `quantized_build` has created the
@@ -394,6 +398,9 @@ def get_config(self):
394398
),
395399
"kernel_constraint": constraints.serialize(self.kernel_constraint),
396400
"bias_constraint": constraints.serialize(self.bias_constraint),
401+
"quantization_config": serialization_lib.serialize_keras_object(
402+
self.quantization_config
403+
),
397404
}
398405
if self.lora_rank:
399406
config["lora_rank"] = self.lora_rank
@@ -402,6 +409,16 @@ def get_config(self):
402409
config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size
403410
return {**base_config, **config}
404411

412+
@classmethod
413+
def from_config(cls, config):
414+
config = config.copy()
415+
config["quantization_config"] = (
416+
serialization_lib.deserialize_keras_object(
417+
config.get("quantization_config", None)
418+
)
419+
)
420+
return super().from_config(config)
421+
405422
@property
406423
def variable_serialization_spec(self):
407424
"""Returns a dict mapping quantization modes to variable names in order.
@@ -465,6 +482,10 @@ def _int8_build(self, kernel_shape, config=None):
465482
quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes),
466483
)
467484
)
485+
# If the config provided a default AbsMaxQuantizer, we need to
486+
# override the axis to match the equation's reduction axes.
487+
if isinstance(self.inputs_quantizer, quantizers.AbsMaxQuantizer):
488+
self.inputs_quantizer.axis = tuple(self._input_reduced_axes)
468489
self._kernel = self.add_weight(
469490
name="kernel",
470491
shape=kernel_shape,
@@ -614,6 +635,10 @@ def _int4_build(self, kernel_shape, config=None):
614635
quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes),
615636
)
616637
)
638+
# If the config provided a default AbsMaxQuantizer, we need to
639+
# override the axis to match the equation's reduction axes.
640+
if isinstance(self.inputs_quantizer, quantizers.AbsMaxQuantizer):
641+
self.inputs_quantizer.axis = tuple(self._input_reduced_axes)
617642

618643
# Choose the axis to perform int4 packing - use the first reduced axis
619644
# for the kernel (analogous to the input dimension of a Dense layer).
@@ -980,14 +1005,16 @@ def quantize(self, mode=None, type_check=True, config=None):
9801005
if type_check and (type(self) is not EinsumDense):
9811006
raise self._not_implemented_error(self.quantize)
9821007

1008+
self.quantization_config = config
1009+
9831010
kernel_shape = self._kernel.shape
9841011
if mode in ("int8", "int4", "gptq"):
9851012
self._set_quantization_info()
9861013

9871014
if mode == "int8":
9881015
# Quantize `self._kernel` to int8 and compute corresponding scale
9891016
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
990-
config,
1017+
self.quantization_config,
9911018
quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes),
9921019
)
9931020
kernel_value, kernel_scale = weight_quantizer(
@@ -998,7 +1025,7 @@ def quantize(self, mode=None, type_check=True, config=None):
9981025
elif mode == "int4":
9991026
# Quantize to int4 values (stored in int8 dtype, range [-8, 7])
10001027
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
1001-
config,
1028+
self.quantization_config,
10021029
quantizers.AbsMaxQuantizer(
10031030
axis=self._kernel_reduced_axes,
10041031
value_range=(-8, 7),
@@ -1017,7 +1044,7 @@ def quantize(self, mode=None, type_check=True, config=None):
10171044
)
10181045
kernel_value = packed_kernel_value
10191046
del self._kernel
1020-
self.quantized_build(kernel_shape, mode, config)
1047+
self.quantized_build(kernel_shape, mode, self.quantization_config)
10211048

10221049
# Assign values to the newly created variables.
10231050
if mode in ("int8", "int4"):
@@ -1028,7 +1055,7 @@ def quantize(self, mode=None, type_check=True, config=None):
10281055
if self.dtype_policy.quantization_mode is None:
10291056
policy_name = mode
10301057
if mode == "gptq":
1031-
policy_name = config.dtype_policy_string()
1058+
policy_name = self.quantization_config.dtype_policy_string()
10321059
policy = dtype_policies.get(
10331060
f"{policy_name}_from_{self.dtype_policy.name}"
10341061
)

0 commit comments

Comments
 (0)