Skip to content

vllm.v1.attention.backends.mla.flashmla_sparse

MIN_HEADS_FOR_BF16_PREFILL module-attribute

MIN_HEADS_FOR_BF16_PREFILL = 32

NOTE: FlashMLA Sparse uses an fp8 cache with the following format

In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.

logger module-attribute

logger = init_logger(__name__)

FlashMLASparseBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
class FlashMLASparseBackend(AttentionBackend):
    accept_output_buffer: bool = True
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
        "bfloat16",
        "fp8_ds_mla",
    ]

    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [64]

    @staticmethod
    def get_name() -> str:
        return "FLASHMLA_SPARSE"

    @staticmethod
    def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
        return FlashMLASparseMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["FlashMLASparseImpl"]:
        return FlashMLASparseImpl

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [576]

    @classmethod
    def is_mla(cls) -> bool:
        return True

    @classmethod
    def is_sparse(cls) -> bool:
        return True

    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability.major in [9, 10]

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        if cache_dtype_str == "fp8_ds_mla":
            # custom storage fromat is 656 bytes
            #  see FlashMLA readme.md for details
            return (num_blocks, block_size, 656)
        else:
            return (num_blocks, block_size, head_size)

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

supported_dtypes class-attribute

supported_dtypes: list[dtype] = [bfloat16]

supported_kv_cache_dtypes class-attribute

supported_kv_cache_dtypes: list[CacheDType] = [
    "auto",
    "bfloat16",
    "fp8_ds_mla",
]

get_builder_cls staticmethod

get_builder_cls() -> type[FlashMLASparseMetadataBuilder]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
    return FlashMLASparseMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[FlashMLASparseImpl]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_impl_cls() -> type["FlashMLASparseImpl"]:
    return FlashMLASparseImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,  # assumed to be 1 for MLA
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    if cache_dtype_str == "fp8_ds_mla":
        # custom storage fromat is 656 bytes
        #  see FlashMLA readme.md for details
        return (num_blocks, block_size, 656)
    else:
        return (num_blocks, block_size, head_size)

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_name() -> str:
    return "FLASHMLA_SPARSE"

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    return [576]

get_supported_kernel_block_sizes staticmethod

get_supported_kernel_block_sizes() -> list[
    int | MultipleOf
]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
    return [64]

is_mla classmethod

is_mla() -> bool
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@classmethod
def is_mla(cls) -> bool:
    return True

is_sparse classmethod

is_sparse() -> bool
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@classmethod
def is_sparse(cls) -> bool:
    return True

supports_compute_capability classmethod

supports_compute_capability(
    capability: DeviceCapability,
) -> bool
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
    return capability.major in [9, 10]

FlashMLASparseImpl

Bases: MLACommonBaseImpl[FlashMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
    @staticmethod
    def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
        # FP8 decode kernel only supports h_q = 64 or 128
        # Compute padded head count for decode
        return 64 if num_heads <= 64 else 128

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: str,
        kv_sharing_target_layer_name: str | None,
        # MLA Specific Arguments
        topk_indice_buffer: torch.Tensor | None = None,
        indexer: Optional["Indexer"] = None,
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )
        self.softmax_scale = scale
        assert indexer is not None
        self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
        # Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
        self.prefill_padding = (
            128 if current_platform.is_device_capability_family(100) else 64
        )
        self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)

        if kv_cache_dtype == "fp8_ds_mla":
            # Reserve workspace during initialization
            vllm_config = get_current_vllm_config()
            assert vllm_config is not None and vllm_config.model_config is not None
            prefill_workspace_size = get_prefill_workspace_size(
                vllm_config.model_config.max_model_len
            )
            self.prefill_workspace_shape = (prefill_workspace_size, head_size)
            (self.prefill_bf16_workspace,) = (
                current_workspace_manager().get_simultaneous(
                    (self.prefill_workspace_shape, torch.bfloat16)
                )
            )

    def _forward_bf16_kv(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
        )

        return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)

    def _forward_fp8_kv_separate_prefill_decode(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        fp8_metadata = attn_metadata.fp8_extra_metadata
        assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
        num_decodes = fp8_metadata.num_decodes

        prefill_request_ids = None
        prefill_workspace_starts = None
        has_prefill_workspace = False
        if fp8_metadata.prefill is not None:
            prefill_request_ids = fp8_metadata.prefill.request_ids
            prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
            has_prefill_workspace = True

        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        # For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
        # For BF16 cache: always use global cache slots (no workspace)
        # prefill_workspace_starts has been adjusted in-place per chunk so
        # prefill indices automatically come out chunk-local
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
            HAS_PREFILL_WORKSPACE=has_prefill_workspace,
            prefill_workspace_request_ids=prefill_request_ids,
            prefill_workspace_starts=prefill_workspace_starts,
        )

        fp8_metadata = attn_metadata.fp8_extra_metadata
        assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)

        def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
            # Reshape q: (num_decode_tokens, num_heads, head_dim)
            #         -> (num_decodes, seq_len, num_heads, head_dim)
            q = reshape_query_for_spec_decode(q, num_decodes)
            seq_len = q.shape[1]
            # Reshape topk_indices: (num_decode_tokens, topk)
            #                    -> (num_decodes, seq_len, topk)
            topk_indices = topk_indices.view(num_decodes, seq_len, -1)
            assert fp8_metadata.decode is not None
            attn_out, _ = self._fp8_flash_mla_kernel(
                q=q,
                kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
                topk_indices=topk_indices,
                kernel_metadata=fp8_metadata.decode.kernel_metadata,
            )
            # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
            #              -> (num_decode_tokens, num_heads, head_dim_v)
            return reshape_attn_output_for_spec_decode(attn_out)

        num_decode_tokens = fp8_metadata.num_decode_tokens
        num_prefill_tokens = fp8_metadata.num_prefill_tokens

        # Pure decode: direct call without allocation
        if num_decode_tokens > 0 and num_prefill_tokens == 0:
            assert fp8_metadata.decode is not None
            attn_out = _fp8_decode(q, topk_indices)
        else:
            # Mixed or pure prefill: allocate output tensor
            attn_out = q.new_empty(
                (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
                dtype=q.dtype,
                device=q.device,
            )

            if num_decode_tokens > 0:
                attn_out[:num_decode_tokens] = _fp8_decode(
                    q[:num_decode_tokens], topk_indices[:num_decode_tokens]
                )

            assert fp8_metadata.prefill is not None
            for chunk in fp8_metadata.prefill.chunks:
                chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
                ops.cp_gather_and_upconvert_fp8_kv_cache(
                    kv_c_and_k_pe_cache,
                    chunk_workspace,
                    chunk.block_table,
                    chunk.seq_lens,
                    chunk.workspace_starts,
                    len(chunk.block_table),
                )

                chunk_q = q[chunk.tokens_slice]
                chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]

                attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
                    chunk_q,
                    chunk_workspace,
                    chunk_topk_indices_workspace,
                )

        return attn_out

    def _forward_fp8_kv_mixed_batch(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        """Mixed batch FP8 forward path that treats all tokens as one batch.

        This is equivalent to main branch's approach and avoids the BF16
        prefill kernel which has head padding overhead when num_heads is small.
        Used when use_mixed_batch is True.
        """
        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
        )

        assert attn_metadata.fp8_extra_metadata is not None
        assert isinstance(
            attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
        )
        fp8_metadata = attn_metadata.fp8_extra_metadata

        _attn_out, _ = self._fp8_flash_mla_kernel(
            q=q.unsqueeze(0),  # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
            kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
            topk_indices=topk_indices.unsqueeze(0),  # (T, topk) -> (1, T, topk)
            kernel_metadata=fp8_metadata,
        )

        # Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
        return _attn_out.squeeze(0)

    def _fp8_flash_mla_kernel(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # q shape: (batch, seq_len, num_heads, head_dim)
        actual_num_heads = q.size(2)
        padded_num_heads = self.fp8_decode_padded_heads

        # Pad query if needed (kernel only supports h_q = 64 or 128)
        if actual_num_heads < padded_num_heads:
            logger.warning_once(
                f"Padding num_heads from {actual_num_heads} to "
                f"{padded_num_heads} for FP8 sparse decode kernel"
            )
            q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
            q_padded[:, :, :actual_num_heads, :] = q
            q = q_padded

        out, lse = flash_mla_with_kvcache(
            q=q,
            k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
            block_table=kernel_metadata.dummy_block_table,
            head_dim_v=512,
            cache_seqlens=kernel_metadata.cache_lens,
            tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
            is_fp8_kvcache=True,
            indices=topk_indices,
            softmax_scale=self.softmax_scale,
        )

        # Slice output back to actual head count if we padded
        if actual_num_heads < padded_num_heads:
            out = out[:, :, :actual_num_heads, :]

        return out, lse

    def _bf16_flash_mla_kernel(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
    ) -> torch.Tensor:
        num_tokens = q.shape[0]
        kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
            -1, 1, kv_c_and_k_pe_cache.shape[-1]
        )

        # NOTE(Chen): kernel requires num_local_head to be a multiple of
        # 64 on hopper and 128 on blackwell
        if self.num_heads % self.prefill_padding != 0:
            assert self.prefill_padding % self.num_heads == 0
            logger.warning_once(
                f"Padding num_heads from {self.num_heads} to "
                f"{self.prefill_padding} for BF16 sparse prefill kernel"
            )
            q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
            q_padded[:, : self.num_heads, :] = q
            q = q_padded

        topk_indices = topk_indices.view(num_tokens, 1, -1)
        output = flash_mla_sparse_fwd(
            q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
        )[0]
        output = output[:, : self.num_heads, :]
        return output

    def forward(
        self,
        layer: AttentionLayer,
        q: torch.Tensor,
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata | None,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
        # MQA 576/512 approach for both prefill and decode

        assert output is not None, "Output tensor must be provided."

        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported for MLACommonImpl"
            )

        if attn_metadata is None:
            # Dummy run - no need to allocate buffers
            # The zero fill is required when used with DP + EP
            # to ensure all ranks within a DP group compute the
            # same expert outputs.
            return output.fill_(0)

        num_actual_toks = attn_metadata.num_actual_tokens

        # Inputs and outputs may be padded for CUDA graphs

        q = q[:num_actual_toks, ...]
        k_c_normed = k_c_normed[:num_actual_toks, ...]
        k_pe = k_pe[:num_actual_toks, ...]
        assert self.topk_indices_buffer is not None
        topk_indices = self.topk_indices_buffer[:num_actual_toks]

        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # Convert from (B, N, P) to (N, B, P)
        q_nope = q_nope.transpose(0, 1)
        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
        ql_nope = torch.bmm(q_nope, self.W_UK_T)
        # Convert from (N, B, L) to (B, N, L)
        ql_nope = ql_nope.transpose(0, 1)

        use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"

        q = torch.cat([ql_nope, q_pe], dim=-1)

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                k_c_normed,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=self.kv_cache_dtype,
                scale=layer._k_scale,
            )

        if not use_fp8_cache:
            attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata)
        elif attn_metadata.fp8_use_mixed_batch:
            attn_out = self._forward_fp8_kv_mixed_batch(
                q, kv_cache, topk_indices, attn_metadata
            )
        else:
            attn_out = self._forward_fp8_kv_separate_prefill_decode(
                q, kv_cache, topk_indices, attn_metadata
            )

        self._v_up_proj(attn_out, out=output[:num_actual_toks])
        return output

fp8_decode_padded_heads instance-attribute

fp8_decode_padded_heads = _compute_fp8_decode_padded_heads(
    num_heads
)

prefill_padding instance-attribute

prefill_padding = (
    128 if is_device_capability_family(100) else 64
)

prefill_workspace_shape instance-attribute

prefill_workspace_shape = (
    prefill_workspace_size,
    head_size,
)

softmax_scale instance-attribute

softmax_scale = scale

topk_indices_buffer instance-attribute

topk_indices_buffer: Tensor | None = topk_indices_buffer

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    topk_indice_buffer: Tensor | None = None,
    indexer: Optional[Indexer] = None,
    **mla_args,
) -> None
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    # MLA Specific Arguments
    topk_indice_buffer: torch.Tensor | None = None,
    indexer: Optional["Indexer"] = None,
    **mla_args,
) -> None:
    super().__init__(
        num_heads,
        head_size,
        scale,
        num_kv_heads,
        alibi_slopes,
        sliding_window,
        kv_cache_dtype,
        logits_soft_cap,
        attn_type,
        kv_sharing_target_layer_name,
        **mla_args,
    )
    self.softmax_scale = scale
    assert indexer is not None
    self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
    # Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
    self.prefill_padding = (
        128 if current_platform.is_device_capability_family(100) else 64
    )
    self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)

    if kv_cache_dtype == "fp8_ds_mla":
        # Reserve workspace during initialization
        vllm_config = get_current_vllm_config()
        assert vllm_config is not None and vllm_config.model_config is not None
        prefill_workspace_size = get_prefill_workspace_size(
            vllm_config.model_config.max_model_len
        )
        self.prefill_workspace_shape = (prefill_workspace_size, head_size)
        (self.prefill_bf16_workspace,) = (
            current_workspace_manager().get_simultaneous(
                (self.prefill_workspace_shape, torch.bfloat16)
            )
        )

_bf16_flash_mla_kernel

_bf16_flash_mla_kernel(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _bf16_flash_mla_kernel(
    self,
    q: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    topk_indices: torch.Tensor,
) -> torch.Tensor:
    num_tokens = q.shape[0]
    kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
        -1, 1, kv_c_and_k_pe_cache.shape[-1]
    )

    # NOTE(Chen): kernel requires num_local_head to be a multiple of
    # 64 on hopper and 128 on blackwell
    if self.num_heads % self.prefill_padding != 0:
        assert self.prefill_padding % self.num_heads == 0
        logger.warning_once(
            f"Padding num_heads from {self.num_heads} to "
            f"{self.prefill_padding} for BF16 sparse prefill kernel"
        )
        q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
        q_padded[:, : self.num_heads, :] = q
        q = q_padded

    topk_indices = topk_indices.view(num_tokens, 1, -1)
    output = flash_mla_sparse_fwd(
        q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
    )[0]
    output = output[:, : self.num_heads, :]
    return output

_compute_fp8_decode_padded_heads staticmethod

_compute_fp8_decode_padded_heads(num_heads: int) -> int
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@staticmethod
def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
    # FP8 decode kernel only supports h_q = 64 or 128
    # Compute padded head count for decode
    return 64 if num_heads <= 64 else 128

_forward_bf16_kv

_forward_bf16_kv(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _forward_bf16_kv(
    self,
    q: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    topk_indices: torch.Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
    # Convert per-request indices to global slots (decode) or workspace
    # offsets (prefill).
    topk_indices = triton_convert_req_index_to_global_index(
        attn_metadata.req_id_per_token,
        attn_metadata.block_table,
        topk_indices,
        BLOCK_SIZE=attn_metadata.block_size,
        NUM_TOPK_TOKENS=topk_indices.shape[1],
    )

    return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)

_forward_fp8_kv_mixed_batch

_forward_fp8_kv_mixed_batch(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> Tensor

Mixed batch FP8 forward path that treats all tokens as one batch.

This is equivalent to main branch's approach and avoids the BF16 prefill kernel which has head padding overhead when num_heads is small. Used when use_mixed_batch is True.

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _forward_fp8_kv_mixed_batch(
    self,
    q: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    topk_indices: torch.Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
    """Mixed batch FP8 forward path that treats all tokens as one batch.

    This is equivalent to main branch's approach and avoids the BF16
    prefill kernel which has head padding overhead when num_heads is small.
    Used when use_mixed_batch is True.
    """
    # Convert per-request indices to global slots (decode) or workspace
    # offsets (prefill).
    topk_indices = triton_convert_req_index_to_global_index(
        attn_metadata.req_id_per_token,
        attn_metadata.block_table,
        topk_indices,
        BLOCK_SIZE=attn_metadata.block_size,
        NUM_TOPK_TOKENS=topk_indices.shape[1],
    )

    assert attn_metadata.fp8_extra_metadata is not None
    assert isinstance(
        attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
    )
    fp8_metadata = attn_metadata.fp8_extra_metadata

    _attn_out, _ = self._fp8_flash_mla_kernel(
        q=q.unsqueeze(0),  # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
        kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
        topk_indices=topk_indices.unsqueeze(0),  # (T, topk) -> (1, T, topk)
        kernel_metadata=fp8_metadata,
    )

    # Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
    return _attn_out.squeeze(0)

_forward_fp8_kv_separate_prefill_decode

_forward_fp8_kv_separate_prefill_decode(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _forward_fp8_kv_separate_prefill_decode(
    self,
    q: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    topk_indices: torch.Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
    fp8_metadata = attn_metadata.fp8_extra_metadata
    assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
    num_decodes = fp8_metadata.num_decodes

    prefill_request_ids = None
    prefill_workspace_starts = None
    has_prefill_workspace = False
    if fp8_metadata.prefill is not None:
        prefill_request_ids = fp8_metadata.prefill.request_ids
        prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
        has_prefill_workspace = True

    # Convert per-request indices to global slots (decode) or workspace
    # offsets (prefill).
    # For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
    # For BF16 cache: always use global cache slots (no workspace)
    # prefill_workspace_starts has been adjusted in-place per chunk so
    # prefill indices automatically come out chunk-local
    topk_indices = triton_convert_req_index_to_global_index(
        attn_metadata.req_id_per_token,
        attn_metadata.block_table,
        topk_indices,
        BLOCK_SIZE=attn_metadata.block_size,
        NUM_TOPK_TOKENS=topk_indices.shape[1],
        HAS_PREFILL_WORKSPACE=has_prefill_workspace,
        prefill_workspace_request_ids=prefill_request_ids,
        prefill_workspace_starts=prefill_workspace_starts,
    )

    fp8_metadata = attn_metadata.fp8_extra_metadata
    assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)

    def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
        # Reshape q: (num_decode_tokens, num_heads, head_dim)
        #         -> (num_decodes, seq_len, num_heads, head_dim)
        q = reshape_query_for_spec_decode(q, num_decodes)
        seq_len = q.shape[1]
        # Reshape topk_indices: (num_decode_tokens, topk)
        #                    -> (num_decodes, seq_len, topk)
        topk_indices = topk_indices.view(num_decodes, seq_len, -1)
        assert fp8_metadata.decode is not None
        attn_out, _ = self._fp8_flash_mla_kernel(
            q=q,
            kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
            topk_indices=topk_indices,
            kernel_metadata=fp8_metadata.decode.kernel_metadata,
        )
        # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
        #              -> (num_decode_tokens, num_heads, head_dim_v)
        return reshape_attn_output_for_spec_decode(attn_out)

    num_decode_tokens = fp8_metadata.num_decode_tokens
    num_prefill_tokens = fp8_metadata.num_prefill_tokens

    # Pure decode: direct call without allocation
    if num_decode_tokens > 0 and num_prefill_tokens == 0:
        assert fp8_metadata.decode is not None
        attn_out = _fp8_decode(q, topk_indices)
    else:
        # Mixed or pure prefill: allocate output tensor
        attn_out = q.new_empty(
            (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )

        if num_decode_tokens > 0:
            attn_out[:num_decode_tokens] = _fp8_decode(
                q[:num_decode_tokens], topk_indices[:num_decode_tokens]
            )

        assert fp8_metadata.prefill is not None
        for chunk in fp8_metadata.prefill.chunks:
            chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
            ops.cp_gather_and_upconvert_fp8_kv_cache(
                kv_c_and_k_pe_cache,
                chunk_workspace,
                chunk.block_table,
                chunk.seq_lens,
                chunk.workspace_starts,
                len(chunk.block_table),
            )

            chunk_q = q[chunk.tokens_slice]
            chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]

            attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
                chunk_q,
                chunk_workspace,
                chunk_topk_indices_workspace,
            )

    return attn_out

_fp8_flash_mla_kernel

_fp8_flash_mla_kernel(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
    kernel_metadata: FP8KernelMetadata,
) -> tuple[Tensor, Tensor]
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _fp8_flash_mla_kernel(
    self,
    q: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    topk_indices: torch.Tensor,
    kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    # q shape: (batch, seq_len, num_heads, head_dim)
    actual_num_heads = q.size(2)
    padded_num_heads = self.fp8_decode_padded_heads

    # Pad query if needed (kernel only supports h_q = 64 or 128)
    if actual_num_heads < padded_num_heads:
        logger.warning_once(
            f"Padding num_heads from {actual_num_heads} to "
            f"{padded_num_heads} for FP8 sparse decode kernel"
        )
        q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
        q_padded[:, :, :actual_num_heads, :] = q
        q = q_padded

    out, lse = flash_mla_with_kvcache(
        q=q,
        k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
        block_table=kernel_metadata.dummy_block_table,
        head_dim_v=512,
        cache_seqlens=kernel_metadata.cache_lens,
        tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
        is_fp8_kvcache=True,
        indices=topk_indices,
        softmax_scale=self.softmax_scale,
    )

    # Slice output back to actual head count if we padded
    if actual_num_heads < padded_num_heads:
        out = out[:, :, :actual_num_heads, :]

    return out, lse

forward

forward(
    layer: AttentionLayer,
    q: Tensor,
    k_c_normed: Tensor,
    k_pe: Tensor,
    kv_cache: Tensor,
    attn_metadata: FlashMLASparseMetadata | None,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def forward(
    self,
    layer: AttentionLayer,
    q: torch.Tensor,
    k_c_normed: torch.Tensor,  # key in unified attn
    k_pe: torch.Tensor,  # value in unified attn
    kv_cache: torch.Tensor,
    attn_metadata: FlashMLASparseMetadata | None,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
    # MQA 576/512 approach for both prefill and decode

    assert output is not None, "Output tensor must be provided."

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported for MLACommonImpl"
        )

    if attn_metadata is None:
        # Dummy run - no need to allocate buffers
        # The zero fill is required when used with DP + EP
        # to ensure all ranks within a DP group compute the
        # same expert outputs.
        return output.fill_(0)

    num_actual_toks = attn_metadata.num_actual_tokens

    # Inputs and outputs may be padded for CUDA graphs

    q = q[:num_actual_toks, ...]
    k_c_normed = k_c_normed[:num_actual_toks, ...]
    k_pe = k_pe[:num_actual_toks, ...]
    assert self.topk_indices_buffer is not None
    topk_indices = self.topk_indices_buffer[:num_actual_toks]

    q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
    # Convert from (B, N, P) to (N, B, P)
    q_nope = q_nope.transpose(0, 1)
    # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
    ql_nope = torch.bmm(q_nope, self.W_UK_T)
    # Convert from (N, B, L) to (B, N, L)
    ql_nope = ql_nope.transpose(0, 1)

    use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"

    q = torch.cat([ql_nope, q_pe], dim=-1)

    # write the latent and rope to kv cache
    if kv_cache.numel() > 0:
        ops.concat_and_cache_mla(
            k_c_normed,
            k_pe.squeeze(1),
            kv_cache,
            attn_metadata.slot_mapping.flatten(),
            kv_cache_dtype=self.kv_cache_dtype,
            scale=layer._k_scale,
        )

    if not use_fp8_cache:
        attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata)
    elif attn_metadata.fp8_use_mixed_batch:
        attn_out = self._forward_fp8_kv_mixed_batch(
            q, kv_cache, topk_indices, attn_metadata
        )
    else:
        attn_out = self._forward_fp8_kv_separate_prefill_decode(
            q, kv_cache, topk_indices, attn_metadata
        )

    self._v_up_proj(attn_out, out=output[:num_actual_toks])
    return output

FlashMLASparseMetadata dataclass

Bases: AttentionMetadata

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class FlashMLASparseMetadata(AttentionMetadata):
    num_reqs: int
    max_query_len: int
    max_seq_len: int

    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

    block_table: torch.Tensor
    req_id_per_token: torch.Tensor
    block_size: int = 64
    topk_tokens: int = 2048

    @dataclass
    class FP8KernelMetadata:
        scheduler_metadata: FlashMLASchedMeta
        dummy_block_table: torch.Tensor
        cache_lens: torch.Tensor

    @dataclass
    class FP8SeparatePrefillDecode:
        @dataclass
        class Decode:
            kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
            decode_query_len: int  # needed for reshape in spec decode

        @dataclass
        class Prefill:
            # Sequence lengths (context + query) for prefill requests
            # Shape: [num_prefill_reqs]
            seq_lens: torch.Tensor

            # Request ID for each token: -1 for decode tokens, request index
            # (0, 1, 2, ...) for prefill tokens.
            # Shape: [num_actual_tokens]
            request_ids: torch.Tensor

            # Workspace start offsets for all prefill requests
            # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
            # 0-indexed within each chunk. Used to map prefill tokens to workspace
            # offsets in convert_logical_index_to_physical_index
            workspace_starts: torch.Tensor

            @dataclass
            class Chunk:
                """Metadata for a chunk of prefill requests.

                Prefill requests may be chunked to fit within the fixed workspace size.
                """

                seq_lens: torch.Tensor
                tokens_slice: slice
                block_table: torch.Tensor
                req_start_idx: int
                workspace_starts: torch.Tensor
                chunk_tot_seqlen: int

            chunks: list[Chunk]

        num_prefills: int = 0
        num_decodes: int = 0
        num_prefill_tokens: int = 0
        num_decode_tokens: int = 0

        decode: Decode | None = None
        prefill: Prefill | None = None

    fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
    fp8_use_mixed_batch: bool = False

block_size class-attribute instance-attribute

block_size: int = 64

block_table instance-attribute

block_table: Tensor

fp8_extra_metadata class-attribute instance-attribute

fp8_extra_metadata: (
    FP8SeparatePrefillDecode | FP8KernelMetadata | None
) = None

fp8_use_mixed_batch class-attribute instance-attribute

fp8_use_mixed_batch: bool = False

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_reqs instance-attribute

num_reqs: int

query_start_loc instance-attribute

query_start_loc: Tensor

req_id_per_token instance-attribute

req_id_per_token: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

topk_tokens class-attribute instance-attribute

topk_tokens: int = 2048

FP8KernelMetadata dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class FP8KernelMetadata:
    scheduler_metadata: FlashMLASchedMeta
    dummy_block_table: torch.Tensor
    cache_lens: torch.Tensor

cache_lens instance-attribute

cache_lens: Tensor

dummy_block_table instance-attribute

dummy_block_table: Tensor

scheduler_metadata instance-attribute

scheduler_metadata: FlashMLASchedMeta

__init__

__init__(
    scheduler_metadata: FlashMLASchedMeta,
    dummy_block_table: Tensor,
    cache_lens: Tensor,
) -> None

FP8SeparatePrefillDecode dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class FP8SeparatePrefillDecode:
    @dataclass
    class Decode:
        kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
        decode_query_len: int  # needed for reshape in spec decode

    @dataclass
    class Prefill:
        # Sequence lengths (context + query) for prefill requests
        # Shape: [num_prefill_reqs]
        seq_lens: torch.Tensor

        # Request ID for each token: -1 for decode tokens, request index
        # (0, 1, 2, ...) for prefill tokens.
        # Shape: [num_actual_tokens]
        request_ids: torch.Tensor

        # Workspace start offsets for all prefill requests
        # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
        # 0-indexed within each chunk. Used to map prefill tokens to workspace
        # offsets in convert_logical_index_to_physical_index
        workspace_starts: torch.Tensor

        @dataclass
        class Chunk:
            """Metadata for a chunk of prefill requests.

            Prefill requests may be chunked to fit within the fixed workspace size.
            """

            seq_lens: torch.Tensor
            tokens_slice: slice
            block_table: torch.Tensor
            req_start_idx: int
            workspace_starts: torch.Tensor
            chunk_tot_seqlen: int

        chunks: list[Chunk]

    num_prefills: int = 0
    num_decodes: int = 0
    num_prefill_tokens: int = 0
    num_decode_tokens: int = 0

    decode: Decode | None = None
    prefill: Prefill | None = None

decode class-attribute instance-attribute

decode: Decode | None = None

num_decode_tokens class-attribute instance-attribute

num_decode_tokens: int = 0

num_decodes class-attribute instance-attribute

num_decodes: int = 0

num_prefill_tokens class-attribute instance-attribute

num_prefill_tokens: int = 0

num_prefills class-attribute instance-attribute

num_prefills: int = 0

prefill class-attribute instance-attribute

prefill: Prefill | None = None

Decode dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class Decode:
    kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
    decode_query_len: int  # needed for reshape in spec decode
decode_query_len instance-attribute
decode_query_len: int
kernel_metadata instance-attribute
kernel_metadata: FP8KernelMetadata
__init__
__init__(
    kernel_metadata: FP8KernelMetadata,
    decode_query_len: int,
) -> None

Prefill dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class Prefill:
    # Sequence lengths (context + query) for prefill requests
    # Shape: [num_prefill_reqs]
    seq_lens: torch.Tensor

    # Request ID for each token: -1 for decode tokens, request index
    # (0, 1, 2, ...) for prefill tokens.
    # Shape: [num_actual_tokens]
    request_ids: torch.Tensor

    # Workspace start offsets for all prefill requests
    # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
    # 0-indexed within each chunk. Used to map prefill tokens to workspace
    # offsets in convert_logical_index_to_physical_index
    workspace_starts: torch.Tensor

    @dataclass
    class Chunk:
        """Metadata for a chunk of prefill requests.

        Prefill requests may be chunked to fit within the fixed workspace size.
        """

        seq_lens: torch.Tensor
        tokens_slice: slice
        block_table: torch.Tensor
        req_start_idx: int
        workspace_starts: torch.Tensor
        chunk_tot_seqlen: int

    chunks: list[Chunk]
chunks instance-attribute
chunks: list[Chunk]
request_ids instance-attribute
request_ids: Tensor
seq_lens instance-attribute
seq_lens: Tensor
workspace_starts instance-attribute
workspace_starts: Tensor
Chunk dataclass

Metadata for a chunk of prefill requests.

Prefill requests may be chunked to fit within the fixed workspace size.

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class Chunk:
    """Metadata for a chunk of prefill requests.

    Prefill requests may be chunked to fit within the fixed workspace size.
    """

    seq_lens: torch.Tensor
    tokens_slice: slice
    block_table: torch.Tensor
    req_start_idx: int
    workspace_starts: torch.Tensor
    chunk_tot_seqlen: int
block_table instance-attribute
block_table: Tensor
chunk_tot_seqlen instance-attribute
chunk_tot_seqlen: int
req_start_idx instance-attribute
req_start_idx: int
seq_lens instance-attribute
seq_lens: Tensor
tokens_slice instance-attribute
tokens_slice: slice
workspace_starts instance-attribute
workspace_starts: Tensor
__init__
__init__(
    seq_lens: Tensor,
    tokens_slice: slice,
    block_table: Tensor,
    req_start_idx: int,
    workspace_starts: Tensor,
    chunk_tot_seqlen: int,
) -> None
__init__
__init__(
    seq_lens: Tensor,
    request_ids: Tensor,
    workspace_starts: Tensor,
    chunks: list[Chunk],
) -> None

__init__

__init__(
    num_prefills: int = 0,
    num_decodes: int = 0,
    num_prefill_tokens: int = 0,
    num_decode_tokens: int = 0,
    decode: Decode | None = None,
    prefill: Prefill | None = None,
) -> None

__init__

__init__(
    num_reqs: int,
    max_query_len: int,
    max_seq_len: int,
    num_actual_tokens: int,
    query_start_loc: Tensor,
    slot_mapping: Tensor,
    block_table: Tensor,
    req_id_per_token: Tensor,
    block_size: int = 64,
    topk_tokens: int = 2048,
    fp8_extra_metadata: FP8SeparatePrefillDecode
    | FP8KernelMetadata
    | None = None,
    fp8_use_mixed_batch: bool = False,
) -> None

FlashMLASparseMetadataBuilder

Bases: AttentionMetadataBuilder[FlashMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ) -> None:
        self.vllm_config = vllm_config
        self.layer_names = layer_names
        cache_config = vllm_config.cache_config
        self.kv_cache_spec = kv_cache_spec
        self.model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
        self.device = device

        # Treat requests with query length <= 1 as decodes to match the
        # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)

        props = torch.cuda.get_device_properties(device)
        sm_count = props.multi_processor_count

        self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
        self.mla_dims = get_mla_dims(self.model_config)
        # FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
        self.fp8_decode_padded_heads = (
            FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
        )

        self.topk_tokens = vllm_config.model_config.hf_config.index_topk
        self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
        max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
        self.topk_tokens_tensor = torch.full(
            (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
        )
        # Shape: [max_num_seqs], all elements = max_model_len
        self.max_model_len_tensor = torch.full(
            (max_num_seqs,),
            self.model_config.max_model_len,
            device=device,
            dtype=torch.int32,
        )
        # this is ignored by `flash_mla_with_kvcache` if indices not None
        self.dummy_block_table = torch.empty(
            (max_num_seqs, 1), dtype=torch.int32, device=self.device
        )

        # Equation taken from FlashMLA/csrc/api/sparse_decode.h
        # For sparse FP8 decode, the formula depends on architecture:
        # - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
        # - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
        # - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
        # For max buffer size, use s_q = 1 (the case that produces largest output)
        # Use padded head count since that's what will be passed to the kernel
        h_q = self.fp8_decode_padded_heads
        if current_platform.is_device_capability_family(100):
            # SM100 head64 or head64x2 uses full SM count
            max_num_sm_parts = sm_count
        else:
            # SM90 uses h_q/64 divisor
            max_num_sm_parts = sm_count // max(1, h_q // 64)
        self.tile_scheduler_metadata_buffer = torch.empty(
            # TileSchedulerMetaDataSize = 8
            # see: FlashMLA/csrc/params.h
            (max_num_sm_parts, 8),
            dtype=torch.int32,
            device=device,
        )
        # Sized for per-request batching (num_decodes + 1)
        self.num_splits_buffer = torch.empty(
            (max_num_seqs + 1,),
            dtype=torch.int32,
            device=device,
        )
        self.req_id_per_token_buffer = torch.empty(
            (vllm_config.scheduler_config.max_num_batched_tokens,),
            dtype=torch.int32,
            device=device,
        )

    def _build_fp8_mixed_decode_prefill(
        self,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> "FlashMLASparseMetadata.FP8KernelMetadata":
        """Build FP8 metadata treating all tokens as one mixed batch.

        This matches main branch's approach and avoids the BF16 prefill kernel
        which has head padding overhead when num_heads is small (high TP case).
        """
        num_tokens = common_attn_metadata.num_actual_tokens

        # Use padded head count since that's what the kernel will see
        padded_heads = self.fp8_decode_padded_heads

        # Build metadata for all tokens as a single batch
        scheduler_metadata, _ = get_mla_metadata(
            cache_seqlens=self.topk_tokens_tensor[:1],  # Single batch
            num_q_tokens_per_head_k=num_tokens * padded_heads,
            topk=self.topk_tokens,
            num_heads_q=padded_heads,
            num_heads_k=1,
            is_fp8_kvcache=True,
        )

        fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
            scheduler_metadata=scheduler_metadata,
            cache_lens=self.max_model_len_tensor[:1],
            dummy_block_table=self.dummy_block_table[:1],
        )

        return fp8_metadata

    def _build_fp8_separate_prefill_decode(
        self,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
        num_tokens = common_attn_metadata.num_actual_tokens

        (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold or 1,
                require_uniform=True,
            )
        )

        FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
        fp8_metadata = FP8Meta(
            num_decodes=num_decodes,
            num_prefills=num_prefills,
            num_decode_tokens=num_decode_tokens,
            num_prefill_tokens=num_prefill_tokens,
        )

        # Extract prefill sequence lengths (context + query, not just query)
        # Decode requests come first in the batch, prefill requests follow
        prefill_seq_lens = None
        prefill_request_id = None
        prefill_workspace_starts = None
        prefill_chunks = None

        # For pure decode batches, prefill_request_id will be None
        # For mixed batches, it will have -1 for decode and request_id for prefill
        if num_prefills > 0:
            seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
            seq_lens = common_attn_metadata.seq_lens
            query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

            prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
            prefill_seq_lens = seq_lens[num_decodes:]

            # Build prefill_request_id: -1 for decode, request index for
            # prefill. This enables a single
            # convert_logical_index_to_physical_index call for all tokens
            prefill_request_id = torch.full(
                (num_tokens,), -1, dtype=torch.int32, device=self.device
            )
            # Map prefill tokens to their request IDs (0, 1, 2, ...)
            for req_idx in range(num_prefills):
                # Get query token range for this prefill request
                global_req_idx = num_decodes + req_idx
                req_query_start = query_start_loc_cpu[global_req_idx]
                req_query_end = query_start_loc_cpu[global_req_idx + 1]
                prefill_request_id[req_query_start:req_query_end] = req_idx

            # will be adjusted by chunk loop
            prefill_workspace_starts_cpu = torch.zeros(
                num_prefills, dtype=torch.int32, pin_memory=True
            )
            prefill_workspace_starts_cpu[1:] = torch.cumsum(
                prefill_seq_lens_cpu[:-1], dim=0
            )
            # populated by non-blocking copy after prefill_workspace_starts_cpu is
            # updated by each chunk
            prefill_workspace_starts = torch.empty(
                num_prefills, dtype=torch.int32, device=self.device
            )

            # Chunk prefill requests to fit within workspace size
            max_prefill_buffer_size = get_prefill_workspace_size(
                self.vllm_config.model_config.max_model_len
            )
            chunk_bounds = split_prefill_chunks(
                prefill_seq_lens_cpu, max_prefill_buffer_size
            )

            prefill_chunks = []
            for chunk_start, chunk_end in chunk_bounds:
                # Adjust workspace_starts in-place per chunk to be
                # 0-indexed within each chunk
                # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
                #   Initial: workspace_starts=[0,10,25,45]
                #   After:   workspace_starts=[0,10,0,20]
                #           (chunk 0 starts at 0, chunk 1 starts at 0)
                offset = prefill_workspace_starts_cpu[chunk_start].item()
                prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset

                chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
                chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
                token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
                token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
                tokens_slice = slice(token_start, token_end)

                # Create chunk view of gpu tensor
                chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
                chunk_block_table = common_attn_metadata.block_table_tensor[
                    num_decodes + chunk_start : num_decodes + chunk_end
                ]

                prefill_chunks.append(
                    FP8Meta.Prefill.Chunk(
                        seq_lens=chunk_seq_lens,
                        tokens_slice=tokens_slice,
                        block_table=chunk_block_table,
                        req_start_idx=chunk_start,
                        workspace_starts=chunk_workspace_starts,
                        chunk_tot_seqlen=chunk_tot_seqlen,
                    )
                )

            prefill_workspace_starts.copy_(
                prefill_workspace_starts_cpu, non_blocking=True
            )

            fp8_metadata.prefill = FP8Meta.Prefill(
                seq_lens=prefill_seq_lens,
                request_ids=prefill_request_id,
                workspace_starts=prefill_workspace_starts,
                chunks=prefill_chunks,
            )

        if num_decodes > 0:
            # Compute decode_query_len for spec decode (uniform due to require_uniform)
            query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
            decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()

            # Use padded head count since that's what the kernel will see
            padded_heads = self.fp8_decode_padded_heads
            scheduler_metadata, _ = get_mla_metadata(
                cache_seqlens=self.topk_tokens_tensor[:num_decodes],
                num_q_tokens_per_head_k=decode_query_len * padded_heads,
                topk=self.topk_tokens,
                num_heads_q=padded_heads,
                num_heads_k=1,
                is_fp8_kvcache=True,
            )

            kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
                scheduler_metadata=scheduler_metadata,
                dummy_block_table=self.dummy_block_table[:num_decodes],
                cache_lens=self.max_model_len_tensor[:num_decodes],
            )
            fp8_metadata.decode = FP8Meta.Decode(
                kernel_metadata=kernel_meta,
                decode_query_len=decode_query_len,
            )

        return fp8_metadata

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashMLASparseMetadata:
        cm = common_attn_metadata
        num_tokens = cm.num_actual_tokens
        starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
        seg_lengths = np.diff(starts)
        req_id_per_token = np.repeat(
            np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
        )
        # Zero-fill for cudagraphs
        self.req_id_per_token_buffer.fill_(0)
        self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
            torch.from_numpy(req_id_per_token), non_blocking=True
        )
        req_id_per_token = self.req_id_per_token_buffer[:num_tokens]

        fp8_extra_metadata: (
            FlashMLASparseMetadata.FP8SeparatePrefillDecode
            | FlashMLASparseMetadata.FP8KernelMetadata
            | None
        ) = None
        fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
        if self.use_fp8_kv_cache:
            if fp8_use_mixed_batch:
                fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
            else:
                fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)

        metadata = FlashMLASparseMetadata(
            num_reqs=cm.num_reqs,
            max_query_len=cm.max_query_len,
            max_seq_len=cm.max_seq_len,
            num_actual_tokens=cm.num_actual_tokens,
            query_start_loc=cm.query_start_loc,
            slot_mapping=cm.slot_mapping,
            block_table=cm.block_table_tensor,
            req_id_per_token=req_id_per_token,
            block_size=self.kv_cache_spec.block_size,
            topk_tokens=self.topk_tokens,
            fp8_extra_metadata=fp8_extra_metadata,
            fp8_use_mixed_batch=fp8_use_mixed_batch,
        )

        return metadata

_cudagraph_support class-attribute

_cudagraph_support: AttentionCGSupport = UNIFORM_BATCH

device instance-attribute

device = device

dummy_block_table instance-attribute

dummy_block_table = empty(
    (max_num_seqs, 1), dtype=int32, device=device
)

fp8_decode_padded_heads instance-attribute

fp8_decode_padded_heads = _compute_fp8_decode_padded_heads(
    num_heads
)

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

layer_names instance-attribute

layer_names = layer_names

max_model_len_tensor instance-attribute

max_model_len_tensor = full(
    (max_num_seqs,),
    max_model_len,
    device=device,
    dtype=int32,
)

mla_dims instance-attribute

mla_dims = get_mla_dims(model_config)

model_config instance-attribute

model_config = model_config

num_heads instance-attribute

num_heads = get_num_attention_heads(parallel_config)

num_splits_buffer instance-attribute

num_splits_buffer = empty(
    (max_num_seqs + 1,), dtype=int32, device=device
)

req_id_per_token_buffer instance-attribute

req_id_per_token_buffer = empty(
    (max_num_batched_tokens,), dtype=int32, device=device
)

tile_scheduler_metadata_buffer instance-attribute

tile_scheduler_metadata_buffer = empty(
    (max_num_sm_parts, 8), dtype=int32, device=device
)

topk_tokens instance-attribute

topk_tokens = index_topk

topk_tokens_tensor instance-attribute

topk_tokens_tensor = full(
    (max_num_seqs,), topk_tokens, device=device, dtype=int32
)

use_fp8_kv_cache instance-attribute

use_fp8_kv_cache = cache_dtype == 'fp8_ds_mla'

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
) -> None
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
) -> None:
    self.vllm_config = vllm_config
    self.layer_names = layer_names
    cache_config = vllm_config.cache_config
    self.kv_cache_spec = kv_cache_spec
    self.model_config = vllm_config.model_config
    parallel_config = vllm_config.parallel_config
    self.device = device

    # Treat requests with query length <= 1 as decodes to match the
    # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
    self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)

    props = torch.cuda.get_device_properties(device)
    sm_count = props.multi_processor_count

    self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
    self.mla_dims = get_mla_dims(self.model_config)
    # FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
    self.fp8_decode_padded_heads = (
        FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
    )

    self.topk_tokens = vllm_config.model_config.hf_config.index_topk
    self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
    max_num_seqs = vllm_config.scheduler_config.max_num_seqs
    # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
    self.topk_tokens_tensor = torch.full(
        (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
    )
    # Shape: [max_num_seqs], all elements = max_model_len
    self.max_model_len_tensor = torch.full(
        (max_num_seqs,),
        self.model_config.max_model_len,
        device=device,
        dtype=torch.int32,
    )
    # this is ignored by `flash_mla_with_kvcache` if indices not None
    self.dummy_block_table = torch.empty(
        (max_num_seqs, 1), dtype=torch.int32, device=self.device
    )

    # Equation taken from FlashMLA/csrc/api/sparse_decode.h
    # For sparse FP8 decode, the formula depends on architecture:
    # - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
    # - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
    # - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
    # For max buffer size, use s_q = 1 (the case that produces largest output)
    # Use padded head count since that's what will be passed to the kernel
    h_q = self.fp8_decode_padded_heads
    if current_platform.is_device_capability_family(100):
        # SM100 head64 or head64x2 uses full SM count
        max_num_sm_parts = sm_count
    else:
        # SM90 uses h_q/64 divisor
        max_num_sm_parts = sm_count // max(1, h_q // 64)
    self.tile_scheduler_metadata_buffer = torch.empty(
        # TileSchedulerMetaDataSize = 8
        # see: FlashMLA/csrc/params.h
        (max_num_sm_parts, 8),
        dtype=torch.int32,
        device=device,
    )
    # Sized for per-request batching (num_decodes + 1)
    self.num_splits_buffer = torch.empty(
        (max_num_seqs + 1,),
        dtype=torch.int32,
        device=device,
    )
    self.req_id_per_token_buffer = torch.empty(
        (vllm_config.scheduler_config.max_num_batched_tokens,),
        dtype=torch.int32,
        device=device,
    )

_build_fp8_mixed_decode_prefill

_build_fp8_mixed_decode_prefill(
    common_attn_metadata: CommonAttentionMetadata,
) -> FP8KernelMetadata

Build FP8 metadata treating all tokens as one mixed batch.

This matches main branch's approach and avoids the BF16 prefill kernel which has head padding overhead when num_heads is small (high TP case).

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _build_fp8_mixed_decode_prefill(
    self,
    common_attn_metadata: CommonAttentionMetadata,
) -> "FlashMLASparseMetadata.FP8KernelMetadata":
    """Build FP8 metadata treating all tokens as one mixed batch.

    This matches main branch's approach and avoids the BF16 prefill kernel
    which has head padding overhead when num_heads is small (high TP case).
    """
    num_tokens = common_attn_metadata.num_actual_tokens

    # Use padded head count since that's what the kernel will see
    padded_heads = self.fp8_decode_padded_heads

    # Build metadata for all tokens as a single batch
    scheduler_metadata, _ = get_mla_metadata(
        cache_seqlens=self.topk_tokens_tensor[:1],  # Single batch
        num_q_tokens_per_head_k=num_tokens * padded_heads,
        topk=self.topk_tokens,
        num_heads_q=padded_heads,
        num_heads_k=1,
        is_fp8_kvcache=True,
    )

    fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
        scheduler_metadata=scheduler_metadata,
        cache_lens=self.max_model_len_tensor[:1],
        dummy_block_table=self.dummy_block_table[:1],
    )

    return fp8_metadata

_build_fp8_separate_prefill_decode

_build_fp8_separate_prefill_decode(
    common_attn_metadata: CommonAttentionMetadata,
) -> FP8SeparatePrefillDecode
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _build_fp8_separate_prefill_decode(
    self,
    common_attn_metadata: CommonAttentionMetadata,
) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
    num_tokens = common_attn_metadata.num_actual_tokens

    (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
        split_decodes_and_prefills(
            common_attn_metadata,
            decode_threshold=self.reorder_batch_threshold or 1,
            require_uniform=True,
        )
    )

    FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
    fp8_metadata = FP8Meta(
        num_decodes=num_decodes,
        num_prefills=num_prefills,
        num_decode_tokens=num_decode_tokens,
        num_prefill_tokens=num_prefill_tokens,
    )

    # Extract prefill sequence lengths (context + query, not just query)
    # Decode requests come first in the batch, prefill requests follow
    prefill_seq_lens = None
    prefill_request_id = None
    prefill_workspace_starts = None
    prefill_chunks = None

    # For pure decode batches, prefill_request_id will be None
    # For mixed batches, it will have -1 for decode and request_id for prefill
    if num_prefills > 0:
        seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
        seq_lens = common_attn_metadata.seq_lens
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

        prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
        prefill_seq_lens = seq_lens[num_decodes:]

        # Build prefill_request_id: -1 for decode, request index for
        # prefill. This enables a single
        # convert_logical_index_to_physical_index call for all tokens
        prefill_request_id = torch.full(
            (num_tokens,), -1, dtype=torch.int32, device=self.device
        )
        # Map prefill tokens to their request IDs (0, 1, 2, ...)
        for req_idx in range(num_prefills):
            # Get query token range for this prefill request
            global_req_idx = num_decodes + req_idx
            req_query_start = query_start_loc_cpu[global_req_idx]
            req_query_end = query_start_loc_cpu[global_req_idx + 1]
            prefill_request_id[req_query_start:req_query_end] = req_idx

        # will be adjusted by chunk loop
        prefill_workspace_starts_cpu = torch.zeros(
            num_prefills, dtype=torch.int32, pin_memory=True
        )
        prefill_workspace_starts_cpu[1:] = torch.cumsum(
            prefill_seq_lens_cpu[:-1], dim=0
        )
        # populated by non-blocking copy after prefill_workspace_starts_cpu is
        # updated by each chunk
        prefill_workspace_starts = torch.empty(
            num_prefills, dtype=torch.int32, device=self.device
        )

        # Chunk prefill requests to fit within workspace size
        max_prefill_buffer_size = get_prefill_workspace_size(
            self.vllm_config.model_config.max_model_len
        )
        chunk_bounds = split_prefill_chunks(
            prefill_seq_lens_cpu, max_prefill_buffer_size
        )

        prefill_chunks = []
        for chunk_start, chunk_end in chunk_bounds:
            # Adjust workspace_starts in-place per chunk to be
            # 0-indexed within each chunk
            # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
            #   Initial: workspace_starts=[0,10,25,45]
            #   After:   workspace_starts=[0,10,0,20]
            #           (chunk 0 starts at 0, chunk 1 starts at 0)
            offset = prefill_workspace_starts_cpu[chunk_start].item()
            prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset

            chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
            chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
            token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
            token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
            tokens_slice = slice(token_start, token_end)

            # Create chunk view of gpu tensor
            chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
            chunk_block_table = common_attn_metadata.block_table_tensor[
                num_decodes + chunk_start : num_decodes + chunk_end
            ]

            prefill_chunks.append(
                FP8Meta.Prefill.Chunk(
                    seq_lens=chunk_seq_lens,
                    tokens_slice=tokens_slice,
                    block_table=chunk_block_table,
                    req_start_idx=chunk_start,
                    workspace_starts=chunk_workspace_starts,
                    chunk_tot_seqlen=chunk_tot_seqlen,
                )
            )

        prefill_workspace_starts.copy_(
            prefill_workspace_starts_cpu, non_blocking=True
        )

        fp8_metadata.prefill = FP8Meta.Prefill(
            seq_lens=prefill_seq_lens,
            request_ids=prefill_request_id,
            workspace_starts=prefill_workspace_starts,
            chunks=prefill_chunks,
        )

    if num_decodes > 0:
        # Compute decode_query_len for spec decode (uniform due to require_uniform)
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
        decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()

        # Use padded head count since that's what the kernel will see
        padded_heads = self.fp8_decode_padded_heads
        scheduler_metadata, _ = get_mla_metadata(
            cache_seqlens=self.topk_tokens_tensor[:num_decodes],
            num_q_tokens_per_head_k=decode_query_len * padded_heads,
            topk=self.topk_tokens,
            num_heads_q=padded_heads,
            num_heads_k=1,
            is_fp8_kvcache=True,
        )

        kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
            scheduler_metadata=scheduler_metadata,
            dummy_block_table=self.dummy_block_table[:num_decodes],
            cache_lens=self.max_model_len_tensor[:num_decodes],
        )
        fp8_metadata.decode = FP8Meta.Decode(
            kernel_metadata=kernel_meta,
            decode_query_len=decode_query_len,
        )

    return fp8_metadata

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> FlashMLASparseMetadata
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> FlashMLASparseMetadata:
    cm = common_attn_metadata
    num_tokens = cm.num_actual_tokens
    starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
    seg_lengths = np.diff(starts)
    req_id_per_token = np.repeat(
        np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
    )
    # Zero-fill for cudagraphs
    self.req_id_per_token_buffer.fill_(0)
    self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
        torch.from_numpy(req_id_per_token), non_blocking=True
    )
    req_id_per_token = self.req_id_per_token_buffer[:num_tokens]

    fp8_extra_metadata: (
        FlashMLASparseMetadata.FP8SeparatePrefillDecode
        | FlashMLASparseMetadata.FP8KernelMetadata
        | None
    ) = None
    fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
    if self.use_fp8_kv_cache:
        if fp8_use_mixed_batch:
            fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
        else:
            fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)

    metadata = FlashMLASparseMetadata(
        num_reqs=cm.num_reqs,
        max_query_len=cm.max_query_len,
        max_seq_len=cm.max_seq_len,
        num_actual_tokens=cm.num_actual_tokens,
        query_start_loc=cm.query_start_loc,
        slot_mapping=cm.slot_mapping,
        block_table=cm.block_table_tensor,
        req_id_per_token=req_id_per_token,
        block_size=self.kv_cache_spec.block_size,
        topk_tokens=self.topk_tokens,
        fp8_extra_metadata=fp8_extra_metadata,
        fp8_use_mixed_batch=fp8_use_mixed_batch,
    )

    return metadata

_convert_req_index_to_global_index_kernel

_convert_req_index_to_global_index_kernel(
    req_id_ptr,
    block_table_ptr,
    token_indices_ptr,
    out_ptr,
    prefill_request_id_ptr,
    workspace_starts_ptr,
    max_num_blocks_per_req: constexpr,
    BLOCK_SIZE: constexpr,
    BLOCK_N: constexpr,
    HAS_PREFILL: constexpr,
    bt_stride0,
    bt_stride1,
    ti_stride0,
    ti_stride1,
    out_stride0,
    out_stride1,
)
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@triton.jit
def _convert_req_index_to_global_index_kernel(
    req_id_ptr,  # int32 [num_tokens]
    block_table_ptr,  # int32 [num_requests, max_num_blocks_per_req]
    token_indices_ptr,  # int32 [num_tokens, NUM_TOPK_TOKENS]
    out_ptr,  # int32 [num_tokens, NUM_TOPK_TOKENS]
    prefill_request_id_ptr,  # int32 [num_tokens], -1 for decode, >=0 for prefill
    workspace_starts_ptr,  # int32 [num_prefill_reqs+1] or nullptr
    # shapes (compile-time where possible)
    max_num_blocks_per_req: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    BLOCK_N: tl.constexpr,  # tile width along columns
    HAS_PREFILL: tl.constexpr,
    # strides (in elements)
    bt_stride0,
    bt_stride1,
    ti_stride0,
    ti_stride1,
    out_stride0,
    out_stride1,
):
    # program_id(0) -> token_id (row)
    # program_id(1) -> tile index along columns
    token_id = tl.program_id(0)
    tile_id = tl.program_id(1)

    # Each program covers BLOCK_N consecutive columns
    indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)

    # Load request id for this token (no mask: grid is exact)
    req = tl.load(req_id_ptr + token_id)

    # Load token indices for this tile
    ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
    tok = tl.load(ti_ptr)  # int32

    # Only token == -1 should propagate as -1
    is_invalid_tok = tok < 0
    is_prefill = False
    if HAS_PREFILL:
        prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
        is_prefill = prefill_req_id >= 0
    # Compute block id and in-block offset
    block_id = tok // BLOCK_SIZE
    inblock_off = tok % BLOCK_SIZE

    # Guard block_table access
    valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
    bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
    is_invalid_tok |= ~valid_block
    base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
    out_val = base * BLOCK_SIZE + inblock_off

    # Override with prefill output if prefill is enabled
    if HAS_PREFILL:
        workspace_start = tl.load(
            workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
        )
        prefill_out = workspace_start + tok
        out_val = tl.where(is_prefill, prefill_out, out_val)
    out_val = tl.where(is_invalid_tok, -1, out_val)

    # Store results
    out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
    tl.store(out_ptr_ij, out_val)

get_prefill_workspace_size

get_prefill_workspace_size(max_model_len: int)
Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def get_prefill_workspace_size(max_model_len: int):
    # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
    # May be tuned later.
    # Memory usage: 5 * max_model_len * 576 * 2 bytes
    #   Example: DeepSeek-V3.2 with max_model_len=163840 ->
    #            5 * 163840 * 576 * 2 = ~900 MB
    # This fits nicely below the typical MoE workspace size of >2GB so this is "free"
    return max_model_len * 5

triton_convert_req_index_to_global_index

triton_convert_req_index_to_global_index(
    req_id: Tensor,
    block_table: Tensor,
    token_indices: Tensor,
    BLOCK_SIZE: int = 64,
    NUM_TOPK_TOKENS: int = 2048,
    BLOCK_N: int = 128,
    HAS_PREFILL_WORKSPACE: bool = False,
    prefill_workspace_request_ids: Tensor | None = None,
    prefill_workspace_starts: Tensor | None = None,
)

out[token_id, indice_id] = block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + token_indices[token_id, indice_id] % BLOCK_SIZE

Only when token_indices[token_id, indice_id] == -1 do we output -1. For safety, we also output -1 if the derived block_id would be out-of-bounds.

When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets instead of global cache slots. prefill_workspace_request_ids and prefill_workspace_starts must be provided.

int32 [num_tokens], -1 for decode else

prefill request index (maps to prefill_workspace_starts)

prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace starts for each prefill request

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def triton_convert_req_index_to_global_index(
    req_id: torch.Tensor,  # int32 [num_tokens]
    block_table: torch.Tensor,  # int32 [num_requests, max_num_blocks_per_req]
    token_indices: torch.Tensor,  # int32 [num_tokens, NUM_TOPK_TOKENS]
    BLOCK_SIZE: int = 64,
    NUM_TOPK_TOKENS: int = 2048,
    BLOCK_N: int = 128,  # tile width along columns
    HAS_PREFILL_WORKSPACE: bool = False,
    prefill_workspace_request_ids: torch.Tensor | None = None,
    prefill_workspace_starts: torch.Tensor | None = None,
):
    """
    out[token_id, indice_id] =
        block_table[req_id[token_id],
            token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
        + token_indices[token_id, indice_id] % BLOCK_SIZE

    Only when token_indices[token_id, indice_id] == -1 do we output -1.
    For safety, we also output -1 if the derived block_id would be
        out-of-bounds.

    When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
    instead of global cache slots. prefill_workspace_request_ids and
    prefill_workspace_starts must be provided.

    prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
        prefill request index (maps to prefill_workspace_starts)
    prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
        starts for each prefill request
    """
    assert req_id.dtype == torch.int32
    assert block_table.dtype == torch.int32
    assert token_indices.dtype == torch.int32
    assert token_indices.shape[1] == NUM_TOPK_TOKENS
    assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
        f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
    )

    if HAS_PREFILL_WORKSPACE:
        assert prefill_workspace_request_ids is not None
        assert prefill_workspace_starts is not None
        assert prefill_workspace_request_ids.dtype == torch.int32
        assert prefill_workspace_starts.dtype == torch.int32

    num_tokens = req_id.shape[0]
    max_num_blocks_per_req = block_table.shape[1]
    tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N

    # Ensure contiguous tensors on the same device
    req_id_c = req_id.contiguous()
    block_table_c = block_table.contiguous()
    token_indices_c = token_indices.contiguous()
    out = torch.empty_like(token_indices_c)

    # Strides in elements
    bt_stride0, bt_stride1 = block_table_c.stride()
    ti_stride0, ti_stride1 = token_indices_c.stride()
    out_stride0, out_stride1 = out.stride()

    # Prepare prefill pointers
    if HAS_PREFILL_WORKSPACE:
        assert prefill_workspace_request_ids is not None  # for mypy
        assert prefill_workspace_starts is not None  # for mypy
        assert prefill_workspace_request_ids.is_contiguous()
        assert prefill_workspace_starts.is_contiguous()

    # Exact 2D grid: tokens × column tiles
    grid = (num_tokens, tiles_per_row)

    _convert_req_index_to_global_index_kernel[grid](
        req_id_c,
        block_table_c,
        token_indices_c,
        out,
        prefill_workspace_request_ids,
        prefill_workspace_starts,
        # shapes / constexprs
        max_num_blocks_per_req,
        BLOCK_SIZE,
        BLOCK_N,
        HAS_PREFILL_WORKSPACE,
        # strides
        bt_stride0,
        bt_stride1,
        ti_stride0,
        ti_stride1,
        out_stride0,
        out_stride1,
    )
    return out