Naukowcy z Apple proponują metodę Cut Cross-Entropy (CCE): Usprawniony sposób obliczania straty krzyżowo-entropijnej w uczeniu maszynowym, bez zapisywania wszystkich logitsów w pamięci globalnej

Postępy w modelach językowych o dużej skali (LLM) zrewolucjonizowały przetwarzanie języka naturalnego, obejmując takie zastosowania jak generowanie tekstów, tłumaczenie oraz streszczanie. Modele te opierają się na olbrzymich zbiorach danych, dużej liczbie parametrów i rozbudowanych słownikach, co wymaga zaawansowanych technik zarządzania wymaganiami obliczeniowymi i pamięciowymi. Kluczowym elementem treningu modeli LLM jest obliczanie funkcji straty opierającej się na entropii krzyżowej. Choć jest to centralny aspekt wpływający na dokładność modelu, wiąże się również z istotnymi wyzwaniami pamięciowymi, zwłaszcza ze względu na rozmiar i złożoność słownika.

Problemy związane z pamięcią w dużych modelach językowych

Wymagania pamięciowe na etapie obliczania entropii krzyżowej stanowią poważne ograniczenie podczas trenowania dużych modeli językowych, szczególnie gdy rozmiar słownika osiąga setki tysięcy tokenów. Problem staje się szczególnie zauważalny w modelach takich jak Gemma 2 (2B), gdzie samo obliczanie entropii krzyżowej może pochłonąć do 24 GB pamięci, co stanowi aż 90% całkowitego zużycia pamięci podczas treningu. Takie ograniczenia wymuszają zmniejszenie rozmiarów partii danych (batch size) i konieczność kompromisów między wydajnością modelu a wykonalnością obliczeniową, co stanowi główną przeszkodę w skalowaniu.

Dotychczasowe podejścia i ich ograniczenia

Dotychczasowe metody mające na celu redukcję zużycia pamięci, takie jak FlashAttention czy hierarchiczne słowniki, koncentrowały się na konkretnych elementach, takich jak mechanizm samouważności (self-attention), lecz nie rozwiązały problemu obciążenia warstwy entropii krzyżowej. Metody podziału na fragmenty (chunking) zmniejszają wymagania pamięciowe, ale wprowadzają opóźnienia, co ogranicza ich praktyczne zastosowanie. Co więcej, te podejścia nie w pełni wykorzystują rzadkość gradientów ani optymalizacji sprzętowych, pozostawiając przestrzeń na dalsze udoskonalenia.

Innowacyjne rozwiązanie: Cut Cross-Entropy (CCE)

Badacze z Apple zaprezentowali nową metodę o nazwie Cut Cross-Entropy (CCE), zaprojektowaną w celu przezwyciężenia problemów pamięciowych związanych z modelami o dużym słowniku. W przeciwieństwie do tradycyjnych metod, które przechowują wszystkie logity dla tokenów w pamięci, CCE dynamicznie oblicza jedynie niezbędne logity i wykonuje redukcje log-sum-exp w pamięci na chipie. Dzięki temu nie ma potrzeby tworzenia dużych macierzy w pamięci GPU, co znacząco redukuje zużycie pamięci. Na przykład, w modelu Gemma 2, zużycie pamięci na obliczenia związane z funkcją straty zmniejszyło się z 24 GB do zaledwie 1 MB, a całkowite zużycie pamięci przez głowicę klasyfikatora spadło z 28 GB do 1 GB.

Zasada działania CCE

Sercem metody CCE jest efektywna strategia obliczeń, która wykorzystuje niestandardowe jądra CUDA do przetwarzania osadzeń i wykonywania redukcji. Dzięki obliczaniu logitów w locie i unikaniu przechowywania pośrednich danych w pamięci, metoda ta korzysta z pamięci współdzielonej GPU, która jest szybsza i bardziej wydajna niż tradycyjna pamięć globalna. Ponadto, filtrowanie gradientów selektywnie pomija obliczenia, które mają niewielki wpływ na gradient, wykorzystując wewnętrzną rzadkość macierzy softmax. Sortowanie słownika optymalizuje przetwarzanie, grupując tokeny o istotnym wkładzie, minimalizując tym samym marnowane obliczenia. Połączenie tych innowacji umożliwia stworzenie mechanizmu obliczania funkcji straty, który jest zarówno wydajny pamięciowo, jak i niskoopóźnieniowy.

Korzyści płynące z CCE

Zyski wydajnościowe wynikające z zastosowania CCE są imponujące. Redukcje pamięci umożliwiły dziesięciokrotne zwiększenie rozmiaru partii danych dla mniejszych modeli, takich jak GPT-2, oraz 1,5-krotne zwiększenie dla większych modeli, takich jak Llama 2 (13B). Przepustowość treningowa pozostała na tym samym poziomie, a wyniki eksperymentalne wykazały stabilną konwergencję, równą tradycyjnym metodom. Dla partii danych składającej się z 8192 tokenów i słownika o wielkości 256 000, CCE osiągnęło szczytowe zużycie pamięci na poziomie zaledwie 1 MB w porównaniu do 28 GB w tradycyjnych metodach. Testy stabilności treningu na modelach takich jak Llama 3 (8B) oraz Phi 3.5 Mini potwierdziły niezawodność CCE, z krzywymi strat nieodróżnialnymi od tych uzyskanych w tradycyjnych podejściach.

Kluczowe wnioski z badań

Badania nad metodą CCE dostarczają kilku istotnych wniosków:

1. Znacząca redukcja pamięci – CCE redukuje zużycie pamięci na obliczenia entropii krzyżowej do zaledwie 1 MB w dużych modelach, takich jak Gemma 2 (2B).
2. Poprawa skalowalności – Metoda pozwala na użycie większych partii danych, co wspiera bardziej efektywne wykorzystanie zasobów obliczeniowych, kluczowe dla trenowania rozbudowanych modeli.
3. Zyski w efektywności – Niestandardowe jądra CUDA i filtrowanie gradientów zapewniają, że redukcja pamięci nie wpływa negatywnie na prędkość treningu ani konwergencję modelu.
4. Zastosowanie w praktyce – Metoda jest elastyczna i może być stosowana w różnych architekturach, z potencjalnymi aplikacjami w klasyfikacji obrazów i uczeniu kontrastywnym.
5. Potencjał na przyszłość – Zdolność CCE do obsługi dużych słowników przy minimalnym wpływie na pamięć może umożliwić trenowanie jeszcze bardziej rozbudowanych modeli w przyszłości, z lepszym zrównoważeniem procesów obliczeniowych.

Podsumowanie

Metoda Cut Cross-Entropy (CCE) stanowi przełomowe rozwiązanie w trenowaniu dużych modeli językowych, eliminując krytyczny problem związany z pamięciochłonnymi warstwami entropii krzyżowej. Poprzez zastosowanie dynamicznego obliczania logitów, filtrowania gradientów i sortowania słownika, CCE oferuje dramatyczne redukcje zużycia pamięci bez poświęcania prędkości czy dokładności. To innowacyjne podejście nie tylko zwiększa efektywność obecnych modeli, ale także otwiera drogę do bardziej skalowalnych i zrównoważonych architektur w przyszłości, umożliwiając rozwój dużych modeli maszynowego uczenia.