Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Functorch] vmap over index_select expands the output #115347

Open
vmoens opened this issue Dec 7, 2023 · 1 comment 路 May be fixed by #126680
Open

[Functorch] vmap over index_select expands the output #115347

vmoens opened this issue Dec 7, 2023 · 1 comment 路 May be fixed by #126680
Labels
actionable high priority module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vmoens
Copy link
Contributor

vmoens commented Dec 7, 2023

馃悰 Describe the bug

When calling torch.vmap over torch.index_select with a batched index, the result does not match Jax's results:

Expected

from jax import numpy as jnp
import jax
jax.vmap(lambda x, y: x[y], [None, 0])(jnp.arange(3), jnp.array([0, 2]))
# Array([0, 2], dtype=int32)

PyTorch

torch.vmap(lambda x, y: torch.index_select(x, 0, y), (None, 0))(torch.arange(3), torch.tensor([0, 2]))
# tensor([[0, 0, 0],
#              [2, 2, 2]])

Versions

PT '2.2.0.dev20231128'

cc @ezyang @gchanan @zou3519 @kadeng @Chillee @samdow @kshitij12345 @janeyx99

@zou3519 zou3519 added module: functorch Pertaining to torch.func or pytorch/functorch and removed triage review labels Dec 7, 2023
@zou3519
Copy link
Contributor

zou3519 commented Dec 7, 2023

Yeah, this is broken. Here's a test case just using PyTorch:

import torch
x = torch.arange(3)
y = torch.tensor([0, 2])

def f(x, y):
    return torch.index_select(x, 0, y)

result = torch.vmap(f, (None, 0))(x, y)
# vmap should be equivalent to a for-loop + stack over the vmapped Tensor
expected = torch.stack([f(x, y[0]), f(x, y[1])])
print(result)
print(expected)

The first place to look is

m.impl("index_select", index_select_decomp);
. If the index_select decomposition is correct, then it's possible the gather batching rule is wrong (
VMAP_SUPPORT(gather, gather_batch_rule);
)

@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Dec 11, 2023
@vmoens vmoens linked a pull request May 20, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable high priority module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants