Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions .github/scripts/torchao_model_releases/quantize_and_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,29 @@ def _untie_weights_and_save_locally(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
"""

_int8_int4_hqq_quant_code = """
from torchao.quantization.quant_api import (
IntxWeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
ModuleFqnToConfig,
)
from torchao.quantization.granularity import PerGroup, PerAxis
embedding_config = IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
intx_choose_qparams_algorithm="hqq_scale_only",
)
linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
intx_choose_qparams_algorithm="hqq_scale_only",
)
quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}})
quantization_config = TorchAoConfig(quant_type=quant_config, include_input_output_embeddings=True, modules_to_not_convert=[])
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
"""

_awq_int4_quant_code = """
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.prototype.awq import (
Expand Down Expand Up @@ -589,14 +612,8 @@ def quantize_and_upload(
push_to_user_id: str,
populate_model_card_template: bool,
):
_int8_int4_linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
)
_int8_int4_embedding_config = IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
)
is_mobile = quant in ["INT8-INT4", "INT8-INT4-HQQ"]

quant_to_config = {
"FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
"INT4": Int4WeightOnlyConfig(
Expand All @@ -606,8 +623,28 @@ def quantize_and_upload(
),
"INT8-INT4": ModuleFqnToConfig(
{
"_default": _int8_int4_linear_config,
"model.embed_tokens": _int8_int4_embedding_config,
"_default": Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
),
"model.embed_tokens": IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
),
}
),
"INT8-INT4-HQQ": ModuleFqnToConfig(
{
"_default": Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
intx_choose_qparams_algorithm="hqq_scale_only",
),
"model.embed_tokens": IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
intx_choose_qparams_algorithm="hqq_scale_only",
),
}
),
}
Expand All @@ -616,12 +653,13 @@ def quantize_and_upload(
"FP8": _fp8_quant_code,
"INT4": _int4_quant_code,
"INT8-INT4": _int8_int4_quant_code,
"INT8-INT4-HQQ": _int8_int4_hqq_quant_code,
"AWQ-INT4": _awq_int4_quant_code,
}

# preparation
model_to_quantize = model_id
if quant == "INT8-INT4":
if is_mobile:
model_to_quantize = _untie_weights_and_save_locally(model_to_quantize)

# quantization
Expand Down Expand Up @@ -666,7 +704,7 @@ def quantize_and_upload(
quant_config = quant_to_config[quant]

torchao_config_kwargs = {}
if "INT8-INT4" in quant:
if is_mobile:
torchao_config_kwargs["modules_to_not_convert"] = []
torchao_config_kwargs["include_input_output_embeddings"] = True

Expand All @@ -688,7 +726,6 @@ def quantize_and_upload(
save_to_user_id = username if push_to_user_id is None else push_to_user_id
save_to = f"{save_to_user_id}/{MODEL_NAME}-{quant}"
untied_model_path = 'f"{{MODEL_NAME}}-untied-weights"'
is_mobile = quant == "INT8-INT4"
quantized_model_id = save_to
# model card
content = MODEL_CARD.format(
Expand Down Expand Up @@ -775,7 +812,7 @@ def quantize_and_upload(
parser.add_argument(
"--quant",
type=str,
help="Quantization method. Options are FP8, INT4, INT8-INT4, AWQ-INT4",
help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4",
)
parser.add_argument(
"--tasks",
Expand Down
Loading