Пример #1
0
def main(argv=None):
    experiment = Experiment(name=FLAGS.ex, overwrite=FLAGS.ow)
    dirs = experiment.config['dirs']
    run_config = experiment.config['run']

    gpu_list_param = run_config['gpu_list']

    if isinstance(gpu_list_param, int):
        gpu_list = [gpu_list_param]
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list_param)
    else:
        gpu_list = list(range(len(gpu_list_param.split(','))))
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list_param
    gpu_batch_size = int(run_config['batch_size'] / max(len(gpu_list), 1))
    devices = ['/gpu:' + str(gpu_num) for gpu_num in gpu_list]

    train_dataset = run_config.get('dataset', 'kitti')

    kdata = KITTIData(data_dir=dirs['data'],
                      fast_dir=dirs.get('fast'),
                      stat_log_dir=None,
                      development=run_config['development'])
    einput = KITTIInput(data=kdata,
                        batch_size=1,
                        normalize=False,
                        dims=(384, 1280))

    if train_dataset == 'chairs':
        cconfig = copy.deepcopy(experiment.config['train'])
        cconfig.update(experiment.config['train_chairs'])
        convert_input_strings(cconfig, dirs)
        citers = cconfig.get('num_iters', 0)
        cdata = ChairsData(data_dir=dirs['data'],
                           fast_dir=dirs.get('fast'),
                           stat_log_dir=None,
                           development=run_config['development'])
        cinput = ChairsInput(data=cdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(cconfig['height'], cconfig['width']))
        tr = Trainer(lambda shift: cinput.input_raw(
            swap_images=False, shift=shift * run_config['batch_size']),
                     lambda: einput.input_train_2012(),
                     params=cconfig,
                     normalization=cinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, citers)

    elif train_dataset == 'kitti':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_kitti'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        kinput = KITTIInput(data=kdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=True,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(lambda shift: kinput.input_raw(swap_images=False,
                                                    center_crop=True,
                                                    shift=shift * run_config[
                                                        'batch_size']),
                     lambda: einput.input_train_2012(),
                     params=kconfig,
                     normalization=kinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'cityscapes':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_cityscapes'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        cdata = CityscapesData(data_dir=dirs['data'],
                               fast_dir=dirs.get('fast'),
                               stat_log_dir=None,
                               development=run_config['development'])
        kinput = KITTIInput(data=cdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=False,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(lambda shift: kinput.input_raw(swap_images=False,
                                                    center_crop=True,
                                                    skip=[0, 1],
                                                    shift=shift * run_config[
                                                        'batch_size']),
                     lambda: einput.input_train_2012(),
                     params=kconfig,
                     normalization=kinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'synthia':
        sconfig = copy.deepcopy(experiment.config['train'])
        sconfig.update(experiment.config['train_synthia'])
        convert_input_strings(sconfig, dirs)
        siters = sconfig.get('num_iters', 0)
        sdata = SynthiaData(data_dir=dirs['data'],
                            fast_dir=dirs.get('fast'),
                            stat_log_dir=None,
                            development=run_config['development'])
        sinput = KITTIInput(data=sdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            dims=(sconfig['height'], sconfig['width']))
        tr = Trainer(lambda shift: sinput.input_raw(
            swap_images=False, shift=shift * run_config['batch_size']),
                     lambda: einput.input_train_2012(),
                     params=sconfig,
                     normalization=sinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, siters)

    elif train_dataset == 'kitti_ft':
        ftconfig = copy.deepcopy(experiment.config['train'])
        ftconfig.update(experiment.config['train_kitti_ft'])
        convert_input_strings(ftconfig, dirs)
        ftiters = ftconfig.get('num_iters', 0)
        ftinput = KITTIInput(data=kdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(ftconfig['height'], ftconfig['width']))
        tr = Trainer(lambda shift: ftinput.input_train_gt(40),
                     lambda: einput.input_train_2015(40),
                     supervised=True,
                     params=ftconfig,
                     normalization=ftinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, ftiters)

    else:
        raise ValueError("Invalid dataset. Dataset must be one of "
                         "{synthia, kitti, kitti_ft, cityscapes, chairs}")

    if not FLAGS.debug:
        experiment.conclude()