예제 #1
0
def main():
    torch.cuda.manual_seed_all(Args.seed)

    train_transform = transforms.Compose([
        transforms.Resize(Args.resized),
        transforms.Grayscale(Args.num_channel),
        transforms.ToTensor()
    ])
    coco_train = COCO(Args.train_path, transform=train_transform)
    trainloader = DataLoader(coco_train,
                             batch_size=Args.batch_size,
                             shuffle=True,
                             num_workers=min(4, Args.batch_size),
                             pin_memory=True)
    loaders = {'train': trainloader}

    model = DenseFuse(num_channel=Args.num_channel)
    model.to(Args.device)

    optimizer = optim.Adam(model.parameters(), lr=Args.lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=1, factor=Args.lr_decay_factor)

    ms_ssim = MS_SSIM(data_range=1.0, size_average=True, channel=1)
    criterions = {'ms_ssim': ms_ssim}

    ckpt_name = time.ctime().replace(' ', '-').replace(':', '-')
    ckptPath = Args.ckptPath.joinpath(ckpt_name)
    train(loaders,
          model,
          criterions,
          optimizer,
          Args.num_epochs,
          ckptPath,
          scheduler=scheduler)
예제 #2
0
def ssim_loss(X, Y):

    # # X: (N,3,H,W) a batch of non-negative RGB images (0~255)
    # # Y: (N,3,H,W)
    #
    # # calculate ssim & ms-ssim for each image
    # ssim_val = ssim(X, Y, data_range=255, size_average=False)  # return (N,)
    # ms_ssim_val = ms_ssim(X, Y, data_range=255, size_average=False)  # (N,)
    #
    # # set 'size_average=True' to get a scalar value as loss.
    # ssim_loss = 1 - ssim(X, Y, data_range=255, size_average=True)  # return a scalar
    # ms_ssim_loss = 1 - ms_ssim(X, Y, data_range=255, size_average=True)

    # reuse the gaussian kernel with SSIM & MS_SSIM.
    ssim_module = SSIM(data_range=255,
                       size_average=True,
                       channel=1,
                       nonnegative_ssim=False)
    ms_ssim_module = MS_SSIM(data_range=255,
                             size_average=True,
                             channel=1,
                             nonnegative_ssim=False)

    ssim_loss = 1 - ssim_module(X, Y)
    ms_ssim_loss = 1 - ms_ssim_module(X, Y)

    return ms_ssim_loss
예제 #3
0
    def __init__(self, name,
        dip_n_iter=8000, net='skip',
        lr=0.001, reg_std=1./100,
         w_proj_loss=1.0, w_perceptual_loss=0.0, 
         w_ssim_loss=0.0, w_tv_loss=0.0, randomize_projs=None,
         channels=[16, 32, 64, 128, 256]):
        super(DgrReconstructor, self).__init__(name)
        self.n_iter = dip_n_iter
        assert net in ['skip', 'skipV2', 'skipV3', 'unet', 'dncnn']
        self.net = net
        self.channels = channels
        self.lr = lr
        self.reg_std = reg_std
        # loss weights
        self.w_proj_loss = w_proj_loss
        self.w_perceptual_loss = w_perceptual_loss
        self.w_tv_loss = w_tv_loss
        self.w_ssim_loss = w_ssim_loss
        self.randomize_projs = randomize_projs
        # loss functions
        self.mse = torch.nn.MSELoss().to(self.DEVICE)
        self.ssim = MS_SSIM(data_range=1.0, size_average=True, channel=self.IMAGE_DEPTH).to(self.DEVICE)
        self.perceptual = VGGPerceptualLoss(resize=True).to(self.DEVICE)

        self.gt = None
        self.noisy = None
        self.FOCUS = None
        self.log_dir = None
예제 #4
0
파일: utils.py 프로젝트: btolooshams/crsae
    def forward(self, input, target):

        loss = self.a * (1 - MS_SSIM(win_size=self.win_size,
                                     data_range=self.data_range,
                                     channel=self.channel)(input, target)
                         ) + (1 - self.a) * torch.nn.L1Loss()(input, target)
        return loss
예제 #5
0
def criterions(name):
    if name == "mse":
        return nn.MSELoss()
    elif name == "l1":
        return nn.L1Loss()
    elif name == "lpips":
        return lpips.LPIPS(net="vgg").cuda()
    elif name == "ms_ssim":
        return MS_SSIM(data_range=1.0)
예제 #6
0
 def __init__(self, eps=1e-6, lambda_=1):
     super(CharbonnierLossPlusMSSSIM, self).__init__()
     self.eps = eps
     self.lambda_ = lambda_
     self.ms_ssim_module = MS_SSIM(win_size=11,
                                   win_sigma=1.5,
                                   data_range=1.0,
                                   size_average=True,
                                   channel=3)
예제 #7
0
    def __init__(self, kernel_w=3, sigma=1.5, channels=1, weights=None):
        super().__init__()
        self.kernel_w = kernel_w
        self.sigma = sigma

        #number of weights determines the depth of the pyramid
        #standard are 5, too deep for MNist resolution
        if weights is None:
            self.weights = [0.0516, 0.32949, 0.34622, 0.27261]
        else:
            self.weights = weights
        self.ssim_d = MS_SSIM(win_size=kernel_w, win_sigma=sigma, data_range=1.0,
                              channel=channels, weights=self.weights, size_average=False)

        self.config =  {'Distance' : 'SSIM', 'win_size': kernel_w, 'win_sigma': sigma}
예제 #8
0
    def set_criterion(self, level=None):
        assert level in [0, 1, 2, 3, 4, 5], 'unknown level'
        criterions = [nn.MSELoss()]
        coefficients = [1.0]

        if level<4:
            perceptual_loss = VGG19Loss(['relu5_4'])
            perceptual_loss.to(self.device)
            criterions.append(perceptual_loss)
            coefficients.append(0.01)

        if level==0:
            criterions.append(MS_SSIM(data_range=1.0, size_average=True, channel=3))
            if self.perceptual:
                coefficients.append(-0.1)
            else:
                coefficients.append(-0.01)

        self.criterions = criterions
        self.coefficients = torch.FloatTensor(coefficients).to(self.device)
예제 #9
0
    def __init__(self, args, ckp):
        super(Loss, self).__init__()
        print('Preparing loss function:')

        self.n_GPUs = args.n_GPUs
        self.loss = []
        self.loss_module = nn.ModuleList()
        for loss in args.loss.split('+'):
            weight, loss_type = loss.split('*')
            if loss_type == 'MSE':  # L2 loss
                loss_function = nn.MSELoss()
            elif loss_type == 'L1':
                loss_function = nn.L1Loss()
            elif loss_type.find('VGG') >= 0:
                module = import_module('loss.vgg')
                loss_function = getattr(module,
                                        'VGG')(loss_type[3:],
                                               rgb_range=args.rgb_range)
            elif loss_type.find('TextureL') >= 0:
                module = import_module('loss.vgg')
                loss_function = getattr(module,
                                        'VGG')(loss_type[3:],
                                               rgb_range=args.rgb_range,
                                               texture_loss=True)
            elif loss_type.find('GAN') >= 0:
                module = import_module('loss.adversarial')
                loss_function = getattr(module, 'Adversarial')(args, loss_type)
            elif loss_type.find('TVLoss') >= 0:
                module = import_module('loss.tvloss')
                loss_function = getattr(module, 'TVLoss')()

            elif loss_type.find('SSIM') >= 0:
                from pytorch_msssim import SSIM
                loss_function = SSIM(win_size=7,
                                     win_sigma=1,
                                     data_range=args.rgb_range,
                                     size_average=True,
                                     channel=3)
            elif loss_type.find('MS-SSIM') >= 0:
                from pytorch_msssim import MS_SSIM
                loss_function = MS_SSIM(win_sigma=1,
                                        data_range=args.rgb_range,
                                        size_average=True,
                                        channel=3)

            self.loss.append({
                'type': loss_type,
                'weight': float(weight),
                'function': loss_function
            })
            if loss_type.find('GAN') >= 0:
                self.loss.append({
                    'type': 'DIS',
                    'weight': 1,
                    'function': None
                })

        if len(self.loss) > 1:
            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})

        for l in self.loss:
            if l['function'] is not None:
                print('{:.6f} * {}'.format(l['weight'], l['type']))
                self.loss_module.append(l['function'])

        self.log = torch.Tensor()

        device = torch.device('cpu' if args.cpu else 'cuda')
        self.loss_module.to(device)
        if args.precision == 'half': self.loss_module.half()
        if not args.cpu and args.n_GPUs > 1:
            self.loss_module = nn.DataParallel(self.loss_module,
                                               range(args.n_GPUs))

        if args.load != '': self.load(ckp.dir, cpu=args.cpu)
예제 #10
0
 def __init__(self, alpha=0.84):
     super(ReconstructionLoss, self).__init__()
     self.alpha = alpha
     self.l1 = nn.L1Loss()
     self.ms_ssim = MS_SSIM(data_range=1, size_average=True)
예제 #11
0
                 0.02,
                 gpu_id=device)

# VGG for perceptual loss
if opt.lamb_content > 0:
    vgg = Vgg16()
    init_vgg16(root_path)
    vgg.load_state_dict(torch.load(os.path.join(root_path, "vgg16.weight")))
    vgg.to(device)

# define loss
criterionL1 = nn.L1Loss().to(device)
criterionL2 = nn.MSELoss().to(device)
criterionMSE = nn.MSELoss().to(device)
criterionSSIM = SSIM(data_range=255, size_average=True, channel=3)
criterionMSSSIM1 = MS_SSIM(data_range=255, size_average=True, channel=1)
criterionMSSSIM3 = MS_SSIM(data_range=255, size_average=True, channel=3)

# setup optimizer
optimizer_i = optim.Adam(net_i.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
optimizer_r = optim.Adam(net_r.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
net_i_scheduler = get_scheduler(optimizer_i, opt)
net_r_scheduler = get_scheduler(optimizer_r, opt)

loss_i_list = []
loss_r_list = []
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
예제 #12
0
                                   channel).to(img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return ssim(img1,
                    img2,
                    window=window,
                    window_size=self.window_size,
                    size_average=self.size_average)


class MSSSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, channel=3):
        super(MSSSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = channel

    def forward(self, img1, img2):
        # TODO: store window between calls if possible
        return msssim(img1,
                      img2,
                      window_size=self.window_size,
                      size_average=self.size_average)


vgg_model = VGG19().cuda()
vgg_model = vgg_model.eval()
GAN_loss_calculator = GANLoss()
mssim_calculator = MS_SSIM(data_range=1.0, size_average=True, channel=3)
예제 #13
0
def tensor_ssim_module():
    # reuse the gaussian kernel with SSIM & MS_SSIM.
    ssim_module = SSIM(data_range=255, size_average=True, channel=3)
    ms_ssim_module = MS_SSIM(data_range=255, size_average=True, channel=3)
예제 #14
0
def do_learn(opt, run_dir="./runs"):
    print('Starting ', opt.run_path)
    path_data = os.path.join(run_dir, opt.run_path)
    # ----------
    #  Tensorboard
    # ----------
    if do_tensorboard:
        # stats are stored in "runs", within subfolder opt.run_path.
        writer = SummaryWriter(log_dir=path_data)

    # Create a time tag
    import datetime
    try:
        tag = datetime.datetime.now().isoformat(sep='_', timespec='seconds')
    except TypeError:
        # Python 3.5 and below
        # 'timespec' is an invalid keyword argument for this function
        tag = datetime.datetime.now().replace(microsecond=0).isoformat(sep='_')
    tag = tag.replace(':', '-')

    # Configure data loader
    dataloader = load_data(opt.datapath,
                           opt.img_size,
                           opt.batch_size,
                           rand_hflip=opt.rand_hflip,
                           rand_affine=opt.rand_affine)

    if opt.do_SSIM:
        # from pytorch_msssim import NMSSSIM
        # E_loss = NMSSSIM(window_size=opt.window_size, val_range=1., size_average=True, channel=3, normalize=True)
        # from pytorch_msssim import NSSIM #as neg_SSIM
        # E_loss = NSSIM(window_size=opt.window_size, val_range=1., size_average=True)
        # NEW: we use https://github.com/VainF/pytorch-msssim instead of https://github.com/SpikeAI/pytorch-msssim
        #from pytorch_msssim import msssim, ssim
        from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
        E_loss = MS_SSIM(win_size=opt.window_size,
                         data_range=1,
                         size_average=True,
                         channel=3)

    else:
        E_loss = torch.nn.MSELoss(reduction='sum')

    sigmoid = torch.nn.Sigmoid()

    # Initialize generator and discriminator
    generator = Generator(opt)
    discriminator = Discriminator(opt)
    encoder = Encoder(opt)

    if opt.verbose:
        print_network(generator)
        print_network(discriminator)
        print_network(encoder)

    eye = 1 - torch.eye(opt.batch_size)
    use_cuda = True if torch.cuda.is_available() else False
    if use_cuda:
        #print("Nombre de GPU : ",torch.cuda.device_count())
        print("Running on GPU : ", torch.cuda.get_device_name())
        # if torch.cuda.device_count() > opt.GPU:
        #     torch.cuda.set_device(opt.GPU)
        generator.cuda()
        discriminator.cuda()
        # adversarial_loss.cuda()
        encoder.cuda()
        # MSE_loss.cuda()
        E_loss.cuda()
        eye = eye.cuda()

        Tensor = torch.cuda.FloatTensor
    else:
        print("Running on CPU ")
        Tensor = torch.FloatTensor

    # Initialize weights
    if opt.init_weight:
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)
        encoder.apply(weights_init_normal)

    # Optimizers
    if opt.optimizer == 'rmsprop':
        # https://pytorch.org/docs/stable/optim.html#torch.optim.RMSprop
        opts = dict(momentum=1 - opt.beta1, alpha=opt.beta2)
        optimizer = torch.optim.RMSprop
    elif opt.optimizer == 'adam':
        # https://pytorch.org/docs/stable/optim.html#torch.optim.Adam
        opts = dict(betas=(opt.beta1, opt.beta2))
        optimizer = torch.optim.Adam
    elif opt.optimizer == 'sgd':
        opts = dict(momentum=1 - opt.beta1,
                    nesterov=True,
                    weight_decay=1 - opt.beta2)
        optimizer = torch.optim.SGD
    else:
        raise ('wrong optimizer')

    optimizer_G = optimizer(generator.parameters(), lr=opt.lrG, **opts)
    optimizer_D = optimizer(discriminator.parameters(), lr=opt.lrD, **opts)
    if opt.do_joint:
        import itertools
        optimizer_E = optimizer(itertools.chain(encoder.parameters(),
                                                generator.parameters()),
                                lr=opt.lrE,
                                **opts)
    else:
        optimizer_E = optimizer(encoder.parameters(), lr=opt.lrE, **opts)

    # TODO parameterize scheduler !
    # gamma = .1 ** (1 / opt.n_epochs)
    # schedulers = []
    # for optimizer in [optimizer_G, optimizer_D, optimizer_E]:
    #     schedulers.append(torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma))

    # ----------
    #  Training
    # ----------

    nb_batch = len(dataloader)

    stat_record = init_hist(opt.n_epochs, nb_batch)

    # https://github.com/soumith/dcgan.torch/issues/14  dribnet commented on 21 Mar 2016
    # https://arxiv.org/abs/1609.04468
    def slerp(val, low, high):
        corr = np.diag(
            (low / np.linalg.norm(low)) @ (high / np.linalg.norm(high)).T)
        omega = np.arccos(np.clip(corr, -1, 1))[:, None]
        so = np.sin(omega)
        out = np.sin(
            (1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high
        # L'Hopital's rule/LERP
        out[so[:, 0] == 0, :] = (
            1.0 - val) * low[so[:, 0] == 0, :] + val * high[so[:, 0] == 0, :]
        return out

    def norm2(z):
        """
        L2-norm of a tensor.

        outputs a scalar
        """
        # return torch.mean(z.pow(2)).pow(.5)
        return (z**2).sum().sqrt()

    def gen_z(imgs=None, rho=.25, do_slerp=opt.do_slerp):
        """
        Generate noise in the feature space.

        outputs a vector
        """
        if not imgs is None:
            z_imgs = encoder(imgs).cpu().numpy()
            if do_slerp:
                z_shuffle = z_imgs.copy()
                z_shuffle = z_shuffle[torch.randperm(opt.batch_size), :]
                z = slerp(rho, z_imgs, z_shuffle)
            else:
                z /= norm2(z)
                z_imgs /= norm2(z_imgs)
                z = (1 - rho) * z_imgs + rho * z
                z /= norm2(z)
        else:
            z = np.random.normal(0, 1, (opt.batch_size, opt.latent_dim))
        # convert to tensor
        return Variable(Tensor(z), requires_grad=False)

    def gen_noise(imgs):
        """
        Generate noise in the image space

        outputs an image
        """
        v_noise = np.random.normal(0, 1, imgs.shape)  # one random image
        # one contrast value per image
        v_noise *= np.abs(
            np.random.normal(0, 1, (imgs.shape[0], opt.channels, 1, 1)))
        # convert to tensor
        v_noise = Variable(Tensor(v_noise), requires_grad=False)
        return v_noise

    # Vecteur z fixe pour faire les samples
    fixed_noise = gen_z()
    real_imgs_samples = None

    # z_zeros = Variable(Tensor(opt.batch_size, opt.latent_dim).fill_(0), requires_grad=False)
    # z_ones = Variable(Tensor(opt.batch_size, opt.latent_dim).fill_(1), requires_grad=False)
    # Adversarial ground truths
    # valid = Variable(Tensor(opt.batch_size, 1).fill_(1), requires_grad=False)
    # fake = Variable(Tensor(opt.batch_size, 1).fill_(0), requires_grad=False)

    t_total = time.time()
    for i_epoch, epoch in enumerate(range(1, opt.n_epochs + 1)):
        t_epoch = time.time()
        for iteration, (imgs, _) in enumerate(dataloader):
            t_batch = time.time()

            # ---------------------
            #  Train Encoder
            # ---------------------
            for p in generator.parameters():
                p.requires_grad = opt.do_joint
            for p in encoder.parameters():
                p.requires_grad = True
            # the following is not necessary as we do not use D here and only optimize ||G(E(x)) - x ||^2
            for p in discriminator.parameters():
                p.requires_grad = False  # to avoid learning D when learning E

            real_imgs = Variable(imgs.type(Tensor), requires_grad=False)

            # init samples used to visualize performance of the AE
            if real_imgs_samples is None:
                real_imgs_samples = real_imgs[:opt.N_samples]

            # add noise here to real_imgs
            real_imgs_ = real_imgs * 1.
            if opt.E_noise > 0:
                real_imgs_ += opt.E_noise * gen_noise(real_imgs)

            z_imgs = encoder(real_imgs_)
            decoded_imgs = generator(z_imgs)

            # Loss measures Encoder's ability to generate vectors suitable with the generator
            e_loss = 1. - E_loss(real_imgs, decoded_imgs)
            # energy = 1. # E_loss(real_imgs, zero_target)  # normalize on the energy of imgs
            # if opt.do_joint:
            #     e_loss = E_loss(real_imgs, decoded_imgs) / energy
            # else:
            #     e_loss = E_loss(real_imgs, decoded_imgs.detach()) / energy

            if opt.lambdaE > 0:
                # We wish to make sure the intermediate vector z_imgs get closer to a iid normal (centered gausian of variance 1)
                e_loss += opt.lambdaE * (torch.sum(z_imgs) / opt.batch_size /
                                         opt.latent_dim).pow(2)
                e_loss += opt.lambdaE * (torch.sum(z_imgs.pow(2)) /
                                         opt.batch_size / opt.latent_dim -
                                         1).pow(2).pow(.5)

            # Backward
            optimizer_E.zero_grad()
            e_loss.backward()
            optimizer_E.step()

            valid_smooth = np.random.uniform(opt.valid_smooth,
                                             1.0 - (1 - opt.valid_smooth) / 2,
                                             (opt.batch_size, 1))
            valid_smooth = Variable(Tensor(valid_smooth), requires_grad=False)
            fake_smooth = np.random.uniform((1 - opt.valid_smooth) / 2,
                                            1 - opt.valid_smooth,
                                            (opt.batch_size, 1))
            fake_smooth = Variable(Tensor(fake_smooth), requires_grad=False)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            # Discriminator Requires grad, Encoder + Generator requires_grad = False
            for p in discriminator.parameters():
                p.requires_grad = True
            for p in generator.parameters():
                p.requires_grad = False  # to avoid computation
            for p in encoder.parameters():
                p.requires_grad = False  # to avoid computation

            # Configure input
            real_imgs = Variable(imgs.type(Tensor), requires_grad=False)
            real_imgs_ = real_imgs * 1.
            if opt.D_noise > 0:
                real_imgs_ += opt.D_noise * gen_noise(real_imgs)
            if opt.do_insight:
                # the discriminator can not access the images directly but only
                # what is visible through the auto-encoder
                real_imgs_ = generator(encoder(real_imgs_))

            # Discriminator decision (in logit units)
            # TODO : group images by sub-batches and train to discriminate from all together
            # should allow to avoid mode collapse
            logit_d_x = discriminator(real_imgs_)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            if opt.GAN_loss == 'wasserstein':
                # weight clipping
                for p in discriminator.parameters():
                    p.data.clamp_(-0.01, 0.01)

            # Measure discriminator's ability to classify real from generated samples
            if opt.GAN_loss == 'ian':
                # eq. 14 in https://arxiv.org/pdf/1701.00160.pdf
                real_loss = -torch.sum(1 / (1. - 1 / sigmoid(logit_d_x)))
            elif opt.GAN_loss == 'hinge':
                # TODO check if we use p or log p
                real_loss = nn.ReLU()(valid_smooth - sigmoid(logit_d_x)).mean()
            elif opt.GAN_loss == 'wasserstein':
                real_loss = torch.mean(
                    torch.abs(valid_smooth - sigmoid(logit_d_x)))
            elif opt.GAN_loss == 'alternative':
                # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/
                real_loss = -torch.sum(torch.log(sigmoid(logit_d_x)))
            elif opt.GAN_loss == 'alternativ2':
                # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/
                real_loss = -torch.sum(
                    torch.log(sigmoid(logit_d_x) / (1. - sigmoid(logit_d_x))))
            elif opt.GAN_loss == 'alternativ3':
                # to maximize D(x), we minimize  - sum(logit_d_x)
                real_loss = -torch.sum(logit_d_x)
            elif opt.GAN_loss == 'original':
                real_loss = F.binary_cross_entropy(sigmoid(logit_d_x),
                                                   valid_smooth)
            else:
                print('GAN_loss not defined', opt.GAN_loss)

            # Generate a batch of fake images and learn the discriminator to treat them as such
            z = gen_z(imgs=real_imgs_)
            gen_imgs = generator(z)
            if opt.D_noise > 0: gen_imgs += opt.D_noise * gen_noise(real_imgs)

            # Discriminator decision for fake data
            logit_d_fake = discriminator(gen_imgs.detach())
            # Measure discriminator's ability to classify real from generated samples
            if opt.GAN_loss == 'wasserstein':
                fake_loss = torch.mean(sigmoid(logit_d_fake))
            elif opt.GAN_loss == 'hinge':
                # TODO check if we use p or log p
                real_loss = nn.ReLU()(1.0 + sigmoid(logit_d_fake)).mean()
            elif opt.GAN_loss == 'alternative':
                # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/
                fake_loss = -torch.sum(torch.log(1 - sigmoid(logit_d_fake)))
            elif opt.GAN_loss == 'alternativ2':
                # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/
                fake_loss = torch.sum(
                    torch.log(
                        sigmoid(logit_d_fake) / (1. - sigmoid(logit_d_fake))))
            elif opt.GAN_loss == 'alternativ3':
                # to minimize D(G(z)), we minimize sum(logit_d_fake)
                fake_loss = torch.sum(logit_d_fake)
            elif opt.GAN_loss in ['original', 'ian']:
                fake_loss = F.binary_cross_entropy(sigmoid(logit_d_fake),
                                                   fake_smooth)
            else:
                print('GAN_loss not defined', opt.GAN_loss)

            # Backward
            optimizer_D.zero_grad()
            real_loss.backward()
            fake_loss.backward()
            # apply the gradients
            optimizer_D.step()

            # -----------------
            #  Train Generator
            # -----------------
            for p in generator.parameters():
                p.requires_grad = True
            for p in discriminator.parameters():
                p.requires_grad = False  # to avoid computation
            for p in encoder.parameters():
                p.requires_grad = False  # to avoid computation

            # Generate a batch of fake images
            z = gen_z(imgs=real_imgs_)
            gen_imgs = generator(z)
            if opt.G_noise > 0: gen_imgs += opt.G_noise * gen_noise(real_imgs)

            # New discriminator decision (since we just updated D)
            logit_d_g_z = discriminator(gen_imgs)

            # Loss functions
            # Loss measures generator's ability to fool the discriminator
            if opt.GAN_loss == 'ian':
                # eq. 14 in https://arxiv.org/pdf/1701.00160.pdf
                # https://en.wikipedia.org/wiki/Logit
                g_loss = -torch.sum(
                    sigmoid(logit_d_g_z) / (1 - sigmoid(logit_d_g_z)))
            elif opt.GAN_loss == 'wasserstein' or opt.GAN_loss == 'hinge':
                g_loss = torch.mean(
                    torch.abs(valid_smooth - sigmoid(logit_d_g_z)))
            elif opt.GAN_loss == 'alternative':
                # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/
                g_loss = -torch.sum(torch.log(sigmoid(logit_d_g_z)))
            elif opt.GAN_loss == 'alternativ2':
                # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/
                g_loss = -torch.sum(
                    torch.log(
                        sigmoid(logit_d_g_z) / (1. - sigmoid(logit_d_g_z))))
                # g_loss = torch.sum(torch.log(1./sigmoid(logit_d_g_z) - 1.))
            elif opt.GAN_loss == 'alternativ3':
                # to maximize D(G(z)), we minimize - sum(logit_d_g_z)
                g_loss = -torch.sum(logit_d_g_z)
            elif opt.GAN_loss == 'original':
                # https://pytorch.org/docs/stable/nn.html?highlight=bcewithlogitsloss#torch.nn.BCEWithLogitsLoss
                #adversarial_loss = torch.nn.BCEWithLogitsLoss()  # eq. 8 in https://arxiv.org/pdf/1701.00160.pdf
                #
                # https://medium.com/swlh/gan-to-generate-images-of-cars-5f706ca88da
                # adversarial_loss = torch.nn.BCE()  # eq. 8 in https://arxiv.org/pdf/1701.00160.pdf
                g_loss = F.binary_cross_entropy(sigmoid(logit_d_g_z),
                                                valid_smooth)
            else:
                print('GAN_loss not defined', opt.GAN_loss)

            # penalize low variability in a batch, that is, mode collapse
            # TODO maximize sum of the distances to the nearest neighbors
            if opt.lambdaG > 0:
                e_g_z = encoder(gen_imgs)  # get normal vectors
                Xcorr = torch.tensordot(e_g_z, torch.transpose(e_g_z, 0, 1),
                                        1) / opt.latent_dim
                Xcorr *= eye  # set the diagonal elements to zero
                g_loss += opt.lambdaG * torch.sum(Xcorr.pow(2)).pow(.5)

            # Backward
            optimizer_G.zero_grad()
            g_loss.backward()
            # apply the gradients
            optimizer_G.step()

            # -----------------
            #  Recording stats
            # -----------------
            d_loss = real_loss + fake_loss

            # Compensation pour le BCElogits
            d_fake = sigmoid(logit_d_fake)
            d_x = sigmoid(logit_d_x)
            d_g_z = sigmoid(logit_d_g_z)
            print(
                "%s [Epoch %d/%d] [Batch %d/%d] [E loss: %f] [D loss: %f] [G loss: %f] [D(x) %f] [D(G(z)) %f] [D(G(z')) %f] [Time: %fs]"
                % (opt.run_path, epoch, opt.n_epochs, iteration + 1,
                   len(dataloader), e_loss.item(), d_loss.item(),
                   g_loss.item(), torch.mean(d_x), torch.mean(d_fake),
                   torch.mean(d_g_z), time.time() - t_batch))
            # Save Losses and scores for Tensorboard
            save_hist_batch(stat_record, iteration, i_epoch, g_loss, d_loss,
                            e_loss, d_x, d_g_z)

        if do_tensorboard:
            # Tensorboard save
            writer.add_scalar('loss/E', e_loss.item(), global_step=epoch)
            # writer.add_histogram('coeffs/z', z, global_step=epoch)
            try:
                writer.add_histogram('coeffs/E_x', z_imgs, global_step=epoch)
            except:
                pass
            # writer.add_histogram('image/x', real_imgs, global_step=epoch)
            # try:
            #     writer.add_histogram('image/E_G_x', decoded_imgs, global_step=epoch)
            # except:
            #     pass
            # try:
            #     writer.add_histogram('image/G_z', gen_imgs, global_step=epoch)
            # except:
            #     pass
            writer.add_scalar('loss/G', g_loss.item(), global_step=epoch)
            # writer.add_scalar('score/D_fake', hist["d_fake_mean"][i], global_step=epoch)
            # print(stat_record["d_g_z_mean"])
            writer.add_scalar('score/D_g_z',
                              np.mean(stat_record["d_g_z_mean"]),
                              global_step=epoch)
            writer.add_scalar('loss/D', d_loss.item(), global_step=epoch)

            writer.add_scalar('score/D_x',
                              np.mean(stat_record["d_x_mean"]),
                              global_step=epoch)

            # Save samples
            if epoch % opt.sample_interval == 0:
                """
                Use generator model and noise vector to generate images.
                Save them to tensorboard
                """
                generator.eval()
                gen_imgs = generator(fixed_noise)
                from torchvision.utils import make_grid
                grid = make_grid(gen_imgs,
                                 normalize=True,
                                 nrow=16,
                                 range=(0, 1))
                writer.add_image('Generated images', grid, epoch)
                generator.train()
                """
                Use auto-encoder model and original images to generate images.
                Save them to tensorboard

                """
                # grid_imgs = make_grid(real_imgs_samples, normalize=True, nrow=8, range=(0, 1))
                # writer.add_image('Images/original', grid_imgs, epoch)

                generator.eval()
                encoder.eval()
                enc_imgs = encoder(real_imgs_samples)
                dec_imgs = generator(enc_imgs)
                grid_dec = make_grid(dec_imgs,
                                     normalize=True,
                                     nrow=16,
                                     range=(0, 1))
                # writer.add_image('Images/auto-encoded', grid_dec, epoch)
                writer.add_image('Auto-encoded', grid_dec, epoch)
                generator.train()
                encoder.train()
                # writer.add_graph(encoder, real_imgs_samples)
                # writer.add_graph(generator, enc_imgs)
                # writer.add_graph(discriminator, real_imgs_samples)
                #

        # if epoch % opt.sample_interval == 0 :
        #     sampling(fixed_noise, generator, path_data, epoch, tag)
        #     # do_plot(hist, start_epoch, epoch)

        print("[Epoch Time: ", time.time() - t_epoch, "s]")

    sampling(fixed_noise, generator, path_data, epoch, tag, nrow=16)

    # for scheduler in schedulers: scheduler.step()
    t_final = time.gmtime(time.time() - t_total)
    print("[Total Time: ",
          t_final.tm_mday - 1,
          "j:",
          time.strftime("%Hh:%Mm:%Ss", t_final),
          "]",
          sep='')

    if do_tensorboard:
        writer.close()
예제 #15
0
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        device = torch.device(
            "cuda:0" if torch.cuda.is_available() and self.hparams.cuda else "cpu")

        if self.hparams.modelID == 0:
            self.net = ResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks,
                              starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks,
                              is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm,
                              res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV,
                              upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D)  # TODO think of 2D
            # self.net = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1)
        elif self.hparams.modelID == 2:
            self.net = DualSpaceResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks,
                                        starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks,
                                        is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm,
                                        res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV,
                                        upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D,
                                        connect_mode=self.hparams.model_dspace_connect_mode, inner_norm_ksp=self.hparams.model_inner_norm_ksp)
        elif self.hparams.modelID == 3: #Primal-Dual Network, complex Primal
            self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = True,
                            g_normtype = "magmax",
                            transform = "Fourier",
                            return_abs = True)
        elif self.hparams.modelID == 4: #Primal-Dual Network, absolute Primal
            self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = False,
                            g_normtype = "magmax",
                            transform = "Fourier")
        elif self.hparams.modelID == 5: #Primal-Dual UNet Network, absolute Primal
            self.net = PrimalDualNetwork(n_primary=4, n_dual=5, n_iterations=2,
                            use_original_block = False,
                            use_original_init = False,
                            use_complex_primal = False,
                            g_normtype = "magmax",
                            transform = "Fourier")
        elif self.hparams.modelID == 6: #Primal-Dual Network v2 (no residual), complex Primal
            self.net = PrimalDualNetworkNoResidue(n_primary=5, n_dual=5, n_iterations=10,
                            use_original_block = True,
                            use_original_init = True,
                            use_complex_primal = True,
                            residuals=False,
                            g_normtype = "magmax",
                            transform = "Fourier",
                            return_abs = True)

        else:
            # TODO: other models
            sys.exit("Only ReconResNet and DualSpaceResNet have been implemented so far in ReconEngine")

        if bool(self.hparams.preweights_path):
            print("Pre-weights found, loding...")
            chk = torch.load(self.hparams.preweights_path, map_location='cpu')
            self.net.load_state_dict(chk['state_dict'])

        if self.hparams.lossID == 0:
            if self.hparams.in_channels != 1 or self.hparams.out_channels != 1:
                sys.exit(
                    "Perceptual Loss used here only works for 1 channel input and output")
            self.loss = PerceptualLoss(device=device, loss_model="unet3Dds", resize=None,
                                       loss_type=self.hparams.ploss_type, n_level=self.hparams.ploss_level)  # TODO thinkof 2D
        elif self.hparams.lossID == 1:
            self.loss = nn.L1Loss(reduction='mean')
        elif self.hparams.lossID == 2:
            self.loss = MS_SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device)
        elif self.hparams.lossID == 3:
            self.loss = SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device)
        else:
            sys.exit("Invalid Loss ID")

        self.dataspace = DataSpaceHandler(**self.hparams)

        if self.hparams.ds_mode == 0:
            trans = tioTransforms
            augs = tioAugmentations
        elif self.hparams.ds_mode == 1:
            trans = pytTransforms
            augs = pytAugmentations

        # TODO parameterised everything
        self.init_transforms = []
        self.aug_transforms = []
        self.transforms = []
        if self.hparams.ds_mode == 0 and self.hparams.cannonicalResample:  # Only applicable for TorchIO
            self.init_transforms += [tio.ToCanonical(), tio.Resample('gt')]
        if self.hparams.ds_mode == 0 and self.hparams.forceNormAffine:  # Only applicable for TorchIO
            self.init_transforms += [trans.ForceAffine()]
        if self.hparams.croppad and self.hparams.ds_mode == 1:
            self.init_transforms += [
                trans.CropOrPad(size=self.hparams.input_shape)]
        self.init_transforms += [trans.IntensityNorm(type=self.hparams.norm_type, return_meta=self.hparams.motion_return_meta)]
        # dataspace_transforms = self.dataspace.getTransforms() #TODO: dataspace transforms are not in use
        # self.init_transforms += dataspace_transforms
        if bool(self.hparams.random_crop) and self.hparams.ds_mode == 1:
            self.aug_transforms += [augs.RandomCrop(
                size=self.hparams.random_crop, p=self.hparams.p_random_crop)]
        if self.hparams.p_contrast_augment > 0:
            self.aug_transforms += [augs.getContrastAugs(
                p=self.hparams.p_contrast_augment)]
        # if the task if MoCo and pre-corrupted vols are not supplied
        if self.hparams.taskID == 1 and not bool(self.hparams.train_path_inp):
            if self.hparams.motion_mode == 0 and self.hparams.ds_mode == 0:
                motion_params = {k.split('motionmg_')[
                    1]: v for k, v in self.hparams.items() if k.startswith('motionmg')}
                self.transforms += [tioMotion.RandomMotionGhostingFast(
                    **motion_params), trans.IntensityNorm()]
            elif self.hparams.motion_mode == 1 and self.hparams.ds_mode == 1 and not self.hparams.is3D:
                self.transforms += [pytMotion.Motion2Dv0(
                    sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)]
            elif self.hparams.motion_mode == 2 and self.hparams.ds_mode == 1 and not self.hparams.is3D:
                self.transforms += [pytMotion.Motion2Dv1(sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads,
                                                         restore_original=self.hparams.motion_restore_original, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)]
            else:
                sys.exit(
                    "Error: invalid motion_mode, ds_mode, is3D combo. Please double check!")

        self.static_metamat = sio.loadmat(self.hparams.static_metamat_file) if bool(
            self.hparams.static_metamat_file) else None
        if self.hparams.taskID == 0 and self.hparams.use_datacon:
            self.datacon = DataConsistency(
                isRadial=self.hparams.is_radial, metadict=self.static_metamat)
        else:
            self.datacon = None

        input_shape = self.hparams.input_shape if self.hparams.is3D else self.hparams.input_shape[
            :-1]
        self.example_input_array = torch.empty(
            self.hparams.batch_size, self.hparams.in_channels, *input_shape).float()
        self.saver = ResSaver(
            self.hparams.res_path, save_inp=self.hparams.save_inp, do_norm=self.hparams.do_savenorm)
예제 #16
0
def loss_new_msssim(x, y):
    msssim_loss = MS_SSIM(data_range=10, channel=2)
    loss = 1 - msssim_loss(x, y)

    return loss
예제 #17
0
import torch
# X: (N,3,H,W) a batch of RGB images (0~255)
# Y: (N,3,H,W)
X = torch.rand(4, 3, 512, 512)
Y = torch.rand(4, 3, 512, 512)
#Y = X

# ssim_val = ssim( X, Y, data_range=1.0, size_average=False) # return (N,)
# ms_ssim_val = ms_ssim( X, Y, data_range=1.0, size_average=False ) #(N,)

# # or set 'size_average=True' to get a scalar value as loss.
# ssim_loss = ssim( X, Y, data_range=1.0, size_average=True) # return a scalar
# ms_ssim_loss = ms_ssim( X, Y, data_range=1.0, size_average=True )

# or reuse windows with SSIM & MS_SSIM.
ssim_module = SSIM(win_size=11,
                   win_sigma=1.5,
                   data_range=1.0,
                   size_average=True,
                   channel=3)
ms_ssim_module = MS_SSIM(win_size=11,
                         win_sigma=1.5,
                         data_range=1.0,
                         size_average=True,
                         channel=3)

ssim_loss = 1 - ssim_module(X, Y)
ms_ssim_loss = 1 - ms_ssim_module(X, Y)

X = torch.rand(4, 3, 512, 512)
Y = torch.rand(4, 3, 512, 512)
예제 #18
0
def MyDNN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = os.path.join(opt.save_path, opt.task)
    sample_folder = os.path.join(opt.sample_path, opt.task)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_L2 = torch.nn.MSELoss().cuda()
    mse_loss = nn.MSELoss().cuda()
    ms_ssim_module = MS_SSIM(data_range=2,
                             size_average=True,
                             channel=3,
                             nonnegative_ssim=True)
    # Pretrained VGG
    # vgg = MINCFeatureExtractor(opt).cuda()
    # Initialize Generator
    generator = utils.create_MyDNN(opt)
    use_checkpoint = False
    if use_checkpoint:
        checkpoint_path = './MyDNN1_denoise_epoch175_bs1'
        # Load a pre-trained network
        pretrained_net = torch.load(checkpoint_path + '.pth')
        load_dict(generator, pretrained_net)
        print('Generator is loaded!')
    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
    else:
        generator = generator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            if epoch < 200:
                lr = 0.0001
            if epoch >= 200:
                lr = 0.00005
            if epoch >= 300:
                lr = 0.00001
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        return lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator, val_PSNR,
                   best_PSNR):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.save_best_model and best_PSNR == val_PSNR:
            torch.save(generator,
                       'final_%s_epoch%d_best.pth' % (opt.task, epoch))
            print('The best model is successfully saved at epoch %d' % (epoch))
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'MyDNN1_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'MyDNN1_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'final_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))

            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'final_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataloader
    # trainset = dataset.TestDataset(opt)
    trainset = dataset.Noise2CleanDataset(opt)
    print('The overall number of training images:', len(trainset))
    testset = dataset.TestDataset(opt)
    valset = dataset.ValDataset(opt)
    print('The overall number of val images:', len(valset))
    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)
    val_loader = DataLoader(valset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)
    test_loader = DataLoader(testset,
                             batch_size=opt.batch_size,
                             shuffle=True,
                             num_workers=opt.num_workers,
                             pin_memory=True)
    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()
    best_PSNR = 0
    # For loop training
    for epoch in range(opt.epochs):
        total_loss = 0
        total_ploss = 0
        total_sobel = 0
        total_Lap = 0
        for i, (true_input, simulated_input, true_target,
                noise_level_map) in enumerate(dataloader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()
            simulated_input = simulated_input.cuda()
            noise_level_map = noise_level_map.cuda()
            # Train Generator
            optimizer_G.zero_grad()
            pre_clean = generator(true_input)

            # Parse through VGGMINC layers
            # features_y = vgg(pre_clean)
            # features_x = vgg(true_input)
            # content_loss =  criterion_L2(features_y, features_x).

            pre = pre_clean[0, :, :, :].data.permute(1, 2, 0).cpu().numpy()
            pre = rgb2gray(pre)
            true = true_input[0, :, :, :].data.permute(1, 2, 0).cpu().numpy()
            true = rgb2gray(true)
            laplacian_pre = cv2.Laplacian(pre, cv2.CV_32F)  #CV_64F为图像深度
            laplacian_gt = cv2.Laplacian(true, cv2.CV_32F)  #CV_64F为图像深度
            sobel_pre = 0.5 * (cv2.Sobel(pre, cv2.CV_32F, 1, 0, ksize=5) +
                               cv2.Sobel(pre, cv2.CV_32F, 0, 1, ksize=5)
                               )  #1,0参数表示在x方向求一阶导数
            sobel_gt = 0.5 * (cv2.Sobel(true, cv2.CV_32F, 1, 0, ksize=5) +
                              cv2.Sobel(true, cv2.CV_32F, 0, 1, ksize=5)
                              )  #0,1参数表示在y方向求一阶导数
            sobel_loss = mean_squared_error(sobel_pre, sobel_gt)
            laplacian_loss = mean_squared_error(laplacian_pre, laplacian_gt)
            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(pre_clean, true_target)

            # MS-SSIM loss
            ms_ssim_loss = 1 - ms_ssim_module(pre_clean + 1, true_target + 1)

            # Overall Loss and optimize
            loss = Pixellevel_L1_Loss + 0.5 * laplacian_loss
            # loss =  Pixellevel_L1_Loss
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()
            total_loss = Pixellevel_L1_Loss.item() + total_loss
            # total_ploss = content_loss.item() + total_ploss
            total_sobel = sobel_loss + total_sobel
            total_Lap = laplacian_loss + total_Lap

            # # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [laplacian_loss Loss: %.4f] [sobel_loss Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   Pixellevel_L1_Loss.item(), laplacian_loss.item(),
                   sobel_loss.item(), time_left))
            img_list = [pre_clean, true_target, true_input]
            name_list = ['pred', 'gt', 'noise']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='MyDNN_MS_epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)

            # Learning rate decrease at certain epochs
            lr = adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                      optimizer_G)
        print(
            "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [laplacian_loss Loss: %.4f] [sobel_loss Loss: %.4f] Time_left: %s"
            % ((epoch + 1), opt.epochs, i, len(dataloader), total_loss / 320,
               total_Lap / 320, total_sobel / 320, time_left))
        ### Validation
        val_PSNR = 0
        be_PSNR = 0
        num_of_val_image = 0

        for j, (true_input, simulated_input, true_target,
                noise_level_map) in enumerate(val_loader):

            # To device
            # A is for input image, B is for target image
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Forward propagation
            with torch.no_grad():
                pre_clean = generator(true_input)

            # Accumulate num of image and val_PSNR
            num_of_val_image += true_input.shape[0]
            val_PSNR += utils.psnr(pre_clean, true_target,
                                   255) * true_input.shape[0]
            be_PSNR += utils.psnr(true_input, true_target,
                                  255) * true_input.shape[0]
        val_PSNR = val_PSNR / num_of_val_image
        be_PSNR = be_PSNR / num_of_val_image

        # Record average PSNR
        print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR))
        print('PSNR before denoising %d: %.4f' % ((epoch + 1), be_PSNR))
        best_PSNR = max(val_PSNR, best_PSNR)
        # Save model at certain epochs or iterations
        save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                   generator, val_PSNR, best_PSNR)