Пример #1
0
def main(argv=None):
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

    print("-- evaluating: on {} pairs from {}/{}".format(
        FLAGS.num, FLAGS.dataset, FLAGS.variant))

    default_config = config_dict()
    dirs = default_config['dirs']

    if FLAGS.dataset == 'kitti':
        data = KITTIData(dirs['data'], development=True)
        data_input = KITTIInput(data,
                                batch_size=1,
                                normalize=False,
                                dims=(384, 1280))
        inputs = getattr(data_input, 'input_' + FLAGS.variant)()
    elif FLAGS.dataset == 'nao':
        data = NaoData(dirs['data'], development=True)
        data_input = NaoInput(data,
                              batch_size=1,
                              normalize=False,
                              dir_name=FLAGS.directory,
                              dims=(192, 256))
    elif FLAGS.dataset == 'chairs':
        data = ChairsData(dirs['data'], development=True)
        data_input = ChairsInput(data,
                                 batch_size=1,
                                 normalize=False,
                                 dims=(384, 512))
        if FLAGS.variant == 'test_2015' and FLAGS.num == -1:
            FLAGS.num = 200
        elif FLAGS.variant == 'test_2012' and FLAGS.num == -1:
            FLAGS.num = 195
    elif FLAGS.dataset == 'sintel':
        data = SintelData(dirs['data'], development=True)
        data_input = SintelInput(data,
                                 batch_size=1,
                                 normalize=False,
                                 dims=(512, 1024))
    if FLAGS.variant in ['test_clean', 'test_final'] and FLAGS.num == -1:
        FLAGS.num = 552
    elif FLAGS.dataset == 'mdb':
        data = MiddleburyData(dirs['data'], development=True)
        data_input = MiddleburyInput(data,
                                     batch_size=1,
                                     normalize=False,
                                     dims=(512, 640))
        if FLAGS.variant == 'test' and FLAGS.num == -1:
            FLAGS.num = 12

    input_fn = getattr(data_input, 'input_' + FLAGS.variant)

    results = []
    for name in FLAGS.ex.split(','):
        result, image_names = _evaluate_experiment(name, input_fn, data_input)
        results.append(result)

    display(results, image_names)
Пример #2
0
def main(argv=None):
    experiment = Experiment(
        name=FLAGS.ex,
        overwrite=FLAGS.ow)
    dirs = experiment.config['dirs']
    run_config = experiment.config['run']

    gpu_list_param = run_config['gpu_list']

    if isinstance(gpu_list_param, int):
        gpu_list = [gpu_list_param]
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list_param)
    else:
        gpu_list = list(range(len(gpu_list_param.split(','))))
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list_param
    gpu_batch_size = int(run_config['batch_size'] / max(len(gpu_list), 1))
    devices = ['/gpu:' + str(gpu_num) for gpu_num in gpu_list]

    train_dataset = run_config.get('dataset', 'kitti')

    kdata = KITTIData(data_dir=dirs['data'],
                      fast_dir=dirs.get('fast'),
                      stat_log_dir=None,
                      development=run_config['development'])
    einput = KITTIInput(data=kdata,
                        batch_size=1,
                        normalize=False,
                        dims=(384, 1280))
    epinput = KITTIInput(data=kdata,
                        batch_size=1,
                        normalize=False,
                        dims=(384, 1280))

    if train_dataset == 'chairs':
        cconfig = copy.deepcopy(experiment.config['train'])
        cconfig.update(experiment.config['train_chairs'])
        convert_input_strings(cconfig, dirs)
        citers = cconfig.get('num_iters', 0)
        cdata = ChairsData(data_dir=dirs['data'],
                           fast_dir=dirs.get('fast'),
                           stat_log_dir=None,
                           development=run_config['development'])
        cinput = ChairsInput(data=cdata,
                 batch_size=gpu_batch_size,
                 normalize=False,
                 dims=(cconfig['height'], cconfig['width']))
        tr = Trainer(
              lambda shift: cinput.input_raw(swap_images=False,
                                             shift=shift * run_config['batch_size']),
              lambda: einput.input_train_2012(),
              params=cconfig,
              normalization=cinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices)
        tr.run(0, citers)

    elif train_dataset == 'kitti':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_kitti'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        kinput = KITTIInput(data=kdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=True,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(
              lambda shift: kinput.input_raw(swap_images=False,
                                             center_crop=True,
                                             shift=shift * run_config['batch_size'],
                                             epipolar_weight=kconfig.get('epipolar_weight', None)),
              lambda: einput.input_train_2012(),
              params=kconfig,
              normalization=kinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices,
              eval_pose_batch_fn=lambda: epinput.input_odometry(),
              eval_pose_summaries_dir=experiment.eval_pose_dir)
        tr.run(0, kiters)

    elif train_dataset == 'cityscapes':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_cityscapes'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        cdata = CityscapesData(data_dir=dirs['data'],
                    fast_dir=dirs.get('fast'),
                    stat_log_dir=None,
                    development=run_config['development'])
        kinput = KITTIInput(data=cdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=False,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(
              lambda shift: kinput.input_raw(swap_images=False,
                                             center_crop=True,
                                             skip=[0, 1],
                                             shift=shift * run_config['batch_size']),
              lambda: einput.input_train_2012(),
              params=kconfig,
              normalization=kinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'synthia':
        sconfig = copy.deepcopy(experiment.config['train'])
        sconfig.update(experiment.config['train_synthia'])
        convert_input_strings(sconfig, dirs)
        siters = sconfig.get('num_iters', 0)
        sdata = SynthiaData(data_dir=dirs['data'],
                fast_dir=dirs.get('fast'),
                stat_log_dir=None,
                development=run_config['development'])
        sinput = KITTIInput(data=sdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            dims=(sconfig['height'], sconfig['width']))
        tr = Trainer(
              lambda shift: sinput.input_raw(swap_images=False,
                                             center_crop=True,
                                             shift=shift * run_config['batch_size'],
                                             epipolar_weight=sconfig.get('epipolar_weight', None)),
              lambda: einput.input_train_2012(),
              params=sconfig,
              normalization=sinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices)
        tr.run(0, siters)

    elif train_dataset == 'kitti_ft':
        ftconfig = copy.deepcopy(experiment.config['train'])
        ftconfig.update(experiment.config['train_kitti_ft'])
        convert_input_strings(ftconfig, dirs)
        ftiters = ftconfig.get('num_iters', 0)
        ftinput = KITTIInput(data=kdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(ftconfig['height'], ftconfig['width']))
        tr = Trainer(
              lambda shift: ftinput.input_train_gt(40),
              lambda: einput.input_train_2015(40),
              supervised=True,
              params=ftconfig,
              normalization=ftinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices)
        tr.run(0, ftiters)

    else:
      raise ValueError(
          "Invalid dataset. Dataset must be one of "
          "{synthia, kitti, kitti_ft, cityscapes, chairs}")

    if not FLAGS.debug:
        experiment.conclude()
Пример #3
0
def main(argv=None):
    experiment = Experiment(name=FLAGS.ex, overwrite=FLAGS.ow)
    dirs = experiment.config['dirs']
    run_config = experiment.config['run']

    gpu_list_param = run_config['gpu_list']

    if isinstance(gpu_list_param, int):
        gpu_list = [gpu_list_param]
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list_param)
    else:
        gpu_list = list(range(len(gpu_list_param.split(','))))
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list_param
    gpu_batch_size = int(run_config['batch_size'] / max(len(gpu_list), 1))
    devices = ['/gpu:' + str(gpu_num) for gpu_num in gpu_list]

    train_dataset = run_config.get('dataset', 'kitti')

    print("train_dataset:", train_dataset)

    kdata = KITTIData(data_dir=dirs['data'],
                      fast_dir=dirs.get('fast'),
                      stat_log_dir=None,
                      development=run_config['development'])
    einput = KITTIInput(data=kdata,
                        batch_size=1,
                        normalize=False,
                        dims=(384, 1280))

    if train_dataset == 'chairs':
        cconfig = copy.deepcopy(experiment.config['train'])
        cconfig.update(experiment.config['train_chairs'])
        convert_input_strings(cconfig, dirs)
        citers = cconfig.get('num_iters', 0)
        cdata = ChairsData(data_dir=dirs['data'],
                           fast_dir=dirs.get('fast'),
                           stat_log_dir=None,
                           development=run_config['development'])
        cinput = ChairsInput(data=cdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(cconfig['height'], cconfig['width']))
        tr = Trainer(lambda shift: cinput.input_raw(
            swap_images=False, shift=shift * run_config['batch_size']),
                     lambda: einput.input_train_2012(),
                     params=cconfig,
                     normalization=cinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, citers)

    elif train_dataset == 'kitti':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_kitti'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        kinput = KITTIInput(data=kdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=True,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(lambda shift: kinput.input_raw(swap_images=False,
                                                    center_crop=True,
                                                    shift=shift * run_config[
                                                        'batch_size']),
                     lambda: einput.input_train_2012(),
                     params=kconfig,
                     normalization=kinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'cityscapes':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_cityscapes'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        cdata = CityscapesData(data_dir=dirs['data'],
                               fast_dir=dirs.get('fast'),
                               stat_log_dir=None,
                               development=run_config['development'])
        kinput = KITTIInput(data=cdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=False,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(lambda shift: kinput.input_raw(swap_images=False,
                                                    center_crop=True,
                                                    skip=[0, 1],
                                                    shift=shift * run_config[
                                                        'batch_size']),
                     lambda: einput.input_train_2012(),
                     params=kconfig,
                     normalization=kinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'synthia':
        sconfig = copy.deepcopy(experiment.config['train'])
        sconfig.update(experiment.config['train_synthia'])
        convert_input_strings(sconfig, dirs)
        siters = sconfig.get('num_iters', 0)
        sdata = SynthiaData(data_dir=dirs['data'],
                            fast_dir=dirs.get('fast'),
                            stat_log_dir=None,
                            development=run_config['development'])
        sinput = KITTIInput(data=sdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            dims=(sconfig['height'], sconfig['width']))
        tr = Trainer(lambda shift: sinput.input_raw(
            swap_images=False, shift=shift * run_config['batch_size']),
                     lambda: einput.input_train_2012(),
                     params=sconfig,
                     normalization=sinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, siters)

    elif train_dataset == 'nao':
        # c&p and adjusted from synthia
        nconfig = copy.deepcopy(experiment.config['train'])
        nconfig.update(experiment.config['train_nao'])
        convert_input_strings(nconfig, dirs)
        niters = nconfig.get('num_iters', 0)
        ndata = NaoData(dirs['data'], development=False)
        ninput = NaoInput(ndata,
                          batch_size=1,
                          normalize=False,
                          dir_name='grey400',
                          dims=(192, 256))
        tr = Trainer(
            lambda shift: ninput.input_consecutive(return_shape=False),
            # lambda: einput.input_train_2012(),    # todo: is this appropriate for nao data? what does it do?
            lambda: None,
            params=nconfig,
            normalization=ninput.get_normalization(),
            train_summaries_dir=experiment.train_dir,
            eval_summaries_dir=experiment.eval_dir,
            experiment=FLAGS.ex,
            ckpt_dir=experiment.save_dir,
            debug=FLAGS.debug,
            interactive_plot=run_config.get('interactive_plot'),
            devices=devices)
        tr.run(0, niters)

    elif train_dataset == 'kitti_ft':
        ftconfig = copy.deepcopy(experiment.config['train'])
        ftconfig.update(experiment.config['train_kitti_ft'])
        convert_input_strings(ftconfig, dirs)
        ftiters = ftconfig.get('num_iters', 0)
        ftinput = KITTIInput(data=kdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(ftconfig['height'], ftconfig['width']))
        tr = Trainer(lambda shift: ftinput.input_train_gt(40),
                     lambda: einput.input_train_2015(40),
                     supervised=True,
                     params=ftconfig,
                     normalization=ftinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, ftiters)

    else:
        raise ValueError("Invalid dataset. Dataset must be one of "
                         "{synthia, kitti, kitti_ft, cityscapes, chairs, nao}")

    if not FLAGS.debug:
        experiment.conclude()

    print('done with training at:')
    print(str(datetime.now()))