Пример #1
0
    def testOutOfRangeError(self):
        with self.test_session():
            [tfrecord_path
             ] = test_utils.create_tfrecord_files(self.get_temp_dir(),
                                                  num_files=1)

        key, value = parallel_reader.single_pass_read(
            tfrecord_path, reader_class=io_ops.TFRecordReader)
        init_op = variables.local_variables_initializer()

        with self.test_session() as sess:
            sess.run(init_op)
            with queues.QueueRunners(sess):
                num_reads = 11
                with self.assertRaises(errors_impl.OutOfRangeError):
                    for _ in range(num_reads):
                        sess.run([key, value])
Пример #2
0
    def testTFRecordSeparateGetDataset(self):
        dataset_dir = tempfile.mkdtemp(
            prefix=os.path.join(self.get_temp_dir(), 'tfrecord_separate_get'))

        height = 300
        width = 280

        with self.test_session():
            provider = dataset_data_provider.DatasetDataProvider(
                _create_tfrecord_dataset(dataset_dir))
        [image] = provider.get(['image'])
        [label] = provider.get(['label'])
        image = _resize_image(image, height, width)

        with session.Session('') as sess:
            with queues.QueueRunners(sess):
                image, label = sess.run([image, label])
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))
Пример #3
0
    def testTFRecordReader(self):
        with self.test_session():
            [tfrecord_path
             ] = test_utils.create_tfrecord_files(self.get_temp_dir(),
                                                  num_files=1)

        key, value = parallel_reader.single_pass_read(
            tfrecord_path, reader_class=io_ops.TFRecordReader)
        init_op = variables.local_variables_initializer()

        with self.test_session() as sess:
            sess.run(init_op)
            with queues.QueueRunners(sess):
                flowers = 0
                num_reads = 9
                for _ in range(num_reads):
                    current_key, _ = sess.run([key, value])
                    if 'flowers' in str(current_key):
                        flowers += 1
                self.assertGreater(flowers, 0)
                self.assertEquals(flowers, num_reads)
Пример #4
0
    def testTFRecordDataset(self):
        dataset_dir = tempfile.mkdtemp(
            prefix=os.path.join(self.get_temp_dir(), 'tfrecord_dataset'))

        height = 300
        width = 280

        with self.test_session():
            test_dataset = _create_tfrecord_dataset(dataset_dir)
            provider = dataset_data_provider.DatasetDataProvider(test_dataset)
            key, image, label = provider.get(['record_key', 'image', 'label'])
            image = _resize_image(image, height, width)

            with session.Session('') as sess:
                with queues.QueueRunners(sess):
                    key, image, label = sess.run([key, image, label])
            split_key = key.decode('utf-8').split(':')
            self.assertEqual(2, len(split_key))
            self.assertEqual(test_dataset.data_sources[0], split_key[0])
            self.assertTrue(split_key[1].isdigit())
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))
Пример #5
0
def main(argv=()):
    del argv

    batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus

    data_stream_init = utils.setup_data_stream_genome(
        "train",
        batch_size=FLAGS.init_batch_size,
        image_res=FLAGS.image_res,
    )
    (image_init_batch, class_init_batch, box_init_batch) = data_stream_init

    data_stream_train = utils.setup_data_stream_genome(
        "train", batch_size=batch_size, image_res=FLAGS.image_res)
    (image_train_batch, class_train_batch, box_train_batch) = data_stream_train

    data_stream_val = utils.setup_data_stream_genome("val",
                                                     batch_size=batch_size,
                                                     image_res=FLAGS.image_res)
    (image_val_batch, class_val_batch, box_val_batch) = data_stream_val

    def model_template(images, labels, boxes, stage):
        return models.model_detection(images, labels, boxes, stage)

    model_factory = tf.make_template("detection", model_template)

    tf.GLOBAL = {}

    # Init
    tf.GLOBAL["init"] = True
    tf.GLOBAL["dropout"] = 0.0

    with tf.device("/cpu:0"):
        _ = model_factory(image_init_batch, [class_init_batch], box_init_batch,
                          0)
    ## Train
    tf.GLOBAL["init"] = False
    tf.GLOBAL["dropout"] = 0.5

    imgs_train = tf.split(image_train_batch, FLAGS.num_gpus, 0)
    class_train = tf.split(class_train_batch, FLAGS.num_gpus, 0)
    boxes_train = tf.split(box_train_batch, FLAGS.num_gpus, 0)

    min_stage = tf.placeholder(shape=[], dtype=tf.int32)
    stage_train = tf.random_uniform([], min_stage, 5, dtype=tf.int32)

    loss_train = 0.0
    for i in range(FLAGS.num_gpus):
        with tf.device("gpu:%i" % i if FLAGS.mode == "gpu" else "/cpu:0"):
            _, loss = model_factory(imgs_train[i], [class_train[i]],
                                    boxes_train[i], stage_train)
            loss_train = loss_train + loss

    loss_train /= FLAGS.num_gpus

    # Optimization
    learning_rate = tf.Variable(0.0001)
    update_lr = learning_rate.assign(FLAGS.decay * learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate, 0.95, 0.9995)
    train_step = optimizer.minimize(loss_train,
                                    colocate_gradients_with_ops=True)

    train_bpd_ph = tf.placeholder(shape=[], dtype=tf.float32)
    summary_train = {
        i: tf.summary.scalar("train_bpd_stage%i" % i, train_bpd_ph)
        for i in range(5)
    }

    ## Val
    tf.GLOBAL["init"] = False
    tf.GLOBAL["dropout"] = 0.0

    imgs_val = tf.split(image_val_batch, FLAGS.num_gpus, 0)
    class_val = tf.split(class_val_batch, FLAGS.num_gpus, 0)
    boxes_val = tf.split(box_val_batch, FLAGS.num_gpus, 0)
    stage_val = tf.random_uniform([], 0, 5, dtype=tf.int32)

    loss_val = 0.0
    label_p_val, point_p_val = [], []
    for i in range(FLAGS.num_gpus):
        with tf.device("gpu:%i" % i if FLAGS.mode == "gpu" else "/cpu:0"):
            [label_p_v,
             point_p_v], loss = model_factory(imgs_val[i], [class_val[i]],
                                              boxes_val[i], stage_val)
            loss_val = loss_val + loss
            label_p_val.append(label_p_v)
            point_p_val.append(point_p_v)

    loss_val /= FLAGS.num_gpus
    label_p_val = [tf.concat(l, axis=0) for l in zip(*label_p_val)]
    point_p_val = [tf.concat(l, axis=0) for l in zip(*point_p_val)]

    val_bpd_ph = tf.placeholder(shape=[], dtype=tf.float32)
    summary_val = {
        i: tf.summary.scalar("val_bpd_stage%i" % i, val_bpd_ph)
        for i in range(5)
    }

    # Counters
    global_step, val_step = tf.Variable(1), tf.Variable(1)
    update_global_step = global_step.assign_add(1)
    update_val_step = val_step.assign_add(1)

    ## Inits
    var_init_1 = [
        v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if v.name.find("image_parser") >= 0
    ]
    var_init_2 = [
        v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if v.name.find("detector") >= 0
    ]
    var_rest = list(
        set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) -
        set(var_init_1 + var_init_2))

    init_ops = [
        tf.initialize_variables(v_l)
        for v_l in [var_init_1, var_init_2, var_rest]
    ]

    ####
    image_summary_placeholder = tf.placeholder(dtype=tf.float32)
    image_summary_sample_val = tf.summary.image("validation_samples",
                                                image_summary_placeholder,
                                                max_outputs=256)
    saver = tf.train.Saver()

    # tf.get_default_graph().finalize()
    with tf.Session() as sess:
        with queues.QueueRunners(sess):

            default_model_meta = os.path.join(FLAGS.tb_log_dir, "main",
                                              "model.ckpt.meta")
            default_model_file = os.path.join(FLAGS.tb_log_dir, "main",
                                              "model.ckpt")
            rerun = False
            if tf.gfile.Exists(default_model_meta):
                print("Model is loading...")
                saver.restore(sess, default_model_file)
                rerun = True
            else:
                # Initialization (Due to the bug in tensorflow it is split
                #                 into multiple steps)
                _ = [sess.run(init_op) for init_op in init_ops]

                if FLAGS.use_pretrained:
                    utils.optimistic_restore(sess, "")
                    sess.run(global_step.assign(1))
                    sess.run(val_step.assign(1))
                    sess.run(learning_rate.assign(0.0001))

            # Summary writers
            summary_writer_main = tf.summary.FileWriter(
                "%s/%s" % (FLAGS.tb_log_dir, "main"), sess.graph)
            # Visalize validation GT

            if not rerun:
                (imgs_sample, box_cls_sample, boxes_sample) = sess.run(
                    [image_val_batch, class_val_batch, box_val_batch])

                boxes_sample = np.concatenate([
                    boxes_sample[..., :2][:, None], boxes_sample[..., 2:][:,
                                                                          None]
                ], 1)
                imgs_with_box = utils.visualize(imgs_sample, box_cls_sample,
                                                boxes_sample, utils.LABEL_MAP)

                s = sess.run(
                    image_summary_sample_val,
                    {image_summary_placeholder: np.array(imgs_with_box)})
                summary_writer_main.add_summary(s, 0)

        # Run training

            n_iter_train = (utils.SST_COUNTS["train"] // batch_size
                            if FLAGS.iter_cap <= 0 else FLAGS.iter_cap)
            n_iter_val = (utils.SST_COUNTS["val"] // batch_size
                          if FLAGS.iter_cap <= 0 else FLAGS.iter_cap)

            max_iter = FLAGS.num_epochs * n_iter_train

            buf_loss = defaultdict(list)
            val_i = 0
            while True and (not FLAGS.run_test):

                # Training step
                (_, loss_v, stage_v, train_i, val_i) = sess.run([
                    train_step, loss_train, stage_train, global_step, val_step
                ], {min_stage: 0})

                buf_loss[stage_v].append(loss_v)

                # Update global counter and learning rate
                sess.run([update_global_step, update_lr])

                # Log training error
                if train_i % FLAGS.log_training_loss == 0:

                    for i in range(5):
                        s = sess.run(summary_train[i],
                                     {train_bpd_ph: np.mean(buf_loss[i])})
                        summary_writer_main.add_summary(s, train_i)
                    buf_loss = defaultdict(list)

                # Log val error and visualize samples
                if train_i % FLAGS.log_val_loss == 0:

                    buf_loss = defaultdict(list)
                    for i in range(n_iter_val):
                        loss_v, stage_v = sess.run([loss_val, stage_val])
                        buf_loss[stage_v].append(loss_v)

                    for i in range(5):
                        s = sess.run(summary_val[i],
                                     {val_bpd_ph: np.mean(buf_loss[i])})
                        summary_writer_main.add_summary(s, val_i)
                    buf_loss = defaultdict(list)

                    # Sample detections

                    label_np = np.zeros((batch_size, 41))
                    boxes_np = np.zeros((batch_size, 56, 56, 4))

                    # stage 0
                    l = sess.run(
                        label_p_val, {
                            image_val_batch: imgs_sample,
                            class_val_batch: label_np,
                            box_val_batch: boxes_np,
                            stage_val: 0
                        })[0]
                    l = np.argmax(l, axis=1)
                    label_np[range(batch_size), l] = 1

                    # stage 1
                    for ii in range(4):
                        l = sess.run(
                            point_p_val, {
                                image_val_batch: imgs_sample,
                                class_val_batch: label_np,
                                box_val_batch: boxes_np,
                                stage_val: ii + 1
                            })[ii]
                        l = (l == np.amax(l, axis=(1, 2),
                                          keepdims=True)).astype("int32")
                        boxes_np[:, :, :, ii:ii + 1] = l

                    # vis
                    boxes_np = np.concatenate([
                        boxes_np[..., :2][:, None], boxes_np[..., 2:][:, None]
                    ], 1)
                    imgs_with_box = utils.visualize(imgs_sample, label_np,
                                                    boxes_np, utils.LABEL_MAP)

                    image_summary_det = tf.summary.image(
                        "detection_samples%i" % val_i,
                        image_summary_placeholder,
                        max_outputs=256)

                    s = sess.run(
                        image_summary_det,
                        {image_summary_placeholder: np.array(imgs_with_box)})
                    summary_writer_main.add_summary(s, 0)

                    # Save model
                    saver.save(
                        sess,
                        os.path.join(FLAGS.tb_log_dir, "main", "model.ckpt"))
                    saver.save(
                        sess,
                        os.path.join(FLAGS.tb_log_dir, "main",
                                     "model%i.ckpt" % val_i))

                    sess.run([update_val_step])

                # Terminate
                if train_i > max_iter:
                    break

            if FLAGS.run_test:
                pass