分散処理

0.章の概要

この章では、分散処理の基本的な理論と実際に使われているアプローチについて詳しく説明します。分散深層学習の技術を活用することで、計算資源を最も効率的に活用し、大規模なデータセットや複雑なモデルのトレーニングを可能にします。また、後半の連合学習では、内部データを流出せずに、組織間で分散処理を行う方法についても解説しています。

1. 並列分散処理

学習キーワード: 分散深層学習、モデル並列化、データ並列化

概要

分散深層学習とは、計算タスクを複数の計算機やデバイスに分散させて処理を行う技術です。 特に、大規模なデータセットや複雑なモデルのトレーニングにおいては、分散処理を用いた計算資源の効率化が必要になります。

分散処理の方法は、主に以下の2種類に分けられます。

  • モデル並列化
  • データ並列化

モデル並列化はモデルが大きいときに用いられ、データ並列化はデータが大きいときに用いられます。


モデル並列化

1つのモデルのパラメータや計算を分割して複数の計算機、デバイスで処理し、必要に応じて他のプロセスと通信します。バッチサイズをプロセスの数だけ増加(2プロセスならば2倍)させたり、学習を高速化できます。

model_parallelization
  • 特徴: 複雑なモデルを効率的に学習できますが、デバイス間でこの処理を行う場合は、データ通信が大きなボトルネックになってしまいます。
  • 適用例: モデルのサイズがデバイスのメモリ容量を超える場合に適用されます(GPT-4のような大規模モデルや高解像度の画像を処理する多層CNNモデルなど)。

データ並列化

同一モデルを複数のデバイス分コピーし、それぞれ異なるデータを処理します(コピーしたネットワークはレプリカと呼ばれる)。その後、レプリカごとで学習した勾配を数iterationごとに平均化することで同期します。モデルが大きく、1デバイスで処理が難しい場合に用いられます。

data_parallelization
  • 特徴: 各レプリカが独立して計算を行うため、モデル並列化に比べて通信量は少なくなります。
  • 適用例: 大規模データセットを効率的に学習したい時に適用されます。

また、パラメータの合わせ方で、同期型か非同期型か分かれます。

  • 同期型全てのレプリカの計算が終了するのを待ち、全レプリカの勾配が出た後に平均をとり、親モデルのパラメータを更新します。全レプリカの計算が完了するまで待機する必要がありますが、精度が高いとされています。
    synchronous
  • 非同期型各レプリカが独立して計算を行う手法です。各レプリカで学習した勾配をパラメーターサーバに送信し、パラメータサーバ内で平均値が更新されます。新たに学習するときは、にパラメーターサーバから最新の平均勾配を取得し、それを用いて訓練を行っていきます。処理が高速ですが、各レプリカが最新のモデルを使用できないため、学習が不安定になりやすい欠点があります。
    asynchronous

2. 連合学習(Federated Learning)

学習キーワード: クロスデバイス学習、クロスサイロ学習、Federated Averaging、Local Model、Global Model

概要

連合学習はGoogle が2017年に提唱した技術で、データをデバイス間(組織間、個人間)で共有することなく、分散学習を行う手法です。 従来のモデル学習では、様々な組織、個人などから生のデータを集約してモデル作成を行うため、「自社(個人)の情報が洩れる」などのセキュリティ面での欠点がありました。 連合学習では、各デバイス(個人ごとや組織ごと)でLocal Modelをトレーニングし、その結果を集約してGlobal Modelを作成します。 集約するのはLocal Modelであり、生データではないため、プライバシーを保護しつつ、効率的な学習行えるようになりました。

以下に、連合学習の概要図を示します。

federated_learning

連合学習の種類

連合学習は、主にクロスデバイス学習とクロスサイロ学習の2種類に分けられます。

クロスデバイス学習

クロスデバイス学習では、スマートフォンなど多数のIoTデバイスを用いて学習を行います。

  • 例: スマートフォンユーザーの予測変換履歴から、連合学習を用いてキーボード入力予測モデルを学習する。
クロスサイロ学習

クロスサイロ学習では、組織間でデータを共有せずに、協力してモデルをトレーニングする手法です。

  • 例: 病院間で、院外には公開できない患者情報を用いて診断予測モデルの学習を行う。

Federated Averaging

Federated Averagingは、googleが2017年に発表した、連合学習の発端となった手法です。各デバイスで学習したLocal Modelの重みを中央サーバーで統合し、新しいGlobal Modelを生成します。詳細のアルゴリズムは多少複雑ですので、詳しく知りたい方は論文「Communication-Efficient Learning of Deep Networks from Decentralized Data」を参照ください。

Federated Averagingをもとにした連合学習の流れは以下の通りです。

  1. サーバーが全てのクライアントからランダムに数個選択します。
  2. 選択されたクライアントは、Global Modelをもとに、自社や個人で保有するデータを使ってローカルに学習を行い、パラメータを更新します。
  3. 各クライアントは、更新されたモデル(Local Model)のパラメータをサーバーに共有します。
  4. サーバーは共有された各クライアントのパラメーターを平均し、新たなGlobal Modelとします。

関連用語の補足
  • Local Model: 各デバイスがローカルで学習したモデル。
  • Global Model: 複数デバイスで統合されたモデル。

キーワードまとめ

分散深層学習、モデル並列化、データ並列化、クロスデバイス学習、クロスサイロ学習、Federated Averaging、Local Model、Global Model