Last active
December 20, 2019 23:04
-
-
Save juliensimon/fb07ea736fd34bd863decd2c70b3c48e to your computer and use it in GitHub Desktop.
DGL part 4
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.nn.functional as F | |
def gcn_message(edge): | |
# In: a graph edge | |
# Out: a message containing the features of the source node | |
return {'msg' : edges.src['h']} | |
def gcn_reduce(node): | |
# In: a graph node | |
# Out: the new 'h' features for the node, obtained by summing received messages | |
return {'h' : torch.sum(node.mailbox['msg'], dim=1)} | |
# Define the GCNLayer module | |
class GCNLayer(nn.Module): | |
def __init__(self, in_feats, out_feats): | |
super(GCNLayer, self).__init__() | |
self.linear = nn.Linear(in_feats, out_feats) | |
def forward(self, g, inputs): | |
# In: the graph and the input features | |
# Set input features for all nodes | |
g.ndata['h'] = inputs | |
# Trigger message passing on all edges | |
g.send(g.edges(), gcn_message) | |
# Trigger aggregation at all nodes | |
g.recv(g.nodes(), gcn_reduce) | |
# Get the updated node features | |
h = g.ndata.pop('h') | |
# Perform linear transformation | |
return self.linear(h) | |
# We start with 34 dimensions (for the 34 nodes). | |
# The first layer shrinks them to 5 dimensions (arbitrary number). | |
# The second layer outputs 2 dimensions (the class probabilities). | |
net = GCN(34, 5, 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment