Exemplo n.º 1
0
import tensorflow as tf
import os
import time

from tensortalk import config
from tensortalk.utils import logger, ensure_dir
from tensortalk.network import CaptionNetwork, TrainInputPipeline

CHECKPOINT_INTERVAL = 2500

if __name__ == '__main__':
    input_pipeline = TrainInputPipeline([config.train_features_file], num_epochs=5, batch_size=config.batch_size)

    session = tf.Session()
    tf.set_random_seed(1337)
    net = CaptionNetwork(session, input_pipeline)

    current_logs_path = os.path.join(config.logs_path, str(int(time.time())))
    ensure_dir(current_logs_path)
    summary_writer = tf.train.SummaryWriter(os.path.expanduser(current_logs_path), session.graph.as_graph_def())
    merged_summary = tf.merge_all_summaries()

    session.run(tf.initialize_all_variables())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=session, coord=coord)

    window_loss = 20.
    processed = 0

    try:
        while not coord.should_stop():
Exemplo n.º 2
0
from tensortalk.sampler import BeamSearchSampler

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('images', metavar='image', type=str, nargs='+',
        help='Images for captioning')
    parser.add_argument('--model', dest='model_file', required=True,
        help='Path to serialized model')
    args = parser.parse_args()

    image_manager = ImageManager()
    coco_manager = CocoManager(config.train_annotations_file, config.words_count)
    weights_file = args.model_file

    session = tf.Session()
    input_pipeline = UserInputPipeline()

    model = CaptionNetwork(session, input_pipeline)
    model.load(weights_file)

    sampler = BeamSearchSampler(beam_size=5)

    for img_name in args.images:
        img = np.float32(PIL.Image.open(img_name))
        img_features = image_manager.extract_features(img)

        sequences = sampler.sample(model, img_features, size=15)
        print img_name
        for sequence in sequences[-3:]:
            words = [coco_manager.vocab.get_word(word_idx - 1, limit=config.output_words_count - 1) for word_idx in sequence]
            print ' '.join(words)