class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.embeddings_slice: tuple[int, int] | None
        self.embeddings_weights: torch.Tensor | None
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        if self.base_layer.num_added_embeddings_per_partition > 0:
            # We can start adding lora weights
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition : self.base_layer.num_org_embeddings_per_partition  # noqa: E501
                + self.base_layer.num_added_embeddings_per_partition
            ]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index
                - self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index
                - self.base_layer.org_vocab_size,
            )
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition :
            ].fill_(0)
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None
        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )
    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: torch.Tensor | None,
    ):
        self.reset_lora(index)
        # NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
        # so we need transpose here
        self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
            lora_a.T, non_blocking=True
        )
        self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
            lora_b, non_blocking=True
        )
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index,
                : embeddings_tensor.shape[0],
                : embeddings_tensor.shape[1],
            ].copy_(embeddings_tensor, non_blocking=True)
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1],
                    self.embeddings_tensors.shape[2],
                )[self.embeddings_slice[0] : self.embeddings_slice[1]]
                assert self.embeddings_weights is not None
                self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0)
        # NB: Don't use torch.narrow here. torch.narrow triggers some
        # Dynamic Shape specialization in torch.compile
        num_tokens = x.shape[0]
        indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
        indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
        full_lora_a_embeddings = F.embedding(
            x + indices_1,
            self.lora_a_stacked_2d,
        )
        full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask))
        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1
            )
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1],
                -1,
            )
        lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_embedding(
            full_output, full_lora_a_embeddings, self.lora_b_stacked, add_input=True
        )
        if not current_platform.can_update_inplace():
            full_output = lora_output
        return full_output.view_as(full_output_org)
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None,
    ) -> bool:
        return type(source_layer) is VocabParallelEmbedding
    @property
    def weight(self):
        return self.base_layer.weight