Tsubatoの発信記録

主に機械学習やデータサイエンス関連で学んだことを書いています。

ゼロから作るDeep Learning3 フレームワーク編を写経した

0. この投稿の概要

  • フレームワークということでただ動くだけではなく、ユーザビリティやメモリ管理にも気を配っています。そこで本記事はコアとなる自動微分を実現する仕組み、ユーザビリティを向上させる仕組み、メモリ効率を改善する仕組みという3つの軸で学んだことをまとめました。
  • なお本書の最初のセクションはwebで公開されていますので、まずはこちらを確認して購入を検討されると良いかと思います。最初のセクションに一番コアなことが書かれていますので太っ腹ですね。

koki0702.github.io

1. 学んだこと

自動微分を実現する仕組み

  • 変数を格納するVariableクラスと各種関数のベースとなるFunctionクラスを作成、自動微分は基本的にこの2クラスだけで成立している。
    • 特にFunctionクラスは順伝播と逆伝播を行うメソッドを持ち、具体的な計算は継承先のクラスでそれぞれ実装される。
  • 誤差逆伝播を自動化させるために関数と変数のつながりを保持する必要がある。このつながりは順伝播でデータを流すときに作り、この特徴がDefine-by-Runと呼ばれる。
  • 分岐を含む複雑な計算グラフで逆伝播を正しく実現するために、順伝播をした際に変数や関数の世代を記録し、後ろの世代を優先的に逆伝播時に処理する。
  • 勾配もVariableクラスとして保持することで、逆伝播時も計算グラフを作れるようになる。勾配をさらに逆伝播することで高階微分を求めることができる。

ユーザビリティを向上させる仕組み

  • 複数の入力がある関数に対応するため、Functionクラスは可変長引数を使っている。
  • Variable同士、あるいはVariableとndarrayとの計算を+や-のような演算子を使ってできるように演算子オーバーロードをしている。
    • 他にもVariableクラスに格納されている中身(ndarray)を見やすくするため、ndarrayと同様のインスタンス変数(shape, ndim)などをpropertyとして実装。また__repr__メソッドを実装してprintにも対応した。
  • Parameterクラス(Variableと同一)の集約としてLayerクラスを実装。さらにパラメータ更新作業を一括して行うOptimizerクラスを実装。これが無いとパラメータごとに更新する退屈なコードをユーザが書かなければならない。

メモリ効率を改善する仕組み

  • pythonのメモリ管理は参照カウント方式が基本、もし循環参照がある場合は参照カウント方式では削除されないため、循環参照を作らないように心がける。
    • Functionが入出力のVariableを参照し、Variableはそれを生み出したFunctionを参照するため循環参照がある。本書ではFunctionから出力Variableへの参照を弱参照に置き換えた。
  • MeanSquaredErrorのような複数の関数の組み合わせで求められるものは、それ単体としてFunctionとして定義をした。既存の関数の組み合わせで実装すると、途中の変数に対しても計算グラフが生成されてメモリ効率が良くないため。
  • 本書の範囲では順伝播の結果を全てFunctionクラスが保持しているが、関数の中には逆伝播の計算に順伝播の情報が不要のものがある。関数ごとに保持するデータを決めれば不要な情報を保持する必要がなくなる。

2. 感想

  • 実装してみて、普段利用しているフレームワークのコアは自動微分の仕組みにあることに改めて気付かされました。逆に言えば既存のフレームワークの理解を深めるならまずはそこを抑えておいた方が良さそうです。
  • 本書は500ページ近くのボリュームがある本ですが、特に面白い自動微分に関わる実装は前半で終わるので、まずはそこまで写経をしてみると良いかと思います。
  • 次のステップとしてはモチベーションが高いうちにPyTorchのコードリーディングに少しでも挑戦したいです。また、今回とは違うdefine-and-run方式、特にどのようなネットワークの最適化が行われるかも調べたいと思います。