-
-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Fix: Support past_key_values in model.generate for multi-turn conversations #3653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Vk/past key vals
for more information, see https://pre-commit.ci
Summary of ChangesHello @vivekkalyanarangan30, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request provides a comprehensive and well-described fix for handling past_key_values in multi-turn conversations, addressing shape mismatches and position_ids issues. The changes are well-structured, and the inclusion of a dedicated test case is excellent for ensuring the fix is robust and prevents regressions. My review includes a critical bug fix for a potential NameError and a suggestion to remove redundant code to improve maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
unsloth/unsloth/models/llama.py
Lines 215 to 219 in f60e842
| "dtype": self.dtype, | |
| "cache_position": torch.arange( | |
| cache_length, cache_length + 1, device = input_ids.device | |
| ), | |
| "batch_size": bs, |
In _fast_prepare_inputs_for_generation, the kwargs built at these lines reference cache_length and bs, but the prior assignment bs, cache_length = input_ids.shape was removed. When past_key_values is supplied and the model implements _prepare_4d_causal_attention_mask_with_cache_position, this branch now raises UnboundLocalError before the forward call, so generation with cached keys/values will crash instead of running.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if position_ids is not None: | ||
| # Robust fix: Slice position_ids if it's longer than input_ids | ||
| # Handle both 1D and 2D position_ids | ||
| if position_ids.dim() == 2: | ||
| if position_ids.shape[1] > input_ids.shape[1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define position_ids before non-classification PEFT forward
The newly added PEFT branch checks if position_ids is not None but position_ids is neither a parameter nor initialized in this function, so the guard itself raises NameError for every non-classification call. This prevents standard PEFT inference or training from reaching the underlying model at all, even when callers do not intend to pass position IDs.
Useful? React with 👍 / 👎.
Fix: Support `past_key_values` in `model.generate`
for more information, see https://pre-commit.ci
|
@vivekkalyanarangan30 Thanks for this! So is the goal to allow yourself to pass in custom |
|
@danielhanchen yes, for long system prompt the pre-computed kv-cache can be passed here in |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request provides a comprehensive fix for handling past_key_values in multi-turn conversations, addressing a critical RuntimeError. The changes are well-structured and logical, including the introduction of a conditional fast path for single-token decoding, robust handling of position_ids, and a safety slicing mechanism. The addition of a dedicated test case (test_issue_497.py) is excellent for verifying the fix and preventing regressions. My main suggestion is to refactor a small piece of duplicated code into a shared utility function to improve maintainability. Overall, this is a high-quality contribution that significantly improves the model's generation capabilities.
| if position_ids is not None: | ||
| # Robust fix: Slice position_ids if it's longer than input_ids | ||
| # Handle both 1D and 2D position_ids | ||
| if position_ids.dim() == 2: | ||
| if position_ids.shape[1] > input_ids.shape[1]: | ||
| position_ids = position_ids[:, -input_ids.shape[1] :] | ||
| elif position_ids.dim() == 1: | ||
| if position_ids.shape[0] > input_ids.shape[1]: | ||
| position_ids = position_ids[-input_ids.shape[1] :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic for robustly slicing position_ids is also present in unsloth/models/llama.py (in PeftModel_fast_forward). To improve maintainability and avoid code duplication, consider refactoring this block into a shared utility function.
For example, you could create a helper in a shared utility file:
def slice_position_ids_if_needed(position_ids, input_ids):
"""Slices position_ids to match the length of input_ids if they are longer."""
if position_ids is None:
return None
if position_ids.dim() == 2:
if position_ids.shape[1] > input_ids.shape[1]:
return position_ids[:, -input_ids.shape[1]:]
elif position_ids.dim() == 1:
if position_ids.shape[0] > input_ids.shape[1]:
return position_ids[-input_ids.shape[1]:]
return position_idsThis would simplify the code here and in llama.py, making future maintenance easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielhanchen shall I incorporate this review comment?
Datta0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an effort to rework attention. So we might need to ensure that this works properly even alongside the said changes
|
@Datta0 any idea when we can proceed with this? |
|
Hey @vivekkalyanarangan30 |
Added a KVCache performance comparison example
for more information, see https://pre-commit.ci
|
@Datta0 I have provided a script to test outputs and performance, hope this helps! |
comparison file cleanup
for more information, see https://pre-commit.ci
Description
This PR resolves Issue #497, enabling
FastLanguageModelto correctly handlepast_key_valuesduring generation, specifically for multi-turn conversations where a new prompt is appended to an existing history.The Problem
Previously, passing
past_key_valuestomodel.generatecaused aRuntimeError(shape mismatch) orIndexErrorduring the prefill phase (processing the new user prompt). This occurred because:LlamaModel_fast_forward_inferenceassumes a single-token input (q_len=1) for decoding. However, during the prefill step of a multi-turn conversation, the input contains multiple tokens (the new prompt), causing a shape mismatch in the attention mechanism.position_ids: The_fast_prepare_inputs_for_generationfunction did not correctly slice or generateposition_idsfor the new tokens, leading to mismatches whentransformerspassed them to the model.transformerspassed unslicedposition_ids(matching the full sequence length) to the forward pass, causing crashes when the model expectedposition_idsmatching the slicedinput_ids.The Solution
This PR implements a robust fix across
llama.pyandmistral.py:CausalLM_fast_forward(Llama) andMistralForCausalLM_fast_forward(Mistral) to only use the optimized single-token inference kernel wheninput_ids.shape[1] == 1. For multi-token inputs (prefill), it falls back to the standard forward pass (which is still optimized with Unsloth's attention kernels but handles sequence processing correctly).position_idsHandling: Added logic in_fast_prepare_inputs_for_generationto correctly sliceinput_idsand generate/sliceposition_idsto match the new tokens.position_idsare passed with a length greater thaninput_ids(which can happen iftransformersignores the prepared inputs), they are automatically sliced to match the input length. This preventsRuntimeErrorandIndexErrorregardless of the upstream behavior.ValueErrorwherecache_implementation="dynamic"was being set even whenpast_key_valueswere provided.Verification
tests/test_issue_497.pywhich reproduces the multi-turn conversation scenario and asserts that generation completes successfully with correct output.LlamaModel_fast_forward_inferencekernel, ensuring no regression in generation speed.Checklist
past_key_values#497tests/test_issue_497.py)