def main(): torch.cuda.manual_seed_all(Args.seed) train_transform = transforms.Compose([ transforms.Resize(Args.resized), transforms.Grayscale(Args.num_channel), transforms.ToTensor() ]) coco_train = COCO(Args.train_path, transform=train_transform) trainloader = DataLoader(coco_train, batch_size=Args.batch_size, shuffle=True, num_workers=min(4, Args.batch_size), pin_memory=True) loaders = {'train': trainloader} model = DenseFuse(num_channel=Args.num_channel) model.to(Args.device) optimizer = optim.Adam(model.parameters(), lr=Args.lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', patience=1, factor=Args.lr_decay_factor) ms_ssim = MS_SSIM(data_range=1.0, size_average=True, channel=1) criterions = {'ms_ssim': ms_ssim} ckpt_name = time.ctime().replace(' ', '-').replace(':', '-') ckptPath = Args.ckptPath.joinpath(ckpt_name) train(loaders, model, criterions, optimizer, Args.num_epochs, ckptPath, scheduler=scheduler)
def ssim_loss(X, Y): # # X: (N,3,H,W) a batch of non-negative RGB images (0~255) # # Y: (N,3,H,W) # # # calculate ssim & ms-ssim for each image # ssim_val = ssim(X, Y, data_range=255, size_average=False) # return (N,) # ms_ssim_val = ms_ssim(X, Y, data_range=255, size_average=False) # (N,) # # # set 'size_average=True' to get a scalar value as loss. # ssim_loss = 1 - ssim(X, Y, data_range=255, size_average=True) # return a scalar # ms_ssim_loss = 1 - ms_ssim(X, Y, data_range=255, size_average=True) # reuse the gaussian kernel with SSIM & MS_SSIM. ssim_module = SSIM(data_range=255, size_average=True, channel=1, nonnegative_ssim=False) ms_ssim_module = MS_SSIM(data_range=255, size_average=True, channel=1, nonnegative_ssim=False) ssim_loss = 1 - ssim_module(X, Y) ms_ssim_loss = 1 - ms_ssim_module(X, Y) return ms_ssim_loss
def __init__(self, name, dip_n_iter=8000, net='skip', lr=0.001, reg_std=1./100, w_proj_loss=1.0, w_perceptual_loss=0.0, w_ssim_loss=0.0, w_tv_loss=0.0, randomize_projs=None, channels=[16, 32, 64, 128, 256]): super(DgrReconstructor, self).__init__(name) self.n_iter = dip_n_iter assert net in ['skip', 'skipV2', 'skipV3', 'unet', 'dncnn'] self.net = net self.channels = channels self.lr = lr self.reg_std = reg_std # loss weights self.w_proj_loss = w_proj_loss self.w_perceptual_loss = w_perceptual_loss self.w_tv_loss = w_tv_loss self.w_ssim_loss = w_ssim_loss self.randomize_projs = randomize_projs # loss functions self.mse = torch.nn.MSELoss().to(self.DEVICE) self.ssim = MS_SSIM(data_range=1.0, size_average=True, channel=self.IMAGE_DEPTH).to(self.DEVICE) self.perceptual = VGGPerceptualLoss(resize=True).to(self.DEVICE) self.gt = None self.noisy = None self.FOCUS = None self.log_dir = None
def forward(self, input, target): loss = self.a * (1 - MS_SSIM(win_size=self.win_size, data_range=self.data_range, channel=self.channel)(input, target) ) + (1 - self.a) * torch.nn.L1Loss()(input, target) return loss
def criterions(name): if name == "mse": return nn.MSELoss() elif name == "l1": return nn.L1Loss() elif name == "lpips": return lpips.LPIPS(net="vgg").cuda() elif name == "ms_ssim": return MS_SSIM(data_range=1.0)
def __init__(self, eps=1e-6, lambda_=1): super(CharbonnierLossPlusMSSSIM, self).__init__() self.eps = eps self.lambda_ = lambda_ self.ms_ssim_module = MS_SSIM(win_size=11, win_sigma=1.5, data_range=1.0, size_average=True, channel=3)
def __init__(self, kernel_w=3, sigma=1.5, channels=1, weights=None): super().__init__() self.kernel_w = kernel_w self.sigma = sigma #number of weights determines the depth of the pyramid #standard are 5, too deep for MNist resolution if weights is None: self.weights = [0.0516, 0.32949, 0.34622, 0.27261] else: self.weights = weights self.ssim_d = MS_SSIM(win_size=kernel_w, win_sigma=sigma, data_range=1.0, channel=channels, weights=self.weights, size_average=False) self.config = {'Distance' : 'SSIM', 'win_size': kernel_w, 'win_sigma': sigma}
def set_criterion(self, level=None): assert level in [0, 1, 2, 3, 4, 5], 'unknown level' criterions = [nn.MSELoss()] coefficients = [1.0] if level<4: perceptual_loss = VGG19Loss(['relu5_4']) perceptual_loss.to(self.device) criterions.append(perceptual_loss) coefficients.append(0.01) if level==0: criterions.append(MS_SSIM(data_range=1.0, size_average=True, channel=3)) if self.perceptual: coefficients.append(-0.1) else: coefficients.append(-0.01) self.criterions = criterions self.coefficients = torch.FloatTensor(coefficients).to(self.device)
def __init__(self, args, ckp): super(Loss, self).__init__() print('Preparing loss function:') self.n_GPUs = args.n_GPUs self.loss = [] self.loss_module = nn.ModuleList() for loss in args.loss.split('+'): weight, loss_type = loss.split('*') if loss_type == 'MSE': # L2 loss loss_function = nn.MSELoss() elif loss_type == 'L1': loss_function = nn.L1Loss() elif loss_type.find('VGG') >= 0: module = import_module('loss.vgg') loss_function = getattr(module, 'VGG')(loss_type[3:], rgb_range=args.rgb_range) elif loss_type.find('TextureL') >= 0: module = import_module('loss.vgg') loss_function = getattr(module, 'VGG')(loss_type[3:], rgb_range=args.rgb_range, texture_loss=True) elif loss_type.find('GAN') >= 0: module = import_module('loss.adversarial') loss_function = getattr(module, 'Adversarial')(args, loss_type) elif loss_type.find('TVLoss') >= 0: module = import_module('loss.tvloss') loss_function = getattr(module, 'TVLoss')() elif loss_type.find('SSIM') >= 0: from pytorch_msssim import SSIM loss_function = SSIM(win_size=7, win_sigma=1, data_range=args.rgb_range, size_average=True, channel=3) elif loss_type.find('MS-SSIM') >= 0: from pytorch_msssim import MS_SSIM loss_function = MS_SSIM(win_sigma=1, data_range=args.rgb_range, size_average=True, channel=3) self.loss.append({ 'type': loss_type, 'weight': float(weight), 'function': loss_function }) if loss_type.find('GAN') >= 0: self.loss.append({ 'type': 'DIS', 'weight': 1, 'function': None }) if len(self.loss) > 1: self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) for l in self.loss: if l['function'] is not None: print('{:.6f} * {}'.format(l['weight'], l['type'])) self.loss_module.append(l['function']) self.log = torch.Tensor() device = torch.device('cpu' if args.cpu else 'cuda') self.loss_module.to(device) if args.precision == 'half': self.loss_module.half() if not args.cpu and args.n_GPUs > 1: self.loss_module = nn.DataParallel(self.loss_module, range(args.n_GPUs)) if args.load != '': self.load(ckp.dir, cpu=args.cpu)
def __init__(self, alpha=0.84): super(ReconstructionLoss, self).__init__() self.alpha = alpha self.l1 = nn.L1Loss() self.ms_ssim = MS_SSIM(data_range=1, size_average=True)
0.02, gpu_id=device) # VGG for perceptual loss if opt.lamb_content > 0: vgg = Vgg16() init_vgg16(root_path) vgg.load_state_dict(torch.load(os.path.join(root_path, "vgg16.weight"))) vgg.to(device) # define loss criterionL1 = nn.L1Loss().to(device) criterionL2 = nn.MSELoss().to(device) criterionMSE = nn.MSELoss().to(device) criterionSSIM = SSIM(data_range=255, size_average=True, channel=3) criterionMSSSIM1 = MS_SSIM(data_range=255, size_average=True, channel=1) criterionMSSSIM3 = MS_SSIM(data_range=255, size_average=True, channel=3) # setup optimizer optimizer_i = optim.Adam(net_i.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer_r = optim.Adam(net_r.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) net_i_scheduler = get_scheduler(optimizer_i, opt) net_r_scheduler = get_scheduler(optimizer_r, opt) loss_i_list = [] loss_r_list = [] for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
channel).to(img1.device).type(img1.dtype) self.window = window self.channel = channel return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) class MSSSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True, channel=3): super(MSSSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = channel def forward(self, img1, img2): # TODO: store window between calls if possible return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) vgg_model = VGG19().cuda() vgg_model = vgg_model.eval() GAN_loss_calculator = GANLoss() mssim_calculator = MS_SSIM(data_range=1.0, size_average=True, channel=3)
def tensor_ssim_module(): # reuse the gaussian kernel with SSIM & MS_SSIM. ssim_module = SSIM(data_range=255, size_average=True, channel=3) ms_ssim_module = MS_SSIM(data_range=255, size_average=True, channel=3)
def do_learn(opt, run_dir="./runs"): print('Starting ', opt.run_path) path_data = os.path.join(run_dir, opt.run_path) # ---------- # Tensorboard # ---------- if do_tensorboard: # stats are stored in "runs", within subfolder opt.run_path. writer = SummaryWriter(log_dir=path_data) # Create a time tag import datetime try: tag = datetime.datetime.now().isoformat(sep='_', timespec='seconds') except TypeError: # Python 3.5 and below # 'timespec' is an invalid keyword argument for this function tag = datetime.datetime.now().replace(microsecond=0).isoformat(sep='_') tag = tag.replace(':', '-') # Configure data loader dataloader = load_data(opt.datapath, opt.img_size, opt.batch_size, rand_hflip=opt.rand_hflip, rand_affine=opt.rand_affine) if opt.do_SSIM: # from pytorch_msssim import NMSSSIM # E_loss = NMSSSIM(window_size=opt.window_size, val_range=1., size_average=True, channel=3, normalize=True) # from pytorch_msssim import NSSIM #as neg_SSIM # E_loss = NSSIM(window_size=opt.window_size, val_range=1., size_average=True) # NEW: we use https://github.com/VainF/pytorch-msssim instead of https://github.com/SpikeAI/pytorch-msssim #from pytorch_msssim import msssim, ssim from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM E_loss = MS_SSIM(win_size=opt.window_size, data_range=1, size_average=True, channel=3) else: E_loss = torch.nn.MSELoss(reduction='sum') sigmoid = torch.nn.Sigmoid() # Initialize generator and discriminator generator = Generator(opt) discriminator = Discriminator(opt) encoder = Encoder(opt) if opt.verbose: print_network(generator) print_network(discriminator) print_network(encoder) eye = 1 - torch.eye(opt.batch_size) use_cuda = True if torch.cuda.is_available() else False if use_cuda: #print("Nombre de GPU : ",torch.cuda.device_count()) print("Running on GPU : ", torch.cuda.get_device_name()) # if torch.cuda.device_count() > opt.GPU: # torch.cuda.set_device(opt.GPU) generator.cuda() discriminator.cuda() # adversarial_loss.cuda() encoder.cuda() # MSE_loss.cuda() E_loss.cuda() eye = eye.cuda() Tensor = torch.cuda.FloatTensor else: print("Running on CPU ") Tensor = torch.FloatTensor # Initialize weights if opt.init_weight: generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) encoder.apply(weights_init_normal) # Optimizers if opt.optimizer == 'rmsprop': # https://pytorch.org/docs/stable/optim.html#torch.optim.RMSprop opts = dict(momentum=1 - opt.beta1, alpha=opt.beta2) optimizer = torch.optim.RMSprop elif opt.optimizer == 'adam': # https://pytorch.org/docs/stable/optim.html#torch.optim.Adam opts = dict(betas=(opt.beta1, opt.beta2)) optimizer = torch.optim.Adam elif opt.optimizer == 'sgd': opts = dict(momentum=1 - opt.beta1, nesterov=True, weight_decay=1 - opt.beta2) optimizer = torch.optim.SGD else: raise ('wrong optimizer') optimizer_G = optimizer(generator.parameters(), lr=opt.lrG, **opts) optimizer_D = optimizer(discriminator.parameters(), lr=opt.lrD, **opts) if opt.do_joint: import itertools optimizer_E = optimizer(itertools.chain(encoder.parameters(), generator.parameters()), lr=opt.lrE, **opts) else: optimizer_E = optimizer(encoder.parameters(), lr=opt.lrE, **opts) # TODO parameterize scheduler ! # gamma = .1 ** (1 / opt.n_epochs) # schedulers = [] # for optimizer in [optimizer_G, optimizer_D, optimizer_E]: # schedulers.append(torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)) # ---------- # Training # ---------- nb_batch = len(dataloader) stat_record = init_hist(opt.n_epochs, nb_batch) # https://github.com/soumith/dcgan.torch/issues/14 dribnet commented on 21 Mar 2016 # https://arxiv.org/abs/1609.04468 def slerp(val, low, high): corr = np.diag( (low / np.linalg.norm(low)) @ (high / np.linalg.norm(high)).T) omega = np.arccos(np.clip(corr, -1, 1))[:, None] so = np.sin(omega) out = np.sin( (1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high # L'Hopital's rule/LERP out[so[:, 0] == 0, :] = ( 1.0 - val) * low[so[:, 0] == 0, :] + val * high[so[:, 0] == 0, :] return out def norm2(z): """ L2-norm of a tensor. outputs a scalar """ # return torch.mean(z.pow(2)).pow(.5) return (z**2).sum().sqrt() def gen_z(imgs=None, rho=.25, do_slerp=opt.do_slerp): """ Generate noise in the feature space. outputs a vector """ if not imgs is None: z_imgs = encoder(imgs).cpu().numpy() if do_slerp: z_shuffle = z_imgs.copy() z_shuffle = z_shuffle[torch.randperm(opt.batch_size), :] z = slerp(rho, z_imgs, z_shuffle) else: z /= norm2(z) z_imgs /= norm2(z_imgs) z = (1 - rho) * z_imgs + rho * z z /= norm2(z) else: z = np.random.normal(0, 1, (opt.batch_size, opt.latent_dim)) # convert to tensor return Variable(Tensor(z), requires_grad=False) def gen_noise(imgs): """ Generate noise in the image space outputs an image """ v_noise = np.random.normal(0, 1, imgs.shape) # one random image # one contrast value per image v_noise *= np.abs( np.random.normal(0, 1, (imgs.shape[0], opt.channels, 1, 1))) # convert to tensor v_noise = Variable(Tensor(v_noise), requires_grad=False) return v_noise # Vecteur z fixe pour faire les samples fixed_noise = gen_z() real_imgs_samples = None # z_zeros = Variable(Tensor(opt.batch_size, opt.latent_dim).fill_(0), requires_grad=False) # z_ones = Variable(Tensor(opt.batch_size, opt.latent_dim).fill_(1), requires_grad=False) # Adversarial ground truths # valid = Variable(Tensor(opt.batch_size, 1).fill_(1), requires_grad=False) # fake = Variable(Tensor(opt.batch_size, 1).fill_(0), requires_grad=False) t_total = time.time() for i_epoch, epoch in enumerate(range(1, opt.n_epochs + 1)): t_epoch = time.time() for iteration, (imgs, _) in enumerate(dataloader): t_batch = time.time() # --------------------- # Train Encoder # --------------------- for p in generator.parameters(): p.requires_grad = opt.do_joint for p in encoder.parameters(): p.requires_grad = True # the following is not necessary as we do not use D here and only optimize ||G(E(x)) - x ||^2 for p in discriminator.parameters(): p.requires_grad = False # to avoid learning D when learning E real_imgs = Variable(imgs.type(Tensor), requires_grad=False) # init samples used to visualize performance of the AE if real_imgs_samples is None: real_imgs_samples = real_imgs[:opt.N_samples] # add noise here to real_imgs real_imgs_ = real_imgs * 1. if opt.E_noise > 0: real_imgs_ += opt.E_noise * gen_noise(real_imgs) z_imgs = encoder(real_imgs_) decoded_imgs = generator(z_imgs) # Loss measures Encoder's ability to generate vectors suitable with the generator e_loss = 1. - E_loss(real_imgs, decoded_imgs) # energy = 1. # E_loss(real_imgs, zero_target) # normalize on the energy of imgs # if opt.do_joint: # e_loss = E_loss(real_imgs, decoded_imgs) / energy # else: # e_loss = E_loss(real_imgs, decoded_imgs.detach()) / energy if opt.lambdaE > 0: # We wish to make sure the intermediate vector z_imgs get closer to a iid normal (centered gausian of variance 1) e_loss += opt.lambdaE * (torch.sum(z_imgs) / opt.batch_size / opt.latent_dim).pow(2) e_loss += opt.lambdaE * (torch.sum(z_imgs.pow(2)) / opt.batch_size / opt.latent_dim - 1).pow(2).pow(.5) # Backward optimizer_E.zero_grad() e_loss.backward() optimizer_E.step() valid_smooth = np.random.uniform(opt.valid_smooth, 1.0 - (1 - opt.valid_smooth) / 2, (opt.batch_size, 1)) valid_smooth = Variable(Tensor(valid_smooth), requires_grad=False) fake_smooth = np.random.uniform((1 - opt.valid_smooth) / 2, 1 - opt.valid_smooth, (opt.batch_size, 1)) fake_smooth = Variable(Tensor(fake_smooth), requires_grad=False) # --------------------- # Train Discriminator # --------------------- # Discriminator Requires grad, Encoder + Generator requires_grad = False for p in discriminator.parameters(): p.requires_grad = True for p in generator.parameters(): p.requires_grad = False # to avoid computation for p in encoder.parameters(): p.requires_grad = False # to avoid computation # Configure input real_imgs = Variable(imgs.type(Tensor), requires_grad=False) real_imgs_ = real_imgs * 1. if opt.D_noise > 0: real_imgs_ += opt.D_noise * gen_noise(real_imgs) if opt.do_insight: # the discriminator can not access the images directly but only # what is visible through the auto-encoder real_imgs_ = generator(encoder(real_imgs_)) # Discriminator decision (in logit units) # TODO : group images by sub-batches and train to discriminate from all together # should allow to avoid mode collapse logit_d_x = discriminator(real_imgs_) # --------------------- # Train Discriminator # --------------------- if opt.GAN_loss == 'wasserstein': # weight clipping for p in discriminator.parameters(): p.data.clamp_(-0.01, 0.01) # Measure discriminator's ability to classify real from generated samples if opt.GAN_loss == 'ian': # eq. 14 in https://arxiv.org/pdf/1701.00160.pdf real_loss = -torch.sum(1 / (1. - 1 / sigmoid(logit_d_x))) elif opt.GAN_loss == 'hinge': # TODO check if we use p or log p real_loss = nn.ReLU()(valid_smooth - sigmoid(logit_d_x)).mean() elif opt.GAN_loss == 'wasserstein': real_loss = torch.mean( torch.abs(valid_smooth - sigmoid(logit_d_x))) elif opt.GAN_loss == 'alternative': # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/ real_loss = -torch.sum(torch.log(sigmoid(logit_d_x))) elif opt.GAN_loss == 'alternativ2': # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/ real_loss = -torch.sum( torch.log(sigmoid(logit_d_x) / (1. - sigmoid(logit_d_x)))) elif opt.GAN_loss == 'alternativ3': # to maximize D(x), we minimize - sum(logit_d_x) real_loss = -torch.sum(logit_d_x) elif opt.GAN_loss == 'original': real_loss = F.binary_cross_entropy(sigmoid(logit_d_x), valid_smooth) else: print('GAN_loss not defined', opt.GAN_loss) # Generate a batch of fake images and learn the discriminator to treat them as such z = gen_z(imgs=real_imgs_) gen_imgs = generator(z) if opt.D_noise > 0: gen_imgs += opt.D_noise * gen_noise(real_imgs) # Discriminator decision for fake data logit_d_fake = discriminator(gen_imgs.detach()) # Measure discriminator's ability to classify real from generated samples if opt.GAN_loss == 'wasserstein': fake_loss = torch.mean(sigmoid(logit_d_fake)) elif opt.GAN_loss == 'hinge': # TODO check if we use p or log p real_loss = nn.ReLU()(1.0 + sigmoid(logit_d_fake)).mean() elif opt.GAN_loss == 'alternative': # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/ fake_loss = -torch.sum(torch.log(1 - sigmoid(logit_d_fake))) elif opt.GAN_loss == 'alternativ2': # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/ fake_loss = torch.sum( torch.log( sigmoid(logit_d_fake) / (1. - sigmoid(logit_d_fake)))) elif opt.GAN_loss == 'alternativ3': # to minimize D(G(z)), we minimize sum(logit_d_fake) fake_loss = torch.sum(logit_d_fake) elif opt.GAN_loss in ['original', 'ian']: fake_loss = F.binary_cross_entropy(sigmoid(logit_d_fake), fake_smooth) else: print('GAN_loss not defined', opt.GAN_loss) # Backward optimizer_D.zero_grad() real_loss.backward() fake_loss.backward() # apply the gradients optimizer_D.step() # ----------------- # Train Generator # ----------------- for p in generator.parameters(): p.requires_grad = True for p in discriminator.parameters(): p.requires_grad = False # to avoid computation for p in encoder.parameters(): p.requires_grad = False # to avoid computation # Generate a batch of fake images z = gen_z(imgs=real_imgs_) gen_imgs = generator(z) if opt.G_noise > 0: gen_imgs += opt.G_noise * gen_noise(real_imgs) # New discriminator decision (since we just updated D) logit_d_g_z = discriminator(gen_imgs) # Loss functions # Loss measures generator's ability to fool the discriminator if opt.GAN_loss == 'ian': # eq. 14 in https://arxiv.org/pdf/1701.00160.pdf # https://en.wikipedia.org/wiki/Logit g_loss = -torch.sum( sigmoid(logit_d_g_z) / (1 - sigmoid(logit_d_g_z))) elif opt.GAN_loss == 'wasserstein' or opt.GAN_loss == 'hinge': g_loss = torch.mean( torch.abs(valid_smooth - sigmoid(logit_d_g_z))) elif opt.GAN_loss == 'alternative': # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/ g_loss = -torch.sum(torch.log(sigmoid(logit_d_g_z))) elif opt.GAN_loss == 'alternativ2': # https://www.inference.vc/an-alternative-update-rule-for-generative-adversarial-networks/ g_loss = -torch.sum( torch.log( sigmoid(logit_d_g_z) / (1. - sigmoid(logit_d_g_z)))) # g_loss = torch.sum(torch.log(1./sigmoid(logit_d_g_z) - 1.)) elif opt.GAN_loss == 'alternativ3': # to maximize D(G(z)), we minimize - sum(logit_d_g_z) g_loss = -torch.sum(logit_d_g_z) elif opt.GAN_loss == 'original': # https://pytorch.org/docs/stable/nn.html?highlight=bcewithlogitsloss#torch.nn.BCEWithLogitsLoss #adversarial_loss = torch.nn.BCEWithLogitsLoss() # eq. 8 in https://arxiv.org/pdf/1701.00160.pdf # # https://medium.com/swlh/gan-to-generate-images-of-cars-5f706ca88da # adversarial_loss = torch.nn.BCE() # eq. 8 in https://arxiv.org/pdf/1701.00160.pdf g_loss = F.binary_cross_entropy(sigmoid(logit_d_g_z), valid_smooth) else: print('GAN_loss not defined', opt.GAN_loss) # penalize low variability in a batch, that is, mode collapse # TODO maximize sum of the distances to the nearest neighbors if opt.lambdaG > 0: e_g_z = encoder(gen_imgs) # get normal vectors Xcorr = torch.tensordot(e_g_z, torch.transpose(e_g_z, 0, 1), 1) / opt.latent_dim Xcorr *= eye # set the diagonal elements to zero g_loss += opt.lambdaG * torch.sum(Xcorr.pow(2)).pow(.5) # Backward optimizer_G.zero_grad() g_loss.backward() # apply the gradients optimizer_G.step() # ----------------- # Recording stats # ----------------- d_loss = real_loss + fake_loss # Compensation pour le BCElogits d_fake = sigmoid(logit_d_fake) d_x = sigmoid(logit_d_x) d_g_z = sigmoid(logit_d_g_z) print( "%s [Epoch %d/%d] [Batch %d/%d] [E loss: %f] [D loss: %f] [G loss: %f] [D(x) %f] [D(G(z)) %f] [D(G(z')) %f] [Time: %fs]" % (opt.run_path, epoch, opt.n_epochs, iteration + 1, len(dataloader), e_loss.item(), d_loss.item(), g_loss.item(), torch.mean(d_x), torch.mean(d_fake), torch.mean(d_g_z), time.time() - t_batch)) # Save Losses and scores for Tensorboard save_hist_batch(stat_record, iteration, i_epoch, g_loss, d_loss, e_loss, d_x, d_g_z) if do_tensorboard: # Tensorboard save writer.add_scalar('loss/E', e_loss.item(), global_step=epoch) # writer.add_histogram('coeffs/z', z, global_step=epoch) try: writer.add_histogram('coeffs/E_x', z_imgs, global_step=epoch) except: pass # writer.add_histogram('image/x', real_imgs, global_step=epoch) # try: # writer.add_histogram('image/E_G_x', decoded_imgs, global_step=epoch) # except: # pass # try: # writer.add_histogram('image/G_z', gen_imgs, global_step=epoch) # except: # pass writer.add_scalar('loss/G', g_loss.item(), global_step=epoch) # writer.add_scalar('score/D_fake', hist["d_fake_mean"][i], global_step=epoch) # print(stat_record["d_g_z_mean"]) writer.add_scalar('score/D_g_z', np.mean(stat_record["d_g_z_mean"]), global_step=epoch) writer.add_scalar('loss/D', d_loss.item(), global_step=epoch) writer.add_scalar('score/D_x', np.mean(stat_record["d_x_mean"]), global_step=epoch) # Save samples if epoch % opt.sample_interval == 0: """ Use generator model and noise vector to generate images. Save them to tensorboard """ generator.eval() gen_imgs = generator(fixed_noise) from torchvision.utils import make_grid grid = make_grid(gen_imgs, normalize=True, nrow=16, range=(0, 1)) writer.add_image('Generated images', grid, epoch) generator.train() """ Use auto-encoder model and original images to generate images. Save them to tensorboard """ # grid_imgs = make_grid(real_imgs_samples, normalize=True, nrow=8, range=(0, 1)) # writer.add_image('Images/original', grid_imgs, epoch) generator.eval() encoder.eval() enc_imgs = encoder(real_imgs_samples) dec_imgs = generator(enc_imgs) grid_dec = make_grid(dec_imgs, normalize=True, nrow=16, range=(0, 1)) # writer.add_image('Images/auto-encoded', grid_dec, epoch) writer.add_image('Auto-encoded', grid_dec, epoch) generator.train() encoder.train() # writer.add_graph(encoder, real_imgs_samples) # writer.add_graph(generator, enc_imgs) # writer.add_graph(discriminator, real_imgs_samples) # # if epoch % opt.sample_interval == 0 : # sampling(fixed_noise, generator, path_data, epoch, tag) # # do_plot(hist, start_epoch, epoch) print("[Epoch Time: ", time.time() - t_epoch, "s]") sampling(fixed_noise, generator, path_data, epoch, tag, nrow=16) # for scheduler in schedulers: scheduler.step() t_final = time.gmtime(time.time() - t_total) print("[Total Time: ", t_final.tm_mday - 1, "j:", time.strftime("%Hh:%Mm:%Ss", t_final), "]", sep='') if do_tensorboard: writer.close()
def __init__(self, **kwargs): super().__init__() self.save_hyperparameters() device = torch.device( "cuda:0" if torch.cuda.is_available() and self.hparams.cuda else "cpu") if self.hparams.modelID == 0: self.net = ResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks, starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks, is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm, res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV, upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D) # TODO think of 2D # self.net = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1) elif self.hparams.modelID == 2: self.net = DualSpaceResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks, starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks, is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm, res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV, upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D, connect_mode=self.hparams.model_dspace_connect_mode, inner_norm_ksp=self.hparams.model_inner_norm_ksp) elif self.hparams.modelID == 3: #Primal-Dual Network, complex Primal self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10, use_original_block = True, use_original_init = True, use_complex_primal = True, g_normtype = "magmax", transform = "Fourier", return_abs = True) elif self.hparams.modelID == 4: #Primal-Dual Network, absolute Primal self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10, use_original_block = True, use_original_init = True, use_complex_primal = False, g_normtype = "magmax", transform = "Fourier") elif self.hparams.modelID == 5: #Primal-Dual UNet Network, absolute Primal self.net = PrimalDualNetwork(n_primary=4, n_dual=5, n_iterations=2, use_original_block = False, use_original_init = False, use_complex_primal = False, g_normtype = "magmax", transform = "Fourier") elif self.hparams.modelID == 6: #Primal-Dual Network v2 (no residual), complex Primal self.net = PrimalDualNetworkNoResidue(n_primary=5, n_dual=5, n_iterations=10, use_original_block = True, use_original_init = True, use_complex_primal = True, residuals=False, g_normtype = "magmax", transform = "Fourier", return_abs = True) else: # TODO: other models sys.exit("Only ReconResNet and DualSpaceResNet have been implemented so far in ReconEngine") if bool(self.hparams.preweights_path): print("Pre-weights found, loding...") chk = torch.load(self.hparams.preweights_path, map_location='cpu') self.net.load_state_dict(chk['state_dict']) if self.hparams.lossID == 0: if self.hparams.in_channels != 1 or self.hparams.out_channels != 1: sys.exit( "Perceptual Loss used here only works for 1 channel input and output") self.loss = PerceptualLoss(device=device, loss_model="unet3Dds", resize=None, loss_type=self.hparams.ploss_type, n_level=self.hparams.ploss_level) # TODO thinkof 2D elif self.hparams.lossID == 1: self.loss = nn.L1Loss(reduction='mean') elif self.hparams.lossID == 2: self.loss = MS_SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device) elif self.hparams.lossID == 3: self.loss = SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device) else: sys.exit("Invalid Loss ID") self.dataspace = DataSpaceHandler(**self.hparams) if self.hparams.ds_mode == 0: trans = tioTransforms augs = tioAugmentations elif self.hparams.ds_mode == 1: trans = pytTransforms augs = pytAugmentations # TODO parameterised everything self.init_transforms = [] self.aug_transforms = [] self.transforms = [] if self.hparams.ds_mode == 0 and self.hparams.cannonicalResample: # Only applicable for TorchIO self.init_transforms += [tio.ToCanonical(), tio.Resample('gt')] if self.hparams.ds_mode == 0 and self.hparams.forceNormAffine: # Only applicable for TorchIO self.init_transforms += [trans.ForceAffine()] if self.hparams.croppad and self.hparams.ds_mode == 1: self.init_transforms += [ trans.CropOrPad(size=self.hparams.input_shape)] self.init_transforms += [trans.IntensityNorm(type=self.hparams.norm_type, return_meta=self.hparams.motion_return_meta)] # dataspace_transforms = self.dataspace.getTransforms() #TODO: dataspace transforms are not in use # self.init_transforms += dataspace_transforms if bool(self.hparams.random_crop) and self.hparams.ds_mode == 1: self.aug_transforms += [augs.RandomCrop( size=self.hparams.random_crop, p=self.hparams.p_random_crop)] if self.hparams.p_contrast_augment > 0: self.aug_transforms += [augs.getContrastAugs( p=self.hparams.p_contrast_augment)] # if the task if MoCo and pre-corrupted vols are not supplied if self.hparams.taskID == 1 and not bool(self.hparams.train_path_inp): if self.hparams.motion_mode == 0 and self.hparams.ds_mode == 0: motion_params = {k.split('motionmg_')[ 1]: v for k, v in self.hparams.items() if k.startswith('motionmg')} self.transforms += [tioMotion.RandomMotionGhostingFast( **motion_params), trans.IntensityNorm()] elif self.hparams.motion_mode == 1 and self.hparams.ds_mode == 1 and not self.hparams.is3D: self.transforms += [pytMotion.Motion2Dv0( sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)] elif self.hparams.motion_mode == 2 and self.hparams.ds_mode == 1 and not self.hparams.is3D: self.transforms += [pytMotion.Motion2Dv1(sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, restore_original=self.hparams.motion_restore_original, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)] else: sys.exit( "Error: invalid motion_mode, ds_mode, is3D combo. Please double check!") self.static_metamat = sio.loadmat(self.hparams.static_metamat_file) if bool( self.hparams.static_metamat_file) else None if self.hparams.taskID == 0 and self.hparams.use_datacon: self.datacon = DataConsistency( isRadial=self.hparams.is_radial, metadict=self.static_metamat) else: self.datacon = None input_shape = self.hparams.input_shape if self.hparams.is3D else self.hparams.input_shape[ :-1] self.example_input_array = torch.empty( self.hparams.batch_size, self.hparams.in_channels, *input_shape).float() self.saver = ResSaver( self.hparams.res_path, save_inp=self.hparams.save_inp, do_norm=self.hparams.do_savenorm)
def loss_new_msssim(x, y): msssim_loss = MS_SSIM(data_range=10, channel=2) loss = 1 - msssim_loss(x, y) return loss
import torch # X: (N,3,H,W) a batch of RGB images (0~255) # Y: (N,3,H,W) X = torch.rand(4, 3, 512, 512) Y = torch.rand(4, 3, 512, 512) #Y = X # ssim_val = ssim( X, Y, data_range=1.0, size_average=False) # return (N,) # ms_ssim_val = ms_ssim( X, Y, data_range=1.0, size_average=False ) #(N,) # # or set 'size_average=True' to get a scalar value as loss. # ssim_loss = ssim( X, Y, data_range=1.0, size_average=True) # return a scalar # ms_ssim_loss = ms_ssim( X, Y, data_range=1.0, size_average=True ) # or reuse windows with SSIM & MS_SSIM. ssim_module = SSIM(win_size=11, win_sigma=1.5, data_range=1.0, size_average=True, channel=3) ms_ssim_module = MS_SSIM(win_size=11, win_sigma=1.5, data_range=1.0, size_average=True, channel=3) ssim_loss = 1 - ssim_module(X, Y) ms_ssim_loss = 1 - ms_ssim_module(X, Y) X = torch.rand(4, 3, 512, 512) Y = torch.rand(4, 3, 512, 512)
def MyDNN(opt): # ---------------------------------------- # Network training parameters # ---------------------------------------- # cudnn benchmark cudnn.benchmark = opt.cudnn_benchmark # configurations save_folder = os.path.join(opt.save_path, opt.task) sample_folder = os.path.join(opt.sample_path, opt.task) if not os.path.exists(save_folder): os.makedirs(save_folder) if not os.path.exists(sample_folder): os.makedirs(sample_folder) # Loss functions criterion_L1 = torch.nn.L1Loss().cuda() criterion_L2 = torch.nn.MSELoss().cuda() mse_loss = nn.MSELoss().cuda() ms_ssim_module = MS_SSIM(data_range=2, size_average=True, channel=3, nonnegative_ssim=True) # Pretrained VGG # vgg = MINCFeatureExtractor(opt).cuda() # Initialize Generator generator = utils.create_MyDNN(opt) use_checkpoint = False if use_checkpoint: checkpoint_path = './MyDNN1_denoise_epoch175_bs1' # Load a pre-trained network pretrained_net = torch.load(checkpoint_path + '.pth') load_dict(generator, pretrained_net) print('Generator is loaded!') # To device if opt.multi_gpu: generator = nn.DataParallel(generator) generator = generator.cuda() else: generator = generator.cuda() # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_g, betas=(opt.b1, opt.b2), weight_decay=opt.weight_decay) # Learning rate decrease def adjust_learning_rate(opt, epoch, iteration, optimizer): #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs if opt.lr_decrease_mode == 'epoch': lr = opt.lr_g * (opt.lr_decrease_factor **(epoch // opt.lr_decrease_epoch)) if epoch < 200: lr = 0.0001 if epoch >= 200: lr = 0.00005 if epoch >= 300: lr = 0.00001 for param_group in optimizer.param_groups: param_group['lr'] = lr if opt.lr_decrease_mode == 'iter': lr = opt.lr_g * (opt.lr_decrease_factor **(iteration // opt.lr_decrease_iter)) for param_group in optimizer.param_groups: param_group['lr'] = lr return lr # Save the model if pre_train == True def save_model(opt, epoch, iteration, len_dataset, generator, val_PSNR, best_PSNR): """Save the model at "checkpoint_interval" and its multiple""" if opt.save_best_model and best_PSNR == val_PSNR: torch.save(generator, 'final_%s_epoch%d_best.pth' % (opt.task, epoch)) print('The best model is successfully saved at epoch %d' % (epoch)) if opt.multi_gpu == True: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator.module, 'MyDNN1_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator.module, 'MyDNN1_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) else: if opt.save_mode == 'epoch': if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0): if opt.save_name_mode: torch.save( generator, 'final_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size)) print( 'The trained model is successfully saved at epoch %d' % (epoch)) if opt.save_mode == 'iter': if iteration % opt.save_by_iter == 0: if opt.save_name_mode: torch.save( generator, 'final_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size)) print( 'The trained model is successfully saved at iteration %d' % (iteration)) # ---------------------------------------- # Network dataset # ---------------------------------------- # Define the dataloader # trainset = dataset.TestDataset(opt) trainset = dataset.Noise2CleanDataset(opt) print('The overall number of training images:', len(trainset)) testset = dataset.TestDataset(opt) valset = dataset.ValDataset(opt) print('The overall number of val images:', len(valset)) # Define the dataloader dataloader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) val_loader = DataLoader(valset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) test_loader = DataLoader(testset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) # ---------------------------------------- # Training # ---------------------------------------- # Count start time prev_time = time.time() best_PSNR = 0 # For loop training for epoch in range(opt.epochs): total_loss = 0 total_ploss = 0 total_sobel = 0 total_Lap = 0 for i, (true_input, simulated_input, true_target, noise_level_map) in enumerate(dataloader): # To device true_input = true_input.cuda() true_target = true_target.cuda() simulated_input = simulated_input.cuda() noise_level_map = noise_level_map.cuda() # Train Generator optimizer_G.zero_grad() pre_clean = generator(true_input) # Parse through VGGMINC layers # features_y = vgg(pre_clean) # features_x = vgg(true_input) # content_loss = criterion_L2(features_y, features_x). pre = pre_clean[0, :, :, :].data.permute(1, 2, 0).cpu().numpy() pre = rgb2gray(pre) true = true_input[0, :, :, :].data.permute(1, 2, 0).cpu().numpy() true = rgb2gray(true) laplacian_pre = cv2.Laplacian(pre, cv2.CV_32F) #CV_64F为图像深度 laplacian_gt = cv2.Laplacian(true, cv2.CV_32F) #CV_64F为图像深度 sobel_pre = 0.5 * (cv2.Sobel(pre, cv2.CV_32F, 1, 0, ksize=5) + cv2.Sobel(pre, cv2.CV_32F, 0, 1, ksize=5) ) #1,0参数表示在x方向求一阶导数 sobel_gt = 0.5 * (cv2.Sobel(true, cv2.CV_32F, 1, 0, ksize=5) + cv2.Sobel(true, cv2.CV_32F, 0, 1, ksize=5) ) #0,1参数表示在y方向求一阶导数 sobel_loss = mean_squared_error(sobel_pre, sobel_gt) laplacian_loss = mean_squared_error(laplacian_pre, laplacian_gt) # L1 Loss Pixellevel_L1_Loss = criterion_L1(pre_clean, true_target) # MS-SSIM loss ms_ssim_loss = 1 - ms_ssim_module(pre_clean + 1, true_target + 1) # Overall Loss and optimize loss = Pixellevel_L1_Loss + 0.5 * laplacian_loss # loss = Pixellevel_L1_Loss loss.backward() optimizer_G.step() # Determine approximate time left iters_done = epoch * len(dataloader) + i iters_left = opt.epochs * len(dataloader) - iters_done time_left = datetime.timedelta(seconds=iters_left * (time.time() - prev_time)) prev_time = time.time() total_loss = Pixellevel_L1_Loss.item() + total_loss # total_ploss = content_loss.item() + total_ploss total_sobel = sobel_loss + total_sobel total_Lap = laplacian_loss + total_Lap # # Print log print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [laplacian_loss Loss: %.4f] [sobel_loss Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), Pixellevel_L1_Loss.item(), laplacian_loss.item(), sobel_loss.item(), time_left)) img_list = [pre_clean, true_target, true_input] name_list = ['pred', 'gt', 'noise'] utils.save_sample_png(sample_folder=sample_folder, sample_name='MyDNN_MS_epoch%d' % (epoch + 1), img_list=img_list, name_list=name_list, pixel_max_cnt=255) # Learning rate decrease at certain epochs lr = adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G) print( "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [laplacian_loss Loss: %.4f] [sobel_loss Loss: %.4f] Time_left: %s" % ((epoch + 1), opt.epochs, i, len(dataloader), total_loss / 320, total_Lap / 320, total_sobel / 320, time_left)) ### Validation val_PSNR = 0 be_PSNR = 0 num_of_val_image = 0 for j, (true_input, simulated_input, true_target, noise_level_map) in enumerate(val_loader): # To device # A is for input image, B is for target image true_input = true_input.cuda() true_target = true_target.cuda() # Forward propagation with torch.no_grad(): pre_clean = generator(true_input) # Accumulate num of image and val_PSNR num_of_val_image += true_input.shape[0] val_PSNR += utils.psnr(pre_clean, true_target, 255) * true_input.shape[0] be_PSNR += utils.psnr(true_input, true_target, 255) * true_input.shape[0] val_PSNR = val_PSNR / num_of_val_image be_PSNR = be_PSNR / num_of_val_image # Record average PSNR print('PSNR at epoch %d: %.4f' % ((epoch + 1), val_PSNR)) print('PSNR before denoising %d: %.4f' % ((epoch + 1), be_PSNR)) best_PSNR = max(val_PSNR, best_PSNR) # Save model at certain epochs or iterations save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator, val_PSNR, best_PSNR)