斷點續訓加速
更新時間 2025-09-09 12:30:43
最近更新時間: 2025-09-09 12:30:43
分享文章
萬卡規模國產化集群下,斷點續訓在5類故障下實現1分鐘檢測、5分鐘內定位、15分鐘內恢復訓練。
測試數據及代碼準備
| 數據集 | 數據集大小 | 使用模型 |
|---|---|---|
| Wikipedia-en (1M條) | 9.1GB | Llama2-70B /Llama3.1-405B |
使用預處理為MindRecord格式的Wikipedia-en (1M條)數據集,上傳到對象存儲,并由對象存儲下載到平臺HPFS。
測試代碼在gitlab倉庫下載到本地,并放置于/work/home下。
腳本和任務準備
按照下面修改run.sh腳本
#! /bin/bash
# huijuformers的絕對路徑, 需要修改path_to_huijuformers
export BASE_DIR=/work/data/llama2_9216/huijuformers
## 以下為平臺自動注入的環境變量
# yaml文件中需要修改的環境變量
export BATCH_SIZE=1
export EPOCHS=350
export LEARNING_RATE=6.e-5
export DATA_PARALLEL=256
export MODEL_PARALLEL=4
export PIPELINE_STAGE=9
# 模型微調相關
export FINETUNE_MODEL_TYPE=llama2_70b_base # 合并為一個參數,與模型存放文件夾名稱一致(與后端溝通過)
export FINETUNING_TYPE=ALL
export TIME_TAG=$(date +"%m%d-%H%M")
# # 數據相關
# export DATASET_PATH=${BASE_DIR}/data
# export DATASET_FILE=original_data.json # 需要修改
# ## 以下為平臺后端需要自行更改后傳入的環境變量
# # 平臺數據格式轉換,專用數據調試時用不到
# export DATASET_TMP_PATH=${BASE_DIR}/data/processed_data/${FINETUNE_MODEL_TYPE}
# mkdir -p ${DATASET_TMP_PATH}
# # 模型輸入
# # 專業模式,平臺訓練時需要按照平臺的掛載路徑去修改這一塊的變量
# export CHECKPOINT_DIR=''
# # 低代碼模式,微調時約定掛載為下面的路徑
# # export CHECKPOINT_DIR=/work/mount/publicModel/${FINETUNE_MODEL_TYPE}/${FINETUNE_MODEL_TYPE}
# # 輸出文件夾路徑,run_mode為訓練模式,如train,lora,full
run_mode=train
export OUTPUT_DIR=${BASE_DIR}/output/${FINETUNE_MODEL_TYPE}/${run_mode}/${TIME_TAG}
export OUTPUT_ROOT_DIR=${BASE_DIR}/output/${FINETUNE_MODEL_TYPE}/${run_mode}
rm -rf ${OUTPUT_DIR}/resume_record
# 獲取節點IP、名稱,記錄至文件
echo $(hostname -I | awk '{print $1}'),$NODE_NAME >> ${BASE_DIR}/output/nodes
sed -i '/pam_limits.so/s/^/#/' /etc/pam.d/sshd
# 啟動腳本
cd ${BASE_DIR}/bin/scripts
# apt install netcat -y
# 微調
# bash finetune.sh
# 預訓練
export MS_TOPO_TIMEOUT=7200
bash train.sh啟動訓練任務
點擊訓練任務頁面的新建任務,按照如下的示例配置訓練命令和模型掛載等,然后啟動任務。
Llama2-70B萬卡測試結果
在平臺使用9216卡對Llama2-70B進行預訓練,萬卡規模國產化集群下,斷點續訓在5類故障下實現1分鐘檢測、5分鐘內定位、15分鐘內恢復訓練。
整體結果
訓練環境
| 服務器型號 | Atlas 800T A2 |
|---|---|
| NPU型號 | 910B2(64GB) |
| 驅動版本 | 23.0.3 |
| CANN | 8.0.RC2 |
| Python | 3.10.14 |
| MindSpore | 2.3.0 |
| Mindformers | 1.2.0 |
訓練配置
| Epochs | 350 |
|---|---|
| Learning Rate | 6.e-5 |
| Global Batch Size | 32768 |
| Batch Size | 1 |
| Micro Batch Size | 128 |
| Sequence Length | 4096 |
| Data Parallel (DP) | 256 |
| Model Parallel (MP) | 4 |
| Pipeline Parallel (PP) | 9 |
| max_device_memory | 54GB |
| jit_level | O2 |
訓練結果
| 吞吐量(tokens/s/p) | 366.915 |
|---|---|
| MFU - 芯片算力(%) | 43.061 |
| MFU - CUBE算力(%) | 45.867 |
斷點續訓
斷點CheckPoint總大小:22T,其中0卡斷點CheckPoint大小:2.9G。
故障1:業務故障,kill所有python進程
| 故障檢測時間(Min) | 7.2s |
|---|---|
| 故障處理耗時(Min) | 231.7s, 3.86min |
| 故障恢復耗時(Min) | 458s, 7.63min |
| CKPT加載時間(Min) | 0.28min |
| 0卡CKPT加載速度(GB/s) | 0.99 |
故障2:節點心跳故障,把node上label去掉
| 故障檢測時間(Min) | 18.9s, 0.315min |
|---|---|
| 故障處理耗時(Min) | 279.8s, 4.64min |
| 故障恢復耗時(Min) | 478s, 7.96min |
| 0卡CKPT加載時間(Min) | 0.3min |
| 0卡CKPT加載速度(GB/s) | 1.01 |
故障3:節點down故障,reboot
| 故障檢測時間(Min) | 18.3s, 0.3min |
|---|---|
| 故障處理耗時(Min) | 465s, 7.75min |
| 故障恢復耗時(Min) | 546s, 9.1min |
| 0卡CKPT加載時間(Min) | 0.98min |
| 0卡CKPT加載速度(GB/s) | 0.1 |
故障4:網絡故障,網卡link down
| 故障檢測時間(Min) | 895s(600s HCCL), 14.9min |
|---|---|
| 故障處理耗時(Min) | 300s, 5min |
| 故障恢復耗時(Min) | 472.1s, 7.86min |
| 0卡CKPT加載時間(Min) | 0.32min |
| 0卡CKPT加載速度(GB/s) | 0.96 |
故障5:PCIE故障,模擬NPU掉卡
| 故障檢測時間(Min) | 78.7s, 1.3min |
|---|---|
| 故障處理耗時(Min) | 267.1s, 4.4min |
| 故障恢復耗時(Min) | 516.9s, 8.6min |
| 0卡CKPT加載時間(Min) | 0.32min |
| 0卡CKPT加載速度(GB/s) | 0.98 |