--- orphan: true --- # Jax on LUMI-G The instructions that follow describe the installation of Jax on LUMI-G, utilizing an overlay and EasyBuild. The procedure is based on the following resources: - An EasyBuild configuration file for Jax: ``` wget https://462000008.lumidata.eu/easyconfigs/jax-0.4.13-rocm-5.6.1-python-3.10-singularity-20231130.eb ``` - A script for creating an overlay for the Jax container: ``` wget https://462000008.lumidata.eu/userscripts/create_overlay.sh ``` ## Preparation: Set the Location for Your EasyBuild installation Before you start installing software with EasyBuild on LUMI-G, it's essential to define where your installations should go. The procedure below details the necessary steps. If you have already installed software with EasyBuild on LUMI-G, you can skip this section. ### Step 1: Define EasyBuild Installation Directory By default, EasyBuild installs software in the `$HOME/EasyBuild` directory. To avoid consuming the limited storage space in your home directory, it is recommended to install software within your project directory. This not only helps manage storage quotas but also allows you to share installations with your project team. Update the `EBU_USER_PREFIX` environment variable to point to your chosen directory within the `/project/project_*` or `/project/project_*/username` paths: Add the following to your `.bashrc` file: ``` export EBU_USER_PREFIX=/project/project_465000000/EasyBuild ``` or, for a user-specific installation: ``` export EBU_USER_PREFIX=/project/project_465000000/username/EasyBuild ``` Ensure you replace `465000000` with your actual project number, and `username` with your real username. By setting the `EBU_USER_PREFIX` variable to a directory within your project space, you alleviate space constraints in your home directory and promote resource sharing among your project team members. ### Step 2: Load the LUMI Software Stack Loading the LUMI software stack ensures you're working with a consistent software environment. For example, to load the LUMI/23.09 stack, use the following command: ```{code-block} bash module load LUMI/23.09 ``` This command should automatically load the appropriate partition module. ### Step 3: Load the EasyBuild Module Now, load EasyBuild with the following module command, which will confirm the installation directory: ```{code-block} bash module load EasyBuild-user ``` To verify EasyBuild's configuration: ``` eb --show-config ``` **Tip for users with multiple projects**: Do not change `EBU_USER_PREFIX` when a LUMI module is loaded. For more details and advanced configurations, refer to the [LUMI documentation](https://docs.lumi-supercomputer.eu/software/installing/easybuild/#preparation-set-the-location-for-your-easybuild-installation). ## Installing jax to your project environment using easybuild On Lumi-G locate a singularity image with Jax in the folder: ```bash /appl/local/containers/easybuild-sif-images/ ``` Find a suitable image, at the time of this writing the image: ```bash lumi-jax-rocm-5.6.1-python-3.10-jax-0.4.13-dockerhash-1e625e0cfb23.sif ``` is available. Create an easybuild config file, e.g. `jax-0.4.13-rocm-5.6.1-python-3.10-singularity.eb` with the following content: ```{code-block} bash --- linenos: emphasize-lines: 7 --- easyblock = 'MakeCp' name = 'jax' version = '0.4.13' versionsuffix = '-rocm-5.6.1-python-3.10-singularity-20231130' local_sif = 'lumi-jax-rocm-5.6.1-python-3.10-jax-0.4.13-dockerhash-1e625e0cfb23.sif' homepage = 'https://jax.readthedocs.io/' whatis = [ 'Description: JAX is Autograd and XLA, brought together for high-performance numerical computing.' ] description = """ This module provides a container with jax %(version)s. The module defines a number of environment variables: * SIF and SIFJAX: The full path and name of the Singularity SIF file to use with singularity exec etc. * SINGULARITY_BINDPATH: Mounts the necessary directories from the system, including /users, /project, /scratch, and /flash so that you should be able to use your regular directories in the container. * RUNSCRIPTS and RUNSCRIPTSJAX: The directory with some sample runscripts. Note that this container uses a Conda environment internally. When in the container, the command to activate the container is contained in the environment variable WITH_CONDA. """ toolchain = SYSTEM sources = [ { 'filename': local_sif, 'extract_cmd': '/bin/cp %s .' }, ] skipsteps = ['build'] files_to_copy = [ ([local_sif], '.'), ] local_runscript_python_simple=""" #!/bin/bash -e # Start conda environment inside the container \$WITH_CONDA # Run application python "\$@" """ postinstallcmds = [ 'mkdir -p %(installdir)s/runscripts', f'cat >%(installdir)s/runscripts/conda-python-simple < createoverlay.log set -xeuo pipefail trap 'handle_error' ERR MPI4PI_VERSION=3.1.5 MPI4JAX_VERSION=0.3.14.post7 PYTORCH_VERSION=2.1.2 TORCHVISION_VERSION=0.16.2 TORCHAUDIO_VERSION=2.1.2 echo -e "\n Will create an overlay with packages:\n" print_keyvalue "mpi4py" $MPI4PI_VERSION print_keyvalue "mpi4jax" $MPI4JAX_VERSION print_keyvalue "pytorch" $PYTORCH_VERSION print_keyvalue "torchvision" $TORCHVISION_VERSION print_keyvalue "torchaudio" $TORCHAUDIO_VERSION echo print_step_start "Checking if the jax module is loaded" if [[ ! ":$LOADEDMODULES:" == *":jax/0.4.13-rocm-5.6.1-python-3.10-singularity-20231130:"* ]] then print_fail print_instruction Please install and/or load the jax container module else print_ok fi print_step_start "Checking if the \$SIF variable is set" if [[ -z ${SIF+x} ]] then print_fail print_instruction "The jax module is loaded but the \$SIF environment variable is not set" exit else print_ok echo fi print_keyvalue "base container" $(basename $SIF) echo STARTDIR=$PWD print_step_start "Creating build directory" rm -rf $XDG_RUNTIME_DIR/overlay_tmp mkdir $XDG_RUNTIME_DIR/overlay_tmp cd $XDG_RUNTIME_DIR/overlay_tmp print_ok print_step_start "Downloading source" curl -LO https://github.com/mpi4jax/mpi4jax/archive/refs/tags/v$MPI4JAX_VERSION.tar.gz curl -LO https://github.com/mpi4py/mpi4py/releases/download/$MPI4PI_VERSION/mpi4py-$MPI4PI_VERSION.tar.gz print_ok print_step_start "Writing installation script" cat < install.sh \$WITH_CONDA set -xeuo pipefail export TMPDIR=\$PWD export INSTALLDIR=\$PWD/staging/\$(dirname \$(dirname \$(which python)) | cut -c2-) export PYMAJMIN=\$(python -c "import sys; print(f'{sys.version_info[0]}.{sys.version_info[1]}')") export PYTHONPATH=\$INSTALLDIR/lib/python\$PYMAJMIN/site-packages tar xf mpi4py-$MPI4PI_VERSION.tar.gz cd mpi4py-$MPI4PI_VERSION cp /opt/cray/pe/python/3.10.10/lib/python3.10/site-packages/mpi4py/mpi.cfg . pip install --no-cache-dir --prefix=\$INSTALLDIR --no-build-isolation . cd .. tar xf v$MPI4JAX_VERSION.tar.gz cd mpi4jax-$MPI4JAX_VERSION/ pip install --no-cache-dir --prefix=\$INSTALLDIR --no-build-isolation . cd .. pip3 install --no-cache-dir --prefix=\$INSTALLDIR \\ --index-url https://download.pytorch.org/whl/rocm5.6 \\ torch==$PYTORCH_VERSION+rocm5.6 torchvision==$TORCHVISION_VERSION+rocm5.6 torchaudio==$TORCHAUDIO_VERSION+rocm5.6 exit EOF chmod u+x ./install.sh print_ok print_step_start "Running installation script" singularity exec -B$PWD $SIF ./install.sh 1>&2 print_ok print_step_start "Creating SquashFS image" chmod -R 777 staging/ mksquashfs staging/ overlay.squashfs -no-xattrs -processors 16 1>&2 rm -rf staging print_ok print_step_start "Testing" for package in mpi4py mpi4jax torch do singularity exec --overlay=overlay.squashfs $SIF /runscripts/conda-python-simple -c "import $package" done print_ok print_step_start "Cleaning up" cp overlay.squashfs $STARTDIR/ cd $STARTDIR rm -rf $XDG_RUNTIME_DIR/overlay_tmp print_ok echo -e "\n$(tput setaf 4)Overlay written to$(tput sgr0) $STARTDIR/overlay.squashfs\n" print_instruction "To run with the overlay, use the following commands:" cat < EOF ``` 3. Make the script executable by running `chmod +x create_overlay.sh`. 4. Execute the script by running `./create_overlay.sh`. The script will run several checks to ensure you have loaded the appropriate Jax environment and that the SIF (Singularity Image File) environment variable is set. It will then proceed to download the required packages, run an installation script inside the singularity container, and create a `SquashFS` image containing all installed packages. 5. The generated `overlay.squashfs` file will be placed in the starting directory. You can then use this overlay with the following commands: ```{code-block} bash module load LUMI/$LUMI_STACK_VERSION module load jax/0.4.13-rocm-5.6.1-python-3.10-singularity-20231130 singularity exec --overlay=$STARTDIR/overlay.squashfs \$SIF /runscripts/conda-python-simple ``` Replace `` with the path to your Python script. Remember to replace `$LUMI_STACK_VERSION` with the appropriate version number for your LUMI installation. 6. After creating the overlay, you can use it to execute your Python scripts in an environment with the installed packages. The overlay contains everything needed to run `mpi4py`, `mpi4jax`, `PyTorch`, `torchvision`, and `torchaudio` with the Jax container. 7. To ensure everything is working correctly, the script includes a testing step where it tries to import the installed packages using Python within the Singularity container. 8. Once the overlay has been tested and confirmed to work, the script will perform cleanup, removing the temporary build directory and leaving you with just the `overlay.squashfs` file. 9. It is advisable to document the versions of the packages included in your overlay and the Jax container version you are using. This will help maintain reproducibility and ease future updates or modifications to your computational environment. 10. The overlay creation script logs all its activity to a file named `createoverlay.log`. If you encounter any issues during the creation of the overlay, this log file will help diagnose what went wrong. In summary, by using the `create_overlay.sh` script, you can easily augment the computational environment provided by the Jax container with additional packages that are necessary for your work. This approach provides a flexible and reproducible way to manage your software dependencies on LUMI-G. This concludes the guide on how to create an overlay for the Jax container on LUMI-G. Remember to adjust the package versions in the script as necessary to fit the requirements of your specific project or workflow. ## Running JAX Jobs on LUMI-G with Slurm In the following we will show how to run a JAX job on LUMI-G using Slurm. We will use the example Python script provided in your guide to demonstrate how to execute a distributed JAX computation using pmap accross multiple GPUs on multiple nodes. The example Python script performs a simple convolution operation on an array in a distributed manner. It demonstrates the process index assignment, input array preparation, and the distributed execution of a JAX operation using pmap. In the end of the script the results are gathered on the root process and printed using MPI's gather function. ### Example Python Script for Distributed JAX Computation ```{code-block} Python --- linenos: emphasize-lines: 12, 14-18, 33-34, 46, 60, 65 --- import jax import numpy as np import jax.numpy as jnp from mpi4py import MPI # Initialize MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() jax.distributed.initialize() def convolve(x, w): output = [] for i in range(1, len(x) - 1): output.append(jnp.dot(x[i-1:i+2], w)) return jnp.array(output) # Determine the process index and the number of processes process_index = jax.process_index() n_devices = jax.process_count() # Create the input arrays based on the process index # Each process gets its own unique slice of data. # For example, if process_index is 0, it will select the first row; if 1, the second row, and so on. xs = np.arange(5 * n_devices).reshape(n_devices, 5)[process_index:process_index+1] ws = np.array([2., 3., 4.]) # Shared weights array # Ensure ws has an extra dimension to match the shape of xs ws = ws.reshape(1, -1) # Now we can apply pmap with the corrected data shape distributed_convolve = jax.pmap(convolve, axis_name='p') # Using the convolve operation in a distributed context print(f"----- xs (process_index={process_index}) -----") print(xs) print("----- ws -----") print(ws) print("----- distributed_convolve(xs, ws) -----") print(distributed_convolve(xs, ws)) # Each process executes its portion of the distributed convolve operation local_output = distributed_convolve(xs, ws) # Convert the local JAX array to a NumPy array on each process local_output_np = np.array(local_output) # Gather all the results on the root process if rank == 0: # Prepare a container to hold the received data from all processes # The size is the total number of processes times the size of each local_output gathered_outputs = np.empty([n_devices, *local_output_np.shape[1:]], dtype=local_output_np.dtype) else: gathered_outputs = None # Use MPI's gather function to collect all arrays on the root process comm.Gather(local_output_np, gathered_outputs, root=0) # Now the root process can print the combined array if rank == 0: print("Gathered outputs on root process:") print(gathered_outputs) ``` ### Preparing Your Slurm Job Script Below is an example Slurm job script which demonstrates how to execute a distributed JAX computation using pmap with MPI. Save the script as run_jax_job.sh and make sure to adapt the parameters to suit the needs of your job, like --account, which should specify your own project account. ```{code-block} bash --- linenos: emphasize-lines: 3, 6-7, 10-12, 15, 22-30 --- #!/bin/bash -e #SBATCH --job-name=distributed_jax_example #SBATCH --account=project_465000000 #SBATCH --time=00:04:00 #SBATCH --partition=standard-g #SBATCH --nodes=4 #SBATCH --ntasks-per-node=8 #SBATCH --gpus-per-node=8 module purge module load LUMI/23.09 module load jax # Specifying CPU binding for each task to improve performance CPU_BIND="map_cpu:49,57,17,25,1,9,33,41" n_nodes=$SLURM_NNODES # Number of nodes allocated n_gpus=$((n_nodes*$SLURM_GPUS_PER_NODE)) # Total number of GPUs n_tasks=$n_gpus # Number of tasks is equal to the number of GPUs # Run the job using srun srun \ -N $n_nodes \ -n $n_gpus \ --cpu-bind=${CPU_BIND} \ --gpus-per-node=$SLURM_GPUS_PER_NODE \ singularity exec \ --overlay=path-to-overlay/overlay.squashfs \ $SIFJAX \ bash -c '$WITH_CONDA; python distributed_pmap_mpi.py' ``` - On line 3 replace `project_465000000` with your own project account number. - On lines 6-7, adjust the number of nodes. Keep 8 GPUs per node as the entire node is allocated to your job on the standard-g partition. - On lines 10-12 we load the LUMI and JAX modules and purge any previously loaded modules. - On line 15, we specify the CPU binding for each task for more info see [LUMI documentation](https://docs.lumi-supercomputer.eu/runjobs/scheduled-jobs/distribution-binding/#gpu-binding) - On lines 22-30 we run the job using srun, specifying the number of nodes, tasks, and GPUs. We also specify the overlay file and the JAX container image to use. Finally, we execute the Python script `distributed_pmap_mpi.py` using the WITH_CONDA environment variable. ### Submitting Your Job Once you have your Slurm job script configured, you can submit your job by executing the following command in the terminal: ```{code-block} bash sbatch run_jax_job.sh ```