Beispiel #1
0
    def __init__(self,
                 input_size,
                 embedding_size,
                 n_classes,
                 dropout=False,
                 k=5,
                 aggr='max',
                 pool_op='max'):
        super(DECSeq, self).__init__()
        self.conv1 = EdgeConv(
            MLP([2 * input_size, 64, 64, 64], batch_norm=True), aggr)
        self.conv2 = DynamicEdgeConv(MLP([2 * 64, 128], batch_norm=True), k,
                                     aggr)
        self.lin1 = MLP([128 + 64, 1024])
        if pool_op == 'max':
            self.pool = global_max_pool

        if dropout:
            self.mlp = Seq(MLP([1024, 512], batch_norm=True), Dropout(0.5),
                           MLP([512, 256], batch_norm=True), Dropout(0.5),
                           Lin(256, n_classes))
        else:
            self.mlp = Seq(MLP([1024, 512]), MLP([512, 256]),
                           Lin(256, n_classes))
Beispiel #2
0
 def __init__(self,
              input_size,
              embedding_size,
              n_classes,
              dropout=True,
              k=5,
              aggr='max',
              pool_op='max',
              k_global=25):
     super(DECSeqGlob, self).__init__()
     self.k_global = k_global
     self.conv1 = EdgeConv(MLP([2 * 3, 64, 64, 64]), aggr)
     self.conv2 = DynamicEdgeConv(MLP([2 * 64, 128]), k, aggr)
     self.lin1 = MLP([128 + 64, 1024])
     if pool_op == 'max':
         self.pool = global_max_pool
     if dropout:
         self.mlp = Seq(MLP([1024, 512]), Dropout(0.5), MLP([512, 256]),
                        Dropout(0.5), MLP([256, 32]))
     else:
         self.mlp = Seq(MLP([1024, 512]), MLP([512, 256]), MLP([256, 32]))
     self.lin = Lin(256, n_classes)
     # self.conv_glob = EdgeConv(MLP([2 * 32, 32]), aggr)
     self.conv_glob = GATConv(32, 32, heads=8, dropout=0.5, concat=True)