forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 1
DeepSeek-V3.2-Exp #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
createthis
wants to merge
461
commits into
master
Choose a base branch
from
deepseek_v3_2_exp
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This comment was marked as off-topic.
This comment was marked as off-topic.
- Removed the forced CPU backend assignment of kvaware_indices
- src/llama-sparse-topk.cpp: deleted the block that moved result to
backend_cpu. Now it stays where it’s produced.
- src/llama-model.cpp: removed both instances of
ggml_backend_sched_set_tensor_backend(sched, kvaware_indices,
backend_cpu) so we don’t bounce indices to host in MLA and MHA sparse
paths.
- Gate debug-only float32 cast of indices:
- src/llama-sparse-topk.cpp: only cast to F32 and log the f32 indices
when LLAMA_SPARSE_DEBUG is set. This cuts extra nodes/copies in
normal runs.
- Increase default Top-K token tile size:
- src/llama-sparse-topk.cpp: default TILE_T from 32 to 128, still
overridable via LLAMA_SPARSE_TOPK_TILE_T.
branches, so we avoid the extra backend hop to CPU after apply_sparse_attention_kvaware
is vendoring. Update .gitignore.
public or protected in order to have an external method call.
gymnastics so we can feed fp8 indexer data to the WMMA HGRP kernel.
Inner loop now reads FP8 K codes instead of F32 K The launch currently passes null FP8 pointers
…ernel by quantizing from F32 K inside ggml_cuda_indexer_logits_fused_device when WMMA is used and DeepGEMM is not
…erically aligned with the CPU reference
…end-to-end at the GGML/llama-level. The fused indexer op is still not consuming the sidecar in CUDA (that’s the next step), but all the plumbing is there
gather kernel took 0.17 ms. This one takes 0.019 ms. A clear win.
in merge commit 184076. This brings that code in. However, there is a problem: Radix Sort is turned off because GPT 5.1 thinks we will never have a tile row count high enough to use it. I believe this points to an architectural issue on our end because I know Radix Sort is a critical performance feature of this kernel. I'm investigating.
useRadixSort true/false.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Don't merge. WIP.
When switching to this branch from
deepseek_v3_2_exp_simple, you need to run:Then recompile the project (I am assuming you have a single blackwell 6000 pro and 768gb of system ram):
You can add
LLAMA_SPARSE_PROF=1to get performance profiling of the kernels.There are a lot of different kernels in this branch gated by env vars. This is probably the most interesting config
as it uses the vendored VLLM top-k kernel:
LLAMA_FP8_INDEXER_CACHE=1 LLAMA_SPARSE_TOPK_VLLM=1 ./build/bin/llama-server \ --model /data2/DeepSeek-V3.2-Exp-GGUF/q4_k_m/DeepSeek-V3.2-Exp-Q4_K_M-00001-of-00009.gguf \ --alias DeepSeek-V3.2-Exp:671b-q4_k_m \ --no-webui \ --numa numactl \ --threads 32 \ --ctx-size 163840 \ --n-gpu-layers 62 \ -ot "blk\.(3|4|5|6|7|8|9)\.ffn_.*=CUDA0" \ -ot exps=CPU \ -ub 4096 -b 4096 \ --seed 3407 \ --temp 0.6 \ --top-p 0.95 \ --min-p 0.1 \ --log-colors on \ --flash-attn on \ --host 0.0.0.0 \ --prio 2 \ --jinja \ --port 114343.96 tok/s