Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add YaRN and Dynamic-YaRN RoPE Scaling Methods #30910

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mig-mfreitas
Copy link

What does this PR do?

YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes.

Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments.

We implement YaRN and Dynamic-YaRN for the following list of models:

  • LLaMA
  • Falcon
  • GPT-NeoX
  • Olmo
  • Persimmon
  • Phi
  • StableLM
  • OpenLLaMA

New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs.

For more details, please refer to https://arxiv.org/abs/2309.00071.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

YaRN (Yet another RoPE extension method) combines the NTK-By-Parts
Interpolation and Attention Scaling methods, improving upon existing
RoPE interpolation methods for longer context window sizes.

Fine-tuned models maintain their original performance across benchmarks
while enabling efficient extrapolation and transfer learning for
quicker convergence, especially in compute-limited environments.

We implement YaRN and Dynamic-YaRN for the following list of models:

 - LLaMA
 - Falcon
 - GPT-NeoX
 - Olmo
 - Persimmon
 - Phi
 - StableLM
 - OpenLLaMA

New unit tests are added to assert YaRN's correct behavior on both
short and long sequence inputs.

For more details, please refer to https://arxiv.org/abs/2309.00071.

Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>
@amyeroberts
Copy link
Collaborator

cc @ArthurZucker too

@ArthurZucker
Copy link
Collaborator

Hey! Thanks a lot for taking the time to implement this! 🤗
cc @gante with this and phi3's new method, not sure it makes sense to add all of them in the modeling code, WDYT?

This is unrelated to Llama so we might need some modularity for this!

@miguelm-almeida
Copy link
Contributor

🤗

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Olá Miguel e Miguel! 👋 (@miguelm-almeida @mig-mfreitas )

I have a couple of requests regarding user experience and recent changes in our repo. My number 1 suggestion would be to delete the diff in all models except Llama, and leave the model copies for another PR. It's much faster for everyone (you and me) to iterate over a model, and then copy the design when we're happy 🤗 In this particular case (RoPE models), we also have different implementations that we need to iron out on our end before adding Yarn there.

⚠️ please treat all my comment as if there were made on Llama, and not on Falcon. Some of the suggested changes only work on architectures that are up to date, like Llama (and unlike Falcon)

Finally: one of the goals of this PR should be to be able to load the original YaRN models using transformers. Currently, there are some models on the Hub that have custom code (e.g. https://huggingface.co/NousResearch/Yarn-Llama-2-7b-64k). At the moment, these models require adding trust_remote_code=True to from_pretrained (which loads the custom code in the repo). With this PR, we remove the need for that flag and would be using the code in transformers instead :)

Ping me if you have further questions (and feel free to ping me by email if I'm taking to long to reply) 🤗

@@ -66,13 +66,31 @@ class OpenLlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

models on the deprecated folder should not be updated :) (let's remove the changes on open_llama)

`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two comments:

  1. The contents of yarn_rope_scaling should be part of rope_scaling

A single config dict for everything related to RoPE scaling is preferable so we can easily upgrade it to a standalone config class in a future PR :)

It would also allow loading existing models by the original authors, e.g. https://huggingface.co/NousResearch/Yarn-Llama-2-7b-64k (have a look at their custom config code and the model's config file -- both assume all rope scaling params are in rope_scaling)


  1. We have found through experience that the best default in config files is no default :) That way, we (huggingface):
    a) don't have to push changes to repositories in the hub in case we find bugs
    b) we can easily distinguish defaults from user-defined values that happen to be equal to the default

If the point in 1. is addressed, then no change is needed to the existing default (None). Defaults in the classes and the validation code are very helpful, though!

@@ -201,3 +229,55 @@ def _rope_scaling_validation(self):
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation
def _yarn_rope_scaling_validation(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise, the contents of this function should be moved into _rope_scaling_validation, and the flags should only be checked if the rope scaling method is a yarn one and the flags exist in the dictionary

@@ -162,12 +188,14 @@ def __init__(
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.yarn_rope_scaling = yarn_rope_scaling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.yarn_rope_scaling = yarn_rope_scaling

(as per the comment above)

Comment on lines +94 to +95
extrapolation_factor (`float`, defaults to 1):
Factor to ajust the n-dimensional rotational scaling for extrapolation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this parameter (extrapolation_factor) is in the original implementation. However, if we dig further, we can see that it is not used in practice (unless I'm missing something -- feel free to correct me!):

  1. The default value of 1. does not change the computation
  2. There are no references to it in the yarn paper;
  3. I couldn't find any Yarn model on the hub that has set this parameter in config.json, meaning the default 1 is always used;
  4. All references in the original repo use the default value
  5. In an older PR, the author writes "extrapolation_factor and ntk_factor are used for validation purposes, and should not be changed unless it is necessary."

As such, I believe we can:

  1. delete this variable from the config
  2. delete all related code :)

return 1.0
return 0.1 * math.log(scaling_factor) + 1.0

def forward(self, x, seq_len=None):
Copy link
Member

@gante gante May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the pattern we have in LlamaRotaryEmbedding.forward -- the pattern was changed a few minor versions ago from the one you have here, where sin and cos are cached, to a different one. The new pattern is faster and is compatible with torch.compile.

From a quick glance: I think you may be able to call super().forward and simply apply * self.mscale on the results

Parameter to set the boundary for extrapolation (only) in the linear ramp function.
beta_slow (`float`, *optional*, defaults to 1):
Parameter to set the boundary for interpolation (only) in the linear ramp function.
finetuned (`bool`, *optional*, defaults to `False`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be removed (see the comments on the new dynamic class)

self._sin_cached[:seq_len, ...].to(dtype=x.dtype),
)

def yarn(self, device):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def yarn(self, device):
def compute_yarn_scaling(self, device):

(Or a similar name. Let's use descriptive function names :) )

device,
)

if finetuned:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if finetuned:
if self.max_position_embeddings != self.original_max_position_embeddings:

This should be true for fine-tuned models, saving us a flag :)

Comment on lines +526 to +529
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also check that yarn_sin/cos_short != original_sin/cos_short (i.e. that applying yarn should change all values)

@miguelm-almeida
Copy link
Contributor

Thank you very much for this in-depth review and suggestions! We'll iterate on it and reach back shortly 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants