Tsubatoの発信記録

主に機械学習やデータサイエンス関連で学んだことを書いています。

Andrej KarpathyのGPT解説動画

本記事の概要

TeslaでAI開発のディレクターを務め、現在はChatGPTで有名なOpenAIで働くAndrej KarpathyのGPT解説動画[Let's build GPT: from scratch, in code, spelled out.]を紹介します。
www.youtube.com

動画の概要

  • ChatGPTにも使用されている言語モデルGPT3と同等のモデルを実装していきます。


https://upload.wikimedia.org/wikipedia/commons/9/91/Full_GPT_architecture.png

  • データセットは1MB程度の小さなものなのでそこまでの精度は出ません。あくまでモデルのアーキテクチャを学ぶのが目的のようです。
  • 実装は全てPyTorchで、もちろん動画で実装されるコードは全て公開されています。GitHub - karpathy/ng-video-lecture
  • 2時間の動画ですが、データの前処理から始まり、シンプルなモデルから徐々にGPTに近づけていくスタイルでとても理解しやすい解説でした。

個人的な学び

  • GPTの構造はTransformerの元論文、Attention Is All You NeedのDecoder側(右半分)を微修正したものとのこと。


  • 一方でコンピュータビジョンに使われるVision TransformerはEncoder側を使っているので、それぞれが別の用途に使われていくのは興味深いです。もっともTransformerの根幹であるMulti head attentionはEncoder/Decoder両方にあるので両者に大きな差異はないですが…
  • 余談ですが、Vision Transformerには優れた日本語の本が出ていますのでこちらもオススメです。特に2章の実装の解説はSelf Attentionを理解する上でとても参考になるので、ここだけでも読んでみても良いかと思います。

  • 実装面では以下のtrilを利用した過去の入力の効率的な平均計算が参考になりました。
import torch
import torch.nn.functional as F

B, T, C = 2, 8, 2
x = torch.randn(B, T, C)
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril==0, float("-inf"))
wei = F.softmax(wei, dim=1)
wei @ x
  • GPTでは過去の語句を元に次の語句を予測するというタスクのため、未来の語句をマスクして過去の語句の特徴だけを考慮する必要があります。
  • そこで行列の対角よりも上側の要素を全て0にするtrilを用いたtrickが使われています。単純に平均をとる場合なら以下のような実装で良いですが、GPTではqueryとkeyの行列積で重み付けをするため、より汎用性の高いsoftmaxを使った実装が用いられています。
tril = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
wei @ x