class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
    @classmethod
    def get_min_capability(cls) -> int:
        raise NotImplementedError(
            "TPU platform does have a concept of compute capability, "
            "this method should not be called."
        )
    @classmethod
    def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        if not current_platform.is_tpu():
            return False, "ScaledMMXLA requires running on TPU."
        if c.is_static_input_scheme:
            return False, "ScaledMMXLA requires dynamic activation scales."
        if not c.input_symmetric:
            return False, "ScaledMMXLA requires symmetric activation scales."
        if not c.is_channelwise:
            return False, "ScaledMMXLA requires channelwise weight scales"
        return True, None
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # WEIGHT
        # [out, in] (different than cutlass_scaled_mm)
        weight = getattr(layer, self.w_q_name)
        replace_parameter(
            layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
        )
        # WEIGHT SCALE
        # XLA kernels support only per-tensor and per-channel.
        # If we have a fused module (QKV, MLP) with per tensor scales (thus N
        # scales being passed to the kernel), convert to the per-channel case.
        is_fused_module = len(layer.logical_widths) > 1
        weight_scale = getattr(layer, self.w_s_name)
        if is_fused_module and not self.config.is_channelwise:
            weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
        # [out_channel,] (different than cutlass_scaled_mm)
        weight_scale = weight_scale.squeeze(-1)
        replace_parameter(
            layer,
            self.w_s_name,
            torch.nn.Parameter(weight_scale.data, requires_grad=False),
        )
        # Only support symmetric dynamic activation quantization.
        setattr(layer, self.i_s_name, None)
        setattr(layer, self.i_zp_name, None)
        setattr(layer, self.azp_adj_name, None)
        # Filter warning for cond usage in apply_weights. It is okay
        # to specialize the graph since bias is not dynamic.
        warnings.filterwarnings(
            "ignore",
            message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.",  # noqa: E501
        )
    def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
        return x
    def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
        return x + bias
    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        w_q, w_s, _, _, _ = self._get_weight_params(layer)
        # Required to register custom ops.
        import torch_xla.experimental.custom_kernel  # noqa: F401
        out = torch.ops.xla.quantized_matmul_int8(
            x,
            w_q,
            w_s,
            quantize_activation=True,
        )
        # Explicitly capture control flow to make dynamo happy.
        # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
        return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])