Exemplo n.º 1
0
def prepare_features(coco_manager, image_source_dir, target):
    image_manager = ImageManager()

    processed = 0
    storage = tf.python_io.TFRecordWriter(target)

    image_batch = []
    text_batch = []
    len_batch = []

    for img_id in coco_manager.img_ids():
        logger().info('%d items processed', processed)
        processed += 1

        try:
            img = coco_manager.load_images(img_id)[0]

            raw_data = np.float32(PIL.Image.open(os.path.join(image_source_dir, img['file_name'])))
            if raw_data.ndim == 2:
                raw_data = raw_data.reshape(*(raw_data.shape + (1,))) + np.zeros((1, 1, 3), dtype=np.int32)
                logger().warn('Found grayscale image, fixed')

            text_data_batch = np.int64(coco_manager.sents(img_id))

            text_buffer = np.zeros((config.sents_per_sample, config.max_len + 1), dtype=np.int64)
            len_buffer = np.zeros(config.sents_per_sample, dtype=np.int64)

            for idx, text_data in enumerate(text_data_batch[:config.sents_per_sample, :config.max_len]):
                text_buffer[idx, 0] = config.words_count + 1
                text_buffer[idx, 1 : 1 + text_data.shape[-1]] = text_data
                len_buffer[idx] = text_data.shape[-1]

            image_batch.append(raw_data)
            text_batch.append(text_buffer)
            len_batch.append(len_buffer)

        except Exception as e:
            logger().error('Failed processing image %s: %s', img['file_name'], str(e))

        if (len(image_batch) == IMAGE_BATCH_SIZE):
            image_data = image_manager.extract_features(image_batch)

            for image, text, length in zip(image_data, text_batch, len_batch):
                example = tf.train.Example()
                example.features.feature['image'].float_list.value.extend([float(value) for value in image])
                example.features.feature['text'].int64_list.value.extend(text[:, :-1].reshape(-1))
                example.features.feature['result'].int64_list.value.extend(text[:, 1:].reshape(-1))
                example.features.feature['len'].int64_list.value.extend(length)

                storage.write(example.SerializeToString())

            image_batch = []
            text_batch = []
            len_batch = []
Exemplo n.º 2
0
    ensure_dir(current_logs_path)
    summary_writer = tf.train.SummaryWriter(os.path.expanduser(current_logs_path), session.graph.as_graph_def())
    merged_summary = tf.merge_all_summaries()

    session.run(tf.initialize_all_variables())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=session, coord=coord)

    window_loss = 20.
    processed = 0

    try:
        while not coord.should_stop():
            _, loss, summary = session.run([net.train_task, net.loss, merged_summary])
            processed += 1
            window_loss = 0.95 * window_loss + 0.05 * loss
            logger().info('Batch %d completed, smoothed loss is %f', processed, window_loss)
            summary_writer.add_summary(summary, processed)
            if processed % CHECKPOINT_INTERVAL == 0:
                epochs_completed = processed / CHECKPOINT_INTERVAL
                weights_file = config.weights_file_template % epochs_completed
                net.save(weights_file)
                logger().info('Model checkpoint saved to %s', weights_file)

    except tf.errors.OutOfRangeError:
        print 'Done'
    finally:
        coord.request_stop()

    coord.join(threads)
    session.close()