def main(args): if args.debug_every <= 1: pdb.set_trace() if args.sw_name is not None: assert args.image is None and args.question is None from shapeworld import Dataset, torch_util from shapeworld.datasets import clevr_util class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader): def __init__(self, **kwargs): super(ShapeWorldDataLoader, self).__init__(**kwargs) def __iter__(self): for batch in super(ShapeWorldDataLoader, self).__iter__(): question = batch['caption'].long() image = batch['world'] feats = batch['world'] answer = batch['agreement'].long() if 'caption_model' in batch: program_seq = batch['caption_model'].apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model))) else: program_seq = torch.IntTensor([0 for _ in batch['caption']]) program_json = dict() yield question, image, feats, answer, program_seq, program_json dataset = Dataset.create(dtype='agreement', name=args.sw_name, variant=args.sw_variant, language=args.sw_language, config=args.sw_config) print('ShapeWorld dataset: {} (variant: {})'.format(dataset, args.sw_variant)) print('Config: ' + str(args.sw_config)) dataset = torch_util.ShapeWorldDataset(dataset=dataset, # include_model=True mode=(None if args.sw_mode == 'none' else args.sw_mode), epoch=(args.num_samples is None))
def main(_): # FILESYSTEM SETUP ------------------------------------------------------------ assert FLAGS.data_dir, "Must specify data location!" assert FLAGS.log_dir, "Must specify experiment to log to!" assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir assert FLAGS.cnn_ckpt, "Must specify where to load CNN checkpoint from!" assert FLAGS.variant, "Must specific shapeworld variant" # Build saving folders save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag train_path = save_root + os.sep + "train" eval_path = save_root + os.sep + "eval" test_path = save_root + os.sep + "test" if not tf.gfile.IsDirectory(train_path): tf.gfile.MakeDirs(train_path) tf.gfile.MakeDirs(eval_path) tf.gfile.MakeDirs(test_path) tf.logging.info("Creating training directory: %s", train_path) tf.logging.info("Creating eval directory: %s", eval_path) tf.logging.info("Creating eval directory: %s", test_path) else: tf.logging.info("Using training directory: %s", train_path) tf.logging.info("Using eval directory: %s", eval_path) # Sanity check tf.reset_default_graph() tf.logging.info("Clean graph reset...") try: dataset = Dataset.create(dtype=FLAGS.dtype, name=FLAGS.name, variant=FLAGS.variant, config=FLAGS.data_dir) dataset.pixel_noise_stddev = 0.1 except Exception: raise ValueError( "variant=%s did not point to a valid Shapeworld dataset" % FLAGS.variant) # Get parsing and parameter feats params = Config(mode="train", sw_specification=dataset.specification()) params.cnn_checkpoint = FLAGS.cnn_ckpt params.batch_size = FLAGS.batch_size # MODEL SETUP ------------------------------------------------------------ g = tf.Graph() with g.as_default(): parser = FullSequenceBatchParser( src_vocab=dataset.vocabularies['language']) params.vocab_size = len(parser.tgt_vocab) batch = tf_util.batch_records(dataset, mode="train", batch_size=params.batch_size) model = CaptioningModel(config=params, batch_parser=parser) if FLAGS.glove_dir: tf.logging.info("Loading GloVe Embeddings...") gl = GloveLoader(vocab=parser.tgt_vocab, glove_dir=FLAGS.glove_dir, dims=FLAGS.glove_dim, load_new=False) glove_initials = gl.get_embeddings_matrix() tf.logging.info("Building model with GloVe initialisation...") model.build_model(batch, embedding_init=glove_initials) else: tf.logging.info("Building model without GloVe initialisation...") model.build_model(batch) tf.logging.info("Network built...") # TRAINING OPERATION SETUP ------------------------------------------------------------ with tf.control_dependencies(tf.get_collection( tf.GraphKeys.UPDATE_OPS)): train_op = tf.contrib.layers.optimize_loss( loss=model.batch_loss, global_step=model.global_step, learning_rate=params.initial_learning_rate, optimizer=params.optimizer, clip_gradients=params.clip_gradients, ) logging_saver = tf.train.Saver( max_to_keep=params.max_checkpoints_to_keep) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(logdir=train_path, graph=g) tf.logging.info('###' * 20) tf.logging.info("Beginning shape2seq network training for %d steps" % params.num_total_steps) with tf.Session(graph=g, config=tf.ConfigProto(allow_soft_placement=True)) as sess: tf.logging.info("### Trainable Variables") for var in tf.trainable_variables(): print("-> %s" % var.op.name) coordinator = tf.train.Coordinator() queue_threads = tf.train.start_queue_runners(sess=sess, coord=coordinator) # Initialise everything sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) tf.logging.info("Restoring CNN...") model.init_fn(sess) start_train_time = time.time() # Loss accumulator and logging interval generator at [25%, 50%, 75%, 100%] * epoch logging_loss = [] logging_points = np.linspace(0, params.num_steps_per_epoch, 4, endpoint=False, dtype=np.int32) logging_points = np.fliplr( [params.num_steps_per_epoch - logging_points])[0] for c_epoch in range(0, params.num_epochs): tf.logging.info("Running epoch %d" % c_epoch) for c_step in trange(params.num_steps_per_epoch * c_epoch, params.num_steps_per_epoch * (c_epoch + 1)): if c_step in logging_points: _, loss_, summaries = sess.run( fetches=[train_op, model.batch_loss, summary_op]) loss_ = logging_loss + [loss_] logging_loss = [] avg_loss = np.mean(loss_).squeeze() new_summ = tf.Summary() new_summ.value.add(tag="train/avg_loss", simple_value=avg_loss) train_writer.add_summary( new_summ, tf.train.global_step(sess, model.global_step)) train_writer.add_summary( summaries, tf.train.global_step(sess, model.global_step)) train_writer.flush() tf.logging.info( " -> Average loss step %d, for last %d steps: %.5f" % (c_step, len(loss_), avg_loss)) # Run without summaries else: _, loss_, = sess.run(fetches=[train_op, model.batch_loss]) logging_loss.append(loss_) logging_saver.save(sess=sess, save_path=train_path + os.sep + "model", global_step=tf.train.global_step( sess, model.global_step)) coordinator.request_stop() coordinator.join(threads=queue_threads) end_time = time.time() - start_train_time tf.logging.info('Training complete in %.2f-secs/%.2f-mins/%.2f-hours', end_time, end_time / 60, end_time / (60 * 60))
) args = parser.parse_args() args.config_values = util.parse_config(values=args.config_values) # TFRecords utility if args.tf_records: from shapeworld import tf_util if args.v1: util.set_version(1) # does not include variant, as loading data for generation is not expected dataset = Dataset.create(dtype=args.type, name=args.name, language=args.language, config=args.config, **args.config_values) sys.stdout.write('{time} {dataset}\n'.format( time=datetime.now().strftime('%H:%M:%S'), dataset=dataset)) if args.config is None: if args.config_values: sys.stdout.write(' config: {config}\n'.format( config=args.config_values)) else: sys.stdout.write( ' config: {config}\n'.format(config=args.config)) if args.config_values: sys.stdout.write(' {config}\n'.format( config=args.config_values)) sys.stdout.flush()
def main(_): # FILESYSTEM SETUP ------------------------------------------------------------ assert FLAGS.data_dir, "Must specify data location!" assert FLAGS.log_dir, "Must specify experiment to log to!" assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir assert FLAGS.variant, "Must specific shapeworld variant" # Folder setup for saving summaries and loading checkpoints save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag test_path = save_root + os.sep + "test_2" if not tf.gfile.IsDirectory(test_path): tf.gfile.MakeDirs(test_path) train_path = FLAGS.log_dir + os.sep + FLAGS.exp_tag + os.sep + "train" model_ckpt = tf.train.latest_checkpoint( train_path) # Get checkpoint to load tf.logging.info("Loading checkpoint %s", model_ckpt) assert model_ckpt, "Checkpoints could not be loaded, check that train_path %s exists" % train_path # Sanity check graph reset tf.reset_default_graph() tf.logging.info("Clean graph reset...") try: dataset = Dataset.create(dtype=FLAGS.dtype, name=FLAGS.name, variant=FLAGS.variant, config=FLAGS.data_dir) dataset.pixel_noise_stddev = 0.1 dataset.random_sampling = False except Exception: raise ValueError( "variant=%s did not point to a valid Shapeworld dataset" % FLAGS.variant) # Get parsing and parameter feats params = Config(mode="test", sw_specification=dataset.specification()) # Parse decoding arg from CLI params.decode_type = FLAGS.decode_type assert params.decode_type in ['greedy', 'sample', 'beam'] # MODEL SETUP ------------------------------------------------------------ g = tf.Graph() with g.as_default(): parser = FullSequenceBatchParser( src_vocab=dataset.vocabularies['language']) vocab, rev_vocab = parser.get_vocab() params.vocab_size = len(parser.tgt_vocab) caption_pl = tf.placeholder(dtype=tf.int32, shape=(params.batch_size, FLAGS.max_seq_len)) caption_len_pl = tf.placeholder(dtype=tf.int32, shape=(params.batch_size, )) world_pl = tf.placeholder(dtype=tf.float32, shape=(params.batch_size, 64, 64, 3)) batch = { "caption": caption_pl, "caption_length": caption_len_pl, "world": world_pl } model = CaptioningModel(config=params, batch_parser=parser) model.build_model(batch) restore_model = tf.train.Saver() tf.logging.info("Network built...") # TESTING SETUP ------------------------------------------------------------ if FLAGS.num_imgs < 1: num_imgs = params.instances_per_shard * params.num_shards else: num_imgs = FLAGS.num_imgs tf.logging.info("Running test for %d images", num_imgs) test_writer = tf.summary.FileWriter(logdir=test_path, graph=g) start_test_time = time.time() with tf.Session(graph=g, config=tf.ConfigProto(allow_soft_placement=True)) as sess: # Model restoration restore_model.restore(sess, model_ckpt) tf.logging.info("Model restored!") # Trained model does not need initialisation. Init the vocab conversation tables sess.run([tf.tables_initializer()]) # Freeze graph sess.graph.finalize() # Get global step global_step = tf.train.global_step(sess, model.global_step) tf.logging.info("Successfully loaded %s at global step = %d.", os.path.basename(model_ckpt), global_step) misses = [] cap_scores = [] perplexities = [] sem_parser = parser.build_semparser() for data_partition in ['validation', 'test']: tf.logging.info("Loading Shapeworld data...") idx_batch = dataset.generate(n=num_imgs, mode=data_partition, include_model=True) # Dict of lists -> list of dicts idx_batch = [{k: v[idx] for k, v in idx_batch.items()} for idx in range(0, num_imgs)] tf.logging.info("Data loaded!..") for b_idx, batch in enumerate(idx_batch): reference_caps, inf_decoder_outputs, batch_perplexity = sess.run( fetches=[ model.reference_captions, model.inf_decoder_output, model.batch_perplexity ], feed_dict={ model.phase: 0, world_pl: [batch['world']], caption_pl: [batch['caption']], caption_len_pl: [batch['caption_length']] }) perplexities.append(batch_perplexity) ref_cap = reference_caps.squeeze() inf_cap = inf_decoder_outputs.sample_id.squeeze() if inf_cap.ndim > 0 and inf_cap.ndim > 0: ref_cap = " ".join(rev_vocab[r] for r in ref_cap if r != parser.pad_token_id) inf_cap = " ".join([ rev_vocab[w] for w in filter( lambda y: y != parser.pad_token_id and y != parser. eos_token_id, inf_cap) ]) try: cap_scores.append( sem_parser(batch['world_model'], inf_cap)) except Exception as exc: print("Uncaught failure") print(exc) continue print("-------------------------------------------") print("%d | REF -> %s | INF -> %s" % (b_idx, ref_cap, inf_cap)) print(cap_scores[-1]) else: print("Skipping %d as inf_cap %s is malformed" % (b_idx, inf_cap)) misses.append(inf_cap) avg_perplexity = np.mean(perplexities).squeeze() std_perplexity = np.std(perplexities).squeeze() agree_rate = np.mean([sc.agreement for sc in cap_scores]) false_rate = np.mean([sc.false for sc in cap_scores]) ooscope_rate = np.mean([sc.out_of_scope for sc in cap_scores]) ungramm_rate = np.mean([sc.ungrammatical for sc in cap_scores]) print("--------------------------") print("PERPLEXITY -> %.5f +- %.5f" % (avg_perplexity, std_perplexity)) print("AGREEMENT RATE -> %.2f" % agree_rate) print("FALSE RATE -> %.2f" % false_rate) print("OOSCOPE RATE -> %.2f" % ooscope_rate) print("UNGRAMMATICAL RATE -> %.2f" % ungramm_rate) print("misses -> %d" % sum(misses)) new_summ = tf.Summary() new_summ.value.add(tag="%s/perplexity_avg_%s_%s" % (data_partition, FLAGS.decode_type, FLAGS.name), simple_value=avg_perplexity) new_summ.value.add(tag="%s/perplexity_std_%s_%s" % (data_partition, FLAGS.decode_type, FLAGS.name), simple_value=std_perplexity) new_summ.value.add(tag="%s/agree_rate_%s_%s" % (data_partition, FLAGS.decode_type, FLAGS.name), simple_value=agree_rate) new_summ.value.add(tag="%s/false_rate_%s_%s" % (data_partition, FLAGS.decode_type, FLAGS.name), simple_value=false_rate) new_summ.value.add(tag="%s/ooscope_rate_%s_%s" % (data_partition, FLAGS.decode_type, FLAGS.name), simple_value=ooscope_rate) new_summ.value.add(tag="%s/ungramm_rate_%s_%s" % (data_partition, FLAGS.decode_type, FLAGS.name), simple_value=ungramm_rate) test_writer.add_summary( new_summ, tf.train.global_step(sess, model.global_step)) test_writer.flush() test_outputs_fname = test_path + os.sep + "caps_%d_%s_%s.csv" % ( tf.train.global_step(sess, model.global_step), data_partition, FLAGS.decode_type) with open(test_outputs_fname, 'w', newline='\n') as fh: writer = csv.writer(fh, delimiter=',') writer.writerow(cap_scores[0]._fields) writer.writerows(list(c) for c in cap_scores) end_time = time.time() - start_test_time tf.logging.info('Testing complete in %.2f-secs/%.2f-mins/%.2f-hours', end_time, end_time / 60, end_time / (60 * 60))
def main(args): if args.randomize_checkpoint_path == 1: name, ext = os.path.splitext(args.checkpoint_path) num = random.randint(1, 1000000) args.checkpoint_path = '%s_%06d%s' % (name, num, ext) print('Will save checkpoints to %s' % args.checkpoint_path) if args.sw_name is not None or args.sw_config is not None: from shapeworld import Dataset, torch_util from shapeworld.datasets import clevr_util class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader): def __init__(self, **kwargs): super(ShapeWorldDataLoader, self).__init__(**kwargs) def __iter__(self): for batch in super(ShapeWorldDataLoader, self).__iter__(): question = batch['caption'].long() image = batch['world'] feats = batch['world'] answer = batch['agreement'].long() if 'caption_model' in batch: program_seq = batch['caption_model'].apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model))) else: program_seq = torch.IntTensor([0 for _ in batch['caption']]) program_json = dict() yield question, image, feats, answer, program_seq, program_json dataset = Dataset.create(dtype='agreement', name=args.sw_name, variant=args.sw_variant, language=args.sw_language, config=args.sw_config) print('ShapeWorld dataset: {} (variant: {})'.format(dataset, args.sw_variant)) print('Config: ' + str(args.sw_config)) if args.program_generator_start_from is None: question_token_to_idx = { word: index + 2 if index > 0 else 0 for word, index in dataset.vocabularies['language'].items() } question_token_to_idx['<NULL>'] = 0 question_token_to_idx['<START>'] = 1 question_token_to_idx['<END>'] = 2 vocab = dict( question_token_to_idx=question_token_to_idx, program_token_to_idx={'<NULL>': 0, '<START>': 1, '<END>': 2}, # missing!!! answer_token_to_idx={'false': 0, 'true': 1} ) with open(args.checkpoint_path + '.vocab', 'w') as filehandle: json.dump(vocab, filehandle) else: with open(args.program_generator_start_from + '.vocab', 'r') as filehandle: vocab = json.load(filehandle) question_token_to_idx = vocab['question_token_to_idx'] index = len(question_token_to_idx) for word in dataset.vocabularies['language']: if word not in question_token_to_idx: question_token_to_idx[word] = index index += 1 with open(args.checkpoint_path + '.vocab', 'w') as filehandle: json.dump(vocab, filehandle) args.feature_dim = ','.join(str(n) for n in reversed(dataset.world_shape())) args.vocab_json = args.checkpoint_path + '.vocab' train_dataset = torch_util.ShapeWorldDataset(dataset=dataset, mode='train') # , include_model=True) train_loader = ShapeWorldDataLoader(dataset=train_dataset, batch_size=args.batch_size) # num_workers=1 if args.sw_mixer == 1: val_loader = list() for d in dataset.datasets: val_dataset = torch_util.ShapeWorldDataset(dataset=d, mode='validation', epoch=(args.num_val_samples is None)) val_loader.append(ShapeWorldDataLoader(dataset=val_dataset, batch_size=args.batch_size)) # num_workers=1 else: val_dataset = torch_util.ShapeWorldDataset(dataset=dataset, mode='validation', epoch=(args.num_val_samples is None)) val_loader = ShapeWorldDataLoader(dataset=val_dataset, batch_size=args.batch_size) # num_workers=1 train_loop(args, train_loader, val_loader) else: vocab = utils.load_vocab(args.vocab_json) if args.use_local_copies == 1: shutil.copy(args.train_question_h5, '/tmp/train_questions.h5') shutil.copy(args.train_features_h5, '/tmp/train_features.h5') shutil.copy(args.val_question_h5, '/tmp/val_questions.h5') shutil.copy(args.val_features_h5, '/tmp/val_features.h5') args.train_question_h5 = '/tmp/train_questions.h5' args.train_features_h5 = '/tmp/train_features.h5' args.val_question_h5 = '/tmp/val_questions.h5' args.val_features_h5 = '/tmp/val_features.h5' question_families = None if args.family_split_file is not None: with open(args.family_split_file, 'r') as f: question_families = json.load(f) train_loader_kwargs = { 'question_h5': args.train_question_h5, 'feature_h5': args.train_features_h5, 'vocab': vocab, 'batch_size': args.batch_size, 'shuffle': args.shuffle_train_data == 1, 'question_families': question_families, 'max_samples': args.num_train_samples, 'num_workers': args.loader_num_workers, } val_loader_kwargs = { 'question_h5': args.val_question_h5, 'feature_h5': args.val_features_h5, 'vocab': vocab, 'batch_size': args.batch_size, 'question_families': question_families, 'max_samples': args.num_val_samples, 'num_workers': args.loader_num_workers, } with ClevrDataLoader(**train_loader_kwargs) as train_loader, \ ClevrDataLoader(**val_loader_kwargs) as val_loader: train_loop(args, train_loader, val_loader) if args.use_local_copies == 1 and args.cleanup_local_copies == 1: os.remove('/tmp/train_questions.h5') os.remove('/tmp/train_features.h5') os.remove('/tmp/val_questions.h5') os.remove('/tmp/val_features.h5')
def main(args): if args.debug_every <= 1: pdb.set_trace() if args.sw_name is not None or args.sw_config is not None: assert args.image is None and args.question is None from shapeworld import Dataset, torch_util from shapeworld.datasets import clevr_util class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader): def __iter__(self): for batch in super(ShapeWorldDataLoader, self).__iter__(): if "caption" in batch: question = batch["caption"].long() else: question = batch["question"].long() if args.sw_features == 1: image = batch["world_features"] else: image = batch["world"] feats = image if "agreement" in batch: answer = batch["agreement"].long() else: answer = batch["answer"].long() if "caption_model" in batch: assert args.sw_name.startswith( "clevr") or args.sw_program == 3 program_seq = batch["caption_model"] # .apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model))) elif "question_model" in batch: program_seq = batch["question_model"] elif "caption" in batch: if args.sw_program == 1: program_seq = batch["caption_pn"].long() elif args.sw_program == 2: program_seq = batch["caption_rpn"].long() else: program_seq = [None] else: program_seq = [None] # program_seq = torch.IntTensor([0 for _ in batch['question']]) program_json = dict() yield question, image, feats, answer, program_seq, program_json dataset = Dataset.create( dtype=args.sw_type, name=args.sw_name, variant=args.sw_variant, language=args.sw_language, config=args.sw_config, ) print("ShapeWorld dataset: {} (variant: {})".format( dataset, args.sw_variant)) print("Config: " + str(args.sw_config)) if args.program_generator is not None: with open(args.program_generator + ".vocab", "r") as filehandle: vocab = json.load(filehandle) elif args.execution_engine is not None: with open(args.execution_engine + ".vocab", "r") as filehandle: vocab = json.load(filehandle) elif args.baseline_model is not None: with open(args.baseline_model + ".vocab", "r") as filehandle: vocab = json.load(filehandle) program_token_to_idx = vocab["program_token_to_idx"] include_model = args.model_type in ("PG", "EE", "PG+EE") and ( args.sw_name.startswith("clevr") or args.sw_program == 3) if include_model: def preprocess(model): if args.sw_name.startswith("clevr"): program_prefix = vr.programs.list_to_prefix( model["program"]) else: program_prefix = clevr_util.parse_program(mode=0, model=model) program_str = vr.programs.list_to_str(program_prefix) program_tokens = tokenize(program_str) program_encoded = encode(program_tokens, program_token_to_idx) program_encoded += [ program_token_to_idx["<NULL>"] for _ in range(27 - len(program_encoded)) ] return np.asarray(program_encoded, dtype=np.int64) if args.sw_name.startswith("clevr"): preprocessing = dict(question_model=preprocess) else: preprocessing = dict(caption_model=preprocess) elif args.sw_program in (1, 2): def preprocess(caption_pn): caption_pn += (caption_pn > 0) * 2 for n, symbol in enumerate(caption_pn): if symbol == 0: caption_pn[n] = 2 break caption_pn = np.concatenate(([1], caption_pn)) return caption_pn if args.sw_program == 1: preprocessing = dict(caption_pn=preprocess) else: preprocessing = dict(caption_rpn=preprocess) else: preprocessing = None dataset = torch_util.ShapeWorldDataset( dataset=dataset, mode=(None if args.sw_mode == "none" else args.sw_mode), include_model=include_model, epoch=(args.num_samples is None), preprocessing=preprocessing, ) loader = ShapeWorldDataLoader(dataset=dataset, batch_size=args.batch_size) model = None if args.model_type in ("CNN", "LSTM", "CNN+LSTM", "CNN+LSTM+SA"): assert args.baseline_model is not None print("Loading baseline model from", args.baseline_model) model, _ = utils.load_baseline(args.baseline_model) if args.vocab_json is not None: new_vocab = utils.load_vocab(args.vocab_json) model.rnn.expand_vocab(new_vocab["question_token_to_idx"]) elif args.program_generator is not None and args.execution_engine is not None: pg, _ = utils.load_program_generator(args.program_generator, args.model_type) ee, _ = utils.load_execution_engine(args.execution_engine, verbose=False, model_type=args.model_type) if args.vocab_json is not None: new_vocab = utils.load_vocab(args.vocab_json) pg.expand_encoder_vocab(new_vocab["question_token_to_idx"]) model = (pg, ee) elif args.model_type == "FiLM": assert args.baseline_model is not None pg, _ = utils.load_program_generator(args.baseline_model, args.model_type) ee, _ = utils.load_execution_engine(args.baseline_model, verbose=False, model_type=args.model_type) if args.vocab_json is not None: new_vocab = utils.load_vocab(args.vocab_json) pg.expand_encoder_vocab(new_vocab["question_token_to_idx"]) model = (pg, ee) else: print( "Must give either --baseline_model or --program_generator and --execution_engine" ) return if torch.cuda.is_available(): dtype = torch.cuda.FloatTensor else: dtype = torch.FloatTensor if args.question is not None and args.image is not None: run_single_example(args, model, dtype, args.question) # Interactive mode elif (args.image is not None and args.input_question_h5 is None and args.input_features_h5 is None): feats_var = extract_image_features(args, dtype) print(colored("Ask me something!", "cyan")) while True: # Get user question question_raw = input(">>> ") run_single_example(args, model, dtype, question_raw, feats_var) elif args.sw_name is not None or args.sw_config is not None: predictions, visualization = run_batch(args, model, dtype, loader) if args.sw_pred_dir is not None: assert args.sw_pred_name is not None pred_dir = os.path.join( args.sw_pred_dir, dataset.dataset.type, dataset.dataset.name, dataset.dataset.variant, ) if not os.path.isdir(pred_dir): os.makedirs(pred_dir) id2word = dataset.dataset.vocabulary(value_type="language") with open( os.path.join( pred_dir, args.sw_pred_name + "-" + args.sw_mode + ".txt"), "w", ) as filehandle: filehandle.write("".join( "{} {} {}\n".format(correct, agreement, " ".join( id2word[c] for c in caption)) for correct, agreement, caption in zip( predictions["correct"], predictions["agreement"], predictions["caption"], ))) print("Predictions saved") if args.sw_vis_dir is not None: assert args.sw_vis_name is not None from io import BytesIO from shapeworld.world import World vis_dir = os.path.join( args.sw_vis_dir, dataset.dataset.type, dataset.dataset.name, dataset.dataset.variant, ) image_dir = os.path.join(vis_dir, args.sw_mode, "images") if not os.path.isdir(image_dir): os.makedirs(image_dir) worlds = np.transpose(visualization["world"], (0, 2, 3, 1)) for n in range(worlds.shape[0]): image = World.get_image(world_array=worlds[n]) image_bytes = BytesIO() image.save(image_bytes, format="png") with open(os.path.join(image_dir, "world-{}.png".format(n)), "wb") as filehandle: filehandle.write(image_bytes.getvalue()) image_bytes.close() with open( os.path.join( vis_dir, args.sw_vis_name + "-" + args.sw_mode + ".html"), "w", ) as filehandle: html = dataset.dataset.get_html( generated=visualization, image_format="png", image_dir=(args.sw_mode + "/images/"), ) filehandle.write(html) print("Visualization saved") else: vocab = load_vocab(args) loader_kwargs = { "question_h5": args.input_question_h5, "feature_h5": args.input_features_h5, "vocab": vocab, "batch_size": args.batch_size, } if args.family_split_file is not None: with open(args.family_split_file, "r") as f: loader_kwargs["question_families"] = json.load(f) with ClevrDataLoader(**loader_kwargs) as loader: run_batch(args, model, dtype, loader)
def main(_): # FILESYSTEM SETUP ------------------------------------------------------------ assert FLAGS.data_dir, "Must specify data location!" assert FLAGS.log_dir, "Must specify experiment to log to!" assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir assert FLAGS.parse_type # Folder setup for saving summaries and loading checkpoints save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag test_path = save_root + os.sep + "test" if not tf.gfile.IsDirectory(test_path): tf.gfile.MakeDirs(test_path) train_path = FLAGS.log_dir + os.sep + FLAGS.exp_tag + os.sep + "train" model_ckpt = tf.train.latest_checkpoint( train_path) # Get checkpoint to load tf.logging.info("Loading checkpoint %s", model_ckpt) assert model_ckpt, "Checkpoints could not be loaded, check that train_path %s exists" % train_path # Sanity check graph reset tf.reset_default_graph() tf.logging.info("Clean graph reset...") # try: dataset = Dataset.create(dtype=FLAGS.dtype, name=FLAGS.name, config=FLAGS.data_dir) dataset.pixel_noise_stddev = 0.1 dataset.random_sampling = False # except Exception: # raise ValueError("config=%s did not point to a valid Shapeworld dataset" % FLAGS.data_dir) # Get parsing and parameter feats params = Config(mode="test", sw_specification=dataset.specification()) # Parse decoding arg from CLI params.decode_type = FLAGS.decode_type assert params.decode_type in ['greedy', 'sample', 'beam'] # MODEL SETUP ------------------------------------------------------------ g = tf.Graph() with g.as_default(): parser = SimpleBatchParser(src_vocab=dataset.vocabularies['language'], batch_type=FLAGS.parse_type) vocab, rev_vocab = parser.get_vocab() params.vocab_size = len(parser.tgt_vocab) batch = tf_util.batch_records(dataset, mode=FLAGS.data_partition, batch_size=params.batch_size) model = CaptioningModel(config=params, batch_parser=parser) model.build_model(batch) restore_model = tf.train.Saver() tf.logging.info("Network built...") # TESTING SETUP ------------------------------------------------------------ if FLAGS.num_imgs < 1: num_imgs = params.instances_per_shard * params.num_shards else: num_imgs = FLAGS.num_imgs tf.logging.info("Running test for %d images", num_imgs) test_writer = tf.summary.FileWriter(logdir=test_path, graph=g) with tf.Session(graph=g, config=tf.ConfigProto(allow_soft_placement=True)) as sess: # Launch data loading queues coordinator = tf.train.Coordinator() queue_threads = tf.train.start_queue_runners(sess=sess, coord=coordinator) # Model restoration restore_model.restore(sess, model_ckpt) tf.logging.info("Model restored!") # Trained model does not need initialisation. Init the vocab conversation tables sess.run([tf.tables_initializer()]) # Freeze graph sess.graph.finalize() # Get global step global_step = tf.train.global_step(sess, model.global_step) tf.logging.info("Successfully loaded %s at global step = %d.", os.path.basename(model_ckpt), global_step) start_test_time = time.time() corrects = [] incorrects = [] # For correctly formed, but wrong captions misses = [] # For incorrectly formed captions perplexities = [] for b_idx in range(num_imgs): # idx_batch = dataset.generate(n=params.batch_size, mode=FLAGS.data_partition, include_model=True) reference_caps, inf_decoder_outputs, batch_perplexity = sess.run( fetches=[ model.reference_captions, model.inf_decoder_output, model.batch_perplexity ], feed_dict={model.phase: 0}) ref_cap = reference_caps.squeeze() inf_cap = inf_decoder_outputs.sample_id.squeeze() perplexities.append(batch_perplexity) if inf_cap.ndim > 0 and inf_cap.ndim > 0: print("%d REF -> %s | INF -> %s" % (b_idx, " ".join( rev_vocab[r] for r in ref_cap), " ".join(rev_vocab[r] for r in inf_cap))) # Strip <S>, </S> and any irrelevant tokens and convert to list for order insensitivity ref_cap = set([ tok for tok in ref_cap if int(tok) not in parser.token_filter ]) inf_cap = set([ tok for tok in inf_cap if int(tok) not in parser.token_filter ]) if np.all([i in ref_cap for i in inf_cap]): corrects.append(1) else: incorrects.append((ref_cap, inf_cap)) else: print("Skipping %d as inf_cap %s is malformed" % (b_idx, inf_cap)) misses.append(1) # Overall scores for checkpoint avg_acc = np.mean(corrects).squeeze() std_acc = np.std(corrects).squeeze() print("Accuracy: %s -> %.5f ± %.5f | Misses: %d " % (FLAGS.parse_type, avg_acc, std_acc, len(misses))) avg_perplexity = np.mean(perplexities).squeeze() std_perplexity = np.std(perplexities).squeeze() print("------------") print("PERPLEXITY -> %.5f +- %.5f" % (avg_perplexity, std_perplexity)) new_summ = tf.Summary() new_summ.value.add(tag="%s/avg_acc_%s" % (FLAGS.data_partition, FLAGS.name), simple_value=avg_acc) new_summ.value.add(tag="%s/std_acc_%s" % (FLAGS.data_partition, FLAGS.name), simple_value=std_acc) new_summ.value.add(tag="%s/perplexity_avg_%s" % (FLAGS.data_partition, FLAGS.name), simple_value=avg_perplexity) new_summ.value.add(tag="%s/perplexity_std_%s" % (FLAGS.data_partition, FLAGS.name), simple_value=std_perplexity) test_writer.add_summary(new_summ, tf.train.global_step(sess, model.global_step)) test_writer.flush() coordinator.request_stop() coordinator.join(threads=queue_threads) end_time = time.time() - start_test_time tf.logging.info('Testing complete in %.2f-secs/%.2f-mins/%.2f-hours', end_time, end_time / 60, end_time / (60 * 60))
# TFRecords utility if args.tf_records: from shapeworld import tf_util # tensorflow verbosity if args.verbosity >= 2: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' else: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # dataset exclude_values = ('world', ) if args.features else ('world_features', ) dataset = Dataset.create(dtype=args.type, name=args.name, variant=args.variant, language=args.language, config=args.config, exclude_values=exclude_values, **args.config_values) # information about dataset and model if args.verbosity >= 1: sys.stdout.write('{time} train {model} on {dataset}\n'.format( time=datetime.now().strftime('%H:%M:%S'), model=args.model, dataset=dataset)) if args.config is None: if args.config_values: sys.stdout.write(' config: {config}\n'.format( config=args.config_values)) else: