vllm.model_executor.layers.fla.ops.chunk_scaled_dot_kkt ¶
   chunk_scaled_dot_kkt_fwd ¶
 chunk_scaled_dot_kkt_fwd(
    k: Tensor,
    g: Tensor | None = None,
    beta: Tensor | None = None,
    cu_seqlens: LongTensor | None = None,
    chunk_size: int = 64,
    output_dtype: dtype = float32,
) -> Tensor
Compute beta * K * K^T.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
 k  |   Tensor  |    The key tensor of shape   |  required | 
 beta  |   Tensor  |    The beta tensor of shape   |   None  |  
 g  |   Tensor  |    The cumulative sum of the gate tensor of shape   |   None  |  
 cu_seqlens  |   LongTensor  |    The cumulative sequence lengths of the input tensor. Default: None  |   None  |  
 chunk_size  |   int  |    The chunk size. Default: 64.  |   64  |  
 output_dtype  |   dtype  |    The dtype of the output tensor. Default:   |   float32  |  
Returns:
| Type | Description | 
|---|---|
 Tensor  |    beta * K * K^T of shape   |  
Source code in vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
   chunk_scaled_dot_kkt_fwd_kernel ¶
 chunk_scaled_dot_kkt_fwd_kernel(
    k,
    beta,
    g,
    A,
    cu_seqlens,
    chunk_indices,
    T,
    H: constexpr,
    Hg: constexpr,
    K: constexpr,
    BT: constexpr,
    BK: constexpr,
    IS_VARLEN: constexpr,
    USE_G: constexpr,
)