2024-05-02
I’ve had luck getting GPUs quickly on graham.computecanada.ca, so I suggest you use something else.
On the login node:
Previously I was explicitly specifying a new StdEnv, but for now, the latest one is the default. This is not necessary right now but for future reference.
module load StdEnv/2023Figure grab the least stale python and cuda
module spider python
module spider cudaGrab specific versions
module load python/3.11.5
module load cuda/12.2Or if you’re feeling lucky, I usually just module load python cuda
Create a Python virtual environment, I’ll name it jaxenv
python -m venv ~/jaxenvActivate and update it, using --no-index goes after cluster-built wheels, preventing dependency hell.
source ~/jaxenv/bin/activate
python -m pip install --upgrade pip --no-index
python -m pip install jax --no-indexGet an interactive session with a GPU just to kick the tires.
salloc --time=0:20:00 --mem=3500 --gres=gpu:1 --account=def-rsadveOnce you’re in that worker’s shell
source ~/jaxenv/bin/activate
pythonOnce you’ve got an interactive Python session
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)should say GPU.
Then sometimes I’ll run a little GPU code to see if anything exciting blows up.
In an interactive session with a GPU attached:
import jax.numpy as jnp
from jax import grad, jit, vmap, random
key = random.key(0)
x = random.normal(key, (3000,3000), dtype=jnp.float32)
jnp.dot(x, x.T).block_until_ready() # runs on the GPU