Skip to content

Conversation

@sambhavnoobcoder
Copy link
Contributor

Implement ensure_weight_tying for trainable_token_indices

Summary

This PR implements consistent weight tying behavior for trainable_token_indices as specified in issue #2864. It extends the ensure_weight_tying parameter (introduced in PR #2803) to work with trainable_token_indices, providing users explicit control over weight tying between embeddings and LM head.

Fixes #2864 (trainable_token_indices portion)


Problem Statement

Background

PEFT models sometimes need to handle tied weights between embedding layers and LM head layers (when tie_word_embeddings=True). The ensure_weight_tying parameter was introduced in PR #2803 to give users explicit control over this behavior for modules_to_save. However, the same control was missing for trainable_token_indices.

The Issue

Issue identified that the weight tying behavior for trainable_token_indices was not consistent across different scenarios. Specifically, there were four cases that needed to be implemented:

  1. Untied model with ensure_weight_tying=True: Should warn users that weight tying cannot be applied
  2. Tied model with ensure_weight_tying=True and different indices: Should error, as it's impossible to tie adapters with different token indices
  3. Tied model with ensure_weight_tying=False and different indices: Should treat layers as separate (backwards compatibility behavior)
  4. Tied model with ensure_weight_tying=True and same indices: Should apply weight tying correctly

Solution Approach

Implementation Strategy:

  1. Check weight tying configuration early (before creating wrappers)
  2. Detect if user specified both embedding and lm_head layers in dict format
  3. Check if their token indices match or differ
  4. Apply appropriate logic based on the configuration matrix from the issue
  5. Skip creating wrappers for layers that will be tied later

Changes Made

1. Updated Configuration Documentation

File: src/peft/tuners/lora/config.py

Updated the ensure_weight_tying parameter docstring to clarify that it now applies to both modules_to_save and trainable_token_indices, making the documentation consistent with the implementation.

2. Implemented Weight Tying Logic

File: src/peft/utils/other.py

Added comprehensive logic within the existing trainable_token_indices handling block:

Key Components:

  • Early Detection: Check weight tying configuration before creating any wrappers
  • Layer Detection: Identify if both embedding and lm_head layers are specified
  • Index Comparison: Determine if token indices match between the layers
  • Skip Logic: Prevent double-wrapping by skipping layers that will be tied
  • Warning System: Inform users when their configuration cannot be applied
  • Error Handling: Raise clear errors for contradictory configurations
  • Backwards Compatibility: Preserve existing behavior when ensure_weight_tying=False

Four Cases Implemented:

  1. Case 1 - Warning for Untied Models:

    • When: weights_tied=False + ensure_weight_tying=True
    • Action: Issue warning that weight tying cannot be applied
    • Rationale: Model doesn't have tied weights, so user's request cannot be fulfilled
  2. Case 2 - Error for Contradictory Configuration:

    • When: weights_tied=True + ensure_weight_tying=True + different indices
    • Action: Raise ValueError with clear explanation
    • Rationale: Cannot tie adapters that operate on different token indices
  3. Case 3 - Backwards Compatibility:

    • When: weights_tied=True + ensure_weight_tying=False + different indices
    • Action: Treat layers as separate (no tying)
    • Rationale: User explicitly opted out, respect their choice even if model supports tying
  4. Case 4 - Apply Tying:

    • When: Other combinations where tying is appropriate
    • Action: Create tied adapters that share parameters
    • Rationale: Normal weight tying behavior

3. Comprehensive Test Suite

File: tests/test_trainable_tokens.py

Added 7 new test methods covering all scenarios:

Test Coverage:

  • test_ensure_weight_tying_warns_when_model_not_tied_list_format: Verifies warning for list format
  • test_ensure_weight_tying_warns_when_model_not_tied_dict_format: Verifies warning for dict format
  • test_weight_tying_bc_different_indices_treated_separately: Verifies backwards compatibility
  • test_ensure_weight_tying_errors_with_different_indices: Verifies error for contradictory config
  • test_ensure_weight_tying_applied_with_same_indices: Verifies tying with same indices
  • test_weight_tying_bc_same_indices_applied: Verifies BC for same indices
  • test_ensure_weight_tying_with_single_layer: Verifies list format tying

Testing Results

New Tests

All 7 new tests pass successfully:

  • test_ensure_weight_tying_warns_when_model_not_tied_list_format
  • test_ensure_weight_tying_warns_when_model_not_tied_dict_format
  • test_weight_tying_bc_different_indices_treated_separately
  • test_ensure_weight_tying_errors_with_different_indices
  • test_ensure_weight_tying_applied_with_same_indices
  • test_weight_tying_bc_same_indices_applied
  • test_ensure_weight_tying_with_single_layer

Backwards Compatibility

This implementation maintains full backwards compatibility:

Default Behavior Unchanged: ensure_weight_tying defaults to False, preserving existing behavior
No Breaking Changes: Existing code continues to work without modification
Opt-in Enhancement: Users must explicitly set ensure_weight_tying=True to use new features
BC Mode Preserved: When ensure_weight_tying=False, existing automatic tying still works for compatible configurations


Screenshots

Screenshot 2025-10-26 at 7 20 09 PM

Checklist

  • Implementation follows the specification in issue Deal with weight tying consistently #2864
  • All 7 new tests pass
  • Backwards compatibility maintained
  • Documentation updated (docstring)
  • Code is scoped only to trainable_token_indices
  • Error messages are clear and actionable
  • Warning messages inform users appropriately

cc: @BenjaminBossan


Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for a lot for handling the update of weight tying of trainable tokens. What's there already looks quite good, but I wonder if we can simplify the implementation, please check my suggestions.

Regarding the tests, I wanted to map the tests you wrote onto the table from #2864, this is what I ended up with:

weights tied ensure_weight_tying LoraConfig trainable_token_indices result test
False False [1, 2, 3] trainable tokens on embeddings only
False True [1, 2, 3] warn & trainable tokens on embeddings only test_ensure_weight_tying_warns_when_model_not_tied_list_format
True False [1, 2, 3] tied trainable tokens
True True [1, 2, 3] tied trainable tokens test_ensure_weight_tying_with_single_layer
False False {"lm_head": [1,2], "embed_tokens": [1,2]} treat as separate
False True {"lm_head": [1,2], "embed_tokens": [1,2]} warn & treat as separate
True False {"lm_head": [1,2], "embed_tokens": [1,2]} tied trainable tokens test_weight_tying_bc_same_indices_applied
True True {"lm_head": [1,2], "embed_tokens": [1,2]} tied trainable tokens test_ensure_weight_tying_applied_with_same_indices
False False {"lm_head": [1,2], "embed_tokens": [3,4]} treat as separate
False True {"lm_head": [1,2], "embed_tokens": [3,4]} warn & treat as separate
True False {"lm_head": [1,2], "embed_tokens": [3,4]} *treat as separate test_weight_tying_bc_different_indices_treated_separately
True True {"lm_head": [1,2], "embed_tokens": [3,4]} *error test_ensure_weight_tying_errors_with_different_indices

Does this look right to you? I think it means there are still a few gaps in the tests, could you please provide the missing ones? Some tests could be combined via pytest.mark.parametrize if the expected outcomes are the same.

]
assert warnings_found

def test_ensure_weight_tying_warns_when_model_not_tied_dict_format(self, model_weight_untied, recwarn):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test can be merged with test_ensure_weight_tying_warns_when_model_not_tied_list_format by parametrizing the trainable_token_indices argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7

Comment on lines 992 to 996
warnings_list = [w.message.args[0] for w in recwarn]
warnings_found = [
msg for msg in warnings_list if "ensure_weight_tying=True but the model does not have tied weights" in msg
]
assert warnings_found
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit more elegant to do:

expected = ...
assert any(expected in msg for msg in warings_list)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7

ensure_weight_tying=True,
)

with pytest.raises(ValueError) as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use:

msg = "Cannot ensure weight tying when different token indices are specified"
with pytest.raises(ValueError, match=msg):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7

ensure_weight_tying = getattr(peft_config, "ensure_weight_tying", False)

# Check if we're dealing with dict format that specifies both embed_tokens and lm_head
is_dict_format = isinstance(peft_config.trainable_token_indices, dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need is_dict_format. The check below, len(target_layers) > 1, is already enough, is it not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes , re-reviewed this , and simplified the logic significantly . refrence 232c6e7 for implementation .

Comment on lines 1487 to 1490
if "embed" in key_lower and not ("lm" in key_lower or "head" in key_lower):
embed_key = key
elif "lm_head" in key_lower or ("head" in key_lower and "lm" not in key_lower):
lm_head_key = key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we overcomplicate things here. If there are multiple target_layers, can we not just compare them to the tied weights? Is it important to identify here which one is for the embedding and which one is for the LM head?

Below, you're using the names for the error message, which is a nice touch, but if we can refrain from guessing here, it would be worth it to make the error message more generic IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i relooked at this and removed the string matching logic (checking for "embed", "lm_head", etc.) and now directly compare target layers against model._tied_weights_keys and the actual embedding layer. The error message is now generic, showing all conflicting tied layers instead of assuming specific names.

indices_mismatch = True
else:
# Same indices - if weights are tied and we're applying tying, skip lm_head (it'll be tied later)
if weights_tied and not (not ensure_weight_tying and False): # Will apply tying
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check makes no sense to me, why and False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7

@sambhavnoobcoder
Copy link
Contributor Author

About the test coverage , the table looks correct. I've filled all 6 gaps in the test coverage:

  • Added 2 new standalone test functions (test_untied_model_list_format_no_ensure and test_tied_model_list_format_no_ensure)
  • Expanded the parametrized test_ensure_weight_tying_warns_when_model_not_tied from 2 to 4 scenarios (adding the dict format cases)
  • Added parametrized test_untied_model_dict_no_ensure covering 2 scenarios (same and different indices)

@sambhavnoobcoder
Copy link
Contributor Author

sambhavnoobcoder commented Oct 29, 2025

@BenjaminBossan Thank you for the detailed review . i have made all the changes and would appreciate if you could have a look at it again . I'll make any changes necessary .

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on the PR and extending the tests. I still have a few comments, please check.

As a general remark, the logic for handling weight tying in trainable tokens is inherently quite complex. Therefore, I focused on checking if the implementation is clear and simple while keeping the functionality intact. When I found code that I thought could be improved in this regard, I added a comment. But I would also kindly ask you to double check if you can find anything that can be simplified and apply it, even if I haven't commented on it. This will help with the long term health of the PEFT code base 🙏

weights_tied = (
model_config.get("tie_word_embeddings", False)
# some models may be misconfigured to have weight tying enabled but don't define tied weights keys
and model._tied_weights_keys is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could theoretically raise an AttributeError if used with a non-HF transformers model, right? It's not so likely in practice, since a non-HF transformers model is unlikely to have a model config with tie_word_embeddings, but let's still use getattr here to be safe. I would also assign this to a variable, as it's used 3 times in total.

# Check if any of the target layers correspond to tied weights in the model
# Instead of guessing layer names, compare against actual tied weight keys
# Extract module names from tied weights keys (remove the weight attribute name)
tied_module_names = {".".join(key.split(".")[:-1]) for key in model._tied_weights_keys}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say this is simpler:

Suggested change
tied_module_names = {".".join(key.split(".")[:-1]) for key in model._tied_weights_keys}
tied_module_names = {key.rpartition(".")[0] for key in model._tied_weights_keys}

I saw that the existing code does the same thing as you did here, but let's still try to improve :) (feel free to adjust the existing code below too).

break

# Find which target layers are in the tied weights (including the embedding source)
for target_layer in target_layers:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rename target_layer to target_layer_name to make it clear that it's the name, not the module itself.

has_both_layers = True
# Check if all tied layers have the same indices
first_indices = target_layers[tied_layer_keys[0]]
indices_match = all(target_layers[key] == first_indices for key in tied_layer_keys[1:])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need both indices_match and indices_mismatch, it's a bit redundant. I think it's easiest to eliminate the former.

for name, module in model.named_modules():
if module is embedding_module:
# Get just the last part of the name for matching with target_layers
embedding_name = name.split(".")[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the logic in this loop is fine, it can be a bit confusing: What would it mean if the embedding_module is not found? This should never happen, right? So I'm wondering if we can do something like:

embedding_name = next(n.split(".")[-1] for n, m in model.named_modules() if m is embedding_module)

This would raise an error if embedding_module is not found instead of leaving embedding_name = None. What's your opinion?

if weights_tied and ensure_weight_tying and has_both_layers and indices_mismatch:
# Build more generic error message showing the conflicting layers
tied_layers_info = ", ".join([f"{key}: {target_layers[key]}" for key in tied_layer_keys])
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not raise this error immediately after indices_mismatch was determined? The earlier we can raise, the better. It should also make the check simpler, as we only need to check for if indices_mismatch.

# Since indices match here, indices_mismatch=False, so this simplifies to: we apply tying
# Skip all tied modules except the embedding (first one in tied_layer_keys)
# The embedding is typically first, but to be safe, skip modules in _tied_weights_keys
for key in tied_layer_keys:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we cannot simply take the intersection between the two:

layers_to_skip = set(tied_layer_keys) & tied_module_names.

This approach would fail if we have a substring match but not a full string match, which is what you cover with tied_module.endswith(key). However, I don't see what would need to happen for a substring-only match, and AFAICT, the tests also never reach that point. Could you please explain?

and isinstance(model.get_input_embeddings(), TrainableTokensWrapper)
):
# the embedding layer is modified and we want weight tying.
and not (not ensure_weight_tying and has_both_layers and indices_mismatch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conditional is a bit hard to read IMO, let's try to simplify. So for and not (not ensure_weight_tying ... let's move it out of the parenthesis, i.e. it becomes and ensure_weight_tying. As for indices_mismatch, this can only ever be True if has_both_layers is also True, right? So we don't really need to check both.


if len(target_layers) > 1 and weights_tied and model._tied_weights_keys:
# Check if any of the target layers correspond to tied weights in the model
# Instead of guessing layer names, compare against actual tied weight keys
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment can be removed IMO.

assert lm_head_adapter.token_indices["default"] == [1, 2]

def test_weight_tying_bc_same_indices_applied(self, model_weight_tied):
"""Backwards compatibility: same indices should have weight tying even when ensure_weight_tying=False"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really for BC, is it? I think this is just the general expected behavior. The BC part is only for cases where the behavior might not be what the users expects but we cannot change it now because it would be backwards incompatible.

@BenjaminBossan
Copy link
Member

@sambhavnoobcoder Are you still working on this?

@sambhavnoobcoder
Copy link
Contributor Author

oh hi @BenjaminBossan , actually i had resolved all the comments already , just forgot to tag you for a review . just resolved a small merge conflict , this i already for your reviewing / merging now .

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my latest comments, the logic is much cleaner now. I have one small comment, otherwise the PR looks good.


if len(target_layers) > 1 and weights_tied and tied_weights_keys:
# Extract module names from tied weights keys (remove the weight attribute name)
tied_module_names = {key.rpartition(".")[0] for key in tied_weights_keys}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use _get_module_names_tied_with_embedding to determine tied_module_names here too (you already use it below). After this change, it means we also don't need tied_weights_keys = getattr(model, "_tied_weights_keys", None) anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohkay , i missed this somehow . resolved this in 88f5f22

@sambhavnoobcoder
Copy link
Contributor Author

@BenjaminBossan resolved this too—please re-review, and we can merge if all looks good.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

Thanks @sambhavnoobcoder, the PR looks good from my side.

@githubnemo could you please also do a review 🙏?

@sambhavnoobcoder
Copy link
Contributor Author

Hi @githubnemo ,
Could you please look into this as well ? I would appreciate your review on the same , make any changes needed so that we can merge this accordingly .

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two comments, I might have misunderstood something though :)

Comment on lines 1492 to 1494
if target_layer_name == embedding_name:
tied_layer_keys.append(target_layer_name)
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given input embedding layer encoder.embed_tokens, then embedding_name = "embed_tokens" and therefore targeting encoder.embed_tokens would not be considered here (since "encoder.embed_tokens" != "embed_tokens"), right?

We probably need to check embedding_name.endswith(target_layer_name). If there are conflicts with other layer names (e.g., separate embeddings both called embed_tokens in different submodules) the user can specify the target names more specifically and resolve the conflict. Similarly to how the comparison happens in the for-loop below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I don't think that the change is sufficient.

Consider an encoder-decoder model like BART or Marian - it has several layers called embed_tokens, these don't have to be tied necessarily.

We can imagine a model made up of two such models:

class MegaModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        config1 = MarianConfig()
        self.m1 = MarianModel(config1)

        config2 = MarianConfig()
        config2.share_encoder_decoder_embeddings = False
        self.m2 = MarianModel(config2)

    def get_input_embeddings(self):
        return self.m1.encoder.embed_tokens

    # [...]

peft_config = LoraConfig(
    trainable_token_indices={
        "m1.encoder.embed_tokens": [1,2,3],
        "m2.encoder.embed_tokens": [4,5,6],
    },
    ensure_weight_tying=True,
)
peft_model = get_peft_model(MegaModel(), peft_config)

In this case we will target all embed_tokens in MegaModel, i.e. MegaModel.m1.encoder.embed_tokens, MegaModel.m1.decoder.embed_tokens, MegaModel.m2.encoder.embed_tokens, MegaModel.m2.decoder.embed_tokens.

For the code we're discussing this would mean that:

  1. embedding_name == "embed_tokens"
  2. target_layers == ["m1.encoder.embed_tokens", "m2.encoder.embed_tokens"]
  3. for each target_layer_name: target_layer_name.endswith(embedding_name)

So we would raise an error because the indices don't match. But this is wrong since m2 s not related to the embeddings of m1. But the user has no way to further specify which layers they meant since we're stripping
the embedding layer's name:

            embedding_name = next(n.split(".")[-1] for n, m in model.named_modules() if m is embedding_module)

I think it would be better to turn it around:

  1. use the full embedding name
  2. check embedding_name.endswith(target_layer_name)

This is what I was getting at in my initial comment. In that case user can specify target layer embed_tokens and will match the embedding layer m1.encoder.embed_tokens but, if there is a conflict like in the scenario above, the user can be more specific and write target layer m1.encoder.embed_tokens and resolve the conflict.

Ideally we would also have a unit test for this in the form of the example above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 616fb80

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix looks reasonable, thanks. I think the test needs some work, commented there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the test fix is addressed in 456cd36 now

# There might be the possibility that we have output weights that are tied to the input weights.
# In that case we will tie any module that wants tied weights to the token adapter to make sure that
# any modification is reflected in the tied layers as well.
tied_weights_module_names = _get_module_names_tied_with_embedding(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads-up: this code path is very similar to TrainableTokensModel.inject_adapter and would benefit from having the same ensure weight-tying logic. Maybe it makes sense to refactor this code into something more general but it's OK not to to keep the scope narrow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added fix in models.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate what this fix does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix in commit 769b8b0 sorts out a bug in inject_adapter model.py, where weight tying was quietly failing when modules were already wrapped. The problem showed up when users explicitly targeted both the embedding and tied layers, like with target_modules=["embed_tokens", "lm_head"]. In that case, the old code called _create_and_replace_dict, which then used update_layer(). But update_layer() would just exit early if a tied_adapter was passed, so the tying never actually got applied.

For the fix, I added a check: if the module is already a TrainableTokensLayer, we now fully replace it with a new tied module, instead of attempting to update the existing one.

As for your suggestion to bring in the same ensure_weight_tying logic from other.py—I didn't include the full validation (like checking for mismatched indices or respecting the ensure_weight_tying flag), because standalone TrainableTokensConfig only accepts a single token_indices list, not a dict that maps different indices to specific layer names. That means scenarios needing that validation (e.g., mismatched indices across tied layers) can't happen in standalone mode—they only crop up when combining with LoRA, which goes through the other.py path.

Would you rather I add the full validation logic to inject_adapter for consistency's sake? Please do point out if you think i am missing something here . Or do you think this is solid as is?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix in commit 769b8b0 sorts out a bug in inject_adapter model.py, where weight tying was quietly failing when modules were already wrapped. The problem showed up when users explicitly targeted both the embedding and tied layers, like with target_modules=["embed_tokens", "lm_head"]. In that case, the old code called _create_and_replace_dict, which then used update_layer(). But update_layer() would just exit early if a tied_adapter was passed, so the tying never actually got applied.

For the fix, I added a check: if the module is already a TrainableTokensLayer, we now fully replace it with a new tied module, instead of attempting to update the existing one.

Ah, I see. I think that doing it this way is fine. At first I thought that this will break when we have multiple trainable tokens adapters with target_modules=["embed_tokens", "lm_head"] but since the lm_head adapter is tied with embed_tokens anyway, it will just be updated to the embed_tokens version, inheriting the other adapter settings.

In any way, if there was a bug it would be nice to have a test for it:

  • one for the case you described
  • one for the multi-adapter case (adding multiple trainable token adapters to the same model and layers)

As for your suggestion to bring in the same ensure_weight_tying logic from other.py—I didn't include the full validation (like checking for mismatched indices or respecting the ensure_weight_tying flag), because standalone TrainableTokensConfig only accepts a single token_indices list, not a dict that maps different indices to specific layer names. That means scenarios needing that validation (e.g., mismatched indices across tied layers) can't happen in standalone mode—they only crop up when combining with LoRA, which goes through the other.py path.

Would you rather I add the full validation logic to inject_adapter for consistency's sake? Please do point out if you think i am missing something here . Or do you think this is solid as is?

No, that sounds reasonable. Thanks :)

Copy link
Contributor Author

@sambhavnoobcoder sambhavnoobcoder Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool . i agree that those testcases could help . i wrote two for my own validation , and they pass , so pushing them here as well , so that the entire testing stays consistent that way . you'll find them in 6ed72f7

@sambhavnoobcoder sambhavnoobcoder force-pushed the trainable-tokens-weight-tying branch from b76a094 to c62aa56 Compare December 9, 2025 10:25
@sambhavnoobcoder
Copy link
Contributor Author

Thank you for the review @githubnemo . i have made the necessary fixes according to your review . also ran make syle for style consistency . hoping you could look into this again and rereview the same . i'll make any more changes needed asap .

@sambhavnoobcoder
Copy link
Contributor Author

sambhavnoobcoder commented Dec 15, 2025

Thank you @githubnemo . Addressed all the comments in the commits 456cd36 and 6ed72f7 . Responded in the necessary threads as well . Hoping you could review once again , and point out if there is anything else needed to improve here . I'll make those changes asap .

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update, two small nits. After those we're ready to go :)

self._tied_weights_keys = [
"m1.encoder.embed_tokens.weight",
"m1.decoder.embed_tokens.weight",
# m2 has no tied weights since tie_word_embeddings=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the mapping is there regardless of the state of config2.tie_word_embeddings. so for consistency we should add those either way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 7e1821d

Comment on lines 1333 to 1334
"""Test the fix for when user explicitly targets both embedding and tied layers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this doc strings and the one of test_multiple_trainable_token_adapters_same_model talk about a fix which makes sense in the context of this PR but not when looking at the tests later. let's rephrase this so that a future reader will understand what is meant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 5f86f3a

@sambhavnoobcoder
Copy link
Contributor Author

Hi @githubnemo ,
Thank you for the review . i have made the changes as requested , both very small changes in 7e1821d and 5f86f3a . Hoping you could Re-review this once more , and then we can probably take this forward .

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for your work :)

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work, the PR is basically ready. I just saw a small handful of things to improve in the tests. After this, it's time to merge.

def test_mega_model_multiple_embed_tokens_specific_targeting(self):
"""Test that users can specify full paths to disambiguate multiple embed_tokens layers.
This tests the scenario described by the maintainer where a composite model has multiple sub-models, each with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"described by the maintainer" is not necessary for the description :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done 69dbbe7 .

This tests the scenario described by the maintainer where a composite model has multiple sub-models, each with
their own embed_tokens, and users need to target them independently with different token indices.
"""
from transformers import BartConfig, BartModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make the import global.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done 1969250

# They should NOT share delta parameters (model doesn't have tied weights)
assert embed_adapter.trainable_tokens_delta is not lm_head_adapter.trainable_tokens_delta

def test_mega_model_multiple_embed_tokens_specific_targeting(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use "mega" but instead "composite" throughout this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done eb20b83

@sambhavnoobcoder
Copy link
Contributor Author

Thank you @BenjaminBossan . Those were some silly errors on my part , i'll keep those in mind from the next time for sure . Made make fixes , i would appreciate if you could review please .

@BenjaminBossan
Copy link
Member

@sambhavnoobcoder There is still an issue, this time from hf-doc-builder, when it runs doc-builder style src/peft tests docs/source --max_len 119 --check_only. However, I can't replicate the issue locally, not sure what's happening. Perhaps you could try? It's version v5.0.0.

@sambhavnoobcoder
Copy link
Contributor Author

@BenjaminBossan I think this should pass now , made the necessary change .

@BenjaminBossan
Copy link
Member

Thanks for taking care of the docs @sambhavnoobcoder. There is a failing test, could you please check?

@sambhavnoobcoder
Copy link
Contributor Author

Hi @BenjaminBossan , i think this should be fixed now . Evne though the tests passed locally for me even before , not sure if that might be due to the underlying operating system of my local device in any way ( since the failing tests do seem to indicate windows or ubuntu ) , but i was able to write a small script to replicate the error , and made the fix , tested it out , i think the fix works fine now , so the changes from b56d14c should pass this test .

@BenjaminBossan
Copy link
Member

Evne though the tests passed locally for me even before , not sure if that might be due to the underlying operating system of my local device in any way ( since the failing tests do seem to indicate windows or ubuntu ) , but i was able to write a small script to replicate the error

Could you please share this script? I also cannot replicate the error (using Ubuntu with and without GPU) so I would like to test for myself. The reason why it's important is that there are quite a few changes in the upcoming transformers v5 release and I want to ensure that the PR still works with those (the CI on the PR only checks the latest release).

@sambhavnoobcoder
Copy link
Contributor Author

Yes sure @BenjaminBossan , here is the script , running this produces the error from CI:

from transformers import BartConfig, BartModel
import torch

class CompositeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        config1 = BartConfig(vocab_size=100, d_model=32, encoder_layers=1, decoder_layers=1, 
                            encoder_attention_heads=2, decoder_attention_heads=2, 
                            encoder_ffn_dim=64, decoder_ffn_dim=64)
        self.m1 = BartModel(config1)
        config2 = BartConfig(vocab_size=100, d_model=32, encoder_layers=1, decoder_layers=1, 
                            encoder_attention_heads=2, decoder_attention_heads=2, 
                            encoder_ffn_dim=64, decoder_ffn_dim=64)
        config2.tie_word_embeddings = False
        self.m2 = BartModel(config2)
        self.config = config1
        
        # NO FIX - Comment out these two lines to reproduce the error
        # self.m1._tied_weights_keys = None
        # self.m2._tied_weights_keys = None
        
        self._tied_weights_keys = {
            'm1.decoder.embed_tokens.weight': 'm1.encoder.embed_tokens.weight',
            'm2.decoder.embed_tokens.weight': 'm2.encoder.embed_tokens.weight',
        }

model = CompositeModel()

# This simulates the PEFT main branch code from src/peft/utils/other.py lines 1677-1681
tied_weights_keys = {}
for module_name, module in model.named_modules():
    module_tied_weights_keys = getattr(module, '_tied_weights_keys', None)
    if module_tied_weights_keys and not module_name:
        tied_weights_keys.update(module_tied_weights_keys)
    elif module_tied_weights_keys:
        # This line calls .items() expecting a dict, but BartModel has a list
        tied_weights_keys.update(
            {f"{module_name}.{k}": f"{module_name}.{v}" for k, v in module_tied_weights_keys.items()}
        )

print("Success!")

Result WITHOUT fix:

AttributeError: 'list' object has no attribute 'items'

The error occurs because BartModel's _tied_weights_keys is a list but PEFT main branch code calls .items() on it.

With the fix , the script runs successfully because PEFT skips None values in the if module_tied_weights_keys check. Result WITH fix:

✓ Success! No errors with the fix.

I totally understand your concern regarding the upgrade and compatibility with transformers V5 , i have am also staying updated with that , recently read the article about it . I would also want to ensure this remains compatible with the new version and not cause any errors . So kindly feel free to look into this closely regarding the same . I'll also give it a look , and will resolve any reviews / comments from your end on the same as soon as possible .

@BenjaminBossan
Copy link
Member

Thanks for the small reproducer @sambhavnoobcoder but for me, this snippet passes. I also don't see why it should fail. Instead, I tried this:

from peft.utils.other import _get_module_names_tied_with_embedding

class CompositeModel(torch.nn.Module):
    # same as yours

model = CompositeModel()

_get_module_names_tied_with_embedding(model)

Here I get a failure, but it's a different one:

ValueError: The supplied model implements _tied_weights_keys as a dict but doesn't implement 'get_input_embeddings' so we can't determine which weights are tied to embeddings.

I think this is a legit error. So in the end I wonder if this really reproduces the error we saw earlier in CI. Could you please double-check?

@sambhavnoobcoder
Copy link
Contributor Author

sambhavnoobcoder commented Jan 7, 2026

Curious and curiouser . @BenjaminBossan i spent some time tying to debug this . here's my finding so far :

Summary / TLDR :

I think the current code/ branch is correct , you probably couldn't replicate due to a merge issue , but you should be able to replicate it now , so i would appreciate if you could test this once more with my reproducer script above and check it out once more please . if it still fails , i'll look into this one more time once again .

Long version / Detailed observation :

The error seems to stems from changes in the main branch (commit b4f56dd on Dec 3, 2025, "FIX Transformers v5 fixes #2934"), which added a loop to iterate through ALL sub-modules when collecting _tied_weights_keys . This loop calls .items() on sub-modules' _tied_weights_keys, expecting a dict. In the CompositeModel test, BartModel sub-modules have _tied_weights_keys as a list (e.g., ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']), leading to AttributeError: 'list' object has no attribute 'items'.The error manifests in the merged state (PR branch + main), where the CompositeModel test combines with main's sub-module iteration code.

I've also gone through a detailed timeline on how things evolved here . i paid special attention to the transformers v5 update or any related fixes , since i didn't want anything to break because of that here .

Finally i ran the same tests as above here after merging changes from the main with and without the fix . Without fix i was able to replicate the error and with the fix , it was resolved . I was also able to replicate your error as well , and the fix seems to resolve it .

So i think that the fix is proper and also ensures compatibility with:

  • Transformers v4 (list-format _tied_weights_keys).
  • Transformers v5 (dict-format _tied_weights_keys).

So i think you should be able to replicate the error now and verify it as well . Hence i would ask you if you could rerun the same script again , once in full and once commenting out the fix , and inform me accordingly if you are able to replicate / reproduce the error now or not . based on your observations , i will take the necessary actions accordingly .

in case it still doesn't work for you , perhaps could you point to some potential points you think i should look at to debug this , perhaps specifically wrt v5 transformers if necessary ?

@BenjaminBossan
Copy link
Member

Ah I see now, I had to delete my tracking branch of your branch and re-create it, for some reason it was not up-to-date.

The issue only occurs with transformers < v5 (or main branch) and the cause is that we're mixing old style _tied_weights_keys (list) and new style _tied_weights_keys (dict). The CompositeModel uses the new style (v5) but when we use transformers < v5, the BartModels still use the old style. The _get_module_names_tied_with_embedding function assumes it's either/or, not a mix of both, which I think is a reasonable assumption.

What it means for your test is that you should ensure that _tied_weights_keys is consistent. You could check the type of self.m1._tied_weights_keys, if it's a list, self._tied_weights_keys should also be a list, otherwise it should be a dict as it is right now. Please add a comment to explain why this is happening.

@sambhavnoobcoder
Copy link
Contributor Author

sambhavnoobcoder commented Jan 7, 2026

Hi @BenjaminBossan,
Thank you for the feedback! I've updated the code to check the type and ensure format consistency as you suggested.I considered using list format when sub-models use list format, but that would incorrectly imply all layers are tied to each other (including across sub-models). The dict format correctly expresses independent tied weights within each sub-model while avoiding the format mixing issue. I initially used list format (commit 456cd36), then switched to dict for v5 compatibility (commit 7e1821d). The current fix handles the v4/v5 transition by checking types and ensuring consistency.

If i have misunderstood something , or any other changes are needed , kindly inform , i'll take care of them as well asap .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Deal with weight tying consistently

4 participants