Exemplo n.º 1
0
    def sum_losses(self, losses: List[Loss]) -> Loss:
        res = Loss.ZERO()
        coef = self.get_coef()
        for i, l in enumerate(losses):
            res = res + l * coef[i].detach()

        return res
 def forward(self, m1: ProbabilityMeasure, m2: ProbabilityMeasure):
     batch_loss = self.loss(m1.probability, m1.coord, m2.probability,
                            m2.coord)
     if self.border:
         batch_loss = batch_loss[batch_loss > self.border]
         if batch_loss.shape[0] == 0:
             return Loss.ZERO()
     return Loss(batch_loss.mean())
Exemplo n.º 3
0
 def add_generator_loss(self, loss: nn.Module, weight=1.0):
     return self.__add__(
         GANLossObject(
             lambda dx, dy: Loss.ZERO(),
             lambda dgz, real, fake: Loss(loss(fake[0], real[0].detach()) * weight),
             self.discriminator
         )
     )
Exemplo n.º 4
0
 def __call__(self, *args, **kwargs):
     self.CLLLeT4uK += 1
     if self.cond(self.CLLLeT4uK):
         if self.preproc:
             return self.penalty(*self.preproc(*args, **kwargs))
         else:
             return self.penalty(*args, **kwargs)
     else:
         return Loss.ZERO()
def hm_svoego_roda_loss(pred, target, coef=1.0, l1_coef=0.0):
    pred_mes = UniformMeasure2DFactory.from_heatmap(pred)
    target_mes = UniformMeasure2DFactory.from_heatmap(target)

    # pred = pred.relu() + 1e-15
    # target[target < 1e-7] = 0
    # target[target > 1 - 1e-7] = 1

    if torch.isnan(pred).any() or torch.isnan(target).any():
        print("nan in hm")
        return Loss.ZERO()

    bce = nn.BCELoss()(pred, target)

    if torch.isnan(bce).any():
        print("nan in bce")
        return Loss.ZERO()

    return Loss(bce * coef + nn.MSELoss()(pred_mes.coord, target_mes.coord) *
                (0.0005 * coef) +
                nn.L1Loss()(pred_mes.coord, target_mes.coord) * l1_coef)