def __init__(self, config, data_loader): """Set parameters of neural network and its training.""" self.ssim_loss = SSIM() self.generator = None self.discriminator = None self.distance_based_loss = None self.g_optimizer = None self.d_optimizer = None self.g_conv_dim = 128 self.beta1 = 0.9 self.beta2 = 0.999 self.learning_rate = 0.0001 self.image_size = config.image_size self.num_epochs = config.num_epochs self.distance_weight = config.distance_weight self.noise = config.noise self.residual = config.residual self.data_loader = data_loader self.generate_path = config.generate_path self.model_path = config.model_path self.tensorboard = config.tensorboard if self.tensorboard: self.tb_writer = tensorboardX.SummaryWriter( filename_suffix='_%s_%s' % (config.distance_weight, config.dataset)) self.tb_graph_added = False self.build_model()
def train(epoch, loader, model, optimizer, scheduler, device): loader = tqdm(loader) criterion = SSIM()#nn.MSELoss() latent_loss_weight = 0.25 sample_size = 25 mse_sum = 0 mse_n = 0 for i, (img, label) in enumerate(loader): model.zero_grad() img = img.to(device) out, latent_loss = model(img) recon_loss = criterion((out*0.5 + 0.5), (img*0.5 + 0.5)) latent_loss = latent_loss.mean() loss = recon_loss + latent_loss_weight * latent_loss loss.backward() if scheduler is not None: scheduler.step() optimizer.step() mse_sum += recon_loss.item() * img.shape[0] mse_n += img.shape[0] print(recon_loss) lr = optimizer.param_groups[0]['lr'] loader.set_description( ( f'epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; ' f'latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; ' f'lr: {lr:.5f}' ) ) if i % 100 == 0: model.eval() sample = img[:sample_size] with torch.no_grad(): out, _ = model(sample) utils.save_image( torch.cat([sample, out], 0), f'sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png', nrow=sample_size, normalize=True, range=(-1, 1), ) model.train()
def train(self): self.scheduler.step() self.loss.step() epoch = self.scheduler.last_epoch + 1 lr = self.scheduler.get_lr()[0] self.args.Noisy = True self.ckp.write_log( '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) ) self.loss.start_log() self.model.train() criterion_ssim = SSIM(window_size=11, size_average=False) criterion_ssim = criterion_ssim.cuda() timer_data, timer_model = utility.timer(), utility.timer() for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train): lr, hr = self.prepare(lr, hr) timer_data.hold() timer_model.tic() self.optimizer.zero_grad() sr = self.model(lr, idx_scale) loss = self.loss(sr, hr) #+ self.ssim*criterion_ssim(sr, hr) loss.backward() if self.args.gclip > 0: utils.clip_grad_value_( self.model.parameters(), self.args.gclip ) self.optimizer.step() timer_model.hold() if (batch + 1) % self.args.print_every == 0: self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), self.loss.display_loss(batch), timer_model.release(), timer_data.release())) timer_data.tic() self.loss.end_log(len(self.loader_train)) self.error_last = self.loss.log[-1, -1]
def __init__(self): super(Model, self).__init__() self.cross_entropy = nn.CrossEntropyLoss() self.mse = nn.MSELoss(reduce=True, size_average=True) self.l1 = nn.L1Loss() self.SL1 = nn.SmoothL1Loss() self.ssim = SSIM(window_size=11) self.avg = nn.AdaptiveAvgPool2d(1) self.predict_net = PredictNet(64) self.device = next(self.predict_net.parameters()).device self.resnet = resnet50_backbone(pretrained=True) self.regression = nn.Sequential( nn.Conv2d(1 * 2, 1, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(inplace=True), ) self.res_out = nn.Sequential( nn.Conv2d(2048, 1024, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(inplace=True), nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(inplace=True), nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(inplace=True), ) init_nets = [self.regression, self.predict_net, self.res_out] for net in init_nets: net.apply(weights_init_xavier) def get_log_diff_fn(eps=0.2): log_255_sq = np.float32(2 * np.log(255.0)) log_255_sq = log_255_sq.item() # int max_val = np.float32(log_255_sq - np.log(eps)) max_val = max_val.item() # int log_255_sq = torch.from_numpy(np.array(log_255_sq)).float().to( self.device) max_val = torch.from_numpy(np.array(max_val)).float().to( self.device) def log_diff_fn(in_a, in_b): diff = 255.0 * (in_a - in_b) val = log_255_sq - torch.log(diff**2 + eps) return val / max_val return log_diff_fn self.log_diff_fn = get_log_diff_fn(1) self.downsample_filter = DownSampleFilter() self.upsample_filter = UpSampleFilter()
def compute_loss_ssim(self, img_pyramid, img_warped_pyramid, occ_mask_list): loss_list = [] for scale in range(self.num_scales): img, img_warped, occ_mask = img_pyramid[scale], img_warped_pyramid[ scale], occ_mask_list[scale] divider = occ_mask.mean((1, 2, 3)) occ_mask_pad = occ_mask.repeat(1, 3, 1, 1) ssim = SSIM(img * occ_mask_pad, img_warped * occ_mask_pad) loss_ssim = torch.clamp((1.0 - ssim) / 2.0, 0, 1).mean((1, 2, 3)) loss_ssim = loss_ssim / (divider + 1e-12) loss_list.append(loss_ssim[:, None]) loss = torch.cat(loss_list, 1).sum(1) return loss
def __init__(self, model, lr=1e-1, n_epochs=10, verbose=True, dir_base='./output/checkpoints'): self.model = model self.lr = lr self.n_epochs = n_epochs self.verbose = verbose # Initializations self.dir_base = dir_base if not os.path.exists(self.dir_base): os.makedirs(self.dir_base) self.dir_log = f'{dir_base}/log.txt' self.best_summary_loss = 10**5 self.epoch = 0 self.device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') self.model.to(self.device) # Define the optimizer self.params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = torch.optim.AdamW(self.params, lr=self.lr) self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer=self.optimizer, mode='min', factor=0.5, patience=1, verbose=True, threshold=0.00005, threshold_mode='abs', cooldown=0, min_lr=1e-8, eps=1e-8) # Define the loss self.criterion = SSIM() self.log(f'====================================================') self.log( f'Fitter prepared | Time: {datetime.utcnow().isoformat()} | Device: {self.device}' )
def compute_pairwise_loss(self, tgt_img, ref_img, tgt_depth, ref_depth, pose, intrinsic): ref_img_warped, valid_mask, projected_depth, computed_depth = inverse_warp2( ref_img, tgt_depth, ref_depth, pose, intrinsic, 'zeros') diff_img = (tgt_img - ref_img_warped).abs() diff_depth = ((computed_depth - projected_depth).abs() / (computed_depth + projected_depth).abs()).clamp(0, 1) ssim_map = (0.5 * (1 - SSIM(tgt_img, ref_img_warped))).clamp(0, 1) diff_img = (0.15 * diff_img + 0.85 * ssim_map) # Modified in 01.19.2020 #weight_mask = (1 - diff_depth) #diff_img = diff_img * weight_mask # compute loss reconstruction_loss = diff_img.mean() geometry_consistency_loss = diff_depth.mean() #reconstruction_loss = mean_on_mask(diff_img, valid_mask) #geometry_consistency_loss = mean_on_mask(diff_depth, valid_mask) return reconstruction_loss, geometry_consistency_loss
def main(): global args, best_loss global logger global device, kwargs if args.model_type == 'fcn': filter_list = [ 1, int(args.model_multiplier * 4), int(args.model_multiplier * 8), int(args.model_multiplier * 16), int(args.model_multiplier * 32), int(args.model_multiplier * 64), 10 ] print('Model filter sizes list is {}'.format(filter_list)) model = VAE(filters=filter_list, dilations=[1, 1, 1, 1, 1, 1], paddings=[0, 0, 0, 0, 0, 0], strides=[1, 1, 2, 1, 2, 2], decoder_kernels=[3, 4, 4, 4, 4, 4], decoder_paddings=[1, 0, 0, 0, 0, 0], decoder_strides=[1, 1, 1, 2, 2, 1], split_filter=args.split_filter).to(device) print(model) elif args.model_type == 'fcns_1n': filter_list = [ 1, int(args.model_multiplier * 4), int(args.model_multiplier * 8), int(args.model_multiplier * 16), ] print('Model filter sizes list is {}'.format(filter_list)) model = VAE1N(filters=filter_list, dilations=[1, 1, 1], paddings=[1, 1, 1], strides=[2, 2, 2], decoder_kernels=[4, 4, 3], decoder_paddings=[1, 1, 1], decoder_strides=[2, 2, 2], latent_space_size=10).to(device) elif args.model_type == 'fcns': model = VAESimplifiedFC().to(device) elif args.model_type == 'fc': model = VAEBaseline( latent_space_size=args.latent_space_size).to(device) elif args.model_type == 'fc_conv': model = VAEBaselineConv( latent_space_size=args.latent_space_size).to(device) if args.optimizer.startswith('adam'): optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), # Only finetunable params lr=args.lr) elif args.optimizer.startswith('rmsprop'): optimizer = torch.optim.RMSprop( filter(lambda p: p.requires_grad, model.parameters()), # Only finetunable params lr=args.lr) elif args.optimizer.startswith('sgd'): optimizer = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), # Only finetunable params lr=args.lr) else: raise ValueError('Optimizer not supported') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint (epoch {})".format( checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.predict: pass elif args.evaluate: pass else: if args.dataset_type == 'fmnist': train_dataset = FMNISTDataset(mode='train', random_state=args.seed, use_augs=args.do_augs) val_dataset = FMNISTDataset(mode='val', random_state=args.seed, use_augs=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, **kwargs) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, **kwargs) elif args.dataset_type == 'mnist': train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True, **kwargs) criterion = VAELoss( use_running_mean=args.do_running_mean, image_loss_type=args.image_loss_type, image_loss_weight=args.img_loss_weight, kl_loss_weight=args.kl_loss_weight, ssim_window_size=args.ssim_window_size, latent_space_size=args.latent_space_size).to(device) # criterion = loss_function ssim = SSIM(window_size=args.ssim_window_size, size_average=True).to(device) scheduler = MultiStepLR(optimizer, milestones=[args.m1, args.m2], gamma=0.1) for epoch in range(args.start_epoch, args.epochs): # train for one epoch train_loss, train_img_loss, train_kl_loss, train_ssim = train( train_loader, model, criterion, ssim, optimizer, epoch) # evaluate on validation set val_loss, val_img_loss, val_kl_loss, val_ssim = validate( val_loader, model, criterion, ssim) scheduler.step() #============ TensorBoard logging ============# # Log the scalar values if args.tensorboard: info = { 'eph_tr_loss': train_loss, 'eph_tr_ssim': train_ssim, 'eph_val_loss': val_loss, 'eph_val_ssim': val_ssim, 'eph_tr_img_loss': train_img_loss, 'eph_tr_kl_loss': train_kl_loss, 'eph_val_img_loss': val_img_loss, 'eph_val_kl_loss': val_kl_loss, } for tag, value in info.items(): logger.scalar_summary(tag, value, epoch + 1) # remember best prec@1 and save checkpoint is_best = val_loss < best_loss best_loss = min(val_loss, best_loss) save_checkpoint( { 'epoch': epoch + 1, 'optimizer': optimizer.state_dict(), 'state_dict': model.state_dict(), 'best_loss': best_loss, }, is_best, 'weights/{}_checkpoint.pth.tar'.format(str(args.lognumber)), 'weights/{}_best.pth.tar'.format(str(args.lognumber)))
G = model.VGG_VAE(5) D.apply(weights_init) G.apply(weights_init) D.cuda() G.cuda() print(D) print(G) D_criterion = torch.nn.BCEWithLogitsLoss().cuda() D_optimizer = torch.optim.SGD(D.parameters(), lr=1e-3) G_criterion = torch.nn.BCEWithLogitsLoss().cuda() G_l1 = torch.nn.L1Loss().cuda() G_msssim = MSSSIM().cuda() G_ssim = SSIM().cuda() G_optimizer = torch.optim.Adam(G.parameters(), lr=1e-3) pathlib.Path(sample_output).mkdir(parents=True, exist_ok=True) pathlib.Path(os.path.join(sample_output, "images")).mkdir(parents=True, exist_ok=True) d_loss = 0 g_loss = 0 d_to_g_threshold = 0.5 g_to_d_threshold = 0.3 train_d = True train_g = True conditional_training = False _si = 1