def __init__(self, c, args): super(HNN, self).__init__(c) self.manifold = getattr(manifolds, args.manifold)() assert args.num_layers > 1 dims, acts, _ = hyp_layers.get_dim_act_curv(args) hnn_layers = [] for i in range(len(dims) - 1): in_dim, out_dim = dims[i], dims[i + 1] act = acts[i] hnn_layers.append( hyp_layers.HNNLayer(self.manifold, in_dim, out_dim, self.c, args.dropout, act, args.bias)) self.layers = nn.Sequential(*hnn_layers) self.encode_graph = False
def __init__(self, c, args, task): super(HNNDecoder, self).__init__(c) self.manifold = getattr(manifolds, args.manifold)() if not args.cuda == -1: c = torch.Tensor([c]).to(args.device) if task == 'nc': self.input_dim = args.dim self.output_dim = args.n_classes self.bias = args.bias self.classifier = Linear(self.input_dim, self.output_dim, args.dropout, lambda x: x, self.bias) self.decode_adj = False elif task == 'rec': assert args.num_layers > 0 dims, acts, _ = hyp_layers.get_dim_act_curv(args) dims = dims[::-1] acts = acts[::-1][:-1] + [lambda x: x] # Last layer without act encdec_share_curvature = False hnn_layers = [] num_dec_layers = args.num_dec_layers for i in range(num_dec_layers): in_dim, out_dim = dims[i], dims[i + 1] act = acts[i] c_in = c c_out = None if (i == num_dec_layers - 1) else c hnn_layers.append( hyp_layers.HNNLayer( self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias ) ) self.decoder = nn.Sequential(*hnn_layers) self.decode_adj = False else: raise RuntimeError('Unknown task')