logo

GNN

图神经网络

王哲峰 / 2022-07-15


目录

TODO

GNN 解决什么问题

近十年来(从 2012 年 AlexNet 开始计算),深度学习在计算机视觉(CV)和自然语言处理(NLP)等领域得到的长足的发展, 深度神经网络对于图像和文字等欧几里得数据(Euclidean data)可以进行较好的处理,之所以被称为欧几里得数据, 是由于这类数据位于 $n$ 维欧几里得空间 $\mathbb{R}^{n}$ 中(如 AlexNet 将所有图像的尺寸都预处理成 224x224x3)。 常见的有表格二维的欧几里得数据,RGB 图像数据是三维欧几里得数据,长宽两个维度加一个颜色/通道维度,如果再加上 batch,就是四维

然而,现实世界是复杂的,如社交网络,一个人的朋友数量是不固定的,也很难排个顺序, 这类复杂的非欧几里得数据(non-Euclidean),没有上下左右,没有顺序,没有坐标参考点, 难以用方方正正的(grid-like)矩阵/张量表示,为了把不规则的脚(非欧数据)穿进标准的鞋(神经网络)里, 之前干了不少削足适履的事,效果不太好,于是,问题变成了:能否设计一种新的鞋,使它能适合不规则的脚呢?

非欧数据的场景很多,除了上面提到的社交网络,其他例子如:计算机网络,病毒传播路径, 交通运输网络(地铁网络),食物链,粒子网络(物理学家描述基本粒子生存的关系,有点类似家谱), 说到家谱,家谱也是,(生物)神经网络(神经网络本来就是生物学术语,现在人工神经网络 ANN 太多, 鸠占鹊巢了), 基因控制网络,分子结构,知识图谱,推荐系统,论文引用网络等等。 这些场景的非欧数据用图(Graph)来表达是最合适的,但是, 经典的深度学习网络(ANN, CNN, RNN)却难以处理这些非欧数据,于是, 图神经网络(GNN)应运而生,GNN 以图作为输入,输出各种下游任务的预测结果

下游任务包括但不限于:

图介绍

图(Graph)是图论的研究对象,图论是欧拉在研究哥尼斯堡七桥问题过程中,创造出来的新数学分支

网络(Graph / Network)视为一个系统,以 $G(N, E)$ 表示,由两种元素组成: 顶点/节点(Vertex/Node),以 $N$ 表示,和边/链接(Edge/Link),以 $E$ 表示。 顶点和边具有属性(Attribute),边可能有方向(有向图 Directed Graph)。 社交网络中,人是顶点,人和人之间的关系是边,人/顶点的属性比如年龄、性别、职业、爱好等构成了一个向量,类似的,边也可用向量来表示

img

图表示

图本身也具有表达其自身的全局属性,来描述整个图

邻接矩阵

如何用数学表示图中顶点的关系呢?最常见的方法是邻接矩阵(Adjacency Matrix), 下图中 A 和 B、C、E 相连,故第一行和第一列对应的位置为 1,其余位置为 0

img

如果将图片的像素表达为图,下左图表示图片的像素值,深色表示 1,浅色表示 0, 右图为该图片对应的图,中间为对应的邻接矩阵,蓝色表示 1,白色表示 0。 随着图的顶点数($n$)增多,邻接矩阵矩阵的规模($n^{2}$)迅速增大, 一张百万($10^{6}$)像素的照片, 对应的邻接矩阵的大小就是($10^{6} \times 10^{6} = 10^{12}$), 计算时容易内存溢出,而且其中大多数值为 0,很稀疏

img

文本也可以用邻接矩阵表示,但是问题也是类似的,很大很稀疏:

img

邻接列表

也可以选用边来表示图,即邻接列表(Adjacency List),这可以大幅减少对空间的消耗,因为实际的边比所有可能的边(邻接矩阵)数量往往小很多

img

类似的例子有很多:

神经网络特点

节点特征的表达学习

消息传递

节点嵌入的计算

附录

图表示实现

class Vertex:
    def __init__(self, vertex):
        self.name = vertex
        self.neighbors = []
        
    def add_neighbor(self, neighbor):
        if isinstance(neighbor, Vertex):
            if neighbor.name not in self.neighbors:
                self.neighbors.append(neighbor.name)
                neighbor.neighbors.append(self.name)
                self.neighbors = sorted(self.neighbors)
                neighbor.neighbors = sorted(neighbor.neighbors)
        else:
            return False
        
    def add_neighbors(self, neighbors):
        for neighbor in neighbors:
            if isinstance(neighbor, Vertex):
                if neighbor.name not in self.neighbors:
                    self.neighbors.append(neighbor.name)
                    neighbor.neighbors.append(self.name)
                    self.neighbors = sorted(self.neighbors)
                    neighbor.neighbors = sorted(neighbor.neighbors)
            else:
                return False
        
    def __repr__(self):
        return str(self.neighbors)


class Graph:
    def __init__(self):
        self.vertices = {}
    
    def add_vertex(self, vertex):
        if isinstance(vertex, Vertex):
            self.vertices[vertex.name] = vertex.neighbors

            
    def add_vertices(self, vertices):
        for vertex in vertices:
            if isinstance(vertex, Vertex):
                self.vertices[vertex.name] = vertex.neighbors
            
    def add_edge(self, vertex_from, vertex_to):
        if isinstance(vertex_from, Vertex) and isinstance(vertex_to, Vertex):
            vertex_from.add_neighbor(vertex_to)
            if isinstance(vertex_from, Vertex) and isinstance(vertex_to, Vertex):
                self.vertices[vertex_from.name] = vertex_from.neighbors
                self.vertices[vertex_to.name] = vertex_to.neighbors
                
    def add_edges(self, edges):
        for edge in edges:
            self.add_edge(edge[0],edge[1])          
    
    def adjacencyList(self):
        if len(self.vertices) >= 1:
                return [str(key) + ":" + str(self.vertices[key]) for key in self.vertices.keys()]  
        else:
            return dict()
        
    def adjacencyMatrix(self):
        if len(self.vertices) >= 1:
            self.vertex_names = sorted(g.vertices.keys())
            self.vertex_indices = dict(zip(self.vertex_names, range(len(self.vertex_names)))) 
            import numpy as np
            self.adjacency_matrix = np.zeros(shape=(len(self.vertices),len(self.vertices)))
            for i in range(len(self.vertex_names)):
                for j in range(i, len(self.vertices)):
                    for el in g.vertices[self.vertex_names[i]]:
                        j = g.vertex_indices[el]
                        self.adjacency_matrix[i,j] = 1
            return self.adjacency_matrix
        else:
            return dict()              

                     
def graph(g):
    """
    Function to print a graph as adjacency list and adjacency matrix.
    """
    return str(g.adjacencyList()) + '\n' + '\n' + str(g.adjacencyMatrix())

##############################################################################
a = Vertex('A')
b = Vertex('B')
c = Vertex('C')
d = Vertex('D')
e = Vertex('E')

a.add_neighbors([b, c, e]) 
b.add_neighbors([a, c])
c.add_neighbors([b, d, a, e])
d.add_neighbor(c)
e.add_neighbors([a, c])
        
g = Graph()
print(graph(g))
print()
g.add_vertices([a,b,c,d,e])
g.add_edge(b,d)
print(graph(g))
[
    "A:['B', 'C', 'E']", 
    "C:['A', 'B', 'D', 'E']", 
    "B:['A', 'C', 'D']", 
    "E:['A', 'C']", 
    "D:['B', 'C']"
]

[[ 0.  1.  1.  0.  1.]
 [ 1.  0.  1.  1.  0.]
 [ 1.  1.  0.  1.  1.]
 [ 0.  1.  1.  0.  0.]
 [ 1.  0.  1.  0.  0.]]

参考