Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to run FLAN-T5 inference on GCP TPU v3 (TF 2.16.1) #30901

Open
2 of 4 tasks
sumanthratna opened this issue May 20, 2024 · 10 comments
Open
2 of 4 tasks

Unable to run FLAN-T5 inference on GCP TPU v3 (TF 2.16.1) #30901

sumanthratna opened this issue May 20, 2024 · 10 comments
Labels
Core: Tokenization Internals of the library; Tokenization. TensorFlow Anything TensorFlow TPU

Comments

@sumanthratna
Copy link

sumanthratna commented May 20, 2024

System Info

  • transformers version: 4.41.0
  • Platform: Linux-5.19.0-1030-gcp-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): not installed (NA)
  • Tensorflow version (GPU?): 2.16.1 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no, using TPU
  • Using distributed or parallel set-up in script?: no, using TPU VM

Who can help?

@gante, @Rocketknight1, @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. gcloud compute tpus tpu-vm create my-tpu-vm --zone=us-central1-a --accelerator-type=v3-8 --version=tpu-vm-tf-2.16.1-pjrt
  2. gcloud compute tpus tpu-vm ssh --zone "us-central1-a" "my-tpu-tvm" --project $GCP_PROJECT_NAME
  3. python3 -m pip install transformers sentencepiece
export PJRT_DEVICE=TPU
export TPU_NAME=local
export NEXT_PLUGGABLE_DEVICE_USE_C_API=true
export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so

cat > test-hf.py <<EOF
from transformers import AutoTokenizer, TFT5ForConditionalGeneration
import tensorflow as tf

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

with strategy.scope():
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
    model = TFT5ForConditionalGeneration.from_pretrained("google/flan-t5-small")

    task_prefix = "translate English to Romanian: "
    sentences = ["The house is wonderful.", "I like to work in NYC.", "My favorite food is pizza; what is yours?"]

    inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="tf", padding=True)
    print("inputs", inputs)

    xla_generate = tf.function(model.generate, jit_compile=True)

output_sequences = strategy.run(
    xla_generate,
    args=(),
    kwargs={
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "do_sample": False,
    },
)
print("output_sequences", output_sequences)
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
EOF
  1. python3 test-hf.py

traceback:

inputs {'input_ids': <tf.Tensor: shape=(3, 18), dtype=int32, numpy=
array([[13959,  1566,    12,  3871,    29,    10,    37,   629,    19,
         1627,     5,     1,     0,     0,     0,     0,     0,     0],
       [13959,  1566,    12,  3871,    29,    10,    27,   114,    12,
          161,    16, 13465,     5,     1,     0,     0,     0,     0],
       [13959,  1566,    12,  3871,    29,    10,   499,  1305,   542,
           19,  6871,   117,   125,    19,    39,     7,    58,     1]],
      dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(3, 18), dtype=int32, numpy=
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
      dtype=int32)>}
2024-05-20 04:38:33.191211: I tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc:265] Subgraph fingerprint:7549819359939688275
2024-05-20 04:38:33.323311: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node StatefulPartitionedCall.
output_sequences None
Traceback (most recent call last):
  File "/home/sumanthratna/test-hf2.py", line 30, in <module>
    print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
  File "/home/sumanthratna/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 3796, in batch_decode
    return [
TypeError: 'NoneType' object is not iterable

The same behavior occurs when using StreamExecutor rather than PJRT. The same behavior occurs when removing jit_compile=True from tf.function(model.generate, jit_compile=True).

Expected behavior

I expect generation to succeed and yield the following final output:

['Houseul este minunat.', 'Doresc să lucreze în NYC.', 'Favorite mâncare este pizza; ceea ce este dumneavoastră?']
@sumanthratna sumanthratna changed the title Unable to run FLAN-T5 inference on GCP TPU v3 Unable to run FLAN-T5 inference on GCP TPU v3 (TF 2.16.1) May 20, 2024
@amyeroberts amyeroberts added TensorFlow Anything TensorFlow Core: Tokenization Internals of the library; Tokenization. TPU labels May 20, 2024
@ArthurZucker
Copy link
Collaborator

Hey! Looking at the traceback, and the code, it seems like you are using custom code + strategy = tf.distribute.TPUStrategy(resolver).
I won't have time to look into it, maybe @Rocketknight1 will? Otherwise would recommend you to ask on the forum!

@Rocketknight1
Copy link
Member

cc @sayakpaul - do we have any examples for XLA generation on TPU?

Also, one thing I'd point out is that in general, jit_compile=True actually causes problems on TPU, because XLA compilation happens implicitly inside a TPUStrategy.scope(). Have you tried without that?

@sayakpaul
Copy link
Member

No, I don't think. Additionally, I concur with your suggestions here:

Also, one thing I'd point out is that in general, jit_compile=True actually causes problems on TPU, because XLA compilation happens implicitly inside a TPUStrategy.scope(). Have you tried without that?

@Rocketknight1
Copy link
Member

Understood! @sumanthratna let me know if you can't get it working, and I'll try to reproduce the issue here and make a working example

@sumanthratna
Copy link
Author

Thanks for the replies! @sayakpaul @Rocketknight1 I did try without jit_compile and saw the same behavior (see my note at the bottom of “Reproduction”)

It may also be relevant that I’m not able to run the above code using just XLA (without tf Strategy) — will post reproduction here shortly

@Rocketknight1
Copy link
Member

Hi @sumanthratna, yeah - our recommendation for TPU debugging is to first get the code working on CPU/GPU with jit_compile=True to enable XLA, then remove jit_compile=True and run it on TPU.

We have a guide specifically on XLA generation with TensorFlow here, which might help!

@sumanthratna
Copy link
Author

sumanthratna commented May 20, 2024

our recommendation for TPU debugging is to first get the code working on CPU/GPU with jit_compile=True to enable XLA, then remove jit_compile=True and run it on TPU.

XLA (using CPU) on a CPU machine

See here for my successful go at running inference with XLA on CPU: https://colab.research.google.com/drive/1KOrB7DBm92isAfvsiQcaMUs1uZGCQyUp?usp=sharing.

XLA (using CPU) on a TPU VM

when I run the above code from the notebook (without jit_compile), I see this:

2024-05-20 14:06:19.470300: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at xla_ops.cc:580 : NOT_FOUND: could not find registered transfer manager for platform Host -- check target linkage
Traceback (most recent call last):
  File "/home/sumanthratna/test-hf.py", line 13, in <module>
    output_sequences = xla_generate(
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.NotFoundError: could not find registered transfer manager for platform Host -- check target linkage [Op:__inference_generate_7515]
I0520 14:06:20.563726256    8141 work_stealing_thread_pool.cc:269]     WorkStealingThreadPoolImpl::Quiesce
D0520 14:06:20.564341113    8141 init.cc:165]                          grpc_shutdown starts clean-up now

@Rocketknight1
Copy link
Member

Yeah - cc @gante, did we ever try XLA generation on TPU? Should we expect it to work at all, or would users need something more like a simplified manual generation loop?

@gante
Copy link
Member

gante commented May 29, 2024

@Rocketknight1 nope, at least I haven't :D

@Rocketknight1
Copy link
Member

Hmn, okay - I'm afraid this is just really untested @sumanthratna! If you get it working, please let us know, and we can document it somewhere, but you might have to code some kind of manual generation loop for TPU instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Tokenization Internals of the library; Tokenization. TensorFlow Anything TensorFlow TPU
Projects
None yet
Development

No branches or pull requests

6 participants