def data_and_model_loader(config,
                          pretrained_depth_path,
                          pretrained_pose_path,
                          seq=None,
                          load_depth=True):
    if seq == None:
        seq = config['test_seq']
    else:
        seq = [seq]
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    test_dset = KittiLoaderPytorch(
        config, [seq, seq, seq],
        mode='test',
        transform_img=get_data_transforms(config)['test'])
    test_dset_loaders = torch.utils.data.DataLoader(
        test_dset,
        batch_size=config['minibatch'],
        shuffle=False,
        num_workers=6)
    eval_dsets = {'test': test_dset_loaders}

    if load_depth:
        depth_model = models.depth_model(config).to(device)
    pose_model = models.pose_model(config).to(device)

    if pretrained_depth_path is not None and load_depth == True:
        depth_model.load_state_dict(torch.load(pretrained_depth_path))
    if pretrained_pose_path is not None:
        pose_model.load_state_dict(torch.load(pretrained_pose_path))

    pose_model.train(False)
    pose_model.eval()
    if load_depth:
        depth_model.train(False)
        depth_model.eval()
    else:
        depth_model = None

    mmodels = [depth_model, pose_model]
    return test_dset_loaders, mmodels, device
Beispiel #2
0
    ### dataset and model loading
    from data.kitti_loader_stereo import KittiLoaderPytorch
    test_dset = KittiLoaderPytorch(
        config, [[seq], [seq], [seq]],
        mode='test',
        transform_img=get_data_transforms(config)['test'])
    test_dset_loaders = torch.utils.data.DataLoader(
        test_dset,
        batch_size=config['minibatch'],
        shuffle=False,
        num_workers=6)

    import models.packetnet_depth_and_egomotion as models_packetnet
    import models.depth_and_egomotion as models

    depth_model = models.depth_model(config).to(device)
    pose_model = models_packetnet.pose_model(config).to(device)
    pretrained_depth_path = glob.glob(
        '{}/**depth**best-loss-val_seq-**-test_seq-{}**.pth'.format(dir,
                                                                    ''))[0]
    pretrained_pose_path = glob.glob(
        '{}/**pose**best-loss-val_seq-**-test_seq-{}**.pth'.format(dir, ''))[0]
    depth_model.load_state_dict(torch.load(pretrained_depth_path))
    pose_model.load_state_dict(torch.load(pretrained_pose_path))
    pose_model.train(False).eval()
    depth_model.train(False).eval()

    ### Plane Model
    from models.plane_net import PlaneModel, scale_recovery
    from losses import Plane_Height_loss
    plane_loss = Plane_Height_loss(config)
def main():
    results = {}
    config['device'] = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    start = time.time()
    now = datetime.datetime.now()
    ts = '{}-{}-{}-{}-{}'.format(now.year, now.month, now.day, now.hour,
                                 now.minute)
    ''' Load Pretrained Models'''
    pretrained_depth_path, pretrained_pose_path = None, None
    if config['load_pretrained_depth']:
        pretrained_depth_path = glob.glob(
            '{}/**depth**best-loss-val_seq-**-test_seq-**.pth'.format(
                config['pretrained_dir']))[0]
    if config['load_pretrained_pose']:
        pretrained_pose_path = glob.glob(
            '{}/**pose**best-loss-val_seq-**-test_seq-**.pth'.format(
                config['pretrained_dir']))[0]

    if config['load_pretrained_pose'] == True and config[
            'load_pretrained_depth'] == True:  ## skip epoch 0 if pose model is pretrained (epoch 0 initializes with VO)
        epochs = range(1, config['num_epochs'])
    else:
        epochs = range(0, config['num_epochs'])

    ### load the models
    depth_model = models_monodepth.depth_model(config).to(config['device'])
    pose_model = models_packetnet.pose_model(config).to(config['device'])

    if pretrained_depth_path is not None and config[
            'load_pretrained_depth'] == True:
        depth_model.load_state_dict(torch.load(pretrained_depth_path))
    if pretrained_pose_path is not None and config[
            'load_pretrained_pose'] == True:
        pose_model.load_state_dict(torch.load(pretrained_pose_path))
    models = [depth_model, pose_model]
    ## Load the pretrained plane estimator if using the plane loss
    if config['l_scale_recovery']:
        from models.plane_net import PlaneModel
        plane_model = PlaneModel(config).to(config['device'])
        pretrained_plane_path = glob.glob('{}/**plane**.pth'.format(
            config['pretrained_plane_dir']))[0]
        plane_model.load_state_dict(torch.load(pretrained_plane_path))
        for param in plane_model.parameters():
            param.requires_grad = False
    else:
        plane_model = None

    if config['freeze_depthnet']: print('Freezing depth network weights.')
    if config['freeze_posenet']: print('Freezing pose network weights.')
    for param in depth_model.parameters():
        param.requires_grad = not config['freeze_depthnet']
    for param in pose_model.parameters():
        param.requires_grad = not config['freeze_posenet']

    params = [{
        'params': depth_model.parameters()
    }, {
        'params': pose_model.parameters(),
        'lr': 2 * config['lr']
    }]
    loss = losses.Compute_Loss(config, plane_model=plane_model)
    optimizer = torch.optim.Adam(params,
                                 lr=config['lr'],
                                 weight_decay=config['wd'])  #, amsgrad=True)
    trainer = Trainer(config, models, loss, optimizer)
    cudnn.benchmark = True

    best_val_loss = {}
    best_loss_epoch = {}
    for key, dset in eval_dsets.items():
        best_val_loss[key] = 1e5

    for epoch in epochs:
        optimizer = exp_lr_scheduler(
            optimizer, epoch, lr_decay_epoch=config['lr_decay_epoch']
        )  ## reduce learning rate as training progresses
        print("Epoch {}".format(epoch))
        train_losses = trainer.forward(dset_loaders['train'], epoch, 'train')
        with torch.no_grad():
            val_losses = trainer.forward(dset_loaders['val'], epoch, 'val')


#
        if epoch == 0 or (epoch == 1 and
                          (config['load_pretrained_pose'] == True)):
            val_writer = SummaryWriter(
                comment="tw-val-{}-test_seq-{}_val".format(
                    args.val_seq[0], args.test_seq[0]))
            train_writer = SummaryWriter(
                comment="tw-val-{}-test_seq-{}_train".format(
                    args.val_seq[0], args.test_seq[0]))

        if train_losses is not None and val_losses is not None:
            for key, value in train_losses.items():
                train_writer.add_scalar('{}'.format(key), value, epoch + 1)
                val_writer.add_scalar('{}'.format(key), val_losses[key],
                                      epoch + 1)

        for key, dset in eval_dsets.items():
            print("{} Set, Epoch {}".format(key, epoch))

            if epoch > 0:
                ###plot images, depth map, explainability mask
                img_array, disparity, exp_mask, d = test_depth_and_reconstruction(
                    device, models, dset, config, epoch=epoch)

                source_disp, reconstructed_disp, d_masks = d[0], d[1], d[2]
                img_array = plot_img_array(img_array)
                train_writer.add_image(key + '/imgs', img_array, epoch + 1)
                for i, d in enumerate(disparity):
                    train_writer.add_image(
                        key + '/depth-{}/target-depth'.format(i), plot_disp(d),
                        epoch + 1)

                    ### For depth consistency
                    train_writer.add_image(
                        key + '/depth-{}/source-depth'.format(i),
                        plot_disp(source_disp[i]), epoch + 1)
                    train_writer.add_image(
                        key + '/depth-{}/reconstructed-depth'.format(i),
                        plot_disp(reconstructed_disp[i]), epoch + 1)

                d_masks = plot_img_array(d_masks)
                train_writer.add_image(key + '/depth/masks', d_masks,
                                       epoch + 1)

                exp_mask = plot_img_array(exp_mask)
                train_writer.add_image(key + '/exp_mask', exp_mask, epoch + 1)

                ###evaluate trajectories
                if args.data_format == 'odometry':
                    est_lie_alg, gt_lie_alg, est_traj, gt_traj, errors = test_trajectory(
                        config, device, models, dset, epoch)

                    errors = plot_6_by_1(est_lie_alg - gt_lie_alg,
                                         title='6x1 Errors')
                    pose_vec_est = plot_6_by_1(est_lie_alg, title='Est')
                    pose_vec_gt = plot_6_by_1(gt_lie_alg, title='GT')
                    est_traj_img = plot_multi_traj(est_traj, 'Est.', gt_traj,
                                                   'GT', key + ' Set')

                    train_writer.add_image(key + '/est_traj', est_traj_img,
                                           epoch + 1)
                    train_writer.add_image(key + '/errors', errors, epoch + 1)
                    train_writer.add_image(key + '/gt_lie_alg', pose_vec_gt,
                                           epoch + 1)
                    train_writer.add_image(key + '/est_lie_alg', pose_vec_est,
                                           epoch + 1)

                    results[key] = {
                        'val_seq': args.val_seq,
                        'test_seq': args.test_seq,
                        'epochs': epoch + 1,
                        'est_pose_vecs': est_lie_alg,
                        'gt_pose_vecs': gt_lie_alg,
                        'est_traj': est_traj,
                        'gt_traj': gt_traj,
                    }

                if args.save_results:  ##Save the best models
                    os.makedirs('results/{}'.format(config['date']),
                                exist_ok=True)

                    if (val_losses['l_reconstruct_forward'] +
                            val_losses['l_reconstruct_inverse']
                        ) < best_val_loss[
                            key] and epoch > 0:  # and epoch > 2*(config['iterations']-1):
                        best_val_loss[key] = (
                            val_losses['l_reconstruct_forward'] +
                            val_losses['l_reconstruct_inverse'])
                        best_loss_epoch[key] = epoch
                        depth_dict_loss = depth_model.state_dict()
                        pose_dict_loss = pose_model.state_dict()
                        if key == 'val':
                            print("Lowest validation loss (saving model)")
                            torch.save(
                                depth_dict_loss,
                                'results/{}/{}-depth-best-loss-val_seq-{}-test_seq-{}.pth'
                                .format(config['date'], ts, args.val_seq[0],
                                        args.test_seq[0]))
                            torch.save(
                                pose_dict_loss,
                                'results/{}/{}-pose-best-loss-val_seq-{}-test_seq-{}.pth'
                                .format(config['date'], ts, args.val_seq[0],
                                        args.test_seq[0]))

                        if args.data_format == 'odometry':
                            results[key]['best_loss_epoch'] = best_loss_epoch[
                                key]
                            save_obj(
                                results,
                                'results/{}/{}-results-val_seq-{}-test_seq-{}'.
                                format(config['date'], ts, args.val_seq[0],
                                       args.test_seq[0]))
                        save_obj(config,
                                 'results/{}/config'.format(config['date']))
                        f = open(
                            "results/{}/config.txt".format(config['date']),
                            "w")
                        f.write(str(config))
                        f.close()
    save_obj(loss.scale_factor_list,
             'results/{}/scale_factor'.format(config['date']))
    duration = timeSince(start)
    print("Training complete (duration: {})".format(duration))