def reset_parameters(self):
        self.psi_1.reset_parameters()
        for psi in self.psi_stack:
            psi.reset_parameters()

        reset(self.mlp)
        self.sum_weights = Parameter(torch.ones(self.num_psi + 1))
 def reset_parameters(self):
     reset(self.mlp)
     if self.msg_norm is not None:
         self.msg_norm.reset_parameters()
     if self.t and isinstance(self.t, Tensor):
         self.t.data.fill_(self.initial_t)
     if self.p and isinstance(self.p, Tensor):
         self.p.data.fill_(self.initial_p)
Beispiel #3
0
 def reset_parameters(self):
     reset(self.msg_up_nn)
     reset(self.msg_boundaries_nn)
     reset(self.update_up_nn)
     reset(self.update_boundaries_nn)
     reset(self.combine_nn)
     self.eps1.data.fill_(self.initial_eps)
     self.eps2.data.fill_(self.initial_eps)
Beispiel #4
0
 def reset_parameters(self):
     if self.edge_dim is not None:
         self.edge_encoder.reset_parameters()
     for nn in self.pre_nns:
         reset(nn)
     for nn in self.post_nns:
         reset(nn)
     self.lin.reset_parameters()
Beispiel #5
0
def test_reset():
    nn = Lin(16, 16)
    w = nn.weight.clone()
    reset(nn)
    assert not nn.weight.tolist() == w.tolist()

    nn = Seq(Lin(16, 16), ReLU(), Lin(16, 16))
    w_1, w_2 = nn[0].weight.clone(), nn[2].weight.clone()
    reset(nn)
    assert not nn[0].weight.tolist() == w_1.tolist()
    assert not nn[2].weight.tolist() == w_2.tolist()
Beispiel #6
0
 def reset_parameters(self):
     reset(self.k_lin)
     reset(self.q_lin)
     reset(self.v_lin)
     reset(self.a_lin)
     ones(self.skip)
     ones(self.p_rel)
     glorot(self.a_rel)
     glorot(self.m_rel)
 def reset_parameters(self):
     reset(self.nn)
     if self.root_weight:
         self.lin.reset_parameters()
     zeros(self.bias)
 def reset_parameters(self):
     reset(self.edge_encoder)
     reset(self.mlp)
     self.eps.data.fill_(self.initial_eps)
 def reset_parameters(self):
     reset(self.edge_encoder)
     reset(self.mlp)
Beispiel #10
0
 def reset_parameters(self):
     reset(self.fc1)
     reset(self.fc3)
     reset(self.fc4)
Beispiel #11
0
 def reset_parameters(self):
     reset(self.scorer)
     reset(self.encoder)
     reset(self.pool)
Beispiel #12
0
 def reset_parameters(self):
     reset(self.gate_nn)
     reset(self.nn)
Beispiel #13
0
 def reset_parameters(self):
     reset(self.layer)
 def reset_parameters(self):
     reset(self.layer1)
     reset(self.layerg)
     reset(self.layer2)
     reset(self.layerg2)
     reset(self.layer3)
     reset(self.layerg3)
     reset(self.layer4)
     reset(self.layer5)
 def reset_parameters(self):
     reset(self.mlp1)
     reset(self.mlp2)
     reset(self.conv)
Beispiel #16
0
 def reset_parameters(self):
     reset(self.nn)
     uniform(self.in_channels, self.root)
     uniform(self.in_channels, self.bias)
Beispiel #17
0
 def reset_parameters(self):
     reset(self.attn)
Beispiel #18
0
 def reset_parameters(self):
     for psi in self.psi1_stack:          
       psi.reset_parameters()
     # self.psi_2.reset_parameters()
     reset(self.mlp)
Beispiel #19
0
 def reset_parameters(self):
     reset(self.edge_nn)
     reset(self.message_lstm)
     if self.root:
         reset(self.root)
     zeros(self.bias)
 def reset_parameters(self):
     self.lamb.data.fill_(self.initial_lamb)
     reset(self.optimizer)
     reset(self.potential)
Beispiel #21
0
 def reset_parameters(self):
     reset(self.nn1)
     self.eps.data.fill_(self.initial_eps)
Beispiel #22
0
 def reset_parameters(self):
     self.psi_1.reset_parameters()
     self.psi_2.reset_parameters()
     reset(self.mlp)
Beispiel #23
0
 def reset_parameters(self):
     reset(self.conv)
Beispiel #24
0
 def __init__(self, encoder, discriminator, decoder=None):
     super(ARGA, self).__init__(encoder, decoder)
     self.discriminator = discriminator
     reset(self.discriminator)
Beispiel #25
0
 def reset_parameters(self):
     for conv in self.convs.values():
         reset(conv)
Beispiel #26
0
 def reset_parameters(self):
     super(ARGA, self).reset_parameters()
     reset(self.discriminator)
 def reset_parameters(self):
     reset(self.local_nn)
     reset(self.global_nn)
Beispiel #28
0
 def reset_parameters(self):
     reset(self.encoder)
     reset(self.decoder)
Beispiel #29
0
 def reset_parameters(self):
     reset(self.nn)
Beispiel #30
0
 def reset_parameters(self):
     reset(self.proj)
     glorot(self.lin_src)
     glorot(self.lin_dst)
     self.k_lin.reset_parameters()
     glorot(self.q)