Esempio n. 1
0
def ssim_loss(X, Y):

    # # X: (N,3,H,W) a batch of non-negative RGB images (0~255)
    # # Y: (N,3,H,W)
    #
    # # calculate ssim & ms-ssim for each image
    # ssim_val = ssim(X, Y, data_range=255, size_average=False)  # return (N,)
    # ms_ssim_val = ms_ssim(X, Y, data_range=255, size_average=False)  # (N,)
    #
    # # set 'size_average=True' to get a scalar value as loss.
    # ssim_loss = 1 - ssim(X, Y, data_range=255, size_average=True)  # return a scalar
    # ms_ssim_loss = 1 - ms_ssim(X, Y, data_range=255, size_average=True)

    # reuse the gaussian kernel with SSIM & MS_SSIM.
    ssim_module = SSIM(data_range=255,
                       size_average=True,
                       channel=1,
                       nonnegative_ssim=False)
    ms_ssim_module = MS_SSIM(data_range=255,
                             size_average=True,
                             channel=1,
                             nonnegative_ssim=False)

    ssim_loss = 1 - ssim_module(X, Y)
    ms_ssim_loss = 1 - ms_ssim_module(X, Y)

    return ms_ssim_loss
    def __init__(self, device, model, lr=2e-4, rendering_loss_type='l1',
                 ssim_loss_weight=0.05):
        self.device = device
        self.model = model
        self.lr = lr
        self.rendering_loss_type = rendering_loss_type
        self.ssim_loss_weight = ssim_loss_weight
        self.use_ssim = self.ssim_loss_weight != 0
        # If False doesn't save losses in loss history
        self.register_losses = True
        # Check if model is multi-gpu
        self.multi_gpu = isinstance(self.model, nn.DataParallel)

        # Initialize optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

        # Initialize loss functions
        # For rendered images
        if self.rendering_loss_type == 'l1':
            self.loss_func = nn.L1Loss()
        elif self.rendering_loss_type == 'l2':
            self.loss_func = nn.MSELoss()

        # For SSIM
        if self.use_ssim:
            self.ssim_loss_func = SSIM(data_range=1.0, size_average=True,
                                       channel=3, nonnegative_ssim=False)

        # Loss histories
        self.recorded_losses = ["total", "regression", "ssim"]
        self.loss_history = {loss_type: [] for loss_type in self.recorded_losses}
        self.epoch_loss_history = {loss_type: [] for loss_type in self.recorded_losses}
        self.val_loss_history = {loss_type: [] for loss_type in self.recorded_losses}
Esempio n. 3
0
 def __init__(self):
     super(Loss, self).__init__()
     self.l1 = nn.L1Loss()
     self.criterion_ssim = SSIM(win_size=11,
                                win_sigma=1.5,
                                data_range=1.0,
                                size_average=True,
                                channel=3).cuda()
Esempio n. 4
0
 def __init__(self, to_cuda):
     self.vgg = VGG19().to(to_cuda).eval()
     self.criterion = nn.L1Loss(size_average=True)
     self.criterionMSE = nn.MSELoss()
     self.ssim_module = SSIM(data_range=1,
                             size_average=True,
                             channel=3,
                             nonnegative_ssim=False)
     self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
     self.flownet = spynet.Network().to(to_cuda).eval()
     self.alpha = 50
Esempio n. 5
0
    def __init__(self, opt):
        self.ssim_module = SSIM(data_range=255, size_average=True, channel=3)

        self.cuda = opt.cuda
        self.fid_module = get_inception_model().eval()

        if self.cuda:
            print("Using CUDA")
            self.fid_module = self.fid_module.cuda()

        self.lpips_module = PerceptualLoss(use_gpu=self.cuda)

        self.fid_features_real = list()
        self.fid_features_fake = list()
Esempio n. 6
0
    def __init__(self, hparams):
        super(ProbaModel, self).__init__()
        self.hparams = hparams

        self.batch_size = hparams.batch_size

        self.gen = MSRResNet(in_nc=hparams.nc + 2,
                             out_nc=1,
                             nf=48,
                             nb=16,
                             upscale=3)
        self.dis = Discriminator()

        self.loss_MSE = torch.nn.MSELoss()
        self.loss_ssim = SSIM(win_size=12,
                              win_sigma=1.5,
                              data_range=1.,
                              size_average=True,
                              channel=1)

        path = os.path.join('/local/pajot/data/proba_v', "norm.csv")
        self.baseline_cpsnrs = pd.read_csv(
            path, ' ', names=['name',
                              'score']).set_index('name').to_dict()['score']
Esempio n. 7
0
    def main(self):
        dataset_class = self.datasets[self.args.dataset](
            root=self.args.root,
            add_labeled=self.args.add_labeled,
            advanced_transforms=True,
            merged=self.args.merged,
            remove_classes=self.args.remove_classes,
            oversampling=self.args.oversampling,
            unlabeled_subset_ratio=self.args.unlabeled_subset,
            seed=self.args.seed,
            start_labeled=self.args.start_labeled)

        _, labeled_dataset, unlabeled_dataset, labeled_indices, unlabeled_indices, test_dataset = \
            dataset_class.get_dataset()

        labeled_loader, unlabeled_loader, val_loader = create_loaders(
            self.args, labeled_dataset, unlabeled_dataset, test_dataset,
            labeled_indices, unlabeled_indices, self.kwargs,
            dataset_class.unlabeled_subset_num)

        base_dataset = dataset_class.get_base_dataset_autoencoder()

        base_loader = create_base_loader(base_dataset, self.kwargs,
                                         self.args.batch_size)

        reconstruction_loss_log = []

        bce_loss = nn.BCELoss().cuda()
        l1_loss = nn.L1Loss()
        l2_loss = nn.MSELoss()
        ssim_loss = SSIM(size_average=True,
                         data_range=1.0,
                         nonnegative_ssim=True)

        criterions_reconstruction = {
            'bce': bce_loss,
            'l1': l1_loss,
            'l2': l2_loss,
            'ssim': ssim_loss
        }
        criterion_cl = get_loss(self.args,
                                dataset_class.labeled_class_samples,
                                reduction='none')

        model, optimizer, self.args = create_model_optimizer_autoencoder(
            self.args, dataset_class)

        best_loss = np.inf

        metrics_per_cycle = pd.DataFrame([])
        metrics_per_epoch = pd.DataFrame([])
        num_class_per_cycle = pd.DataFrame([])

        best_recall, best_report, last_best_epochs = 0, None, 0
        best_model = deepcopy(model)

        self.args.start_epoch = 0
        self.args.weak_supervision_strategy = "random_sampling"
        current_labeled = dataset_class.start_labeled

        for epoch in range(self.args.start_epoch, self.args.epochs):
            cl_train_loss, losses_avg_reconstruction, losses_reconstruction = \
                self.train(labeled_loader, model, criterion_cl, optimizer, last_best_epochs, epoch,
                           criterions_reconstruction, base_loader)
            val_loss, val_report = self.validate(val_loader, model,
                                                 last_best_epochs,
                                                 criterion_cl)

            reconstruction_loss_log.append(losses_avg_reconstruction.tolist())
            best_loss = min(best_loss, losses_reconstruction.avg)

            is_best = val_report['macro avg']['recall'] > best_recall
            last_best_epochs = 0 if is_best else last_best_epochs + 1

            val_report = pd.concat([val_report, cl_train_loss, val_loss],
                                   axis=1)
            metrics_per_epoch = pd.concat([metrics_per_epoch, val_report])

            if epoch > self.args.labeled_warmup_epochs and last_best_epochs > self.args.add_labeled_epochs:
                metrics_per_cycle = pd.concat([metrics_per_cycle, best_report])

                labeled_loader, unlabeled_loader, val_loader, labeled_indices, unlabeled_indices = \
                    perform_sampling(self.args, None, None,
                                     epoch, model, labeled_loader, unlabeled_loader,
                                     dataset_class, labeled_indices,
                                     unlabeled_indices, labeled_dataset,
                                     unlabeled_dataset,
                                     test_dataset, self.kwargs, current_labeled,
                                     model)

                current_labeled += self.args.add_labeled
                last_best_epochs = 0

                if self.args.reset_model:
                    model, optimizer, self.args = create_model_optimizer_autoencoder(
                        self.args, dataset_class)

                if self.args.novel_class_detection:
                    num_classes = [
                        np.sum(
                            np.array(base_dataset.targets)[labeled_indices] ==
                            i) for i in range(len(base_dataset.classes))
                    ]
                    num_class_per_cycle = pd.concat([
                        num_class_per_cycle,
                        pd.DataFrame.from_dict(
                            {
                                cls: num_classes[i]
                                for i, cls in enumerate(base_dataset.classes)
                            },
                            orient='index').T
                    ])

                criterion_cl = get_loss(self.args,
                                        dataset_class.labeled_class_samples,
                                        reduction='none')
            else:
                best_recall = val_report['macro avg'][
                    'recall'] if is_best else best_recall
                best_report = val_report if is_best else best_report
                best_model = deepcopy(model) if is_best else best_model

            if current_labeled > self.args.stop_labeled:
                break

            save_checkpoint(
                self.args, {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_recall,
                }, is_best)

        if self.args.store_logs:
            store_logs(self.args,
                       pd.DataFrame(reconstruction_loss_log,
                                    columns=['bce', 'l1', 'l2', 'ssim']),
                       log_type='ae_loss')
            store_logs(self.args, metrics_per_cycle)
            store_logs(self.args, metrics_per_epoch, log_type='epoch_wise')
            store_logs(self.args, num_class_per_cycle, log_type='novel_class')

        self.model = model
        return model
Esempio n. 8
0
	def __init__(self , max_value = 1.0):
		super(SSIMLoss, self).__init__()
		self.max_value = max_value
		self.ssim_fn = SSIM(data_range=max_value, size_average = False, channel = 3)
Esempio n. 9
0
 def __init__(self, weight):
     super(DepthLoss, self).__init__()
     self.weight = weight
     self.l1_loss = nn.L1Loss()
     self.SSIM = SSIM()
     self.gradient_loss = GradientLoss()
Esempio n. 10
0
    def __init__(self, args, ckp):
        super(Loss, self).__init__()
        print('Preparing loss function:')

        self.n_GPUs = args.n_GPUs
        self.loss = []
        self.loss_module = nn.ModuleList()
        for loss in args.loss.split('+'):
            weight, loss_type = loss.split('*')
            if loss_type == 'MSE':  # L2 loss
                loss_function = nn.MSELoss()
            elif loss_type == 'L1':
                loss_function = nn.L1Loss()
            elif loss_type.find('VGG') >= 0:
                module = import_module('loss.vgg')
                loss_function = getattr(module,
                                        'VGG')(loss_type[3:],
                                               rgb_range=args.rgb_range)
            elif loss_type.find('TextureL') >= 0:
                module = import_module('loss.vgg')
                loss_function = getattr(module,
                                        'VGG')(loss_type[3:],
                                               rgb_range=args.rgb_range,
                                               texture_loss=True)
            elif loss_type.find('GAN') >= 0:
                module = import_module('loss.adversarial')
                loss_function = getattr(module, 'Adversarial')(args, loss_type)
            elif loss_type.find('TVLoss') >= 0:
                module = import_module('loss.tvloss')
                loss_function = getattr(module, 'TVLoss')()

            elif loss_type.find('SSIM') >= 0:
                from pytorch_msssim import SSIM
                loss_function = SSIM(win_size=7,
                                     win_sigma=1,
                                     data_range=args.rgb_range,
                                     size_average=True,
                                     channel=3)
            elif loss_type.find('MS-SSIM') >= 0:
                from pytorch_msssim import MS_SSIM
                loss_function = MS_SSIM(win_sigma=1,
                                        data_range=args.rgb_range,
                                        size_average=True,
                                        channel=3)

            self.loss.append({
                'type': loss_type,
                'weight': float(weight),
                'function': loss_function
            })
            if loss_type.find('GAN') >= 0:
                self.loss.append({
                    'type': 'DIS',
                    'weight': 1,
                    'function': None
                })

        if len(self.loss) > 1:
            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})

        for l in self.loss:
            if l['function'] is not None:
                print('{:.6f} * {}'.format(l['weight'], l['type']))
                self.loss_module.append(l['function'])

        self.log = torch.Tensor()

        device = torch.device('cpu' if args.cpu else 'cuda')
        self.loss_module.to(device)
        if args.precision == 'half': self.loss_module.half()
        if not args.cpu and args.n_GPUs > 1:
            self.loss_module = nn.DataParallel(self.loss_module,
                                               range(args.n_GPUs))

        if args.load != '': self.load(ckp.dir, cpu=args.cpu)
Esempio n. 11
0
def ssim_fn_val(sr , hr , max_value = 255):
	return SSIM(data_range = max_value, size_average= False, channel = 3)(sr,hr)
Esempio n. 12
0
 def __init__(self, primary_loss="mse", weights=[1.0, 1.0], asImages=True):
     self.main_loss = MSE_LOSS(reduction="mean")
     if "bce" in primary_loss: self.main_loss = BCE_LOSS(reduction="mean")
     self.ssim_loss = SSIM(data_range=1.0, nonnegative_ssim=True)
     self.weights = weights
     self.asImages = asImages
Esempio n. 13
0
import csv
import os
import matplotlib.pyplot as plt

import torch
from torch.autograd import Variable

from torchvision.utils import save_image
import torch.nn.functional as F

from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

ssim_fn_vgg = SSIM(data_range=1.0, size_average=False, channel=512)


def write_to_csv_file(filename, content):
    with open(filename, 'a+') as f:
        file_writer = csv.writer(f)
        file_writer.writerow(content)


def get_image(input_tensor):
    return input_tensor.clamp(0, 1).permute(1, 2, 0).cpu().numpy()


@torch.no_grad()
def test(epoch, generator, dataloader):
    psnr_val = 0
Esempio n. 14
0
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        device = torch.device(
            "cuda:0" if torch.cuda.is_available() and self.hparams.cuda else "cpu")

        if self.hparams.modelID == 0:
            self.net = ResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks,
                              starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks,
                              is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm,
                              res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV,
                              upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D)  # TODO think of 2D
            # self.net = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1)
        elif self.hparams.modelID == 2:
            self.net = DualSpaceResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks,
                                        starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks,
                                        is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm,
                                        res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV,
                                        upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D,
                                        connect_mode=self.hparams.model_dspace_connect_mode, inner_norm_ksp=self.hparams.model_inner_norm_ksp)
        elif self.hparams.modelID == 3: #Primal-Dual Network, complex Primal
            self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = True,
                            g_normtype = "magmax",
                            transform = "Fourier",
                            return_abs = True)
        elif self.hparams.modelID == 4: #Primal-Dual Network, absolute Primal
            self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = False,
                            g_normtype = "magmax",
                            transform = "Fourier")
        elif self.hparams.modelID == 5: #Primal-Dual UNet Network, absolute Primal
            self.net = PrimalDualNetwork(n_primary=4, n_dual=5, n_iterations=2,
                            use_original_block = False,
                            use_original_init = False,
                            use_complex_primal = False,
                            g_normtype = "magmax",
                            transform = "Fourier")
        elif self.hparams.modelID == 6: #Primal-Dual Network v2 (no residual), complex Primal
            self.net = PrimalDualNetworkNoResidue(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = True,
                            residuals=False,
                            g_normtype = "magmax",
                            transform = "Fourier",
                            return_abs = True)

        else:
            # TODO: other models
            sys.exit("Only ReconResNet and DualSpaceResNet have been implemented so far in ReconEngine")

        if bool(self.hparams.preweights_path):
            print("Pre-weights found, loding...")
            chk = torch.load(self.hparams.preweights_path, map_location='cpu')
            self.net.load_state_dict(chk['state_dict'])

        if self.hparams.lossID == 0:
            if self.hparams.in_channels != 1 or self.hparams.out_channels != 1:
                sys.exit(
                    "Perceptual Loss used here only works for 1 channel input and output")
            self.loss = PerceptualLoss(device=device, loss_model="unet3Dds", resize=None,
                                       loss_type=self.hparams.ploss_type, n_level=self.hparams.ploss_level)  # TODO thinkof 2D
        elif self.hparams.lossID == 1:
            self.loss = nn.L1Loss(reduction='mean')
        elif self.hparams.lossID == 2:
            self.loss = MS_SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device)
        elif self.hparams.lossID == 3:
            self.loss = SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device)
        else:
            sys.exit("Invalid Loss ID")

        self.dataspace = DataSpaceHandler(**self.hparams)

        if self.hparams.ds_mode == 0:
            trans = tioTransforms
            augs = tioAugmentations
        elif self.hparams.ds_mode == 1:
            trans = pytTransforms
            augs = pytAugmentations

        # TODO parameterised everything
        self.init_transforms = []
        self.aug_transforms = []
        self.transforms = []
        if self.hparams.ds_mode == 0 and self.hparams.cannonicalResample:  # Only applicable for TorchIO
            self.init_transforms += [tio.ToCanonical(), tio.Resample('gt')]
        if self.hparams.ds_mode == 0 and self.hparams.forceNormAffine:  # Only applicable for TorchIO
            self.init_transforms += [trans.ForceAffine()]
        if self.hparams.croppad and self.hparams.ds_mode == 1:
            self.init_transforms += [
                trans.CropOrPad(size=self.hparams.input_shape)]
        self.init_transforms += [trans.IntensityNorm(type=self.hparams.norm_type, return_meta=self.hparams.motion_return_meta)]
        # dataspace_transforms = self.dataspace.getTransforms() #TODO: dataspace transforms are not in use
        # self.init_transforms += dataspace_transforms
        if bool(self.hparams.random_crop) and self.hparams.ds_mode == 1:
            self.aug_transforms += [augs.RandomCrop(
                size=self.hparams.random_crop, p=self.hparams.p_random_crop)]
        if self.hparams.p_contrast_augment > 0:
            self.aug_transforms += [augs.getContrastAugs(
                p=self.hparams.p_contrast_augment)]
        # if the task if MoCo and pre-corrupted vols are not supplied
        if self.hparams.taskID == 1 and not bool(self.hparams.train_path_inp):
            if self.hparams.motion_mode == 0 and self.hparams.ds_mode == 0:
                motion_params = {k.split('motionmg_')[
                    1]: v for k, v in self.hparams.items() if k.startswith('motionmg')}
                self.transforms += [tioMotion.RandomMotionGhostingFast(
                    **motion_params), trans.IntensityNorm()]
            elif self.hparams.motion_mode == 1 and self.hparams.ds_mode == 1 and not self.hparams.is3D:
                self.transforms += [pytMotion.Motion2Dv0(
                    sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)]
            elif self.hparams.motion_mode == 2 and self.hparams.ds_mode == 1 and not self.hparams.is3D:
                self.transforms += [pytMotion.Motion2Dv1(sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads,
                                                         restore_original=self.hparams.motion_restore_original, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)]
            else:
                sys.exit(
                    "Error: invalid motion_mode, ds_mode, is3D combo. Please double check!")

        self.static_metamat = sio.loadmat(self.hparams.static_metamat_file) if bool(
            self.hparams.static_metamat_file) else None
        if self.hparams.taskID == 0 and self.hparams.use_datacon:
            self.datacon = DataConsistency(
                isRadial=self.hparams.is_radial, metadict=self.static_metamat)
        else:
            self.datacon = None

        input_shape = self.hparams.input_shape if self.hparams.is3D else self.hparams.input_shape[
            :-1]
        self.example_input_array = torch.empty(
            self.hparams.batch_size, self.hparams.in_channels, *input_shape).float()
        self.saver = ResSaver(
            self.hparams.res_path, save_inp=self.hparams.save_inp, do_norm=self.hparams.do_savenorm)
Esempio n. 15
0
 def ssim(self):
     """
      @return ssim function on normalized image with 3 channels
     """
     return SSIM(data_range=1, size_average=True, channel=3)
Esempio n. 16
0

def l1loss(x, y):
    return torch.nn.functional.l1_loss(x, y, reduction='mean')


def l2loss(x, y):
    return torch.pow(x - y, 2).mean()


def psnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return (10 * torch.log10(x.shape[-2] * x.shape[-1] /
                             (x - y).pow(2).sum(dim=(2, 3)))).mean(dim=1)


ssim = SSIM(data_range=1.0)
msssim = MSSSIM(data_range=1.0)


def gaussian(x, sigma=1.0):
    return np.exp(-(x**2) / (2 * (sigma**2)))


def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, device=None):
    """Construct the convolution kernel for a gaussian blur
    See https://en.wikipedia.org/wiki/Gaussian_blur for a definition.
    Overall I first generate a NxNx2 matrix of indices, and then use those to
    calculate the gaussian function on each element. The two dimensional
    Gaussian function is then the product along axis=2.
    Also, in_channels == out_channels == n_channels
    """
Esempio n. 17
0
                 'normal',
                 0.02,
                 gpu_id=device)

# VGG for perceptual loss
if opt.lamb_content > 0:
    vgg = Vgg16()
    init_vgg16(root_path)
    vgg.load_state_dict(torch.load(os.path.join(root_path, "vgg16.weight")))
    vgg.to(device)

# define loss
criterionL1 = nn.L1Loss().to(device)
criterionL2 = nn.MSELoss().to(device)
criterionMSE = nn.MSELoss().to(device)
criterionSSIM = SSIM(data_range=255, size_average=True, channel=3)
criterionMSSSIM1 = MS_SSIM(data_range=255, size_average=True, channel=1)
criterionMSSSIM3 = MS_SSIM(data_range=255, size_average=True, channel=3)

# setup optimizer
optimizer_i = optim.Adam(net_i.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
optimizer_r = optim.Adam(net_r.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
net_i_scheduler = get_scheduler(optimizer_i, opt)
net_r_scheduler = get_scheduler(optimizer_r, opt)

loss_i_list = []
loss_r_list = []
Esempio n. 18
0
        
        # Identity loss
        idt_B = decoder(encoder(real_B))
        loss_idt = criterion_l1(idt_B,real_B)*5
        
        # GAN feature matching loss
        # loss_feat = 0
        # for j in range(len(pred_fake[0])-1):
        #     loss_feat += criterion_feat(pred_fake[0][j], pred_real[0][j].detach())*1
                        
        # VGG feature matching loss
        loss_VGG = criterion_VGG(fake_B,real_A)*10
        
           
        ##ssimloss
        ssim_module = SSIM(data_range=1, size_average=True, channel=3)
        X1 = (real_A + 1)*0.5   # [-1, 1] => [0, 1]
        Y1 = (fake_B + 1)*0.5  
        ssim_A = (1 - ssim_module(X1, Y1))


        # Total loss
        loss = loss_G + loss_feature + loss_idt + ssim_A 
        
        loss.backward()
        
        optimizer_encoder.step()
        optimizer_decoder.step()
        ###################################

        
Esempio n. 19
0
 def __init__(self, alpha=0.84, kernel_w=3, sigma=1.5, channels=1):
     super.__init__()
     self.alpha = alpha
     self.kernel_w = kernel_w
     self.ssim_d = SSIM(win_size=kernel_w, win_sigma=sigma, data_range=1.0, size_average=False, channel=channels)
     self.config =  {'Distance' : 'L1 + SSIM', 'alpha': alpha, 'win_size': kernel_w, 'win_sigma': sigma}
Esempio n. 20
0
            ]
        }
    })

    to_tensor = ToTensor()

    def to_pilimagelab(pic):
        pic = pic.mul(255).byte()
        nppic = np.transpose(pic.numpy(), (1, 2, 0))
        return PIL.Image.fromarray(nppic, mode='LAB')

    invcielab = InvCIELABTransform()

    ssim = SSIM(win_size=11,
                win_sigma=1.5,
                data_range=255,
                size_average=True,
                channel=3)

    cur_time = time.time()

    log_time = cur_time
    checkpoint_time = cur_time
    backup_time = cur_time

    print('Training')
    eval_iter = iter(eval_loader)
    while True:
        epoch_start = time.time()

        torch.manual_seed(epoch_start)
Esempio n. 21
0
# hbao = np.expand_dims(np.expand_dims(hbao, 0), axis=0)
# nnao = np.expand_dims(np.expand_dims(nnao, 0), axis=0)
# deepshading = np.expand_dims(np.expand_dims(deepshading, 0), axis=0)
vao = np.expand_dims(np.expand_dims(vao, 0), axis=0)

gt = torch.from_numpy(gt).cuda()
ours = torch.from_numpy(ours).cuda()
# hbao = torch.from_numpy(hbao).cuda()
# nnao = torch.from_numpy(nnao).cuda()
# deepshading = torch.from_numpy(deepshading).cuda()
vao = torch.from_numpy(vao).cuda()

# ssim_hbao = SSIM(win_size=11, win_sigma=1.5, data_range=1.0, size_average=True, channel=1)(gt, hbao).item()
ssim_ours = SSIM(win_size=11,
                 win_sigma=1.5,
                 data_range=1.0,
                 size_average=True,
                 channel=1)(gt, ours).item()
# ssim_nnao = SSIM(win_size=11, win_sigma=1.5, data_range=1.0, size_average=True, channel=1)(gt, nnao).item()
# ssim_deepshading = SSIM(win_size=11, win_sigma=1.5, data_range=1.0, size_average=True, channel=1)(gt, deepshading).item()
ssim_vao = SSIM(win_size=11,
                win_sigma=1.5,
                data_range=1.0,
                size_average=True,
                channel=1)(gt, vao).item()

print(ssim_ours, ssim_vao)
# mse_hbao = torch.nn.L1Loss()(gt, hbao).item()
mse_ours = torch.nn.L1Loss()(gt, ours).item()
# mse_nnao = torch.nn.L1Loss()(gt, nnao).item()
# mse_deepshading = torch.nn.L1Loss()(gt, deepshading).item()
Esempio n. 22
0
    im_a=trans(io.imread(str1+s+'.png'))
    im_f=trans(io.imread(str2+s+'.png'))
    g_a = trans(io.imread(str3 + s + '.png'))
    g_f = trans(io.imread(str4 + s + '.png'))
    G_f = trans(io.imread(str5 + s + '.png'))
    return im_a,im_f,g_a,g_f,G_f

net = SDNet()
net2=IFNet()
net=torch.nn.DataParallel(net)
net2=torch.nn.DataParallel(net2)
net.cuda()
net2.cuda()
net.load_state_dict(torch.load(net_dict))
net2.load_state_dict(torch.load(net2_dict))
ssim = SSIM(data_range=255, size_average=True, channel=3)
S_a=0
S_f=0
S_c=0
for i in range(0, np.shape(lis)[1]):
    im_a,im_f,g_a,g_f,G_f = read_image(lis[0,i])
    im_a,im_f,g_a,g_f,G_f=torch.tensor(im_a,dtype=torch.float32),torch.tensor(im_f),torch.tensor(g_a),torch.tensor(g_f),torch.tensor(G_f)
    im_a, im_f, g_a, g_f, G_f = im_a.cuda(), im_f.cuda(), g_a.cuda(), g_f.cuda(), G_f.cuda()
    net.eval()
    net2.eval()
    out_a, out_f = net(im_a - 90.5, im_f - 120.5)
    output = net2(out_a, out_f)
    output=torch.clamp(output,0,255)
    s1=torch.mean(ssim(out_a , g_a))
    s2=torch.mean(ssim(out_f , g_f))
    s3=torch.mean(ssim(output, G_f))
Esempio n. 23
0
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
from torch.autograd import Variable


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")	
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

def psnr_fn(sr , hr , max_value = 1.):
	mse = torch.mean((sr - hr)**2 , axis = (1,2,3))
	return 20 * torch.log10(max_value / torch.sqrt(mse))

def mse_fn(sr , hr):
	return torch.mean((sr-hr)**2 , axis = (1,2,3))	

ssim_fn = SSIM(data_range = 1.0, size_average= False, channel = 3)

def ssim_fn_val(sr , hr , max_value = 255):
	return SSIM(data_range = max_value, size_average= False, channel = 3)(sr,hr)

# def get_loss_fn(loss_type , args = None):
# 	if loss_type == 'L1':
# 		return nn.L1Loss()
# 	if loss_type == 'MSE':
# 		return nn.MSELoss()
# 	if loss_type == 'PSNR':
# 		return psnr.PSNR()
# 	if loss_type == 'SSIM':
# 		return SSIM(data_range=1.0, size_average = False, channel = 3)
# 	if loss_type == 'GAN':
# 		return GAN.GAN(args)
Esempio n. 24
0
params = sum([np.prod(p.size()) for p in net.parameters()])

# Normalization statistics
stats = {'norm_type':args.norm_type, 'norm_type_img':args.norm_type, 'mean_imgs':mean_imgs, 'std_images':std_images, 'max_images':max_images,
        'mean_vols':mean_vols, 'std_vols':std_vols, 'max_vols':max_volumes}



# Create loss function and optimizer
loss = nn.MSELoss()
if args.use_img_loss>0:
    loss_img = nn.MSELoss()

# reuse the gaussian kernel with SSIM & MS_SSIM. 
ssim_module = SSIM(data_range=1, size_average=True, channel=n_lenslets).to(device_repro)

optimizer = torch.optim.Adam(trainable_params, lr=args.learning_rate)

# timers
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

# create gradient scaler for mixed precision training
scaler = GradScaler()

start_epoch = 0
if len(args.checkpoint_XLFMNet)>0:
    net.load_state_dict(checkpoint_XLFMNet['model_state_dict'], strict=False)
    optimizer.load_state_dict(checkpoint_XLFMNet['optimizer_state_dict'])
    start_epoch = checkpoint_XLFMNet['epoch']-1
 def __init__(self, primary_loss="mse", weights=[1.0, 1.0]):
     self.main_loss = MSE_LOSS(reduction="mean")
     if "bce" in primary_loss: self.main_loss = BCE_LOSS(reduction="mean")
     self.ssim_loss = SSIM()
     self.weights = weights
Esempio n. 26
0
def tensor_ssim_module():
    # reuse the gaussian kernel with SSIM & MS_SSIM.
    ssim_module = SSIM(data_range=255, size_average=True, channel=3)
    ms_ssim_module = MS_SSIM(data_range=255, size_average=True, channel=3)
Esempio n. 27
0
 def __init__(self):
     super(MSSSIM, self).__init__()
     self.ssim = SSIM(data_range=1, size_average=True, channel=3)
Esempio n. 28
0
img1 = torch.from_numpy(npImg1).float().unsqueeze(0).unsqueeze(0) / 255.0
img2 = torch.rand(img1.size())

if torch.cuda.is_available():
    img1 = img1.cuda()
    img2 = img2.cuda()

img1 = Variable(img1, requires_grad=False)
img2 = Variable(img2, requires_grad=True)

ssim_value = ssim(img1, img2).item()
print("Initial ssim:", ssim_value)

ssim_loss = SSIM(win_size=11,
                 win_sigma=1.5,
                 data_range=1,
                 size_average=True,
                 channels=1)

optimizer = optim.Adam([img2], lr=0.01)

while ssim_value < 0.9999:
    optimizer.zero_grad()
    _ssim_loss = 1 - ssim_loss(img1, img2)
    _ssim_loss.backward()
    optimizer.step()

    ssim_value = ssim(img1, img2).item()
    print(ssim_value)

img2_ = (img2 * 255.0).squeeze()
Esempio n. 29
0
import torch
# X: (N,3,H,W) a batch of RGB images (0~255)
# Y: (N,3,H,W)
X = torch.rand(4, 3, 512, 512)
Y = torch.rand(4, 3, 512, 512)
#Y = X

# ssim_val = ssim( X, Y, data_range=1.0, size_average=False) # return (N,)
# ms_ssim_val = ms_ssim( X, Y, data_range=1.0, size_average=False ) #(N,)

# # or set 'size_average=True' to get a scalar value as loss.
# ssim_loss = ssim( X, Y, data_range=1.0, size_average=True) # return a scalar
# ms_ssim_loss = ms_ssim( X, Y, data_range=1.0, size_average=True )

# or reuse windows with SSIM & MS_SSIM.
ssim_module = SSIM(win_size=11,
                   win_sigma=1.5,
                   data_range=1.0,
                   size_average=True,
                   channel=3)
ms_ssim_module = MS_SSIM(win_size=11,
                         win_sigma=1.5,
                         data_range=1.0,
                         size_average=True,
                         channel=3)

ssim_loss = 1 - ssim_module(X, Y)
ms_ssim_loss = 1 - ms_ssim_module(X, Y)

X = torch.rand(4, 3, 512, 512)
Y = torch.rand(4, 3, 512, 512)
Esempio n. 30
0
    def train(self):
        dataset_class = self.datasets[self.args.dataset](
            root=self.args.root,
            add_labeled=self.args.add_labeled,
            advanced_transforms=False,
            merged=self.args.merged,
            remove_classes=self.args.remove_classes,
            oversampling=self.args.oversampling,
            unlabeled_subset_ratio=self.args.unlabeled_subset,
            seed=self.args.seed,
            start_labeled=self.args.start_labeled)

        base_dataset = dataset_class.get_base_dataset_autoencoder()

        train_loader = create_base_loader(base_dataset, self.kwargs,
                                          self.args.batch_size)

        training_loss_log = []

        l1_loss = nn.L1Loss()
        l2_loss = nn.MSELoss()
        ssim_loss = SSIM(size_average=True,
                         data_range=1.0,
                         nonnegative_ssim=True)

        criterions = {'l1': l1_loss, 'l2': l2_loss, 'ssim': ssim_loss}

        model, optimizer, self.args = create_model_optimizer_autoencoder(
            self.args, dataset_class)

        best_loss = np.inf

        for epoch in range(self.args.start_epoch,
                           self.args.autoencoder_train_epochs):
            model.train()
            batch_time = AverageMeter()
            losses = AverageMeter()
            losses_sum = np.zeros(len(criterions.keys()))

            end = time.time()
            for i, (data_x, data_y) in enumerate(train_loader):
                data_x = data_x.cuda(non_blocking=True)

                output = model(data_x)

                losses_alt = np.array([
                    v(output, data_x).cpu().detach().data.item()
                    for v in criterions.values()
                ])
                losses_alt[-1] = 1 - losses_alt[-1]
                losses_sum = losses_sum + losses_alt
                loss = criterions['l2'](
                    output, data_x) + (1 - criterions['ssim'](output, data_x))

                losses.update(loss.data.item(), data_x.size(0))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                batch_time.update(time.time() - end)
                end = time.time()

                if i % self.args.print_freq == 0:
                    print('Epoch: [{0}][{1}/{2}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                              epoch,
                              i,
                              len(train_loader),
                              batch_time=batch_time,
                              loss=losses))

            losses_avg = losses_sum / len(train_loader)
            training_loss_log.append(losses_avg.tolist())

            is_best = best_loss > losses.avg
            best_loss = min(best_loss, losses.avg)

            save_checkpoint(
                self.args, {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_loss,
                }, is_best)

        if self.args.store_logs and not self.args.resume:
            store_logs(self.args,
                       pd.DataFrame(training_loss_log,
                                    columns=['l1', 'l2', 'ssim']),
                       log_type='ae_loss')

        self.model = model
        return model