Skip to content

Instantly share code, notes, and snippets.

@juliensimon
Last active December 20, 2019 23:04
Show Gist options
  • Save juliensimon/fb07ea736fd34bd863decd2c70b3c48e to your computer and use it in GitHub Desktop.
Save juliensimon/fb07ea736fd34bd863decd2c70b3c48e to your computer and use it in GitHub Desktop.
DGL part 4
import torch.nn as nn
import torch.nn.functional as F
def gcn_message(edge):
# In: a graph edge
# Out: a message containing the features of the source node
return {'msg' : edges.src['h']}
def gcn_reduce(node):
# In: a graph node
# Out: the new 'h' features for the node, obtained by summing received messages
return {'h' : torch.sum(node.mailbox['msg'], dim=1)}
# Define the GCNLayer module
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, inputs):
# In: the graph and the input features
# Set input features for all nodes
g.ndata['h'] = inputs
# Trigger message passing on all edges
g.send(g.edges(), gcn_message)
# Trigger aggregation at all nodes
g.recv(g.nodes(), gcn_reduce)
# Get the updated node features
h = g.ndata.pop('h')
# Perform linear transformation
return self.linear(h)
# We start with 34 dimensions (for the 34 nodes).
# The first layer shrinks them to 5 dimensions (arbitrary number).
# The second layer outputs 2 dimensions (the class probabilities).
net = GCN(34, 5, 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment