国产精品电影_久久视频免费_欧美日韩国产激情_成年人视频免费在线播放_日本久久亚洲电影_久久都是精品_66av99_九色精品美女在线_蜜臀a∨国产成人精品_冲田杏梨av在线_欧美精品在线一区二区三区_麻豆mv在线看

微調大模型,AMD MI300X就夠了!跟著這篇博客微調Llama 3.1 405B,效果媲美H100

人工智能 新聞
為了優化訓練,在微調 LLaMA 405B 模型,只計算 LoRA 參數的梯度,保持主模型參數不變。

隨著 AI 模型的參數量越來越大,對算力的需求也水漲船高。

比如最近,Llama-3.1 登上了最強開源大模型的寶座,但超大杯 405B 版本的內存就高達 900 多 GB,這對算力構成了更加苛刻的挑戰。

如何降低算力的使用成本和使用門檻,已經成為許多公司尋求突破的關鍵。Felafax 就是其中的一家創業公司,致力于簡化 AI 訓練集群的搭建流程。

Nikhil Sonti 和 Nikhin Sonti 創立了 Felafax,他們的口號是在構建開源 AI 平臺,為下一代 AI 硬件服務,將機器學習的訓練成本降低 30%。

與英偉達相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的性價比,按每美元計算,其性能表現更為出色。

最近,Felafax 的聯合創始人 Nikhil Sonti 發布了一篇博客,詳細分享了如何通過 8 張 AMD MI300X GPU 和 JAX 微調 LLaMA 3.1 405B 模型的方法,所有代碼現已開源。

圖片

Github 鏈接:https://github.com/felafax/felafax

機器之心對博客內容進行了不改變原意的編譯、整理,以下是博客內容:

JAX 尤其適合非英偉達硬件

JAX 是一個強大的機器學習庫,結合了類似 NumPy 的 API、自動微分功能以及 Google 的 XLA 編譯器。它在模型并行化方面提供了優秀的 API,因此非常適合像 LLaMA 3.1 405B 這樣的超大模型訓練。

在使用 AMD 硬件時,JAX 有幾個明顯的優勢:

  • 多硬件并行支持:JAX 采用 XLA(加速線性代數)編譯器,將計算編譯為硬件無關的中間表示(HLO),這意味著同樣的 JAX 代碼無需修改便可高效運行在不同硬件后端,包括 AMD GPU。
  • 獨立于底層硬件:XLA 編譯器的優化策略是通用的,不針對某個特定的硬件平臺。這使得任何支持 XLA 的硬件設備(如 CPU、GPU、TPU)都能受益于這些優化,獲得更好的性能表現。
  • 極高的適應性:從 NVIDIA 轉移到 AMD(或其他硬件)時,JAX 只需做極少的代碼改動。而相較之下,PyTorch 與英偉達的 CUDA 生態系統緊密耦合,遷移過程相對復雜。

因此,JAX 成為了我們在非英偉達硬件上的最佳選擇。

拉取 Docker 鏡像:

docker pull rocm/jax:latest

啟動 Docker 容器:

# Pull the Docker Image:
docker pull rocm/jax:latest 


# Start the Docker Container:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \ 
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest


# Verify the Installation: 
python3 -c 'import jax; print(jax.devices())'

驗證安裝

python3 -c 'import jax; print (jax.devices ())'

訓練使用了一個配備了 8 張 AMD MI300x GPU 的 AMD 節點。每張 MI300x 擁有 192GB 的 HBM3 內存,性能表現與最新的英偉達 H100 GPU 相比非常出色。

圖片

與英偉達 H100 的比較,來源:TensorWave

訓練 LLaMA 405B:性能與可擴展性

使用 JAX,可以成功地在 AMD GPU 上訓練 LLaMA 405B 模型。我們使用 LoRA 微調,將所有模型權重和 LoRA 參數都設為 bfloat16,LoRA rank 設為 8,LoRA alpha 設為 16:

  • 模型大小:LLaMA 模型的權重占用了約 800GB 的顯存。
  • LoRA 權重 + 優化器狀態:大約占用了 400GB 的顯存。
  • 顯存總使用量:占總顯存的 77%,約 1200GB。
  • 限制:由于 405B 模型的規模過大,batch 大小和序列長度的空間有限,使用的 batch size 為 16,序列長度為 64。
  • JIT 編譯:由于空間限制,無法運行 JIT 編譯版本;它可能需要比急切模式稍多的空間。
  • 訓練速度:使用 JAX 急切模式,約為 35 tokens / 秒。
  • 內存效率:穩定在約 70% 左右。
  • 擴展性:在 8 張 GPU 上,使用 JAX 的擴展性接近線性。

由于硬件和顯存的限制,我們無法運行 JIT 編譯版本的 405B 模型,整個訓練過程是在 JAX 的急切模式下執行的,因此還有很大的進步空間。 

下圖中顯示了在一次微調訓練步驟中,8 張 GPU 的顯存利用率和 rocm-smi 輸出:

GPU 利用率:

圖片

顯存利用率:

圖片

rocm-smi 輸出:

圖片

訓練設置 

將 LLaMA 3.1 從 PyTorch 移植到 JAX 

圖片

此前,Nikhil Sonti 分享過如何將 LLaMA 3.1 從 PyTorch 移植到 JAX。他指出,目前 90% 的大型語言模型(LLM)都運行在 NVIDIA GPU 上,但實際上還有一些同樣強大且性價比更高的替代方案。例如,在 Google TPU 上訓練和部署 Llama 3.1 的成本比 NVIDIA GPU 低約 30%。

然而,支持非 NVIDIA 硬件的開發工具較為匱乏。Sonti 最初嘗試使用 PyTorch XLA 在 TPU 上訓練 Llama 3.1,但過程并不順利。XLA 與 PyTorch 的集成不夠完善,缺少一些關鍵的庫(如 bitsandbytes 無法正常運行),同時還遇到了一些難以解決的 HuggingFace 錯誤。

為此,他決定調整策略,將 Llama 3.1 從 PyTorch 移植到 JAX,成功解決了這些問題。Sonti 還錄制了詳細的教程視頻,并開源了所有代碼:

圖片

  • 方法演示:https://dub.sh/felafax-demo
  • 代碼倉庫:https://github.com/felafax/felafax

加載模型,并把模型參數分片

處理像 LLaMA 405B 這樣的超大模型,需要在多個設備之間高效地進行參數分片。以下是如何通過 JAX 實現這一點的。

在 JAX 中進行參數分片

為了將巨大的 LLaMA 405B 模型高效地分布到 8 張 AMD GPU 上,需要使用 JAX 的設備網格(device mesh)功能。

部署代碼:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69

JAX 的設備網格可以幫助我們把可用的設備組織成一個網格,讓我們可以指定如何把模型的參數和計算分配到不同的 GPU 上。

在本文的設置中,需要創建一個形狀為(1, 8, 1)的網格,并將軸分別命名為數據并行(dp)、全分片數據并行(fsdp)和模型并行(mp)。然后,為模型的每個張量定義特定的分片規則,指定這些維度如何沿著這些網格軸進行分片。

DEVICES = jax.devices () 
DEVICE_COUNT = len (DEVICES) 
DEVICE_MESH = mesh_utils.create_device_mesh ((1, 8, 1)) 
MESH = Mesh (devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))

可視化分片

可以使用以下代碼來可視化分片結果,從而方便地驗證分片規則是否按預期應用。

jax.debug.visualize_array_sharding

分片規則

模型不同組件的分片規則如下所示:

  • 參數如何分片:

參數要在 8 個 GPU 之間分配。例如,LM head(lm_head/kernel)張量有兩個軸,按照 PS ("fsdp", "mp") 進行分片。在本例中是 8 和 1,因此可以看到該張量在第一個軸上沿著 8 個 GPU 被拆分。

  • Non-Replicated 參數:

沒有任何分片規范的參數會在所有設備上進行復制。例如,層歸一化(attention_norm/kernel 和 ffn_norm/kernel)沒有設置分片規范,是 PS (None)。

應用分片函數

在加載模型時,使用以下分片函數逐步對模型權重進行分片:

def make_shard_and_gather_fns (partition_specs):
    def make_shard_fn (partition_spec):
        out_sharding = NamedSharding (mesh, partition_spec)
        def shard_fn (tensor):
            return jax.device_put (tensor, out_sharding).block_until_ready ()
        return shard_fn

    shard_fns = jax.tree_util.tree_map (make_shard_fn, partition_specs)
    return shard_fns

# Create shard functions based on partitioning rules
shard_fns = make_shard_and_gather_fns (partitioning_rules)

這使得我們能夠將每個參數放置在指定的設備上,并按照設定的分片進行處理。

分片訓練 Batch

最初,訓練 Batch 是正常創建的,但在輸入模型之前,需要按照下面的代碼在 GPU 上進行分片:

train_batch = jax.device_put ( train_batch, 
NamedSharding (self.mesh, PS ("dp", "fsdp")))

在這里,我們指定訓練 Batch 應該在 "dp" 和 "fsdp" 軸上進行分片,在本例中分別對應于被分成 1 和 8 份,如果把結果可視化出來,如下所示:

分片前:

圖片

在調用  jax.device_put 之后:

圖片

加入 LoRA

LoRA 通過將權重更新分解為低秩矩陣,減少了可訓練參數的數量,這對于微調大型模型特別有效。以下是在 AMD GPU 上微調 Llama 3.1-405 的 LoRA 的要點:

  • 將 LoRA 參數(lora_a 和 lora_b)與主模型參數分開。
  • 使用 jax.lax.stop_gradient (kernel) 來防止對主模型權重的更新。
  • 使用 lax.dot_general 進行快速、精確控制的矩陣運算。
  • LoRA 輸出在添加到主輸出之前會被縮放為 (self.lora_alpha/self.lora_rank)。

LoRADense 層

在此設定一個自定義的 LoRADense 層,該層集成了 LoRA 參數:

class LoRADense (nn.Module):
    features: int
    lora_rank: int = 8
    lora_alpha: float = 16.0
@nn.compact
def __call__(self, inputs: Any) -> Any:
# Original kernel parameter (frozen)
        kernel = self.param ('kernel', ...)
        y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)
# LoRA parameters (trainable)
        lora_a = self.variable ('lora_params', 'lora_a', ..., ...)
        lora_b = self.variable ('lora_params', 'lora_b', ..., ...)
# Compute LoRA output
        lora_output = lax.dot_general (inputs, lora_a.value, ...)
        lora_output = lax.dot_general (lora_output, lora_b.value, ...)
# Combine original output with LoRA modifications
        y += (self.lora_alpha/self.lora_rank) * lora_output




        return y.astype (self.dtype)

分片 LoRA 參數

為了高效地在設備之間分配 LoRA 參數,我們也通過 JAX 設定了分片規則,這確保了 LoRA 參數與主模型參數的分片一致,優化了內存使用和計算效率。

LoRA A matrices (lora_a)

LoRA A 矩陣(lora_a)

  • 分片規則:PS ("fsdp", "mp")
  • 可視化結果:如下圖所示,lora_a 參數被分片為 (8, 1),這意味著第一個軸在 8 個設備上進行分片("fsdp" 軸),而第二個軸未進行分片。

圖片

LoRA B 矩陣(lora_b)

  • 分片規則:PS ("mp", "fsdp")
  • 可視化結果:如下圖所示,lora_b 參數被分片為 (1, 8),這意味著第二個軸在 8 個設備上進行分片(fsdp 軸),而第一個軸未進行分片。

圖片

這種分片策略優化了參數的分配,減少了通信開銷,并在訓練過程中增強了并行性。它確保每個設備僅持有一部分 LoRA 參數,使得大模型如 LLaMA 405B 的高效擴展成為可能。

僅更新 LoRA 參數 

為了優化訓練,在微調 LLaMA 405B 模型,只計算 LoRA 參數的梯度,保持主模型參數不變。這個方法減少了內存使用,并加速了訓練,因為只更新較少的參數。可以移步 GitHub 倉庫,查看實現細節。

在訓練過程中,每一步都涉及將一批輸入數據通過模型進行處理。由于只有 LoRA 參數是可訓練的,因此模型的預測和計算的損失僅依賴于這些參數,然后對 LoRA 參數進行反向傳播。只更新這些參數簡化了訓練過程,使得在多個 GPU 上高效微調像 LLaMA 405B 這樣的大型模型成為可能。

更多研究細節,請參考原博客。

責任編輯:張燕妮 來源: 機器之心
相關推薦

2024-07-24 13:58:25

2024-08-16 14:00:00

2024-08-02 14:53:00

2024-07-24 13:18:17

2023-06-07 08:22:59

LLM微調技術

2024-04-15 12:50:00

大型語言模型ReFT

2023-06-14 12:08:51

2024-12-25 13:33:18

2023-08-13 07:44:18

GPU模型英偉達

2024-07-23 09:20:35

2023-06-28 21:47:54

2024-09-09 07:46:16

2025-04-10 07:59:51

2024-09-06 13:00:29

2023-10-20 17:53:05

2024-04-29 06:46:50

2024-12-30 00:01:00

多模態大模型Python

2024-07-29 13:38:06

2024-07-24 09:20:45

點贊
收藏

51CTO技術棧公眾號

中文在线播放一区二区| 91丝袜呻吟高潮美腿白嫩在线观看| 国产资源第一页| 亚洲欧洲制服丝袜| 果冻天美麻豆一区二区国产| 宅男噜噜99国产精品观看免费| 日本高清成人免费播放| 日韩精品诱惑一区?区三区| 亚洲欧美国产中文| 九九热这里只有精品免费看| 国产精品一区二区久久不卡| 伊人影院在线视频| 国产日韩欧美中文| 在线精品一区二区| 成人在线激情视频| 99精品国产在热久久下载| 91文字幕巨乱亚洲香蕉| 九九亚洲精品| 91精品久久久久久久久| 在线一区欧美| 天堂资源在线观看| 日本高清免费不卡视频| 精品毛片久久久久久| 国产精品一区二区在线观看不卡 | 亚洲va欧美va人人爽| 国产成人午夜电影| 久久精品综合网| 少妇高潮一区二区三区| 中文字幕一二三区在线观看| 国产精品丝袜白浆摸在线| 亚洲午夜在线电影| 亚洲激情中文在线| 欧洲精品二区| 免费看又黄又无码的网站| 欧美黄色片免费观看| 精品美女国产在线| 国产一区二区在线视频| 少妇精品久久久| 91免费在线| 人人澡人人爽| 欧美日韩天天操| 久久99精品久久久久久琪琪| 色女孩综合影院| 日本麻豆一区二区三区视频| 在线免费看黄色| 亚洲黄色成人久久久| 91av视频导航| 色婷婷久久久亚洲一区二区三区| 国产精品99久久久久久有的能看| 国产一区二区区别| 在线一区视频观看| 国产秀色在线www免费观看| 欧美乱做爰xxxⅹ久久久| 国产成人av一区二区三区| 欧美极品第一页| 欧美一区二区视频网站| 国产精品久久久久久久久动漫 | 日本三级视频在线播放| 国产 porn| 欧美黄网在线观看| 99国产盗摄| 久久久久久香蕉网| 精品无人区太爽高潮在线播放| 欧美性受xxxx黑人xyx| 中文字幕第一区| av一本久道久久综合久久鬼色| 在线播放精品| 天天影视欧美综合在线观看| japansex久久高清精品| 国内精品久久久久国产| 大胆高清日本a视频| www.69av| 香港三级日本三级a视频| 亚洲精品国产精品国自产观看| 国产精品视频一区二区高潮| 日本伊人精品一区二区三区介绍| 久久国产精品电影| 久久精品国产亚洲| 日韩成人av在线播放| 日韩欧美高清一区| 欧美日韩亚洲综合在线| 亚洲精品中文在线观看| 亚洲欧美另类久久久精品| 成人免费一区二区三区视频 | 成人av国产| 国产精品白丝一区二区三区| 免费观看性欧美大片无片| 日韩欧美久久| 国产欧美自拍| 欧美日韩看看2015永久免费| 香蕉人人精品| 欧美激情影院| 亚洲欧美一区在线| 午夜精品偷拍| 99精品国产99久久久久久福利| 久久天堂成人| 久久精品国产成人一区二区三区| 国产精品456露脸| 国产精品久久久99| 一区二区三区四区在线免费观看| 亚洲一区二区三区精品在线| 欧美丝袜丝nylons| 精品日韩一区二区三区| 日韩中文字幕网| 成人性生交xxxxx网站| 亚洲自拍偷拍色图| 一区二区视频国产| 日本五十路在线| av在线网址观看| 日韩深夜影院| 久久国内精品视频| 国产女人aaa级久久久级| 香港成人在线视频| 一区二区亚洲精品国产| 国产日韩欧美电影在线观看| eeuss中文| 激情综合网五月激情| 菠萝蜜视频国产在线播放| 老司机aⅴ在线精品导航 | 黄色成人在线看| 久久久久久青草| 精品一区二区男人吃奶| 视频一区二区中文字幕| 亚洲国产欧美日韩另类综合| 一区二区三区动漫| 国产综合18久久久久久| 久久久久久久久久久久久久久久久久久| 国产福利在线免费观看| 色琪琪久久se色| a美女胸又www黄视频久久| 日韩你懂的在线观看| 国产精品swag| 日韩porn| 经典一区二区| 国产精品丝袜久久久久久app| 精品爽片免费看久久| 成人免费视频97| 色多多视频在线播放| 羞羞的视频在线看| 久久久9色精品国产一区二区三区| 激情偷乱视频一区二区三区| 日韩一区二区视频| 国产免费一区视频观看免费| 国产成人亚洲综合无码| 蜜桃视频www网站在线观看| 影音先锋国产精品| 天天操天天色综合| 国产欧美日韩中文字幕| 国产成人午夜精品| 欧美国产美女| 日韩av中文字幕一区二区| 99热这里都是精品| 成人精品高清在线| 精品国产鲁一鲁一区二区张丽| 欧美裸体bbwbbwbbw| 91精品国产福利| 韩国日本不卡在线| 精品中文字幕在线观看| 2020国产精品视频| www.av毛片| 黄色片免费在线观看| 盗摄系列偷拍视频精品tp| 久久女同性恋中文字幕| 久久综合伊人77777蜜臀| 黄色一级视频在线播放| 国产亚洲字幕| 亚洲图片欧美视频| 亚洲开心激情网| 亚洲高清视频一区| 成人在线免费看黄| 狠狠入ady亚洲精品| 日韩av在线免费观看一区| 黄网站色视频免费观看 | 国产精品美女www| 日本黄色片在线观看| 国内久久精品视频| 69影院欧美专区视频| 久久精品视频观看| 99re亚洲国产精品| 91av免费看| 成年人在线观看网站| 99精品视频在线观看免费| 日本高清视频一区| caopen在线视频| 国产精品99精品久久免费| 欧日韩在线观看| 欧美日韩在线中文字幕| 久久久天天操| 国产精品国产三级国产专播精品人| 污污视频网站免费观看| 99精品国产在热久久下载| 久久99久久亚洲国产| 91小视频xxxx网站在线| 国产在线视频精品一区| 动漫3d精品一区二区三区| 国产精品玖玖玖在线资源| 日韩精品视频免费| av在线免费播放| 欧美三级一区二区| 深夜福利在线观看直播|