Skip to content

Conversation

@samkaufman
Copy link
Contributor

@samkaufman samkaufman commented Apr 29, 2024

This specializes bf16-f32 and bf16-bf16 vector-matrix multiplications to first convert bf16 vectors into f32 buffers of vector-length strips of even- and odd-indexed values.

The 2B, bf16 model running on my Zen 1 machine sees ~10% throughput improvements to single-threaded prefill, single-threaded generation, and multi-threaded prefill, but only a marginal improvement to multi-threaded generation throughput.

This PR does not implement support for SFP.

TODOs:

  • Check for native BF16 support.
  • Lift out allocations of f32 vectors. (@jan-wassenberg volunteered to handle this.)
  • Clean up some MatVecAdd code duplication.

Copy link
Member

@jan-wassenberg jan-wassenberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, great to see this coming together :) Thanks for sending the PR. Some sugestions:

gemma/ops.h Outdated

const hn::ScalableTag<float> df;

const auto vec_dequant = hwy::AllocateAligned<float>(kInner);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allocation can be quite slow, let's move this into gemma.cc's Activations. That would require plumbing through an extra tmp arg here, and the std::array storage should probably be the largest per-call size * max number of threads (say 128 or 256). Would you prefer if I made this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jan-wassenberg Sure. Thanks for the help!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #173, see even_odd :)

@samkaufman
Copy link
Contributor Author

samkaufman commented Apr 29, 2024

@jan-wassenberg Thank you for reviewing! I'll branch on native BF16 support and clean up those near-duplicate MatVecAdd implementations, then turn off this PR's draft bit.

@samkaufman
Copy link
Contributor Author

@jan-wassenberg One more question: what's the best way to check that the target doesn't have a native bf16 product/dot product support (e.g., AVX512_BF16)? You previously pointed me at a highway PR, but it looks like Copybara scrubbed the branch when the PR was dropped.

@jan-wassenberg jan-wassenberg mentioned this pull request Apr 30, 2024
67 tasks
@jan-wassenberg
Copy link
Member

@jan-wassenberg One more question: what's the best way to check that the target doesn't have a native bf16 product/dot product support (e.g., AVX512_BF16)? You previously pointed me at a highway PR, but it looks like Copybara scrubbed the branch when the PR was dropped.

We can check #if defined(HWY_NATIVE_DOT_BF16) && HWY_NATIVE_DOT_BF16 :)

gemma/ops.h Outdated
// vector to even-odd layout.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT,
std::enable_if_t<
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider replacing with HWY_IF_SAME2(VecT, float, hwy::bfloat16_t).

copybara-service bot pushed a commit that referenced this pull request Apr 30, 2024
Also inline ProjQ and ProjKV lambdas,
add missing includes/deps for ops_test.

PiperOrigin-RevId: 629460608
@samkaufman
Copy link
Contributor Author

samkaufman commented Apr 30, 2024

@jan-wassenberg Done. Native bf16 checks added.

Additionally, 59ebecc fixes a bug I introduced in 6a78a23. That commit affected overload resolution such that the specialization was never called. That's now fixed by moving the bulk of MatVecAdd into detail::MatVecAddInner and between even-odd and linear layouts inside a constexpr. Using a constexpr ensures that it's all downstream of MatVecAdd's type inference.

@samkaufman
Copy link
Contributor Author

I see even_odd storage is merged to dev. I'll merge.

@samkaufman samkaufman marked this pull request as ready for review April 30, 2024 23:24
Copy link
Member

@jan-wassenberg jan-wassenberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, looks good to me, thanks for updating! Can you give it a quick sanity check also with sfp weights (e.g. those prefixed 1.1 on Kaggle) to make sure that also still works?

@samkaufman
Copy link
Contributor Author

Can you give it a quick sanity check also with sfp weights (e.g. those prefixed 1.1 on Kaggle) to make sure that also still works?

Already did. Works great.

@jan-wassenberg
Copy link
Member

Thanks for confirming :D

@jan-wassenberg jan-wassenberg added the copybara-import Trigger Copybara for merging pull requests label May 1, 2024
@jan-wassenberg
Copy link
Member

Internal CI caught some unused vars:

third_party/gemma_cpp/gemma/ops.h:101:14: error: unused variable 'odd' [-Werror,-Wunused-variable]
  101 |   const auto odd = Set(du32, 0xFFFF0000u);
      |              ^~~
third_party/gemma_cpp/gemma/ops.h:364: error: unused variable 'df' [-Werror,-Wunused-variable]
  364 |   const hn::ScalableTag<float> df;
      |                                ^~
third_party/gemma_cpp/gemma/ops.h:366: error: unused variable 'kNumStrips' [-Werror,-Wunused-variable]
  366 |   constexpr size_t kNumStrips = kOuter / kRowsPerStrip;

Please fix :)

Copy link
Contributor Author

@samkaufman samkaufman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Hopefully that sorts it.

@copybara-service copybara-service bot merged commit 6eeef2e into google:dev May 3, 2024
copybara-service bot pushed a commit that referenced this pull request May 7, 2024
Remove extra Dot() overload
MatVecAdd always adds, use MatVecT<kAdd> if conditional.
Remove ununsed MatVecAddLoop and MatVecLoop
No longer tsan-verify even_odd

PiperOrigin-RevId: 631377279
copybara-service bot pushed a commit that referenced this pull request May 8, 2024
Disable it for float32 because there is not enough benefit.

PiperOrigin-RevId: 631788326
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

copybara-import Trigger Copybara for merging pull requests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants