-
Notifications
You must be signed in to change notification settings - Fork 586
Factor out deinterleaving of bf16 vectors for MatVecs. #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
jan-wassenberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, great to see this coming together :) Thanks for sending the PR. Some sugestions:
gemma/ops.h
Outdated
|
|
||
| const hn::ScalableTag<float> df; | ||
|
|
||
| const auto vec_dequant = hwy::AllocateAligned<float>(kInner); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allocation can be quite slow, let's move this into gemma.cc's Activations. That would require plumbing through an extra tmp arg here, and the std::array storage should probably be the largest per-call size * max number of threads (say 128 or 256). Would you prefer if I made this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jan-wassenberg Sure. Thanks for the help!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in #173, see even_odd :)
|
@jan-wassenberg Thank you for reviewing! I'll branch on native BF16 support and clean up those near-duplicate MatVecAdd implementations, then turn off this PR's draft bit. |
|
@jan-wassenberg One more question: what's the best way to check that the target doesn't have a native bf16 product/dot product support (e.g., AVX512_BF16)? You previously pointed me at a highway PR, but it looks like Copybara scrubbed the branch when the PR was dropped. |
We can check |
gemma/ops.h
Outdated
| // vector to even-odd layout. | ||
| template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT, | ||
| typename VecT, typename AddT, | ||
| std::enable_if_t< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider replacing with HWY_IF_SAME2(VecT, float, hwy::bfloat16_t).
Also inline ProjQ and ProjKV lambdas, add missing includes/deps for ops_test. PiperOrigin-RevId: 629460608
|
@jan-wassenberg Done. Native bf16 checks added. Additionally, 59ebecc fixes a bug I introduced in 6a78a23. That commit affected overload resolution such that the specialization was never called. That's now fixed by moving the bulk of MatVecAdd into detail::MatVecAddInner and between even-odd and linear layouts inside a constexpr. Using a constexpr ensures that it's all downstream of MatVecAdd's type inference. |
|
I see even_odd storage is merged to dev. I'll merge. |
jan-wassenberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, looks good to me, thanks for updating! Can you give it a quick sanity check also with sfp weights (e.g. those prefixed 1.1 on Kaggle) to make sure that also still works?
Already did. Works great. |
|
Thanks for confirming :D |
|
Internal CI caught some unused vars: Please fix :) |
samkaufman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops. Hopefully that sorts it.
Remove extra Dot() overload MatVecAdd always adds, use MatVecT<kAdd> if conditional. Remove ununsed MatVecAddLoop and MatVecLoop No longer tsan-verify even_odd PiperOrigin-RevId: 631377279
Disable it for float32 because there is not enough benefit. PiperOrigin-RevId: 631788326
This specializes bf16-f32 and bf16-bf16 vector-matrix multiplications to first convert bf16 vectors into f32 buffers of vector-length strips of even- and odd-indexed values.
The 2B, bf16 model running on my Zen 1 machine sees ~10% throughput improvements to single-threaded prefill, single-threaded generation, and multi-threaded prefill, but only a marginal improvement to multi-threaded generation throughput.
This PR does not implement support for SFP.
TODOs: