def __init__(self,
                 d_in,
                 d_degree,
                 degree_as_tag,
                 retain_features,
                 d_msg,
                 d_up,
                 d_h,
                 seed,
                 activation_name,
                 bn,
                 aggr='add',
                 msg_kind='general',
                 eps=0,
                 train_eps=False,
                 flow='source_to_target',
                 **kwargs):

        super(MPNN_sparse, self).__init__()

        d_msg = d_in if d_msg is None else d_msg

        self.flow = flow
        self.aggr = aggr
        self.msg_kind = msg_kind

        self.degree_as_tag = degree_as_tag
        self.retain_features = retain_features

        if degree_as_tag:
            d_in = d_in + d_degree if retain_features else d_degree

        if msg_kind == 'gin':
            msg_input_dim = None
            self.initial_eps = eps
            if train_eps:
                self.eps = torch.nn.Parameter(torch.Tensor([eps]))
            else:
                self.register_buffer('eps', torch.Tensor([eps]))
            self.eps.data.fill_(self.initial_eps)
            self.msg_fn = None
            update_input_dim = d_in

        elif msg_kind == 'general':
            msg_input_dim = 2 * d_in
            self.msg_fn = mlp(msg_input_dim, d_msg, d_h, seed, activation_name,
                              bn)
            update_input_dim = d_in + d_msg

        self.update_fn = mlp(update_input_dim, d_up, d_h, seed,
                             activation_name, bn)

        return
Пример #2
0
    def __init__(self,
                 d_in,
                 d_ef,
                 d_id,
                 d_degree,
                 degree_as_tag,
                 retain_features,
                 id_scope,
                 d_msg,
                 d_up,
                 d_h,
                 seed,
                 activation_name,
                 bn,
                 aggr='add',
                 msg_kind='ogb',
                 eps=0,
                 train_eps=False,
                 flow='source_to_target',
                 **kwargs):

        super(GSN_edge_sparse_ogb, self).__init__()
        
        self.flow = flow
        self.aggr = aggr
        self.msg_kind = msg_kind
        self.id_scope = id_scope
        
        self.degree_as_tag = degree_as_tag
        self.retain_features = retain_features

        if degree_as_tag:
            d_in = d_in + d_degree if retain_features else d_degree

        # INPUT_DIMS
        if msg_kind == 'ogb':
            
            self.initial_eps = eps
            if train_eps:
                self.eps = torch.nn.Parameter(torch.Tensor([eps]))
            else:
                self.register_buffer('eps', torch.Tensor([eps]))
            self.eps.data.fill_(self.initial_eps)
            
            update_input_dim = d_in
            
        else:
            raise NotImplementedError('msg kind {} is not currently supported.'.format(msg_kind))
        self.update_fn = mlp(update_input_dim, d_up, d_h, seed, activation_name, bn)       

        return
    def __init__(self,
                 d_in,
                 d_ef,
                 d_id,
                 d_degree,
                 degree_as_tag,
                 retain_features,
                 id_scope,
                 d_msg,
                 d_up,
                 d_h,
                 seed,
                 activation_name,
                 bn,
                 aggr='add',
                 msg_kind='general',
                 eps=0,
                 train_eps=False,
                 flow='source_to_target',
                 **kwargs):

        super(GSN_edge_sparse, self).__init__()

        d_msg = d_in if d_msg is None else d_msg

        self.flow = flow
        self.aggr = aggr
        self.msg_kind = msg_kind
        self.id_scope = id_scope

        self.degree_as_tag = degree_as_tag
        self.retain_features = retain_features

        if degree_as_tag:
            d_in = d_in + d_degree if retain_features else d_degree

        # INPUT_DIMS
        if msg_kind == 'gin':

            # dummy variable for self loop edge features
            self.central_node_edge_encoder = central_encoder(
                kwargs['edge_embedding'], d_ef, extend=kwargs['extend_dims'])
            d_ef = self.central_node_edge_encoder.d_out

            if self.id_scope == 'local':
                # dummy variable for self loop edge counts
                self.central_node_id_encoder = central_encoder(
                    kwargs['id_embedding'], d_id, extend=kwargs['extend_dims'])
                d_id = self.central_node_id_encoder.d_out

            msg_input_dim = None
            self.initial_eps = eps
            if train_eps:
                self.eps = torch.nn.Parameter(torch.Tensor([eps]))
            else:
                self.register_buffer('eps', torch.Tensor([eps]))
            self.eps.data.fill_(self.initial_eps)

            self.msg_fn = None
            update_input_dim = d_in + d_id + d_ef

        elif msg_kind == 'general':
            msg_input_dim = 2 * d_in + d_id + d_ef if id_scope == 'local' else 2 * (
                d_in + d_id) + d_ef
            # MSG_FUN
            self.msg_fn = mlp(msg_input_dim, d_msg, d_h, seed, activation_name,
                              bn)
            update_input_dim = d_in + d_msg

        else:
            raise NotImplementedError(
                'msg kind {} is not currently supported.'.format(msg_kind))

        self.update_fn = mlp(update_input_dim, d_up, d_h, seed,
                             activation_name, bn)

        return
    def __init__(self,
                 in_features,
                 out_features,
                 encoder_ids,
                 d_in_id,
                 in_edge_features=None,
                 d_in_node_encoder=None,
                 d_in_edge_encoder=None,
                 encoder_degrees=None,
                 d_degree=None,
                 **kwargs):

        super(GNN_OGB, self).__init__()

        seed = kwargs['seed']

        #-------------- Initializations

        self.model_name = kwargs['model_name']
        self.readout = kwargs['readout'] if kwargs[
            'readout'] is not None else 'sum'
        self.dropout_features = kwargs['dropout_features']
        self.bn = kwargs['bn']
        self.final_projection = kwargs['final_projection']
        self.residual = kwargs['residual']
        self.inject_ids = kwargs['inject_ids']
        self.vn = kwargs['vn']

        id_scope = kwargs['id_scope']
        d_msg = kwargs['d_msg']
        d_out = kwargs['d_out']
        d_h = kwargs['d_h']
        aggr = kwargs['aggr'] if kwargs['aggr'] is not None else 'add'
        flow = kwargs['flow'] if kwargs[
            'flow'] is not None else 'target_to_source'
        msg_kind = kwargs['msg_kind'] if kwargs[
            'msg_kind'] is not None else 'general'
        train_eps = kwargs['train_eps'] if kwargs[
            'train_eps'] is not None else [False for _ in range(len(d_out))]
        activation_mlp = kwargs['activation_mlp']
        bn_mlp = kwargs['bn_mlp']
        jk_mlp = kwargs['jk_mlp']
        degree_embedding = kwargs['degree_embedding'] if kwargs[
            'degree_as_tag'][0] else 'None'
        degree_as_tag = kwargs['degree_as_tag']
        retain_features = kwargs['retain_features']

        encoders_kwargs = {
            'seed': seed,
            'activation_mlp': activation_mlp,
            'bn_mlp': bn_mlp,
            'aggr': kwargs['multi_embedding_aggr'],
            'features_scope': kwargs['features_scope']
        }

        #-------------- Input node embedding
        self.input_node_encoder = DiscreteEmbedding(
            kwargs['input_node_encoder'], in_features, d_in_node_encoder,
            kwargs['d_out_node_encoder'], **encoders_kwargs)
        d_in = self.input_node_encoder.d_out

        #-------------- Virtual node embedding
        if self.vn:
            vn_encoder_kwargs = copy.deepcopy(encoders_kwargs)
            vn_encoder_kwargs['init'] = 'zeros'
            self.vn_encoder = DiscreteEmbedding(kwargs['input_vn_encoder'], 1,
                                                [1],
                                                kwargs['d_out_vn_encoder'],
                                                **vn_encoder_kwargs)
            d_in_vn = self.vn_encoder.d_out

        #-------------- Edge embedding (for each GNN layer)
        self.edge_encoder = []
        d_ef = []
        for i in range(len(d_out)):
            edge_encoder_layer = DiscreteEmbedding(
                kwargs['edge_encoder'], in_edge_features, d_in_edge_encoder,
                kwargs['d_out_edge_encoder'][i], **encoders_kwargs)
            self.edge_encoder.append(edge_encoder_layer)
            d_ef.append(edge_encoder_layer.d_out)

        self.edge_encoder = nn.ModuleList(self.edge_encoder)

        # -------------- Identifier embedding (for each GNN layer)
        self.id_encoder = []
        d_id = []
        num_id_encoders = len(d_out) if kwargs['inject_ids'] else 1
        for i in range(num_id_encoders):
            id_encoder_layer = DiscreteEmbedding(kwargs['id_embedding'],
                                                 len(d_in_id), d_in_id,
                                                 kwargs['d_out_id_embedding'],
                                                 **encoders_kwargs)
            self.id_encoder.append(id_encoder_layer)
            d_id.append(id_encoder_layer.d_out)

        self.id_encoder = nn.ModuleList(self.id_encoder)

        #-------------- Degree embedding
        self.degree_encoder = DiscreteEmbedding(
            degree_embedding, 1, d_degree, kwargs['d_out_degree_embedding'],
            **encoders_kwargs)
        d_degree = self.degree_encoder.d_out

        #-------------- GNN layers w/ bn
        self.conv = []
        self.batch_norms = []
        self.mlp_vn = []
        for i in range(len(d_out)):

            if i > 0 and self.vn:
                #-------------- vn msg function
                mlp_vn_temp = mlp(d_in_vn, kwargs['d_out_vn'][i - 1], d_h[i],
                                  seed, activation_mlp, bn_mlp)
                self.mlp_vn.append(mlp_vn_temp)
                d_in_vn = kwargs['d_out_vn'][i - 1]

            kwargs_filter = {
                'd_in': d_in,
                'd_degree': d_degree,
                'degree_as_tag': degree_as_tag[i],
                'retain_features': retain_features[i],
                'd_msg': d_msg[i],
                'd_up': d_out[i],
                'd_h': d_h[i],
                'seed': seed,
                'activation_name': activation_mlp,
                'bn': bn_mlp,
                'aggr': aggr,
                'msg_kind': msg_kind,
                'eps': 0,
                'train_eps': train_eps[i],
                'flow': flow,
                'd_ef': d_ef[i],
                'edge_embedding': kwargs['edge_encoder'],
                'id_embedding': kwargs['id_embedding'],
                'extend_dims': kwargs['extend_dims']
            }

            use_ids = ((i > 0 and kwargs['inject_ids']) or
                       (i == 0)) and (self.model_name == 'GSN_edge_sparse_ogb')

            if use_ids:
                filter_fn = GSN_edge_sparse_ogb
                kwargs_filter['d_id'] = d_id[i] if self.inject_ids else d_id[0]
                kwargs_filter['id_scope'] = id_scope
            else:
                filter_fn = MPNN_edge_sparse_ogb
            self.conv.append(filter_fn(**kwargs_filter))

            bn_layer = nn.BatchNorm1d(d_out[i]) if self.bn[i] else None
            self.batch_norms.append(bn_layer)

            d_in = d_out[i]

        self.conv = nn.ModuleList(self.conv)
        self.batch_norms = nn.ModuleList(self.batch_norms)
        if kwargs['vn']:
            self.mlp_vn = nn.ModuleList(self.mlp_vn)

        #-------------- Readout
        if self.readout == 'sum':
            self.global_pool = global_add_pool_sparse
        elif self.readout == 'mean':
            self.global_pool = global_mean_pool_sparse
        else:
            raise ValueError("Invalid graph pooling type.")

        #-------------- Virtual node aggregation operator
        if self.vn:
            if kwargs['vn_pooling'] == 'sum':
                self.global_vn_pool = global_add_pool_sparse
            elif kwargs['vn_pooling'] == 'mean':
                self.global_vn_pool = global_mean_pool_sparse
            else:
                raise ValueError("Invalid graph virtual node pooling type.")

        self.lin_proj = nn.Linear(d_out[-1], out_features)

        #-------------- Activation fn (same across the network)

        self.activation = choose_activation(kwargs['activation'])

        return
    def __init__(self, encoder_name, d_in_features, d_in_encoder,
                 d_out_encoder, **kwargs):

        super(DiscreteEmbedding, self).__init__()

        #-------------- various different embedding layers
        kwargs['init'] = None if 'init' not in kwargs else kwargs['init']

        self.encoder_name = encoder_name
        # d_in_features: input feature size (e.g. if already one hot encoded),
        # d_in_encoder: number of unique values that will be encoded (size of embedding vocabulary)

        #-------------- fill embedding with zeros
        if encoder_name == 'zero_encoder':
            self.encoder = zero_encoder(d_out_encoder)
            d_out = d_out_encoder

        #-------------- linear pojection
        elif encoder_name == 'linear':
            self.encoder = nn.Linear(d_in_features, d_out_encoder, bias=True)
            d_out = d_out_encoder

        #-------------- mlp
        elif encoder_name == 'mlp':
            self.encoder = mlp(d_in_features, d_out_encoder, d_out_encoder,
                               kwargs['seed'], kwargs['activation_mlp'],
                               kwargs['bn_mlp'])
            d_out = d_out_encoder

        #-------------- multi hot encoding of categorical data
        elif encoder_name == 'one_hot_encoder':
            self.encoder = one_hot_encoder(d_in_encoder)
            d_out = sum(d_in_encoder)

        #-------------- embedding of categorical data (linear projection without bias of one hot encodings)
        elif encoder_name == 'embedding':
            self.encoder = multi_embedding(d_in_encoder, d_out_encoder,
                                           kwargs['aggr'], kwargs['init'])
            if kwargs['aggr'] == 'concat':
                d_out = len(d_in_encoder) * d_out_encoder
            else:
                d_out = d_out_encoder

        #-------------- for ogb: multi hot encoding of node features
        elif encoder_name == 'atom_one_hot_encoder':
            full_atom_feature_dims = get_atom_feature_dims(
            ) if kwargs['features_scope'] == 'full' else get_atom_feature_dims(
            )[:2]
            self.encoder = one_hot_encoder(full_atom_feature_dims)
            d_out = sum(full_atom_feature_dims)

        #-------------- for ogb: multi hot encoding of edge features
        elif encoder_name == 'bond_one_hot_encoder':
            full_bond_feature_dims = get_bond_feature_dims(
            ) if kwargs['features_scope'] == 'full' else get_bond_feature_dims(
            )[:2]
            self.encoder = one_hot_encoder(full_bond_feature_dims)
            d_out = sum(full_bond_feature_dims)

        #-------------- for ogb: embedding of node features
        elif encoder_name == 'atom_encoder':
            self.encoder = AtomEncoder(d_out_encoder)
            d_out = d_out_encoder

        #-------------- for ogb: embedding of edge features
        elif encoder_name == 'bond_encoder':
            self.encoder = BondEncoder(emb_dim=d_out_encoder)
            d_out = d_out_encoder

        #-------------- no embedding, use as is
        elif encoder_name == 'None':
            self.encoder = None
            d_out = d_in_features

        else:
            raise NotImplementedError(
                'Encoder {} is not currently supported.'.format(encoder_name))

        self.d_out = d_out

        return
Пример #6
0
    def __init__(self,
                 in_features,
                 out_features,
                 encoder_ids,
                 d_in_id,
                 in_edge_features=None,
                 d_in_node_encoder=None,
                 d_in_edge_encoder=None,
                 encoder_degrees=None,
                 d_degree=None,
                 **kwargs):

        super(MLPSubstructures, self).__init__()

        seed = kwargs['seed']

        #-------------- Initializations

        self.model_name = kwargs['model_name']
        self.readout = kwargs['readout'] if kwargs[
            'readout'] is not None else 'sum'
        self.dropout_features = kwargs['dropout_features']
        self.bn = kwargs['bn']
        self.degree_as_tag = kwargs['degree_as_tag']
        self.retain_features = kwargs['retain_features']
        self.id_scope = kwargs['id_scope']

        d_out = kwargs['d_out']
        d_h = kwargs['d_h']
        activation_mlp = kwargs['activation_mlp']
        bn_mlp = kwargs['bn_mlp']
        jk_mlp = kwargs['jk_mlp']
        degree_embedding = kwargs['degree_embedding'] if kwargs[
            'degree_as_tag'][0] else 'None'

        encoders_kwargs = {
            'seed': seed,
            'activation_mlp': activation_mlp,
            'bn_mlp': bn_mlp,
            'aggr': kwargs['multi_embedding_aggr']
        }

        #-------------- Input node embedding
        self.input_node_encoder = DiscreteEmbedding(
            kwargs['input_node_encoder'], in_features, d_in_node_encoder,
            kwargs['d_out_node_encoder'], **encoders_kwargs)
        d_in = self.input_node_encoder.d_out

        #-------------- Edge embedding (for each GNN layer)
        self.edge_encoder = DiscreteEmbedding(kwargs['edge_encoder'],
                                              in_edge_features,
                                              d_in_edge_encoder,
                                              kwargs['d_out_edge_encoder'][0],
                                              **encoders_kwargs)
        d_ef = self.edge_encoder.d_out

        #-------------- Identifier embedding (for each GNN layer)
        self.id_encoder = DiscreteEmbedding(kwargs['id_embedding'],
                                            len(d_in_id), d_in_id,
                                            kwargs['d_out_id_embedding'],
                                            **encoders_kwargs)
        d_id = self.id_encoder.d_out

        #-------------- Degree embedding
        self.degree_encoder = DiscreteEmbedding(
            degree_embedding, 1, d_degree, kwargs['d_out_degree_embedding'],
            **encoders_kwargs)
        d_degree = self.degree_encoder.d_out

        #-------------- edge-wise MLP w/ bn

        if self.degree_as_tag[0] and self.retain_features[0] == True:
            mlp_input_dim = 2 * (d_in + d_degree)
        elif self.degree_as_tag[0] and self.retain_features[0] == False:
            mlp_input_dim = 2 * d_degree
        else:
            mlp_input_dim = 2 * d_in

        if self.id_scope == 'global':
            mlp_input_dim += 2 * d_id
        else:
            mlp_input_dim += d_id

        if kwargs['edge_encoder'] != 'None':
            mlp_input_dim += d_ef

        filter_fn = mlp(mlp_input_dim, d_out[0], d_h[0], seed, activation_mlp,
                        bn_mlp)
        self.conv = filter_fn
        self.batch_norms = nn.BatchNorm1d(d_out[0]) if self.bn[0] else None

        if jk_mlp:
            final_jk_layer = mlp(d_out[0], out_features, d_h[0], seed,
                                 activation_mlp, bn_mlp)
        else:
            final_jk_layer = nn.Linear(d_out[0], out_features)

        self.lin_proj = final_jk_layer

        #-------------- Readout
        if self.readout == 'sum':
            self.global_pool = global_add_pool_sparse
        elif self.readout == 'mean':
            self.global_pool = global_mean_pool_sparse
        else:
            raise ValueError("Invalid graph pooling type.")

        #-------------- Activation fn (same across the network)
        self.activation = choose_activation(kwargs['activation'])

        return
    def __init__(self,
                 d_in,
                 d_ef,
                 d_degree,
                 degree_as_tag,
                 retain_features,
                 d_msg,
                 d_up,
                 d_h,
                 seed,
                 activation_name,
                 bn,
                 aggr='add',
                 msg_kind='general',
                 eps=0,
                 train_eps=False,
                 flow='source_to_target',
                 **kwargs):

        super(MPNN_edge_sparse, self).__init__()

        d_msg = d_in if d_msg is None else d_msg

        self.flow = flow
        self.aggr = aggr
        self.msg_kind = msg_kind

        self.degree_as_tag = degree_as_tag
        self.retain_features = retain_features

        if degree_as_tag:
            d_in = d_in + d_degree if retain_features else d_degree

        # INPUT_DIMS
        if msg_kind == 'gin':

            # dummy variable for self loop edge features
            self.central_node_edge_encoder = central_encoder(
                kwargs['edge_embedding'], d_ef, extend=kwargs['extend_dims'])
            d_ef = self.central_node_edge_encoder.d_out

            msg_input_dim = None
            self.initial_eps = eps
            if train_eps:
                self.eps = torch.nn.Parameter(torch.Tensor([eps]))
            else:
                self.register_buffer('eps', torch.Tensor([eps]))
            self.eps.data.fill_(self.initial_eps)
            self.msg_fn = None
            update_input_dim = d_in + d_ef

        elif msg_kind == 'general':
            msg_input_dim = 2 * d_in + d_ef
            self.msg_fn = mlp(msg_input_dim, d_msg, d_h, seed, activation_name,
                              bn)
            update_input_dim = d_in + d_msg

        self.update_fn = mlp(update_input_dim, d_up, d_h, seed,
                             activation_name, bn)

        return
Пример #8
0
    def __init__(self,
                 in_features,
                 out_features,
                 encoder_ids,
                 d_in_id,
                 in_edge_features=None,
                 d_in_node_encoder=None,
                 d_in_edge_encoder=None,
                 encoder_degrees=None,
                 d_degree=None,
                 **kwargs):

        super(GNNSubstructures, self).__init__()

        seed = kwargs['seed']

        #-------------- Initializations

        self.model_name = kwargs['model_name']
        self.readout = kwargs['readout'] if kwargs[
            'readout'] is not None else 'sum'
        self.dropout_features = kwargs['dropout_features']
        self.bn = kwargs['bn']
        self.final_projection = kwargs['final_projection']
        self.inject_ids = kwargs['inject_ids']
        self.inject_edge_features = kwargs['inject_edge_features']
        self.random_features = kwargs['random_features']

        id_scope = kwargs['id_scope']
        d_msg = kwargs['d_msg']
        d_out = kwargs['d_out']
        d_h = kwargs['d_h']
        aggr = kwargs['aggr'] if kwargs['aggr'] is not None else 'add'
        flow = kwargs['flow'] if kwargs[
            'flow'] is not None else 'target_to_source'
        msg_kind = kwargs['msg_kind'] if kwargs[
            'msg_kind'] is not None else 'general'
        train_eps = kwargs['train_eps'] if kwargs[
            'train_eps'] is not None else [False for _ in range(len(d_out))]
        activation_mlp = kwargs['activation_mlp']
        bn_mlp = kwargs['bn_mlp']
        jk_mlp = kwargs['jk_mlp']
        degree_embedding = kwargs['degree_embedding'] if kwargs[
            'degree_as_tag'][0] else 'None'
        degree_as_tag = kwargs['degree_as_tag']
        retain_features = kwargs['retain_features']

        encoders_kwargs = {
            'seed': seed,
            'activation_mlp': activation_mlp,
            'bn_mlp': bn_mlp,
            'aggr': kwargs['multi_embedding_aggr']
        }

        #-------------- Input node embedding
        self.input_node_encoder = DiscreteEmbedding(
            kwargs['input_node_encoder'], in_features, d_in_node_encoder,
            kwargs['d_out_node_encoder'], **encoders_kwargs)
        d_in = self.input_node_encoder.d_out
        if self.random_features:
            self.r_d_out = d_out[0]
            d_in = d_in + self.r_d_out

        #-------------- Edge embedding (for each GNN layer)
        self.edge_encoder = []
        d_ef = []
        num_edge_encoders = len(d_out) if kwargs['inject_edge_features'] else 1
        for i in range(num_edge_encoders):
            edge_encoder_layer = DiscreteEmbedding(
                kwargs['edge_encoder'], in_edge_features, d_in_edge_encoder,
                kwargs['d_out_edge_encoder'][i], **encoders_kwargs)
            self.edge_encoder.append(edge_encoder_layer)
            d_ef.append(edge_encoder_layer.d_out)

        self.edge_encoder = nn.ModuleList(self.edge_encoder)

        #-------------- Identifier embedding (for each GNN layer)
        self.id_encoder = []
        d_id = []
        num_id_encoders = len(d_out) if kwargs['inject_ids'] else 1
        for i in range(num_id_encoders):
            id_encoder_layer = DiscreteEmbedding(kwargs['id_embedding'],
                                                 len(d_in_id), d_in_id,
                                                 kwargs['d_out_id_embedding'],
                                                 **encoders_kwargs)
            self.id_encoder.append(id_encoder_layer)
            d_id.append(id_encoder_layer.d_out)

        self.id_encoder = nn.ModuleList(self.id_encoder)

        #-------------- Degree embedding
        self.degree_encoder = DiscreteEmbedding(
            degree_embedding, 1, d_degree, kwargs['d_out_degree_embedding'],
            **encoders_kwargs)
        d_degree = self.degree_encoder.d_out

        #-------------- GNN layers w/ bn and jk
        self.conv = []
        self.batch_norms = []
        self.lin_proj = []
        for i in range(len(d_out)):

            kwargs_filter = {
                'd_in': d_in,
                'd_degree': d_degree,
                'degree_as_tag': degree_as_tag[i],
                'retain_features': retain_features[i],
                'd_msg': d_msg[i],
                'd_up': d_out[i],
                'd_h': d_h[i],
                'd_ef': d_ef[i] if self.inject_edge_features else d_ef[0],
                'seed': seed,
                'activation_name': activation_mlp,
                'bn': bn_mlp,
                'aggr': aggr,
                'msg_kind': msg_kind,
                'eps': 0,
                'train_eps': train_eps[i],
                'flow': flow,
                'edge_embedding': kwargs['edge_encoder'],
                'id_embedding': kwargs['id_embedding'],
                'extend_dims': kwargs['extend_dims']
            }

            use_ids = ((i > 0 and kwargs['inject_ids']) or
                       (i == 0)) and (self.model_name
                                      in {'GSN_sparse', 'GSN_edge_sparse'})
            use_efs = ((i > 0 and kwargs['inject_edge_features']) or
                       (i == 0)) and (self.model_name in {
                           'GSN_edge_sparse', 'MPNN_edge_sparse'
                       })
            if use_ids:
                filter_fn = GSN_edge_sparse if use_efs else GSN_sparse
                kwargs_filter['d_id'] = d_id[i] if self.inject_ids else d_id[0]
                kwargs_filter['id_scope'] = id_scope
            else:
                filter_fn = MPNN_edge_sparse if use_efs else MPNN_sparse
            self.conv.append(filter_fn(**kwargs_filter))

            if self.final_projection[i]:
                # if desired, jk projections can be performed
                # by an mlp instead of a simple linear layer;
                if jk_mlp:
                    jk_layer = mlp(d_in, out_features, d_h[i], seed,
                                   activation_mlp, bn_mlp)
                else:
                    jk_layer = nn.Linear(d_in, out_features)
            else:
                jk_layer = None
            self.lin_proj.append(jk_layer)

            bn_layer = nn.BatchNorm1d(d_out[i]) if self.bn[i] else None
            self.batch_norms.append(bn_layer)

            d_in = d_out[i]

        if self.final_projection[-1]:
            # if desired, jk projections can be performed
            # by an mlp instead of a simple linear layer;
            if jk_mlp:
                final_jk_layer = mlp(d_in, out_features, d_h[-1], seed,
                                     activation_mlp, bn_mlp)
            else:
                final_jk_layer = nn.Linear(d_in, out_features)
        else:
            final_jk_layer = None
        self.lin_proj.append(final_jk_layer)

        self.conv = nn.ModuleList(self.conv)
        self.lin_proj = nn.ModuleList(self.lin_proj)
        self.batch_norms = nn.ModuleList(self.batch_norms)

        #-------------- Readout
        if self.readout == 'sum':
            self.global_pool = global_add_pool_sparse
        elif self.readout == 'mean':
            self.global_pool = global_mean_pool_sparse
        else:
            raise ValueError("Invalid graph pooling type.")

        #-------------- Activation fn (same across the network)
        self.activation = choose_activation(kwargs['activation'])

        return