def main(): parser = ArgumentParser() parser.add_argument("--augmentation", action='store_true') parser.add_argument("--train-dataset-percentage", type=float, default=100) parser.add_argument("--val-dataset-percentage", type=int, default=100) parser.add_argument("--label-smoothing", type=float, default=0.9) parser.add_argument("--validation-frequency", type=int, default=1) args = parser.parse_args() ENABLE_AUGMENTATION = args.augmentation TRAIN_DATASET_PERCENTAGE = args.train_dataset_percentage VAL_DATASET_PERCENTAGE = args.val_dataset_percentage LABEL_SMOOTHING_FACTOR = args.label_smoothing VALIDATION_FREQUENCY = args.validation_frequency if ENABLE_AUGMENTATION: augment_batch = AugmentPipe() augment_batch.to(device) else: augment_batch = lambda x: x augment_batch.p = 0 NUM_ADV_EPOCHS = round(NUM_ADV_BASELINE_EPOCHS / (TRAIN_DATASET_PERCENTAGE / 100)) NUM_PRETRAIN_EPOCHS = round(NUM_BASELINE_PRETRAIN_EPOCHS / (TRAIN_DATASET_PERCENTAGE / 100)) VALIDATION_FREQUENCY = round(VALIDATION_FREQUENCY / (TRAIN_DATASET_PERCENTAGE / 100)) training_start = datetime.datetime.now().isoformat() train_set = TrainDatasetFromFolder(train_dataset_dir, patch_size=PATCH_SIZE, upscale_factor=UPSCALE_FACTOR) len_train_set = len(train_set) train_set = Subset( train_set, list( np.random.choice( np.arange(len_train_set), int(len_train_set * TRAIN_DATASET_PERCENTAGE / 100), False))) val_set = ValDatasetFromFolder(val_dataset_dir, upscale_factor=UPSCALE_FACTOR) len_val_set = len(val_set) val_set = Subset( val_set, list( np.random.choice(np.arange(len_val_set), int(len_val_set * VAL_DATASET_PERCENTAGE / 100), False))) train_loader = DataLoader(dataset=train_set, num_workers=8, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, prefetch_factor=8) val_loader = DataLoader(dataset=val_set, num_workers=2, batch_size=VAL_BATCH_SIZE, shuffle=False, pin_memory=True, prefetch_factor=2) epoch_validation_hr_dataset = HrValDatasetFromFolder( val_dataset_dir) # Useful to compute FID metric results_folder = Path( f"results_{training_start}_CS:{PATCH_SIZE}_US:{UPSCALE_FACTOR}x_TRAIN:{TRAIN_DATASET_PERCENTAGE}%_AUGMENTATION:{ENABLE_AUGMENTATION}" ) results_folder.mkdir(exist_ok=True) writer = SummaryWriter(str(results_folder / "tensorboard_log")) g_net = Generator(n_residual_blocks=NUM_RESIDUAL_BLOCKS, upsample_factor=UPSCALE_FACTOR) d_net = Discriminator(patch_size=PATCH_SIZE) lpips_metric = lpips.LPIPS(net='alex') g_net.to(device=device) d_net.to(device=device) lpips_metric.to(device=device) g_optimizer = optim.Adam(g_net.parameters(), lr=1e-4) d_optimizer = optim.Adam(d_net.parameters(), lr=1e-4) bce_loss = BCELoss() mse_loss = MSELoss() bce_loss.to(device=device) mse_loss.to(device=device) results = { 'd_total_loss': [], 'g_total_loss': [], 'g_adv_loss': [], 'g_content_loss': [], 'd_real_mean': [], 'd_fake_mean': [], 'psnr': [], 'ssim': [], 'lpips': [], 'fid': [], 'rt': [], 'augment_probability': [] } augment_probability = 0 num_images = len(train_set) * (NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS) prediction_list = [] rt = 0 for epoch in range(1, NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS + 1): train_bar = tqdm(train_loader, ncols=200) running_results = { 'batch_sizes': 0, 'd_epoch_total_loss': 0, 'g_epoch_total_loss': 0, 'g_epoch_adv_loss': 0, 'g_epoch_content_loss': 0, 'd_epoch_real_mean': 0, 'd_epoch_fake_mean': 0, 'rt': 0, 'augment_probability': 0 } image_percentage = epoch / (NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS) * 100 g_net.train() d_net.train() for data, target in train_bar: augment_batch.p = torch.tensor([augment_probability], device=device) batch_size = data.size(0) running_results["batch_sizes"] += batch_size target = target.to(device) data = data.to(device) real_labels = torch.ones(batch_size, device=device) fake_labels = torch.zeros(batch_size, device=device) if epoch > NUM_PRETRAIN_EPOCHS: # Discriminator training d_optimizer.zero_grad(set_to_none=True) d_real_output = d_net(augment_batch(target)) d_real_output_loss = bce_loss( d_real_output, real_labels * LABEL_SMOOTHING_FACTOR) fake_img = g_net(data) d_fake_output = d_net(augment_batch(fake_img)) d_fake_output_loss = bce_loss(d_fake_output, fake_labels) d_total_loss = d_real_output_loss + d_fake_output_loss d_total_loss.backward() d_optimizer.step() d_real_mean = d_real_output.mean() d_fake_mean = d_fake_output.mean() # Generator training g_optimizer.zero_grad(set_to_none=True) fake_img = g_net(data) if epoch > NUM_PRETRAIN_EPOCHS: adversarial_loss = bce_loss(d_net(augment_batch(fake_img)), real_labels) * ADV_LOSS_BALANCER content_loss = mse_loss(fake_img, target) g_total_loss = content_loss + adversarial_loss else: adversarial_loss = mse_loss(torch.zeros( 1, device=device), torch.zeros( 1, device=device)) # Logging purposes, it is always zero content_loss = mse_loss(fake_img, target) g_total_loss = content_loss g_total_loss.backward() g_optimizer.step() if epoch > NUM_PRETRAIN_EPOCHS and ENABLE_AUGMENTATION: prediction_list.append( (torch.sign(d_real_output - 0.5)).tolist()) if len(prediction_list) == RT_BATCH_SMOOTHING_FACTOR: rt_list = [ prediction for sublist in prediction_list for prediction in sublist ] rt = mean(rt_list) if mean(rt_list) > AUGMENT_PROB_TARGET: augment_probability = min( 0.85, augment_probability + AUGMENT_PROBABABILITY_STEP) else: augment_probability = max( 0., augment_probability - AUGMENT_PROBABABILITY_STEP) prediction_list.clear() running_results['g_epoch_total_loss'] += g_total_loss.to( 'cpu', non_blocking=True).detach() * batch_size running_results['g_epoch_adv_loss'] += adversarial_loss.to( 'cpu', non_blocking=True).detach() * batch_size running_results['g_epoch_content_loss'] += content_loss.to( 'cpu', non_blocking=True).detach() * batch_size if epoch > NUM_PRETRAIN_EPOCHS: running_results['d_epoch_total_loss'] += d_total_loss.to( 'cpu', non_blocking=True).detach() * batch_size running_results['d_epoch_real_mean'] += d_real_mean.to( 'cpu', non_blocking=True).detach() * batch_size running_results['d_epoch_fake_mean'] += d_fake_mean.to( 'cpu', non_blocking=True).detach() * batch_size running_results['rt'] += rt * batch_size running_results[ 'augment_probability'] += augment_probability * batch_size train_bar.set_description( desc=f'[{epoch}/{NUM_ADV_EPOCHS + NUM_PRETRAIN_EPOCHS}] ' f'Loss_D: {running_results["d_epoch_total_loss"] / running_results["batch_sizes"]:.4f} ' f'Loss_G: {running_results["g_epoch_total_loss"] / running_results["batch_sizes"]:.4f} ' f'Loss_G_adv: {running_results["g_epoch_adv_loss"] / running_results["batch_sizes"]:.4f} ' f'Loss_G_content: {running_results["g_epoch_content_loss"] / running_results["batch_sizes"]:.4f} ' f'D(x): {running_results["d_epoch_real_mean"] / running_results["batch_sizes"]:.4f} ' f'D(G(z)): {running_results["d_epoch_fake_mean"] / running_results["batch_sizes"]:.4f} ' f'rt: {running_results["rt"] / running_results["batch_sizes"]:.4f} ' f'augment_probability: {running_results["augment_probability"] / running_results["batch_sizes"]:.4f}' ) if epoch == 1 or epoch == ( NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1: torch.cuda.empty_cache() gc.collect() g_net.eval() # ... images_path = results_folder / Path(f'training_images_results') images_path.mkdir(exist_ok=True) with torch.no_grad(): val_bar = tqdm(val_loader, ncols=160) val_results = { 'epoch_mse': 0, 'epoch_ssim': 0, 'epoch_psnr': 0, 'epoch_avg_psnr': 0, 'epoch_avg_ssim': 0, 'epoch_lpips': 0, 'epoch_avg_lpips': 0, 'epoch_fid': 0, 'batch_sizes': 0 } val_images = torch.empty((0, 0)) epoch_validation_sr_dataset = None for lr, val_hr_restore, hr in val_bar: batch_size = lr.size(0) val_results['batch_sizes'] += batch_size hr = hr.to(device=device) lr = lr.to(device=device) sr = g_net(lr) sr = torch.clamp(sr, 0., 1.) if not epoch_validation_sr_dataset: epoch_validation_sr_dataset = SingleTensorDataset( (sr.cpu() * 255).to(torch.uint8)) else: epoch_validation_sr_dataset = ConcatDataset( (epoch_validation_sr_dataset, SingleTensorDataset( (sr.cpu() * 255).to(torch.uint8)))) batch_mse = ((sr - hr)**2).data.mean() # Pixel-wise MSE val_results['epoch_mse'] += batch_mse * batch_size batch_ssim = pytorch_ssim.ssim(sr, hr).item() val_results['epoch_ssim'] += batch_ssim * batch_size val_results['epoch_avg_ssim'] = val_results[ 'epoch_ssim'] / val_results['batch_sizes'] val_results['epoch_psnr'] += 20 * log10( hr.max() / (batch_mse / batch_size)) * batch_size val_results['epoch_avg_psnr'] = val_results[ 'epoch_psnr'] / val_results['batch_sizes'] val_results['epoch_lpips'] += torch.mean( lpips_metric(hr * 2 - 1, sr * 2 - 1)).to( 'cpu', non_blocking=True).detach() * batch_size val_results['epoch_avg_lpips'] = val_results[ 'epoch_lpips'] / val_results['batch_sizes'] val_bar.set_description( desc= f"[converting LR images to SR images] PSNR: {val_results['epoch_avg_psnr']:4f} dB " f"SSIM: {val_results['epoch_avg_ssim']:4f} " f"LPIPS: {val_results['epoch_avg_lpips']:.4f} ") if val_images.size(0) * val_images.size( 1) < NUM_LOGGED_VALIDATION_IMAGES * 3: if val_images.size(0) == 0: val_images = torch.hstack( (display_transform(CENTER_CROP_SIZE) (val_hr_restore).unsqueeze(0).transpose(0, 1), display_transform(CENTER_CROP_SIZE)( hr.data.cpu()).unsqueeze(0).transpose( 0, 1), display_transform(CENTER_CROP_SIZE)( sr.data.cpu()).unsqueeze(0).transpose( 0, 1))) else: val_images = torch.cat(( val_images, torch.hstack( (display_transform(CENTER_CROP_SIZE)( val_hr_restore).unsqueeze(0).transpose( 0, 1), display_transform(CENTER_CROP_SIZE)( hr.data.cpu()).unsqueeze(0).transpose( 0, 1), display_transform(CENTER_CROP_SIZE)( sr.data.cpu()).unsqueeze(0).transpose( 0, 1))))) val_results['epoch_fid'] = calculate_metrics( epoch_validation_sr_dataset, epoch_validation_hr_dataset, cuda=True, fid=True, verbose=True )['frechet_inception_distance'] # Set batch_size=1 if you get memory error (inside calculate metric function) val_images = val_images.view( (NUM_LOGGED_VALIDATION_IMAGES // 4, -1, 3, CENTER_CROP_SIZE, CENTER_CROP_SIZE)) val_save_bar = tqdm(val_images, desc='[saving validation results]', ncols=160) for index, image_batch in enumerate(val_save_bar, start=1): image_grid = utils.make_grid(image_batch, nrow=3, padding=5) writer.add_image( f'progress{image_percentage:.1f}_index_{index}.png', image_grid) # save loss / scores / psnr /ssim results['d_total_loss'].append(running_results['d_epoch_total_loss'] / running_results['batch_sizes']) results['g_total_loss'].append(running_results['g_epoch_total_loss'] / running_results['batch_sizes']) results['g_adv_loss'].append(running_results['g_epoch_adv_loss'] / running_results['batch_sizes']) results['g_content_loss'].append( running_results['g_epoch_content_loss'] / running_results['batch_sizes']) results['d_real_mean'].append(running_results['d_epoch_real_mean'] / running_results['batch_sizes']) results['d_fake_mean'].append(running_results['d_epoch_fake_mean'] / running_results['batch_sizes']) results['rt'].append(running_results['rt'] / running_results['batch_sizes']) results['augment_probability'].append( running_results['augment_probability'] / running_results['batch_sizes']) if epoch == 1 or epoch == ( NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1: results['psnr'].append(val_results['epoch_avg_psnr']) results['ssim'].append(val_results['epoch_avg_ssim']) results['lpips'].append(val_results['epoch_avg_lpips']) results['fid'].append(val_results['epoch_fid']) for metric, metric_values in results.items(): if epoch == 1 or epoch == ( NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1 or \ metric not in ["psnr", "ssim", "lpips", "fid"]: writer.add_scalar(metric, metric_values[-1], int(image_percentage * num_images * 0.01)) if epoch == 1 or epoch == ( NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1: # save model parameters models_path = results_folder / "saved_models" models_path.mkdir(exist_ok=True) torch.save( { 'progress': image_percentage, 'g_net': g_net.state_dict(), 'd_net': g_net.state_dict(), # 'g_optimizer': g_optimizer.state_dict(), Uncomment this if you want resume training # 'd_optimizer': d_optimizer.state_dict(), }, str(models_path / f'progress_{image_percentage:.1f}.tar'))
def main(): # Define some hyper-parameters for training global optimizer benchmarks = 'Sweden' model_name = 'ComplEx' opt_method = 'Adagrad' # "Adagrad" "Adadelta" "Adam" "SGD" GDR = False # 是否引入坐标信息 emb_dim = 100 # bilinear model # ent_dim = emb_dim # rel_dim = emb_dim lr = 0.0001 # margin = 1.5 n_epochs = 10000 train_b_size = 512 # 训练时batch size eval_b_size = 256 # 测评valid test 时batch size validation_freq = 10 # 多少轮进行在验证集进行一次测试 同时保存最佳模型 require_improvement = validation_freq * 5 # 验证集top_k超过多少epoch没下降,结束训练 model_save_path = './checkpoint/' + benchmarks + '_' + model_name + '_' + opt_method + '.ckpt' # 保存最佳hits k (ent)模型 device = 'cuda:0' if cuda.is_available() else 'cpu' # Load dataset module = getattr(import_module('torchkge.models'), model_name + 'Model') load_data = getattr(import_module('torchkge.utils.datasets'), 'load_' + benchmarks) print('Loading data...') kg_train, kg_val, kg_test = load_data(GDR=GDR) print( f'Train set: {kg_train.n_ent} entities, {kg_train.n_rel} relations, {kg_train.n_facts} triplets.' ) print( f'Valid set: {kg_val.n_facts} triplets, Test set: {kg_test.n_facts} triplets.' ) # Define the model and criterion print('Loading model...') model = module(emb_dim, kg_train.n_ent, kg_train.n_rel) # criterion = MarginLoss(margin) # criterion = BinaryCrossEntropyLoss() criterion = MSELoss(reduction='sum') # Move everything to CUDA if available if device == 'cuda:0': cuda.empty_cache() model.to(device) criterion.to(device) dataloader = DataLoader(kg_train, batch_size=train_b_size, use_cuda='all') else: dataloader = DataLoader(kg_train, batch_size=train_b_size, use_cuda=None) # Define the torch optimizer to be used optimizer = optimizer(model, opt_method=opt_method, lr=lr) # optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5) sampler = BernoulliNegativeSampler(kg_train) start_epoch = 1 best_score = float('-inf') if os.path.exists(model_save_path): # 存在则加载模型 并继续训练 start_epoch, best_score = load_ckpt(model_save_path, model, optimizer) print(f'loading ckpt sucessful, start on epoch {start_epoch}...') print(model) print('lr: {}, dim {}, total epoch: {}, device: {}, batch size: {}, optim: {}, GDR: {}'\ .format(lr, emb_dim, n_epochs, device, train_b_size, opt_method, GDR)) print('Training...') last_improve = start_epoch # 记录上次验证集loss下降的epoch数 start = time.time() # last_improve = start # save_time = start for epoch in range(start_epoch, n_epochs + 1): running_loss = 0.0 model.train() for i, batch in enumerate(dataloader): if GDR: h, t, r, point = batch[0], batch[1], batch[2], batch[3] n_h, n_t = sampler.corrupt_batch(h, t, r) # 1:1 negative sampling n_point = id2point(n_h, n_t, kg_train.id2point) optimizer.zero_grad() # forward + backward + optimize pos, neg = model(h, t, n_h, n_t, r) loss = criterion(pos, neg, point, n_point) else: h, t, r = batch[0], batch[1], batch[2] n_h, n_t = sampler.corrupt_batch(h, t, r) optimizer.zero_grad() pos, neg = model(h, t, n_h, n_t, r) loss = criterion(pos, neg) loss.backward() optimizer.step() running_loss += loss.item() # test if epoch % validation_freq == 0: create_dir_not_exists('./checkpoint') model.eval() evaluator = LinkPredictionEvaluator(model, kg_val) evaluator.evaluate(b_size=eval_b_size, verbose=False) _, hit_at_k = evaluator.hit_at_k(10) # val filter hit_k print('Epoch [{:>5}/{:>5}] '.format(epoch, n_epochs), end='') if hit_at_k > best_score: save_ckpt(model, optimizer, epoch, best_score, model_save_path) best_score = hit_at_k improve = '*' # 在有提升的结果后面加上*标注 last_improve = epoch # 验证集hit_k增大即认为有提升 else: improve = '' msg = '| mean loss: {:>8.3f}, Time: {}, Val Hit@10: {:>5.2%} {}' print( msg.format(running_loss / len(dataloader), time_since(start), hit_at_k, improve)) model.normalize_parameters() if epoch - last_improve > require_improvement: # 验证集top_k超过一定epoch没增加,结束训练 print("\nNo optimization for a long time, auto-stopping...") break print('Training done, start evaluate on test data...') print('lr: {}, dim {}, device: {}, eval batch size: {}, optim: {}, GDR: {}'\ .format(lr, emb_dim, device, eval_b_size, opt_method, GDR)) # Testing the best checkpoint on test dataset load_ckpt(model_save_path, model, optimizer) model.eval() lp_evaluator = LinkPredictionEvaluator(model, kg_test) lp_evaluator.evaluate(eval_b_size, verbose=False) lp_evaluator.print_results() # rp_evaluator = RelationPredictionEvaluator(model, kg_test) # rp_evaluator.evaluate(eval_b_size, verbose=False) # rp_evaluator.print_results() print(f'Total time cost: {time_since(start)}')
class Trainer(object): def __init__(self, cfg): self.cfg = cfg self.device = cfg.MODEL.DEVICE self.height = cfg.INPUT.HEIGHT self.width = cfg.INPUT.WIDTH self.scales = cfg.INPUT.SCALES self.frame_ids = cfg.INPUT.FRAME_IDS assert self.height % 32 == 0, "'height' must be a multiple of 32" assert self.width % 32 == 0, "'width' must be a multiple of 32" self.num_epochs = cfg.SOLVER.NUM_EPOCHS self.batch_size = cfg.SOLVER.IMS_PER_BATCH self.disparity_smoothness = cfg.SOLVER.DISPARITY_SMOOTHNESS self.min_depth = cfg.SOLVER.MIN_DEPTH self.max_depth = cfg.SOLVER.MAX_DEPTH self.epoch = 0 self.step = 0 self.output_dir = cfg.OUTPUT_DIR self.log_freq = cfg.SOLVER.LOG_FREQ self.val_freq = cfg.SOLVER.VAL_FREQ # Tensorboard writers now = datetime.datetime.now() self.writers = {} for mode in ["train", "valid"]: self.writers[mode] = SummaryWriter( os.path.join(self.output_dir, "{} {}".format(mode, now))) # Model self.model = MonodepthModel(cfg) self.model.to(self.device) # Optimizer self.model_optimizer = optim.Adam(self.model.parameters_to_train(), cfg.SOLVER.BASE_LR) self.model_lr_scheduler = optim.lr_scheduler.StepLR( self.model_optimizer, cfg.SOLVER.SCHEDULER_STEP_SIZE, cfg.SOLVER.SCHEDULER_GAMMA) # Data self.train_loader = make_data_loader(cfg, is_train=True) self.val_loader = make_data_loader(cfg, is_train=False) self.val_iter = iter(self.val_loader) logger.info("Train dataset size: {}".format( len(self.train_loader.dataset))) logger.info("Valid dataset size: {}".format( len(self.val_loader.dataset))) # Loss self.ssim = SSIM() self.ssim.to(self.device) self.gps_loss = MSELoss() self.gps_loss.to(self.device) self.backproject_depth = {} self.project_3d = {} for scale in self.scales: h = self.height // (2**scale) w = self.width // (2**scale) self.backproject_depth[scale] = BackprojectDepth( self.batch_size, h, w) self.backproject_depth[scale].to(self.device) self.project_3d[scale] = Project3D(self.batch_size, h, w) self.project_3d[scale].to(self.device) def train(self): for p in self.model.parameters_to_train(): p.requires_grad = False for p in self.model.parameters( ['map_pose_encoder', 'map_pose_decoder']): p.requires_grad = True while self.epoch < self.num_epochs: logger.info("Epoch {}/{} LR {}".format(self.epoch + 1, self.num_epochs, self.get_lr())) self.run_epoch() self.epoch += 1 self.model_lr_scheduler.step() self.checkpoint() def run_epoch(self): """Run a single epoch of training and validation """ self.model.set_train() for _, inputs in enumerate(tqdm(self.train_loader)): inputs, outputs = self.model.process_batch(inputs) losses = self.compute_losses(inputs, outputs) self.model_optimizer.zero_grad() losses["loss"].backward() self.model_optimizer.step() self.step += 1 if self.step % self.log_freq == 0: self.log_losses(losses, is_train=True) if self.step % self.val_freq == 0: self.validate() def validate(self): """Validating the model on a single minibatch and log progress """ self.model.set_eval() try: inputs = self.val_iter.next() except StopIteration: self.val_iter = iter(self.val_loader) inputs = self.val_iter.next() with torch.no_grad(): inputs, outputs = self.model.process_batch(inputs) losses = self.compute_losses(inputs, outputs) self.log_losses(losses, is_train=False) self.log_images(inputs, outputs, is_train=False) del inputs, outputs, losses self.model.set_train() def compute_losses(self, inputs, outputs): """Compute the reprojection and smoothness losses for a minibatch """ losses = {} # Create warped images self.generate_images_pred(inputs, outputs) self.generate_map_pred(inputs, outputs) total_loss = 0 for scale in self.scales: img_loss = self.compute_image_loss(inputs, outputs, scale) map_loss = self.compute_map_loss(inputs, outputs, scale) total_loss += img_loss + map_loss losses["loss/{}".format(scale)] = img_loss losses["loss/map{}".format(scale)] = map_loss total_loss /= len(self.scales) gps_loss = self.compute_gps_loss(inputs, outputs) losses["loss/gps"] = gps_loss losses["loss"] = total_loss + gps_loss return losses def generate_images_pred(self, inputs, outputs): """Generate the warped (reprojected) color images for a minibatch. Generated images are saved into the `outputs` dictionary. """ for scale in self.scales: disp = outputs[("disp", scale)] disp = F.interpolate(disp, [self.height, self.width], mode="bilinear", align_corners=False) source_scale = 0 _, depth = disp_to_depth(disp, self.min_depth, self.max_depth) outputs[("depth", 0, scale)] = depth for i, frame_id in enumerate(self.frame_ids[1:]): if frame_id == "s": T = inputs["stereo_T"] else: T = outputs[("cam_T_cam", 0, frame_id)] cam_points = self.backproject_depth[source_scale]( depth, inputs[("inv_K", frame_id, source_scale)]) pix_coords = self.project_3d[source_scale]( cam_points, inputs[("K", frame_id, source_scale)], T) outputs[("sample", frame_id, scale)] = pix_coords outputs[("color", frame_id, scale)] = F.grid_sample( inputs[("color", frame_id, source_scale)], outputs[("sample", frame_id, scale)], padding_mode="border") outputs[("color_identity", frame_id, scale)] = \ inputs[("color", frame_id, source_scale)] def generate_map_pred(self, inputs, outputs): """Generate the warped (reprojected) color images for a minibatch. Generated images are saved into the `outputs` dictionary. """ for scale in self.scales: disp = outputs[("disp", scale)] disp = F.interpolate(disp, [self.height, self.width], mode="bilinear", align_corners=False) source_scale = 0 _, depth = disp_to_depth(disp, self.min_depth, self.max_depth) outputs[("depth", 0, scale)] = depth frame_id = 0 T = outputs[("map_cam_T_cam", 0, frame_id)] cam_points = self.backproject_depth[source_scale]( depth, inputs[("inv_K", frame_id, source_scale)]) pix_coords = self.project_3d[source_scale](cam_points, inputs[("K", frame_id, source_scale)], T) outputs[("sample", frame_id, scale)] = pix_coords outputs[("map_view", frame_id, scale)] = F.grid_sample( inputs[("map_view", frame_id, source_scale)], outputs[("sample", frame_id, scale)], padding_mode="border") outputs[("map_view_identity", frame_id, scale)] = \ inputs[("map_view", frame_id, source_scale)] def compute_image_loss(self, inputs, outputs, scale): loss = 0 reprojection_losses = [] source_scale = 0 disp = outputs[("disp", scale)] color = inputs[("color", 0, scale)] target = inputs[("color", 0, source_scale)] for frame_id in self.frame_ids[1:]: pred = outputs[("color", frame_id, scale)] reprojection_losses.append( self.compute_reprojection_loss(pred, target)) reprojection_loss = torch.cat(reprojection_losses, 1) identity_reprojection_losses = [] for frame_id in self.frame_ids[1:]: pred = inputs[("color", frame_id, source_scale)] identity_reprojection_losses.append( self.compute_reprojection_loss(pred, target)) identity_reprojection_loss = torch.cat(identity_reprojection_losses, 1) # add random numbers to break ties identity_reprojection_loss += torch.randn( identity_reprojection_loss.shape).cuda() * 0.00001 combined = torch.cat((identity_reprojection_loss, reprojection_loss), dim=1) if combined.shape[1] == 1: to_optimise = combined else: to_optimise, idxs = torch.min(combined, dim=1) outputs["identity_selection/{}".format(scale)] = ( idxs > identity_reprojection_loss.shape[1] - 1).float() loss += to_optimise.mean() mean_disp = disp.mean(2, True).mean(3, True) norm_disp = disp / (mean_disp + 1e-7) smooth_loss = get_smooth_loss(norm_disp, color) loss += self.disparity_smoothness * smooth_loss / (2**scale) return loss def compute_map_loss(self, inputs, outputs, scale): pred = outputs[("map_view", 0, scale)] target = inputs[("map_pred", 0, 0)] mask, idxs = torch.max(target, dim=1) abs_diff = torch.abs(target - pred) l1_loss = abs_diff.mean(1) loss = l1_loss * mask loss = loss.mean() return loss def compute_gps_loss(self, inputs, outputs): gps_loss = 0 for frame_id in self.frame_ids[1:]: pred_trans = outputs[("translation", 0, frame_id)][:, 0] targ_trans = inputs['gps_delta', frame_id] pred_norm = torch.norm(pred_trans, dim=2) targ_norm = torch.norm(targ_trans, dim=1, keepdim=True) gps_loss += self.gps_loss(pred_norm, targ_norm) return gps_loss def compute_reprojection_loss(self, pred, target): """Computes reprojection loss between a batch of predicted and target images """ abs_diff = torch.abs(target - pred) l1_loss = abs_diff.mean(1, True) ssim_loss = self.ssim(pred, target).mean(1, True) reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss return reprojection_loss def get_lr(self): for param_group in self.model_optimizer.param_groups: return param_group['lr'] def log_losses(self, losses, is_train=True): """Write an event to the tensorboard events file """ mode = "train" if is_train else "valid" writer = self.writers[mode] for l, v in losses.items(): writer.add_scalar("{}".format(l), v, self.step) def log_images(self, inputs, outputs, is_train=True): mode = "train" if is_train else "valid" writer = self.writers[mode] num_images = min(4, self.batch_size) # write a maxmimum of four images for j in range(num_images): for s in self.scales: for frame_id in self.frame_ids: writer.add_image("color_{}_{}/{}".format(frame_id, s, j), inputs[("color", frame_id, s)][j].data, self.step) if s == 0 and frame_id != 0: writer.add_image( "color_pred_{}_{}/{}".format(frame_id, s, j), outputs[("color", frame_id, s)][j].data, self.step) writer.add_image("disp_{}/{}".format(s, j), normalize_image(outputs[("disp", s)][j]), self.step) writer.add_image( "automask_{}/{}".format(s, j), outputs["identity_selection/{}".format(s)][j][None, ...], self.step) writer.add_image("map_view_{}/{}".format(s, j), inputs[("map_view", 0, s)][j].data, self.step) writer.add_image("map_pred_{}/{}".format(s, j), inputs[("map_pred", 0, s)][j].data, self.step) if s == 0: writer.add_image("map_warp_{}/{}".format(s, j), outputs[("map_view", 0, s)][j].data, self.step) def checkpoint(self): """Save model weights to disk """ save_folder = os.path.join(self.output_dir, "models", "weights_{}".format(self.epoch)) if not os.path.exists(save_folder): os.makedirs(save_folder) logger.info("Saving to {}".format(save_folder)) self.model.save_model(save_folder) # Save trainer state trainer_state = { 'epoch': self.epoch, 'step': self.step, 'optimizer': self.model_optimizer.state_dict(), 'scheduler': self.model_lr_scheduler.state_dict(), } save_path = os.path.join(save_folder, "{}.pth".format('trainer')) torch.save(trainer_state, save_path) # Symlink latest model for resuming latest_model_path = os.path.join(self.output_dir, "models", "latest_weights") if os.path.islink(latest_model_path): os.unlink(latest_model_path) os.symlink(os.path.basename(save_folder), latest_model_path) def load_checkpoint(self, load_optimizer=True): """Load model(s) from disk """ save_folder = os.path.join(self.output_dir, "models", "latest_weights") assert os.path.isdir(save_folder), "Cannot find folder {}".format( save_folder) logger.info("Loading from {}".format(save_folder)) self.model.load_model(save_folder) self.model.to(self.device) if load_optimizer: # Load trainer state save_path = os.path.join(save_folder, "{}.pth".format("trainer")) if os.path.isfile(save_path): logger.info("Loading trainer...") trainer_state = torch.load(save_path) self.epoch = trainer_state['epoch'] self.step = trainer_state['step'] self.model_lr_scheduler.load_state_dict( trainer_state['scheduler']) logger.info( "Unresolved issue: Unable to load saved optimizer weights." ) # self.model_optimizer.load_state_dict(trainer_state['optimizer']) else: logger.info("Could not load trainer")