Beispiel #1
0
class RGCNSAGPooling(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 num_relations,
                 ratio=0.5,
                 min_score=None,
                 multiplier=1,
                 nonlinearity=torch.tanh,
                 rgcn_func="FastRGCNConv",
                 **kwargs):
        super(RGCNSAGPooling, self).__init__()

        self.in_channels = in_channels
        self.ratio = ratio
        self.gnn = FastRGCNConv(
            in_channels, 1, num_relations, **
            kwargs) if rgcn_func == "FastRGCNConv" else RGCNConv(
                in_channels, 1, num_relations, **kwargs)
        self.min_score = min_score
        self.multiplier = multiplier
        self.nonlinearity = nonlinearity

        self.reset_parameters()

    def reset_parameters(self):
        self.gnn.reset_parameters()

    def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        attn = x if attn is None else attn
        attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn
        score = self.gnn(attn, edge_index, edge_attr).view(-1)

        if self.min_score is None:
            score = self.nonlinearity(score)
        else:
            score = softmax(score, batch)

        perm = topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm, score[perm]

    def __repr__(self):
        return '{}({}, {}, {}={}, multiplier={})'.format(
            self.__class__.__name__, self.gnn.__class__.__name__,
            self.in_channels,
            'ratio' if self.min_score is None else 'min_score',
            self.ratio if self.min_score is None else self.min_score,
            self.multiplier)
Beispiel #2
0
 def __init__(self, in_channels, out_channels, num_relations):
     super().__init__()
     self.conv1 = FastRGCNConv(in_channels, 16, num_relations, num_bases=30)
     self.conv2 = FastRGCNConv(16,
                               out_channels,
                               num_relations,
                               num_bases=30)
Beispiel #3
0
    def __init__(self, in_channels, num_relations, ratio=0.5, min_score=None,
                 multiplier=1, nonlinearity=torch.tanh, rgcn_func="FastRGCNConv", **kwargs):
        super(RGCNSAGPooling, self).__init__()

        self.in_channels = in_channels
        self.ratio = ratio
        self.gnn = FastRGCNConv(in_channels, 1, num_relations, **kwargs) if rgcn_func=="FastRGCNConv" else RGCNConv(in_channels, 1, num_relations, **kwargs)
        self.min_score = min_score
        self.multiplier = multiplier
        self.nonlinearity = nonlinearity

        self.reset_parameters()
def test_rgcn_conv_equality(conf):
    num_bases, num_blocks = conf

    x1 = torch.randn(4, 4)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]])
    edge_type = torch.tensor([0, 1, 1, 0, 0, 1])

    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3],
        [0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1],
    ])
    edge_type = torch.tensor([0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3])

    torch.manual_seed(12345)
    conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks)

    torch.manual_seed(12345)
    conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks)

    out1 = conv1(x1, edge_index, edge_type)
    out2 = conv2(x1, edge_index, edge_type)
    assert torch.allclose(out1, out2, atol=1e-6)

    if num_blocks is None:
        out1 = conv1(None, edge_index, edge_type)
        out2 = conv2(None, edge_index, edge_type)
        assert torch.allclose(out1, out2, atol=1e-6)
Beispiel #5
0
    def __init__(self,
                 in_channels,
                 number_hidden_layers,
                 aggr,
                 hidden_out_channel,
                 out_channel,
                 pool_layer,
                 k=1,
                 device=None):
        super(InceptionNet, self).__init__()
        self.pool_layer = pool_layer  # 'add', 'max', 'mean' or 'sort'
        self.device = device
        self.k = k
        self.atom_encoder = AtomEncoder(emb_dim=in_channels)
        self.batchnorm = BatchNorm(in_channels=2 * hidden_out_channel)

        self.rgcn_list = torch.nn.ModuleList()
        self.graphconv_list = torch.nn.ModuleList()
        self.rgcn_list.append(
            FastRGCNConv(in_channels=in_channels,
                         out_channels=hidden_out_channel,
                         num_relations=NUM_RELATIONS))
        self.graphconv_list.append(
            GraphConv(in_channels=in_channels,
                      out_channels=hidden_out_channel))

        if number_hidden_layers != 0:
            for i in range(number_hidden_layers):
                self.rgcn_list.append(
                    FastRGCNConv(in_channels=2 * hidden_out_channel,
                                 out_channels=hidden_out_channel,
                                 num_relations=NUM_RELATIONS))
                self.graphconv_list.append(
                    GraphConv(in_channels=2 * hidden_out_channel,
                              out_channels=hidden_out_channel))

        self.rgcn_list.append(
            FastRGCNConv(in_channels=2 * hidden_out_channel,
                         out_channels=out_channel,
                         num_relations=NUM_RELATIONS))
        self.graphconv_list.append(
            GraphConv(in_channels=2 * hidden_out_channel,
                      out_channels=out_channel))

        self.linear1 = nn.Linear(2 * k * out_channel, 16)
        self.linear2 = nn.Linear(16, 1)
Beispiel #6
0
    def __init__(self, config):
        super(Encoder, self).__init__()
        #        self.initializer = Initializer(config)
        layer = EncoderLayer(config)
        #        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
        self.layer = nn.ModuleList([layer])
        #        self.conv = FastRGCNConv(config.hidden_size,config.hidden_size)
        self.conv3 = FastRGCNConv(config.hidden_size,
                                  config.hidden_size,
                                  25,
                                  num_bases=128)
        self.conv2 = torch.nn.ModuleList()
        self.conv22 = torch.nn.ModuleList()

        for i in range(3):
            self.conv2.append(DNAConv(config.hidden_size, 32, 2, 0.1))
            self.conv22.append(DNAConv(config.hidden_size, 32, 2, 0.1))

        self.hidden_size = config.hidden_size
        #        self.conv2 = DNAConv(config.hidden_size,32,16,0.1)

        #        self.conv2 = AGNNConv(config.hidden_size,config.hidden_size)
        self.norm = nn.LayerNorm([512, config.hidden_size], 1e-05)
Beispiel #7
0
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 depth,
                 pool_ratios=0.5,
                 sum_res=True,
                 act=F.relu,
                 num_relations=4):
        super(GraphRUNet, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        self.sum_res = sum_res

        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(
            FastRGCNConv(in_channels, channels, num_relations=num_relations))
        for i in range(depth):
            self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels))

        in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        for i in range(depth - 1):
            self.up_convs.append(GCNConv(channels, channels))
        self.up_convs.append(GCNConv(channels, channels))

        self.reset_parameters()