リカレントニューラルネットワーク

0.章の概要

この章では、リカレントネットワーク(RNN:Recurrent Neural Network)について解説しています。リカレントネットワークは時系列を考慮できるニューラルネットワークで、音声データ、機械翻訳、自然言語処理などで使用されています。従来のニューラルネットワーク(例えば、順伝播ネットワーク)では入力をまとめて行うため、時系列データを考慮できませんでしたが、リカレントネットワークではその問題を解決することができました。

1.リカレントニューラルネットワーク

学習キーワード: 順伝播の計算、逆伝搬の計算(Back Propagation Through Time; BPTT)、双方向RNN

概要

リカレントニューラルネットワーク(RNN)は、時系列データや系列データを処理するための強力なモデルです。ここで言う時系列データとは、時間的に変化するデータを指します。自然言語データも時系列データに含まれます。


順伝播の計算
RNNの基本構造
rnn

RNNは、入力データを時間的に処理するために、隠れ層の状態を持ちます。各時刻 \(t\) において、RNNは以下のように計算を行います。

\[ h_t = f(W_h h_{t-1} + W_x x_t + b) \]
\[ o_t = W_y h_t + b_y \] \[ y_t = f(o_t) \]

ここで、

  • \(h_t\): 時刻 \(t\) の隠れ状態
  • \(h_{t-1}\): 時刻 \(t-1\) の隠れ状態
  • \(x_t\): 時刻 \(t\) の入力ベクトル
  • \(W_h\): 隠れ状態の重み行列
  • \(W_x\): 入力の重み行列
  • \(b\): バイアスベクトル
  • \(f\): 活性化関数(通常はtanhやReLU)
  • \(o_t\): 時刻 \(t\) の出力(活性化関数を適用する前の値)
  • \(y_t\): 時刻 \(t\) の出力(活性化関数を適用後の値)
  • \(W_y\): 出力の重み行列
  • \(b_y\): 出力のバイアスベクトル

従来のニューラルネットワークと異なる点は、RNNが時間的に連続する入力データを処理する能力を持つことです。従来のニューラルネットワークでは、入力データをまとめて処理するため、時系列データの連続性を考慮できませんでしたが、RNNは、各時刻の入力データと前の隠れ状態を用いて、現在の隠れ状態を計算することで、時系列データの連続性を考慮することができます。

この特徴により、RNNは時系列データを扱えるようになり、音声データ、機械翻訳、自然言語処理など、多くの領域で活用されています。


逆伝搬の計算(Back Propagation Through Time; BPTT)
逆伝播の計算(BPTT)

RNNの学習には、誤差を逆伝播させるBPTTが用いられます。BPTTは、通常の逆伝播と同じアルゴリズムを使用しますが、RNNの特性上、誤差の伝播も「時間軸を遡る形」で行う必要があります。

BPTTでは、時間ステップを遡ることで誤差の伝播を行いますが、長い時系列データになると「勾配消失問題」や「勾配爆発問題」が生じやすくなり、学習が不安定になります。(この問題を解決するために、LSTMやGRUといった特殊なRNNセルが開発されました。)

また、BPTTは並列計算が難しく、計算コストが高いという問題もありますが、現在では、計算効率を向上させるためにBPTTを一定の時間範囲で打ち切る手法として、Truncated BPTTが用いられています。


双方向RNN

双方向RNN(Bidirectional Recurrent Neural Network、Bi-RNN)は、入力データの過去と未来の両方の文脈を考慮して出力を生成するRNNの拡張版です。通常のRNNは一方向(順方向)のみで入力を処理しますが、Bi-RNNでは順方向と逆方向の両方でデータを処理し、2つの情報を結合して出力を得ます。

bidirectional_rnn
\[ \overrightarrow{h_t} = f(W_x x_t + W_h \overrightarrow{h_{t-1}} + b_h) \] \[ \overleftarrow{h_t} = f(W_x x_t + W_h \overleftarrow{h_{t+1}} + b_h) \] \[ y_t = g(W_out \cdot [ \overrightarrow{h_t} \overleftarrow{h_t} ] + b_out) \]

ここで、

  • \(\overrightarrow{h_t}\): 前向きのRNNの時刻 \(t\) の隠れ状態
  • \(\overleftarrow{h_t}\): 後向きのRNNの時刻 \(t\) の隠れ状態
  • \(y_t\): 出力
  • \(W_x\): 入力の重み行列
  • \(W_h\): 隠れ状態の重み行列
  • \(b_h\): 隠れ状態のバイアスベクトル
  • \(W_out\): 出力の重み行列
  • \(b_out\): 出力のバイアスベクトル
  • \(f\): 活性化関数(通常はtanhやReLU)
  • \(g\): 出力層の活性化関数

双方向RNNは、過去と未来の両方向から文脈情報を取り入れられるため、特に以下のような用途に適しています。

  • 自然言語処理(文の意味理解、品詞タグ付け、機械翻訳など):単語の意味は前後の文脈に依存するため、Bi-RNNは精度を高めます。
  • 音声認識:ある音声が前後の音に依存するケースに有効です。
  • 画像キャプション生成:双方向RNNを活用することで、画像内の物体やその関係性を時系列データとして扱い、キャプションを生成します。

Bi-RNNは計算コストが通常のRNNに比べて2倍になる点には注意が必要ですが、優れた性能を発揮する場面が多いため、広く応用されています。


関連用語の補足
  • 教師強制teacher forcing:長期依存性の問題を回避するため、前の層の隠れ層ではなく、正解ラベルを次の時刻に利用する学習方法のこと。中間層同士の接続がなく、並列処理が可能だが、テスト時は、出力を使用しているため、訓練時データの分布が異なり、精度が低い。
  • スケジュールサンプリング(scheduled sampling):教師強制の課題解決を目的に出てきた手法で、次の層に、正解ラベルを使用するのか出力を使用するのかを、確率で決定する手法。
  • 全結合グラフィカルモデル:過去の全情報を直接入力するように変更したモデルで、BPTTの問題を解決しようとしたが、パラメータ数の多さがデメリットで計算コストが高い。
  • スキップ接続:、離れた過去の情報を現在に接続させ、長期の依存性を解決しようとした手法。

2.ゲート機構

学習キーワード: 勾配消失、忘却ゲート、入力ゲート、出力ゲート、LSTM(長期記憶と短期記憶)、GRU、リセットゲート、メモリーセル

概要

RNNは、長い系列データを扱うため、勾配消失問題が発生しやすいという問題があります。この問題を解決するために、以下で説明するLSTMやGRUなどの改良版が提案されました。

これらのモデルは、ゲート機構を導入することで、重要な情報を長期間保持し、不要な情報を忘れることができます。具体的には、LSTMではセル状態とゲート(入力ゲート、忘却ゲート、出力ゲート)を導入し、GRUではリセットゲートと更新ゲートを導入します。


LSTM(長期記憶と短期記憶)

LSTMは、RNNの改良版の一つで、長期依存性の問題を解決できたモデルとなりました。その構造は、下図のようにRNNの中間層を、メモリと3つのゲートを持つブロック(LSTM blockと呼ばれる)に置き換えたものなっています。

lstm
  • 忘却ゲート:過去の情報を忘れるかどうかを決定します。
  • 入力ゲート:新しい情報をセル状態に追加するかどうかを決定します。
  • 出力ゲート:セル状態から出力する情報を決定します。
  • 現時刻の状態:セル状態を更新します。
  • 現時刻の出力:セル状態から出力します。
  • z:セル状態の更新に使用されます。

これらの数式は以下の通りです。

\[ f_t = \sigma(W_f x_t + U_f h_{t-1} + b_f) \] \[ i_t = \sigma(W_i x_t + U_i h_{t-1} + b_i) \] \[ o_t = \sigma(W_o x_t + U_o h_{t-1} + b_o) \] \[ c_t = f_t \odot c_{t-1} + i_t \odot z_t\] \[ h_t = o_t \odot \tanh(c_t) \] \[ z_t = \tanh(W_z x_t + U_z h_{t-1} + b_z) \]

ここで、\(\sigma\)はシグモイド関数を表し、\(f_t\)は忘却ゲート、\(i_t\)は入力ゲート、\(o_t\)は出力ゲート、\(c_t\)はセル状態、\(h_t\)は隠れ状態、\(z_t\)はセル状態の更新を表します。

のぞき穴結合

のぞき穴結合は、LSTMのゲートにセル状態の情報を追加することで、ゲートの決定にセル状態の影響を与える技術です。通常のLSTMでは、ゲートの入力は前の隠れ状態と現在の入力のみで決定されますが、のぞき穴結合を導入することで、ゲートがセル状態も考慮に入れて調整されるようになります。

\[ f_t = \sigma(W_f x_t + U_f h_{t-1} + V_f c_{t-1} + b_f) \] \[ i_t = \sigma(W_i x_t + U_i h_{t-1} + V_i c_{t-1} + b_i) \] \[ o_t = \sigma(W_o x_t + U_o h_{t-1} + V_o c_{t} + b_o) \] \[ c_t = f_t \odot c_{t-1} + i_t \odot z_t \] \[ h_t = o_t \odot \tanh(c_t) \] \[ z_t = \tanh(W_z x_t + U_z h_{t-1} + b_z) \]

ここで、\(\sigma\)はシグモイド関数を表し、\(f_t\)は忘却ゲート、\(i_t\)は入力ゲート、\(o_t\)は出力ゲート、\(c_t\)はセル状態、\(h_t\)は隠れ状態、\(z_t\)はセル状態の更新を表します。また、\(W_f, U_f, V_f, b_f\)などはモデルのパラメータを表し、\(x_t, h_{t-1}, c_{t-1}\)などは入力や前の時刻の状態を表します。\(o_t\) を計算する時点では、\(c_t\) が計算されているため、\(c_t\) を用いています。

上記数式のように、のぞき穴結合を導入することで、LSTMの性能が向上することが期待されます。


GRU

GRUは、LSTMと似たゲート機構を持つが、よりシンプルな構造をしています。パラメータ数が多く計算コストが高いLSTMの弱点を克服したモデルです。

gru
  • 更新ゲート:新しい情報をどのようにセル状態に追加するかを決定します。
  • リセットゲート:過去の情報をどのように忘れるかを決定します。
  • 新しい隠れ状態:セル状態を更新します。
  • 現時刻の出力:セル状態から出力します。

これらの数式は以下の通りです。

\[ u_t = \sigma(W_u x_t + U_u h_{t-1} + b_u) \] \[ r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r) \] \[ h_t = (1 - u_t) \odot h_{t-1} + u_t \odot \hat{h}_t \] \[ \hat{h}_t = \tanh(U_h x_t + W_h (r_t \odot h_{t-1}) + b_h) \]

ここで、\(\sigma\)はシグモイド関数を表し、\(u_t\)は更新ゲート、\(r_t\)はリセットゲート、\(\hat{h}_t\)は新しい隠れ状態、\(h_t\)は現時刻の隠れ状態を表します。

LSTMとGRUの精度はタスクに応じて異なり、データの特性やタスクの目的に応じて、どちらが適しているかを判断することが重要になります。


関連用語の補足
  • メモリーセル:LSTMやGRUでは、メモリーセル(Cell State)が導入されています。これは、長期的な依存関係を記憶するための機構です。メモリーセルは、各時刻で更新され、忘却ゲートと入力ゲートによって情報が追加または削除されます。
  • 勾配爆発問題:勾配爆発問題はネットワークの深さが増えるにつれて、勾配が指数関数的に増加し、数値が非常に大きくなりすぎて計算が不安定になる問題です。
  • 勾配クリッピング(gradient clipping):勾配爆発問題を解決するために用いられる手法。勾配が一定の閾値を超えた場合に、その勾配を制限することで学習を安定させることができます。具体的には、以下の2種類の方法で勾配をクリップします。
    \[ g' = \min(g, \theta) \]

    これは要素ごとにクリッピングした場合の式で、\(g\)は勾配の値、\(\theta\)は閾値、\(g'\)はクリップされた勾配の値です。

    \[ g' = \frac{g}{\|g\|} \cdot \min(\|g\|, \theta) \]

    これはノルムでクリッピングした場合の式で、\(g\)は勾配の値、\(\theta\)は閾値、\(g'\)はクリップされた勾配の値、\(\|g\|\)は勾配のノルム(大きさ)です。

3.系列変換

学習キーワード: エンコーダ・デコーダ、sequence-to-sequence(seq2seq)、アテンション(注意)機構

概要

「系列変換」は、自然言語処理や音声処理などの分野で、入力されたデータの系列(例えば文章や音声)を別の系列に変換する技術です。具体的には、機械翻訳(原文を別の言語に変換する)や要約生成(長文を短い要約に変換する)、音声認識(音声をテキストに変換する)などで用いられます。


sequence-to-sequence(seq2seq)

seq2seqは、2つのRNN(エンコーダー・デコーダー)を利用し、時系列データを別の時系列データに変換するモデルです。Encoder-Decoderモデルとも呼ばれています。

基本的なseq2seqモデルは、エンコーダとデコーダの2つの主要な構成要素からなります。

seq2seq
エンコーダ(Encoder)→符号化、圧縮

入力された系列(例えば文章)を一つ一つのトークンとして順に処理し、最終的に系列全体をまとめた潜在ベクトル(コンテキストベクトル)を出力するもので、アーキテクチャとして、LSTMやGRUなどが多く使われます。出力されるコンテキストベクトルは、入力系列の要約情報を保持しています。

デコーダ(Decoder)→復号化、展開

エンコーダからのコンテキストベクトルを受け取り、これを元に出力系列を一つずつ生成していきます。生成されたトークンは次のステップでの入力として使われ、モデルが系列全体を段階的に出力します。

seq2seqは複雑な系列変換タスクに適しており、入力と出力の長さが異なる場合でも柔軟に対応できますが、計算コストが高いという問題点があります。


アテンション(注意)機構

従来のseq2seqモデルでは、エンコーダが出力する情報が一つのベクトルに圧縮されるため、長い入力系列では情報が失われやすいという問題があり、これを改善するためにアテンション機構が発明されました。アテンション機構は、デコーダが出力を生成する際に、エンコーダの各出力を適切に重み付けして利用できるようにし、長い系列でも過去情報をを適切に保持します。

※エンコーダの各出力を重みづけすることで、過去の重要なポイントに注意(Attention)することからこのような名前がついています。

attention_encoder
attention_decoder
アテンションの計算手順

アテンションの計算手順は、以下の5つのステップで構成されます。

下記図引用元記事はこちら
attention_process
  1. エンコーダの出力とデコーダの状態を準備:エンコーダの出力 (encoder \(h_t\)) とデコーダの状態 (decoder \(h_t\)) を準備します。
  2. スコアの計算:デコーダの状態 \(Q\)(Query) とエンコーダの状態 \(K\)(Key) の内積(スコア)を計算します。スコアは、エンコーダの各出力に対する重要度を表します。
    \[ score_t = QK^T \]
  3. 重要度の計算:スコアにソフトマックス関数をかけます。
    \[ \alpha_t = \frac{\exp(score_t)}{\sum_{t=1}^T \exp(score_t)} \]
  4. 重要度とエンコーダの出力を掛ける:重要度とエンコーダの出力 \(V\) を掛けます。
    \[ \alpha_t V = \text{softmax}(QK^T) V \]
  5. Context vectorの計算:重要度とエンコーダの出力の要素ごとの積の総和を、Context vector\(C_t\) と呼びます。
    \[ C_t = \sum_{t=1}^T \text{softmax}(QK^T)V \]

ここで、\(Q\)はデコーダの状態、 \(K\) 、\(V\)はエンコーダの状態、\(\alpha_t\) はアテンションウェイト、\(C_t\) はコンテキストベクトルを表します。

K(キー)とV(バリュー)の違い

エンコーダの出力 \(h_t\) は、入力系列を処理して得られる隠れ状態ベクトルです。この出力から、\(K\) (キー)と\(V\)(バリュー)が生成されますが、異なる重み行列を使用して変換されるため、最終的な数値は異なります。

具体的には、次のように表現されます:

\[ K = W_k h_t \] \[ V = W_v h_t \]

ここで、\(W_k\) と \(W_v\) はそれぞれキーとバリューを生成するための異なる重み行列です。

キーワードまとめ

順伝播の計算、逆伝搬の計算(Back Propagation Through Time; BPTT)、双方向RNN、勾配消失、忘却ゲート、入力ゲート、出力ゲート、LSTM(長期記憶と短期記憶)、GRU、リセットゲート、メモリーセル、エンコーダ・デコーダ、seqence-to-sequence(seq2seq)、アテンション(注意)機構