Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0c5c04c
Experimental FP8 KV Cache. Not wired into anything yet.
createthis Dec 6, 2025
952cc97
Wired into the build, but not yet wired into DeepSeek V3.2
createthis Dec 6, 2025
513ea61
More FP8 KV changes
createthis Dec 7, 2025
8981f5a
Flesh out get_k
createthis Dec 7, 2025
a523479
Add a test for the fp8 kv cache
createthis Dec 7, 2025
49a0651
Test passing
createthis Dec 7, 2025
5112464
LLAMA_DEEPSEEK32_FP8_K=1 env var and wiring up.
createthis Dec 7, 2025
e952efa
Add the FP8 pack custom op hook and replacing the unsafe pointer‑base…
createthis Dec 7, 2025
3f17d34
Add GGML_OP_KV_DSMLA_PACK
createthis Dec 7, 2025
78da439
FP8 K is inferring again.
createthis Dec 7, 2025
18201a9
Vendor FlashMLA mla decode sm100 kernel.
createthis Dec 10, 2025
44d5c6f
Fix the newline situation. Bots. Sigh.
createthis Dec 10, 2025
696d060
Extend the printf profiling for flashmla
createthis Dec 10, 2025
7ea5669
Passing test.
createthis Dec 10, 2025
15bee54
Bring back the detail output. Bots. Sigh.
createthis Dec 10, 2025
e0e0f52
Merge branch 'deepseek_v3_2_exp_fp8_kv_cache' into deepseek_v3_2_exp_…
createthis Dec 10, 2025
d0185ca
- apply_sparse_attention_kvaware now accepts the FP8 KV blob
createthis Dec 11, 2025
0ac4361
Rip vendors/flashmla/sm100/decode/sparse_fp8/splitkv_mla.cu out of the
createthis Dec 11, 2025
63c581b
Fix build
createthis Dec 11, 2025
b190c19
Add sparse MLA decode glue code unit test.
createthis Dec 12, 2025
8e7701a
Fix LLAMA_SPARSE_DEBUG=1 during inference.
createthis Dec 12, 2025
7e91faf
Add tests/test-indexer-fp8-fused-glue.cpp
createthis Dec 12, 2025
0f6b45c
Add tests for fp8 indexer glue code. Still trying to track down
createthis Dec 12, 2025
a7cc37b
Add two more tests on our quest for accuracy:
createthis Dec 12, 2025
09fdef4
Two more tests
createthis Dec 12, 2025
bca3f42
- FP8 K quantization uses UE8M0, matching vLLM’s indexer_k_quant_and_…
createthis Dec 13, 2025
d4ac521
GPT 5.2 says these changes are important for correctness. It doesn't
createthis Dec 13, 2025
458a802
New LLAMA_INDEXER_FP8_TC=1 kernel. This is experimental and will
createthis Dec 14, 2025
8a0b70d
Add passing test for naive FP8 kernel. This will eventually become FP8
createthis Dec 14, 2025
686d84f
Add LLAMA_INDEXER_FP8_TC=1 FP8 tensor core MMA Lightning Indexer Kern…
createthis Dec 15, 2025
48f09a8
- Rewrote the FP8 cache quantization kernel to be shape-correct
createthis Dec 15, 2025
8a09b93
Remove .orig file. Oops. Didn't mean to commit that.
createthis Dec 15, 2025
3b8e6f5
Ensure that the FP8 indexer cache is not used if the env var is not set.
createthis Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ggml/include/ggml-cuda-indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ void ggml_cuda_mask_window_ends_device_to_host(struct ggml_backend_cuda_context
void ggml_cuda_mask_window_ends_device_to_host_simple(const float * dMask, int N_kv, int T, int * hEnds);
void ggml_cuda_mask_window_starts_device_to_host_simple(const float * dMask, int N_kv, int T, int * hStarts);


// Optional FP8 tensor-core lightning indexer launcher (DeepSeek V3.2)
// Expects FP8 K/Q and scales already prepared; computes logits [kv, Tc]
void ggml_cuda_indexer_logits_fp8_tc_hgrp_launch(
struct ggml_backend_cuda_context & ctx,
const unsigned char * K_fp8, // [kv, D] FP8 E4M3 codes
const float * K_sf, // [kv] per-row K scale (UE8M0)
const unsigned char * Q_fp8, // [Tc*H, D] FP8 E4M3 codes
const float * Q_sf, // [Tc*H] per-row Q scale (UE8M0)
const float * W, // [H, Tc]
const float * k_scale, // [kv]
int D, int H, int Tc, int kv,
const int * starts, const int * ends,
float * Out);

#ifdef __cplusplus
}
#endif
7 changes: 7 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ extern "C" {
GGML_OP_INDEXER_FUSED,
GGML_OP_SPARSE_MLA_DECODE,
GGML_OP_GLU,
GGML_OP_KV_DSMLA_PACK,

GGML_OP_INDEXER_K_CACHE_FP8,

Expand Down Expand Up @@ -764,6 +765,12 @@ extern "C" {


// Variant that accepts optional per-column windows [start,end)

GGML_API struct ggml_tensor * ggml_kv_dsmla_pack(
struct ggml_context * ctx,
struct ggml_tensor * k_latent_rope,
struct ggml_tensor * k_idxs,
struct ggml_tensor * k_blob);
GGML_API struct ggml_tensor * ggml_sparse_topk_radix_ex(
struct ggml_context * ctx,
struct ggml_tensor * scores,
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2362,6 +2362,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{
n_tasks = n_threads;
} break;
case GGML_OP_KV_DSMLA_PACK:
{
// trivial metadata op for FP8 KV; handled only on CUDA backend
n_tasks = 1;
} break;
case GGML_OP_NONE:
{
n_tasks = 1;
Expand Down
41 changes: 41 additions & 0 deletions ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_SOURCES_CUDA "vllm_topk_per_row.cu")
list(APPEND GGML_SOURCES_CUDA "fp8_indexer_cache.cu")
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX "ggml-cuda-deepgemm\.cu$")
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX "indexer-fp8-tc\.cu$")
list(APPEND GGML_SOURCES_CUDA "topk-radix.cu" indexer-fused.cu sparse-mla-decode.cu)

file(GLOB SRCS "template-instances/fattn-mma*.cu")
Expand Down Expand Up @@ -77,6 +78,28 @@ if (CUDAToolkit_FOUND)

list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX "mqa_attn_return_logits_kernel\\.cu$")



# FlashMLA sparse FP8 MLA decode (SM100/SM120) as an OBJECT library (experimental)
add_library(flashmla_sparse_kernels OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/vendors/flashmla/smxx/get_mla_metadata.cu
)
# Compile these kernels for SM120A (Blackwell / TCGen05), same as DeepGEMM and lightning kernels.
set_property(TARGET flashmla_sparse_kernels PROPERTY CUDA_ARCHITECTURES "120a")
set_property(TARGET flashmla_sparse_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(flashmla_sparse_kernels PRIVATE
${CMAKE_SOURCE_DIR}/ggml/include
${CMAKE_SOURCE_DIR}/ggml/src
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda
${CMAKE_CURRENT_SOURCE_DIR}/vendors/flashmla
${CMAKE_CURRENT_SOURCE_DIR}/vendors/flashmla/sm100
${CMAKE_CURRENT_SOURCE_DIR}/vendors/flashmla/smxx
${CMAKE_CURRENT_SOURCE_DIR}/vendors/cutlass/include
)
target_compile_options(flashmla_sparse_kernels PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr --expt-extended-lambda --use_fast_math>)
target_sources(ggml-cuda PRIVATE $<TARGET_OBJECTS:flashmla_sparse_kernels>)

# DeepGEMM FP8 paged MQA logits (SM100-style) as a separate OBJECT library
add_library(deepgemm_kernels OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/ggml-cuda-deepgemm.cu
Expand All @@ -97,6 +120,24 @@ if (CUDAToolkit_FOUND)
# Inject the compiled objects into ggml-cuda
target_sources(ggml-cuda PRIVATE $<TARGET_OBJECTS:deepgemm_kernels>)

# Dedicated FP8 TC indexer kernel (DeepSeek V3.2) as an OBJECT library,
# compiled only for Blackwell/TCGen05-style architectures (SM120A).
add_library(indexer_fp8_tc_kernels OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/indexer-fp8-tc.cu
)
set_property(TARGET indexer_fp8_tc_kernels PROPERTY CUDA_ARCHITECTURES "120a")
set_property(TARGET indexer_fp8_tc_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(indexer_fp8_tc_kernels PRIVATE
${CMAKE_SOURCE_DIR}/ggml/include
${CMAKE_SOURCE_DIR}/ggml/src
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda
${CMAKE_CURRENT_SOURCE_DIR}/vendors/tilelang/fp8_lightning_indexer
${CMAKE_CURRENT_SOURCE_DIR}/vendors/cutlass/include
)
target_compile_options(indexer_fp8_tc_kernels PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr --expt-extended-lambda --use_fast_math>)
target_sources(ggml-cuda PRIVATE $<TARGET_OBJECTS:indexer_fp8_tc_kernels>)

# Build the Lightning Indexer kernel as its own OBJECT library so we can
# give it a different CUDA arch than the rest of the project.
add_library(lightning_kernels OBJECT
Expand Down
63 changes: 63 additions & 0 deletions ggml/src/ggml-cuda/flashmla-sparse-mla.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "common.cuh"
#include "../../include/ggml-cuda-indexer.h"

// FlashMLA sparse MLA metadata helper is currently SM100-only. For now we
// exclude it from the SM120 build and always fall back to the GGML sparse MLA
// decode kernel. Once we have an SM120-compatible implementation we can
// re-enable this include and wire it appropriately.
// #include "vendors/flashmla/smxx/get_mla_metadata.h"

extern "C" void ggml_cuda_sparse_mla_decode_device(
ggml_backend_cuda_context & ctx,
const float * q,
const float * k,
const float * v,
const int32_t * topk,
int D, int Hq, int Hkv, int Dv,
int N, int Ksel,
float kq_scale, float softcap,
float * out);

extern "C" void ggml_cuda_sparse_mla_decode_flashmla_sm100(
ggml_backend_cuda_context & ctx,
const float * q,
const float * k,
const float * v,
const int32_t * topk,
const unsigned char * kv_blob,
int Dq,
int Hq,
int Hkv,
int Dv,
int Nkv,
int K,
float kq_scale,
float softcap,
float * out) {
#if CUDART_VERSION < 12000
(void)ctx; (void)q; (void)k; (void)v; (void)topk;
(void)Dq; (void)Hq; (void)Hkv; (void)Dv; (void)Nkv; (void)K;
(void)kq_scale; (void)softcap; (void)out;
return;
#else
// For now, only handle DeepSeek V3.2-shaped decode (Dq=576, Dv=512, Hkv=1)
// and fall back to the GGML kernel otherwise.
if (Dq != 576 || Dv != 512 || Hkv != 1) {
ggml_cuda_sparse_mla_decode_device(ctx, q, k, v, topk,
Dq, Hq, Hkv, Dv, Nkv, K,
kq_scale, softcap, out);
return;
}

// TODO: Implement full FlashMLA SM100 sparse FP8 decode wiring here by
// constructing BF16 Q, DS-MLA FP8 kcache view, sparse indices, metadata
// (GetDecodingMetadataParams), and DecodingParams, then calling:
// sm100::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream);
// run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
// For now, still delegate to the GGML F32 kernel until the mapping is
// completed.
ggml_cuda_sparse_mla_decode_device(ctx, q, k, v, topk,
Dq, Hq, Hkv, Dv, Nkv, K,
kq_scale, softcap, out);
#endif
}
88 changes: 49 additions & 39 deletions ggml/src/ggml-cuda/fp8_indexer_cache.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "common.cuh"
#include <cuda_fp16.h>
#include <cmath>
// FP8 indexer K cache quantization for DeepSeek V3.2.
// Layout matches vLLM's DeepseekV32IndexerCache indexer_k_quant_and_cache:
// kv_cache: [num_blocks, cache_block_size, cache_stride]
Expand Down Expand Up @@ -34,60 +35,71 @@ __global__ void k_indexer_fp8_quant_and_cache_kernel(
int quant_bs,
int cache_block_size,
int cache_stride) {
constexpr int VEC_SIZE = 4; // we process 4 floats (16B) at a time
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx =
(blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
if (head_dim_idx >= head_dim) return;
// One warp processes one (token, quant_block) along head_dim.
// quant_bs is in bytes (== elements) for FP8 codes.
constexpr int VEC_SIZE = 4; // 4 floats per lane

const int64_t token_idx = (int64_t) blockIdx.x;
const int64_t slot = slot_map[token_idx];
if (slot < 0) return; // padded token
if (slot < 0) return;

const int64_t block_idx = slot / cache_block_size;
const int64_t block_offset = slot % cache_block_size;
// Load a vector of VEC_SIZE values from K
const int64_t k_offset_vec = (token_idx * (int64_t) head_dim + head_dim_idx) / VEC_SIZE;
float2 packed = reinterpret_cast<const float2*>(k)[k_offset_vec];
float *vals = reinterpret_cast<float*>(&packed);
// Compute local amax over this vector

const int64_t blocks_per_row = (head_dim + quant_bs - 1) / quant_bs;
const int64_t qblk = (int64_t) blockIdx.y;
if (qblk >= blocks_per_row) return;

const int64_t block_start = qblk * (int64_t) quant_bs;
const int64_t lane = (int64_t) (threadIdx.x & 31);

float v[VEC_SIZE] = {0.f, 0.f, 0.f, 0.f};
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
const int64_t in_block = lane * VEC_SIZE + i;
const int64_t d = block_start + in_block;
if (in_block < quant_bs && d < head_dim) {
v[i] = k[token_idx * (int64_t) head_dim + d];
}
}

float amax = 0.0f;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
amax = fmaxf(amax, fabsf(vals[i]));
amax = fmaxf(amax, fabsf(v[i]));
}

#if __CUDA_ARCH__ >= 700
// Warp-wide reduction of amax within this quant block.
for (int mask = 16; mask > 0; mask >>= 1) {
#ifdef USE_ROCM
#ifdef USE_ROCM
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
#else
#else
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
#endif
#endif
}
#endif

float scale = fmaxf(amax, 1e-4f) / 448.0f;
// Base offset of this block in kv_cache
const int64_t block_base =
block_idx * (int64_t) cache_block_size * cache_stride;
// FP8 values region: [cache_block_size * head_dim] bytes per block
const int64_t vals_base = block_base + block_offset * (int64_t) head_dim;
const int64_t dst_offset = vals_base + head_dim_idx;
scale = exp2f(ceilf(log2f(scale)));

const int64_t block_base = block_idx * (int64_t) cache_block_size * cache_stride;
const int64_t vals_base = block_base + block_offset * (int64_t) head_dim;

#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
if (head_dim_idx + i < head_dim) {
float v = vals[i];
float scaled = v / scale;
uint8_t code = f32_to_fp8e4m3(scaled);
kv_cache[dst_offset + i] = code;
const int64_t in_block = lane * VEC_SIZE + i;
const int64_t d = block_start + in_block;
if (in_block < quant_bs && d < head_dim) {
kv_cache[vals_base + d] = f32_to_fp8e4m3(v[i] / scale);
}
}
// Write FP32 scale for this quant block. The block index along head_dim is
// (block_offset * head_dim + head_dim_idx) / quant_bs.
if (threadIdx.x == 0 && threadIdx.y == 0) {
const int64_t block_linear = block_offset * (int64_t) head_dim + head_dim_idx;
const int64_t scale_block_idx = block_linear / quant_bs;

if (lane == 0) {
const int64_t scales_base = block_base + (int64_t) cache_block_size * head_dim;
const int64_t scale_byte_offset = scales_base + scale_block_idx * 4; // 4 bytes per FP32 scale
*reinterpret_cast<float*>(&kv_cache[scale_byte_offset]) = scale;
const int64_t scale_block_idx = block_offset * blocks_per_row + qblk;
const int64_t scale_byte_offset = scales_base + scale_block_idx * 4;
*reinterpret_cast<float *>(&kv_cache[scale_byte_offset]) = scale;

}
}
} // namespace ggml_cuda_fp8_indexer
Expand All @@ -102,10 +114,8 @@ extern "C" void ggml_cuda_indexer_k_cache_fp8_quantize(
int cache_block_size,
int cache_stride) {
cudaStream_t stream = ctx.stream();
constexpr int VEC_SIZE = 4;
dim3 grid(num_tokens,
(head_dim + quant_bs * VEC_SIZE - 1) / (quant_bs * VEC_SIZE));
dim3 block(32, VEC_SIZE);
dim3 grid(num_tokens, (head_dim + quant_bs - 1) / quant_bs);
dim3 block(32, 1);
ggml_cuda_fp8_indexer::k_indexer_fp8_quant_and_cache_kernel<<<grid, block, 0, stream>>>(
dK, dKvCache, dSlotMap, head_dim, quant_bs, cache_block_size, cache_stride);
}
Expand Down
Loading
Loading