Example #1
0
 def __init__(self, reduction="mean"):
     super(G_Loss, self).__init__(reduction)
     self.sig = SigmoidCrossEntropyWithLogits()
     self.l1_loss = nn.L1Loss()
     self.ones = ops.OnesLike()
     self.LAMBDA_GAN = args.LAMBDA_GAN
     self.LAMBDA_L1 = args.LAMBDA_L1
    def construct(self, x, seq_lengths):
        """Defines the ReverseSequence operator computation performed."""
        batch_size = x.shape[self.batch_dim]
        max_seq_len = x.shape[self.seq_dim]
        seq_lens_type = seq_lengths.dtype

        back = ops.Sub()(seq_lengths, ops.OnesLike()(seq_lengths))

        batch_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type,
                                    0)
        forward_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type,
                                      1)

        back = back.view(-1, 1)
        reverse_idx = ops.Sub()(back, forward_idx)

        condition = ops.Less()(reverse_idx, ops.ZerosLike()(reverse_idx))
        reverse_idx = ops.Select()(condition, forward_idx, reverse_idx)

        reverse_idx = ops.ExpandDims()(reverse_idx, 2)
        batch_idx = ops.ExpandDims()(batch_idx, 2)

        if self.batch_dim > self.seq_dim:
            batch_idx = ops.Transpose()(batch_idx, (1, 0, 2))
            reverse_idx = ops.Transpose()(reverse_idx, (1, 0, 2))
            x = ops.Transpose()(x, (1, 0, 2))
        start_indices = ops.Concat(2)((batch_idx, reverse_idx))

        output = ops.GatherNd()(x, start_indices)

        return output
Example #3
0
 def __init__(self, mode="lsgan", reduction='mean'):
     super(GANLoss, self).__init__()
     self.loss = None
     self.ones = ops.OnesLike()
     if mode == "lsgan":
         self.loss = nn.MSELoss(reduction)
     elif mode == "vanilla":
         self.loss = BCEWithLogits(reduction)
     else:
         raise NotImplementedError(
             f'GANLoss {mode} not recognized, we support lsgan and vanilla.'
         )
Example #4
0
 def __init__(self, reduction="mean"):
     super(D_Loss, self).__init__(reduction)
     self.sig = SigmoidCrossEntropyWithLogits()
     self.ones = ops.OnesLike()
     self.zeros = ops.ZerosLike()
     self.LAMBDA_Dis = args.LAMBDA_Dis
Example #5
0
    def construct(self, inputs, targets):
        """
        Args:
        - inputs: feature matrix with shape (batch_size, feat_dim)
        - targets: ground truth labels with shape (num_classes)
        """
        n = inputs.shape[0]

        # Compute pairwise distance, replace by the official when merged
        pow = P.Pow()
        sum = P.ReduceSum(keep_dims=True)
        expand = P.BroadcastTo((n, n))
        transpose = P.Transpose()
        mul = P.Mul()
        add = P.Add()
        sqrt = P.Sqrt()
        equal = P.Equal()
        cat = P.Concat()
        ones_like = P.OnesLike()

        dist = pow(inputs, 2)
        dist = sum(dist, axis=1)
        dist = expand(dist)
        dist = dist + transpose(dist, (1, 0))

        temp1 = P.matmul(inputs, transpose(inputs, (1, 0)))
        temp1 = mul(-2, temp1)
        dist = add(dist, temp1)
        dist = P.composite.clip_by_value(
            dist, clip_value_min=1e-12, clip_value_max=100000000
        )  # for numerical stability, clip_value_max=? why must set?
        dist = sqrt(dist)

        # For each anchor, find the hardest positive and negative
        targets = expand(targets)
        mask = equal(targets, transpose(targets, (1, 0)))
        dist_ap = []
        dist_an = []

        # only for debugging
        #####################
        # print("dist is")
        # print(dist.shape)
        # print(dist)
        # print("mask is")
        # print(mask.shape)
        # print(mask)
        # print(mask[0])
        #####################

        for i in range(n):
            minval = -1.0
            maxval = -1.0
            for j in range(n):
                if mask[i][j] and dist[i][j] > maxval:
                    maxval = dist[i][j]
                if not mask[i][j] and (dist[i][j] < minval or minval == -1):
                    minval = dist[i][j]

            if (not isinstance(minval, Tensor)
                    or not isinstance(maxval, Tensor) or minval == -1.0
                    or maxval == -1.0):
                if self.error_msg is not None:
                    print("Error Msg", file=self.error_msg)
                    print("mask {} is".format(i), file=self.error_msg)
                    print(mask[i], file=self.error_msg)
                    print("dist is:", file=self.error_msg)
                    print(dist[i], file=self.error_msg)
                    print(maxval, file=self.error_msg)
                    print(minval, file=self.error_msg)
                    print(type(maxval), file=self.error_msg)
                    print(type(minval), file=self.error_msg)
                    self.error_msg.flush()

            # assert minval != -1.0 and isinstance(minval, Tensor)
            # assert maxval != -1.0 and isinstance(maxval, Tensor)
            dist_ap.append(maxval.asnumpy())
            dist_an.append(minval.asnumpy())

        dist_ap = Tensor(dist_ap, ms.float32)
        dist_an = Tensor(dist_an, ms.float32)
        # only for debugging
        #####################
        # print(dist_ap)
        # print(dist_ap.shape)
        # print(dist_an)
        #####################

        # Compute ranking hinge loss
        y = ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)

        # # compute accuracy
        # correct = torch.ge(dist_an, dist_ap).sum().item()
        return loss


# class GradOriTripletLoss(nn.Cell)
#     def __init__(self, net):
#         super(GradOriTripletLoss, self).__init__()
#         self.net = net
#         self.grad_op = P.GradOperation(get_all=True)
#
#     def construct(self, inputs, targets):
#         gradient_function = self.grad_op(self.net)
#         return gradient_function(inputs, targets)
Example #6
0
 def __init__(self, G_A, G_B, use_identity=True):
     super(Generator, self).__init__()
     self.G_A = G_A
     self.G_B = G_B
     self.ones = ops.OnesLike()
     self.use_identity = use_identity