Skip to solution.
jax is a relatively new contender for top-spot amongst the machine learning libraries, rivaled by the (currently) more established Tensorflow and Pytorch.
I am coming around to the speed and philosophy of jax. Furthermore DeepMind and Google research teams are using jax to develop a variety of their state of the art models, so we’d better follow along!
Jax’s developers make up for any feature or API immaturity with their awesome responsiveness and helpfulness. I’ve had a few teething problems with the library, raised issues on the jax github project, and without fail received prompt and helpful responses. Here I am sharing one I got recently and I think should get more exposure because it unblocked me dramatically, and I could not find this anywhere. Relevant for at least jax==0.2.25.
Shout out to GitHub user and jax developer (amongst many other things) skye for the info in this post!
And shout out to the TPU Research Cloud – Google’s free TPU quota for researchers. It was using these machines that I ran into this problem.
How to run multiple Jax programmes, one per TPU
If your machine has more than one device, you might want to run multiple concurrent programmes on each device. Jax cannot do this by default.
Example: you are doing a hyperparameter search on a programme designed to run on a single device (which is usually the default). You have 8 devices, 32 CPU cores and plenty of RAM, so you want to run 8 experiments at once.
Don’t confuse this scenario with mapping functions / a programme over multiple devices. That can be done cleanly with the built-in pmap.
Specifically, jax cannot detect devices in a new session, once jax has been imported in another session. This means I can only run one experiment at a time, using e.g. the 4 TPU chips provided by Google cloud’s TPUs. I raised an issue to this effect.
I was pointed to a refresher on TPU architecture (fair enough!). In summary, there are two cores per chip. Jax shows each core as a TpuDevice, so it will show 8 devices (cores) when there are 4 chips available.
There are particular environment variables that will allow you to run one jax programme per chip, as follows.
DEVICE_NUM=$1 # first argument # specific to my programme - defines the experiment CONFIG_NUM=$2 # second argument # unique port for each host ADDRESS=$((8476 + DEVICE_NUM)) export TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 export TPU_HOST_BOUNDS=1,1,1 export TPU_VISIBLE_DEVICES=$DEVICE_NUM export TPU_MESH_CONTROLLER_ADDRESS=localhost:$ADDRESS export TPU_MESH_CONTROLLER_PORT=$ADDRESS # This will be specific to your programme python main.py --config-num $2
For convenience, I made a bash script to kick off experiments.
# 4x processes: 1 chip each (2 cores per chip) import os os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "1,1,1" os.environ["TPU_HOST_BOUNDS"] = "1,1,1" # Different per process: os.environ["TPU_VISIBLE_DEVICES"] = "0" # "1", "2", "3" # Pick a unique port per process os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476" os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476" # IMPORTANT - this must be imported after setting environment variables import jax print(jax.numpy.ones(1).device())