예제 #1
0
 def __init__(self, c, args):
     super(HNN, self).__init__(c)
     self.manifold = getattr(manifolds, args.manifold)()
     assert args.num_layers > 1
     dims, acts, _ = hyp_layers.get_dim_act_curv(args)
     hnn_layers = []
     for i in range(len(dims) - 1):
         in_dim, out_dim = dims[i], dims[i + 1]
         act = acts[i]
         hnn_layers.append(
             hyp_layers.HNNLayer(self.manifold, in_dim, out_dim, self.c,
                                 args.dropout, act, args.bias))
     self.layers = nn.Sequential(*hnn_layers)
     self.encode_graph = False
예제 #2
0
    def __init__(self, c, args, task):
        super(HNNDecoder, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)()



        if not args.cuda == -1:
            c = torch.Tensor([c]).to(args.device)

        if task == 'nc':
            self.input_dim = args.dim
            self.output_dim = args.n_classes
            self.bias = args.bias
            self.classifier = Linear(self.input_dim, self.output_dim, args.dropout, lambda x: x, self.bias)
            self.decode_adj = False

        elif task == 'rec':
            assert args.num_layers > 0

            dims, acts, _ = hyp_layers.get_dim_act_curv(args)
            dims = dims[::-1]
            acts = acts[::-1][:-1] + [lambda x: x] # Last layer without act

            encdec_share_curvature = False

            hnn_layers = []
            num_dec_layers = args.num_dec_layers
            for i in range(num_dec_layers):
                in_dim, out_dim = dims[i], dims[i + 1]
                act = acts[i]
                c_in = c
                c_out = None if (i == num_dec_layers - 1) else c

                hnn_layers.append(
                    hyp_layers.HNNLayer(
                            self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias
                    )
                )

            self.decoder = nn.Sequential(*hnn_layers)
            self.decode_adj = False
        else:
            raise RuntimeError('Unknown task')