Exemplo n.º 1
0
 def __init__(self, input_dim, output_dim):
     super(Log_w_DisNet, self).__init__()
     self.dimLinear = nn.Linear(input_dim, output_dim, bias=None)
     if check_gpu():
         self.dimLinear = self.dimLinear.cuda()
     # nn.init.kaiming_uniform_(self.dimLinear.weight)
     nn.init.eye_(self.dimLinear.weight)
Exemplo n.º 2
0
    def forward(self, x):
        feature = self.encoding(x)
        dist = self.calc_distance_between_codebooks(feature)
        if check_gpu():
            dist = dist.cuda()
        # classifier_output = self.classifier(dist).cpu()

        dist2 = dist * -5
        dist2 = nn.functional.softmax(dist2, dim=1)
        one_hot_labels = one_hot(self.n_class, self.dictLayer.labels).float().cuda()
        classifier_output = dist2 @ one_hot_labels
        return classifier_output.cpu(), dist.cpu()
Exemplo n.º 3
0
    def forward(self, x, y, dij=None):
        x = LogEigFunction.apply(x)
        x = x.view(x.size(0), -1)
        y = LogEigFunction.apply(y)
        y = y.view(y.size(0), -1)
        fe = pair(y, x, dij)
        if check_gpu():
            fe = fe.cuda()

        x1 = fe[:, : x.size(1)].contiguous()
        x2 = fe[:, x.size(1) :].contiguous()
        d = torch.norm(x1 - x2 + 1e-16, dim=1)
        return d
Exemplo n.º 4
0
    def __init__(self, args):
        super(BMS_Net, self).__init__()
        self.ep = args.ep
        dims = [int(i) for i in args.dims.split(",")]
        self.feature = []
        for i in range(len(dims) - 2):
            self.feature.append(BiMap(dims[i], dims[i + 1]))
            self.feature.append(ReEig(self.ep))
        self.feature.append(BiMap(dims[-2], dims[-1]))
        self.feature = nn.Sequential(*self.feature)
        self.dictLayer = RieDictionaryLayer(args.n_atom, dims[-1], args.margin1)
        if args.metric_method == "log":
            self.distFun = Log_DisNet(dims[-1] ** 2)
        elif args.metric_method == "log_w":
            self.distFun = Log_w_DisNet(dims[-1] ** 2, args.log_dim)
        elif args.metric_method == "jbld":
            self.distFun = JBLD_DisNet()
        else:
            raise NotImplementedError
        self.margin2 = args.margin2
        self.n_class = args.n_class
        self.use_intra_loss = args.lambda2 != 0
        self.use_triplet_loss = args.lambda1 != 0

        classifier = []
        if args.n_fc == 1:
            classifier.append(nn.Linear(args.n_atom, args.n_class, bias=None))
        else:
            classifier.append(nn.Linear(args.n_atom, args.n_fc_node, bias=None))
            for i in range(args.n_fc - 2):
                classifier.append(nn.ReLU(True))
                classifier.append(nn.Linear(args.n_fc_node, args.n_fc_node, bias=None))
            classifier.append(nn.ReLU(True))
            classifier.append(nn.Linear(args.n_fc_node, args.n_class, bias=None))

        self.classifier = nn.Sequential(*classifier)
        if check_gpu():
            self.classifier = self.classifier.cuda()