def __init__( self, initial_size, output_size, middle_size_1, middle_size_2, num_nodes, num_relations, device1, device2, ): super().__init__() self.device1 = device1 self.device2 = device2 self.emb = nn.parameter.Parameter(torch.ones( (num_nodes, initial_size)), requires_grad=True).to(device1) glorot_orthogonal(self.emb, 1) self.conv1 = gnn.RGCNConv(initial_size, middle_size_1, num_relations, num_bases=12).to(device1) self.conv2 = gnn.RGCNConv(middle_size_1, middle_size_2, num_relations, num_bases=12).to(device1) self.conv3 = gnn.RGCNConv( middle_size_2, output_size - middle_size_2 - middle_size_1 - initial_size, num_relations, num_bases=12, ).to(device1) self.drop = nn.Dropout(0.2)
def __init__(self, initial_size, num_nodes, num_relations, *args, **kwargs): super().__init__() self.emb = nn.parameter.Parameter(torch.ones( (num_nodes, initial_size)), requires_grad=True) glorot_orthogonal(self.emb, 1) self.drop = nn.Dropout(0.2)
def reset_parameters(self): glorot_orthogonal(self.lin_rbf.weight, scale=2.0) glorot_orthogonal(self.lin_sbf.weight, scale=2.0) glorot_orthogonal(self.lin_kj.weight, scale=2.0) self.lin_kj.bias.data.fill_(0) glorot_orthogonal(self.lin_ji.weight, scale=2.0) self.lin_ji.bias.data.fill_(0) self.W.data.normal_(mean=0, std=2 / self.W.size(0)) for res_layer in self.layers_before_skip: res_layer.reset_parameters() glorot_orthogonal(self.lin.weight, scale=2.0) self.lin.bias.data.fill_(0) for res_layer in self.layers_after_skip: res_layer.reset_parameters()
def reset_parameters(self): glorot_orthogonal(self.lin_rbf.weight, scale=2.0) glorot_orthogonal(self.lin_up.weight, scale=2.0) for lin in self.lins: glorot_orthogonal(lin.weight, scale=2.0) lin.bias.data.fill_(0) self.lin.weight.data.fill_(0)
def reset_parameters(self): glorot_orthogonal(self.lin_up.weight, scale=2.0) for lin in self.lins: glorot_orthogonal(lin.weight, scale=2.0) lin.bias.data.fill_(0) if self.output_init == 'zeros': self.lin.weight.data.fill_(0) if self.output_init == 'GlorotOrthogonal': glorot_orthogonal(self.lin.weight, scale=2.0)
def reset_parameters(self): glorot_orthogonal(self.lin_rbf1.weight, scale=2.0) glorot_orthogonal(self.lin_rbf2.weight, scale=2.0) glorot_orthogonal(self.lin_sbf1.weight, scale=2.0) glorot_orthogonal(self.lin_sbf2.weight, scale=2.0) glorot_orthogonal(self.lin_kj.weight, scale=2.0) self.lin_kj.bias.data.fill_(0) glorot_orthogonal(self.lin_ji.weight, scale=2.0) self.lin_ji.bias.data.fill_(0) glorot_orthogonal(self.lin_down.weight, scale=2.0) glorot_orthogonal(self.lin_up.weight, scale=2.0) for res_layer in self.layers_before_skip: res_layer.reset_parameters() glorot_orthogonal(self.lin.weight, scale=2.0) self.lin.bias.data.fill_(0) for res_layer in self.layers_before_skip: res_layer.reset_parameters()
def reset_parameters(self): glorot_orthogonal(self.lin1.weight, scale=2.0) self.lin1.bias.data.fill_(0) glorot_orthogonal(self.lin2.weight, scale=2.0) self.lin2.bias.data.fill_(0)
def reset_parameters(self): self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) self.lin_rbf_0.reset_parameters() self.lin.reset_parameters() glorot_orthogonal(self.lin_rbf_1.weight, scale=2.0)
def __init__(self, input_size, num_relations): super().__init__() self.rel = nn.parameter.Parameter(torch.ones( (num_relations, input_size)), requires_grad=True) glorot_orthogonal(self.rel, 1)