Jax on LUMI-G

The instructions that follow describe the installation of Jax on LUMI-G, utilizing an overlay and EasyBuild. The procedure is based on the following resources:

  • An EasyBuild configuration file for Jax:

wget https://462000008.lumidata.eu/easyconfigs/jax-0.4.13-rocm-5.6.1-python-3.10-singularity-20231130.eb
  • A script for creating an overlay for the Jax container:

wget https://462000008.lumidata.eu/userscripts/create_overlay.sh

Preparation: Set the Location for Your EasyBuild installation

Before you start installing software with EasyBuild on LUMI-G, it’s essential to define where your installations should go. The procedure below details the necessary steps. If you have already installed software with EasyBuild on LUMI-G, you can skip this section.

Step 1: Define EasyBuild Installation Directory

By default, EasyBuild installs software in the $HOME/EasyBuild directory. To avoid consuming the limited storage space in your home directory, it is recommended to install software within your project directory. This not only helps manage storage quotas but also allows you to share installations with your project team.

Update the EBU_USER_PREFIX environment variable to point to your chosen directory within the /project/project_* or /project/project_*/username paths:

Add the following to your .bashrc file:

export EBU_USER_PREFIX=/project/project_465000000/EasyBuild

or, for a user-specific installation:

export EBU_USER_PREFIX=/project/project_465000000/username/EasyBuild

Ensure you replace 465000000 with your actual project number, and username with your real username.

By setting the EBU_USER_PREFIX variable to a directory within your project space, you alleviate space constraints in your home directory and promote resource sharing among your project team members.

Step 2: Load the LUMI Software Stack

Loading the LUMI software stack ensures you’re working with a consistent software environment. For example, to load the LUMI/23.09 stack, use the following command:

module load LUMI/23.09

This command should automatically load the appropriate partition module.

Step 3: Load the EasyBuild Module

Now, load EasyBuild with the following module command, which will confirm the installation directory:

module load EasyBuild-user

To verify EasyBuild’s configuration:

eb --show-config

Tip for users with multiple projects: Do not change EBU_USER_PREFIX when a LUMI module is loaded.

For more details and advanced configurations, refer to the LUMI documentation.

Installing jax to your project environment using easybuild

On Lumi-G locate a singularity image with Jax in the folder:

/appl/local/containers/easybuild-sif-images/

Find a suitable image, at the time of this writing the image:

lumi-jax-rocm-5.6.1-python-3.10-jax-0.4.13-dockerhash-1e625e0cfb23.sif

is available.

Create an easybuild config file, e.g. jax-0.4.13-rocm-5.6.1-python-3.10-singularity.eb with the following content:

 1easyblock = 'MakeCp'
 2
 3name = 'jax'
 4version = '0.4.13'
 5versionsuffix = '-rocm-5.6.1-python-3.10-singularity-20231130'
 6
 7local_sif = 'lumi-jax-rocm-5.6.1-python-3.10-jax-0.4.13-dockerhash-1e625e0cfb23.sif'
 8
 9homepage = 'https://jax.readthedocs.io/'
10
11whatis = [
12    'Description: JAX is Autograd and XLA, brought together for high-performance numerical computing.'
13]
14
15description = """
16This module provides a container with jax %(version)s.
17
18The module defines a number of environment variables:
19* SIF and SIFJAX: The full path and name of the Singularity SIF file
20  to use with singularity exec etc.
21* SINGULARITY_BINDPATH: Mounts the necessary directories from the system,
22  including /users, /project, /scratch, and /flash so that you should be
23  able to use your regular directories in the container.
24* RUNSCRIPTS and RUNSCRIPTSJAX: The directory with some sample
25  runscripts.
26
27Note that this container uses a Conda environment internally. When in
28the container, the command to activate the container is contained in the
29environment variable WITH_CONDA.
30"""
31
32toolchain = SYSTEM
33
34sources = [
35    {
36        'filename': local_sif,
37        'extract_cmd': '/bin/cp %s .'
38    },
39]
40
41skipsteps = ['build']
42
43files_to_copy = [
44    ([local_sif], '.'),
45]
46
47local_runscript_python_simple="""
48#!/bin/bash -e
49
50# Start conda environment inside the container
51\$WITH_CONDA
52
53# Run application
54python "\$@"
55"""
56
57postinstallcmds = [
58    'mkdir -p %(installdir)s/runscripts',
59    f'cat >%(installdir)s/runscripts/conda-python-simple <<EOF {local_runscript_python_simple}EOF',
60    'chmod a+x %(installdir)s/runscripts/conda-python-simple',
61]
62
63modextravars = {
64    'RUNSCRIPTS': '%(installdir)s/runscripts',
65    'RUNSCRIPTSJAX': '%(installdir)s/runscripts',
66    'SINGULARITY_BIND': '/var/spool/slurmd,/opt/cray,/usr/lib64/libcxi.so.1,/usr/lib64/libjansson.so.4,' +
67                        '%(installdir)s/runscripts:/runscripts,' +
68                        '/pfs,/scratch,/projappl,/project,/flash,/appl',
69}
70
71sanity_check_paths = {
72    'files': ['runscripts/conda-python-simple'],
73    'dirs': [],
74}
75
76modluafooter = f"""
77-- Call a routine to set the various environment variables.
78create_container_vars('{local_sif}', 'jax', '%(installdir)s')
79"""
80
81moduleclass = 'devel'

In the highlighted line, change the name of the singularity image to the one you want to use.

Then, install the module using easybuild:

module load LUMI/23.09
module load partition/container
module load EasyBuild-user
eb jax-0.4.13-rocm-5.6.1-python-3.10-singularity.eb

Creating an Overlay for the Jax Container

In cases where you need to install additional packages to the Jax container environment provided by LUMI-G, an overlay filesystem can be created. Here’s how to do it:

  1. Ensure you have already installed and loaded the Jax container module using the previous steps outlined for installing Jax via EasyBuild.

  2. Save the following script as create_overlay.sh. In this example it installs packages like mpi4py, mpi4jax, PyTorch, torchvision, and torchaudio.

#/bin/bash

handle_error() {
  print_fail
  echo -e "\nsee createoverlay.log\n"
}

print_ok() {
  echo -e "\033[50D\033[45C[$(tput setaf 2) OK $(tput sgr0)]"
}

print_fail() {
  echo -e "\033[50D\033[45C[$(tput setaf 1)FAIL$(tput sgr0)]"
}

print_step_start() {
  echo -en " "$1
}

print_instruction() {
  echo -e "\n $(tput setaf 11)$1$(tput sgr0)\n"
}

print_keyvalue() {
  printf "$(tput setaf 4)%20s:$(tput sgr0) %s\n" "$1" "$2"
}

exec 2> createoverlay.log
set -xeuo pipefail
trap 'handle_error' ERR

MPI4PI_VERSION=3.1.5
MPI4JAX_VERSION=0.3.14.post7
PYTORCH_VERSION=2.1.2
TORCHVISION_VERSION=0.16.2
TORCHAUDIO_VERSION=2.1.2

echo -e "\n Will create an overlay with packages:\n"

print_keyvalue "mpi4py" $MPI4PI_VERSION
print_keyvalue "mpi4jax" $MPI4JAX_VERSION
print_keyvalue "pytorch" $PYTORCH_VERSION
print_keyvalue "torchvision" $TORCHVISION_VERSION
print_keyvalue "torchaudio" $TORCHAUDIO_VERSION
echo

print_step_start "Checking if the jax module is loaded"

if [[ ! ":$LOADEDMODULES:" == *":jax/0.4.13-rocm-5.6.1-python-3.10-singularity-20231130:"* ]]
then
  print_fail
  print_instruction Please install and/or load the jax container module
else
  print_ok
fi

print_step_start "Checking if the \$SIF variable is set"

if [[ -z ${SIF+x} ]]
then
  print_fail
  print_instruction "The jax module is loaded but the \$SIF environment variable is not set"
  exit
else
  print_ok
  echo
fi

print_keyvalue "base container" $(basename $SIF)
echo

STARTDIR=$PWD

print_step_start "Creating build directory"

rm -rf $XDG_RUNTIME_DIR/overlay_tmp

mkdir $XDG_RUNTIME_DIR/overlay_tmp
cd $XDG_RUNTIME_DIR/overlay_tmp

print_ok

print_step_start "Downloading source"

curl -LO https://github.com/mpi4jax/mpi4jax/archive/refs/tags/v$MPI4JAX_VERSION.tar.gz
curl -LO https://github.com/mpi4py/mpi4py/releases/download/$MPI4PI_VERSION/mpi4py-$MPI4PI_VERSION.tar.gz

print_ok
print_step_start "Writing installation script"

cat <<EOF > install.sh
\$WITH_CONDA

set -xeuo pipefail

export TMPDIR=\$PWD

export INSTALLDIR=\$PWD/staging/\$(dirname \$(dirname \$(which python)) | cut -c2-)
export PYMAJMIN=\$(python -c "import sys; print(f'{sys.version_info[0]}.{sys.version_info[1]}')")
export PYTHONPATH=\$INSTALLDIR/lib/python\$PYMAJMIN/site-packages

tar xf mpi4py-$MPI4PI_VERSION.tar.gz
cd mpi4py-$MPI4PI_VERSION
cp /opt/cray/pe/python/3.10.10/lib/python3.10/site-packages/mpi4py/mpi.cfg .
pip install --no-cache-dir --prefix=\$INSTALLDIR --no-build-isolation .
cd ..

tar xf v$MPI4JAX_VERSION.tar.gz
cd mpi4jax-$MPI4JAX_VERSION/
pip install --no-cache-dir --prefix=\$INSTALLDIR --no-build-isolation .
cd ..

pip3 install --no-cache-dir --prefix=\$INSTALLDIR \\
             --index-url https://download.pytorch.org/whl/rocm5.6 \\
             torch==$PYTORCH_VERSION+rocm5.6 torchvision==$TORCHVISION_VERSION+rocm5.6 torchaudio==$TORCHAUDIO_VERSION+rocm5.6

exit
EOF

chmod u+x ./install.sh

print_ok
print_step_start "Running installation script"

singularity exec -B$PWD $SIF ./install.sh 1>&2

print_ok
print_step_start "Creating SquashFS image"

chmod -R 777 staging/
mksquashfs staging/ overlay.squashfs -no-xattrs -processors 16 1>&2
rm -rf staging

print_ok
print_step_start "Testing"

for package in mpi4py mpi4jax torch
do
  singularity exec --overlay=overlay.squashfs $SIF /runscripts/conda-python-simple -c "import $package"
done

print_ok
print_step_start "Cleaning up"

cp overlay.squashfs $STARTDIR/
cd $STARTDIR

rm -rf $XDG_RUNTIME_DIR/overlay_tmp

print_ok

echo -e "\n$(tput setaf 4)Overlay written to$(tput sgr0) $STARTDIR/overlay.squashfs\n"

print_instruction "To run with the overlay, use the following commands:"

cat <<EOF
 module load LUMI/$LUMI_STACK_VERSION
 module load jax/0.4.13-rocm-5.6.1-python-3.10-singularity-20231130
 singularity exec --overlay=$STARTDIR/overlay.squashfs \$SIF /runscripts/conda-python-simple <python-script.py>

EOF
  1. Make the script executable by running chmod +x create_overlay.sh.

  2. Execute the script by running ./create_overlay.sh. The script will run several checks to ensure you have loaded the appropriate Jax environment and that the SIF (Singularity Image File) environment variable is set. It will then proceed to download the required packages, run an installation script inside the singularity container, and create a SquashFS image containing all installed packages.

  3. The generated overlay.squashfs file will be placed in the starting directory. You can then use this overlay with the following commands:

module load LUMI/$LUMI_STACK_VERSION
module load jax/0.4.13-rocm-5.6.1-python-3.10-singularity-20231130
singularity exec --overlay=$STARTDIR/overlay.squashfs \$SIF /runscripts/conda-python-simple <python-script.py>

Replace <python-script.py> with the path to your Python script.

Remember to replace $LUMI_STACK_VERSION with the appropriate version number for your LUMI installation.

  1. After creating the overlay, you can use it to execute your Python scripts in an environment with the installed packages. The overlay contains everything needed to run mpi4py, mpi4jax, PyTorch, torchvision, and torchaudio with the Jax container.

  2. To ensure everything is working correctly, the script includes a testing step where it tries to import the installed packages using Python within the Singularity container.

  3. Once the overlay has been tested and confirmed to work, the script will perform cleanup, removing the temporary build directory and leaving you with just the overlay.squashfs file.

  4. It is advisable to document the versions of the packages included in your overlay and the Jax container version you are using. This will help maintain reproducibility and ease future updates or modifications to your computational environment.

  5. The overlay creation script logs all its activity to a file named createoverlay.log. If you encounter any issues during the creation of the overlay, this log file will help diagnose what went wrong.

In summary, by using the create_overlay.sh script, you can easily augment the computational environment provided by the Jax container with additional packages that are necessary for your work. This approach provides a flexible and reproducible way to manage your software dependencies on LUMI-G.

This concludes the guide on how to create an overlay for the Jax container on LUMI-G. Remember to adjust the package versions in the script as necessary to fit the requirements of your specific project or workflow.

Running JAX Jobs on LUMI-G with Slurm

In the following we will show how to run a JAX job on LUMI-G using Slurm. We will use the example Python script provided in your guide to demonstrate how to execute a distributed JAX computation using pmap accross multiple GPUs on multiple nodes.

The example Python script performs a simple convolution operation on an array in a distributed manner. It demonstrates the process index assignment, input array preparation, and the distributed execution of a JAX operation using pmap. In the end of the script the results are gathered on the root process and printed using MPI’s gather function.

Example Python Script for Distributed JAX Computation

 1import jax
 2import numpy as np
 3import jax.numpy as jnp
 4
 5from mpi4py import MPI
 6
 7# Initialize MPI
 8comm = MPI.COMM_WORLD
 9rank = comm.Get_rank()
10
11
12jax.distributed.initialize()
13
14def convolve(x, w):
15    output = []
16    for i in range(1, len(x) - 1):
17        output.append(jnp.dot(x[i-1:i+2], w))
18    return jnp.array(output)
19
20# Determine the process index and the number of processes
21process_index = jax.process_index()
22n_devices = jax.process_count()
23
24# Create the input arrays based on the process index
25# Each process gets its own unique slice of data.
26# For example, if process_index is 0, it will select the first row; if 1, the second row, and so on.
27xs = np.arange(5 * n_devices).reshape(n_devices, 5)[process_index:process_index+1]
28ws = np.array([2., 3., 4.])  # Shared weights array
29
30# Ensure ws has an extra dimension to match the shape of xs
31ws = ws.reshape(1, -1)
32
33# Now we can apply pmap with the corrected data shape
34distributed_convolve = jax.pmap(convolve, axis_name='p')
35
36# Using the convolve operation in a distributed context
37print(f"----- xs (process_index={process_index}) -----")
38print(xs)
39print("----- ws -----")
40print(ws)
41print("----- distributed_convolve(xs, ws) -----")
42print(distributed_convolve(xs, ws))
43
44
45# Each process executes its portion of the distributed convolve operation
46local_output = distributed_convolve(xs, ws)
47
48# Convert the local JAX array to a NumPy array on each process
49local_output_np = np.array(local_output)
50
51# Gather all the results on the root process
52if rank == 0:
53    # Prepare a container to hold the received data from all processes
54    # The size is the total number of processes times the size of each local_output
55    gathered_outputs = np.empty([n_devices, *local_output_np.shape[1:]], dtype=local_output_np.dtype)
56else:
57    gathered_outputs = None
58
59# Use MPI's gather function to collect all arrays on the root process
60comm.Gather(local_output_np, gathered_outputs, root=0)
61
62# Now the root process can print the combined array
63if rank == 0:
64    print("Gathered outputs on root process:")
65    print(gathered_outputs)

Preparing Your Slurm Job Script

Below is an example Slurm job script which demonstrates how to execute a distributed JAX computation using pmap with MPI. Save the script as run_jax_job.sh and make sure to adapt the parameters to suit the needs of your job, like –account, which should specify your own project account.

 1#!/bin/bash -e
 2#SBATCH --job-name=distributed_jax_example
 3#SBATCH --account=project_465000000
 4#SBATCH --time=00:04:00
 5#SBATCH --partition=standard-g
 6#SBATCH --nodes=4
 7#SBATCH --ntasks-per-node=8
 8#SBATCH --gpus-per-node=8
 9
10module purge
11module load LUMI/23.09
12module load jax
13
14# Specifying CPU binding for each task to improve performance
15CPU_BIND="map_cpu:49,57,17,25,1,9,33,41"
16
17n_nodes=$SLURM_NNODES  # Number of nodes allocated
18n_gpus=$((n_nodes*$SLURM_GPUS_PER_NODE))  # Total number of GPUs
19n_tasks=$n_gpus  # Number of tasks is equal to the number of GPUs
20
21# Run the job using srun
22srun \
23  -N $n_nodes \
24  -n $n_gpus \
25  --cpu-bind=${CPU_BIND} \
26  --gpus-per-node=$SLURM_GPUS_PER_NODE \
27  singularity exec \
28    --overlay=path-to-overlay/overlay.squashfs \
29    $SIFJAX \
30    bash -c '$WITH_CONDA; python distributed_pmap_mpi.py'
  • On line 3 replace project_465000000 with your own project account number.

  • On lines 6-7, adjust the number of nodes. Keep 8 GPUs per node as the entire node is allocated to your job on the standard-g partition.

  • On lines 10-12 we load the LUMI and JAX modules and purge any previously loaded modules.

  • On line 15, we specify the CPU binding for each task for more info see LUMI documentation

  • On lines 22-30 we run the job using srun, specifying the number of nodes, tasks, and GPUs. We also specify the overlay file and the JAX container image to use. Finally, we execute the Python script distributed_pmap_mpi.py using the WITH_CONDA environment variable.

Submitting Your Job

Once you have your Slurm job script configured, you can submit your job by executing the following command in the terminal:

sbatch run_jax_job.sh