def __init__(self, c, args): super(Shallow, self).__init__(c) self.manifold = getattr(manifolds, args.manifold)() self.use_feats = args.use_feats weights = torch.Tensor(args.n_nodes, args.dim) if not args.pretrained_embeddings: weights = self.manifold.init_weights(weights, self.c) trainable = True else: weights = torch.Tensor(np.load(args.pretrained_embeddings)) assert weights.shape[ 0] == args.n_nodes, "The embeddings you passed seem to be for another dataset." trainable = False self.lt = manifolds.ManifoldParameter(weights, trainable, self.manifold, self.c) self.all_nodes = torch.LongTensor(list(range(args.n_nodes))) layers = [] if args.pretrained_embeddings is not None and args.num_layers > 0: # MLP layers after pre-trained embeddings dims, acts = get_dim_act(args) if self.use_feats: dims[0] = args.feat_dim + weights.shape[1] else: dims[0] = weights.shape[1] for i in range(len(dims) - 1): in_dim, out_dim = dims[i], dims[i + 1] act = acts[i] layers.append( Linear(in_dim, out_dim, args.dropout, act, args.bias)) self.layers = nn.Sequential(*layers) self.encode_graph = False
def reset_parameteres(self): if not self.pretrained_embeddings: weights = self.manifold.init_weights(self.weights, self.c) trainable = True else: weights = torch.Tensor(np.load(self.pretrained_embeddings)) assert weights.shape[ 0] == self.n_nodes, "The embeddings you passed seem to be for another dataset." trainable = False self.lt = manifolds.ManifoldParameter(weights, trainable, self.manifold, self.c) self.all_nodes = torch.LongTensor(list(range(self.n_nodes))) for _layer in self.layers: _layer.reset_parameters()
def __init__(self, users_items, args): super(HGCFModel, self).__init__() self.c = torch.tensor([args.c]).to(default_device()) self.manifold = getattr(manifolds, "Hyperboloid")() self.nnodes = args.n_nodes self.encoder = getattr(encoders, "HGCF")(self.c, args) self.num_users, self.num_items = users_items self.margin = args.margin self.weight_decay = args.weight_decay self.num_layers = args.num_layers self.args = args self.embedding = nn.Embedding(num_embeddings=self.num_users + self.num_items, embedding_dim=args.embedding_dim).to(default_device()) self.embedding.state_dict()['weight'].uniform_(-args.scale, args.scale) self.embedding.weight = nn.Parameter(self.manifold.expmap0(self.embedding.state_dict()['weight'], self.c)) self.embedding.weight = manifolds.ManifoldParameter(self.embedding.weight, True, self.manifold, self.c)