-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add YaRN and Dynamic-YaRN RoPE Scaling Methods #30910
base: main
Are you sure you want to change the base?
Conversation
YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes. Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments. We implement YaRN and Dynamic-YaRN for the following list of models: - LLaMA - Falcon - GPT-NeoX - Olmo - Persimmon - Phi - StableLM - OpenLLaMA New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs. For more details, please refer to https://arxiv.org/abs/2309.00071. Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>
cc @ArthurZucker too |
Hey! Thanks a lot for taking the time to implement this! 🤗 This is unrelated to Llama so we might need some modularity for this! |
🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Olá Miguel e Miguel! 👋 (@miguelm-almeida @mig-mfreitas )
I have a couple of requests regarding user experience and recent changes in our repo. My number 1 suggestion would be to delete the diff in all models except Llama
, and leave the model copies for another PR. It's much faster for everyone (you and me) to iterate over a model, and then copy the design when we're happy 🤗 In this particular case (RoPE models), we also have different implementations that we need to iron out on our end before adding Yarn there.
Llama
, and not on Falcon
. Some of the suggested changes only work on architectures that are up to date, like Llama
(and unlike Falcon
)
Finally: one of the goals of this PR should be to be able to load the original YaRN models using transformers
. Currently, there are some models on the Hub that have custom code (e.g. https://huggingface.co/NousResearch/Yarn-Llama-2-7b-64k). At the moment, these models require adding trust_remote_code=True
to from_pretrained
(which loads the custom code in the repo). With this PR, we remove the need for that flag and would be using the code in transformers
instead :)
Ping me if you have further questions (and feel free to ping me by email if I'm taking to long to reply) 🤗
@@ -66,13 +66,31 @@ class OpenLlamaConfig(PretrainedConfig): | |||
rope_theta (`float`, *optional*, defaults to 10000.0): | |||
The base period of the RoPE embeddings. | |||
rope_scaling (`Dict`, *optional*): | |||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling | |||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is | |||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
models on the deprecated
folder should not be updated :) (let's remove the changes on open_llama
)
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update | ||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how | ||
these scaling strategies behave: | ||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an | ||
experimental feature, subject to breaking API changes in future versions. | ||
yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two comments:
- The contents of
yarn_rope_scaling
should be part ofrope_scaling
A single config dict for everything related to RoPE scaling is preferable so we can easily upgrade it to a standalone config class in a future PR :)
It would also allow loading existing models by the original authors, e.g. https://huggingface.co/NousResearch/Yarn-Llama-2-7b-64k (have a look at their custom config code and the model's config file -- both assume all rope scaling params are in rope_scaling
)
- We have found through experience that the best default in config files is no default :) That way, we (huggingface):
a) don't have to push changes to repositories in the hub in case we find bugs
b) we can easily distinguish defaults from user-defined values that happen to be equal to the default
If the point in 1. is addressed, then no change is needed to the existing default (None
). Defaults in the classes and the validation code are very helpful, though!
@@ -201,3 +229,55 @@ def _rope_scaling_validation(self): | |||
) | |||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: | |||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") | |||
|
|||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation | |||
def _yarn_rope_scaling_validation(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
likewise, the contents of this function should be moved into _rope_scaling_validation
, and the flags should only be checked if the rope scaling method is a yarn one and the flags exist in the dictionary
@@ -162,12 +188,14 @@ def __init__( | |||
self.max_position_embeddings = max_position_embeddings | |||
self.rope_theta = rope_theta | |||
self.rope_scaling = rope_scaling | |||
self.yarn_rope_scaling = yarn_rope_scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.yarn_rope_scaling = yarn_rope_scaling |
(as per the comment above)
extrapolation_factor (`float`, defaults to 1): | ||
Factor to ajust the n-dimensional rotational scaling for extrapolation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this parameter (extrapolation_factor
) is in the original implementation. However, if we dig further, we can see that it is not used in practice (unless I'm missing something -- feel free to correct me!):
- The default value of
1.
does not change the computation - There are no references to it in the yarn paper;
- I couldn't find any Yarn model on the hub that has set this parameter in
config.json
, meaning the default1
is always used; - All references in the original repo use the default value
- In an older PR, the author writes "extrapolation_factor and ntk_factor are used for validation purposes, and should not be changed unless it is necessary."
As such, I believe we can:
- delete this variable from the config
- delete all related code :)
return 1.0 | ||
return 0.1 * math.log(scaling_factor) + 1.0 | ||
|
||
def forward(self, x, seq_len=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the pattern we have in LlamaRotaryEmbedding.forward
-- the pattern was changed a few minor versions ago from the one you have here, where sin
and cos
are cached, to a different one. The new pattern is faster and is compatible with torch.compile
.
From a quick glance: I think you may be able to call super().forward
and simply apply * self.mscale
on the results
Parameter to set the boundary for extrapolation (only) in the linear ramp function. | ||
beta_slow (`float`, *optional*, defaults to 1): | ||
Parameter to set the boundary for interpolation (only) in the linear ramp function. | ||
finetuned (`bool`, *optional*, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can also be removed (see the comments on the new dynamic class)
self._sin_cached[:seq_len, ...].to(dtype=x.dtype), | ||
) | ||
|
||
def yarn(self, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def yarn(self, device): | |
def compute_yarn_scaling(self, device): |
(Or a similar name. Let's use descriptive function names :) )
device, | ||
) | ||
|
||
if finetuned: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if finetuned: | |
if self.max_position_embeddings != self.original_max_position_embeddings: |
This should be true for fine-tuned models, saving us a flag :)
with self.assertRaises(AssertionError): | ||
torch.testing.assert_close(yarn_cos_long, original_cos_long) | ||
with self.assertRaises(AssertionError): | ||
torch.testing.assert_close(yarn_sin_long, original_sin_long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's also check that yarn_sin/cos_short != original_sin/cos_short (i.e. that applying yarn should change all values)
Thank you very much for this in-depth review and suggestions! We'll iterate on it and reach back shortly 🤗 |
What does this PR do?
YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes.
Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments.
We implement YaRN and Dynamic-YaRN for the following list of models:
New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs.
For more details, please refer to https://arxiv.org/abs/2309.00071.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante