Skip to content

Conversation

@vivekkalyanarangan30
Copy link

Description

This PR resolves Issue #497, enabling FastLanguageModel to correctly handle past_key_values during generation, specifically for multi-turn conversations where a new prompt is appended to an existing history.

The Problem

Previously, passing past_key_values to model.generate caused a RuntimeError (shape mismatch) or IndexError during the prefill phase (processing the new user prompt). This occurred because:

  1. Optimized Inference Path Assumption: Unsloth's LlamaModel_fast_forward_inference assumes a single-token input (q_len=1) for decoding. However, during the prefill step of a multi-turn conversation, the input contains multiple tokens (the new prompt), causing a shape mismatch in the attention mechanism.
  2. Missing/Incorrect position_ids: The _fast_prepare_inputs_for_generation function did not correctly slice or generate position_ids for the new tokens, leading to mismatches when transformers passed them to the model.
  3. Shape Mismatches: In some environments, transformers passed unsliced position_ids (matching the full sequence length) to the forward pass, causing crashes when the model expected position_ids matching the sliced input_ids.

The Solution

This PR implements a robust fix across llama.py and mistral.py:

  1. Conditional Fast Path: Modified CausalLM_fast_forward (Llama) and MistralForCausalLM_fast_forward (Mistral) to only use the optimized single-token inference kernel when input_ids.shape[1] == 1. For multi-token inputs (prefill), it falls back to the standard forward pass (which is still optimized with Unsloth's attention kernels but handles sequence processing correctly).
  2. Robust position_ids Handling: Added logic in _fast_prepare_inputs_for_generation to correctly slice input_ids and generate/slice position_ids to match the new tokens.
  3. Safety Slicing: Added a safety check in the forward pass of both Llama and Mistral models. If position_ids are passed with a length greater than input_ids (which can happen if transformers ignores the prepared inputs), they are automatically sliced to match the input length. This prevents RuntimeError and IndexError regardless of the upstream behavior.
  4. Cache Implementation Conflict: Fixed a ValueError where cache_implementation="dynamic" was being set even when past_key_values were provided.

Verification

  • New Test Case: Added tests/test_issue_497.py which reproduces the multi-turn conversation scenario and asserts that generation completes successfully with correct output.
  • Manual Verification: Verified that the fix works on Google Colab (T4 GPU) where the issue was originally reported.
  • Performance: Confirmed that single-token decoding still utilizes the optimized LlamaModel_fast_forward_inference kernel, ensuring no regression in generation speed.

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @vivekkalyanarangan30, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FastLanguageModel's capability to manage multi-turn conversations by rectifying issues related to past_key_values and position_ids during generation. The changes ensure that the model can seamlessly process new prompts appended to existing conversational history, leveraging optimized inference paths where appropriate while maintaining correctness and preventing common runtime errors.

Highlights

  • Multi-turn Conversation Support: Enables FastLanguageModel to correctly handle past_key_values during generation for multi-turn conversations, resolving previous RuntimeError and IndexError issues.
  • Conditional Fast Path for Inference: The optimized single-token inference kernel is now conditionally applied only when input_ids has a sequence length of 1, ensuring the model falls back to the standard forward pass for multi-token prefill steps.
  • Robust position_ids Handling: Implements correct slicing and generation of position_ids for new tokens when past_key_values are present, preventing mismatches and ensuring accurate positional encoding.
  • Safety Slicing for position_ids: Adds a safeguard to automatically slice position_ids if their length exceeds input_ids length, making the forward pass more robust against upstream transformers behavior.
  • cache_implementation Conflict Resolution: Fixes a ValueError by ensuring cache_implementation="dynamic" is only set when past_key_values are not already provided, preventing conflicts.
  • Comprehensive Verification: Includes a new dedicated test case (tests/test_issue_497.py), manual GPU verification on Google Colab, and confirmation of no performance regression for standard single-token decoding.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a comprehensive and well-described fix for handling past_key_values in multi-turn conversations, addressing shape mismatches and position_ids issues. The changes are well-structured, and the inclusion of a dedicated test case is excellent for ensuring the fix is robust and prevents regressions. My review includes a critical bug fix for a potential NameError and a suggestion to remove redundant code to improve maintainability.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

"dtype": self.dtype,
"cache_position": torch.arange(
cache_length, cache_length + 1, device = input_ids.device
),
"batch_size": bs,

P0 Badge Reintroduce bs/cache_length before preparing attention mask

In _fast_prepare_inputs_for_generation, the kwargs built at these lines reference cache_length and bs, but the prior assignment bs, cache_length = input_ids.shape was removed. When past_key_values is supplied and the model implements _prepare_4d_causal_attention_mask_with_cache_position, this branch now raises UnboundLocalError before the forward call, so generation with cached keys/values will crash instead of running.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +1548 to +1552
if position_ids is not None:
# Robust fix: Slice position_ids if it's longer than input_ids
# Handle both 1D and 2D position_ids
if position_ids.dim() == 2:
if position_ids.shape[1] > input_ids.shape[1]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Define position_ids before non-classification PEFT forward

The newly added PEFT branch checks if position_ids is not None but position_ids is neither a parameter nor initialized in this function, so the guard itself raises NameError for every non-classification call. This prevents standard PEFT inference or training from reaching the underlying model at all, even when callers do not intend to pass position IDs.

Useful? React with 👍 / 👎.

@danielhanchen
Copy link
Contributor

@vivekkalyanarangan30 Thanks for this! So is the goal to allow yourself to pass in custom past_key_values tensors? Is this correct?

@vivekkalyanarangan30
Copy link
Author

vivekkalyanarangan30 commented Nov 30, 2025

@danielhanchen yes, for long system prompt the pre-computed kv-cache can be passed here in past_key_ values achieving prompt caching which will be beneficial - for straight speedups and also unlock techniques like speculative decoding.

@danielhanchen
Copy link
Contributor

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a comprehensive fix for handling past_key_values in multi-turn conversations, addressing a critical RuntimeError. The changes are well-structured and logical, including the introduction of a conditional fast path for single-token decoding, robust handling of position_ids, and a safety slicing mechanism. The addition of a dedicated test case (test_issue_497.py) is excellent for verifying the fix and preventing regressions. My main suggestion is to refactor a small piece of duplicated code into a shared utility function to improve maintainability. Overall, this is a high-quality contribution that significantly improves the model's generation capabilities.

Comment on lines +275 to +283
if position_ids is not None:
# Robust fix: Slice position_ids if it's longer than input_ids
# Handle both 1D and 2D position_ids
if position_ids.dim() == 2:
if position_ids.shape[1] > input_ids.shape[1]:
position_ids = position_ids[:, -input_ids.shape[1] :]
elif position_ids.dim() == 1:
if position_ids.shape[0] > input_ids.shape[1]:
position_ids = position_ids[-input_ids.shape[1] :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for robustly slicing position_ids is also present in unsloth/models/llama.py (in PeftModel_fast_forward). To improve maintainability and avoid code duplication, consider refactoring this block into a shared utility function.

For example, you could create a helper in a shared utility file:

def slice_position_ids_if_needed(position_ids, input_ids):
    """Slices position_ids to match the length of input_ids if they are longer."""
    if position_ids is None:
        return None
    
    if position_ids.dim() == 2:
        if position_ids.shape[1] > input_ids.shape[1]:
            return position_ids[:, -input_ids.shape[1]:]
    elif position_ids.dim() == 1:
        if position_ids.shape[0] > input_ids.shape[1]:
            return position_ids[-input_ids.shape[1]:]
    
    return position_ids

This would simplify the code here and in llama.py, making future maintenance easier.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielhanchen shall I incorporate this review comment?

Copy link
Collaborator

@Datta0 Datta0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an effort to rework attention. So we might need to ensure that this works properly even alongside the said changes

@vivekkalyanarangan30
Copy link
Author

@Datta0 any idea when we can proceed with this?

@Datta0
Copy link
Collaborator

Datta0 commented Dec 11, 2025

Hey @vivekkalyanarangan30
Thanks for the wonderful contributions. To top it off, if you can provide a notebook or script that can do a small comparison between the baseline (default way) vs passing in custom KV pairs, that would be of great help
Once I test that, we can proceed :)

@vivekkalyanarangan30
Copy link
Author

@Datta0 I have provided a script to test outputs and performance, hope this helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allow passing in custom past_key_values

3 participants