def reset_parameters(self): self.psi_1.reset_parameters() for psi in self.psi_stack: psi.reset_parameters() reset(self.mlp) self.sum_weights = Parameter(torch.ones(self.num_psi + 1))
def reset_parameters(self): reset(self.mlp) if self.msg_norm is not None: self.msg_norm.reset_parameters() if self.t and isinstance(self.t, Tensor): self.t.data.fill_(self.initial_t) if self.p and isinstance(self.p, Tensor): self.p.data.fill_(self.initial_p)
def reset_parameters(self): reset(self.msg_up_nn) reset(self.msg_boundaries_nn) reset(self.update_up_nn) reset(self.update_boundaries_nn) reset(self.combine_nn) self.eps1.data.fill_(self.initial_eps) self.eps2.data.fill_(self.initial_eps)
def reset_parameters(self): if self.edge_dim is not None: self.edge_encoder.reset_parameters() for nn in self.pre_nns: reset(nn) for nn in self.post_nns: reset(nn) self.lin.reset_parameters()
def test_reset(): nn = Lin(16, 16) w = nn.weight.clone() reset(nn) assert not nn.weight.tolist() == w.tolist() nn = Seq(Lin(16, 16), ReLU(), Lin(16, 16)) w_1, w_2 = nn[0].weight.clone(), nn[2].weight.clone() reset(nn) assert not nn[0].weight.tolist() == w_1.tolist() assert not nn[2].weight.tolist() == w_2.tolist()
def reset_parameters(self): reset(self.k_lin) reset(self.q_lin) reset(self.v_lin) reset(self.a_lin) ones(self.skip) ones(self.p_rel) glorot(self.a_rel) glorot(self.m_rel)
def reset_parameters(self): reset(self.nn) if self.root_weight: self.lin.reset_parameters() zeros(self.bias)
def reset_parameters(self): reset(self.edge_encoder) reset(self.mlp) self.eps.data.fill_(self.initial_eps)
def reset_parameters(self): reset(self.edge_encoder) reset(self.mlp)
def reset_parameters(self): reset(self.fc1) reset(self.fc3) reset(self.fc4)
def reset_parameters(self): reset(self.scorer) reset(self.encoder) reset(self.pool)
def reset_parameters(self): reset(self.gate_nn) reset(self.nn)
def reset_parameters(self): reset(self.layer)
def reset_parameters(self): reset(self.layer1) reset(self.layerg) reset(self.layer2) reset(self.layerg2) reset(self.layer3) reset(self.layerg3) reset(self.layer4) reset(self.layer5)
def reset_parameters(self): reset(self.mlp1) reset(self.mlp2) reset(self.conv)
def reset_parameters(self): reset(self.nn) uniform(self.in_channels, self.root) uniform(self.in_channels, self.bias)
def reset_parameters(self): reset(self.attn)
def reset_parameters(self): for psi in self.psi1_stack: psi.reset_parameters() # self.psi_2.reset_parameters() reset(self.mlp)
def reset_parameters(self): reset(self.edge_nn) reset(self.message_lstm) if self.root: reset(self.root) zeros(self.bias)
def reset_parameters(self): self.lamb.data.fill_(self.initial_lamb) reset(self.optimizer) reset(self.potential)
def reset_parameters(self): reset(self.nn1) self.eps.data.fill_(self.initial_eps)
def reset_parameters(self): self.psi_1.reset_parameters() self.psi_2.reset_parameters() reset(self.mlp)
def reset_parameters(self): reset(self.conv)
def __init__(self, encoder, discriminator, decoder=None): super(ARGA, self).__init__(encoder, decoder) self.discriminator = discriminator reset(self.discriminator)
def reset_parameters(self): for conv in self.convs.values(): reset(conv)
def reset_parameters(self): super(ARGA, self).reset_parameters() reset(self.discriminator)
def reset_parameters(self): reset(self.local_nn) reset(self.global_nn)
def reset_parameters(self): reset(self.encoder) reset(self.decoder)
def reset_parameters(self): reset(self.nn)
def reset_parameters(self): reset(self.proj) glorot(self.lin_src) glorot(self.lin_dst) self.k_lin.reset_parameters() glorot(self.q)