def __init__( self, batch_size=100, hidden=100, lr=0.001, layers=3, dropout=0.5, virtual_node=False, conv_radius=3, plus=False, appnp=False, out_dim=1, ): # assert conv_type in ["gcn", "gin", "gin+", "gine"] super().__init__() self.hidden = hidden self.lr = lr self.batch_size = batch_size self.k = conv_radius self.atomencoder = AtomEncoder(hidden) self.conv_type = "gin+" if plus else 'gine' convs = [ gine_layer( hidden, dropout=dropout, virtual_node=virtual_node, k=min(i + 1, self.k), conv_type=self.conv_type, edge_embedding=BondEncoder(emb_dim=hidden), ) for i in range(layers - 1) ] convs.append( gine_layer( hidden, dropout=dropout, virtual_node=virtual_node, virtual_node_agg= False, # on last layer, use but do not update virtual node last_layer=True, k=min(layers, self.k), conv_type=self.conv_type, edge_embedding=BondEncoder(emb_dim=hidden), )) self.convs = convs self.main = nn.Sequential(*convs) # self.main = nn.Sequential(OGBMolEmbedding(hidden, embed_edge=False, x_as_list=(conv_type == "gin+")), *convs) self.readout = nn.Linear(hidden, out_dim) self.virtual_node = virtual_node if self.virtual_node: self.v0 = nn.Parameter(torch.zeros(1, hidden), requires_grad=True) self.appnp = APPNP(0.8, 5) if appnp else None # Loss and metrics self.loss_fn = nn.BCEWithLogitsLoss()
def __init__(self, emb_dim, aggr='mean', device='cuda'): super(GCNLafConv, self).__init__() self.linear = torch.nn.Linear(emb_dim, emb_dim) self.root_emb = torch.nn.Embedding(1, emb_dim) self.bond_encoder = BondEncoder(emb_dim = emb_dim) self.aggregator = ScatterAggregationLayer(grad=True, function=aggr, device=device)
def __init__(self, in_feats, out_feats, rank_dim, norm='both', weight=True, bias=True, activation=None, allow_zero_in_degree=False): super(DGLGraphConv, self).__init__() if norm not in ('none', 'both', 'right'): raise DGLError( 'Invalid norm value. Must be either "none", "both" or "right".' ' But got "{}".'.format(norm)) self._in_feats = in_feats self._out_feats = out_feats self._rank_dim = rank_dim self._norm = norm self._allow_zero_in_degree = allow_zero_in_degree if weight: self.w1 = torch.nn.Parameter(torch.Tensor(in_feats, out_feats)) self.w2 = torch.nn.Parameter(torch.Tensor(in_feats + 1, rank_dim)) self.v = torch.nn.Parameter(torch.Tensor(rank_dim, out_feats)) #self.weight_sum = nn.Parameter(th.Tensor(in_feats, out_feats)) #self.weight2 = nn.Parameter(th.Tensor(rank_dim, out_feats)) #self.bias = nn.Parameter(th.Tensor(rank_dim)) else: self.register_parameter('weight', None) self.bond_encoder = BondEncoder(out_feats) self.reset_parameters() self._activation = activation
def __init__(self, hidden_channels, out_channels, num_layers=3, dropout=0.5): super().__init__() self.dropout = dropout self.atom_encoder = AtomEncoder(hidden_channels) self.bond_encoder = BondEncoder(hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): nn = Sequential( Linear(hidden_channels, 2 * hidden_channels), BatchNorm(2 * hidden_channels), ReLU(), Linear(2 * hidden_channels, hidden_channels), BatchNorm(hidden_channels), ReLU(), ) self.convs.append(GINEConv(nn, train_eps=True)) self.lin = Linear(hidden_channels, out_channels)
def __init__(self, emb_dim): super(IConv, self).__init__(aggr='add') #self.linear = torch.nn.Linear(emb_dim, emb_dim) self.iden = torch.nn.Identity(emb_dim, emb_dim) self.root_emb = torch.nn.Embedding(1, emb_dim) self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def __init__(self, gnn_type, num_tasks, num_layer=4, emb_dim=256, dropout=0.0, batch_norm=True, residual=True, graph_pooling="mean"): super().__init__() self.num_tasks = num_tasks self.num_layer = num_layer self.emb_dim = emb_dim self.dropout = dropout self.batch_norm = batch_norm self.residual = residual self.graph_pooling = graph_pooling self.atom_encoder = AtomEncoder(emb_dim) self.bond_encoder = BondEncoder(emb_dim) gnn_layer = { 'Cheb_net': ChebLayer, 'mlp': MLPLayer, }.get(gnn_type, ChebLayer) self.layers = nn.ModuleList([ gnn_layer(emb_dim, emb_dim, dropout=dropout, batch_norm=batch_norm, residual=residual) for _ in range(num_layer) ]) self.pooler = { "mean": dgl.mean_nodes, "sum": dgl.sum_nodes, "max": dgl.max_nodes, }.get(graph_pooling, dgl.mean_nodes) self.graph_pred_linear = MLPReadout(emb_dim, num_tasks)
def __init__(self, num_layers, num_mlp_layers, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, norm_type): super(GIN, self).__init__() self.num_layers = num_layers self.learn_eps = learn_eps self.ginlayers = torch.nn.ModuleList() self.atom_encoder = AtomEncoder(hidden_dim) self.bond_layers = torch.nn.ModuleList() for layer in range(self.num_layers - 1): mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim * 2, hidden_dim, norm_type) self.ginlayers.append( GINConv(ApplyNodeFunc(mlp, norm_type), neighbor_pooling_type, 0, self.learn_eps)) self.bond_layers.append(BondEncoder(hidden_dim)) self.linears_prediction = nn.Linear(hidden_dim, output_dim) self.drop = nn.Dropout(final_dropout) if graph_pooling_type == 'sum': self.pool = SumPooling() elif graph_pooling_type == 'mean': self.pool = AvgPooling() elif graph_pooling_type == 'max': self.pool = MaxPooling() else: raise NotImplementedError
def __init__(self, in_dim, out_dim, aggregator='softmax', beta=1.0, learn_beta=False, p=1.0, learn_p=False, msg_norm=False, learn_msg_scale=False, mlp_layers=1, eps=1e-7): super(GENConv, self).__init__() self.aggr = aggregator self.eps = eps channels = [in_dim] for _ in range(mlp_layers - 1): channels.append(in_dim * 2) channels.append(out_dim) self.mlp = MLP(channels) self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None self.beta = nn.Parameter( torch.Tensor([beta]), requires_grad=True ) if learn_beta and self.aggr == 'softmax' else beta self.p = nn.Parameter(torch.Tensor([p]), requires_grad=True) if learn_p else p self.edge_encoder = BondEncoder(in_dim)
def __init__(self, in_feats, out_feats): super(GCNConv, self).__init__() self.fc = nn.Linear(in_feats, out_feats, bias=False) self.root_emb = nn.Embedding(1, in_feats) self.bond_encoder = BondEncoder(in_feats) self.reset_parameters()
def __init__(self, emb_dim): super(GCNConv, self).__init__(aggr='add') self.w1 = torch.nn.Linear(emb_dim, emb_dim) self.w2 = torch.nn.Linear(emb_dim, emb_dim) self.v = torch.nn.Linear(emb_dim, emb_dim) self.root_emb = torch.nn.Embedding(1, emb_dim) self.bond_encoder = BondEncoder(emb_dim=emb_dim) self.reset_parameters()
def __init__(self, emb_dim): super(GINConv_for_OGB, self).__init__(aggr="add") self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim), torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim)) self.eps = torch.nn.Parameter(torch.Tensor([0])) self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def __init__(self, emb_dim): ''' emb_dim (int): node embedding dimensionality ''' super(GCNConv, self).__init__() self.linear = nn.Linear(emb_dim, emb_dim) self.root_emb = nn.Embedding(1, emb_dim) self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def __init__(self, data_info: dict, embed_size: int = 300, num_layers: int = 5, dropout: float = 0.5, virtual_node: bool = False): """Graph Isomorphism Network (GIN) variant introduced in baselines for OGB graph property prediction datasets Parameters ---------- data_info : dict The information about the input dataset. embed_size : int Embedding size. num_layers : int Number of layers. dropout : float Dropout rate. virtual_node : bool Whether to use virtual node. """ super(OGBGGIN, self).__init__() self.data_info = data_info self.embed_size = embed_size self.num_layers = num_layers self.virtual_node = virtual_node if data_info['name'] in ['ogbg-molhiv', 'ogbg-molpcba']: self.node_encoder = AtomEncoder(embed_size) self.edge_encoders = nn.ModuleList( [BondEncoder(embed_size) for _ in range(num_layers)]) else: # Handle other datasets self.node_encoder = nn.Linear(data_info['node_feat_size'], embed_size) self.edge_encoders = nn.ModuleList([ nn.Linear(data_info['edge_feat_size'], embed_size) for _ in range(num_layers) ]) self.conv_layers = nn.ModuleList( [GINEConv(MLP(embed_size)) for _ in range(num_layers)]) self.dropout = nn.Dropout(dropout) self.pool = AvgPooling() self.pred = nn.Linear(embed_size, data_info['out_size']) if virtual_node: self.virtual_emb = nn.Embedding(1, embed_size) nn.init.constant_(self.virtual_emb.weight.data, 0) self.mlp_virtual = nn.ModuleList() for _ in range(num_layers - 1): self.mlp_virtual.append(MLP(embed_size)) self.virtual_pool = SumPooling()
def __init__(self, emb_dim): ''' emb_dim (int): node embedding dimensionality ''' super(GINConv, self).__init__(aggr = "add") self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, emb_dim)) self.eps = torch.nn.Parameter(torch.Tensor([0])) self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def __init__(self, hidden, config, **kwargs): super(GinConv, self).__init__(aggr='add', **kwargs) self.fea_mlp = Sequential(Linear(hidden, hidden), ReLU(), Linear(hidden, hidden), ReLU()) if config.BN == 'Y': self.BN = BN(hidden) else: self.BN = None self.bond_encoder = BondEncoder(emb_dim=hidden)
def __init__(self, hidden=100, out_dim=128, layers=3, dropout=0.5, virtual_node=False, k=4, conv_type='gin'): super().__init__() self.k = k self.conv_type = conv_type convs = [ ConvBlock(hidden, dropout=dropout, virtual_node=virtual_node, k=min(i + 1, k), conv_type=conv_type, edge_embedding=BondEncoder(emb_dim=hidden)) for i in range(layers - 1) ] convs.append( ConvBlock( hidden, dropout=dropout, virtual_node=virtual_node, virtual_node_agg= False, # on last layer, use but do not update virtual node last_layer=True, k=min(layers, k), conv_type=conv_type, edge_embedding=BondEncoder(emb_dim=hidden))) self.main = nn.Sequential( OGBMolEmbedding(hidden, embed_edge=False, x_as_list=(conv_type == 'gin+')), *convs) self.aggregate = nn.Sequential(GlobalPool('mean'), nn.Linear(hidden, out_dim)) self.virtual_node = virtual_node if self.virtual_node: self.v0 = nn.Parameter(torch.zeros(1, hidden), requires_grad=True)
def __init__(self, emb_dim, aggr, device='cuda'): ''' emb_dim (int): node embedding dimensionality ''' super(GINLafConv, self).__init__() self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim), torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim)) self.eps = torch.nn.Parameter(torch.Tensor([0])) self.bond_encoder = BondEncoder(emb_dim=emb_dim) self.aggregator = ScatterAggregationLayer(grad=True, function=aggr, device=device)
def __init__(self, net_params): super().__init__() hidden_dim = net_params['hidden_dim'] out_dim = net_params['out_dim'] in_feat_dropout = net_params['in_feat_dropout'] dropout = net_params['dropout'] n_layers = net_params['L'] self.type_net = net_params['type_net'] self.pos_enc_dim = net_params['pos_enc_dim'] if self.pos_enc_dim > 0: self.embedding_pos_enc = nn.Linear(self.pos_enc_dim, hidden_dim) self.readout = net_params['readout'] self.graph_norm = net_params['graph_norm'] self.batch_norm = net_params['batch_norm'] self.aggregators = net_params['aggregators'] self.scalers = net_params['scalers'] self.avg_d = net_params['avg_d'] self.residual = net_params['residual'] self.JK = net_params['JK'] self.edge_feat = net_params['edge_feat'] edge_dim = net_params['edge_dim'] pretrans_layers = net_params['pretrans_layers'] posttrans_layers = net_params['posttrans_layers'] self.gru_enable = net_params['gru'] device = net_params['device'] self.device = device self.in_feat_dropout = nn.Dropout(in_feat_dropout) self.embedding_h = AtomEncoder(emb_dim=hidden_dim) if self.edge_feat: self.embedding_e = BondEncoder(emb_dim=edge_dim) self.layers = nn.ModuleList([EIGLayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, graph_norm=self.graph_norm, batch_norm=self.batch_norm, residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, type_net=self.type_net, edge_features=self.edge_feat, edge_dim=edge_dim, pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers).model for _ in range(n_layers - 1)]) self.layers.append(EIGLayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, graph_norm=self.graph_norm, batch_norm=self.batch_norm, residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, type_net=self.type_net, edge_features=self.edge_feat, edge_dim=edge_dim, pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers).model) if self.gru_enable: self.gru = GRU(hidden_dim, hidden_dim, device) self.MLP_layer = MLPReadout(out_dim, 1) # 1 out dim since regression problem
def __init__(self, hidden, num_aggr, config, **kwargs): super(ExpC, self).__init__(aggr='add', **kwargs) self.hidden = hidden self.num_aggr = num_aggr self.fea_mlp = Sequential(Linear(hidden * self.num_aggr, hidden), ReLU(), Linear(hidden, hidden), ReLU()) self.aggr_mlp = Sequential(Linear(hidden * 2, self.num_aggr), Tanh()) if config.BN == 'Y': self.BN = BN(hidden) else: self.BN = None self.bond_encoder = BondEncoder(emb_dim=hidden)
def __init__(self, num_layer=5, emb_dim=100, num_task=2): super(GIN, self).__init__() self.num_layer = num_layer self.gins = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for layer in range(self.num_layer): self.gins.append(GINConv(emb_dim)) self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) ### convenient module to encode/embed raw molecule node/edge features. (TODO) make it more efficient. self.atom_encoder = AtomEncoder(emb_dim) self.bond_encoder = BondEncoder(emb_dim) self.graph_pred_linear = torch.nn.Linear(emb_dim, num_task)
def __init__(self, K: int, in_channels: int, alpha: float, dropout: float = 0., cached: bool = False, add_self_loops: bool = False, # DO NOT add edges because self-loop also needs edge_attr normalize: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super(APPNP, self).__init__(**kwargs) self.K = K self.alpha = alpha self.dropout = dropout self.cached = cached self.add_self_loops = add_self_loops self.normalize = normalize self.bond_encoder = BondEncoder(emb_dim = in_channels) self._cached_edge_index = None self._cached_adj_t = None
def __init__(self, emb_dim): ''' emb_dim (int): node embedding dimensionality ''' super(BayesianGINConv, self).__init__(aggr = "add") self.mlp = torch.nn.Sequential( bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=emb_dim, out_features=emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=emb_dim, out_features=emb_dim) ) self.eps = torch.nn.Parameter(torch.Tensor([0])) self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def __init__(self, net_params): super().__init__() hidden_dim = net_params['hidden_dim'] num_heads = net_params['n_heads'] out_dim = net_params['out_dim'] in_feat_dropout = net_params['in_feat_dropout'] dropout = net_params['dropout'] n_layers = net_params['L'] self.readout = net_params['readout'] self.layer_norm = net_params['layer_norm'] self.batch_norm = net_params['batch_norm'] self.residual = net_params['residual'] self.edge_feat = net_params['edge_feat'] self.device = net_params['device'] self.lap_pos_enc = net_params['lap_pos_enc'] self.wl_pos_enc = net_params['wl_pos_enc'] max_wl_role_index = 37 # this is maximum graph size in the dataset if self.lap_pos_enc: pos_enc_dim = net_params['pos_enc_dim'] self.embedding_lap_pos_enc = nn.Linear(pos_enc_dim, hidden_dim) if self.wl_pos_enc: self.embedding_wl_pos_enc = nn.Embedding(max_wl_role_index, hidden_dim) self.embedding_h = AtomEncoder(emb_dim=hidden_dim) if self.edge_feat: self.embedding_e = BondEncoder(emb_dim=hidden_dim) else: self.embedding_e = nn.Linear(1, hidden_dim) self.in_feat_dropout = nn.Dropout(in_feat_dropout) self.layers = nn.ModuleList([ GraphTransformerLayer(hidden_dim, hidden_dim, num_heads, dropout, self.layer_norm, self.batch_norm, self.residual) for _ in range(n_layers - 1) ]) self.layers.append( GraphTransformerLayer(hidden_dim, out_dim, num_heads, dropout, self.layer_norm, self.batch_norm, self.residual)) self.MLP_layer = MLPReadout( out_dim, 128) # 128 out dim since regression problem
def __init__(self, hidden, config, **kwargs): super(GinConv, self).__init__(aggr='add', **kwargs) if config.fea_activation == 'ELU': self.fea_activation = ELU() elif config.fea_activation == 'ReLU': self.fea_activation = ReLU() self.fea_mlp = Sequential(Linear(hidden, hidden), ReLU(), Linear(hidden, hidden), self.fea_activation) if config.BN == 'T': self.BN = BN(hidden) else: self.BN = None self.bond_encoder = BondEncoder(emb_dim=hidden) self.reset_parameters()
def __init__(self, in_channels: int, out_channels: int, normalize: bool = False, bias: bool = True, **kwargs): # yapf: disable kwargs.setdefault('aggr', 'mean') super(SAGEConv, self).__init__(**kwargs) self.bond_encoder = BondEncoder(emb_dim=in_channels) self.in_channels = in_channels self.out_channels = out_channels self.normalize = normalize # if isinstance(in_channels, int): # in_channels = (in_channels, in_channels) # self.lin_l = Linear(in_channels[0], out_channels, bias=bias) # self.lin_r = Linear(in_channels[1], out_channels, bias=False) self.lin_l = Linear(in_channels, out_channels, bias=bias) self.lin_r = Linear(in_channels, out_channels, bias=False) self.reset_parameters()
def __init__(self, num_timesteps=4, emb_dim=300, num_layers=5, drop_ratio=0, num_tasks=1, **args): super(AttentiveFP, self).__init__() self.num_layers = num_layers self.num_timesteps = num_timesteps self.drop_ratio = drop_ratio self.atom_encoder = AtomEncoder(emb_dim) self.bond_encoder = BondEncoder(emb_dim=emb_dim) conv = GATEConv(emb_dim, emb_dim, emb_dim, drop_ratio) gru = GRUCell(emb_dim, emb_dim) self.atom_convs = torch.nn.ModuleList([conv]) self.atom_grus = torch.nn.ModuleList([gru]) for _ in range(num_layers - 1): conv = GATConv(emb_dim, emb_dim, dropout=drop_ratio, add_self_loops=False, negative_slope=0.01) self.atom_convs.append(conv) self.atom_grus.append(GRUCell(emb_dim, emb_dim)) self.mol_conv = GATConv(emb_dim, emb_dim, dropout=drop_ratio, add_self_loops=False, negative_slope=0.01) self.mol_gru = GRUCell(emb_dim, emb_dim) self.graph_pred_linear = Linear(emb_dim, num_tasks) self.reset_parameters()
def __init__( self, channels: int, alpha: float, theta: float = None, layer: int = None, shared_weights: bool = True, cached: bool = False, add_self_loops: bool = False, # DO NOT add edges because self-loop also needs edge_attr normalize: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super(GCN2Conv, self).__init__(**kwargs) self.bond_encoder = BondEncoder(emb_dim=channels) self.channels = channels self.alpha = alpha self.beta = 1. if theta is not None or layer is not None: assert theta is not None and layer is not None self.beta = log(theta / layer + 1) self.cached = cached self.normalize = normalize self.add_self_loops = add_self_loops self._cached_edge_index = None self._cached_adj_t = None self.weight1 = Parameter(torch.Tensor(channels, channels)) if shared_weights: self.register_parameter('weight2', None) else: self.weight2 = Parameter(torch.Tensor(channels, channels)) self.reset_parameters()
def __init__(self, dataset, in_dim, out_dim, aggregator='softmax', beta=1.0, learn_beta=False, p=1.0, learn_p=False, msg_norm=False, learn_msg_scale=False, norm='batch', mlp_layers=1, eps=1e-7): super(GENConv, self).__init__() self.aggr = aggregator self.eps = eps channels = [in_dim] for i in range(mlp_layers - 1): channels.append(in_dim * 2) channels.append(out_dim) self.mlp = MLP(channels, norm=norm) self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None self.beta = nn.Parameter( torch.Tensor([beta]), requires_grad=True ) if learn_beta and self.aggr == 'softmax' else beta self.p = nn.Parameter(torch.Tensor([p]), requires_grad=True) if learn_p else p if dataset == 'ogbg-molhiv': self.edge_encoder = BondEncoder(in_dim) elif dataset == 'ogbg-ppa': self.edge_encoder = nn.Linear(in_dim, in_dim) else: raise ValueError(f'Dataset {dataset} is not supported.')
def __init__(self, in_channels, out_channels, K=2, improved=False, cached=False, add_self_loops=False, normalize=True, bias=True, **kwargs): kwargs.setdefault('aggr', 'add') super(SoGCNConv, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.K = K self.improved = improved self.cached = cached self.add_self_loops = add_self_loops self.normalize = normalize self._cached_edge_index = None self._cached_adj_t = None self.bond_encoder = BondEncoder(emb_dim=in_channels) # plus one for constant term self.weight = torch.nn.Parameter( torch.Tensor(K + 1, in_channels, out_channels)) if bias: self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters()
def __init__(self, in_channels: int, out_channels: int, num_stacks: int = 1, num_layers: int = 1, shared_weights: bool = False, act: Optional[Callable] = ReLU(), dropout: float = 0., bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super(ARMAConv, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.num_stacks = num_stacks self.num_layers = num_layers self.act = act self.shared_weights = shared_weights self.dropout = dropout self.bond_encoder = BondEncoder(emb_dim=in_channels) K, T, F_in, F_out = num_stacks, num_layers, in_channels, out_channels T = 1 if self.shared_weights else T self.init_weight = Parameter(torch.Tensor(K, F_in, F_out)) self.weight = Parameter(torch.Tensor(max(1, T - 1), K, F_out, F_out)) self.root_weight = Parameter(torch.Tensor(T, K, F_in, F_out)) if bias: self.bias = Parameter(torch.Tensor(T, K, 1, F_out)) else: self.register_parameter('bias', None) self.reset_parameters()