示例#1
0
    def __init__(self, args):
        super(LinkPredictor, self).__init__()

        in_channels = args.hidden_channels
        hidden_channels = args.hidden_channels
        out_channels = args.num_tasks
        num_layers = args.lp_num_layers
        norm = args.lp_norm

        if norm.lower() == 'none':
            self.norms = None
        else:
            self.norms = torch.nn.ModuleList()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))

        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
            if self.norms is not None:
                self.norms.append(norm_layer(norm, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = args.dropout
示例#2
0
    def __init__(self, args):
        super(RevGCN, self).__init__()

        self.inject_input = args.inject_input
        self.num_layers = args.num_layers
        self.num_steps = args.num_steps
        self.dropout = args.dropout
        self.block = args.block
        self.group = args.group

        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        conv = args.conv
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        conv_encode_edge = args.conv_encode_edge
        norm = args.norm
        mlp_layers = args.mlp_layers
        node_features_file_path = args.nf_path

        self.use_one_hot_encoding = args.use_one_hot_encoding

        self.checkpoint_grad = False
        if self.num_steps > 15 or hidden_channels > 64:
            self.checkpoint_grad = True
            self.ckp_k = 10

        print('The number of layers {}'.format(self.num_layers),
              'Aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))

        if self.block == 'res+':
            print('LN/BN->ReLU->GraphConv->Res')
        elif self.block == 'res':
            print('GraphConv->LN/BN->ReLU->Res')
        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')
        elif self.block == "plain":
            print('GraphConv->LN/BN->ReLU')
        else:
            raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.last_norm = norm_layer(norm, hidden_channels)

        for layer in range(self.num_layers):
            Fms = nn.ModuleList()
            fm = GENBlock(hidden_channels // self.group,
                          hidden_channels // self.group,
                          aggr=aggr,
                          t=t,
                          learn_t=self.learn_t,
                          p=p,
                          learn_p=self.learn_p,
                          msg_norm=self.msg_norm,
                          learn_msg_scale=learn_msg_scale,
                          encode_edge=conv_encode_edge,
                          edge_feat_dim=hidden_channels,
                          norm=norm,
                          mlp_layers=mlp_layers)

            for i in range(self.group):
                if i == 0:
                    Fms.append(fm)
                else:
                    Fms.append(copy.deepcopy(fm))

            invertible_module = memgcn.GroupAdditiveCoupling(Fms,
                                                             group=self.group)

            gcn = memgcn.InvertibleModuleWrapper(fn=invertible_module,
                                                 keep_input=False)

            self.gcns.append(gcn)

        self.node_features = torch.load(node_features_file_path).to(
            args.device)

        if self.use_one_hot_encoding:
            self.node_one_hot_encoder = torch.nn.Linear(8, 8)
            self.node_features_encoder = torch.nn.Linear(
                8 * 2, hidden_channels)
        else:
            self.node_features_encoder = torch.nn.Linear(8, hidden_channels)

        self.edge_encoder = torch.nn.Linear(8, hidden_channels)

        self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)
示例#3
0
    def __init__(self, args):
        super(DeeperGCN, self).__init__()

        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.block = args.block
        self.conv_encode_edge = args.conv_encode_edge
        self.add_virtual_node = args.add_virtual_node

        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        conv = args.conv
        aggr = args.gcn_aggr
        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        norm = args.norm
        mlp_layers = args.mlp_layers

        graph_pooling = args.graph_pooling

        print('The number of layers {}'.format(self.num_layers),
              'Aggr aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))
        if self.block == 'res+':
            print('LN/BN->ReLU->GraphConv->Res')
        elif self.block == 'res':
            print('GraphConv->LN/BN->ReLU->Res')
        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')
        elif self.block == "plain":
            print('GraphConv->LN/BN->ReLU')
        else:
            raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        if self.add_virtual_node:
            self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels)
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

            self.mlp_virtualnode_list = torch.nn.ModuleList()

            for layer in range(self.num_layers - 1):
                self.mlp_virtualnode_list.append(
                    MLP([hidden_channels, hidden_channels],
                        norm=norm,
                        last_act=True))

        for layer in range(self.num_layers):
            if conv == 'gen':
                gcn = GENConv(hidden_channels,
                              hidden_channels,
                              aggr=aggr,
                              t=t,
                              learn_t=self.learn_t,
                              p=p,
                              learn_p=self.learn_p,
                              msg_norm=self.msg_norm,
                              learn_msg_scale=learn_msg_scale,
                              encode_edge=self.conv_encode_edge,
                              bond_encoder=True,
                              norm=norm,
                              mlp_layers=mlp_layers)
            else:
                raise Exception('Unknown Conv Type')
            self.gcns.append(gcn)
            self.norms.append(norm_layer(norm, hidden_channels))

        self.atom_encoder = AtomEncoder(emb_dim=hidden_channels)

        if not self.conv_encode_edge:
            self.bond_encoder = BondEncoder(emb_dim=hidden_channels)

        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        else:
            raise Exception('Unknown Pool Type')

        self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)
示例#4
0
    def __init__(self, args):
        super(DeeperGCN, self).__init__()

        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.block = args.block

        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        conv = args.conv
        aggr = args.gcn_aggr
        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        conv_encode_edge = args.conv_encode_edge
        norm = args.norm
        mlp_layers = args.mlp_layers

        graph_pooling = args.graph_pooling

        print('The number of layers {}'.format(self.num_layers),
              'Aggr aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))
        if self.block == 'res+':
            print('LN/BN->ReLU->GraphConv->Res')
        elif self.block == 'res':
            print('GraphConv->LN/BN->ReLU->Res')
        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')
        elif self.block == "plain":
            print('GraphConv->LN/BN->ReLU')
        else:
            raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        for layer in range(self.num_layers):

            if conv == 'gen':
                gcn = GENConv(hidden_channels,
                              hidden_channels,
                              aggr=aggr,
                              t=t,
                              learn_t=self.learn_t,
                              p=p,
                              learn_p=self.learn_p,
                              msg_norm=self.msg_norm,
                              learn_msg_scale=learn_msg_scale,
                              encode_edge=conv_encode_edge,
                              edge_feat_dim=hidden_channels,
                              norm=norm,
                              mlp_layers=mlp_layers)
            else:
                raise Exception('Unknown Conv Type')
            self.gcns.append(gcn)
            self.norms.append(norm_layer(norm, hidden_channels))

        self.node_features_encoder = torch.nn.Linear(7, hidden_channels)

        self.edge_encoder = torch.nn.Linear(7, hidden_channels)

        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        else:
            raise Exception('Unknown Pool Type')

        self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)
示例#5
0
    def __init__(self, args):
        super(DeeperGCN, self).__init__()

        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.block = args.block

        self.checkpoint_grad = False

        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        conv = args.conv
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        conv_encode_edge = args.conv_encode_edge
        norm = args.norm
        mlp_layers = args.mlp_layers
        node_features_file_path = args.nf_path

        self.use_one_hot_encoding = args.use_one_hot_encoding

        if aggr in ['softmax_sg', 'softmax', 'power'] and self.num_layers > 50:
            self.checkpoint_grad = True
            self.ckp_k = 3

        if (self.learn_t or self.learn_p) and self.num_layers > 15:
            self.checkpoint_grad = True
            self.ckp_k = 9

        print('The number of layers {}'.format(self.num_layers),
              'Aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))

        if self.block == 'res+':
            print('LN/BN->ReLU->GraphConv->Res')
        elif self.block == 'res':
            print('GraphConv->LN/BN->ReLU->Res')
        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')
        elif self.block == "plain":
            print('GraphConv->LN/BN->ReLU')
        else:
            raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.layer_norms = torch.nn.ModuleList()

        for layer in range(self.num_layers):

            if conv == 'gen':
                gcn = GENConv(hidden_channels,
                              hidden_channels,
                              aggr=aggr,
                              t=t,
                              learn_t=self.learn_t,
                              p=p,
                              learn_p=self.learn_p,
                              msg_norm=self.msg_norm,
                              learn_msg_scale=learn_msg_scale,
                              encode_edge=conv_encode_edge,
                              edge_feat_dim=hidden_channels,
                              norm=norm,
                              mlp_layers=mlp_layers)
            else:
                raise Exception('Unknown Conv Type')

            self.gcns.append(gcn)
            self.layer_norms.append(norm_layer(norm, hidden_channels))

        self.node_features = torch.load(node_features_file_path).to(
            args.device)

        if self.use_one_hot_encoding:
            self.node_one_hot_encoder = torch.nn.Linear(8, 8)
            self.node_features_encoder = torch.nn.Linear(
                8 * 2, hidden_channels)
        else:
            self.node_features_encoder = torch.nn.Linear(8, hidden_channels)

        self.edge_encoder = torch.nn.Linear(8, hidden_channels)

        self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)
示例#6
0
    def __init__(self, args):
        super(DeeperGCN, self).__init__()

        self.edge_num = 2358104
        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.block = args.block

        self.checkpoint_grad = False

        hidden_channels = args.hidden_channels
        conv = args.conv
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        norm = args.norm
        mlp_layers = args.mlp_layers

        if self.num_layers > 7:
            self.checkpoint_grad = True

        print('The number of layers {}'.format(self.num_layers),
              'Aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))

        # if self.block == 'res+':
        #     print('LN/BN->ReLU->GraphConv->Res')
        # elif self.block == 'res':
        #     print('GraphConv->LN/BN->ReLU->Res')
        # elif self.block == 'dense':
        #     raise NotImplementedError('To be implemented')
        # elif self.block == "plain":
        #     print('GraphConv->LN/BN->ReLU')
        # else:
        #     raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        self.edge_mask1_train = nn.Parameter(torch.ones(self.edge_num, 1),
                                             requires_grad=True)
        self.edge_mask2_fixed = nn.Parameter(torch.ones(self.edge_num, 1),
                                             requires_grad=False)

        for layer in range(self.num_layers):

            if conv == 'gen':
                gcn = GENConv(hidden_channels,
                              hidden_channels,
                              aggr=aggr,
                              t=t,
                              learn_t=self.learn_t,
                              p=p,
                              learn_p=self.learn_p,
                              msg_norm=self.msg_norm,
                              learn_msg_scale=learn_msg_scale,
                              norm=norm,
                              mlp_layers=mlp_layers)
            else:
                raise Exception('Unknown Conv Type')

            self.gcns.append(gcn)
            self.norms.append(norm_layer(norm, hidden_channels))
示例#7
0
文件: model.py 项目: hansonhl/FLAG
    def __init__(self, args):
        super(DeeperGCN, self).__init__()

        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.block = args.block

        self.checkpoint_grad = False

        in_channels = args.in_channels
        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        conv = args.conv
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        norm = args.norm
        mlp_layers = args.mlp_layers

        if aggr in ['softmax_sg', 'softmax', 'power'] and self.num_layers > 3:
            self.checkpoint_grad = True
            self.ckp_k = self.num_layers // 2

        print('The number of layers {}'.format(self.num_layers),
              'Aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))

        if self.block == 'res+':
            print('LN/BN->ReLU->GraphConv->Res')
        elif self.block == 'res':
            print('GraphConv->LN/BN->ReLU->Res')
        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')
        elif self.block == "plain":
            print('GraphConv->LN/BN->ReLU')
        else:
            raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        self.node_features_encoder = torch.nn.Linear(in_channels,
                                                     hidden_channels)
        self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)

        for layer in range(self.num_layers):

            if conv == 'gen':
                gcn = GENConv(hidden_channels,
                              hidden_channels,
                              aggr=aggr,
                              t=t,
                              learn_t=self.learn_t,
                              p=p,
                              learn_p=self.learn_p,
                              msg_norm=self.msg_norm,
                              learn_msg_scale=learn_msg_scale,
                              norm=norm,
                              mlp_layers=mlp_layers)
            elif conv == "gat":
                gcn = GATConv(hidden_channels,
                              hidden_channels,
                              norm=norm,
                              heads=4)

            else:
                raise Exception('Unknown Conv Type')

            self.gcns.append(gcn)
            self.norms.append(norm_layer(norm, hidden_channels))
示例#8
0
    def __init__(self, args):
        super(DeeperGCN, self).__init__()

        self.edge_num = 2484941
        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.block = args.block
        self.checkpoint_grad = False

        in_channels = args.in_channels
        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        conv = args.conv
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        y = args.y
        self.learn_y = args.learn_y

        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        norm = args.norm
        mlp_layers = args.mlp_layers

        if aggr in ['softmax_sg', 'softmax', 'power'] and self.num_layers > 7:
            self.checkpoint_grad = True
            self.ckp_k = self.num_layers // 2

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        self.node_features_encoder = torch.nn.Linear(in_channels,
                                                     hidden_channels)
        self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)

        self.edge_mask1_train = nn.Parameter(torch.ones(self.edge_num, 1),
                                             requires_grad=True)
        self.edge_mask2_fixed = nn.Parameter(torch.ones(self.edge_num, 1),
                                             requires_grad=False)

        for layer in range(self.num_layers):

            if conv == 'gen':
                gcn = GENConv(hidden_channels,
                              hidden_channels,
                              aggr=aggr,
                              t=t,
                              learn_t=self.learn_t,
                              p=p,
                              learn_p=self.learn_p,
                              y=y,
                              learn_y=self.learn_y,
                              msg_norm=self.msg_norm,
                              learn_msg_scale=learn_msg_scale,
                              norm=norm,
                              mlp_layers=mlp_layers)
            else:
                raise Exception('Unknown Conv Type')

            self.gcns.append(gcn)
            self.norms.append(norm_layer(norm, hidden_channels))
    def __init__(
        self,
        num_layers: int,
        dropout: float,
        block: str,
        conv_encode_edge: bool,
        add_virtual_node: bool,
        hidden_channels: int,
        num_tasks: Optional[int],
        conv: str,
        gcn_aggr: str,
        t: float,
        learn_t: bool,
        p: float,
        learn_p: bool,
        y: float,
        learn_y: bool,
        msg_norm: bool,
        learn_msg_scale: bool,
        norm: str,
        mlp_layers: int,
        graph_pooling: Optional[str],
        node_encoder:
        bool = False,  # no pooling, produce node-level representations
        encode_atom: bool = True,
    ):
        super(DeeperGCN, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        self.block = block
        self.conv_encode_edge = conv_encode_edge
        self.add_virtual_node = add_virtual_node
        self.encoder = node_encoder
        self.encode_atom = encode_atom

        aggr = gcn_aggr
        self.learn_t = learn_t
        self.learn_p = learn_p
        self.learn_y = learn_y

        self.msg_norm = msg_norm

        print(
            "The number of layers {}".format(self.num_layers),
            "Aggr aggregation method {}".format(aggr),
            "block: {}".format(self.block),
        )
        if self.block == "res+":
            print("LN/BN->ReLU->GraphConv->Res")
        elif self.block == "res":
            print("GraphConv->LN/BN->ReLU->Res")
        elif self.block == "dense":
            raise NotImplementedError("To be implemented")
        elif self.block == "plain":
            print("GraphConv->LN/BN->ReLU")
        else:
            raise Exception("Unknown block Type")

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        if self.add_virtual_node:
            self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels)
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

            self.mlp_virtualnode_list = torch.nn.ModuleList()

            for layer in range(self.num_layers - 1):
                self.mlp_virtualnode_list.append(
                    MLP([hidden_channels] * 3, norm=norm))

        for layer in range(self.num_layers):
            if conv == "gen":
                gcn = GENConv(
                    hidden_channels,
                    hidden_channels,
                    aggr=aggr,
                    t=t,
                    learn_t=self.learn_t,
                    p=p,
                    learn_p=self.learn_p,
                    y=y,
                    learn_y=self.learn_p,
                    msg_norm=self.msg_norm,
                    learn_msg_scale=learn_msg_scale,
                    encode_edge=self.conv_encode_edge,
                    bond_encoder=True,
                    norm=norm,
                    mlp_layers=mlp_layers,
                )
            else:
                raise Exception("Unknown Conv Type")
            self.gcns.append(gcn)
            self.norms.append(norm_layer(norm, hidden_channels))

        self.atom_encoder = AtomEncoder(emb_dim=hidden_channels)

        if not self.conv_encode_edge:
            self.bond_encoder = BondEncoder(emb_dim=hidden_channels)

        if not node_encoder:
            if graph_pooling == "sum":
                self.pool = global_add_pool
            elif graph_pooling == "mean":
                self.pool = global_mean_pool
            elif graph_pooling == "max":
                self.pool = global_max_pool
            else:
                raise Exception("Unknown Pool Type")

            self.graph_pred_linear = torch.nn.Linear(hidden_channels,
                                                     num_tasks)