Exemple #1
0
class GCN_Net(torch.nn.Module):
    def __init__(self, in_channels, number_hidden_layers, aggr, hidden_out_channel, out_channel, pool_layer, k=1):
        super(GCN_Net, self).__init__()
        self.in_channels = in_channels
        self.number_hidden_layers = number_hidden_layers #number of hidden GraphConv layers
        self.aggr = aggr # "add", "mean" or "max"
        self.pool_layer = pool_layer # 'add', 'max', 'mean' or 'sort'
        self.hidden_out_channel = hidden_out_channel
        self.out_channel = out_channel
        self.atom_encoder = AtomEncoder(emb_dim=self.in_channels)
        self.k = k

        
        self.graph_conv_list = nn.ModuleList()
        self.graph_conv_list.append(GraphConv(in_channels= self.in_channels, out_channels=self.hidden_out_channel, aggr=self.aggr))

        self.batchnorm = BatchNorm(in_channels=self.hidden_out_channel)

        if self.number_hidden_layers != 0 : 
            for i in range(self.number_hidden_layers):
                self.graph_conv_list.append(GraphConv(in_channels= self.hidden_out_channel, out_channels= self.hidden_out_channel, aggr=self.aggr))
           
        self.graph_conv_list.append(GraphConv(in_channels = self.hidden_out_channel, out_channels = self.out_channel, aggr=self.aggr))
         
        self.linear1 = nn.Linear(self.k*self.out_channel, 16)
        self.linear2 = nn.Linear(16, 1)
            
    def forward(self, data):
        x = self.atom_encoder(data.x)
        edge_index = data.edge_index
        
        for i, layer in enumerate(self.graph_conv_list) : 
            x = layer(x, edge_index)
            x = F.relu(x)
            if i == len(self.graph_conv_list) - 1: continue
            x = self.batchnorm(x)  
            
        if self.pool_layer == 'add':
            x = global_add_pool(x, data.batch)
        if self.pool_layer == 'mean':
            x = global_mean_pool(x, data.batch)
        if self.pool_layer == 'max':
            x = global_max_pool(x, data.batch)
        if self.pool_layer == 'sort':
            x = global_sort_pool(x, data.batch, self.k)
   
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        
        return x

    def reset_parameters(self):
        
        #conv layer
        for graphconv in self.graph_conv_list:
            graphconv.reset_parameters()
        
        #batch norm
        self.batchnorm.reset_parameters()
        
        #fully connected
        self.linear1.reset_parameters()
        self.linear2.reset_parameters()
        
Exemple #2
0
class InceptionNet(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 number_hidden_layers,
                 aggr,
                 hidden_out_channel,
                 out_channel,
                 pool_layer,
                 k=1,
                 device=None):
        super(InceptionNet, self).__init__()
        self.pool_layer = pool_layer  # 'add', 'max', 'mean' or 'sort'
        self.device = device
        self.k = k
        self.atom_encoder = AtomEncoder(emb_dim=in_channels)
        self.batchnorm = BatchNorm(in_channels=2 * hidden_out_channel)

        self.rgcn_list = torch.nn.ModuleList()
        self.graphconv_list = torch.nn.ModuleList()
        self.rgcn_list.append(
            FastRGCNConv(in_channels=in_channels,
                         out_channels=hidden_out_channel,
                         num_relations=NUM_RELATIONS))
        self.graphconv_list.append(
            GraphConv(in_channels=in_channels,
                      out_channels=hidden_out_channel))

        if number_hidden_layers != 0:
            for i in range(number_hidden_layers):
                self.rgcn_list.append(
                    FastRGCNConv(in_channels=2 * hidden_out_channel,
                                 out_channels=hidden_out_channel,
                                 num_relations=NUM_RELATIONS))
                self.graphconv_list.append(
                    GraphConv(in_channels=2 * hidden_out_channel,
                              out_channels=hidden_out_channel))

        self.rgcn_list.append(
            FastRGCNConv(in_channels=2 * hidden_out_channel,
                         out_channels=out_channel,
                         num_relations=NUM_RELATIONS))
        self.graphconv_list.append(
            GraphConv(in_channels=2 * hidden_out_channel,
                      out_channels=out_channel))

        self.linear1 = nn.Linear(2 * k * out_channel, 16)
        self.linear2 = nn.Linear(16, 1)

    def forward(self, data):
        x = self.atom_encoder(data.x)
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        edge_attr = torch.LongTensor([
            edge_type[0] + edge_type[1] * 5 + edge_type[2] * 30
            for edge_type in edge_attr
        ]).to(self.device)
        for i in range(len(self.rgcn_list)):
            x_rgcn = self.rgcn_list[i](x, edge_index, edge_attr)
            x_gconv = self.graphconv_list[i](x, edge_index)
            x = torch.cat((x_rgcn, x_gconv), 1)
            x = F.relu(x)
            if i == len(self.rgcn_list) - 1: continue
            x = self.batchnorm(x)

        # x = self.graph_conv(x,edge_index)
        # x = F.relu(x)

        if self.pool_layer == 'add':
            x = global_add_pool(x, data.batch)
        if self.pool_layer == 'mean':
            x = global_mean_pool(x, data.batch)
        if self.pool_layer == 'max':
            x = global_max_pool(x, data.batch)
        if self.pool_layer == 'sort':
            x = global_sort_pool(x, data.batch, self.k)

        x = F.relu(self.linear1(x))
        x = self.linear2(x)

        return x

    def reset_parameters(self):

        #conv layer
        for rgcn in self.rgcn_list:
            rgcn.reset_parameters()

        #batch norm
        self.batchnorm.reset_parameters()

        # self.graph_conv.reset_parameters()

        #fully connected
        self.linear1.reset_parameters()
        self.linear2.reset_parameters()