【深層学習】Graph Attention Networks(GAT)を理解する
深層学習やGNN(Graph Neural Network)関連の論文を読み漁っていると、Graph Attention Network(GAT, グラフアテンションネットワーク)に関する論文を目にすることが多くあると思います。 GAT(は、GNN(Graph Neural Network)やGCN(Graph Convolutional Network)の枠組みに、Attentionの機構を取り入れることで、深層学習の予測精度を上げたモデルとなっています。 この記事では、2018年に発表された元論文に基づいて、GATについて分かりやすく解説していきます。
Graph Attention Networkの概要
GATは、一言で言えば、深層学習でグラフ構造の分類などを行うことができるGCNと呼ばれる手法に、Attentionの機構を取り入れた手法です。 Attentionの機構とは、入力されたデータに対し、どこに注目するかを動的に決定する仕組みで、このAttentionの機構をGCN(グラフ畳み込みネットワーク)に取り入れることで、分類、識別、予測精度を上げています。 このため、GCNについて理解が浅い場合、GATについていまいち理解できない可能性があります。 GCNについては、下記の記事で分かりやすく解説しているので、ぜひ参考にしてみてください。 [画像なし 【深層学習】GCN(グラフ畳み込みネットワーク)をわかりやすく解説する GCN(Graph Convolution Network)は、GNN(Grap…](https://disassemble-channel.com/deep-learning-gcn/) GCNとGATの大きな違いは、ノードの畳み込みの際の係数(これがいわゆるAttention係数です)が大きく違います。 通常のGCNでは、あるノードの次の層の特徴量(潜在変数)を求める際には、隣接するノードの特徴量に線形の重みをかけたものの和に、活性化関数(ReLUなど)を通すことで次の層に値を伝搬させていましたが(※この辺りがよくわからない人は、先にGCNの記事をご覧ください) その隣接ノードの特徴量の線形和を求める際に、隣接ノードをどれも対等に扱っていたのに対し、GATでは、隣接するノードを台頭ではなく、重要度(Attention係数)の概念を取り入れたものになります。 イメージとしては次のようになります。 上図では、ノード1と隣接するノード2, 3, 4のうち、エッジが太線になっているように、ノード3が重要で、次に4のノード、そして2のノードの順に重要度をつけたイメージとなります。 GATは隣接するノードにこのような概念を入れた仕組みとなっています。 GATでは、この重要度を、attention scoreで、$\alpha$としています。 GATでは、各ノードに対して、他のノードからの重要度 $\alpha_$を計算します。上図では、ノード1に対する$\alpha_ ~ \alpha_$を示しています。 ここで、他のノードだけでなく、自分のノードの重要度$\alpha_$も考えることに注意してください。 GATでは、このAttention係数(Attention coefficient)を利用して、畳み込み計算を行います。 GCNの畳み込み計算は次の方に書くことができました。 GCNにおけるMessage Passing
\begin \begin \bm_^ = \sigma \biggl ( \sum_(i) \cup \ < i\>> \frac \sqrt> (\bm^T \bm^) \biggr ) \end \end(1)式は、$l+1$番目の層における、ノード$i$の特徴量の更新式をnode-wiseの表現したものとなっています。(この表現についても、先述のGCNの記事で解説しています) ここで、隣接ノードの特徴量は、$\bm^T \bm^$で、そこに係数$\frac \sqrt>$がついているのが分かりますね。 この係数$\frac \sqrt>$は、GCNの場合だと、ノード$i$とその隣接ノード$j$の字数の平方根をとっているだけですが、GATではこの係数(Attention係数)をもう少し、賢く求めていきましょう、というのが発想となっています。 なんとなくGATとGCNの違いについては分かりましたか? 多分、GCNについて理解していないとなかなか難しいと思うので、GATの前にGCNについて理解を深めることをおすすめします。 では、GATではattention 係数をどのように計算するのか、以降で解説していきます。
GATにおけるAttention係数
では、GATで、各ノード間の繋がりの重要度を示す、アテンション係数 $\alpha$をどのように定義するのかみていきましょう。 まず、特徴量の更新をしたい$i$番目のノードの特徴量を$\bm \in \mathbb^$、ノード$i$に隣接するノード$j$の特徴量を$\bm \in \mathbb^$とすると、GATではこのノード$i$とノード$j$の繋がりのAttention係数$e_$を次のように定義します。 Graph Attention NetworkにおけるAttention係数$e_$の定義
\begin e_ = \bm(\bm, \bm) \endこの$e_$ノード$i$とノード$j$の関連度の重要性を示しています。 ここで、$\bm \in \mathbb^$はGCNでも登場しているように、特徴量を線形変換する重みパラメータです。次元は、$F$から$F’$に変わっていることに注意してください。 また、関数$a$は、shared attention mechanism(attention メカニズム, 注意機構)とよばれており、さまざまな関数系が考えられます。 たとえば、内積のような計算だと、似ているベクトルの内積の演算をすると、値が大きくなりますね。このような関数系がAttention メカニズム $a$のイメージとなります。 ここで、$e_$は決定された後、次のsoftmax関数によって、正規化されます。正規化されたAttention係数が、上述で説明した$\alpha_$に対応します。
\begin \alpha_ = \operatorname(e_) =\frac< \sum_> exp(e_)> \end \begin e_ = \operatorname (\bm[W\bm || W \bm]) \endこれで、Attention係数 $\alpha_$を得ることができました。 ここで、$||$という見かけない演算子がありますが、これは単純に2つのベクトルを連結(concatenationさせているだけです。 難しく考える必要性がありません。今回は、$F’$のベクトルを2つ連結しているので、$W\bm || W \bm$はサイズ$2F’$のベクトルとなります。 論文中では上のような絵が登場します。よく見るとすごく簡単です。 下の丸が全て同一のベクトルで、このベクトルを単層のニューラルネットに入れて、正規化されたAttention係数を得ているだけです。
Graph Attention Network の全体像
関連記事 【深層学習】GATを用いた多変量異常検知 GDNの論文を解説今回は、GAT(Graph Attention Network)を利用した異常検.
グラフニューラルネットワーク 2022/11/27 【GNN】Message Passing Neural Network(MPNN)を解説するMPNN(Message Passing Neural Network)は、GN.
グラフニューラルネットワーク 2022/11/26 【深層学習】GCN(グラフ畳み込みネットワーク)をわかりやすく解説するGCN(Graph Convolution Network)は、GNN(Grap.
グラフニューラルネットワーク 2022/11/23 【GNN】Message Passing Neural Network(MPNN)を解説する ロジスティック回帰の理論と実装をわかりやすく解説 M 機械学習と情報技術航空宇宙の研究者が運営する理工系技術ブログ。 大学教養〜専門レベルの数学・物理・工学を「数式の導出 + Python実装」で解説します。