リカレントニューラルネットワーク
- 0.概要
- 1.リカレントニューラルネットワーク
- 2.ゲート機構
- 3.系列変換
- キーワードまとめ
Contents
0.章の概要
この章では、リカレントネットワーク(RNN:Recurrent Neural Network)について解説しています。リカレントネットワークは時系列を考慮できるニューラルネットワークで、音声データ、機械翻訳、自然言語処理などで使用されています。従来のニューラルネットワーク(例えば、順伝播ネットワーク)では入力をまとめて行うため、時系列データを考慮できませんでしたが、リカレントネットワークではその問題を解決することができました。
1.リカレントニューラルネットワーク
学習キーワード: 順伝播の計算、逆伝搬の計算(Back Propagation Through Time; BPTT)、双方向RNN
概要
リカレントニューラルネットワーク(RNN)は、時系列データや系列データを処理するための強力なモデルです。ここで言う時系列データとは、時間的に変化するデータを指します。自然言語データも時系列データに含まれます。
順伝播の計算
RNNの基本構造
RNNは、入力データを時間的に処理するために、隠れ層の状態を持ちます。各時刻 \(t\) において、RNNは以下のように計算を行います。
ここで、
- \(h_t\): 時刻 \(t\) の隠れ状態
- \(h_{t-1}\): 時刻 \(t-1\) の隠れ状態
- \(x_t\): 時刻 \(t\) の入力ベクトル
- \(W_h\): 隠れ状態の重み行列
- \(W_x\): 入力の重み行列
- \(b\): バイアスベクトル
- \(f\): 活性化関数(通常はtanhやReLU)
- \(o_t\): 時刻 \(t\) の出力(活性化関数を適用する前の値)
- \(y_t\): 時刻 \(t\) の出力(活性化関数を適用後の値)
- \(W_y\): 出力の重み行列
- \(b_y\): 出力のバイアスベクトル
従来のニューラルネットワークと異なる点は、RNNが時間的に連続する入力データを処理する能力を持つことです。従来のニューラルネットワークでは、入力データをまとめて処理するため、時系列データの連続性を考慮できませんでしたが、RNNは、各時刻の入力データと前の隠れ状態を用いて、現在の隠れ状態を計算することで、時系列データの連続性を考慮することができます。
この特徴により、RNNは時系列データを扱えるようになり、音声データ、機械翻訳、自然言語処理など、多くの領域で活用されています。
逆伝搬の計算(Back Propagation Through Time; BPTT)
逆伝播の計算(BPTT)
RNNの学習には、誤差を逆伝播させるBPTTが用いられます。BPTTは、通常の逆伝播と同じアルゴリズムを使用しますが、RNNの特性上、誤差の伝播も「時間軸を遡る形」で行う必要があります。
BPTTでは、時間ステップを遡ることで誤差の伝播を行いますが、長い時系列データになると「勾配消失問題」や「勾配爆発問題」が生じやすくなり、学習が不安定になります。(この問題を解決するために、LSTMやGRUといった特殊なRNNセルが開発されました。)
また、BPTTは並列計算が難しく、計算コストが高いという問題もありますが、現在では、計算効率を向上させるためにBPTTを一定の時間範囲で打ち切る手法として、Truncated BPTTが用いられています。
双方向RNN
双方向RNN(Bidirectional Recurrent Neural Network、Bi-RNN)は、入力データの過去と未来の両方の文脈を考慮して出力を生成するRNNの拡張版です。通常のRNNは一方向(順方向)のみで入力を処理しますが、Bi-RNNでは順方向と逆方向の両方でデータを処理し、2つの情報を結合して出力を得ます。
ここで、
- \(\overrightarrow{h_t}\): 前向きのRNNの時刻 \(t\) の隠れ状態
- \(\overleftarrow{h_t}\): 後向きのRNNの時刻 \(t\) の隠れ状態
- \(y_t\): 出力
- \(W_x\): 入力の重み行列
- \(W_h\): 隠れ状態の重み行列
- \(b_h\): 隠れ状態のバイアスベクトル
- \(W_out\): 出力の重み行列
- \(b_out\): 出力のバイアスベクトル
- \(f\): 活性化関数(通常はtanhやReLU)
- \(g\): 出力層の活性化関数
双方向RNNは、過去と未来の両方向から文脈情報を取り入れられるため、特に以下のような用途に適しています。
- 自然言語処理(文の意味理解、品詞タグ付け、機械翻訳など):単語の意味は前後の文脈に依存するため、Bi-RNNは精度を高めます。
- 音声認識:ある音声が前後の音に依存するケースに有効です。
- 画像キャプション生成:双方向RNNを活用することで、画像内の物体やその関係性を時系列データとして扱い、キャプションを生成します。
Bi-RNNは計算コストが通常のRNNに比べて2倍になる点には注意が必要ですが、優れた性能を発揮する場面が多いため、広く応用されています。
関連用語の補足
- 教師強制teacher forcing:長期依存性の問題を回避するため、前の層の隠れ層ではなく、正解ラベルを次の時刻に利用する学習方法のこと。中間層同士の接続がなく、並列処理が可能だが、テスト時は、出力を使用しているため、訓練時データの分布が異なり、精度が低い。
- スケジュールサンプリング(scheduled sampling):教師強制の課題解決を目的に出てきた手法で、次の層に、正解ラベルを使用するのか出力を使用するのかを、確率で決定する手法。
- 全結合グラフィカルモデル:過去の全情報を直接入力するように変更したモデルで、BPTTの問題を解決しようとしたが、パラメータ数の多さがデメリットで計算コストが高い。
- スキップ接続:、離れた過去の情報を現在に接続させ、長期の依存性を解決しようとした手法。
2.ゲート機構
学習キーワード: 勾配消失、忘却ゲート、入力ゲート、出力ゲート、LSTM(長期記憶と短期記憶)、GRU、リセットゲート、メモリーセル
概要
RNNは、長い系列データを扱うため、勾配消失問題が発生しやすいという問題があります。この問題を解決するために、以下で説明するLSTMやGRUなどの改良版が提案されました。
これらのモデルは、ゲート機構を導入することで、重要な情報を長期間保持し、不要な情報を忘れることができます。具体的には、LSTMではセル状態とゲート(入力ゲート、忘却ゲート、出力ゲート)を導入し、GRUではリセットゲートと更新ゲートを導入します。
LSTM(長期記憶と短期記憶)
LSTMは、RNNの改良版の一つで、長期依存性の問題を解決できたモデルとなりました。その構造は、下図のようにRNNの中間層を、メモリと3つのゲートを持つブロック(LSTM blockと呼ばれる)に置き換えたものなっています。
- 忘却ゲート:過去の情報を忘れるかどうかを決定します。
- 入力ゲート:新しい情報をセル状態に追加するかどうかを決定します。
- 出力ゲート:セル状態から出力する情報を決定します。
- 現時刻の状態:セル状態を更新します。
- 現時刻の出力:セル状態から出力します。
- z:セル状態の更新に使用されます。
これらの数式は以下の通りです。
ここで、\(\sigma\)はシグモイド関数を表し、\(f_t\)は忘却ゲート、\(i_t\)は入力ゲート、\(o_t\)は出力ゲート、\(c_t\)はセル状態、\(h_t\)は隠れ状態、\(z_t\)はセル状態の更新を表します。
のぞき穴結合
のぞき穴結合は、LSTMのゲートにセル状態の情報を追加することで、ゲートの決定にセル状態の影響を与える技術です。通常のLSTMでは、ゲートの入力は前の隠れ状態と現在の入力のみで決定されますが、のぞき穴結合を導入することで、ゲートがセル状態も考慮に入れて調整されるようになります。
ここで、\(\sigma\)はシグモイド関数を表し、\(f_t\)は忘却ゲート、\(i_t\)は入力ゲート、\(o_t\)は出力ゲート、\(c_t\)はセル状態、\(h_t\)は隠れ状態、\(z_t\)はセル状態の更新を表します。また、\(W_f, U_f, V_f, b_f\)などはモデルのパラメータを表し、\(x_t, h_{t-1}, c_{t-1}\)などは入力や前の時刻の状態を表します。\(o_t\) を計算する時点では、\(c_t\) が計算されているため、\(c_t\) を用いています。
上記数式のように、のぞき穴結合を導入することで、LSTMの性能が向上することが期待されます。
GRU
GRUは、LSTMと似たゲート機構を持つが、よりシンプルな構造をしています。パラメータ数が多く計算コストが高いLSTMの弱点を克服したモデルです。
- 更新ゲート:新しい情報をどのようにセル状態に追加するかを決定します。
- リセットゲート:過去の情報をどのように忘れるかを決定します。
- 新しい隠れ状態:セル状態を更新します。
- 現時刻の出力:セル状態から出力します。
これらの数式は以下の通りです。
ここで、\(\sigma\)はシグモイド関数を表し、\(u_t\)は更新ゲート、\(r_t\)はリセットゲート、\(\hat{h}_t\)は新しい隠れ状態、\(h_t\)は現時刻の隠れ状態を表します。
LSTMとGRUの精度はタスクに応じて異なり、データの特性やタスクの目的に応じて、どちらが適しているかを判断することが重要になります。
関連用語の補足
- メモリーセル:LSTMやGRUでは、メモリーセル(Cell State)が導入されています。これは、長期的な依存関係を記憶するための機構です。メモリーセルは、各時刻で更新され、忘却ゲートと入力ゲートによって情報が追加または削除されます。
- 勾配爆発問題:勾配爆発問題はネットワークの深さが増えるにつれて、勾配が指数関数的に増加し、数値が非常に大きくなりすぎて計算が不安定になる問題です。
- 勾配クリッピング(gradient clipping):勾配爆発問題を解決するために用いられる手法。勾配が一定の閾値を超えた場合に、その勾配を制限することで学習を安定させることができます。具体的には、以下の2種類の方法で勾配をクリップします。
\[ g' = \min(g, \theta) \]
これは要素ごとにクリッピングした場合の式で、\(g\)は勾配の値、\(\theta\)は閾値、\(g'\)はクリップされた勾配の値です。
\[ g' = \frac{g}{\|g\|} \cdot \min(\|g\|, \theta) \]これはノルムでクリッピングした場合の式で、\(g\)は勾配の値、\(\theta\)は閾値、\(g'\)はクリップされた勾配の値、\(\|g\|\)は勾配のノルム(大きさ)です。
3.系列変換
学習キーワード: エンコーダ・デコーダ、sequence-to-sequence(seq2seq)、アテンション(注意)機構
概要
「系列変換」は、自然言語処理や音声処理などの分野で、入力されたデータの系列(例えば文章や音声)を別の系列に変換する技術です。具体的には、機械翻訳(原文を別の言語に変換する)や要約生成(長文を短い要約に変換する)、音声認識(音声をテキストに変換する)などで用いられます。
sequence-to-sequence(seq2seq)
seq2seqは、2つのRNN(エンコーダー・デコーダー)を利用し、時系列データを別の時系列データに変換するモデルです。Encoder-Decoderモデルとも呼ばれています。
基本的なseq2seqモデルは、エンコーダとデコーダの2つの主要な構成要素からなります。
エンコーダ(Encoder)→符号化、圧縮
入力された系列(例えば文章)を一つ一つのトークンとして順に処理し、最終的に系列全体をまとめた潜在ベクトル(コンテキストベクトル)を出力するもので、アーキテクチャとして、LSTMやGRUなどが多く使われます。出力されるコンテキストベクトルは、入力系列の要約情報を保持しています。
デコーダ(Decoder)→復号化、展開
エンコーダからのコンテキストベクトルを受け取り、これを元に出力系列を一つずつ生成していきます。生成されたトークンは次のステップでの入力として使われ、モデルが系列全体を段階的に出力します。
seq2seqは複雑な系列変換タスクに適しており、入力と出力の長さが異なる場合でも柔軟に対応できますが、計算コストが高いという問題点があります。
アテンション(注意)機構
従来のseq2seqモデルでは、エンコーダが出力する情報が一つのベクトルに圧縮されるため、長い入力系列では情報が失われやすいという問題があり、これを改善するためにアテンション機構が発明されました。アテンション機構は、デコーダが出力を生成する際に、エンコーダの各出力を適切に重み付けして利用できるようにし、長い系列でも過去情報をを適切に保持します。
※エンコーダの各出力を重みづけすることで、過去の重要なポイントに注意(Attention)することからこのような名前がついています。
アテンションの計算手順
アテンションの計算手順は、以下の5つのステップで構成されます。
下記図引用元記事はこちら- エンコーダの出力とデコーダの状態を準備:エンコーダの出力 (encoder \(h_t\)) とデコーダの状態 (decoder \(h_t\)) を準備します。
- スコアの計算:デコーダの状態 \(Q\)(Query) とエンコーダの状態 \(K\)(Key) の内積(スコア)を計算します。スコアは、エンコーダの各出力に対する重要度を表します。
\[ score_t = QK^T \]
- 重要度の計算:スコアにソフトマックス関数をかけます。
\[ \alpha_t = \frac{\exp(score_t)}{\sum_{t=1}^T \exp(score_t)} \]
- 重要度とエンコーダの出力を掛ける:重要度とエンコーダの出力 \(V\) を掛けます。
\[ \alpha_t V = \text{softmax}(QK^T) V \]
- Context vectorの計算:重要度とエンコーダの出力の要素ごとの積の総和を、Context vector\(C_t\) と呼びます。
\[ C_t = \sum_{t=1}^T \text{softmax}(QK^T)V \]
ここで、\(Q\)はデコーダの状態、 \(K\) 、\(V\)はエンコーダの状態、\(\alpha_t\) はアテンションウェイト、\(C_t\) はコンテキストベクトルを表します。
K(キー)とV(バリュー)の違い
エンコーダの出力 \(h_t\) は、入力系列を処理して得られる隠れ状態ベクトルです。この出力から、\(K\) (キー)と\(V\)(バリュー)が生成されますが、異なる重み行列を使用して変換されるため、最終的な数値は異なります。
具体的には、次のように表現されます:
ここで、\(W_k\) と \(W_v\) はそれぞれキーとバリューを生成するための異なる重み行列です。
キーワードまとめ
順伝播の計算、逆伝搬の計算(Back Propagation Through Time; BPTT)、双方向RNN、勾配消失、忘却ゲート、入力ゲート、出力ゲート、LSTM(長期記憶と短期記憶)、GRU、リセットゲート、メモリーセル、エンコーダ・デコーダ、seqence-to-sequence(seq2seq)、アテンション(注意)機構