Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS Support #67

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

MPS Support #67

wants to merge 4 commits into from

Conversation

ludsvick
Copy link

@ludsvick ludsvick commented Mar 28, 2025

🔥 Summary

Adds Metal Performance Shaders (MPS) for PyTorch to boost performance on MacOS with Apple Silicon/AMD GPUs. Also updates the Cache class with a device-aware update method.

📖 📷 Additional context

  • Updates to kv_cache now handled with update method
    • Uses static kernel ops for CUDA Graph Mode
    • Falls back to slicing on unsupported devices
  • generate and vq_vae_encode_decode updated to include mps support

🛠 Testing instructions

Mesh Generation, CPU

import torch
import trimesh
from cube3d.inference.engine import Engine

# load ckpt
config_path = "cube3d/configs/open_model.yaml"
gpt_ckpt_path = "model_weights/shape_gpt.safetensors"
shape_ckpt_path = "model_weights/shape_tokenizer.safetensors"
engine = Engine(
    config_path, 
    gpt_ckpt_path, 
    shape_ckpt_path, 
    device=torch.device("cpu"),
)

# inference
input_prompt = "A scaly gecko with a long winding tail"
mesh_v_f = engine.t2s([input_prompt], use_kv_cache=True, resolution_base=6.5, top_p=0.0)

# save output
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
_ = trimesh.Trimesh(vertices=vertices, faces=faces).export("outputs/cpu_output.obj")

Output

generating: 100%|█████████████████████████████| 512/512 [00:56<00:00,  8.98it/s]
extracting geometry: 100%|█████████████████████| 8/8 [05:28<00:00, 41.03s/chunk]

Mesh Generation, MPS

import torch
import trimesh
from cube3d.inference.engine import Engine

# load ckpt
config_path = "cube3d/configs/open_model.yaml"
gpt_ckpt_path = "model_weights/shape_gpt.safetensors"
shape_ckpt_path = "model_weights/shape_tokenizer.safetensors"
engine = Engine(
    config_path, 
    gpt_ckpt_path, 
    shape_ckpt_path, 
    device=torch.device("mps"),
)

# inference
input_prompt = "A scaly gecko with a long winding tail"
mesh_v_f = engine.t2s([input_prompt], use_kv_cache=True, resolution_base=6.5, top_p=0.0)

# save output
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
_ = trimesh.Trimesh(vertices=vertices, faces=faces).export("outputs/mps_output.obj")

Output

generating: 100%|█████████████████████████████| 512/512 [01:12<00:00,  7.08it/s]
extracting geometry: 100%|█████████████████████| 8/8 [00:17<00:00,  2.15s/chunk]

✅ Checklist

  • Provide testing instructions
  • Update relevant documentation

@ludsvick ludsvick requested a review from a team as a code owner March 28, 2025 20:48
Copy link
Collaborator

@animan42 animan42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for getting this out quick! I will test it on some cuda devices later today

@@ -125,7 +125,11 @@ def run_shape_decode(
help="Path to save the recovered mesh file.",
)
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets move device selection to a util function in https://github.com/Roblox/cube/blob/main/cube3d/inference/utils.py since its being used at-least twice

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

select_device() added to utils 👍

@animan42 animan42 linked an issue Mar 29, 2025 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Metal Performance Shader (MPS) Integration
2 participants