Exemplo n.º 1
0
def masked_loss(
    got,
    exp,
    throughput,
    exp_mask,
    eps: float = 1e-10,
    trim: int = 0,
    mask_weight: float = 1,
    with_logits: bool = True,
    tone_mapping: bool = False,
):
    active = ((throughput > 0) & (exp_mask == 1)).squeeze(-1)
    misses = ~active

    color_loss = 0
    if active.any():
        if tone_mapping:
            got_active = got * active[..., None]
            exp_active = exp * active[..., None]
            got_active = (got_active) / (1 + got_active)
            exp_active = (exp_active) / (1 + exp_active)
            l1_loss = F.l1_loss(got_active, exp_active)
            l2_loss = F.mse_loss(got_active, exp_active)
            rmse_loss = l2_loss.clamp(min=1e-10).sqrt()

            ssim_loss = -ssim(got_active.permute(0, 3, 1, 2),
                              exp_active.permute(0, 3, 1, 2),
                              data_range=1,
                              size_average=True).log()

            color_loss = l2_loss + rmse_loss + l1_loss + ssim_loss
        else:
            got_active = got * active[..., None]
            exp_active = exp * active[..., None]
            l1_loss = F.l1_loss(got_active, exp_active)
            l2_loss = F.mse_loss(got_active, exp_active)
            rmse_loss = l2_loss.clamp(min=1e-10).sqrt()
            ssim_loss = -ssim(got_active.permute(0, 3, 1, 2),
                              exp_active.permute(0, 3, 1, 2),
                              data_range=1,
                              size_average=True).log()
            color_loss = l2_loss + rmse_loss + l1_loss + ssim_loss

    # This case is hit if the mask intersects nothing
    mask_loss = 0
    if misses.any():
        loss_fn = F.binary_cross_entropy
        if with_logits: loss_fn = F.binary_cross_entropy_with_logits
        mask_loss = loss_fn(
            throughput[misses].reshape(-1, 1),
            exp_mask[misses].reshape(-1, 1),
        )
    out = mask_weight * mask_loss + 10 * color_loss
    return out
Exemplo n.º 2
0
    def predict(self, lr_img_name: str, img_name: str, save_img=True):
        """
        Produces the super-resolution of a single input image and outputs
        the PSNR and SSIM. Optionally saves the super-resoluted image in the
        current working directory.

        Parameters
        ----------
        lr_img_name : str
            The absolute path to the low resolution image.
        img_name : str
            The absolute path to the image to save
        save_img : bool, optional
            Whether the image should be save or not, by default True
        """
        with torch.no_grad():
            img_lr = util.uint2tensor4(util.imread_uint(lr_img_name))
            img_lr = img_lr.to(self.device)

            # Open HR image label
            img_hr = util.uint2tensor4(util.imread_uint(self.val[lr_img_name]))
            img_hr = img_hr.cpu()

            # Predict
            prediction1 = self.model1(img_lr).cpu()
            prediction2 = self.model2(img_lr).cpu()

            # Generate performance measures on validation data for
            # learning curves
            psnr1 = util.calculate_psnr(prediction1.numpy(), img_hr.numpy())
            ssim_1 = ssim(prediction1, img_hr)
            psnr2 = util.calculate_psnr(prediction2.numpy(), img_hr.numpy())
            ssim_2 = ssim(prediction2, img_hr)

            print(f"PSNR for model 1 ({str(self.model1)}): {psnr1}")
            print(f"SSIM for model 1 ({str(self.model1)}): {ssim_1}")
            print(f"PSNR for model 2 ({str(self.model2)}): {psnr2}")
            print(f"SSIM for model 2 ({str(self.model2)}): {ssim_2}")

            avg_loss1 = np.mean(self.lc1["loss"])
            avg_loss2 = np.mean(self.lc2["loss"])
            print(f"Average Loss for model 1 ({str(self.model1)}):" +
                  f"{avg_loss1}")
            print(f"Average Loss for model 2 ({str(self.model2)}): " +
                  f"{avg_loss2}")

            # Save image if needed
            if save_img:
                print("Saving images")
                img1 = util.tensor2uint(prediction1)
                img2 = util.tensor2uint(prediction2)

                util.imsave(img1, img_name + str(self.model1) + "_1.jpg")
                util.imsave(img2, img_name + str(self.model2) + "_2.jpg")
def train_for_model(dtn, train_mri_tensor, train_pet_tensor):
    # define the optimizers and loss functions
    optimizer = torch.optim.Adam(
        dtn.parameters(), lr=learning_rate)  # optimize all dtn parameters
    l2_loss = nn.MSELoss()  # MSEloss

    # perform the training
    counter = 0
    lamda = 0.7
    loss_history = []
    output = torch.empty(2, 1, 256, 256)
    nn.init.constant_(output, 0.3)
    for epoch in range(EPOCH):
        batch_idxs = 555 // batch_size
        for idx in range(0, batch_idxs):
            b_x = train_mri_tensor[idx * batch_size:(idx + 1) *
                                   batch_size, :, :, :].to(device)
            b_y = train_pet_tensor[idx * batch_size:(idx + 1) *
                                   batch_size, :, :, :].to(device)
            counter += 1
            output = dtn(b_x, b_y)  # dtn output
            #print('output.shape ',output.shape)

            ssim_loss_mri = 1 - ssim(output, b_x, data_range=1)
            ssim_loss_pet = 1 - ssim(output, b_y, data_range=1)
            ssim_total = ssim_loss_mri + ssim_loss_pet
            l2_total = l2_loss(output, b_x) + l2_loss(output, b_y)
            loss_total = (1 - lamda) * ssim_total + lamda * l2_total
            #print('loss_total: ', loss_total)

            optimizer.zero_grad()
            #loss_total.backward(retain_graph=True)
            loss_total.backward()
            optimizer.step()
            loss_history.append(loss_total.item())

            if counter % 25 == 0:
                print(
                    "Epoch: [%2d],step: [%2d], mri_ssim: [%.8f], pet_ssim: [%.8f],  total_ssim: [%.8f], total_l2: [%.8f], total_loss: [%.8f]"
                    % (epoch, counter, ssim_loss_mri, ssim_loss_pet,
                       ssim_total, l2_total, loss_total))

            if (epoch == EPOCH - 1):
                # Save a checkpoint
                torch.save(dtn.state_dict(),
                           './fusionDFP.pth',
                           _use_new_zipfile_serialization=False)

                return loss_history
Exemplo n.º 4
0
def calculate_mssim(minibatch, reconstr_image, size_average=True):
    """ compute the ms-sim between an image and its reconstruction

    :param minibatch: the input minibatch
    :param reconstr_image: the reconstructed image
    :returns: the msssim score
    :rtype: float

    """
    if minibatch.dim() == 5 and reconstr_image.dim(
    ) == 5:  # special case where we have temporal dim
        msssim = torch.cat([
            calculate_mssim(minibatch[:, i, ::],
                            reconstr_image[:, i, ::],
                            size_average=size_average).unsqueeze(-1)
            for i in range(minibatch.shape[1])
        ], -1)
        return torch.mean(msssim, -1)

    smallest_dim = min(minibatch.shape[-1], minibatch.shape[-2])
    if minibatch.dtype != reconstr_image.dtype:
        minibatch = minibatch.type(reconstr_image.dtype)

    if smallest_dim < 160:  # Limitation of ms-ssim library due to 4x downsample
        return 1 - ssim(X=minibatch,
                        Y=reconstr_image,
                        data_range=1,
                        size_average=size_average,
                        nonnegative_ssim=True)

    return 1 - ms_ssim(
        X=minibatch, Y=reconstr_image, data_range=1, size_average=size_average)
Exemplo n.º 5
0
 def forward(self, predictions, labels):
     ssim_loss = 1 - ssim(predictions,
                          labels,
                          data_range=1.0,
                          size_average=True,
                          nonnegative_ssim=True)
     return ssim_loss
Exemplo n.º 6
0
def validation(epoch, val_DLoader):  # val_DDataset,

    val_loss = 0
    start_time = time.time()

    for n_count, batch_yx in enumerate(val_DLoader):
        optimizer.zero_grad()
        if cuda:
            batch_x, batch_y, gd = batch_yx[0].cuda(), batch_yx[1].cuda(), batch_yx[2]

        output, illu_estim = model(batch_y)
        ssim_value = torch.mean(ssim(batch_x, output, data_range=1, size_average=False))
        loss = 1 - ssim_value

        val_loss += loss.item()

    elapsed_time = time.time() - start_time

    print('epoch val = %4d , loss = %4.10f , time = %4.2f s' % (
    epoch + 1, val_loss / (n_count * batch_size), elapsed_time))

    f = open(save_dir + '/validation_result.txt', "a+")
    if f is not None:
        f.write('epoch = %4d , loss = %4.10f , lr = %2.4f, time = %4.2fs \n' % (
        epoch + 1, val_loss / (n_count * batch_size), LEARNING_RATE, elapsed_time))
    f.close()
Exemplo n.º 7
0
def compund_mssim_l1_loss(pred: torch.Tensor, gt: torch.Tensor,
                          mean: torch.Tensor, std: torch.Tensor):
    # return (1 - 0.84) * F.l1_loss(pred, gt) + 0.84 * (1 - pt_msssim(pred, gt, val_range=(gt.max() - gt.min())))
    # https://pypi.org/project/pytorch-msssim/
    f1 = F.l1_loss(pred, gt)
    return (1 - 0.84) * f1 + 0.84 * (
        1 - ssim(gt, pred, size_average=True, nonnegative_ssim=True))
Exemplo n.º 8
0
    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_g_total = 0
        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_g_total += l_pix
        if self.cri_CX:
            real_fea = self.netF(self.real_H)
            fake_fea = self.netF(self.fake_H)
            l_CX = self.l_CX_w * self.cri_CX(real_fea, fake_fea)
            l_g_total += l_CX
        if self.cri_ssim:
            if self.cri_ssim == 'ssim':
                ssim_val = ssim(self.fake_H, self.real_H, win_size=self.ssim_window, data_range=1.0, size_average=True)
            elif self.cri_ssim == 'ms-ssim':
                weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363]).to(self.fake_H.device, dtype=self.fake_H.dtype)
                ssim_val = ms_ssim(self.fake_H, self.real_H, win_size=self.ssim_window, data_range=1.0, size_average=True, weights=weights)
            l_ssim = self.l_ssim_w * (1 - ssim_val)
            l_g_total += l_ssim

        l_g_total.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()
        if self.cri_CX:
            self.log_dict['l_CX'] = l_CX.item()
        if self.cri_ssim:
            self.log_dict['l_ssim'] = l_ssim.item()
Exemplo n.º 9
0
def loss(x, y, loss):
    """
    Only used for testing / validation
    
    x := predicted next frame
    y := true next frame
    loss := name of the loss <mse|mae|bce|bcel|ssim>
    """
    if loss == 'mse':
        return f.mse_loss(x, y)
    elif loss == 'mae':
        return f.l1_loss(x, y)
    elif loss == 'bce':
        return f.binary_cross_entropy(x, y)
    elif loss == 'bcel':
        return f.binary_cross_entropy_with_logits(x, y)
    elif loss == 'ssim':
        # The library is only for 4d tensors, but we have 5d (TxBxCxHxW)
        # Loop through the sequence and take mean ssim
        loss = 0
        for i in range(len(x)):
            # ssim is not a loss function -> 1 - ssim
            loss += 1 - ssim(x[i], y[i], data_range=torch.max(y[i]))
        return loss / len(x)
    else:
        raise IOError(
            '[ERROR] Use a valid loss function <mse|mae|bce|bcel|ssim>')
    def inference(self, label, geo, image=None):

        # Encode Inputs
        image = Variable(image) if image is not None else None
        geometry = Variable(geo) if geo is not None else None
        concat_input, real_image = self.encode_input(label,
                                                     image,
                                                     geometry=geometry,
                                                     infer=True)

        print(input_label.data.device)
        # Fake Generation

        if torch.__version__.startswith('0.4'):
            with torch.no_grad():
                fake_image = self.netG.forward(concat_input)
        else:
            fake_image = self.netG.forward(concat_input)

        metrics = {}
        if image is not None:  # metrics
            GT_image = Variable((real_image + 1.) / 2.)
            fake_normalized = (fake_image + 1.) / 2.
            metrics['ssim'] = ssim(fake_normalized,
                                   GT_image,
                                   data_range=1,
                                   size_average=False,
                                   nonnegative_ssim=True)
            metrics['ms_ssim'] = ms_ssim(fake_normalized,
                                         GT_image,
                                         data_range=1,
                                         size_average=False)

        return fake_image, metrics
Exemplo n.º 11
0
    def get_values(self):
        """
        Produces the values of the PSNR and SSIM for each validation data
        instance, as well as the the time to upscale each image.

        Returns
        -------
        dict of list of float
            A dictionary containing the lists of PSNR, SSIM, and inference time
            values for each validation data instance.
        """
        psnr = {"model1": [], "model2": []}
        ssim_ = {"model1": [], "model2": []}
        times = {"model1": [], "model2": []}
        with torch.no_grad():
            for lr_img_name in tqdm(list(self.val.keys())):
                # Open LR image
                img_lr = util.uint2tensor4(util.imread_uint(lr_img_name))
                img_lr = img_lr.to(self.device)

                # Open HR image label
                img_hr = util.uint2tensor4(util.imread_uint
                                           (self.val[lr_img_name]))
                img_hr = img_hr.cpu()

                # Time for prediction for model 1
                start_time = time()
                prediction1 = self.model1(img_lr).cpu()
                end_time = time()
                times["model1"].append(end_time - start_time)

                # Time for prediction for model 2
                start_time = time()
                prediction2 = self.model2(img_lr).cpu()
                end_time = time()
                times["model2"].append(end_time - start_time)

                # Generate performance measures on validation data
                psnr["model1"].append(util.calculate_psnr(prediction1.numpy(),
                                                          img_hr.numpy()))
                ssim_["model1"].append(ssim(prediction1, img_hr))
                psnr["model2"].append(util.calculate_psnr(prediction2.numpy(),
                                                          img_hr.numpy()))
                ssim_["model2"].append(ssim(prediction2, img_hr))

        # Save the performance evaluation measures to the Trainer
        return {"psnr": psnr, "ssim": ssim_, "times": times}
Exemplo n.º 12
0
    def step(self, data, level, train=False, augmentation=True):
        if train:
            assert self.optimizer is not None, 'optimizer should be set before training'
            self.model.train()
        else:
            self.model.eval()

        if self.criterions is None:
            self.set_criterion(level)

        with pass_context() if train else torch.no_grad():
            input = expand(data['raw'].to(self.device))
            target = expand(data['rgb'].to(self.device))

            if augmentation and train:
                k = random.randrange(8)
                input = augment(input, k)
                target = augment(target, k)
                enhanced = self.model(input, level)
            elif augmentation and not train:
                enhanced = torch.mean(torch.stack([augment(self.model(augment(input, k), level), k, inverse=True) for k in range(8)], dim=0), dim=0)
            else:
                enhanced = self.model(input, level)

            total_loss = 0.0
            for i, (coefficient, criterion) in enumerate(zip(self.coefficients, self.criterions)):
                if i==0:
                    mse_loss = coefficient * criterion(enhanced, target)
                    total_loss += mse_loss
                    self.loss['mse'].append(mse_loss.cpu().detach().numpy())
                elif i==1:
                    vgg_loss = coefficient * criterion(shrink(enhanced), shrink(target))
                    total_loss += vgg_loss
                    self.loss['vgg'].append(vgg_loss.cpu().detach().numpy())
                elif i==2:
                    msssim_loss = coefficient * criterion(shrink(enhanced), shrink(target))
                    total_loss += msssim_loss
                    self.loss['msssim'].append(msssim_loss.cpu().detach().numpy())

            if train:
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()
                if self.scheduler is not None:
                    self.scheduler.step()

        input = shrink(input)
        enhanced = shrink(enhanced)
        target = shrink(target)

        self.loss['total'].append(total_loss.cpu().detach().numpy())
        self.metrics['psnr'].append(psnr(enhanced, target).detach().cpu().numpy())
        self.metrics['ssim'].append(ssim(enhanced, target, data_range=1.0, size_average=True).detach().cpu().numpy())
        if level==0: self.metrics['lpips'].append(self.lpips(enhanced, target).detach().cpu().numpy())
        self.images['raw'] = input.detach().cpu()[:,:3,:,:]
        self.images['enhanced'] = enhanced.detach().cpu()
        self.images['rgb'] = target.detach().cpu()

        torch.cuda.empty_cache()
Exemplo n.º 13
0
 def backward_G(self, total_it, writer):
     pred_fake = self.netD(self.x_gen)
     self.loss_G_GAN = -torch.mean(pred_fake)
     self.loss_G_L1 = self.criterionL1(self.x_gen, self.real)
     self.ssim_loss = 1 - ssim(self.real, self.x_gen)
     self.perc_loss = self.criterionPerc(self.bottleneck, self.real)
     self.loss_G = self.loss_G_GAN + self.lmbda1 * self.loss_G_L1 + self.lmbda2 * self.ssim_loss + 10000 * self.perc_loss
     self.loss_G.backward()
Exemplo n.º 14
0
def norm(img1, img2):
    criterion = nn.MSELoss()
    loss_r = criterion(img1,img2)
    print("loss_r:" + str(loss_r.item()))
    SSIM_h=1-pytorch_msssim.ssim(img1,img2)
    print("SSIM_h:"+str(SSIM_h.item()))
    MS_SSIM_h=1-pytorch_msssim.ms_ssim(img1,img2)
    print("MS_SSIM_h:"+str(MS_SSIM_h.item()))
Exemplo n.º 15
0
def train_epoch(model, train_dataloader, optimizer, global_step):
    """Train one epoch."""
    model.train()
    start_time = time.time()
    iter_loss = 0
    optimizer.zero_grad()
    criterion = nn.MSELoss(reduction='sum')

    for idx, data in enumerate(train_dataloader):
        # Load images and labels
        imgs, lbls = data
        imgs, lbls = imgs.float().cuda(), lbls.float().cuda()
        out = model(imgs)
        out = crop_tensor(out, lbls.shape[2], lbls.shape[3])
        loss = criterion(out, lbls) / args.update_iters
        iter_loss += loss
        loss.backward()

        if idx > 0 and (idx % args.update_iters == 0):
            # Update parameters
            p_epoch_idx = (
                global_step * args.update_iters) // len(train_dataloader)
            l2_norm = torch.sqrt(torch.mean(torch.square(imgs-lbls)))*255.
            print("Epoch(%s) - Iter (%s) - Loss: %.4f - SSIM: %.4f" % (
                p_epoch_idx,
                idx/args.update_iters,
                iter_loss.item(),
                ssim(imgs, out,
                     data_range=1.)))
            global_step = global_step + 1
            # nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            for param in model.encoder.parameters():
                param.data.clamp_(0.)
            if global_step > 0 and global_step % args.decay_steps == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] /= 10
            curr_lr = optimizer.param_groups[0]['lr']
            optimizer.zero_grad()
            add_summary(global_step,
                        iter_loss, l2_norm, curr_lr,
                        ssim(imgs, out, data_range=1.),
                        imgs, lbls, out)
            iter_loss = 0
    print("Finished training epoch in %s" % int(time.time() - start_time))
    return global_step
Exemplo n.º 16
0
def ssim_similarity(inpts1, inpts2):
    inpts1 = inpts1.repeat(1, 3, 1, 1)
    inpts2 = inpts2.repeat(1, 3, 1, 1)
    #inpts1 -= inpts1.min(0, keepdim=True)[0]
    #inpts1 /= inpts1.max(0, keepdim=True)[0]
    #inpts2 -= inpts2.min(0, keepdim=True)[0]
    #inpts2 /= inpts2.max(0, keepdim=True)[0]
    ssim_result = ssim(inpts1, inpts2, data_range=1, size_average=True)
    return ssim_result
Exemplo n.º 17
0
 def calculate_score(self, originals, reconstruction, device):
     # Calculate SSIM
     originals = originals.view(originals.size(0), 3, 32, 32).to(device)
     reconstruction = reconstruction.view(reconstruction.size(0), 3, 32,
                                          32).to(device)
     batch_average_score = ssim(originals,
                                reconstruction,
                                data_range=1,
                                size_average=True)
     return batch_average_score
Exemplo n.º 18
0
 def backward_G(self, total_it, writer):
     pred_fake = self.netD(self.x_gen)
     print(f"G_pred_real: {torch.mean(pred_fake).item():.6f}")
     writer.add_scalar('G_pred_real',
                       torch.mean(pred_fake).item(), total_it)
     self.loss_G_GAN = self.criterionG(pred_fake, True)
     self.loss_G_L1 = self.criterionL1(self.x_gen, self.real)
     self.ssim_loss = 1 - ssim(self.real, self.x_gen)
     self.loss_G = 10 * self.loss_G_GAN + self.lmbda1 * self.loss_G_L1 + self.lmbda2 * self.ssim_loss
     self.loss_G.backward()
Exemplo n.º 19
0
    def test(self):

        i = 0
        sum_loss = 0.0
        sum_ssim = 0.0
        average_ssim = 0.0
        average_loss = 0.0
        PSNRLoss = 0.0
        sum_PSNRLoss = 0.0
        PSNR = 0.0

        for epoch in range(self.args.epochs):

            self.model.eval()  #?start?

            for data in self.loader:  #(test-for-train)
                i = i + 1

                data = to_device(data, self.device)

                left = data['left_image']
                bg_image = data['bg_image']

                disps = self.model(left)

                #print(disps.shape)

                l_loss = l1loss(disps, bg_image)
                ssim_loss = ssim(disps, bg_image)
                #PSNR = mseloss(disps,bg_image)
                PSNRLoss = 10 * torch.log10(
                    255 / torch.sqrt(mseloss(disps, bg_image)))

                sum_loss = sum_loss + l_loss.item()
                sum_ssim += ssim_loss.item()
                sum_PSNRLoss = sum_PSNRLoss + PSNRLoss.item()

                average_ssim = sum_ssim / i
                average_loss = sum_loss / i
                average_PSNR = sum_PSNRLoss / i

                # print average_loss

                disp_show = disps.squeeze()
                bg_show = bg_image.squeeze()
                # print(bg_show.shape)
                # plt.figure()
                # plt.subplot(1,2,1)
                # plt.imshow(disp_show.data.cpu().numpy())
                # plt.subplot(1,2,2)
                # plt.imshow(bg_show.data.cpu().numpy())
                # plt.show()

        print('average loss:', average_loss, '\naverage_ssim:', average_ssim,
              '\naverage_PSNR:', average_PSNR)
Exemplo n.º 20
0
def train_for_newModel(net,first_im,second_im):
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)  # optimize all dtn parameters
    l2_loss = nn.MSELoss()
    # perform the training
    counter = 0
    lamda = 0.7
    gamma_ssim = 0.5
    gamma_l2 = 0.5
    loss_history = []
    for epoch in range(EPOCH):
        # run batch images
        batch_idxs = 555 // batch_size
        for idx in range(0, batch_idxs):
            b_x = first_im[idx * batch_size: (idx + 1) * batch_size, :, :, :].to(device)
            b_y = second_im[idx * batch_size: (idx + 1) * batch_size, :, :, :].to(device)

            counter += 1
            output = net(b_x, b_y)  # output
            ssim_loss_mri = 1 - ssim(output, b_x, data_range=1)
            ssim_loss_pet = 1 - ssim(output, b_y, data_range=1)
            l2_loss_mri = l2_loss(output, b_x)
            l2_loss_pet = l2_loss(output, b_y)
            ssim_total = gamma_ssim * ssim_loss_mri + (1 - gamma_ssim) * ssim_loss_pet
            l2_total = gamma_l2 * l2_loss_mri + (1 - gamma_l2) * l2_loss_pet
            loss_total = lamda * ssim_total + (1 - lamda) * l2_total
            optimizer.zero_grad()  # clear gradients for this training step
            loss_total.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            loss_history.append(loss_total.item())

            if counter % 20 == 0:
                print(
                    "Epoch: [%2d],step: [%2d], mri_ssim_loss: [%.8f], pet_ssim_loss: [%.8f],  total_ssim_loss: [%.8f], total_l2_loss: [%.8f], total_loss: [%.8f]"
                    % (epoch, counter, ssim_loss_mri, ssim_loss_pet, ssim_total, l2_total, loss_total))

        if (epoch == EPOCH - 1):
            # Save a checkpoint
            torch.save(net.state_dict(), './ourDB_DFFuse_0314.pth')

    return loss_history
Exemplo n.º 21
0
def run_tests(
  num_samples=32,
):
  l1_losses = []
  l2_losses = []
  psnr_losses = []
  gots = []
  num=100
  with torch.no_grad():
    for i, (c2w, lp) in enumerate(zip(tqdm(cam_to_worlds[:num]), light_locs)):
      exp = exp_imgs[i].clamp(min=0, max=1)
      cameras = NeRFCamera(cam_to_world=c2w.unsqueeze(0), focal=focal, device=device)
      lights = PointLights(intensity=[1,1,1], location=lp[None,...], scale=300, device=device)
      got = None
      for _ in range(num_samples):
        sample = pt.pathtrace(
          density_field,
          size=SIZE, chunk_size=min(SIZE, 100), bundle_size=1, bsdf=learned_bsdf,
          integrator=integrator,
          # 0 is for comparison, 1 is for display
          background=0,
          cameras=cameras, lights=lights, device=device, silent=True,
          w_isect=True,
        )[0]
        if got is None: got = sample
        else: got += sample
      got /= num_samples
      got = got.clamp(min=0, max=1)
      save_plot(
        exp ** (1/2.2), got ** (1/2.2),
        f"outputs/path_nerv_armadillo_{i:03}.png",
      )
      l1_losses.append(F.l1_loss(exp,got).item())
      mse = F.mse_loss(exp,got)
      l2_losses.append(mse.item())
      psnr = mse2psnr(mse).item()
      psnr_losses.append(psnr)
      gots.append(got)
  print("Avg l1 loss", np.mean(l1_losses))
  print("Avg l2 loss", np.mean(l2_losses))
  print("Avg PSNR loss", np.mean(psnr_losses))
  with torch.no_grad():
    gots = torch.stack(gots, dim=0).permute(0, 3, 1, 2)
    exps = torch.stack(exp_imgs[:num], dim=0).permute(0, 3, 1, 2)
    # takes a lot of memory
    torch.cuda.empty_cache()
    ssim_loss = ms_ssim(gots, exps, data_range=1, size_average=True).item()
    print("MS-SSIM loss", ssim_loss)

    ssim_loss = ssim(gots, exps, data_range=1, size_average=True).item()
    print("SSIM loss", ssim_loss)
  return
Exemplo n.º 22
0
    def __getmetrics__(self, sr, hr):
        # Compute the quality metrics of a super resolved image against a
        # high-resolution ground truth.
        with torch.no_grad():
            ssim = pytorch_msssim.ssim(sr,
                                       hr,
                                       data_range=1,
                                       nonnegative_ssim=True,
                                       size_average=False)  #(N,)
            ssim = torch.mean(ssim).item()
            psnr = 10 * log10((hr.max()**2) / self.content_loss(sr, hr))

        return ssim, psnr
Exemplo n.º 23
0
def val(args, epoch, G, criterion_l1, dataloader, device, writer):

    with torch.no_grad():

        G.eval()

        l1_loss = []
        ssim_loss = []
        psnr_loss = []

        for i, (vid, cls) in enumerate(dataloader):

            vid = vid.to(device)
            img = vid[:, :, 0, :, :]
            cls = cls.to(device)

            bs = vid.size(0)
            z = torch.randn(bs, args.dim_z).to(device)

            vid_recon = G(img, z, cls)

            # l1 loss
            err_l1 = criterion_l1(vid_recon, vid)
            l1_loss.append(err_l1)

            vid = vid.transpose(2, 1).contiguous().view(-1, 3, 64, 64)
            vid_recon = vid_recon.transpose(2, 1).contiguous().view(
                -1, 3, 64, 64)

            # ssim
            vid = (vid + 1) / 2  # [0, 1]
            vid_recon = (vid_recon + 1) / 2  # [0, 1]
            err_ssim = ssim(vid, vid_recon, data_range=1, size_average=False)
            ssim_loss.append(err_ssim.mean().item())

            # psnr
            err_psnr = psnr(vid, vid_recon)
            psnr_loss.append(err_psnr.mean().item())

        l1_avg = sum(l1_loss) / len(l1_loss)
        ssim_avg = sum(ssim_loss) / len(ssim_loss)
        psnr_avg = sum(psnr_loss) / len(psnr_loss)

        writer.add_scalar('val/l1_recon', l1_avg, epoch)
        writer.add_scalar('val/ssim', ssim_avg, epoch)
        writer.add_scalar('val/psnr', psnr_avg, epoch)
        writer.flush()

        print("[Epoch %d/%d] [l1: %f] [ssim: %f] [psnr: %f]" %
              (epoch, args.max_epoch, l1_avg, ssim_avg, psnr_avg))
Exemplo n.º 24
0
    def test(self):
        
        i = 0
        sum_loss = 0.0
        sum_ssim = 0.0
        average_ssim = 0.0
        average_loss = 0.0
        for epoch in range(self.args.epochs):


            self.model.eval()   #?start?

            for data in self.loader:       #(test-for-train)
                i = i + 1


                data = to_device(data, self.device)

                left = data['left_image']
                bg_image = data['bg_image']

                disps = self.model(left)

                print(disps.shape)

                l_loss = l1loss(disps,bg_image)
                ssim_loss = ssim(disps,bg_image)
				psnr222 = psnr1(disps,bg_image)


                sum_loss = sum_loss + l_loss.item()
                sum_ssim += ssim_loss.item()

                average_ssim = sum_ssim / i
                average_loss = sum_loss / i

                # print average_loss
    
                disp_show = disps.squeeze()
                bg_show = bg_image.squeeze()
                print(bg_show.shape)
                plt.figure()
                plt.subplot(1,2,1)
                plt.imshow(disp_show.data.cpu().numpy())
                plt.subplot(1,2,2)
                plt.imshow(bg_show.data.cpu().numpy())
                plt.show() 
Exemplo n.º 25
0
    def _save_lc_values(self):
        """
        Saves the data values for the learning curves

        Parameters
        ----------
        device : torch.device
            The device used to make the model's predictions
        """
        psnr = []
        ssim_ = []
        self.model.eval()
        with torch.no_grad():
            for lr_img_name in np.random.choice(list(self.val.keys()),
                                                Trainer.ITEMS_PER_CALCULATION):
                # Open LR image
                img_lr = util.uint2tensor4(util.imread_uint(lr_img_name))
                img_lr = img_lr.to(self.device)

                # Open HR image label
                img_hr = util.uint2tensor4(util.imread_uint
                                           (self.val[lr_img_name]))
                img_hr = img_hr.cpu()

                # Predict
                prediction = self.model(img_lr).cpu()

                # Generate performance measures on validation data for
                # learning curves
                psnr.append(util.calculate_psnr(prediction.numpy(),
                                                img_hr.numpy()))
                ssim_.append(ssim(prediction, img_hr))

        # Save the performance evaluation measures to the Trainer
        self.psnr_values.append(np.mean(psnr))
        self.ssim_values.append(np.mean(ssim_))

        # Return the model to training mode
        self.model.train()
Exemplo n.º 26
0
def ssim_loss_fct(input, target, data_min: float, data_max: float,
                  **kwargs) -> torch.Tensor:
    dynamic_range = data_max - data_min

    # normalize for minimum being at 0
    input_off = input - data_min
    target_off = target - data_min

    if torch.isnan(input).sum() > 0 or torch.isnan(target).sum() > 0:
        warnings.warn(
            "The SSIM is not reliable when called upon tensors with nan values"
        )
        ssim_loss = input.new_zeros(size=(input.size(0), ))
    else:
        ssim_loss = ssim(input_off.unsqueeze(1),
                         target_off.unsqueeze(1),
                         data_range=dynamic_range,
                         size_average=False)

    ssim_loss = reduction_fct(ssim_loss, **kwargs)

    return ssim_loss
Exemplo n.º 27
0
def evaluate(model, val_dataloader, global_step):
    """Perform evaluation on the validation set."""
    mse = nn.MSELoss(reduction='sum')
    print("Starting evaluation..")
    l_mse, l_ssim = [], []
    model.eval()
    with torch.no_grad():
        for _, data in enumerate(val_dataloader):
            imgs, lbls = data
            imgs, lbls = imgs.float().cuda(), lbls.float().cuda()
            out = model(imgs)
            out = crop_tensor(out, lbls.shape[2], lbls.shape[3])
            loss = mse(out, lbls)
            l_mse.append(loss)
            d_ssim = ssim(imgs, out,
                          data_range=1.)
            l_ssim.append(d_ssim)
        print("Completed evaluation..")
        l_mse, l_ssim = torch.Tensor(l_mse), torch.Tensor(l_ssim)
        d_mse, d_ssim = torch.mean(l_mse), torch.mean(l_ssim)
        writer.add_scalar("val/mse", d_mse, global_step)
        writer.add_scalar("val/ssim", d_ssim, global_step)
    return d_mse, d_ssim
Exemplo n.º 28
0
def discard_frames(folder, video_id, min_avg_value=10, ssim_threshold=0.3, ssim_range=.5):
    '''
    Iterates over folders and removes frame sequences if it
    -contains a dark frame (avg pixel value < 5)
    -contains a frame boundary (detected by SSIM)
    '''

    
    
    seq_folders = [os.path.join(folder, f) for f in os.listdir(folder) if f.startswith(video_id)]
    
    for seq_folder in tqdm(seq_folders, desc=f'Discarding frames {folder}'):
        imgs = [imageio.imread(os.path.join(seq_folder, f)) for f in os.listdir(seq_folder)]
        imgs = [torch.tensor(img).permute(2,0,1).unsqueeze(dim=0).cuda().float() for img in imgs]
        removed=False
        for i in range(5):
            if imgs[i].mean().item() < min_avg_value:
                print(f'removed {seq_folder}')
                removed=True
                break
        

        ssims = [ssim(imgs[i], imgs[i+1], win_size=11).item() for i in range(4)]
        if min(ssims) < ssim_threshold:
            print(seq_folder, 'min lower than .3')
            removed=True
        

        if max(ssims)-min(ssims) > ssim_range:
            print(f'deleting {seq_folder}')
            removed=True
        
        if removed:
            with open(os.path.join(args.dataset_folder, 'to_remove_lmd.txt'), 'a') as f:
                f.write(seq_folder+'\n')
                shutil.rmtree(seq_folder)
Exemplo n.º 29
0
    def process(self, inputs, outputs):
        """
        Args:
            inputs: the inputs to a CFPN model. input dicts must contain an 'image' key

            outputs: the outputs of a CFPN model. It is a list of dicts with key
                "instances" that contains :class:`Instances`.
                The :class:`Instances` object needs to have `densepose` field.
        """
        assert (len(inputs) == 1)
        # reshape the input image to have a maximum length of 512 as the model preprocesses
        # Much shuffling of data, but also the dataset is only 24 images
        orig_image = inputs[0]['image'].permute(1, 2, 0)
        orig_image = self.transform.get_transform(orig_image).apply_image(
            orig_image.numpy())
        orig_image = torch.tensor(orig_image).permute(2, 0, 1)
        orig_image = torch.unsqueeze(orig_image, dim=0).float().to(
            outputs[self.eval_img].get_device())

        reconstruct_image = outputs[self.eval_img].float()
        reconstruct_image = reconstruct_image[:, :, 0:orig_image.shape[2],
                                              0:orig_image.shape[3]]
        assert orig_image.shape[1] == 3, "original image must have 3 channels"
        assert reconstruct_image.shape[
            1] == 3, "reconstructed image must have 3 channels"
        with torch.no_grad():
            ssim_val = ssim(reconstruct_image,
                            orig_image,
                            data_range=255,
                            size_average=False)
            ms_ssim_val = ms_ssim(reconstruct_image,
                                  orig_image,
                                  data_range=255,
                                  size_average=False)
            self.ssim_vals.extend(ssim_val)
            self.ms_ssim_vals.extend(ms_ssim_val)
Exemplo n.º 30
0
    def cal_batch_color_loss(self, batch_fake_img, batch_real_img):
        batch_size = batch_fake_img.size(0)
        # rf,gf,bf = batch_fake_img[:,0,:,:]+1, batch_fake_img[:,1,:,:]+1, batch_fake_img[:,2,:,:]+1
        # u_fake = (-0.169*rf-0.331*gf+0.5*bf)/2. + 0.5
        # u_fake = torch.unsqueeze(u_fake, 1)
        # v_fake = (0.5*rf-0.419*gf-0.081*bf)/2. + 0.5
        # v_fake = torch.unsqueeze(v_fake, 1)

        # rr,gr,br = batch_real_img[:,0,:,:]+1, batch_real_img[:,1,:,:]+1, batch_real_img[:,2,:,:]+1
        # u_real = (-0.169*rr-0.331*gr+0.5*br)/2. + 0.5
        # u_real = torch.unsqueeze(u_real, 1)
        # v_real = (0.5*rr-0.419*gr-0.081*br)/2. + 0.5
        # v_real = torch.unsqueeze(v_real, 1)

        # batch_color_loss = 1 - pytorch_msssim.ssim(u_real, u_fake) + 1 - pytorch_msssim.ssim(v_real, v_fake)

        fake_img_nor = (batch_fake_img + 1) / 2
        r, g, b = fake_img_nor[:,
                               0, :, :], fake_img_nor[:,
                                                      1, :, :], fake_img_nor[:,
                                                                             2, :, :]

        mx = torch.max(torch.max(r, g), b)
        mn = torch.min(torch.min(r, g), b)
        df = mx - mn
        F_h = mx
        F_h = torch.where(mx == r, ((60 * ((g - b) / df) + 360) % 360) / 360,
                          F_h)
        F_h = torch.where(mx == g, ((60 * ((b - r) / df) + 120) % 360) / 360,
                          F_h)
        F_h = torch.where(mx == b, ((60 * ((r - g) / df) + 240) % 360) / 360,
                          F_h)
        F_h = torch.where(mx == mn, torch.full_like(F_h, 0), F_h)
        F_h = torch.unsqueeze(F_h, 0)

        F_s = torch.where(mx == 0, torch.full_like(df, 0), (df / mx))
        F_s = torch.unsqueeze(F_s, 0)

        real_img_nor = (batch_real_img + 1) / 2
        r, g, b = real_img_nor[:,
                               0, :, :], real_img_nor[:,
                                                      1, :, :], real_img_nor[:,
                                                                             2, :, :]

        mx = torch.max(torch.max(r, g), b)
        mn = torch.min(torch.min(r, g), b)
        df = mx - mn
        R_h = mx
        R_h = torch.where(mx == r, ((60 * ((g - b) / df) + 360) % 360) / 360,
                          R_h)
        R_h = torch.where(mx == g, ((60 * ((b - r) / df) + 120) % 360) / 360,
                          R_h)
        R_h = torch.where(mx == b, ((60 * ((r - g) / df) + 240) % 360) / 360,
                          R_h)
        R_h = torch.where(mx == mn, torch.full_like(R_h, 0), R_h)
        R_h = torch.unsqueeze(R_h, 0)

        R_s = torch.where(mx == 0, torch.full_like(df, 0), (df / mx))
        R_s = torch.unsqueeze(R_s, 0)

        batch_color_loss = 1 - pytorch_msssim.ssim(
            R_h, F_h) + 1 - pytorch_msssim.ssim(R_s, F_s)

        # batch_color_loss = torch.abs(u_fake - u_real) + torch.abs(v_fake - v_real)
        return batch_color_loss  #torch.mean(batch_color_loss)