import tensorflow as tf import os import time from tensortalk import config from tensortalk.utils import logger, ensure_dir from tensortalk.network import CaptionNetwork, TrainInputPipeline CHECKPOINT_INTERVAL = 2500 if __name__ == '__main__': input_pipeline = TrainInputPipeline([config.train_features_file], num_epochs=5, batch_size=config.batch_size) session = tf.Session() tf.set_random_seed(1337) net = CaptionNetwork(session, input_pipeline) current_logs_path = os.path.join(config.logs_path, str(int(time.time()))) ensure_dir(current_logs_path) summary_writer = tf.train.SummaryWriter(os.path.expanduser(current_logs_path), session.graph.as_graph_def()) merged_summary = tf.merge_all_summaries() session.run(tf.initialize_all_variables()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=session, coord=coord) window_loss = 20. processed = 0 try: while not coord.should_stop():
from tensortalk.sampler import BeamSearchSampler if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('images', metavar='image', type=str, nargs='+', help='Images for captioning') parser.add_argument('--model', dest='model_file', required=True, help='Path to serialized model') args = parser.parse_args() image_manager = ImageManager() coco_manager = CocoManager(config.train_annotations_file, config.words_count) weights_file = args.model_file session = tf.Session() input_pipeline = UserInputPipeline() model = CaptionNetwork(session, input_pipeline) model.load(weights_file) sampler = BeamSearchSampler(beam_size=5) for img_name in args.images: img = np.float32(PIL.Image.open(img_name)) img_features = image_manager.extract_features(img) sequences = sampler.sample(model, img_features, size=15) print img_name for sequence in sequences[-3:]: words = [coco_manager.vocab.get_word(word_idx - 1, limit=config.output_words_count - 1) for word_idx in sequence] print ' '.join(words)