"""Darwin-35B-A3B-Opus AWQ-INT4 custom classes (NOESIS v14.7). Requires trust_remote_code=True. INT4 everywhere: all nn.Linear → Linear4bit (nibble uint8). Experts: Darwin35BExpertsInt4 (same nibble format, 3D tensors). lm_head: BF16 (AWQ standard — output projection kept full precision). """ from __future__ import annotations import gc import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( Qwen3_5MoeForCausalLM, Qwen3_5MoeTextConfig, ) _SKIP_QUANT = {"in_proj_a", "in_proj_b", "lm_head"} # SSM matrices + lm_head → BF16 def _dequant_nibble_2d(weight_i4: torch.Tensor, weight_scale_i4: torch.Tensor, group_size: int = 128) -> torch.Tensor: """weight_i4 [out, in//2] uint8 + weight_scale_i4 [n_groups, out] fp16 → bf16 [out, in].""" out_f, packed = weight_i4.shape in_f = packed * 2 n_groups = weight_scale_i4.shape[0] lo = (weight_i4 & 0x0F).to(torch.int32) hi = (weight_i4 >> 4).to(torch.int32) W = torch.stack([lo, hi], dim=-1).reshape(out_f, in_f).float() - 8.0 W = W.reshape(out_f, n_groups, in_f // n_groups) s = weight_scale_i4.T.float().unsqueeze(-1) # [out, n_groups, 1] return (W * s).reshape(out_f, in_f).to(torch.bfloat16) def _dequant_nibble_3d(q4: torch.Tensor, scales: torch.Tensor, group_size: int = 128) -> torch.Tensor: """q4 [n_exp, out, in//2] uint8 + scales [n_exp, out, n_groups] fp16 → bf16.""" n_exp, out_f, packed = q4.shape in_f = packed * 2 n_groups = scales.shape[-1] lo = (q4 & 0x0F).to(torch.int32) hi = (q4 >> 4).to(torch.int32) W = torch.stack([lo, hi], dim=-1).reshape(n_exp, out_f, in_f).float() - 8.0 W = W.reshape(n_exp, out_f, n_groups, in_f // n_groups) return (W * scales.float().unsqueeze(-1)).reshape(n_exp, out_f, in_f).to(torch.bfloat16) class Linear4bit(nn.Module): """AWQ-INT4 linear layer (nibble-packed uint8, dequantize per forward).""" def __init__(self, in_features: int, out_features: int, bias: bool = False, group_size: int = 128): super().__init__() self.in_features = in_features self.out_features = out_features self.group_size = group_size ng = in_features // group_size self.register_buffer("weight_i4", torch.zeros(out_features, in_features // 2, dtype=torch.uint8)) self.register_buffer("weight_scale_i4", torch.zeros(ng, out_features, dtype=torch.float16)) if bias: self.register_buffer("bias_buf", torch.zeros(out_features)) else: self.bias_buf = None def forward(self, x: torch.Tensor) -> torch.Tensor: W = _dequant_nibble_2d(self.weight_i4, self.weight_scale_i4, self.group_size) return F.linear(x.to(torch.bfloat16), W, self.bias_buf) def _replace_linear_int4(module: nn.Module, group_size: int = 128, _path: str = "") -> None: """Recursively replace qualifying nn.Linear with Linear4bit. Skips SSM state matrices (in_proj_a, in_proj_b) — stored as BF16.""" for name, child in list(module.named_children()): full = f"{_path}.{name}" if _path else name if any(s in full for s in _SKIP_QUANT): continue if (isinstance(child, nn.Linear) and min(child.in_features, child.out_features) >= group_size and not isinstance(child, Linear4bit)): setattr(module, name, Linear4bit(child.in_features, child.out_features, child.bias is not None, group_size)) else: _replace_linear_int4(child, group_size, full) class Darwin35BExpertsInt4(nn.Module): """AWQ-INT4 MoE expert block (nibble uint8, dequantize per forward).""" def __init__(self, config: Qwen3_5MoeTextConfig): super().__init__() self.config = config self.act_fn = nn.SiLU() self.register_buffer("gate_up_proj_q4", torch.empty(0, dtype=torch.uint8)) self.register_buffer("gate_up_proj_scales", torch.empty(0, dtype=torch.float16)) self.register_buffer("gate_up_proj_zeros", torch.empty(0, dtype=torch.int8)) self.register_buffer("down_proj_q4", torch.empty(0, dtype=torch.uint8)) self.register_buffer("down_proj_scales", torch.empty(0, dtype=torch.float16)) self.register_buffer("down_proj_zeros", torch.empty(0, dtype=torch.int8)) self._group_size = 128 @property def gate_up_proj(self) -> torch.Tensor: return _dequant_nibble_3d(self.gate_up_proj_q4, self.gate_up_proj_scales, self._group_size) @property def down_proj(self) -> torch.Tensor: return _dequant_nibble_3d(self.down_proj_q4, self.down_proj_scales, self._group_size) def forward(self, hidden_states, routing_weights, selected_experts, batch_size, sequence_length, hidden_dim): gup = self.gate_up_proj.to(hidden_states.dtype) down = self.down_proj.to(hidden_states.dtype) out = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) for i in range(self.config.num_experts): mask = (selected_experts == i) if not mask.any(): continue idx = mask.nonzero(as_tuple=True) gate, up = (hidden_states[idx] @ gup[i].T).chunk(2, dim=-1) out.index_add_(0, idx[0], (self.act_fn(gate) * up) @ down[i].T * routing_weights[idx].unsqueeze(-1)) return out class Darwin35BForCausalLMInt4(Qwen3_5MoeForCausalLM): """Darwin-35B AWQ-INT4: nibble INT4 for all nn.Linear + expert blocks.""" def __init__(self, config): super().__init__(config) tcfg = config.text_config if hasattr(config, "text_config") else config _replace_linear_int4(self, group_size=128) for layer in self.model.layers: if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"): old = layer.mlp.experts layer.mlp.experts = Darwin35BExpertsInt4(tcfg) del old gc.collect() Darwin35BForCausalLMInt4._auto_class = "AutoModelForCausalLM"