示例#1
0
 def __init__(
     self,
     initial_size,
     output_size,
     middle_size_1,
     middle_size_2,
     num_nodes,
     num_relations,
     device1,
     device2,
 ):
     super().__init__()
     self.device1 = device1
     self.device2 = device2
     self.emb = nn.parameter.Parameter(torch.ones(
         (num_nodes, initial_size)),
                                       requires_grad=True).to(device1)
     glorot_orthogonal(self.emb, 1)
     self.conv1 = gnn.RGCNConv(initial_size,
                               middle_size_1,
                               num_relations,
                               num_bases=12).to(device1)
     self.conv2 = gnn.RGCNConv(middle_size_1,
                               middle_size_2,
                               num_relations,
                               num_bases=12).to(device1)
     self.conv3 = gnn.RGCNConv(
         middle_size_2,
         output_size - middle_size_2 - middle_size_1 - initial_size,
         num_relations,
         num_bases=12,
     ).to(device1)
     self.drop = nn.Dropout(0.2)
示例#2
0
 def __init__(self, initial_size, num_nodes, num_relations, *args,
              **kwargs):
     super().__init__()
     self.emb = nn.parameter.Parameter(torch.ones(
         (num_nodes, initial_size)),
                                       requires_grad=True)
     glorot_orthogonal(self.emb, 1)
     self.drop = nn.Dropout(0.2)
示例#3
0
 def reset_parameters(self):
     glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
     glorot_orthogonal(self.lin_sbf.weight, scale=2.0)
     glorot_orthogonal(self.lin_kj.weight, scale=2.0)
     self.lin_kj.bias.data.fill_(0)
     glorot_orthogonal(self.lin_ji.weight, scale=2.0)
     self.lin_ji.bias.data.fill_(0)
     self.W.data.normal_(mean=0, std=2 / self.W.size(0))
     for res_layer in self.layers_before_skip:
         res_layer.reset_parameters()
     glorot_orthogonal(self.lin.weight, scale=2.0)
     self.lin.bias.data.fill_(0)
     for res_layer in self.layers_after_skip:
         res_layer.reset_parameters()
示例#4
0
 def reset_parameters(self):
     glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
     glorot_orthogonal(self.lin_up.weight, scale=2.0)
     for lin in self.lins:
         glorot_orthogonal(lin.weight, scale=2.0)
         lin.bias.data.fill_(0)
     self.lin.weight.data.fill_(0)
示例#5
0
文件: model.py 项目: zhy5186612/DIG
 def reset_parameters(self):
     glorot_orthogonal(self.lin_up.weight, scale=2.0)
     for lin in self.lins:
         glorot_orthogonal(lin.weight, scale=2.0)
         lin.bias.data.fill_(0)
     if self.output_init == 'zeros':
         self.lin.weight.data.fill_(0)
     if self.output_init == 'GlorotOrthogonal':
         glorot_orthogonal(self.lin.weight, scale=2.0)
示例#6
0
    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf1.weight, scale=2.0)
        glorot_orthogonal(self.lin_rbf2.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf1.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf2.weight, scale=2.0)

        glorot_orthogonal(self.lin_kj.weight, scale=2.0)
        self.lin_kj.bias.data.fill_(0)
        glorot_orthogonal(self.lin_ji.weight, scale=2.0)
        self.lin_ji.bias.data.fill_(0)

        glorot_orthogonal(self.lin_down.weight, scale=2.0)
        glorot_orthogonal(self.lin_up.weight, scale=2.0)

        for res_layer in self.layers_before_skip:
            res_layer.reset_parameters()
        glorot_orthogonal(self.lin.weight, scale=2.0)
        self.lin.bias.data.fill_(0)
        for res_layer in self.layers_before_skip:
            res_layer.reset_parameters()
示例#7
0
 def reset_parameters(self):
     glorot_orthogonal(self.lin1.weight, scale=2.0)
     self.lin1.bias.data.fill_(0)
     glorot_orthogonal(self.lin2.weight, scale=2.0)
     self.lin2.bias.data.fill_(0)
示例#8
0
文件: model.py 项目: zhy5186612/DIG
 def reset_parameters(self):
     self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
     self.lin_rbf_0.reset_parameters()
     self.lin.reset_parameters()
     glorot_orthogonal(self.lin_rbf_1.weight, scale=2.0)
示例#9
0
 def __init__(self, input_size, num_relations):
     super().__init__()
     self.rel = nn.parameter.Parameter(torch.ones(
         (num_relations, input_size)),
                                       requires_grad=True)
     glorot_orthogonal(self.rel, 1)