Ejemplo n.º 1
0
    def test_generate_batch(self):

        generator = KittiGenerator(self.dataset_path, self.depth_base,
                                   self.depth_step)

        start_time = time.time()
        num_set_same_img = 1
        batch = generator.validation_batch(multipatch=False,
                                           num_set_same_img=num_set_same_img)
        duration = time.time() - start_time
        print("PSV Time for batch with %s patches per image: (%.3f sec)" %
              (num_set_same_img, duration))

        print("Batch shape is -> %s" % str(batch['target'].shape))

        # Calculate batch size in bytes to have an approximation of GPU usage
        size = sum([batch['planes'][key].nbytes for key in batch['planes']])
        print("Batch size: %s bytes" % size)

        # Save batch to files
        save_to = "/Users/boyander/MASTER_CVC/DepthEstimation/DepthEstimation/batch_test"
        print("saving batch: %s" % save_to)
        reprojected_images_plot(batch)
        #save_batch_images(batch, save_to)
        print("Done saving batch!")
Ejemplo n.º 2
0
def worker_extraction(params, out_queue):
    cv2.setNumThreads(-1)
    generator = KittiGenerator(params.kitti_path,
                               params.base_depth,
                               params.depth_step)
    while True:
        try:
            batch = generator.next_batch(num_set_same_img=params.patches_per_set)
            #print("Done extraction, put in queue.")
            out_queue.put(batch)
        except Exception as e:
            print("Extraction Worker Exception: %s" % e)
            traceback.print_exc()
Ejemplo n.º 3
0
    def test_meanzero(self):

        input_organizer_zero = InputOrganizer(batch_size=self.batch_size,
                                              meanzero=True)
        input_organizer_normal = InputOrganizer(batch_size=self.batch_size,
                                                meanzero=False)

        generator = KittiGenerator(self.dataset_path, self.depth_base,
                                   self.depth_step)
        batch = generator.next_batch(multipatch=True)

        b_zero = input_organizer_zero.get_feed_dict([batch])
        b_normal = input_organizer_normal.get_feed_dict([batch])

        mean_zero = np.mean(
            b_zero[input_organizer_zero.get_target_placeholder().name])
        mean_normal = np.mean(
            b_normal[input_organizer_normal.get_target_placeholder().name])
        print("Mean on target is -> (%s for organizer meanzero)"
              " (%s for organizer normal)" % (mean_zero, mean_normal))

        sub_zero = np.sum(np.subtract(np.abs(mean_zero), np.abs(mean_zero)))

        print("Zero substraction is %s" % (sub_zero))
Ejemplo n.º 4
0
def run_training(FLAGS):
    """Train MNIST for a number of steps."""
    sess = None
    try:
        # Tell TensorFlow that the model will be built into the default Graph.
        with tf.Graph().as_default():
            # Generate placeholders for the images from all camera
            input_organizer = InputOrganizer(batch_size=FLAGS.batch_size,
                                             meanzero=FLAGS.mean_zero,
                                             num_planes=FLAGS.num_planes)

            # Build a Graph that computes predictions from the inference model.
            net_out = inference(input_organizer,
                                num_planes=FLAGS.num_planes,
                                batch_size=FLAGS.batch_size)

            print("Graph built! continuing...")
            # Add to the Graph the Ops for loss calculation.
            loss = lossF(net_out, input_organizer.get_target_placeholder())

            # Add to the Graph the Ops that calculate and apply gradients.
            print("Learning rate is: %s" % FLAGS.learning_rate)
            train_op = training(loss, FLAGS.learning_rate)

            # Add the Op to compare the logits to the labels during evaluation.
            #eval_correct = evaluation(net_out, input_organizer.get_target_placeholder())

            print("Merging summaries continuing...")
            # Build the summary operation based on the TF collection of Summaries.
            summary_op = tf.summary.merge_all()

            print("Initialize variables...")
            # Add the variable initializer Op.
            init = tf.initialize_all_variables()

            print("Create a saver for writing training checkpoints...")
            # Create a saver for writing training checkpoints.
            saver = tf.train.Saver()

            print("Starting session...")
            # Create a session for running Ops on the Graph.
            #sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config)

            print("Creating SummaryWritter...")
            git_hash = get_git_revision_short_hash()
            summary_name = datetime.now().strftime("%Y_%B_%d_%H_%M_%S")
            summary_name = "%s-%s-%s" % (summary_name, socket.gethostname(),
                                         git_hash)
            summary_dir = os.path.join(FLAGS.traindir, summary_name)
            #os.mkdir(summary_dir)

            # Instantiate a SummaryWriter to output summaries and the Graph.
            summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

            print("Started SummaryWriter -> %s" % summary_dir)
            # And then after everything is built:

            # Run the Op to initialize the variables.
            sess.run(init)

            # read validation batch

            print("Starting multiprocessing queue generator...")
            # IMPORTANT: Define generator to be used
            # Parameters for kitti dataset
            kitti_params = KittiParams(FLAGS.kitti_path, FLAGS.depth_base,
                                       FLAGS.depth_step, FLAGS.patches_per_set)
            generator = GeneratorQueued(
                kitti_params,
                input_organizer,
                batch_size=FLAGS.batch_size,
                extraction_workers=FLAGS.extraction_workers,
                aggregation_workers=FLAGS.aggregation_workers)

            print("Reading validation batch...")
            validation_gene = KittiGenerator(FLAGS.kitti_path,
                                             FLAGS.depth_base,
                                             FLAGS.depth_step)
            validation_batch = input_organizer.get_feed_dict([
                validation_gene.validation_batch(
                    num_set_same_img=FLAGS.batch_size)
            ])
            del validation_gene
            print("Done reading validation Batch!!")

            print(
                "Done! Start training loop, validate and save every (%s steps)..."
                % FLAGS.validate_step)
            # Start the training loop.
            max_steps = FLAGS.max_steps
            for step in range(max_steps):

                # Get images to process in a batch
                start_time1 = time.time()
                feed_dict = generator.get_batch()
                duration_images = time.time() - start_time1

                start_time2 = time.time()
                # Run one step of the model.  The return values are the activations
                # from the `train_op` (which is discarded) and the `loss` Op.  To
                # inspect the values of your Ops or variables, you may include them
                # in the list passed to sess.run() and the value tensors will be
                # returned in the tuple from the call.
                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

                duration_net = time.time() - start_time2
                print('=== Step %d ===' % step)
                # Write the summaries and print an overview fairly often.
                if step % FLAGS.print_step == 0:
                    # Print status to stdout.
                    print(
                        '=== Step %d: loss = %.2f -> images:(%.3f sec), net:(%.3f sec) ==='
                        % (step, loss_value, duration_images, duration_net))
                    # Update the events file.
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # Save a checkpoint every 100 iterations
                # and evaluate the model periodically.
                if (step + 1) % FLAGS.validate_step == 0 or (step +
                                                             1) == max_steps:
                    print("(Step: %s) Checkpoint, saving model." % step)
                    checkpoint_file = os.path.join(summary_dir, 'checkpoint')
                    saver.save(sess, checkpoint_file, global_step=step)
                    eval_summary = do_eval(
                        sess, validation_batch, net_out,
                        input_organizer.get_target_placeholder())
                    #summary_writer.add_summary(eval_summary, step)
                    #summary_writer.flush()
    except Exception as e:
        print("Exception on TRAIN: %s" % e)
        traceback.print_exc()
        if sess:
            sess.close()