-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement ensure_weight_tying for trainable_token_indices (#2864) #2870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement ensure_weight_tying for trainable_token_indices (#2864) #2870
Conversation
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for a lot for handling the update of weight tying of trainable tokens. What's there already looks quite good, but I wonder if we can simplify the implementation, please check my suggestions.
Regarding the tests, I wanted to map the tests you wrote onto the table from #2864, this is what I ended up with:
| weights tied | ensure_weight_tying | LoraConfig trainable_token_indices | result | test |
|---|---|---|---|---|
| False | False | [1, 2, 3] |
trainable tokens on embeddings only | |
| False | True | [1, 2, 3] |
warn & trainable tokens on embeddings only | test_ensure_weight_tying_warns_when_model_not_tied_list_format |
| True | False | [1, 2, 3] |
tied trainable tokens | |
| True | True | [1, 2, 3] |
tied trainable tokens | test_ensure_weight_tying_with_single_layer |
| False | False | {"lm_head": [1,2], "embed_tokens": [1,2]} |
treat as separate | |
| False | True | {"lm_head": [1,2], "embed_tokens": [1,2]} |
warn & treat as separate | |
| True | False | {"lm_head": [1,2], "embed_tokens": [1,2]} |
tied trainable tokens | test_weight_tying_bc_same_indices_applied |
| True | True | {"lm_head": [1,2], "embed_tokens": [1,2]} |
tied trainable tokens | test_ensure_weight_tying_applied_with_same_indices |
| False | False | {"lm_head": [1,2], "embed_tokens": [3,4]} |
treat as separate | |
| False | True | {"lm_head": [1,2], "embed_tokens": [3,4]} |
warn & treat as separate | |
| True | False | {"lm_head": [1,2], "embed_tokens": [3,4]} |
*treat as separate | test_weight_tying_bc_different_indices_treated_separately |
| True | True | {"lm_head": [1,2], "embed_tokens": [3,4]} |
*error | test_ensure_weight_tying_errors_with_different_indices |
Does this look right to you? I think it means there are still a few gaps in the tests, could you please provide the missing ones? Some tests could be combined via pytest.mark.parametrize if the expected outcomes are the same.
tests/test_trainable_tokens.py
Outdated
| ] | ||
| assert warnings_found | ||
|
|
||
| def test_ensure_weight_tying_warns_when_model_not_tied_dict_format(self, model_weight_untied, recwarn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test can be merged with test_ensure_weight_tying_warns_when_model_not_tied_list_format by parametrizing the trainable_token_indices argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
tests/test_trainable_tokens.py
Outdated
| warnings_list = [w.message.args[0] for w in recwarn] | ||
| warnings_found = [ | ||
| msg for msg in warnings_list if "ensure_weight_tying=True but the model does not have tied weights" in msg | ||
| ] | ||
| assert warnings_found |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a bit more elegant to do:
expected = ...
assert any(expected in msg for msg in warings_list)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
tests/test_trainable_tokens.py
Outdated
| ensure_weight_tying=True, | ||
| ) | ||
|
|
||
| with pytest.raises(ValueError) as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use:
msg = "Cannot ensure weight tying when different token indices are specified"
with pytest.raises(ValueError, match=msg):There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
src/peft/utils/other.py
Outdated
| ensure_weight_tying = getattr(peft_config, "ensure_weight_tying", False) | ||
|
|
||
| # Check if we're dealing with dict format that specifies both embed_tokens and lm_head | ||
| is_dict_format = isinstance(peft_config.trainable_token_indices, dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need is_dict_format. The check below, len(target_layers) > 1, is already enough, is it not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes , re-reviewed this , and simplified the logic significantly . refrence 232c6e7 for implementation .
src/peft/utils/other.py
Outdated
| if "embed" in key_lower and not ("lm" in key_lower or "head" in key_lower): | ||
| embed_key = key | ||
| elif "lm_head" in key_lower or ("head" in key_lower and "lm" not in key_lower): | ||
| lm_head_key = key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we overcomplicate things here. If there are multiple target_layers, can we not just compare them to the tied weights? Is it important to identify here which one is for the embedding and which one is for the LM head?
Below, you're using the names for the error message, which is a nice touch, but if we can refrain from guessing here, it would be worth it to make the error message more generic IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i relooked at this and removed the string matching logic (checking for "embed", "lm_head", etc.) and now directly compare target layers against model._tied_weights_keys and the actual embedding layer. The error message is now generic, showing all conflicting tied layers instead of assuming specific names.
src/peft/utils/other.py
Outdated
| indices_mismatch = True | ||
| else: | ||
| # Same indices - if weights are tied and we're applying tying, skip lm_head (it'll be tied later) | ||
| if weights_tied and not (not ensure_weight_tying and False): # Will apply tying |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check makes no sense to me, why and False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
|
About the test coverage , the table looks correct. I've filled all 6 gaps in the test coverage:
|
|
@BenjaminBossan Thank you for the detailed review . i have made all the changes and would appreciate if you could have a look at it again . I'll make any changes necessary . |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating on the PR and extending the tests. I still have a few comments, please check.
As a general remark, the logic for handling weight tying in trainable tokens is inherently quite complex. Therefore, I focused on checking if the implementation is clear and simple while keeping the functionality intact. When I found code that I thought could be improved in this regard, I added a comment. But I would also kindly ask you to double check if you can find anything that can be simplified and apply it, even if I haven't commented on it. This will help with the long term health of the PEFT code base 🙏
src/peft/utils/other.py
Outdated
| weights_tied = ( | ||
| model_config.get("tie_word_embeddings", False) | ||
| # some models may be misconfigured to have weight tying enabled but don't define tied weights keys | ||
| and model._tied_weights_keys is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could theoretically raise an AttributeError if used with a non-HF transformers model, right? It's not so likely in practice, since a non-HF transformers model is unlikely to have a model config with tie_word_embeddings, but let's still use getattr here to be safe. I would also assign this to a variable, as it's used 3 times in total.
src/peft/utils/other.py
Outdated
| # Check if any of the target layers correspond to tied weights in the model | ||
| # Instead of guessing layer names, compare against actual tied weight keys | ||
| # Extract module names from tied weights keys (remove the weight attribute name) | ||
| tied_module_names = {".".join(key.split(".")[:-1]) for key in model._tied_weights_keys} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say this is simpler:
| tied_module_names = {".".join(key.split(".")[:-1]) for key in model._tied_weights_keys} | |
| tied_module_names = {key.rpartition(".")[0] for key in model._tied_weights_keys} |
I saw that the existing code does the same thing as you did here, but let's still try to improve :) (feel free to adjust the existing code below too).
src/peft/utils/other.py
Outdated
| break | ||
|
|
||
| # Find which target layers are in the tied weights (including the embedding source) | ||
| for target_layer in target_layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rename target_layer to target_layer_name to make it clear that it's the name, not the module itself.
src/peft/utils/other.py
Outdated
| has_both_layers = True | ||
| # Check if all tied layers have the same indices | ||
| first_indices = target_layers[tied_layer_keys[0]] | ||
| indices_match = all(target_layers[key] == first_indices for key in tied_layer_keys[1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need both indices_match and indices_mismatch, it's a bit redundant. I think it's easiest to eliminate the former.
src/peft/utils/other.py
Outdated
| for name, module in model.named_modules(): | ||
| if module is embedding_module: | ||
| # Get just the last part of the name for matching with target_layers | ||
| embedding_name = name.split(".")[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although the logic in this loop is fine, it can be a bit confusing: What would it mean if the embedding_module is not found? This should never happen, right? So I'm wondering if we can do something like:
embedding_name = next(n.split(".")[-1] for n, m in model.named_modules() if m is embedding_module)This would raise an error if embedding_module is not found instead of leaving embedding_name = None. What's your opinion?
src/peft/utils/other.py
Outdated
| if weights_tied and ensure_weight_tying and has_both_layers and indices_mismatch: | ||
| # Build more generic error message showing the conflicting layers | ||
| tied_layers_info = ", ".join([f"{key}: {target_layers[key]}" for key in tied_layer_keys]) | ||
| raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not raise this error immediately after indices_mismatch was determined? The earlier we can raise, the better. It should also make the check simpler, as we only need to check for if indices_mismatch.
src/peft/utils/other.py
Outdated
| # Since indices match here, indices_mismatch=False, so this simplifies to: we apply tying | ||
| # Skip all tied modules except the embedding (first one in tied_layer_keys) | ||
| # The embedding is typically first, but to be safe, skip modules in _tied_weights_keys | ||
| for key in tied_layer_keys: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we cannot simply take the intersection between the two:
layers_to_skip = set(tied_layer_keys) & tied_module_names.
This approach would fail if we have a substring match but not a full string match, which is what you cover with tied_module.endswith(key). However, I don't see what would need to happen for a substring-only match, and AFAICT, the tests also never reach that point. Could you please explain?
src/peft/utils/other.py
Outdated
| and isinstance(model.get_input_embeddings(), TrainableTokensWrapper) | ||
| ): | ||
| # the embedding layer is modified and we want weight tying. | ||
| and not (not ensure_weight_tying and has_both_layers and indices_mismatch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conditional is a bit hard to read IMO, let's try to simplify. So for and not (not ensure_weight_tying ... let's move it out of the parenthesis, i.e. it becomes and ensure_weight_tying. As for indices_mismatch, this can only ever be True if has_both_layers is also True, right? So we don't really need to check both.
src/peft/utils/other.py
Outdated
|
|
||
| if len(target_layers) > 1 and weights_tied and model._tied_weights_keys: | ||
| # Check if any of the target layers correspond to tied weights in the model | ||
| # Instead of guessing layer names, compare against actual tied weight keys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment can be removed IMO.
tests/test_trainable_tokens.py
Outdated
| assert lm_head_adapter.token_indices["default"] == [1, 2] | ||
|
|
||
| def test_weight_tying_bc_same_indices_applied(self, model_weight_tied): | ||
| """Backwards compatibility: same indices should have weight tying even when ensure_weight_tying=False""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not really for BC, is it? I think this is just the general expected behavior. The BC part is only for cases where the behavior might not be what the users expects but we cannot change it now because it would be backwards incompatible.
|
@sambhavnoobcoder Are you still working on this? |
|
oh hi @BenjaminBossan , actually i had resolved all the comments already , just forgot to tag you for a review . just resolved a small merge conflict , this i already for your reviewing / merging now . |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my latest comments, the logic is much cleaner now. I have one small comment, otherwise the PR looks good.
src/peft/utils/other.py
Outdated
|
|
||
| if len(target_layers) > 1 and weights_tied and tied_weights_keys: | ||
| # Extract module names from tied weights keys (remove the weight attribute name) | ||
| tied_module_names = {key.rpartition(".")[0] for key in tied_weights_keys} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use _get_module_names_tied_with_embedding to determine tied_module_names here too (you already use it below). After this change, it means we also don't need tied_weights_keys = getattr(model, "_tied_weights_keys", None) anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohkay , i missed this somehow . resolved this in 88f5f22
|
@BenjaminBossan resolved this too—please re-review, and we can merge if all looks good. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thanks @sambhavnoobcoder, the PR looks good from my side. @githubnemo could you please also do a review 🙏? |
|
Hi @githubnemo , |
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two comments, I might have misunderstood something though :)
src/peft/utils/other.py
Outdated
| if target_layer_name == embedding_name: | ||
| tied_layer_keys.append(target_layer_name) | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given input embedding layer encoder.embed_tokens, then embedding_name = "embed_tokens" and therefore targeting encoder.embed_tokens would not be considered here (since "encoder.embed_tokens" != "embed_tokens"), right?
We probably need to check embedding_name.endswith(target_layer_name). If there are conflicts with other layer names (e.g., separate embeddings both called embed_tokens in different submodules) the user can specify the target names more specifically and resolve the conflict. Similarly to how the comparison happens in the for-loop below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I don't think that the change is sufficient.
Consider an encoder-decoder model like BART or Marian - it has several layers called embed_tokens, these don't have to be tied necessarily.
We can imagine a model made up of two such models:
class MegaModel(torch.nn.Module):
def __init__(self):
super().__init__()
config1 = MarianConfig()
self.m1 = MarianModel(config1)
config2 = MarianConfig()
config2.share_encoder_decoder_embeddings = False
self.m2 = MarianModel(config2)
def get_input_embeddings(self):
return self.m1.encoder.embed_tokens
# [...]
peft_config = LoraConfig(
trainable_token_indices={
"m1.encoder.embed_tokens": [1,2,3],
"m2.encoder.embed_tokens": [4,5,6],
},
ensure_weight_tying=True,
)
peft_model = get_peft_model(MegaModel(), peft_config)In this case we will target all embed_tokens in MegaModel, i.e. MegaModel.m1.encoder.embed_tokens, MegaModel.m1.decoder.embed_tokens, MegaModel.m2.encoder.embed_tokens, MegaModel.m2.decoder.embed_tokens.
For the code we're discussing this would mean that:
embedding_name == "embed_tokens"target_layers == ["m1.encoder.embed_tokens", "m2.encoder.embed_tokens"]- for each
target_layer_name:target_layer_name.endswith(embedding_name)
So we would raise an error because the indices don't match. But this is wrong since m2 s not related to the embeddings of m1. But the user has no way to further specify which layers they meant since we're stripping
the embedding layer's name:
embedding_name = next(n.split(".")[-1] for n, m in model.named_modules() if m is embedding_module)
I think it would be better to turn it around:
- use the full embedding name
- check
embedding_name.endswith(target_layer_name)
This is what I was getting at in my initial comment. In that case user can specify target layer embed_tokens and will match the embedding layer m1.encoder.embed_tokens but, if there is a conflict like in the scenario above, the user can be more specific and write target layer m1.encoder.embed_tokens and resolve the conflict.
Ideally we would also have a unit test for this in the form of the example above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in 616fb80
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fix looks reasonable, thanks. I think the test needs some work, commented there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the test fix is addressed in 456cd36 now
| # There might be the possibility that we have output weights that are tied to the input weights. | ||
| # In that case we will tie any module that wants tied weights to the token adapter to make sure that | ||
| # any modification is reflected in the tied layers as well. | ||
| tied_weights_module_names = _get_module_names_tied_with_embedding(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a heads-up: this code path is very similar to TrainableTokensModel.inject_adapter and would benefit from having the same ensure weight-tying logic. Maybe it makes sense to refactor this code into something more general but it's OK not to to keep the scope narrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added fix in models.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate what this fix does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fix in commit 769b8b0 sorts out a bug in inject_adapter model.py, where weight tying was quietly failing when modules were already wrapped. The problem showed up when users explicitly targeted both the embedding and tied layers, like with target_modules=["embed_tokens", "lm_head"]. In that case, the old code called _create_and_replace_dict, which then used update_layer(). But update_layer() would just exit early if a tied_adapter was passed, so the tying never actually got applied.
For the fix, I added a check: if the module is already a TrainableTokensLayer, we now fully replace it with a new tied module, instead of attempting to update the existing one.
As for your suggestion to bring in the same ensure_weight_tying logic from other.py—I didn't include the full validation (like checking for mismatched indices or respecting the ensure_weight_tying flag), because standalone TrainableTokensConfig only accepts a single token_indices list, not a dict that maps different indices to specific layer names. That means scenarios needing that validation (e.g., mismatched indices across tied layers) can't happen in standalone mode—they only crop up when combining with LoRA, which goes through the other.py path.
Would you rather I add the full validation logic to inject_adapter for consistency's sake? Please do point out if you think i am missing something here . Or do you think this is solid as is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fix in commit 769b8b0 sorts out a bug in inject_adapter model.py, where weight tying was quietly failing when modules were already wrapped. The problem showed up when users explicitly targeted both the embedding and tied layers, like with target_modules=["embed_tokens", "lm_head"]. In that case, the old code called _create_and_replace_dict, which then used update_layer(). But update_layer() would just exit early if a tied_adapter was passed, so the tying never actually got applied.
For the fix, I added a check: if the module is already a TrainableTokensLayer, we now fully replace it with a new tied module, instead of attempting to update the existing one.
Ah, I see. I think that doing it this way is fine. At first I thought that this will break when we have multiple trainable tokens adapters with target_modules=["embed_tokens", "lm_head"] but since the lm_head adapter is tied with embed_tokens anyway, it will just be updated to the embed_tokens version, inheriting the other adapter settings.
In any way, if there was a bug it would be nice to have a test for it:
- one for the case you described
- one for the multi-adapter case (adding multiple trainable token adapters to the same model and layers)
As for your suggestion to bring in the same ensure_weight_tying logic from other.py—I didn't include the full validation (like checking for mismatched indices or respecting the ensure_weight_tying flag), because standalone TrainableTokensConfig only accepts a single token_indices list, not a dict that maps different indices to specific layer names. That means scenarios needing that validation (e.g., mismatched indices across tied layers) can't happen in standalone mode—they only crop up when combining with LoRA, which goes through the other.py path.
Would you rather I add the full validation logic to inject_adapter for consistency's sake? Please do point out if you think i am missing something here . Or do you think this is solid as is?
No, that sounds reasonable. Thanks :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool . i agree that those testcases could help . i wrote two for my own validation , and they pass , so pushing them here as well , so that the entire testing stays consistent that way . you'll find them in 6ed72f7
b76a094 to
c62aa56
Compare
|
Thank you for the review @githubnemo . i have made the necessary fixes according to your review . also ran make syle for style consistency . hoping you could look into this again and rereview the same . i'll make any more changes needed asap . |
|
Thank you @githubnemo . Addressed all the comments in the commits 456cd36 and 6ed72f7 . Responded in the necessary threads as well . Hoping you could review once again , and point out if there is anything else needed to improve here . I'll make those changes asap . |
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update, two small nits. After those we're ready to go :)
tests/test_trainable_tokens.py
Outdated
| self._tied_weights_keys = [ | ||
| "m1.encoder.embed_tokens.weight", | ||
| "m1.decoder.embed_tokens.weight", | ||
| # m2 has no tied weights since tie_word_embeddings=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the mapping is there regardless of the state of config2.tie_word_embeddings. so for consistency we should add those either way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in 7e1821d
tests/test_trainable_tokens.py
Outdated
| """Test the fix for when user explicitly targets both embedding and tied layers. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this doc strings and the one of test_multiple_trainable_token_adapters_same_model talk about a fix which makes sense in the context of this PR but not when looking at the tests later. let's rephrase this so that a future reader will understand what is meant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in 5f86f3a
|
Hi @githubnemo , |
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you for your work :)
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work, the PR is basically ready. I just saw a small handful of things to improve in the tests. After this, it's time to merge.
tests/test_trainable_tokens.py
Outdated
| def test_mega_model_multiple_embed_tokens_specific_targeting(self): | ||
| """Test that users can specify full paths to disambiguate multiple embed_tokens layers. | ||
| This tests the scenario described by the maintainer where a composite model has multiple sub-models, each with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"described by the maintainer" is not necessary for the description :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done 69dbbe7 .
tests/test_trainable_tokens.py
Outdated
| This tests the scenario described by the maintainer where a composite model has multiple sub-models, each with | ||
| their own embed_tokens, and users need to target them independently with different token indices. | ||
| """ | ||
| from transformers import BartConfig, BartModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make the import global.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done 1969250
tests/test_trainable_tokens.py
Outdated
| # They should NOT share delta parameters (model doesn't have tied weights) | ||
| assert embed_adapter.trainable_tokens_delta is not lm_head_adapter.trainable_tokens_delta | ||
|
|
||
| def test_mega_model_multiple_embed_tokens_specific_targeting(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use "mega" but instead "composite" throughout this test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done eb20b83
|
Thank you @BenjaminBossan . Those were some silly errors on my part , i'll keep those in mind from the next time for sure . Made make fixes , i would appreciate if you could review please . |
|
@sambhavnoobcoder There is still an issue, this time from |
|
@BenjaminBossan I think this should pass now , made the necessary change . |
|
Thanks for taking care of the docs @sambhavnoobcoder. There is a failing test, could you please check? |
|
Hi @BenjaminBossan , i think this should be fixed now . Evne though the tests passed locally for me even before , not sure if that might be due to the underlying operating system of my local device in any way ( since the failing tests do seem to indicate windows or ubuntu ) , but i was able to write a small script to replicate the error , and made the fix , tested it out , i think the fix works fine now , so the changes from b56d14c should pass this test . |
Could you please share this script? I also cannot replicate the error (using Ubuntu with and without GPU) so I would like to test for myself. The reason why it's important is that there are quite a few changes in the upcoming transformers v5 release and I want to ensure that the PR still works with those (the CI on the PR only checks the latest release). |
|
Yes sure @BenjaminBossan , here is the script , running this produces the error from CI: Result WITHOUT fix: The error occurs because BartModel's _tied_weights_keys is a list but PEFT main branch code calls .items() on it. With the fix , the script runs successfully because PEFT skips None values in the if module_tied_weights_keys check. Result WITH fix: I totally understand your concern regarding the upgrade and compatibility with transformers V5 , i have am also staying updated with that , recently read the article about it . I would also want to ensure this remains compatible with the new version and not cause any errors . So kindly feel free to look into this closely regarding the same . I'll also give it a look , and will resolve any reviews / comments from your end on the same as soon as possible . |
|
Thanks for the small reproducer @sambhavnoobcoder but for me, this snippet passes. I also don't see why it should fail. Instead, I tried this: from peft.utils.other import _get_module_names_tied_with_embedding
class CompositeModel(torch.nn.Module):
# same as yours
model = CompositeModel()
_get_module_names_tied_with_embedding(model)Here I get a failure, but it's a different one:
I think this is a legit error. So in the end I wonder if this really reproduces the error we saw earlier in CI. Could you please double-check? |
|
Curious and curiouser . @BenjaminBossan i spent some time tying to debug this . here's my finding so far : Summary / TLDR :I think the current code/ branch is correct , you probably couldn't replicate due to a merge issue , but you should be able to replicate it now , so i would appreciate if you could test this once more with my reproducer script above and check it out once more please . if it still fails , i'll look into this one more time once again . Long version / Detailed observation :The error seems to stems from changes in the main branch (commit b4f56dd on Dec 3, 2025, "FIX Transformers v5 fixes #2934"), which added a loop to iterate through ALL sub-modules when collecting _tied_weights_keys . This loop calls .items() on sub-modules' _tied_weights_keys, expecting a dict. In the CompositeModel test, BartModel sub-modules have _tied_weights_keys as a list (e.g., ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']), leading to AttributeError: 'list' object has no attribute 'items'.The error manifests in the merged state (PR branch + main), where the CompositeModel test combines with main's sub-module iteration code. I've also gone through a detailed timeline on how things evolved here . i paid special attention to the transformers v5 update or any related fixes , since i didn't want anything to break because of that here . Finally i ran the same tests as above here after merging changes from the main with and without the fix . Without fix i was able to replicate the error and with the fix , it was resolved . I was also able to replicate your error as well , and the fix seems to resolve it . So i think that the fix is proper and also ensures compatibility with:
So i think you should be able to replicate the error now and verify it as well . Hence i would ask you if you could rerun the same script again , once in full and once commenting out the fix , and inform me accordingly if you are able to replicate / reproduce the error now or not . based on your observations , i will take the necessary actions accordingly . in case it still doesn't work for you , perhaps could you point to some potential points you think i should look at to debug this , perhaps specifically wrt v5 transformers if necessary ? |
|
Ah I see now, I had to delete my tracking branch of your branch and re-create it, for some reason it was not up-to-date. The issue only occurs with transformers < v5 (or main branch) and the cause is that we're mixing old style What it means for your test is that you should ensure that |
|
Hi @BenjaminBossan, If i have misunderstood something , or any other changes are needed , kindly inform , i'll take care of them as well asap . |
Implement ensure_weight_tying for trainable_token_indices
Summary
This PR implements consistent weight tying behavior for
trainable_token_indicesas specified in issue #2864. It extends theensure_weight_tyingparameter (introduced in PR #2803) to work withtrainable_token_indices, providing users explicit control over weight tying between embeddings and LM head.Fixes #2864 (trainable_token_indices portion)
Problem Statement
Background
PEFT models sometimes need to handle tied weights between embedding layers and LM head layers (when
tie_word_embeddings=True). Theensure_weight_tyingparameter was introduced in PR #2803 to give users explicit control over this behavior formodules_to_save. However, the same control was missing fortrainable_token_indices.The Issue
Issue identified that the weight tying behavior for
trainable_token_indiceswas not consistent across different scenarios. Specifically, there were four cases that needed to be implemented:Solution Approach
Implementation Strategy:
Changes Made
1. Updated Configuration Documentation
File:
src/peft/tuners/lora/config.pyUpdated the
ensure_weight_tyingparameter docstring to clarify that it now applies to bothmodules_to_saveandtrainable_token_indices, making the documentation consistent with the implementation.2. Implemented Weight Tying Logic
File:
src/peft/utils/other.pyAdded comprehensive logic within the existing
trainable_token_indiceshandling block:Key Components:
ensure_weight_tying=FalseFour Cases Implemented:
Case 1 - Warning for Untied Models:
weights_tied=False+ensure_weight_tying=TrueCase 2 - Error for Contradictory Configuration:
weights_tied=True+ensure_weight_tying=True+ different indicesCase 3 - Backwards Compatibility:
weights_tied=True+ensure_weight_tying=False+ different indicesCase 4 - Apply Tying:
3. Comprehensive Test Suite
File:
tests/test_trainable_tokens.pyAdded 7 new test methods covering all scenarios:
Test Coverage:
test_ensure_weight_tying_warns_when_model_not_tied_list_format: Verifies warning for list formattest_ensure_weight_tying_warns_when_model_not_tied_dict_format: Verifies warning for dict formattest_weight_tying_bc_different_indices_treated_separately: Verifies backwards compatibilitytest_ensure_weight_tying_errors_with_different_indices: Verifies error for contradictory configtest_ensure_weight_tying_applied_with_same_indices: Verifies tying with same indicestest_weight_tying_bc_same_indices_applied: Verifies BC for same indicestest_ensure_weight_tying_with_single_layer: Verifies list format tyingTesting Results
New Tests
All 7 new tests pass successfully:
test_ensure_weight_tying_warns_when_model_not_tied_list_formattest_ensure_weight_tying_warns_when_model_not_tied_dict_formattest_weight_tying_bc_different_indices_treated_separatelytest_ensure_weight_tying_errors_with_different_indicestest_ensure_weight_tying_applied_with_same_indicestest_weight_tying_bc_same_indices_appliedtest_ensure_weight_tying_with_single_layerBackwards Compatibility
This implementation maintains full backwards compatibility:
✅ Default Behavior Unchanged:
ensure_weight_tyingdefaults toFalse, preserving existing behavior✅ No Breaking Changes: Existing code continues to work without modification
✅ Opt-in Enhancement: Users must explicitly set
ensure_weight_tying=Trueto use new features✅ BC Mode Preserved: When
ensure_weight_tying=False, existing automatic tying still works for compatible configurationsScreenshots
Checklist
cc: @BenjaminBossan