#!/bin/bash # Launch with: # flux run -N x -n x -c96 -g 4 -ompibind=omp_proc_bind,omp_places -onosetpgrp \ # -ofastload=on --setattr=rdzv_get_en=0 ./run_flux_train.sh \ # --rocm-version 7.2.1 --mpich-version 9.1.0 \ # --pytorch-venv /path/to/venv --training-script train.py usage() { cat <&2 usage >&2 exit 1 fi } ROCM_VERSION="7.2.1" MPICH_VERSION="9.1.0" PYTORCH_VENV="" TRAINING_SCRIPT_NAME="train.py" while [[ $# -gt 0 ]]; do case "$1" in --rocm-version) require_arg "$1" "$2" ROCM_VERSION="$2" shift 2 ;; --mpich-version) require_arg "$1" "$2" MPICH_VERSION="$2" shift 2 ;; --pytorch-venv) require_arg "$1" "$2" PYTORCH_VENV="$2" shift 2 ;; --training-script) require_arg "$1" "$2" TRAINING_SCRIPT_NAME="$2" shift 2 ;; --help|-h) usage exit 0 ;; *) echo "Unknown argument: $1" >&2 usage >&2 exit 1 ;; esac done if [[ -z "$PYTORCH_VENV" ]]; then echo "Missing required argument: --pytorch-venv" >&2 usage >&2 exit 1 fi export ROCM_VERSION export MPICH_VERSION export PYTORCH_VENV export TRAINING_SCRIPT_NAME module load PrgEnv-gnu gcc-native/11.2 cray-mpich/${MPICH_VERSION} rocm/${ROCM_VERSION} &>/dev/null source $PYTORCH_VENV/bin/activate # Point to your python virtual environment export TENSILE_SOLUTION_SELECTION_METHOD=2 # If using ROCm 7 (For now) export FI_MR_CACHE_MONITOR=kdreg2 export MASTER_ADDR=$(flux hostlist local | /bin/hostlist -n 2) export MASTER_PORT=23456 export MIOPEN_DISABLE_CACHE=0 export MIOPEN_USER_DB_PATH=/var/tmp/$USER/MIOpen_user_db export MIOPEN_CUSTOM_CACHE_DIR=/var/tmp/$USER/MIOpen_custom_cache export LD_LIBRARY_PATH=/opt/rocm-${ROCM_VERSION}/llvm/lib:${LD_LIBRARY_PATH} export NCCL_NET_PLUGIN=librccl-net.so export FI_CXI_RDZV_PROTO=alt_read export FI_CXI_RDZV_THRESHOLD=0 export FI_CXI_RDZV_GET_MIN=0 export FI_CXI_RDZV_EAGER_SIZE=0 export FI_CXI_DEFAULT_TX_SIZE=1024 export FI_CXI_DISABLE_HOST_REGISTER=1 export FI_CXI_DEFAULT_CQ_SIZE=131072 export FI_CXI_RX_MATCH_MODE=hybrid export NCCL_CROSS_NIC=1 export NCCL_SOCKET_IFNAME=hsi0 export PYTORCH_MIOPEN_SUGGEST_NHWC=1 export LOG_RANK=0 export PYTHONUNBUFFERED=1 export OMP_NUM_THREADS=1 PLUGIN_PREFIX=$(find /collab/usr/global/tools/rccl/toss_4_x86_64_ib_cray/ -maxdepth 1 -type d -regex ".*rocm-${ROCM_VERSION:0:3}.*") export LD_LIBRARY_PATH=$PLUGIN_PREFIX/install/lib:$LD_LIBRARY_PATH export NUM_PROC=4 # Can set generically with `NUM_PROC=$(python -c "import torch; print(torch.cuda.device_count())")` export ROCR_VISIBLE_DEVICES=0,1,2,3 # Set appropriate to NUM_PROC setting PYTORCH_ALLOC_CONF="expandable_segments:False" \ torchrun --nnodes ${FLUX_JOB_SIZE} --nproc-per-node=$NUM_PROC \ --rdzv_backend c10d --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 $TRAINING_SCRIPT_NAME