예제 #1
0
파일: hyperAE.py 프로젝트: neale/HyperMT
def z_loss(args, real, fake):
    zero = torch.zeros_like(fake)
    one = torch.ones_like(real)
    d_fake = F.mse(fake, one)
    d_real = F.binary_cross_entropy_with_logits(real, zero)
    d_real_trick = F.binary_cross_entropy_with_logits(real, one)
    loss_z = 10 * (d_fake + d_real)
    return loss_z, d_real_trick
예제 #2
0
def psnr(input: torch.Tensor, target: torch.Tensor,
         max_val: float) -> torch.Tensor:
    r"""Creates a function that calculates the PSNR between 2 images.

    PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error.
    Given an m x n image, the PSNR is:

    .. math::

        \text{PSNR} = 10 \log_{10} \bigg(\frac{\text{MAX}_I^2}{MSE(I,T)}\bigg)

    where

    .. math::

        \text{MSE}(I,T) = \frac{1}{mn}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - T(i,j)]^2

    and :math:`\text{MAX}_I` is the maximum possible input value
    (e.g for floating point images :math:`\text{MAX}_I=1`).

    Args:
        input (torch.Tensor): the input image with arbitrary shape :math:`(*)`.
        labels (torch.Tensor): the labels image with arbitrary shape :math:`(*)`.
        max_val (float): The maximum value in the input tensor.

    Return:
        torch.Tensor: the computed loss as a scalar.

    Examples:
        >>> ones = torch.ones(1)
        >>> psnr(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10)
        tensor(20.0000)

    Reference:
        https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio#Definition
    """
    if not isinstance(input, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor but got {type(target)}.")

    if not isinstance(target, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor but got {type(input)}.")

    if input.shape != target.shape:
        raise TypeError(
            f"Expected tensors of equal shapes, but got {input.shape} and {target.shape}"
        )

    return 10.0 * torch.log10(
        max_val**2 / mse(input, target, reduction='mean'))
예제 #3
0
def metric_fn(model, data):
    img, target = data
    prediction = model(img)
    metric = F.mse(prediction, target)
    return metric
예제 #4
0
def evaluate(generator,
             model,
             criterion,
             optimizer,
             normalizer,
             device,
             task="train",
             verbose=False):
    """
    evaluate the model
    """

    if task == "test":
        model.eval()
        test_targets = []
        test_pred = []
        test_std = []
        test_ids = []
        test_comp = []
    else:
        loss_meter = AverageMeter()
        rmse_meter = AverageMeter()
        mae_meter = AverageMeter()
        if task == "val":
            model.eval()
        elif task == "train":
            model.train()
        else:
            raise NameError("Only train, val or test is allowed as task")

    with trange(len(generator), disable=(not verbose)) as t:
        for input_, target, batch_comp, batch_ids in generator:

            # normalize target
            target_norm = normalizer.norm(target)

            # move tensors to GPU
            input_ = (tensor.to(device) for tensor in input_)
            target_norm = target_norm.to(device)

            # compute output
            output, log_std = model(*input_).chunk(2, dim=1)

            # get predictions and error
            pred = normalizer.denorm(output.data.cpu())

            if task == "test":
                # get the aleatoric std
                std = torch.exp(log_std).data.cpu() * normalizer.std

                # collect the model outputs
                test_ids += batch_ids
                test_comp += batch_comp
                test_targets += target.view(-1).tolist()
                test_pred += pred.view(-1).tolist()
                test_std += std.view(-1).tolist()

            else:
                loss = criterion(output, log_std, target_norm)
                loss_meter.update(loss.data.cpu().item(), target.size(0))

                mae_error = mae(pred, target)
                mae_meter.update(mae_error, target.size(0))

                rmse_error = mse(pred, target).sqrt_()
                rmse_meter.update(rmse_error, target.size(0))

                if task == "train":
                    # compute gradient and do SGD step
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            t.update()

    if task == "test":
        return test_ids, test_comp, test_targets, test_pred, test_std
    else:
        return loss_meter.avg, mae_meter.avg, rmse_meter.avg
예제 #5
0
 def MSE(self, preds):
     # We are using a build in loss function
     return mse(preds, self.targets)