Esempio n. 1
0
def main(argv=None):
    experiment = Experiment(
        name=FLAGS.ex,
        overwrite=FLAGS.ow)
    dirs = experiment.config['dirs']
    run_config = experiment.config['run']

    gpu_list_param = run_config['gpu_list']

    if isinstance(gpu_list_param, int):
        gpu_list = [gpu_list_param]
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list_param)
    else:
        gpu_list = list(range(len(gpu_list_param.split(','))))
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list_param
    gpu_batch_size = int(run_config['batch_size'] / max(len(gpu_list), 1))
    devices = ['/gpu:' + str(gpu_num) for gpu_num in gpu_list]

    train_dataset = run_config.get('dataset', 'kitti')

    kdata = KITTIData(data_dir=dirs['data'],
                      fast_dir=dirs.get('fast'),
                      stat_log_dir=None,
                      development=run_config['development'])
    einput = KITTIInput(data=kdata,
                        batch_size=1,
                        normalize=False,
                        dims=(384, 1280))
    epinput = KITTIInput(data=kdata,
                        batch_size=1,
                        normalize=False,
                        dims=(384, 1280))

    if train_dataset == 'chairs':
        cconfig = copy.deepcopy(experiment.config['train'])
        cconfig.update(experiment.config['train_chairs'])
        convert_input_strings(cconfig, dirs)
        citers = cconfig.get('num_iters', 0)
        cdata = ChairsData(data_dir=dirs['data'],
                           fast_dir=dirs.get('fast'),
                           stat_log_dir=None,
                           development=run_config['development'])
        cinput = ChairsInput(data=cdata,
                 batch_size=gpu_batch_size,
                 normalize=False,
                 dims=(cconfig['height'], cconfig['width']))
        tr = Trainer(
              lambda shift: cinput.input_raw(swap_images=False,
                                             shift=shift * run_config['batch_size']),
              lambda: einput.input_train_2012(),
              params=cconfig,
              normalization=cinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices)
        tr.run(0, citers)

    elif train_dataset == 'kitti':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_kitti'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        kinput = KITTIInput(data=kdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=True,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(
              lambda shift: kinput.input_raw(swap_images=False,
                                             center_crop=True,
                                             shift=shift * run_config['batch_size'],
                                             epipolar_weight=kconfig.get('epipolar_weight', None)),
              lambda: einput.input_train_2012(),
              params=kconfig,
              normalization=kinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices,
              eval_pose_batch_fn=lambda: epinput.input_odometry(),
              eval_pose_summaries_dir=experiment.eval_pose_dir)
        tr.run(0, kiters)

    elif train_dataset == 'cityscapes':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_cityscapes'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        cdata = CityscapesData(data_dir=dirs['data'],
                    fast_dir=dirs.get('fast'),
                    stat_log_dir=None,
                    development=run_config['development'])
        kinput = KITTIInput(data=cdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=False,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(
              lambda shift: kinput.input_raw(swap_images=False,
                                             center_crop=True,
                                             skip=[0, 1],
                                             shift=shift * run_config['batch_size']),
              lambda: einput.input_train_2012(),
              params=kconfig,
              normalization=kinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'synthia':
        sconfig = copy.deepcopy(experiment.config['train'])
        sconfig.update(experiment.config['train_synthia'])
        convert_input_strings(sconfig, dirs)
        siters = sconfig.get('num_iters', 0)
        sdata = SynthiaData(data_dir=dirs['data'],
                fast_dir=dirs.get('fast'),
                stat_log_dir=None,
                development=run_config['development'])
        sinput = KITTIInput(data=sdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            dims=(sconfig['height'], sconfig['width']))
        tr = Trainer(
              lambda shift: sinput.input_raw(swap_images=False,
                                             center_crop=True,
                                             shift=shift * run_config['batch_size'],
                                             epipolar_weight=sconfig.get('epipolar_weight', None)),
              lambda: einput.input_train_2012(),
              params=sconfig,
              normalization=sinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices)
        tr.run(0, siters)

    elif train_dataset == 'kitti_ft':
        ftconfig = copy.deepcopy(experiment.config['train'])
        ftconfig.update(experiment.config['train_kitti_ft'])
        convert_input_strings(ftconfig, dirs)
        ftiters = ftconfig.get('num_iters', 0)
        ftinput = KITTIInput(data=kdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(ftconfig['height'], ftconfig['width']))
        tr = Trainer(
              lambda shift: ftinput.input_train_gt(40),
              lambda: einput.input_train_2015(40),
              supervised=True,
              params=ftconfig,
              normalization=ftinput.get_normalization(),
              train_summaries_dir=experiment.train_dir,
              eval_summaries_dir=experiment.eval_dir,
              experiment=FLAGS.ex,
              ckpt_dir=experiment.save_dir,
              debug=FLAGS.debug,
              interactive_plot=run_config.get('interactive_plot'),
              devices=devices)
        tr.run(0, ftiters)

    else:
      raise ValueError(
          "Invalid dataset. Dataset must be one of "
          "{synthia, kitti, kitti_ft, cityscapes, chairs}")

    if not FLAGS.debug:
        experiment.conclude()
Esempio n. 2
0
def main(argv=None):
    experiment = Experiment(name=FLAGS.ex, overwrite=FLAGS.ow)
    dirs = experiment.config['dirs']
    run_config = experiment.config['run']

    gpu_list_param = run_config['gpu_list']

    if isinstance(gpu_list_param, int):
        gpu_list = [gpu_list_param]
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list_param)
    else:
        gpu_list = list(range(len(gpu_list_param.split(','))))
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list_param
    gpu_batch_size = int(run_config['batch_size'] / max(len(gpu_list), 1))
    devices = ['/gpu:' + str(gpu_num) for gpu_num in gpu_list]
    from tensorflow.python.client import device_lib
    print('using device ', device_lib.list_local_devices())

    train_dataset = run_config.get('dataset', 'kitti')

    # kdata = KITTIData(data_dir=dirs['data'],
    #                   fast_dir=dirs.get('fast'),
    #                   stat_log_dir=None,
    #                   development=run_config['development'])
    # einput = KITTIInput(data=kdata,
    #                     batch_size=1,
    #                     normalize=False,
    #                     dims=(384, 1280))

    if train_dataset == 'chairs':
        cconfig = copy.deepcopy(experiment.config['train'])
        cconfig.update(experiment.config['train_chairs'])
        convert_input_strings(cconfig, dirs)
        citers = cconfig.get('num_iters', 0)
        cdata = ChairsData(data_dir=dirs['data'],
                           fast_dir=dirs.get('fast'),
                           stat_log_dir=None,
                           development=run_config['development'])
        cinput = ChairsInput(data=cdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(cconfig['height'], cconfig['width']))
        tr = Trainer(lambda shift: cinput.input_raw(
            swap_images=False, shift=shift * run_config['batch_size']),
                     lambda: einput.input_train_2012(),
                     params=cconfig,
                     normalization=cinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, citers)

    elif train_dataset == 'kitti':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_kitti'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        kinput = KITTIInput(data=kdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=True,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(lambda shift: kinput.input_raw(swap_images=False,
                                                    center_crop=True,
                                                    shift=shift * run_config[
                                                        'batch_size']),
                     lambda: einput.input_train_2012(),
                     params=kconfig,
                     normalization=kinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'cityscapes':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_cityscapes'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        cdata = CityscapesData(data_dir=dirs['data'],
                               fast_dir=dirs.get('fast'),
                               stat_log_dir=None,
                               development=run_config['development'])
        kinput = KITTIInput(data=cdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=False,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(lambda shift: kinput.input_raw(swap_images=False,
                                                    center_crop=True,
                                                    skip=[0, 1],
                                                    shift=shift * run_config[
                                                        'batch_size']),
                     lambda: einput.input_train_2012(),
                     params=kconfig,
                     normalization=kinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'synthia':
        sconfig = copy.deepcopy(experiment.config['train'])
        sconfig.update(experiment.config['train_synthia'])
        convert_input_strings(sconfig, dirs)
        siters = sconfig.get('num_iters', 0)
        sdata = SynthiaData(data_dir=dirs['data'],
                            fast_dir=dirs.get('fast'),
                            stat_log_dir=None,
                            development=run_config['development'])
        sinput = KITTIInput(data=sdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            dims=(sconfig['height'], sconfig['width']))
        tr = Trainer(lambda shift: sinput.input_raw(
            swap_images=False, shift=shift * run_config['batch_size']),
                     lambda: einput.input_train_2012(),
                     params=sconfig,
                     normalization=sinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, siters)

    elif train_dataset == 'kitti_ft':
        ftconfig = copy.deepcopy(experiment.config['train'])
        ftconfig.update(experiment.config['train_kitti_ft'])
        convert_input_strings(ftconfig, dirs)
        ftiters = ftconfig.get('num_iters', 0)
        ftinput = KITTIInput(data=kdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(ftconfig['height'], ftconfig['width']))
        tr = Trainer(lambda shift: ftinput.input_train_gt(40),
                     lambda: einput.input_train_2015(40),
                     supervised=True,
                     params=ftconfig,
                     normalization=ftinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, ftiters)

    elif train_dataset == 'cartgripper':
        cconfig = copy.deepcopy(experiment.config['train'])
        cconfig.update(experiment.config['cartgripper'])
        convert_input_strings(cconfig, dirs)
        citers = cconfig.get('num_iters', 0)

        from e2eflow.cartgripper.read_tf_records2 import build_tfrecord_input

        conf = {}
        DATA_DIR = os.environ[
            'VMPC_DATA_DIR'] + '/cartgripper_startgoal_large4step/train'
        conf['data_dir'] = DATA_DIR  # 'directory containing data_files.' ,
        conf['skip_frame'] = 1
        conf['train_val_split'] = 0.95
        conf[
            'sequence_length'] = 4  # 48      # 'sequence length, including context frames.'
        conf['batch_size'] = experiment.config['run']['batch_size']
        conf['context_frames'] = 2
        conf['image_only'] = ''
        conf['orig_size'] = [480, 640]
        conf['visualize'] = False

        # global_step_ = tf.placeholder(tf.int32, name="global_step")
        # train_im = sel_images(train_image, global_step_, citers, 4)
        # val_im = sel_images(val_image, global_step_, citers, 4)

        def make_train(iter_offset):
            train_image = build_tfrecord_input(conf, training=True)
            use_size = tf.constant([384, 512])
            im0 = tf.image.resize_images(train_image[:, 0],
                                         use_size,
                                         method=tf.image.ResizeMethod.BILINEAR)
            im1 = tf.image.resize_images(train_image[:, 1],
                                         use_size,
                                         method=tf.image.ResizeMethod.BILINEAR)
            return [im0, im1]

        def make_val(iter_offset):
            val_image = build_tfrecord_input(conf, training=False)
            use_size = tf.constant([384, 512])
            val_image = tf.image.resize_images(
                val_image, use_size, method=tf.image.ResizeMethod.BILINEAR)
            im0 = tf.image.resize_images(val_image[:, 0],
                                         use_size,
                                         method=tf.image.ResizeMethod.BILINEAR)
            im1 = tf.image.resize_images(val_image[:, 1],
                                         use_size,
                                         method=tf.image.ResizeMethod.BILINEAR)
            return [im0, im1]

        tr = Trainer(
            make_train,
            make_val,
            params=cconfig,
            normalization=[np.array([0., 0., 0.], dtype=np.float32),
                           1.],  #TODO: try with normalizeation
            train_summaries_dir=experiment.train_dir,
            eval_summaries_dir=experiment.eval_dir,
            experiment=FLAGS.ex,
            ckpt_dir=experiment.save_dir,
            debug=FLAGS.debug,
            interactive_plot=run_config.get('interactive_plot'),
            devices=devices)
        tr.run(0, citers)

    else:
        raise ValueError("Invalid dataset. Dataset must be one of "
                         "{synthia, kitti, kitti_ft, cityscapes, chairs}")

    if not FLAGS.debug:
        experiment.conclude()
Esempio n. 3
0
def _evaluate_experiment(name, input_fn, data_input):
    normalize_fn = data_input._normalize_image
    resized_h = data_input.dims[0]
    resized_w = data_input.dims[1]

    current_config = config_dict('../config.ini')
    exp_dir = os.path.join(current_config['dirs']['log'], 'ex', name)
    config_path = os.path.join(exp_dir, 'config.ini')
    if not os.path.isfile(config_path):
        config_path = '../config.ini'
    if not os.path.isdir(exp_dir) or not tf.train.get_checkpoint_state(
            exp_dir):
        exp_dir = os.path.join(current_config['dirs']['checkpoints'], name)
    config = config_dict(config_path)
    params = config['train']
    convert_input_strings(params, config_dict('../config.ini')['dirs'])
    dataset_params_name = 'train_' + FLAGS.dataset
    if dataset_params_name in config:
        params.update(config[dataset_params_name])
    ckpt = tf.train.get_checkpoint_state(exp_dir)
    if not ckpt:
        raise RuntimeError("Error: experiment must contain a checkpoint")
    ckpt_path = exp_dir + "/" + os.path.basename(ckpt.model_checkpoint_path)

    with tf.Graph().as_default():  #, tf.device('gpu:' + FLAGS.gpu):
        inputs = input_fn()
        im1, im2, input_shape = inputs[:3]
        truth = inputs[3:]

        height, width, _ = tf.unstack(tf.squeeze(input_shape), num=3, axis=0)
        im1 = resize_input(im1, height, width, resized_h, resized_w)
        im2 = resize_input(im2, height, width, resized_h,
                           resized_w)  # TODO adapt train.py

        _, flow, flow_bw = unsupervised_loss(
            (im1, im2),
            normalization=data_input.get_normalization(),
            params=params,
            augment=False,
            return_flow=True)

        im1 = resize_output(im1, height, width, 3)
        im2 = resize_output(im2, height, width, 3)
        flow = resize_output_flow(flow, height, width, 2)
        flow_bw = resize_output_flow(flow_bw, height, width, 2)

        flow_fw_int16 = flow_to_int16(flow)
        flow_bw_int16 = flow_to_int16(flow_bw)

        im1_pred = image_warp(im2, flow)
        im1_diff = tf.abs(im1 - im1_pred)
        #im2_diff = tf.abs(im1 - im2)

        #flow_bw_warped = image_warp(flow_bw, flow)

        if len(truth) == 4:
            flow_occ, mask_occ, flow_noc, mask_noc = truth
            flow_occ = resize_output_crop(flow_occ, height, width, 2)
            flow_noc = resize_output_crop(flow_noc, height, width, 2)
            mask_occ = resize_output_crop(mask_occ, height, width, 1)
            mask_noc = resize_output_crop(mask_noc, height, width, 1)

            #div = divergence(flow_occ)
            #div_bw = divergence(flow_bw)
            occ_pred = 1 - (1 - occlusion(flow, flow_bw)[0])
            def_pred = 1 - (1 - occlusion(flow, flow_bw)[1])
            disocc_pred = forward_warp(flow_bw) < DISOCC_THRESH
            disocc_fw_pred = forward_warp(flow) < DISOCC_THRESH
            image_slots = [
                ((im1 * 0.5 + im2 * 0.5) / 255, 'overlay'),
                (im1_diff / 255, 'brightness error'),
                #(im1 / 255, 'first image', 1, 0),
                #(im2 / 255, 'second image', 1, 0),
                #(im2_diff / 255, '|first - second|', 1, 2),
                (flow_to_color(flow), 'flow'),
                #(flow_to_color(flow_bw), 'flow bw prediction'),
                #(tf.image.rgb_to_grayscale(im1_diff) > 20, 'diff'),
                #(occ_pred, 'occ'),
                #(def_pred, 'disocc'),
                #(disocc_pred, 'reverse disocc'),
                #(disocc_fw_pred, 'forward disocc prediction'),
                #(div, 'div'),
                #(div < -2, 'neg div'),
                #(div > 5, 'pos div'),
                #(flow_to_color(flow_occ, mask_occ), 'flow truth'),
                (flow_error_image(flow, flow_occ, mask_occ, mask_noc),
                 'flow error')  #  (blue: correct, red: wrong, dark: occluded)
            ]

            # list of (scalar_op, title)
            scalar_slots = [
                (flow_error_avg(flow_noc, flow, mask_noc), 'EPE_noc'),
                (flow_error_avg(flow_occ, flow, mask_occ), 'EPE_all'),
                (outlier_pct(flow_noc, flow, mask_noc), 'outliers_noc'),
                (outlier_pct(flow_occ, flow, mask_occ), 'outliers_all')
            ]
        elif len(truth) == 2:
            flow_gt, mask = truth
            flow_gt = resize_output_crop(flow_gt, height, width, 2)
            mask = resize_output_crop(mask, height, width, 1)

            image_slots = [
                ((im1 * 0.5 + im2 * 0.5) / 255, 'overlay'),
                (im1_diff / 255, 'brightness error'),
                (flow_to_color(flow), 'flow'),
                (flow_to_color(flow_gt, mask), 'gt'),
            ]

            # list of (scalar_op, title)
            scalar_slots = [(flow_error_avg(flow_gt, flow, mask), 'EPE_all')]
        else:
            image_slots = [
                (im1 / 255, 'first image'),
                #(im1_pred / 255, 'warped second image', 0, 1),
                (im1_diff / 255, 'warp error'),
                #(im2 / 255, 'second image', 1, 0),
                #(im2_diff / 255, '|first - second|', 1, 2),
                (flow_to_color(flow), 'flow prediction')
            ]
            scalar_slots = []

        num_ims = len(image_slots)
        image_ops = [t[0] for t in image_slots]
        scalar_ops = [t[0] for t in scalar_slots]
        image_names = [t[1] for t in image_slots]
        scalar_names = [t[1] for t in scalar_slots]
        all_ops = image_ops + scalar_ops

        image_lists = []
        averages = np.zeros(len(scalar_ops))
        sess_config = tf.ConfigProto(allow_soft_placement=True)

        exp_out_dir = os.path.join('../out', name)
        if FLAGS.output_visual or FLAGS.output_benchmark:
            if os.path.isdir(exp_out_dir):
                shutil.rmtree(exp_out_dir)
            os.makedirs(exp_out_dir)
            shutil.copyfile(config_path, os.path.join(exp_out_dir,
                                                      'config.ini'))

        with tf.Session(config=sess_config) as sess:
            saver = tf.train.Saver(tf.global_variables())
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            restore_networks(sess, params, ckpt, ckpt_path)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            # TODO adjust for batch_size > 1 (also need to change image_lists appending)
            max_iter = FLAGS.num if FLAGS.num > 0 else None

            try:
                num_iters = 0
                while not coord.should_stop() and (max_iter is None
                                                   or num_iters != max_iter):
                    all_results = sess.run(
                        [flow, flow_bw, flow_fw_int16, flow_bw_int16] +
                        all_ops)
                    flow_fw_res, flow_bw_res, flow_fw_int16_res, flow_bw_int16_res = all_results[:
                                                                                                 4]
                    all_results = all_results[4:]
                    image_results = all_results[:num_ims]
                    scalar_results = all_results[num_ims:]
                    iterstr = str(num_iters).zfill(6)
                    if FLAGS.output_visual:
                        path_col = os.path.join(exp_out_dir,
                                                iterstr + '_flow.png')
                        path_overlay = os.path.join(exp_out_dir,
                                                    iterstr + '_img.png')
                        path_error = os.path.join(exp_out_dir,
                                                  iterstr + '_err.png')
                        write_rgb_png(image_results[0] * 255, path_overlay)
                        write_rgb_png(image_results[1] * 255, path_col)
                        write_rgb_png(image_results[2] * 255, path_error)
                    if FLAGS.output_benchmark:
                        path_fw = os.path.join(exp_out_dir, iterstr)
                        if FLAGS.output_png:
                            write_rgb_png(flow_fw_int16_res,
                                          path_fw + '_10.png',
                                          bitdepth=16)
                        else:
                            write_flo(flow_fw_res, path_fw + '_10.flo')
                        if FLAGS.output_backward:
                            path_fw = os.path.join(exp_out_dir,
                                                   iterstr + '_01.png')
                            write_rgb_png(flow_bw_int16_res,
                                          path_bw,
                                          bitdepth=16)
                    if num_iters < FLAGS.num_vis:
                        image_lists.append(image_results)
                    averages += scalar_results
                    if num_iters > 0:
                        sys.stdout.write('\r')
                    num_iters += 1
                    sys.stdout.write("-- evaluating '{}': {}/{}".format(
                        name, num_iters, max_iter))
                    sys.stdout.flush()
                    print()
            except tf.errors.OutOfRangeError:
                pass

            averages /= num_iters

            coord.request_stop()
            coord.join(threads)

    for t, avg in zip(scalar_slots, averages):
        _, scalar_name = t
        print("({}) {} = {}".format(name, scalar_name, avg))

    return image_lists, image_names
Esempio n. 4
0
def main(argv=None):
    experiment = Experiment(name=FLAGS.ex, overwrite=FLAGS.ow)
    dirs = experiment.config['dirs']
    run_config = experiment.config['run']

    gpu_list_param = run_config['gpu_list']

    if isinstance(gpu_list_param, int):
        gpu_list = [gpu_list_param]
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list_param)
    else:
        gpu_list = list(range(len(gpu_list_param.split(','))))
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list_param
    gpu_batch_size = int(run_config['batch_size'] / max(len(gpu_list), 1))
    devices = ['/gpu:' + str(gpu_num) for gpu_num in gpu_list]

    train_dataset = run_config.get('dataset', 'kitti')

    print("train_dataset:", train_dataset)

    kdata = KITTIData(data_dir=dirs['data'],
                      fast_dir=dirs.get('fast'),
                      stat_log_dir=None,
                      development=run_config['development'])
    einput = KITTIInput(data=kdata,
                        batch_size=1,
                        normalize=False,
                        dims=(384, 1280))

    if train_dataset == 'chairs':
        cconfig = copy.deepcopy(experiment.config['train'])
        cconfig.update(experiment.config['train_chairs'])
        convert_input_strings(cconfig, dirs)
        citers = cconfig.get('num_iters', 0)
        cdata = ChairsData(data_dir=dirs['data'],
                           fast_dir=dirs.get('fast'),
                           stat_log_dir=None,
                           development=run_config['development'])
        cinput = ChairsInput(data=cdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(cconfig['height'], cconfig['width']))
        tr = Trainer(lambda shift: cinput.input_raw(
            swap_images=False, shift=shift * run_config['batch_size']),
                     lambda: einput.input_train_2012(),
                     params=cconfig,
                     normalization=cinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, citers)

    elif train_dataset == 'kitti':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_kitti'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        kinput = KITTIInput(data=kdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=True,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(lambda shift: kinput.input_raw(swap_images=False,
                                                    center_crop=True,
                                                    shift=shift * run_config[
                                                        'batch_size']),
                     lambda: einput.input_train_2012(),
                     params=kconfig,
                     normalization=kinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'cityscapes':
        kconfig = copy.deepcopy(experiment.config['train'])
        kconfig.update(experiment.config['train_cityscapes'])
        convert_input_strings(kconfig, dirs)
        kiters = kconfig.get('num_iters', 0)
        cdata = CityscapesData(data_dir=dirs['data'],
                               fast_dir=dirs.get('fast'),
                               stat_log_dir=None,
                               development=run_config['development'])
        kinput = KITTIInput(data=cdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            skipped_frames=False,
                            dims=(kconfig['height'], kconfig['width']))
        tr = Trainer(lambda shift: kinput.input_raw(swap_images=False,
                                                    center_crop=True,
                                                    skip=[0, 1],
                                                    shift=shift * run_config[
                                                        'batch_size']),
                     lambda: einput.input_train_2012(),
                     params=kconfig,
                     normalization=kinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, kiters)

    elif train_dataset == 'synthia':
        sconfig = copy.deepcopy(experiment.config['train'])
        sconfig.update(experiment.config['train_synthia'])
        convert_input_strings(sconfig, dirs)
        siters = sconfig.get('num_iters', 0)
        sdata = SynthiaData(data_dir=dirs['data'],
                            fast_dir=dirs.get('fast'),
                            stat_log_dir=None,
                            development=run_config['development'])
        sinput = KITTIInput(data=sdata,
                            batch_size=gpu_batch_size,
                            normalize=False,
                            dims=(sconfig['height'], sconfig['width']))
        tr = Trainer(lambda shift: sinput.input_raw(
            swap_images=False, shift=shift * run_config['batch_size']),
                     lambda: einput.input_train_2012(),
                     params=sconfig,
                     normalization=sinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, siters)

    elif train_dataset == 'nao':
        # c&p and adjusted from synthia
        nconfig = copy.deepcopy(experiment.config['train'])
        nconfig.update(experiment.config['train_nao'])
        convert_input_strings(nconfig, dirs)
        niters = nconfig.get('num_iters', 0)
        ndata = NaoData(dirs['data'], development=False)
        ninput = NaoInput(ndata,
                          batch_size=1,
                          normalize=False,
                          dir_name='grey400',
                          dims=(192, 256))
        tr = Trainer(
            lambda shift: ninput.input_consecutive(return_shape=False),
            # lambda: einput.input_train_2012(),    # todo: is this appropriate for nao data? what does it do?
            lambda: None,
            params=nconfig,
            normalization=ninput.get_normalization(),
            train_summaries_dir=experiment.train_dir,
            eval_summaries_dir=experiment.eval_dir,
            experiment=FLAGS.ex,
            ckpt_dir=experiment.save_dir,
            debug=FLAGS.debug,
            interactive_plot=run_config.get('interactive_plot'),
            devices=devices)
        tr.run(0, niters)

    elif train_dataset == 'kitti_ft':
        ftconfig = copy.deepcopy(experiment.config['train'])
        ftconfig.update(experiment.config['train_kitti_ft'])
        convert_input_strings(ftconfig, dirs)
        ftiters = ftconfig.get('num_iters', 0)
        ftinput = KITTIInput(data=kdata,
                             batch_size=gpu_batch_size,
                             normalize=False,
                             dims=(ftconfig['height'], ftconfig['width']))
        tr = Trainer(lambda shift: ftinput.input_train_gt(40),
                     lambda: einput.input_train_2015(40),
                     supervised=True,
                     params=ftconfig,
                     normalization=ftinput.get_normalization(),
                     train_summaries_dir=experiment.train_dir,
                     eval_summaries_dir=experiment.eval_dir,
                     experiment=FLAGS.ex,
                     ckpt_dir=experiment.save_dir,
                     debug=FLAGS.debug,
                     interactive_plot=run_config.get('interactive_plot'),
                     devices=devices)
        tr.run(0, ftiters)

    else:
        raise ValueError("Invalid dataset. Dataset must be one of "
                         "{synthia, kitti, kitti_ft, cityscapes, chairs, nao}")

    if not FLAGS.debug:
        experiment.conclude()

    print('done with training at:')
    print(str(datetime.now()))