Ejemplo n.º 1
0
    def set_last_mlp(self, last_mlp_opt):

        if len(last_mlp_opt.nn) > 2:
            self.FC_layer = MLP(last_mlp_opt.nn[:len(last_mlp_opt.nn) - 1])
            self.FC_layer.add_module(
                "last", Lin(last_mlp_opt.nn[-2], last_mlp_opt.nn[-1]))
        elif len(last_mlp_opt.nn) == 2:
            self.FC_layer = Seq(Lin(last_mlp_opt.nn[-2], last_mlp_opt.nn[-1]))
        else:
            self.FC_layer = torch.nn.Identity()
Ejemplo n.º 2
0
 def __init__(self, input_size, n_classes, k_graph=False):
     super(GAT, self).__init__()
     self.conv1 = GATConv(input_size, 64, heads=4, dropout=0.)
     self.conv2 = GATConv(256, 256, heads=3, dropout=0.)
     #self.conv3 = GATConv(1024, 256, heads=4, dropout=0., concat=False)
     self.pool = global_max_pool
     self.mlp = Seq(MLP([1024, 512]), Dropout(0.2), MLP([512, 256]),
                    Dropout(0.2), Lin(256, n_classes))
     #self.mlp = Lin(256,n_classes)
     self.k_graph = k_graph
Ejemplo n.º 3
0
    def __init__(self, out_channels, k=5, aggr='max'):
        super().__init__()

        self.conv1 = DynamicEdgeConv(MLP([2 * 3, 64, 64, 64]), k, aggr)
        self.conv2 = DynamicEdgeConv(MLP([2 * 64, 128]), k, aggr)
        self.lin1 = MLP([128 + 64, 1024])

        self.mlp = Seq(
            MLP([1024, 512]), Dropout(0.5), MLP([512, 256]), Dropout(0.5),
            Lin(256, out_channels))
Ejemplo n.º 4
0
    def __init__(self, input_size, embedding_size, n_classes, dropout=True, k=5, aggr='max',pool_op='max'):
        super(DECSeq, self).__init__()
        # self.bn0 = BN(input_size)
        # self.bn1 = BN(64)
        # self.bn2 = BN(128)
        self.conv1 = EdgeConv(MLP([2 * 3, 64, 64, 64], batch_norm=True), aggr)
        self.conv2 = DynamicEdgeConv(MLP([2 * 64, 128], batch_norm=True), k, aggr)
        self.lin1 = MLP([128 + 64, 1024])
        if pool_op == 'max':
            self.pool = global_max_pool

        if dropout:
            self.mlp = Seq(
                MLP([1024, 512],batch_norm=True), Dropout(0.5), MLP([512, 256],batch_norm=True), Dropout(0.5),
                Lin(256, n_classes))
        else:
            self.mlp = Seq(
                MLP([1024, 512]), MLP([512, 256]),
                Lin(256, n_classes))
Ejemplo n.º 5
0
    def __init__(self, option, model_type, dataset, modules):
        # Extract parameters from the dataset
        # Assemble encoder / decoder
        UnwrappedUnetBasedModel.__init__(self, option, model_type, dataset,
                                         modules)

        # Build final MLP
        last_mlp_opt = option.mlp_cls

        self.out_channels = option.out_channels
        in_feat = last_mlp_opt.nn[0]
        self.FC_layer = Seq()
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.add_module(
                str(i),
                Seq(*[
                    Lin(in_feat, last_mlp_opt.nn[i], bias=False),
                    FastBatchNorm1d(last_mlp_opt.nn[i],
                                    momentum=last_mlp_opt.bn_momentum),
                    LeakyReLU(0.2),
                ]),
            )
            in_feat = last_mlp_opt.nn[i]

        if last_mlp_opt.dropout:
            self.FC_layer.add_module("Dropout",
                                     Dropout(p=last_mlp_opt.dropout))

        self.FC_layer.add_module("Last",
                                 Lin(in_feat, self.out_channels, bias=False))
        self.mode = option.loss_mode
        self.normalize_feature = option.normalize_feature
        self.loss_names = ["loss_reg", "loss"]

        self.lambda_reg = self.get_from_opt(option,
                                            ["loss_weights", "lambda_reg"])
        if self.lambda_reg:
            self.loss_names += ["loss_regul"]

        self.lambda_internal_losses = self.get_from_opt(
            option, ["loss_weights", "lambda_internal_losses"])

        self.visual_names = ["data_visual"]
Ejemplo n.º 6
0
    def __init__(self, opt):
        super(DenseDeepGCN, self).__init__()
        channels = opt.n_filters
        k = opt.kernel_size
        act = opt.act
        norm = opt.norm
        bias = opt.bias
        epsilon = opt.epsilon
        stochastic = opt.stochastic
        conv = opt.conv
        c_growth = channels
        self.n_blocks = opt.n_blocks

        self.knn = DenseDilatedKnnGraph(k, 1, stochastic, epsilon)
        self.head = GraphConv2d(opt.in_channels, channels, conv, act, norm,
                                bias)

        if opt.block.lower() == 'res':
            self.backbone = Seq(*[
                ResDynBlock2d(channels, k, 1 +
                              i, conv, act, norm, bias, stochastic, epsilon)
                for i in range(self.n_blocks - 1)
            ])
        elif opt.block.lower() == 'dense':
            self.backbone = Seq(*[
                DenseDynBlock2d(channels + c_growth * i, c_growth, k, 1 +
                                i, conv, act, norm, bias, stochastic, epsilon)
                for i in range(self.n_blocks - 1)
            ])
        else:
            raise NotImplementedError(
                '{} is not implemented. Please check.\n'.format(opt.block))
        self.fusion_block = BasicConv(
            [channels + c_growth * (self.n_blocks - 1), 1024], act, norm, bias)
        self.prediction = Seq(*[
            BasicConv([channels + c_growth *
                       (self.n_blocks - 1) + 1024, 512], act, norm, bias),
            BasicConv([512, 256], act, norm, bias),
            torch.nn.Dropout(p=opt.dropout),
            BasicConv([256, opt.n_classes], None, None, bias)
        ])

        self.model_init()
Ejemplo n.º 7
0
 def __init__(self, input_size, n_classes=2, embedding_size=128, hidden_size=256):
     super(BiLSTM, self).__init__()
     self.emb_size = embedding_size
     self.h_size = hidden_size
     self.mlp = MLP([input_size, embedding_size])
     self.lstm = nn.LSTM(embedding_size, hidden_size,
                         bidirectional=True, batch_first=True)
     self.lin = Seq(
         MLP([hidden_size, 256]), Dropout(0.5),
         MLP([256, 128]), Dropout(0.5),
         nn.Linear(128, n_classes))
Ejemplo n.º 8
0
    def __init__(self, in_features, out_features, k):
        super(DirectionalEdgeConv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.k = k

        self.mlp = Seq(Linear(3+2*in_features, out_features*2),
                       ReLU(),
                       Linear(out_features*2, out_features))

        self.econv = EdgeConv(self.mlp, aggr='max')
Ejemplo n.º 9
0
 def __init__(self, n_features, n_outputs, dim=95):
     super(GINAttNet, self).__init__()
     # Preparation of the Graph Isomorphism Convolutional Layer
     nn1 = Seq(Linear(n_features, dim), ReLU(), Linear(dim, dim))
     self.conv1 = GINConv(nn1)
     self.bn1 = torch.nn.BatchNorm1d(dim)
     # Preparation of the Attention Layer
     self.conv2 = GATConv(dim, dim, heads=1, dropout=0.3)
     # Preparation of the Fully Connected Layer
     self.fc1 = Linear(dim, 2 * dim)
     self.fc2 = Linear(2 * dim, n_outputs)
 def __init__(self, config):
     NodeModelBase.__init__(self, config)
     self.dim_lh = config['node_model_latent_mlp_hidden_size']
     self.dim_l = config['l_outc']
     self.node_mlp_2 = Seq(Linear(self.mlp2_inc, self.mlp2_hs1),
                           LayerNorm(self.mlp2_hs1), ReLU(),
                           Linear(self.mlp2_hs1, self.dim_lh),
                           LayerNorm(self.dim_lh), ReLU())
     self.mlp_m = Linear(self.dim_lh, self.dim_l)
     self.mlp_v = Linear(self.dim_lh, self.dim_l)
     self.mlp_x = Linear(self.dim_l, self.dim_out)
Ejemplo n.º 11
0
 def __init__(self, dof=6, act="Sigmoid", dropout=0.5, graph_input=False, transform=None):
     super(Net, self).__init__()
     self.graph_input = graph_input
     self.transform = transform
     self.pointnet = PointNetXX(act=act, dropout=dropout)
     self.fcs = Seq(
         eval(act)(), Lin(64 * 2, 64), Dropout(dropout),
         eval(act)(), Lin(64, dof)
     )
     self.sx = torch.nn.Parameter(torch.tensor(-2.5), requires_grad=True)
     self.sq = torch.nn.Parameter(torch.tensor(-2.5), requires_grad=True)
Ejemplo n.º 12
0
def MLP(channels, enable_group_norm=True):
    if enable_group_norm:
        num_groups = [0]
        for i in range(1, len(channels)):
            if channels[i] >= 32:
                num_groups.append(channels[i] // 32)
            else:
                num_groups.append(1)
        return Seq(*[
            Seq(torch.nn.utils.weight_norm(Lin(channels[i - 1], channels[i])),
                LeakyReLU(
                    negative_slope=0.2), GroupNorm(num_groups[i], channels[i]))
            for i in range(1, len(channels))
        ])
    else:
        return Seq(*[
            Seq(torch.nn.utils.weight_norm(Lin(channels[i - 1]), channels[i]),
                LeakyReLU(negative_slope=0.2))
            for i in range(1, len(channels))
        ])
Ejemplo n.º 13
0
    def __init__(self, input_size, embedding_size, n_classes, aggr='max', k=5, pool_op='max', same_size=False):
        super(ECnet, self).__init__()
        self.conv1 = EdgeConv(MLP([2 * 3, 64, 64, 64]), aggr)
        self.conv2 = EdgeConv(MLP([2 * 64, 128]), aggr)
        self.lin1 = MLP([128 + 64, 1024])
        if pool_op == 'max':
            self.pool = global_max_pool

        self.mlp = Seq(
            MLP([1024, 512]), Dropout(0.5), MLP([512, 256]), Dropout(0.5),
            Lin(256, n_classes))
    def __init__(self):
        super(EdgeBlock, self).__init__()
        #A sequential container. Modules will be added to it in the order they are passed in the constructor.
        #Alternatively, an ordered dict of modules can also be passed in.

        #Applies a linear transformation to the incoming data: y = xA^T + by=xA^T+b
        self.edge_mlp = Seq(
            Lin(48 * 2, 128),  # changed 2 to 6
            BatchNorm1d(128),
            ReLU(),
            Lin(128, 128))
Ejemplo n.º 15
0
def test_point_conv():
    in_channels, out_channels = (16, 32)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    num_nodes = edge_index.max().item() + 1
    x = torch.randn((num_nodes, in_channels))
    pos = torch.rand((num_nodes, 3))
    norm = torch.nn.functional.normalize(torch.rand((num_nodes, 3)), dim=1)

    local_nn = Seq(Lin(in_channels + 4, 32), ReLU(), Lin(32, out_channels))
    global_nn = Seq(Lin(out_channels, out_channels))
    conv = PPFConv(local_nn, global_nn)
    assert conv.__repr__() == (
        'PPFConv(local_nn=Sequential(\n'
        '  (0): Linear(in_features=20, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=32, bias=True)\n'
        '), global_nn=Sequential(\n'
        '  (0): Linear(in_features=32, out_features=32, bias=True)\n'
        '))')
    assert conv(x, pos, norm, edge_index).size() == (num_nodes, out_channels)
Ejemplo n.º 16
0
    def __init__(self):
        super(GlobalModel_ONE, self).__init__()
        hidden = HIDDEN_GRAPH_ONE
        in_channels = ENCODING_NODE_1 + ENCODING_EDGE_1

        self.global_mlp = Seq(
            Lin(in_channels, hidden),
            LeakyReLU(),
            LayerNorm(hidden),
            #Lin(hidden, hidden), LeakyReLU(),
            Lin(hidden, NO_GRAPH_FEATURES_ONE)).apply(init_weights)
Ejemplo n.º 17
0
def test_reset():
    nn = Lin(16, 16)
    w = nn.weight.clone()
    reset(nn)
    assert not nn.weight.tolist() == w.tolist()

    nn = Seq(Lin(16, 16), ReLU(), Lin(16, 16))
    w_1, w_2 = nn[0].weight.clone(), nn[2].weight.clone()
    reset(nn)
    assert not nn[0].weight.tolist() == w_1.tolist()
    assert not nn[2].weight.tolist() == w_2.tolist()
Ejemplo n.º 18
0
 def __init__(self,
              in_channels,
              n_growth=32,
              kernel_size=3,
              bias=True,
              act='relu',
              norm=True,
              res_scale=1.,
              n_layers=3):
     super(DenseBlock, self).__init__()
     self.res_scale = res_scale
     convs = Seq(*[
         DenseConv(in_channels +
                   i * n_growth, n_growth, kernel_size, bias, act, norm)
         for i in range(n_layers - 1)
     ])
     tail = Conv(in_channels + (n_layers - 1) * n_growth, in_channels,
                 kernel_size, bias, act, norm)
     self.body = Seq(*[convs, tail])
     self.n_layers = n_layers
Ejemplo n.º 19
0
    def __init__(self, node_channels, edge_channels, out_channels, aggr='add'):
        super(NNConvLayer, self).__init__()
        self.node_channels = node_channels
        self.edge_channels = edge_channels
        self.out_channels = out_channels

        self.aggr = aggr
        self.mlp_edge = Seq(nn.Linear(self.edge_channels, out_channels), nn.BatchNorm1d(out_channels), nn.ReLU(),
                            nn.Linear(out_channels, node_channels * out_channels))
        self.mlp_output = nn.ReLU()
        self.nn_conv = NNConv(self.node_channels, self.out_channels, self.mlp_edge, aggr=self.aggr)
Ejemplo n.º 20
0
def test_nn_conv():
    in_channels, out_channels = (16, 32)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    num_nodes = edge_index.max().item() + 1
    x = torch.randn((num_nodes, in_channels))
    pseudo = torch.rand((edge_index.size(1), 3))

    nn = Seq(Lin(3, 32), ReLU(), Lin(32, in_channels * out_channels))
    conv = NNConv(in_channels, out_channels, nn)
    assert conv.__repr__() == 'NNConv(16, 32)'
    assert conv(x, edge_index, pseudo).size() == (num_nodes, out_channels)
Ejemplo n.º 21
0
def test_gine_conv():
    x1 = torch.randn(4, 16)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.randn(row.size(0), 16)
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32))
    conv = GINEConv(nn, train_eps=True)
    assert conv.__repr__() == (
        'GINEConv(nn=Sequential(\n'
        '  (0): Linear(in_features=16, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=32, bias=True)\n'
        '))')
    out = conv(x1, edge_index, value)
    assert out.size() == (4, 32)
    assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()
    assert conv(x1, adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index, value).tolist() == out.tolist()
        assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()

        t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    out1 = conv((x1, x2), edge_index, value)
    out2 = conv((x1, None), edge_index, value, (4, 2))
    assert out1.size() == (2, 32)
    assert out2.size() == (2, 32)
    assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist()
    assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
    assert conv((x1, None), adj.t()).tolist() == out2.tolist()

    if is_full_test():
        t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), edge_index, value).tolist() == out1.tolist()
        assert jit((x1, x2), edge_index, value,
                   size=(4, 2)).tolist() == out1.tolist()
        assert jit((x1, None), edge_index, value,
                   size=(4, 2)).tolist() == out2.tolist()

        t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
        assert jit((x1, None), adj.t()).tolist() == out2.tolist()
Ejemplo n.º 22
0
 def __init__(self, config):
     EdgeModelBase.__init__(self, config)
     self.dim_x1 = self.dim_in
     self.dim_lh = config['edge_model_latent_mlp_hidden_size']
     self.dim_l = config['l_outc']
     self.mlp_h = Seq(Linear(self.dim_x1, self.dim_h1),
                      LayerNorm(self.dim_h1), ReLU(),
                      Linear(self.dim_h1, self.dim_lh),
                      LayerNorm(self.dim_lh), ReLU())
     self.mlp_m = Linear(self.dim_lh, self.dim_l)
     self.mlp_v = Linear(self.dim_lh, self.dim_l)
     self.mlp_w = Linear(self.dim_l, self.dim_out)
Ejemplo n.º 23
0
def MLP(channels, batch_norm=True):
    """
    return a MLP of shape 'channels'
    """
    mlps = []
    for i in range(1, len(channels)):
        mlp = nn.Conv2d(in_channels=channels[i - 1],
                        out_channels=channels[i],
                        kernel_size=(1, 1),
                        stride=(1, 1),
                        bias=True)
        nn.init.kaiming_normal_(mlp.weight)
        mlps.append(mlp)
    if batch_norm:
        mlp = Seq(*[
            Seq(mlps[i - 1], BN(channels[i]), ReLU())
            for i in range(1, len(channels))
        ])
    else:
        mlp = Seq(*[Seq(mlps[i - 1], ReLU()) for i in range(1, len(channels))])
    return mlp
Ejemplo n.º 24
0
    def __init__(self, config):
        super(EdgeModel, self).__init__()

        in_channels = config['e_inc'] + 2*config['n_inc'] + config['u_inc']
        hs1 = config['edge_model_mlp1_hidden_sizes'][0]
        hs2 = config['edge_model_mlp1_hidden_sizes'][1]

        self.edge_mlp = Seq(Linear(in_channels, hs1), 
                            ReLU(), 
                            Linear(hs1, hs2),
                            ReLU(),
                            Linear(hs2, config['e_outc']))
Ejemplo n.º 25
0
def make_mlp(in_channels, mlp_channels, batch_norm=True):
    assert len(mlp_channels) >= 1
    layers = []

    for c in mlp_channels:
        layers += [Lin(in_channels, c)]
        if batch_norm: layers += [BatchNorm1d(c)]
        layers += [ReLU()]

        in_channels = c

    return Seq(*layers)
Ejemplo n.º 26
0
    def __init__(self, in_channels, out_channels, dim_model, k=16):
        super().__init__()
        self.k = k

        # dummy feature is created if there is none given
        in_channels = max(in_channels, 1)

        # first block
        self.mlp_input = MLP([in_channels, dim_model[0]])

        self.transformer_input = TransformerBlock(
            in_channels=dim_model[0],
            out_channels=dim_model[0],
        )

        # backbone layers
        self.transformers_up = torch.nn.ModuleList()
        self.transformers_down = torch.nn.ModuleList()
        self.transition_up = torch.nn.ModuleList()
        self.transition_down = torch.nn.ModuleList()

        for i in range(0, len(dim_model) - 1):

            # Add Transition Down block followed by a Point Transformer block
            self.transition_down.append(
                TransitionDown(in_channels=dim_model[i],
                               out_channels=dim_model[i + 1],
                               k=self.k))

            self.transformers_down.append(
                TransformerBlock(in_channels=dim_model[i + 1],
                                 out_channels=dim_model[i + 1]))

            # Add Transition Up block followed by Point Transformer block
            self.transition_up.append(
                TransitionUp(in_channels=dim_model[i + 1],
                             out_channels=dim_model[i]))

            self.transformers_up.append(
                TransformerBlock(in_channels=dim_model[i],
                                 out_channels=dim_model[i]))

        # summit layers
        self.mlp_summit = MLP([dim_model[-1], dim_model[-1]], batch_norm=False)

        self.transformer_summit = TransformerBlock(
            in_channels=dim_model[-1],
            out_channels=dim_model[-1],
        )

        # class score computation
        self.mlp_output = Seq(Lin(dim_model[0], 64), ReLU(), Lin(64, 64),
                              ReLU(), Lin(64, out_channels))
Ejemplo n.º 27
0
def test_dynamic_edge_conv_conv():
    x = torch.randn((4, 16))

    nn = Seq(Lin(32, 16), ReLU(), Lin(16, 32))
    conv = DynamicEdgeConv(nn, k=6)
    assert conv.__repr__() == (
        'DynamicEdgeConv(nn=Sequential(\n'
        '  (0): Linear(in_features=32, out_features=16, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=16, out_features=32, bias=True)\n'
        '), k=6)')
    assert conv(x).size() == (4, 32)
Ejemplo n.º 28
0
    def __init__(self):
        super(Net, self).__init__()

        self.edge_mlp = Seq(Linear(128, 256), ReLU(), Linear(256, 128))
        self.node_mlp = Seq(Linear(128, 256), ReLU(), Linear(256, 128))
        self.global_mlp = Seq(Linear(128, 256).ReLU(), Linear(256, 128))

        def edge_model(source, target, edge_attr, u):
            out = torch.cat([source, target, edge_attr], dim=1)
            return self.edge_mlp(out)

        def node_model(x, edge_index, edge_attr, u):
            row, col = edge_index
            out = torch.cat([x[col], edge_attr], dim=1)
            out = self.node_mlp(out)
            return scatter_mean(out, row, dim=0, dim_size=x.size(0))

        def global_model(x, edge_index, edge_attr, u, batch):
            out = torch.cat([u, scatter_mean(x, batch, dim=0)], dim=1)
            return self.global_mlp(out)

        self.op = MetaLayer(edge_model, node_model, global_model)
Ejemplo n.º 29
0
    def __init__(self, k=30, emb_dims=1024):
        super(GlobalFeat, self).__init__()
        self.k = k
        self.conv1 = Seq(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                         nn.BatchNorm2d(64),
                         nn.LeakyReLU(negative_slope=0.2))

        self.conv2 = Seq(nn.Conv2d(128, 64, kernel_size=1, bias=False),
                         nn.BatchNorm2d(64),
                         nn.LeakyReLU(negative_slope=0.2))

        self.conv3 = Seq(nn.Conv2d(128, 128, kernel_size=1, bias=False),
                         nn.BatchNorm2d(128),
                         nn.LeakyReLU(negative_slope=0.2))

        self.conv4 = Seq(nn.Conv2d(256, 256, kernel_size=1, bias=False),
                         nn.BatchNorm2d(256),
                         nn.LeakyReLU(negative_slope=0.2))

        self.conv5 = Seq(nn.Conv1d(512, emb_dims, kernel_size=1, bias=False),
                         nn.BatchNorm1d(emb_dims),
                         nn.LeakyReLU(negative_slope=0.2))
Ejemplo n.º 30
0
 def __init__(self,
              n_node_features,
              n_edge_features,
              n_global_features,
              n_hiddens,
              n_targets,
              use_batch_norm=False):
     super().__init__()
     if use_batch_norm:
         self.edge_mlp = Seq(
             Lin(2 * n_node_features + n_edge_features + n_global_features,
                 n_hiddens), LeakyReLU(), Lin(n_hiddens, n_hiddens),
             LeakyReLU(), Lin(n_hiddens, n_targets), BatchNorm1d(n_targets))
     else:
         self.edge_mlp = Seq(
             Lin(2 * n_node_features + n_edge_features + n_global_features,
                 n_hiddens),
             LeakyReLU(),
             Lin(n_hiddens, n_hiddens),
             LeakyReLU(),
             Lin(n_hiddens, n_targets),
         )