Beispiel #1
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--augmentation", action='store_true')
    parser.add_argument("--train-dataset-percentage", type=float, default=100)
    parser.add_argument("--val-dataset-percentage", type=int, default=100)
    parser.add_argument("--label-smoothing", type=float, default=0.9)
    parser.add_argument("--validation-frequency", type=int, default=1)
    args = parser.parse_args()

    ENABLE_AUGMENTATION = args.augmentation
    TRAIN_DATASET_PERCENTAGE = args.train_dataset_percentage
    VAL_DATASET_PERCENTAGE = args.val_dataset_percentage
    LABEL_SMOOTHING_FACTOR = args.label_smoothing
    VALIDATION_FREQUENCY = args.validation_frequency

    if ENABLE_AUGMENTATION:
        augment_batch = AugmentPipe()
        augment_batch.to(device)
    else:
        augment_batch = lambda x: x
        augment_batch.p = 0

    NUM_ADV_EPOCHS = round(NUM_ADV_BASELINE_EPOCHS /
                           (TRAIN_DATASET_PERCENTAGE / 100))
    NUM_PRETRAIN_EPOCHS = round(NUM_BASELINE_PRETRAIN_EPOCHS /
                                (TRAIN_DATASET_PERCENTAGE / 100))
    VALIDATION_FREQUENCY = round(VALIDATION_FREQUENCY /
                                 (TRAIN_DATASET_PERCENTAGE / 100))

    training_start = datetime.datetime.now().isoformat()

    train_set = TrainDatasetFromFolder(train_dataset_dir,
                                       patch_size=PATCH_SIZE,
                                       upscale_factor=UPSCALE_FACTOR)
    len_train_set = len(train_set)
    train_set = Subset(
        train_set,
        list(
            np.random.choice(
                np.arange(len_train_set),
                int(len_train_set * TRAIN_DATASET_PERCENTAGE / 100), False)))

    val_set = ValDatasetFromFolder(val_dataset_dir,
                                   upscale_factor=UPSCALE_FACTOR)
    len_val_set = len(val_set)
    val_set = Subset(
        val_set,
        list(
            np.random.choice(np.arange(len_val_set),
                             int(len_val_set * VAL_DATASET_PERCENTAGE / 100),
                             False)))

    train_loader = DataLoader(dataset=train_set,
                              num_workers=8,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              pin_memory=True,
                              prefetch_factor=8)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=2,
                            batch_size=VAL_BATCH_SIZE,
                            shuffle=False,
                            pin_memory=True,
                            prefetch_factor=2)

    epoch_validation_hr_dataset = HrValDatasetFromFolder(
        val_dataset_dir)  # Useful to compute FID metric

    results_folder = Path(
        f"results_{training_start}_CS:{PATCH_SIZE}_US:{UPSCALE_FACTOR}x_TRAIN:{TRAIN_DATASET_PERCENTAGE}%_AUGMENTATION:{ENABLE_AUGMENTATION}"
    )
    results_folder.mkdir(exist_ok=True)
    writer = SummaryWriter(str(results_folder / "tensorboard_log"))
    g_net = Generator(n_residual_blocks=NUM_RESIDUAL_BLOCKS,
                      upsample_factor=UPSCALE_FACTOR)
    d_net = Discriminator(patch_size=PATCH_SIZE)
    lpips_metric = lpips.LPIPS(net='alex')

    g_net.to(device=device)
    d_net.to(device=device)
    lpips_metric.to(device=device)

    g_optimizer = optim.Adam(g_net.parameters(), lr=1e-4)
    d_optimizer = optim.Adam(d_net.parameters(), lr=1e-4)

    bce_loss = BCELoss()
    mse_loss = MSELoss()

    bce_loss.to(device=device)
    mse_loss.to(device=device)
    results = {
        'd_total_loss': [],
        'g_total_loss': [],
        'g_adv_loss': [],
        'g_content_loss': [],
        'd_real_mean': [],
        'd_fake_mean': [],
        'psnr': [],
        'ssim': [],
        'lpips': [],
        'fid': [],
        'rt': [],
        'augment_probability': []
    }

    augment_probability = 0
    num_images = len(train_set) * (NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS)
    prediction_list = []
    rt = 0

    for epoch in range(1, NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS + 1):
        train_bar = tqdm(train_loader, ncols=200)
        running_results = {
            'batch_sizes': 0,
            'd_epoch_total_loss': 0,
            'g_epoch_total_loss': 0,
            'g_epoch_adv_loss': 0,
            'g_epoch_content_loss': 0,
            'd_epoch_real_mean': 0,
            'd_epoch_fake_mean': 0,
            'rt': 0,
            'augment_probability': 0
        }
        image_percentage = epoch / (NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS) * 100
        g_net.train()
        d_net.train()

        for data, target in train_bar:
            augment_batch.p = torch.tensor([augment_probability],
                                           device=device)
            batch_size = data.size(0)
            running_results["batch_sizes"] += batch_size
            target = target.to(device)
            data = data.to(device)
            real_labels = torch.ones(batch_size, device=device)
            fake_labels = torch.zeros(batch_size, device=device)

            if epoch > NUM_PRETRAIN_EPOCHS:
                # Discriminator training
                d_optimizer.zero_grad(set_to_none=True)

                d_real_output = d_net(augment_batch(target))
                d_real_output_loss = bce_loss(
                    d_real_output, real_labels * LABEL_SMOOTHING_FACTOR)

                fake_img = g_net(data)
                d_fake_output = d_net(augment_batch(fake_img))
                d_fake_output_loss = bce_loss(d_fake_output, fake_labels)

                d_total_loss = d_real_output_loss + d_fake_output_loss
                d_total_loss.backward()
                d_optimizer.step()

                d_real_mean = d_real_output.mean()
                d_fake_mean = d_fake_output.mean()

            # Generator training
            g_optimizer.zero_grad(set_to_none=True)

            fake_img = g_net(data)
            if epoch > NUM_PRETRAIN_EPOCHS:
                adversarial_loss = bce_loss(d_net(augment_batch(fake_img)),
                                            real_labels) * ADV_LOSS_BALANCER
                content_loss = mse_loss(fake_img, target)
                g_total_loss = content_loss + adversarial_loss
            else:
                adversarial_loss = mse_loss(torch.zeros(
                    1, device=device), torch.zeros(
                        1,
                        device=device))  # Logging purposes, it is always zero
                content_loss = mse_loss(fake_img, target)
                g_total_loss = content_loss

            g_total_loss.backward()
            g_optimizer.step()

            if epoch > NUM_PRETRAIN_EPOCHS and ENABLE_AUGMENTATION:
                prediction_list.append(
                    (torch.sign(d_real_output - 0.5)).tolist())
                if len(prediction_list) == RT_BATCH_SMOOTHING_FACTOR:
                    rt_list = [
                        prediction for sublist in prediction_list
                        for prediction in sublist
                    ]
                    rt = mean(rt_list)
                    if mean(rt_list) > AUGMENT_PROB_TARGET:
                        augment_probability = min(
                            0.85,
                            augment_probability + AUGMENT_PROBABABILITY_STEP)
                    else:
                        augment_probability = max(
                            0.,
                            augment_probability - AUGMENT_PROBABABILITY_STEP)
                    prediction_list.clear()

            running_results['g_epoch_total_loss'] += g_total_loss.to(
                'cpu', non_blocking=True).detach() * batch_size
            running_results['g_epoch_adv_loss'] += adversarial_loss.to(
                'cpu', non_blocking=True).detach() * batch_size
            running_results['g_epoch_content_loss'] += content_loss.to(
                'cpu', non_blocking=True).detach() * batch_size
            if epoch > NUM_PRETRAIN_EPOCHS:
                running_results['d_epoch_total_loss'] += d_total_loss.to(
                    'cpu', non_blocking=True).detach() * batch_size
                running_results['d_epoch_real_mean'] += d_real_mean.to(
                    'cpu', non_blocking=True).detach() * batch_size
                running_results['d_epoch_fake_mean'] += d_fake_mean.to(
                    'cpu', non_blocking=True).detach() * batch_size
                running_results['rt'] += rt * batch_size
                running_results[
                    'augment_probability'] += augment_probability * batch_size

            train_bar.set_description(
                desc=f'[{epoch}/{NUM_ADV_EPOCHS + NUM_PRETRAIN_EPOCHS}] '
                f'Loss_D: {running_results["d_epoch_total_loss"] / running_results["batch_sizes"]:.4f} '
                f'Loss_G: {running_results["g_epoch_total_loss"] / running_results["batch_sizes"]:.4f} '
                f'Loss_G_adv: {running_results["g_epoch_adv_loss"] / running_results["batch_sizes"]:.4f} '
                f'Loss_G_content: {running_results["g_epoch_content_loss"] / running_results["batch_sizes"]:.4f} '
                f'D(x): {running_results["d_epoch_real_mean"] / running_results["batch_sizes"]:.4f} '
                f'D(G(z)): {running_results["d_epoch_fake_mean"] / running_results["batch_sizes"]:.4f} '
                f'rt: {running_results["rt"] / running_results["batch_sizes"]:.4f} '
                f'augment_probability: {running_results["augment_probability"] / running_results["batch_sizes"]:.4f}'
            )

        if epoch == 1 or epoch == (
                NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS
        ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1:
            torch.cuda.empty_cache()
            gc.collect()
            g_net.eval()
            # ...
            images_path = results_folder / Path(f'training_images_results')
            images_path.mkdir(exist_ok=True)

            with torch.no_grad():
                val_bar = tqdm(val_loader, ncols=160)
                val_results = {
                    'epoch_mse': 0,
                    'epoch_ssim': 0,
                    'epoch_psnr': 0,
                    'epoch_avg_psnr': 0,
                    'epoch_avg_ssim': 0,
                    'epoch_lpips': 0,
                    'epoch_avg_lpips': 0,
                    'epoch_fid': 0,
                    'batch_sizes': 0
                }
                val_images = torch.empty((0, 0))
                epoch_validation_sr_dataset = None
                for lr, val_hr_restore, hr in val_bar:
                    batch_size = lr.size(0)
                    val_results['batch_sizes'] += batch_size
                    hr = hr.to(device=device)
                    lr = lr.to(device=device)

                    sr = g_net(lr)
                    sr = torch.clamp(sr, 0., 1.)
                    if not epoch_validation_sr_dataset:
                        epoch_validation_sr_dataset = SingleTensorDataset(
                            (sr.cpu() * 255).to(torch.uint8))

                    else:
                        epoch_validation_sr_dataset = ConcatDataset(
                            (epoch_validation_sr_dataset,
                             SingleTensorDataset(
                                 (sr.cpu() * 255).to(torch.uint8))))

                    batch_mse = ((sr - hr)**2).data.mean()  # Pixel-wise MSE
                    val_results['epoch_mse'] += batch_mse * batch_size
                    batch_ssim = pytorch_ssim.ssim(sr, hr).item()
                    val_results['epoch_ssim'] += batch_ssim * batch_size
                    val_results['epoch_avg_ssim'] = val_results[
                        'epoch_ssim'] / val_results['batch_sizes']
                    val_results['epoch_psnr'] += 20 * log10(
                        hr.max() / (batch_mse / batch_size)) * batch_size
                    val_results['epoch_avg_psnr'] = val_results[
                        'epoch_psnr'] / val_results['batch_sizes']
                    val_results['epoch_lpips'] += torch.mean(
                        lpips_metric(hr * 2 - 1, sr * 2 - 1)).to(
                            'cpu', non_blocking=True).detach() * batch_size
                    val_results['epoch_avg_lpips'] = val_results[
                        'epoch_lpips'] / val_results['batch_sizes']

                    val_bar.set_description(
                        desc=
                        f"[converting LR images to SR images] PSNR: {val_results['epoch_avg_psnr']:4f} dB "
                        f"SSIM: {val_results['epoch_avg_ssim']:4f} "
                        f"LPIPS: {val_results['epoch_avg_lpips']:.4f} ")
                    if val_images.size(0) * val_images.size(
                            1) < NUM_LOGGED_VALIDATION_IMAGES * 3:
                        if val_images.size(0) == 0:
                            val_images = torch.hstack(
                                (display_transform(CENTER_CROP_SIZE)
                                 (val_hr_restore).unsqueeze(0).transpose(0, 1),
                                 display_transform(CENTER_CROP_SIZE)(
                                     hr.data.cpu()).unsqueeze(0).transpose(
                                         0, 1),
                                 display_transform(CENTER_CROP_SIZE)(
                                     sr.data.cpu()).unsqueeze(0).transpose(
                                         0, 1)))
                        else:
                            val_images = torch.cat((
                                val_images,
                                torch.hstack(
                                    (display_transform(CENTER_CROP_SIZE)(
                                        val_hr_restore).unsqueeze(0).transpose(
                                            0, 1),
                                     display_transform(CENTER_CROP_SIZE)(
                                         hr.data.cpu()).unsqueeze(0).transpose(
                                             0, 1),
                                     display_transform(CENTER_CROP_SIZE)(
                                         sr.data.cpu()).unsqueeze(0).transpose(
                                             0, 1)))))
                val_results['epoch_fid'] = calculate_metrics(
                    epoch_validation_sr_dataset,
                    epoch_validation_hr_dataset,
                    cuda=True,
                    fid=True,
                    verbose=True
                )['frechet_inception_distance']  # Set batch_size=1 if you get memory error (inside calculate metric function)

                val_images = val_images.view(
                    (NUM_LOGGED_VALIDATION_IMAGES // 4, -1, 3,
                     CENTER_CROP_SIZE, CENTER_CROP_SIZE))
                val_save_bar = tqdm(val_images,
                                    desc='[saving validation results]',
                                    ncols=160)

                for index, image_batch in enumerate(val_save_bar, start=1):
                    image_grid = utils.make_grid(image_batch,
                                                 nrow=3,
                                                 padding=5)
                    writer.add_image(
                        f'progress{image_percentage:.1f}_index_{index}.png',
                        image_grid)

        # save loss / scores / psnr /ssim
        results['d_total_loss'].append(running_results['d_epoch_total_loss'] /
                                       running_results['batch_sizes'])
        results['g_total_loss'].append(running_results['g_epoch_total_loss'] /
                                       running_results['batch_sizes'])
        results['g_adv_loss'].append(running_results['g_epoch_adv_loss'] /
                                     running_results['batch_sizes'])
        results['g_content_loss'].append(
            running_results['g_epoch_content_loss'] /
            running_results['batch_sizes'])
        results['d_real_mean'].append(running_results['d_epoch_real_mean'] /
                                      running_results['batch_sizes'])
        results['d_fake_mean'].append(running_results['d_epoch_fake_mean'] /
                                      running_results['batch_sizes'])
        results['rt'].append(running_results['rt'] /
                             running_results['batch_sizes'])
        results['augment_probability'].append(
            running_results['augment_probability'] /
            running_results['batch_sizes'])
        if epoch == 1 or epoch == (
                NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS
        ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1:
            results['psnr'].append(val_results['epoch_avg_psnr'])
            results['ssim'].append(val_results['epoch_avg_ssim'])
            results['lpips'].append(val_results['epoch_avg_lpips'])
            results['fid'].append(val_results['epoch_fid'])

        for metric, metric_values in results.items():
            if epoch == 1 or epoch == (
                    NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1 or \
                    metric not in ["psnr", "ssim", "lpips", "fid"]:
                writer.add_scalar(metric, metric_values[-1],
                                  int(image_percentage * num_images * 0.01))

        if epoch == 1 or epoch == (
                NUM_PRETRAIN_EPOCHS + NUM_ADV_EPOCHS
        ) or epoch % VALIDATION_FREQUENCY == 0 or VALIDATION_FREQUENCY == 1:
            # save model parameters
            models_path = results_folder / "saved_models"
            models_path.mkdir(exist_ok=True)
            torch.save(
                {
                    'progress': image_percentage,
                    'g_net': g_net.state_dict(),
                    'd_net': g_net.state_dict(),
                    # 'g_optimizer': g_optimizer.state_dict(), Uncomment this if you want resume training
                    # 'd_optimizer': d_optimizer.state_dict(),
                },
                str(models_path / f'progress_{image_percentage:.1f}.tar'))
Beispiel #2
0
def main():
    # Define some hyper-parameters for training
    global optimizer
    benchmarks = 'Sweden'
    model_name = 'ComplEx'
    opt_method = 'Adagrad'  # "Adagrad" "Adadelta" "Adam" "SGD"
    GDR = False  # 是否引入坐标信息

    emb_dim = 100  # bilinear model
    # ent_dim = emb_dim
    # rel_dim = emb_dim
    lr = 0.0001
    # margin = 1.5

    n_epochs = 10000
    train_b_size = 512  # 训练时batch size
    eval_b_size = 256  # 测评valid test 时batch size
    validation_freq = 10  # 多少轮进行在验证集进行一次测试 同时保存最佳模型
    require_improvement = validation_freq * 5  # 验证集top_k超过多少epoch没下降,结束训练
    model_save_path = './checkpoint/' + benchmarks + '_' + model_name + '_' + opt_method + '.ckpt'  # 保存最佳hits k (ent)模型
    device = 'cuda:0' if cuda.is_available() else 'cpu'

    # Load dataset
    module = getattr(import_module('torchkge.models'), model_name + 'Model')
    load_data = getattr(import_module('torchkge.utils.datasets'),
                        'load_' + benchmarks)

    print('Loading data...')
    kg_train, kg_val, kg_test = load_data(GDR=GDR)
    print(
        f'Train set: {kg_train.n_ent} entities, {kg_train.n_rel} relations, {kg_train.n_facts} triplets.'
    )
    print(
        f'Valid set: {kg_val.n_facts} triplets, Test set: {kg_test.n_facts} triplets.'
    )

    # Define the model and criterion
    print('Loading model...')
    model = module(emb_dim, kg_train.n_ent, kg_train.n_rel)
    # criterion = MarginLoss(margin)
    # criterion = BinaryCrossEntropyLoss()
    criterion = MSELoss(reduction='sum')

    # Move everything to CUDA if available
    if device == 'cuda:0':
        cuda.empty_cache()
        model.to(device)
        criterion.to(device)
        dataloader = DataLoader(kg_train,
                                batch_size=train_b_size,
                                use_cuda='all')
    else:
        dataloader = DataLoader(kg_train,
                                batch_size=train_b_size,
                                use_cuda=None)

    # Define the torch optimizer to be used
    optimizer = optimizer(model, opt_method=opt_method, lr=lr)
    # optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    sampler = BernoulliNegativeSampler(kg_train)

    start_epoch = 1
    best_score = float('-inf')
    if os.path.exists(model_save_path):  # 存在则加载模型 并继续训练
        start_epoch, best_score = load_ckpt(model_save_path, model, optimizer)
        print(f'loading ckpt sucessful, start on epoch {start_epoch}...')
    print(model)
    print('lr: {}, dim {}, total epoch: {}, device: {}, batch size: {}, optim: {}, GDR: {}'\
    .format(lr, emb_dim, n_epochs, device, train_b_size, opt_method, GDR))

    print('Training...')
    last_improve = start_epoch  # 记录上次验证集loss下降的epoch数
    start = time.time()
    # last_improve = start
    # save_time = start
    for epoch in range(start_epoch, n_epochs + 1):
        running_loss = 0.0
        model.train()
        for i, batch in enumerate(dataloader):
            if GDR:
                h, t, r, point = batch[0], batch[1], batch[2], batch[3]
                n_h, n_t = sampler.corrupt_batch(h, t,
                                                 r)  # 1:1 negative sampling
                n_point = id2point(n_h, n_t, kg_train.id2point)
                optimizer.zero_grad()

                # forward + backward + optimize
                pos, neg = model(h, t, n_h, n_t, r)
                loss = criterion(pos, neg, point, n_point)
            else:
                h, t, r = batch[0], batch[1], batch[2]
                n_h, n_t = sampler.corrupt_batch(h, t, r)
                optimizer.zero_grad()
                pos, neg = model(h, t, n_h, n_t, r)
                loss = criterion(pos, neg)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # test
        if epoch % validation_freq == 0:
            create_dir_not_exists('./checkpoint')
            model.eval()
            evaluator = LinkPredictionEvaluator(model, kg_val)
            evaluator.evaluate(b_size=eval_b_size, verbose=False)
            _, hit_at_k = evaluator.hit_at_k(10)  # val filter hit_k
            print('Epoch [{:>5}/{:>5}] '.format(epoch, n_epochs), end='')
            if hit_at_k > best_score:
                save_ckpt(model, optimizer, epoch, best_score, model_save_path)
                best_score = hit_at_k
                improve = '*'  # 在有提升的结果后面加上*标注
                last_improve = epoch  # 验证集hit_k增大即认为有提升
            else:
                improve = ''
            msg = '| mean loss: {:>8.3f}, Time: {}, Val Hit@10: {:>5.2%} {}'
            print(
                msg.format(running_loss / len(dataloader), time_since(start),
                           hit_at_k, improve))
        model.normalize_parameters()
        if epoch - last_improve > require_improvement:
            # 验证集top_k超过一定epoch没增加,结束训练
            print("\nNo optimization for a long time, auto-stopping...")
            break

    print('Training done, start evaluate on test data...')
    print('lr: {}, dim {}, device: {}, eval batch size: {}, optim: {}, GDR: {}'\
    .format(lr, emb_dim, device, eval_b_size, opt_method, GDR))

    # Testing the best checkpoint on test dataset
    load_ckpt(model_save_path, model, optimizer)
    model.eval()
    lp_evaluator = LinkPredictionEvaluator(model, kg_test)
    lp_evaluator.evaluate(eval_b_size, verbose=False)
    lp_evaluator.print_results()
    # rp_evaluator = RelationPredictionEvaluator(model, kg_test)
    # rp_evaluator.evaluate(eval_b_size, verbose=False)
    # rp_evaluator.print_results()
    print(f'Total time cost: {time_since(start)}')
Beispiel #3
0
class Trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.MODEL.DEVICE

        self.height = cfg.INPUT.HEIGHT
        self.width = cfg.INPUT.WIDTH
        self.scales = cfg.INPUT.SCALES
        self.frame_ids = cfg.INPUT.FRAME_IDS
        assert self.height % 32 == 0, "'height' must be a multiple of 32"
        assert self.width % 32 == 0, "'width' must be a multiple of 32"

        self.num_epochs = cfg.SOLVER.NUM_EPOCHS
        self.batch_size = cfg.SOLVER.IMS_PER_BATCH
        self.disparity_smoothness = cfg.SOLVER.DISPARITY_SMOOTHNESS
        self.min_depth = cfg.SOLVER.MIN_DEPTH
        self.max_depth = cfg.SOLVER.MAX_DEPTH

        self.epoch = 0
        self.step = 0

        self.output_dir = cfg.OUTPUT_DIR
        self.log_freq = cfg.SOLVER.LOG_FREQ
        self.val_freq = cfg.SOLVER.VAL_FREQ
        # Tensorboard writers
        now = datetime.datetime.now()
        self.writers = {}
        for mode in ["train", "valid"]:
            self.writers[mode] = SummaryWriter(
                os.path.join(self.output_dir, "{} {}".format(mode, now)))

        # Model
        self.model = MonodepthModel(cfg)
        self.model.to(self.device)

        # Optimizer
        self.model_optimizer = optim.Adam(self.model.parameters_to_train(),
                                          cfg.SOLVER.BASE_LR)
        self.model_lr_scheduler = optim.lr_scheduler.StepLR(
            self.model_optimizer, cfg.SOLVER.SCHEDULER_STEP_SIZE,
            cfg.SOLVER.SCHEDULER_GAMMA)

        # Data
        self.train_loader = make_data_loader(cfg, is_train=True)
        self.val_loader = make_data_loader(cfg, is_train=False)
        self.val_iter = iter(self.val_loader)
        logger.info("Train dataset size: {}".format(
            len(self.train_loader.dataset)))
        logger.info("Valid dataset size: {}".format(
            len(self.val_loader.dataset)))

        # Loss
        self.ssim = SSIM()
        self.ssim.to(self.device)
        self.gps_loss = MSELoss()
        self.gps_loss.to(self.device)

        self.backproject_depth = {}
        self.project_3d = {}
        for scale in self.scales:
            h = self.height // (2**scale)
            w = self.width // (2**scale)

            self.backproject_depth[scale] = BackprojectDepth(
                self.batch_size, h, w)
            self.backproject_depth[scale].to(self.device)

            self.project_3d[scale] = Project3D(self.batch_size, h, w)
            self.project_3d[scale].to(self.device)

    def train(self):
        for p in self.model.parameters_to_train():
            p.requires_grad = False
        for p in self.model.parameters(
            ['map_pose_encoder', 'map_pose_decoder']):
            p.requires_grad = True

        while self.epoch < self.num_epochs:
            logger.info("Epoch {}/{}  LR {}".format(self.epoch + 1,
                                                    self.num_epochs,
                                                    self.get_lr()))

            self.run_epoch()
            self.epoch += 1
            self.model_lr_scheduler.step()
            self.checkpoint()

    def run_epoch(self):
        """Run a single epoch of training and validation
        """
        self.model.set_train()

        for _, inputs in enumerate(tqdm(self.train_loader)):
            inputs, outputs = self.model.process_batch(inputs)
            losses = self.compute_losses(inputs, outputs)

            self.model_optimizer.zero_grad()
            losses["loss"].backward()
            self.model_optimizer.step()

            self.step += 1
            if self.step % self.log_freq == 0:
                self.log_losses(losses, is_train=True)
            if self.step % self.val_freq == 0:
                self.validate()

    def validate(self):
        """Validating the model on a single minibatch and log progress
        """
        self.model.set_eval()
        try:
            inputs = self.val_iter.next()
        except StopIteration:
            self.val_iter = iter(self.val_loader)
            inputs = self.val_iter.next()

        with torch.no_grad():
            inputs, outputs = self.model.process_batch(inputs)
            losses = self.compute_losses(inputs, outputs)

            self.log_losses(losses, is_train=False)
            self.log_images(inputs, outputs, is_train=False)
            del inputs, outputs, losses

        self.model.set_train()

    def compute_losses(self, inputs, outputs):
        """Compute the reprojection and smoothness losses for a minibatch
        """
        losses = {}

        # Create warped images
        self.generate_images_pred(inputs, outputs)
        self.generate_map_pred(inputs, outputs)

        total_loss = 0
        for scale in self.scales:
            img_loss = self.compute_image_loss(inputs, outputs, scale)
            map_loss = self.compute_map_loss(inputs, outputs, scale)
            total_loss += img_loss + map_loss
            losses["loss/{}".format(scale)] = img_loss
            losses["loss/map{}".format(scale)] = map_loss
        total_loss /= len(self.scales)

        gps_loss = self.compute_gps_loss(inputs, outputs)
        losses["loss/gps"] = gps_loss

        losses["loss"] = total_loss + gps_loss
        return losses

    def generate_images_pred(self, inputs, outputs):
        """Generate the warped (reprojected) color images for a minibatch.
        Generated images are saved into the `outputs` dictionary.
        """
        for scale in self.scales:
            disp = outputs[("disp", scale)]
            disp = F.interpolate(disp, [self.height, self.width],
                                 mode="bilinear",
                                 align_corners=False)
            source_scale = 0

            _, depth = disp_to_depth(disp, self.min_depth, self.max_depth)

            outputs[("depth", 0, scale)] = depth

            for i, frame_id in enumerate(self.frame_ids[1:]):

                if frame_id == "s":
                    T = inputs["stereo_T"]
                else:
                    T = outputs[("cam_T_cam", 0, frame_id)]

                cam_points = self.backproject_depth[source_scale](
                    depth, inputs[("inv_K", frame_id, source_scale)])
                pix_coords = self.project_3d[source_scale](
                    cam_points, inputs[("K", frame_id, source_scale)], T)

                outputs[("sample", frame_id, scale)] = pix_coords

                outputs[("color", frame_id, scale)] = F.grid_sample(
                    inputs[("color", frame_id, source_scale)],
                    outputs[("sample", frame_id, scale)],
                    padding_mode="border")

                outputs[("color_identity", frame_id, scale)] = \
                    inputs[("color", frame_id, source_scale)]

    def generate_map_pred(self, inputs, outputs):
        """Generate the warped (reprojected) color images for a minibatch.
        Generated images are saved into the `outputs` dictionary.
        """
        for scale in self.scales:
            disp = outputs[("disp", scale)]
            disp = F.interpolate(disp, [self.height, self.width],
                                 mode="bilinear",
                                 align_corners=False)
            source_scale = 0

            _, depth = disp_to_depth(disp, self.min_depth, self.max_depth)

            outputs[("depth", 0, scale)] = depth

            frame_id = 0
            T = outputs[("map_cam_T_cam", 0, frame_id)]

            cam_points = self.backproject_depth[source_scale](
                depth, inputs[("inv_K", frame_id, source_scale)])
            pix_coords = self.project_3d[source_scale](cam_points,
                                                       inputs[("K", frame_id,
                                                               source_scale)],
                                                       T)

            outputs[("sample", frame_id, scale)] = pix_coords

            outputs[("map_view", frame_id, scale)] = F.grid_sample(
                inputs[("map_view", frame_id, source_scale)],
                outputs[("sample", frame_id, scale)],
                padding_mode="border")

            outputs[("map_view_identity", frame_id, scale)] = \
                inputs[("map_view", frame_id, source_scale)]

    def compute_image_loss(self, inputs, outputs, scale):
        loss = 0
        reprojection_losses = []

        source_scale = 0
        disp = outputs[("disp", scale)]
        color = inputs[("color", 0, scale)]
        target = inputs[("color", 0, source_scale)]

        for frame_id in self.frame_ids[1:]:
            pred = outputs[("color", frame_id, scale)]
            reprojection_losses.append(
                self.compute_reprojection_loss(pred, target))

        reprojection_loss = torch.cat(reprojection_losses, 1)

        identity_reprojection_losses = []
        for frame_id in self.frame_ids[1:]:
            pred = inputs[("color", frame_id, source_scale)]
            identity_reprojection_losses.append(
                self.compute_reprojection_loss(pred, target))

        identity_reprojection_loss = torch.cat(identity_reprojection_losses, 1)

        # add random numbers to break ties
        identity_reprojection_loss += torch.randn(
            identity_reprojection_loss.shape).cuda() * 0.00001

        combined = torch.cat((identity_reprojection_loss, reprojection_loss),
                             dim=1)

        if combined.shape[1] == 1:
            to_optimise = combined
        else:
            to_optimise, idxs = torch.min(combined, dim=1)

        outputs["identity_selection/{}".format(scale)] = (
            idxs > identity_reprojection_loss.shape[1] - 1).float()

        loss += to_optimise.mean()

        mean_disp = disp.mean(2, True).mean(3, True)
        norm_disp = disp / (mean_disp + 1e-7)
        smooth_loss = get_smooth_loss(norm_disp, color)

        loss += self.disparity_smoothness * smooth_loss / (2**scale)
        return loss

    def compute_map_loss(self, inputs, outputs, scale):
        pred = outputs[("map_view", 0, scale)]
        target = inputs[("map_pred", 0, 0)]
        mask, idxs = torch.max(target, dim=1)

        abs_diff = torch.abs(target - pred)
        l1_loss = abs_diff.mean(1)

        loss = l1_loss * mask
        loss = loss.mean()
        return loss

    def compute_gps_loss(self, inputs, outputs):
        gps_loss = 0
        for frame_id in self.frame_ids[1:]:
            pred_trans = outputs[("translation", 0, frame_id)][:, 0]
            targ_trans = inputs['gps_delta', frame_id]
            pred_norm = torch.norm(pred_trans, dim=2)
            targ_norm = torch.norm(targ_trans, dim=1, keepdim=True)
            gps_loss += self.gps_loss(pred_norm, targ_norm)
        return gps_loss

    def compute_reprojection_loss(self, pred, target):
        """Computes reprojection loss between a batch of predicted and target images
        """
        abs_diff = torch.abs(target - pred)
        l1_loss = abs_diff.mean(1, True)

        ssim_loss = self.ssim(pred, target).mean(1, True)
        reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss

        return reprojection_loss

    def get_lr(self):
        for param_group in self.model_optimizer.param_groups:
            return param_group['lr']

    def log_losses(self, losses, is_train=True):
        """Write an event to the tensorboard events file
        """
        mode = "train" if is_train else "valid"
        writer = self.writers[mode]

        for l, v in losses.items():
            writer.add_scalar("{}".format(l), v, self.step)

    def log_images(self, inputs, outputs, is_train=True):
        mode = "train" if is_train else "valid"
        writer = self.writers[mode]

        num_images = min(4, self.batch_size)  # write a maxmimum of four images
        for j in range(num_images):

            for s in self.scales:
                for frame_id in self.frame_ids:
                    writer.add_image("color_{}_{}/{}".format(frame_id, s, j),
                                     inputs[("color", frame_id, s)][j].data,
                                     self.step)
                    if s == 0 and frame_id != 0:
                        writer.add_image(
                            "color_pred_{}_{}/{}".format(frame_id, s, j),
                            outputs[("color", frame_id, s)][j].data, self.step)

                writer.add_image("disp_{}/{}".format(s, j),
                                 normalize_image(outputs[("disp", s)][j]),
                                 self.step)

                writer.add_image(
                    "automask_{}/{}".format(s, j),
                    outputs["identity_selection/{}".format(s)][j][None, ...],
                    self.step)

                writer.add_image("map_view_{}/{}".format(s, j),
                                 inputs[("map_view", 0, s)][j].data, self.step)
                writer.add_image("map_pred_{}/{}".format(s, j),
                                 inputs[("map_pred", 0, s)][j].data, self.step)
                if s == 0:
                    writer.add_image("map_warp_{}/{}".format(s, j),
                                     outputs[("map_view", 0, s)][j].data,
                                     self.step)

    def checkpoint(self):
        """Save model weights to disk
        """
        save_folder = os.path.join(self.output_dir, "models",
                                   "weights_{}".format(self.epoch))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        logger.info("Saving to {}".format(save_folder))
        self.model.save_model(save_folder)

        # Save trainer state
        trainer_state = {
            'epoch': self.epoch,
            'step': self.step,
            'optimizer': self.model_optimizer.state_dict(),
            'scheduler': self.model_lr_scheduler.state_dict(),
        }
        save_path = os.path.join(save_folder, "{}.pth".format('trainer'))
        torch.save(trainer_state, save_path)

        # Symlink latest model for resuming
        latest_model_path = os.path.join(self.output_dir, "models",
                                         "latest_weights")
        if os.path.islink(latest_model_path):
            os.unlink(latest_model_path)
        os.symlink(os.path.basename(save_folder), latest_model_path)

    def load_checkpoint(self, load_optimizer=True):
        """Load model(s) from disk
        """
        save_folder = os.path.join(self.output_dir, "models", "latest_weights")
        assert os.path.isdir(save_folder), "Cannot find folder {}".format(
            save_folder)

        logger.info("Loading from {}".format(save_folder))
        self.model.load_model(save_folder)
        self.model.to(self.device)

        if load_optimizer:
            # Load trainer state
            save_path = os.path.join(save_folder, "{}.pth".format("trainer"))
            if os.path.isfile(save_path):
                logger.info("Loading trainer...")
                trainer_state = torch.load(save_path)
                self.epoch = trainer_state['epoch']
                self.step = trainer_state['step']
                self.model_lr_scheduler.load_state_dict(
                    trainer_state['scheduler'])
                logger.info(
                    "Unresolved issue: Unable to load saved optimizer weights."
                )
                # self.model_optimizer.load_state_dict(trainer_state['optimizer'])
            else:
                logger.info("Could not load trainer")