Graph Representation Learning using Graph Convolution / Attention Network

Graph Convolution Network and Graph Attention Network

Sarvesh Khetan
8 min read4 days ago

A. Supervised Learning Tasks

1. Graph Level Regression Task

Now there have been several deep learning based architectures proposed to solve this task :

2. Graph Level Classification Task

Just like regression task we can also have classification task, for instance, given a molecule predict whether it is HIV inhibitor or not !!

B. Unsupervised / Self Supervised Learning Tasks

Above we saw supervised learning methods to determine graph embeddings but there are unsupervised learning approaches too !! There are two most famous architecture, namely :

Graph Convolution Network (GCN)

Method 1 : Global Pooling Method

Step 1 : Generating Mature Node Embeddings using Graph Convolution Networks (GCN)

In non learning method we saw this idea that node embedding can be improve using surround node information but there was no learning involved … it was just a simple aggreation using element wise sum / averaging

Now idea is that before performing the aggregation we will project the embeddings from one dimension to another dimension using a FFNN and then perform aggregation (using sum / average) on the new dimension embeddings !!

Summation as Aggregation Function

Asymmetric Averaging as Aggregation Function

Symmetric Averaging as Aggregation Function

Step 2 : Aggregation — Combining Mature Node Embeddings to get Graph Embeddings

Same aggregation methods that you saw in non learning method here.

Code Implementation (From Scratch and using Pytorch Geometric Library)

Now once you have graph embeddings, you can use any regression algorithm like linear regression / decision tree / neural network / …. to solve the regression problem, below I have used a neural network to solve the regression problem after getting the graph embeddings !!

Method 2 : Dummy / Virtual / Super — Node Method

Same as what we saw in non learning method using GNN just that here instead of using GNN we will use GCN

Method 3 : Hierarchical Pooling Method

Same as what we saw in non learning method using GNN just that here instead of using GNN we will use GCN

Code Implementation (From Scratch and using Pytorch Geometric Library)

Graph Attention Network (GAN)

Method 1 : Global Pooling Method

Step 1 : Generating Mature Node Embeddings using GAN

Note : The normalization formula has a correction, it should be exp(a) everywhere instead of exp(alpha)

Or instead of using Neural Networks to calculate attention they could have also simply used dot product !!

Matrix Implementation

Above I have just explained a single head attention, while in real life implementations we use multi head attention !!

Step 2 : Aggregation — Combining Mature Node Embeddings to get Graph Embeddings

Same aggregation methods that you saw in non learning method here.

Code Implementation (using Pytorch Geometric Library)

Method 2 : Dummy / Virtual / Super — Node Method

Same as what we saw in non learning method using GNN just that here instead of using GNN we will use GAN

Method 3 : Hierarchical Pooling Method

Same as what we saw in non learning method using GNN just that here instead of using GNN we will use GAN

Unsupervised / Self-Supervised Learning for Graph Embeddings

DeepWalk (2014) by Stony Brook University

Step 1 : Generating Mature Node Embeddings

For a given graph, we first generate some sequence of nodes using Random walks.

Each random walk denotes a datapoint to train the neural network

Now these sequence of nodes can be thought of a sequence of words forming a sentence.

Hence now you can use all the techniques used in NLP like next word prediction / mask word prediction / …… Authors of this paper used Word2Vec model architecture to predict next word / node. (you can use other SOTA techniques available in NLP today to perform next word prediction like lstm / transformers / ….)

Step 2 : Aggregation — Combining Mature Node Embeddings to get Graph Embeddings

Now once you have the matured node embeddings you can do a element wise sum / average to get the entire graph embedding !!

Node2Vec (2016) by Standford

Step 1 : Generating Mature Node Embeddings

This algorithm uses some of the ideas presented by Deepwalk but goes a step deeper. Instead of using vanilla random walks as introduced by Deepwalk, they use biased random walk. Biased Random Walk uses a combination of the algorithms DFS and BFS to extract the random walks. This combination of algorithms is controlled by two parameters P (return parameter) and Q (in-out parameter).

Basically, if P is large the random walks will be large, so it does exploration and if P is small we stay locally. Similar but opposite behaviour happens with Q, if Q is small it is going to do exploration and if Q is large it is going to stay locally. More details can be found in the original paper.

As a high level overview, the simplest comparison of a random walk would be through walking. Imagine that each step you take is determined probabilistically. This implies that at each index of time, you have moved in a certain direction based on a probabilistic outcome. This algorithm explores the relationship to each step that you would take and its distance from the initial starting point.

Now you might wonder how these probabilities of moving from one node to another are calculated. Node2Vec introduces the following formula for determining the probability of moving to the node x given that you were previously at the node v.

Where z is the normalization constant, and πvx is the unnormalized transition probability between nodes x and v [4]. Clearly, if there is no edge connecting x and v, then the probability will be 0, but if there is an edge, we identify a normalized probability of going from v to x.

The paper states that the easiest way to introduce a bias to influence the random walks would be if there was a weight associated with each edge. However, that wouldn’t work in the case of unweighted networks. To resolve this, the authors introduced a guided random walk governed by two parameters p and q. p indicates the probability of a random walk getting back to the previous node, and q indicates the probability that a random walk can pass through a previously unseen part of the graph [4].

Where dtx represents the shortest path between nodes t and x. It can be visually seen in the illustration below.

Step 2 : Aggregation — Combining Mature Node Embeddings to get Graph Embeddings

Now once you have the matured node embeddings you can do a element wise sum / average to get the entire graph embedding !!

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

Sarvesh Khetan
Sarvesh Khetan

Written by Sarvesh Khetan

A deep learning enthusiast and a Masters Student at University of Maryland, College Park.

No responses yet

Write a response