Beispiel #1
0
def main(_):
    """Train FlowNet for a FLAGS.max_steps."""

    with tf.Graph().as_default():

        imgs_0, imgs_1, flows = flownet_tools.get_data(FLAGS.datadir, True)

        # img summary after loading
        flownet.image_summary(imgs_0, imgs_1, "A_input", flows)

        # apply augmentation
        imgs_0, imgs_1, flows = apply_augmentation(imgs_0, imgs_1, flows)

        # model
        calc_flows = architectures.flownet_dropout(imgs_0, imgs_1, flows)

        # img summary of result
        flownet.image_summary(None, None, "E_result", calc_flows)

        global_step = slim.get_or_create_global_step()

        train_op = flownet.create_train_op(global_step)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        saver = tf_saver.Saver(
            max_to_keep=FLAGS.max_checkpoints,
            keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_op,
            logdir=FLAGS.logdir + '/train',
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            summary_op=tf.summary.merge_all(),
            log_every_n_steps=FLAGS.log_every_n_steps,
            trace_every_n_steps=FLAGS.trace_every_n_steps,
            session_config=config,
            saver=saver,
            number_of_steps=FLAGS.max_steps,
        )
Beispiel #2
0
def main(_):
    """Evaluate FlowNet for FlyingChair test set"""

    with tf.Graph().as_default():
        # Generate tensors from numpy images and flows.
        var_num = 1
        img_0, img_1, flow = flownet_tools.get_data_flow_s(
            FLAGS.datadir, False, var_num)

        imgs_0 = tf.squeeze(tf.stack([img_0 for i in range(FLAGS.batchsize)]))
        imgs_1 = tf.squeeze(tf.stack([img_1 for i in range(FLAGS.batchsize)]))
        flows = tf.squeeze(tf.stack([flow for i in range(FLAGS.batchsize)]))
        # img summary after loading
        flownet.image_summary(imgs_0, imgs_1, "A_input", flows)

        # Get flow tensor from flownet model
        calc_flows = architectures.flownet_dropout(imgs_0, imgs_1, flows)

        flow_mean, confidence, conf_img = var_mean(calc_flows)

        # confidence = tf.image.convert_image_dtype(confidence, tf.uint16)
        # calc EPE / AEE = ((x1-x2)^2 + (y1-y2)^2)^1/2
        # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3478865/

        aee = aee_f(flow, flow_mean, var_num)
        # bilateral solverc
        img_0 = tf.squeeze(tf.stack(img_0))
        flow_s = tf.squeeze(tf.stack(flow))
        solved_flow = flownet.bil_solv_var(img_0, flow_mean, confidence,
                                           flow_s)
        aee_bs = aee_f(flow, solved_flow, var_num)

        metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map(
            {
                "AEE": slim.metrics.streaming_mean(aee),
                "AEE_BS": slim.metrics.streaming_mean(aee_bs),
                # "AEE_BS_No_Confidence": slim.metrics.streaming_mean(aee_bs),
            })

        for name, value in metrics_to_values.iteritems():
            tf.summary.scalar(name, value)
        # Define the summaries to write:
        flownet.image_summary(None, None, "FlowNetS_no_mean", calc_flows)
        solved_flows = tf.squeeze(
            tf.stack([solved_flow for i in range(FLAGS.batchsize)]))
        flow_means = tf.squeeze(
            tf.stack([flow_mean for i in range(FLAGS.batchsize)]))
        conf_imgs = tf.squeeze(
            tf.stack([conf_img for i in range(FLAGS.batchsize)]))
        flownet.image_summary(None, None, "FlowNetS BS", solved_flows)
        flownet.image_summary(None, None, "FlowNetS Mean", flow_means)
        flownet.image_summary(conf_imgs, conf_imgs, "Confidence", None)
        # Run the actual evaluation loop.
        num_batches = math.ceil(FLAGS.testsize)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.logdir + '/train',
            logdir=FLAGS.logdir + '/eval_var_flownet_s',
            num_evals=num_batches,
            eval_op=metrics_to_updates.values(),
            eval_interval_secs=FLAGS.eval_interval_secs,
            summary_op=tf.summary.merge_all(),
            session_config=config,
            timeout=60 * 60)