Skip to content

Instantly share code, notes, and snippets.

@kyo-takano
Created March 10, 2024 04:40
Show Gist options
  • Save kyo-takano/30ca6b5b081a95b0308ed750139826db to your computer and use it in GitHub Desktop.
Save kyo-takano/30ca6b5b081a95b0308ed750139826db to your computer and use it in GitHub Desktop.
Cheat Sheet: PyTorch on TPU

Cheatsheet: Migrating a PyTorch script to a single TPU

Mar 10, 2024.

See pytorch.org/xla for up-to-date info and implementation with multiple TPUs

Installing PyTorch/XLA for TPU

# Usually pre-installed on TPU instances
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

Add/modify a few lines of code

import torch
import torch_xla.core.xla_model as xm
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = xm.xla_device()
# optional: automatic mixed precision
ctx = torch.autocast(device, dtype=torch.bfloat16)

model = Net(...)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), ...)
dataloader = ...

for x, y in dataloader:
    optimizer.zero_grad()
    x, y = x.to(device), y.to(device)
    with ctx:
        loss = loss_fn(model(x), y)
    loss.backward()
    optimizer.step()
    if device.type == "xla":
        # We must execute the computational graph because
        # XLA tensors are lazy, contrary to CPU/GPU tensors
        # (https://pytorch.org/xla/release/2.2/index.html#xla-tensors-are-lazy)
        xm.mark_step()

Warning

Pitfall! If you move a model to TPU after defining its optimizer like the following, the model parameters will not update.

model = Net(...)
optimizer = torch.optim.AdamW(model.parameters(), ...)
model.to(device)

This is most likely because the optimizer doesn't get references to the model parameters (a bug), even though it works perfectly fine on CPUs/GPUs.

Saving a model

# torch.save(module.state_dict(), path)
xm.save(module.state_dict(), path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment