"""models to learn tensor representations of graphs with specific relational
properties using graph neural networks"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
from tart.utils.model_utils import get_device, get_torch_tensor_type
class Preprocess(nn.Module):
"""Preprocesser for graph features. Identifies tensorized features
and concatenates them for both nodes and edges.
"""
def __init__(self, dim_in, args):
super(Preprocess, self).__init__()
self.dim_in = dim_in
self.node_feat = args.node_feats
self.node_feat_dims = args.node_feat_dims
self.edge_feat = args.edge_feats
self.edge_feat_dim = args.edge_feat_dims
@property
def dim_out(self):
return self.dim_in + sum([aug_dim for aug_dim in self.node_feat_dims])
def forward(self, batch):
node_feat_list = [batch.node_feature]
edge_feat_list = []
for key in self.node_feat:
tensor_key = key + "_t"
if batch[tensor_key] is None:
raise Exception("Node feature {} is None".format(key))
if len(batch[tensor_key].shape) == 3:
# reshape [batch_size, 1, n] to [batch_size, n]
node_feat_list.append(batch[tensor_key].squeeze(1))
else:
node_feat_list.append(batch[tensor_key])
for key in self.edge_feat:
tensor_key = key + "_t"
if batch[tensor_key] is None:
raise Exception("Edge feature {} is None".format(key))
if len(batch[tensor_key].shape) == 3:
# reshape [batch_size, 1, n] to [batch_size, n]
edge_feat_list.append(batch[tensor_key].squeeze(1))
else:
edge_feat_list.append(batch[tensor_key])
batch.node_feature = torch.cat(node_feat_list, dim=-1).type(torch.FloatTensor)
batch.edge_feature = torch.cat(edge_feat_list, dim=-1).type(torch.FloatTensor)
return batch
[docs]class SubgraphEmbedder(nn.Module):
"""Model for order embeddings.
Uses a GNN (encoder) to embed graphs into a vector space, and then uses a
MLP (classifier) to predict if queries are subgraphs of targets.
"""
def __init__(self, input_dim, hidden_dim, args):
super(SubgraphEmbedder, self).__init__()
self.encoder = BasicGNN(input_dim, hidden_dim, hidden_dim, args)
self.margin = args.margin
self.use_intersection = False
self.classifier = nn.Sequential(nn.Linear(1, 2), nn.LogSoftmax(dim=1))
[docs] def forward(self, emb_targets, emb_queries):
return emb_targets, emb_queries
[docs] def criterion(self, pred, labels):
"""Loss function for subgraph ordering in embedding space.
error = amount of violation (if b is a subgraph of a).
For + examples, train to minimize error -> 0;
For - examples, train to minimize error to be atleast self.margin
"""
emb_targets, emb_queries = pred
# sum(||max{0, z_q - z_u}||_2^2))
error = torch.sum(
torch.max(
torch.zeros_like(emb_targets, device=get_device()),
emb_queries - emb_targets,
)
** 2,
dim=1,
)
margin = self.margin
# rewrite loss for -ve examples
error[labels == 0] = torch.max(torch.tensor(0.0, device=get_device()), margin - error)[labels == 0]
relation_loss = torch.sum(error)
return relation_loss
[docs] def predict(self, pred):
"""Inference API: predict if queries are subgraphs of targets
Args:
pred (List<emb_t, emb_q>): embeddings of pairs of graphs
"""
emb_targets, emb_queries = pred
is_subgraph = torch.sum(
torch.max(
torch.zeros_like(emb_targets, device=emb_targets.device),
emb_queries - emb_targets,
)
** 2,
dim=1,
)
return is_subgraph
[docs] def predictv2(self, pred):
"""Inference API v2: predict if queries are subgraphs of targets
Args:
pred (List<emb_t, emb_q>): embeddings of pairs of graphs
"""
emb_targets, emb_queries = pred
batch_size, emb_size = emb_targets.shape
DIM_RATIO = 0.1 # 10% of the embedding dimension
MAX_VIO_DIMS = int(DIM_RATIO * emb_size) # 10% of 64 = 6
# subtract emb_targets from emb_queries
subtract = torch.sub(emb_queries, emb_targets)
assert subtract.shape == (batch_size, emb_size)
# 1 if violating the order constraint
indicator = (subtract > 0).type(get_torch_tensor_type())
# note: if no gpu, comment above and uncomment below
# indicator = (subtract > 0).type(torch.FloatTensor)
# count #dim with violations using einsum (faster than .sum)
indicator_sum = torch.einsum("ij->i", indicator)
assert indicator_sum.shape == (batch_size,)
# 1 indicates violation is in < DIM_RATIO*emb_size dimensions => subgraph
# 0 indicates otherwise. => !subgraph
predictions = (indicator_sum < MAX_VIO_DIMS).view(-1, 1).type(get_torch_tensor_type())
scores = 1 - indicator_sum / emb_size
return predictions, scores
[docs]class BasicGNN(nn.Module):
"""Basic GNN model with the following configurable options:
- number of layers
- aggregation type
- skip connections
"""
def __init__(self, input_dim, hidden_dim, output_dim, args):
super(BasicGNN, self).__init__()
self.dropout = args.dropout
self.n_layers = args.n_layers
self.skip = args.skip
self.agg_type = args.agg_type
self.node_feats = args.node_feats
self.node_feat_dims = args.node_feat_dims
self.edge_feats = args.edge_feats
self.edge_feat_dims = args.edge_feat_dims
# add a preprocessor
self.feat_preprocess = Preprocess(input_dim, args)
input_dim = self.feat_preprocess.dim_out
# MODULE: INPUT
self.pre_mp = nn.Sequential(nn.Linear(input_dim, hidden_dim))
# MODULES(k): Graph Aggregation/Convolution
agg_module = self.get_agg_layer(type=args.agg_type)
self.aggregates = nn.ModuleList()
# add learnable skip params
if args.skip == "learnable":
self.learnable_skip = nn.Parameter(torch.ones(self.n_layers, self.n_layers))
for layer in range(args.n_layers):
if args.skip == "all" or args.skip == "learnable":
# a layer can get input from any of it's preceding layers
# layer_input = hidden_dim * # previous-layers
hidden_input_dim = hidden_dim * (layer + 1)
else:
hidden_input_dim = hidden_dim
self.aggregates.append(agg_module(hidden_input_dim, hidden_dim))
# MODULE: OUTPUT
post_input_dim = hidden_dim * (args.n_layers + 1)
self.post_mp = nn.Sequential(
nn.Linear(post_input_dim, hidden_dim),
nn.Dropout(args.dropout),
nn.LeakyReLU(0.1),
nn.Linear(hidden_dim, output_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 256),
nn.ReLU(),
nn.Linear(256, hidden_dim),
)
def get_agg_layer(self, type):
# graph convolution
if type == "GCN":
return pyg_nn.GCNConv
# graph isomorphism + weighted edges
elif type == "GIN":
return lambda i, h: WeightedGINConv(nn.Sequential(nn.Linear(i, h), nn.ReLU(), nn.Linear(h, h)))
# graph isomorphism net + edge features
elif type == "GINE":
return lambda i, h: pyg_nn.GINEConv(
nn.Sequential(nn.Linear(i, h), nn.ReLU(), nn.Linear(h, h)),
edge_dim=sum(self.edge_feat_dims),
)
else:
print("unrecognized model type")
[docs] def forward(self, data):
# preprocess (if reqd)
if self.feat_preprocess is not None:
if not hasattr(data, "preprocessed"):
data = self.feat_preprocess(data)
data.preprocessed = True
x = data.node_feature
assert x.shape[1] == sum(self.node_feat_dims) + 1, "node feature dim mismatch"
edge_index, edge_attr = data.edge_index, data.edge_feature
batch = data.batch
# MOVE TO DEVICE
x = x.to(get_device())
edge_index = edge_index.to(get_device())
edge_attr = edge_attr.to(get_device())
# pre mlp
x = self.pre_mp(x)
all_emb = x.unsqueeze(1)
emb = x
# aggregate-combine loop (k iterations)
for i in range(len(self.aggregates)):
# print(f"Running layer {i}")
# aggregate
if self.skip == "learnable":
skip_vals = self.learnable_skip[i, : i + 1]
skip_vals = skip_vals.unsqueeze(0).unsqueeze(-1)
curr_emb = all_emb * torch.sigmoid(skip_vals) # select inputs
curr_emb = curr_emb.view(x.size(0), -1)
x = self.aggregates[i](curr_emb, edge_index, edge_attr)
elif self.skip == "all":
x = self.aggregates[i](emb, edge_index, edge_attr)
else:
x = self.aggregates[i](x, edge_index, edge_attr)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# combine
emb = torch.cat((emb, x), 1)
# pooling
# x = pyg_nn.global_mean_pool(x, batch)
emb = pyg_nn.global_add_pool(emb, batch)
# post MLP
emb = self.post_mp(emb)
return emb
def loss(self, pred, label):
return F.nll_loss(pred, label)
# GINConv + weighted edges adapted from NeuroMatch
# [https://arxiv.org/abs/2007.03092]
#
# UPDATE: GINConv does not take into account any edge features.
# However, GINEConv [https://arxiv.org/abs/1905.12265] does.
[docs]class WeightedGINConv(pyg_nn.MessagePassing):
"""WeightedGINConv implementation for PyG."""
def __init__(self, nn, eps=0, train_eps=False, **kwargs):
super(WeightedGINConv, self).__init__(aggr="add", **kwargs)
self.nn = nn
self.initial_eps = eps
if train_eps:
self.eps = torch.nn.Parameter(torch.Tensor([eps]))
else:
self.register_buffer("eps", torch.Tensor([eps]))
self.reset_parameters()
[docs] def reset_parameters(self):
self.eps.data.fill_(self.initial_eps)
def forward(self, x, edge_index, edge_weight=None):
""""""
x = x.unsqueeze(-1) if x.dim() == 1 else x
edge_index, edge_weight = pyg_utils.remove_self_loops(edge_index, edge_weight)
out = self.nn((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_weight=edge_weight))
return out
[docs] def message(self, x_j, edge_weight):
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
def __repr__(self):
return "{}(nn={})".format(self.__class__.__name__, self.nn)