Ejemplo n.º 1
0
def main(args):
  if args.debug_every <= 1:
    pdb.set_trace()

  if args.sw_name is not None:
    assert args.image is None and args.question is None

    from shapeworld import Dataset, torch_util
    from shapeworld.datasets import clevr_util

    class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader):

      def __init__(self, **kwargs):
        super(ShapeWorldDataLoader, self).__init__(**kwargs)

      def __iter__(self):
        for batch in super(ShapeWorldDataLoader, self).__iter__():
          question = batch['caption'].long()
          image = batch['world']
          feats = batch['world']
          answer = batch['agreement'].long()
          if 'caption_model' in batch:
            program_seq = batch['caption_model'].apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model)))
          else:
            program_seq = torch.IntTensor([0 for _ in batch['caption']])
          program_json = dict()
          yield question, image, feats, answer, program_seq, program_json

    dataset = Dataset.create(dtype='agreement', name=args.sw_name, variant=args.sw_variant,
      language=args.sw_language, config=args.sw_config)
    print('ShapeWorld dataset: {} (variant: {})'.format(dataset, args.sw_variant))
    print('Config: ' + str(args.sw_config))

    dataset = torch_util.ShapeWorldDataset(dataset=dataset,  # include_model=True
      mode=(None if args.sw_mode == 'none' else args.sw_mode), epoch=(args.num_samples is None))
Ejemplo n.º 2
0
def main(_):
    # FILESYSTEM SETUP ------------------------------------------------------------
    assert FLAGS.data_dir, "Must specify data location!"
    assert FLAGS.log_dir, "Must specify experiment to log to!"
    assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir
    assert FLAGS.cnn_ckpt, "Must specify where to load CNN checkpoint from!"
    assert FLAGS.variant, "Must specific shapeworld variant"

    # Build saving folders
    save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag
    train_path = save_root + os.sep + "train"
    eval_path = save_root + os.sep + "eval"
    test_path = save_root + os.sep + "test"

    if not tf.gfile.IsDirectory(train_path):
        tf.gfile.MakeDirs(train_path)
        tf.gfile.MakeDirs(eval_path)
        tf.gfile.MakeDirs(test_path)

        tf.logging.info("Creating training directory: %s", train_path)
        tf.logging.info("Creating eval directory: %s", eval_path)
        tf.logging.info("Creating eval directory: %s", test_path)
    else:
        tf.logging.info("Using training directory: %s", train_path)
        tf.logging.info("Using eval directory: %s", eval_path)

    # Sanity check
    tf.reset_default_graph()
    tf.logging.info("Clean graph reset...")

    try:
        dataset = Dataset.create(dtype=FLAGS.dtype,
                                 name=FLAGS.name,
                                 variant=FLAGS.variant,
                                 config=FLAGS.data_dir)
        dataset.pixel_noise_stddev = 0.1
    except Exception:
        raise ValueError(
            "variant=%s did not point to a valid Shapeworld dataset" %
            FLAGS.variant)

    # Get parsing and parameter feats
    params = Config(mode="train", sw_specification=dataset.specification())
    params.cnn_checkpoint = FLAGS.cnn_ckpt
    params.batch_size = FLAGS.batch_size

    # MODEL SETUP ------------------------------------------------------------
    g = tf.Graph()
    with g.as_default():
        parser = FullSequenceBatchParser(
            src_vocab=dataset.vocabularies['language'])
        params.vocab_size = len(parser.tgt_vocab)

        batch = tf_util.batch_records(dataset,
                                      mode="train",
                                      batch_size=params.batch_size)
        model = CaptioningModel(config=params, batch_parser=parser)

        if FLAGS.glove_dir:
            tf.logging.info("Loading GloVe Embeddings...")
            gl = GloveLoader(vocab=parser.tgt_vocab,
                             glove_dir=FLAGS.glove_dir,
                             dims=FLAGS.glove_dim,
                             load_new=False)
            glove_initials = gl.get_embeddings_matrix()
            tf.logging.info("Building model with GloVe initialisation...")
            model.build_model(batch, embedding_init=glove_initials)
        else:
            tf.logging.info("Building model without GloVe initialisation...")
            model.build_model(batch)
        tf.logging.info("Network built...")

        # TRAINING OPERATION SETUP ------------------------------------------------------------
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            train_op = tf.contrib.layers.optimize_loss(
                loss=model.batch_loss,
                global_step=model.global_step,
                learning_rate=params.initial_learning_rate,
                optimizer=params.optimizer,
                clip_gradients=params.clip_gradients,
            )

        logging_saver = tf.train.Saver(
            max_to_keep=params.max_checkpoints_to_keep)
        summary_op = tf.summary.merge_all()

    train_writer = tf.summary.FileWriter(logdir=train_path, graph=g)

    tf.logging.info('###' * 20)
    tf.logging.info("Beginning shape2seq network training for %d steps" %
                    params.num_total_steps)

    with tf.Session(graph=g,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        tf.logging.info("### Trainable Variables")
        for var in tf.trainable_variables():
            print("-> %s" % var.op.name)

        coordinator = tf.train.Coordinator()
        queue_threads = tf.train.start_queue_runners(sess=sess,
                                                     coord=coordinator)

        # Initialise everything
        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

        tf.logging.info("Restoring CNN...")
        model.init_fn(sess)

        start_train_time = time.time()

        # Loss accumulator and logging interval generator at [25%, 50%, 75%, 100%] * epoch
        logging_loss = []
        logging_points = np.linspace(0,
                                     params.num_steps_per_epoch,
                                     4,
                                     endpoint=False,
                                     dtype=np.int32)
        logging_points = np.fliplr(
            [params.num_steps_per_epoch - logging_points])[0]

        for c_epoch in range(0, params.num_epochs):
            tf.logging.info("Running epoch %d" % c_epoch)
            for c_step in trange(params.num_steps_per_epoch * c_epoch,
                                 params.num_steps_per_epoch * (c_epoch + 1)):
                if c_step in logging_points:
                    _, loss_, summaries = sess.run(
                        fetches=[train_op, model.batch_loss, summary_op])

                    loss_ = logging_loss + [loss_]
                    logging_loss = []

                    avg_loss = np.mean(loss_).squeeze()
                    new_summ = tf.Summary()
                    new_summ.value.add(tag="train/avg_loss",
                                       simple_value=avg_loss)
                    train_writer.add_summary(
                        new_summ, tf.train.global_step(sess,
                                                       model.global_step))
                    train_writer.add_summary(
                        summaries,
                        tf.train.global_step(sess, model.global_step))
                    train_writer.flush()

                    tf.logging.info(
                        " -> Average loss step %d, for last %d steps: %.5f" %
                        (c_step, len(loss_), avg_loss))

                # Run without summaries
                else:
                    _, loss_, = sess.run(fetches=[train_op, model.batch_loss])
                    logging_loss.append(loss_)

            logging_saver.save(sess=sess,
                               save_path=train_path + os.sep + "model",
                               global_step=tf.train.global_step(
                                   sess, model.global_step))

        coordinator.request_stop()
        coordinator.join(threads=queue_threads)

        end_time = time.time() - start_train_time
        tf.logging.info('Training complete in %.2f-secs/%.2f-mins/%.2f-hours',
                        end_time, end_time / 60, end_time / (60 * 60))
Ejemplo n.º 3
0
    )

    args = parser.parse_args()
    args.config_values = util.parse_config(values=args.config_values)

    # TFRecords utility
    if args.tf_records:
        from shapeworld import tf_util

    if args.v1:
        util.set_version(1)

    # does not include variant, as loading data for generation is not expected
    dataset = Dataset.create(dtype=args.type,
                             name=args.name,
                             language=args.language,
                             config=args.config,
                             **args.config_values)
    sys.stdout.write('{time} {dataset}\n'.format(
        time=datetime.now().strftime('%H:%M:%S'), dataset=dataset))
    if args.config is None:
        if args.config_values:
            sys.stdout.write('         config: {config}\n'.format(
                config=args.config_values))
    else:
        sys.stdout.write(
            '         config: {config}\n'.format(config=args.config))
        if args.config_values:
            sys.stdout.write('                 {config}\n'.format(
                config=args.config_values))
    sys.stdout.flush()
Ejemplo n.º 4
0
def main(_):
    # FILESYSTEM SETUP ------------------------------------------------------------
    assert FLAGS.data_dir, "Must specify data location!"
    assert FLAGS.log_dir, "Must specify experiment to log to!"
    assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir
    assert FLAGS.variant, "Must specific shapeworld variant"

    # Folder setup for saving summaries and loading checkpoints
    save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag
    test_path = save_root + os.sep + "test_2"
    if not tf.gfile.IsDirectory(test_path):
        tf.gfile.MakeDirs(test_path)

    train_path = FLAGS.log_dir + os.sep + FLAGS.exp_tag + os.sep + "train"

    model_ckpt = tf.train.latest_checkpoint(
        train_path)  # Get checkpoint to load
    tf.logging.info("Loading checkpoint %s", model_ckpt)
    assert model_ckpt, "Checkpoints could not be loaded, check that train_path %s exists" % train_path

    # Sanity check graph reset
    tf.reset_default_graph()
    tf.logging.info("Clean graph reset...")

    try:
        dataset = Dataset.create(dtype=FLAGS.dtype,
                                 name=FLAGS.name,
                                 variant=FLAGS.variant,
                                 config=FLAGS.data_dir)
        dataset.pixel_noise_stddev = 0.1
        dataset.random_sampling = False
    except Exception:
        raise ValueError(
            "variant=%s did not point to a valid Shapeworld dataset" %
            FLAGS.variant)

    # Get parsing and parameter feats
    params = Config(mode="test", sw_specification=dataset.specification())

    # Parse decoding arg from CLI
    params.decode_type = FLAGS.decode_type
    assert params.decode_type in ['greedy', 'sample', 'beam']

    # MODEL SETUP ------------------------------------------------------------
    g = tf.Graph()
    with g.as_default():
        parser = FullSequenceBatchParser(
            src_vocab=dataset.vocabularies['language'])
        vocab, rev_vocab = parser.get_vocab()
        params.vocab_size = len(parser.tgt_vocab)

        caption_pl = tf.placeholder(dtype=tf.int32,
                                    shape=(params.batch_size,
                                           FLAGS.max_seq_len))
        caption_len_pl = tf.placeholder(dtype=tf.int32,
                                        shape=(params.batch_size, ))
        world_pl = tf.placeholder(dtype=tf.float32,
                                  shape=(params.batch_size, 64, 64, 3))
        batch = {
            "caption": caption_pl,
            "caption_length": caption_len_pl,
            "world": world_pl
        }

        model = CaptioningModel(config=params, batch_parser=parser)
        model.build_model(batch)

        restore_model = tf.train.Saver()
        tf.logging.info("Network built...")

    # TESTING SETUP ------------------------------------------------------------

    if FLAGS.num_imgs < 1:
        num_imgs = params.instances_per_shard * params.num_shards
    else:
        num_imgs = FLAGS.num_imgs

    tf.logging.info("Running test for %d images", num_imgs)

    test_writer = tf.summary.FileWriter(logdir=test_path, graph=g)
    start_test_time = time.time()

    with tf.Session(graph=g,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # Model restoration
        restore_model.restore(sess, model_ckpt)
        tf.logging.info("Model restored!")

        # Trained model does not need initialisation. Init the vocab conversation tables
        sess.run([tf.tables_initializer()])

        #  Freeze graph
        sess.graph.finalize()

        # Get global step
        global_step = tf.train.global_step(sess, model.global_step)
        tf.logging.info("Successfully loaded %s at global step = %d.",
                        os.path.basename(model_ckpt), global_step)

        misses = []
        cap_scores = []
        perplexities = []

        sem_parser = parser.build_semparser()

        for data_partition in ['validation', 'test']:

            tf.logging.info("Loading Shapeworld data...")
            idx_batch = dataset.generate(n=num_imgs,
                                         mode=data_partition,
                                         include_model=True)

            # Dict of lists -> list of dicts
            idx_batch = [{k: v[idx]
                          for k, v in idx_batch.items()}
                         for idx in range(0, num_imgs)]

            tf.logging.info("Data loaded!..")

            for b_idx, batch in enumerate(idx_batch):
                reference_caps, inf_decoder_outputs, batch_perplexity = sess.run(
                    fetches=[
                        model.reference_captions, model.inf_decoder_output,
                        model.batch_perplexity
                    ],
                    feed_dict={
                        model.phase: 0,
                        world_pl: [batch['world']],
                        caption_pl: [batch['caption']],
                        caption_len_pl: [batch['caption_length']]
                    })

                perplexities.append(batch_perplexity)
                ref_cap = reference_caps.squeeze()
                inf_cap = inf_decoder_outputs.sample_id.squeeze()

                if inf_cap.ndim > 0 and inf_cap.ndim > 0:
                    ref_cap = " ".join(rev_vocab[r] for r in ref_cap
                                       if r != parser.pad_token_id)
                    inf_cap = " ".join([
                        rev_vocab[w] for w in filter(
                            lambda y: y != parser.pad_token_id and y != parser.
                            eos_token_id, inf_cap)
                    ])

                    try:
                        cap_scores.append(
                            sem_parser(batch['world_model'], inf_cap))
                    except Exception as exc:
                        print("Uncaught failure")
                        print(exc)
                        continue

                    print("-------------------------------------------")
                    print("%d | REF -> %s | INF -> %s" %
                          (b_idx, ref_cap, inf_cap))
                    print(cap_scores[-1])

                else:
                    print("Skipping %d as inf_cap %s is malformed" %
                          (b_idx, inf_cap))
                    misses.append(inf_cap)

            avg_perplexity = np.mean(perplexities).squeeze()
            std_perplexity = np.std(perplexities).squeeze()
            agree_rate = np.mean([sc.agreement for sc in cap_scores])
            false_rate = np.mean([sc.false for sc in cap_scores])
            ooscope_rate = np.mean([sc.out_of_scope for sc in cap_scores])
            ungramm_rate = np.mean([sc.ungrammatical for sc in cap_scores])

            print("--------------------------")
            print("PERPLEXITY -> %.5f +- %.5f" %
                  (avg_perplexity, std_perplexity))
            print("AGREEMENT RATE -> %.2f" % agree_rate)
            print("FALSE RATE -> %.2f" % false_rate)
            print("OOSCOPE RATE -> %.2f" % ooscope_rate)
            print("UNGRAMMATICAL RATE -> %.2f" % ungramm_rate)
            print("misses -> %d" % sum(misses))

            new_summ = tf.Summary()
            new_summ.value.add(tag="%s/perplexity_avg_%s_%s" %
                               (data_partition, FLAGS.decode_type, FLAGS.name),
                               simple_value=avg_perplexity)
            new_summ.value.add(tag="%s/perplexity_std_%s_%s" %
                               (data_partition, FLAGS.decode_type, FLAGS.name),
                               simple_value=std_perplexity)
            new_summ.value.add(tag="%s/agree_rate_%s_%s" %
                               (data_partition, FLAGS.decode_type, FLAGS.name),
                               simple_value=agree_rate)
            new_summ.value.add(tag="%s/false_rate_%s_%s" %
                               (data_partition, FLAGS.decode_type, FLAGS.name),
                               simple_value=false_rate)
            new_summ.value.add(tag="%s/ooscope_rate_%s_%s" %
                               (data_partition, FLAGS.decode_type, FLAGS.name),
                               simple_value=ooscope_rate)
            new_summ.value.add(tag="%s/ungramm_rate_%s_%s" %
                               (data_partition, FLAGS.decode_type, FLAGS.name),
                               simple_value=ungramm_rate)

            test_writer.add_summary(
                new_summ, tf.train.global_step(sess, model.global_step))
            test_writer.flush()

            test_outputs_fname = test_path + os.sep + "caps_%d_%s_%s.csv" % (
                tf.train.global_step(sess, model.global_step), data_partition,
                FLAGS.decode_type)

            with open(test_outputs_fname, 'w', newline='\n') as fh:
                writer = csv.writer(fh, delimiter=',')
                writer.writerow(cap_scores[0]._fields)
                writer.writerows(list(c) for c in cap_scores)

        end_time = time.time() - start_test_time
        tf.logging.info('Testing complete in %.2f-secs/%.2f-mins/%.2f-hours',
                        end_time, end_time / 60, end_time / (60 * 60))
Ejemplo n.º 5
0
def main(args):
  if args.randomize_checkpoint_path == 1:
    name, ext = os.path.splitext(args.checkpoint_path)
    num = random.randint(1, 1000000)
    args.checkpoint_path = '%s_%06d%s' % (name, num, ext)
  print('Will save checkpoints to %s' % args.checkpoint_path)

  if args.sw_name is not None or args.sw_config is not None:
    from shapeworld import Dataset, torch_util
    from shapeworld.datasets import clevr_util

    class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader):

      def __init__(self, **kwargs):
        super(ShapeWorldDataLoader, self).__init__(**kwargs)

      def __iter__(self):
        for batch in super(ShapeWorldDataLoader, self).__iter__():
          question = batch['caption'].long()
          image = batch['world']
          feats = batch['world']
          answer = batch['agreement'].long()
          if 'caption_model' in batch:
            program_seq = batch['caption_model'].apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model)))
          else:
            program_seq = torch.IntTensor([0 for _ in batch['caption']])
          program_json = dict()
          yield question, image, feats, answer, program_seq, program_json

    dataset = Dataset.create(dtype='agreement', name=args.sw_name, variant=args.sw_variant,
      language=args.sw_language, config=args.sw_config)
    print('ShapeWorld dataset: {} (variant: {})'.format(dataset, args.sw_variant))
    print('Config: ' + str(args.sw_config))

    if args.program_generator_start_from is None:
      question_token_to_idx = {
        word: index + 2 if index > 0 else 0
        for word, index in dataset.vocabularies['language'].items()
      }
      question_token_to_idx['<NULL>'] = 0
      question_token_to_idx['<START>'] = 1
      question_token_to_idx['<END>'] = 2
      vocab = dict(
        question_token_to_idx=question_token_to_idx,
        program_token_to_idx={'<NULL>': 0, '<START>': 1, '<END>': 2},  # missing!!!
        answer_token_to_idx={'false': 0, 'true': 1}
      )
      with open(args.checkpoint_path + '.vocab', 'w') as filehandle:
        json.dump(vocab, filehandle)

    else:
      with open(args.program_generator_start_from + '.vocab', 'r') as filehandle:
        vocab = json.load(filehandle)
      question_token_to_idx = vocab['question_token_to_idx']
      index = len(question_token_to_idx)
      for word in dataset.vocabularies['language']:
        if word not in question_token_to_idx:
          question_token_to_idx[word] = index
          index += 1
      with open(args.checkpoint_path + '.vocab', 'w') as filehandle:
        json.dump(vocab, filehandle)

    args.feature_dim = ','.join(str(n) for n in reversed(dataset.world_shape()))
    args.vocab_json = args.checkpoint_path + '.vocab'

    train_dataset = torch_util.ShapeWorldDataset(dataset=dataset, mode='train')  # , include_model=True)
    train_loader = ShapeWorldDataLoader(dataset=train_dataset, batch_size=args.batch_size)  # num_workers=1

    if args.sw_mixer == 1:
      val_loader = list()
      for d in dataset.datasets:
        val_dataset = torch_util.ShapeWorldDataset(dataset=d, mode='validation', epoch=(args.num_val_samples is None))
        val_loader.append(ShapeWorldDataLoader(dataset=val_dataset, batch_size=args.batch_size))  # num_workers=1
    else:
      val_dataset = torch_util.ShapeWorldDataset(dataset=dataset, mode='validation', epoch=(args.num_val_samples is None))
      val_loader = ShapeWorldDataLoader(dataset=val_dataset, batch_size=args.batch_size)  # num_workers=1

    train_loop(args, train_loader, val_loader)

  else:
    vocab = utils.load_vocab(args.vocab_json)

    if args.use_local_copies == 1:
      shutil.copy(args.train_question_h5, '/tmp/train_questions.h5')
      shutil.copy(args.train_features_h5, '/tmp/train_features.h5')
      shutil.copy(args.val_question_h5, '/tmp/val_questions.h5')
      shutil.copy(args.val_features_h5, '/tmp/val_features.h5')
      args.train_question_h5 = '/tmp/train_questions.h5'
      args.train_features_h5 = '/tmp/train_features.h5'
      args.val_question_h5 = '/tmp/val_questions.h5'
      args.val_features_h5 = '/tmp/val_features.h5'

    question_families = None
    if args.family_split_file is not None:
      with open(args.family_split_file, 'r') as f:
        question_families = json.load(f)

    train_loader_kwargs = {
      'question_h5': args.train_question_h5,
      'feature_h5': args.train_features_h5,
      'vocab': vocab,
      'batch_size': args.batch_size,
      'shuffle': args.shuffle_train_data == 1,
      'question_families': question_families,
      'max_samples': args.num_train_samples,
      'num_workers': args.loader_num_workers,
    }
    val_loader_kwargs = {
      'question_h5': args.val_question_h5,
      'feature_h5': args.val_features_h5,
      'vocab': vocab,
      'batch_size': args.batch_size,
      'question_families': question_families,
      'max_samples': args.num_val_samples,
      'num_workers': args.loader_num_workers,
    }

    with ClevrDataLoader(**train_loader_kwargs) as train_loader, \
         ClevrDataLoader(**val_loader_kwargs) as val_loader:
      train_loop(args, train_loader, val_loader)

    if args.use_local_copies == 1 and args.cleanup_local_copies == 1:
      os.remove('/tmp/train_questions.h5')
      os.remove('/tmp/train_features.h5')
      os.remove('/tmp/val_questions.h5')
      os.remove('/tmp/val_features.h5')
Ejemplo n.º 6
0
def main(args):
    if args.debug_every <= 1:
        pdb.set_trace()

    if args.sw_name is not None or args.sw_config is not None:
        assert args.image is None and args.question is None

        from shapeworld import Dataset, torch_util
        from shapeworld.datasets import clevr_util

        class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader):
            def __iter__(self):
                for batch in super(ShapeWorldDataLoader, self).__iter__():
                    if "caption" in batch:
                        question = batch["caption"].long()
                    else:
                        question = batch["question"].long()
                    if args.sw_features == 1:
                        image = batch["world_features"]
                    else:
                        image = batch["world"]
                    feats = image
                    if "agreement" in batch:
                        answer = batch["agreement"].long()
                    else:
                        answer = batch["answer"].long()
                    if "caption_model" in batch:
                        assert args.sw_name.startswith(
                            "clevr") or args.sw_program == 3
                        program_seq = batch["caption_model"]
                        # .apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model)))
                    elif "question_model" in batch:
                        program_seq = batch["question_model"]
                    elif "caption" in batch:
                        if args.sw_program == 1:
                            program_seq = batch["caption_pn"].long()
                        elif args.sw_program == 2:
                            program_seq = batch["caption_rpn"].long()
                        else:
                            program_seq = [None]
                    else:
                        program_seq = [None]
                    # program_seq = torch.IntTensor([0 for _ in batch['question']])
                    program_json = dict()
                    yield question, image, feats, answer, program_seq, program_json

        dataset = Dataset.create(
            dtype=args.sw_type,
            name=args.sw_name,
            variant=args.sw_variant,
            language=args.sw_language,
            config=args.sw_config,
        )
        print("ShapeWorld dataset: {} (variant: {})".format(
            dataset, args.sw_variant))
        print("Config: " + str(args.sw_config))

        if args.program_generator is not None:
            with open(args.program_generator + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.execution_engine is not None:
            with open(args.execution_engine + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.baseline_model is not None:
            with open(args.baseline_model + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        program_token_to_idx = vocab["program_token_to_idx"]

        include_model = args.model_type in ("PG", "EE", "PG+EE") and (
            args.sw_name.startswith("clevr") or args.sw_program == 3)
        if include_model:

            def preprocess(model):
                if args.sw_name.startswith("clevr"):
                    program_prefix = vr.programs.list_to_prefix(
                        model["program"])
                else:
                    program_prefix = clevr_util.parse_program(mode=0,
                                                              model=model)
                program_str = vr.programs.list_to_str(program_prefix)
                program_tokens = tokenize(program_str)
                program_encoded = encode(program_tokens, program_token_to_idx)
                program_encoded += [
                    program_token_to_idx["<NULL>"]
                    for _ in range(27 - len(program_encoded))
                ]
                return np.asarray(program_encoded, dtype=np.int64)

            if args.sw_name.startswith("clevr"):
                preprocessing = dict(question_model=preprocess)
            else:
                preprocessing = dict(caption_model=preprocess)

        elif args.sw_program in (1, 2):

            def preprocess(caption_pn):
                caption_pn += (caption_pn > 0) * 2
                for n, symbol in enumerate(caption_pn):
                    if symbol == 0:
                        caption_pn[n] = 2
                        break
                caption_pn = np.concatenate(([1], caption_pn))
                return caption_pn

            if args.sw_program == 1:
                preprocessing = dict(caption_pn=preprocess)
            else:
                preprocessing = dict(caption_rpn=preprocess)

        else:
            preprocessing = None

        dataset = torch_util.ShapeWorldDataset(
            dataset=dataset,
            mode=(None if args.sw_mode == "none" else args.sw_mode),
            include_model=include_model,
            epoch=(args.num_samples is None),
            preprocessing=preprocessing,
        )

        loader = ShapeWorldDataLoader(dataset=dataset,
                                      batch_size=args.batch_size)

    model = None
    if args.model_type in ("CNN", "LSTM", "CNN+LSTM", "CNN+LSTM+SA"):
        assert args.baseline_model is not None
        print("Loading baseline model from", args.baseline_model)
        model, _ = utils.load_baseline(args.baseline_model)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            model.rnn.expand_vocab(new_vocab["question_token_to_idx"])
    elif args.program_generator is not None and args.execution_engine is not None:
        pg, _ = utils.load_program_generator(args.program_generator,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.execution_engine,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    elif args.model_type == "FiLM":
        assert args.baseline_model is not None
        pg, _ = utils.load_program_generator(args.baseline_model,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.baseline_model,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    else:
        print(
            "Must give either --baseline_model or --program_generator and --execution_engine"
        )
        return

    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
    if args.question is not None and args.image is not None:
        run_single_example(args, model, dtype, args.question)
    # Interactive mode
    elif (args.image is not None and args.input_question_h5 is None
          and args.input_features_h5 is None):
        feats_var = extract_image_features(args, dtype)
        print(colored("Ask me something!", "cyan"))
        while True:
            # Get user question
            question_raw = input(">>> ")
            run_single_example(args, model, dtype, question_raw, feats_var)
    elif args.sw_name is not None or args.sw_config is not None:
        predictions, visualization = run_batch(args, model, dtype, loader)
        if args.sw_pred_dir is not None:
            assert args.sw_pred_name is not None
            pred_dir = os.path.join(
                args.sw_pred_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            if not os.path.isdir(pred_dir):
                os.makedirs(pred_dir)
            id2word = dataset.dataset.vocabulary(value_type="language")
            with open(
                    os.path.join(
                        pred_dir,
                        args.sw_pred_name + "-" + args.sw_mode + ".txt"),
                    "w",
            ) as filehandle:
                filehandle.write("".join(
                    "{} {} {}\n".format(correct, agreement, " ".join(
                        id2word[c] for c in caption))
                    for correct, agreement, caption in zip(
                        predictions["correct"],
                        predictions["agreement"],
                        predictions["caption"],
                    )))
            print("Predictions saved")
        if args.sw_vis_dir is not None:
            assert args.sw_vis_name is not None
            from io import BytesIO
            from shapeworld.world import World

            vis_dir = os.path.join(
                args.sw_vis_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            image_dir = os.path.join(vis_dir, args.sw_mode, "images")
            if not os.path.isdir(image_dir):
                os.makedirs(image_dir)
            worlds = np.transpose(visualization["world"], (0, 2, 3, 1))
            for n in range(worlds.shape[0]):
                image = World.get_image(world_array=worlds[n])
                image_bytes = BytesIO()
                image.save(image_bytes, format="png")
                with open(os.path.join(image_dir, "world-{}.png".format(n)),
                          "wb") as filehandle:
                    filehandle.write(image_bytes.getvalue())
                image_bytes.close()
            with open(
                    os.path.join(
                        vis_dir,
                        args.sw_vis_name + "-" + args.sw_mode + ".html"),
                    "w",
            ) as filehandle:
                html = dataset.dataset.get_html(
                    generated=visualization,
                    image_format="png",
                    image_dir=(args.sw_mode + "/images/"),
                )
                filehandle.write(html)
            print("Visualization saved")
    else:
        vocab = load_vocab(args)
        loader_kwargs = {
            "question_h5": args.input_question_h5,
            "feature_h5": args.input_features_h5,
            "vocab": vocab,
            "batch_size": args.batch_size,
        }
        if args.family_split_file is not None:
            with open(args.family_split_file, "r") as f:
                loader_kwargs["question_families"] = json.load(f)
        with ClevrDataLoader(**loader_kwargs) as loader:
            run_batch(args, model, dtype, loader)
Ejemplo n.º 7
0
def main(_):
    # FILESYSTEM SETUP ------------------------------------------------------------
    assert FLAGS.data_dir, "Must specify data location!"
    assert FLAGS.log_dir, "Must specify experiment to log to!"
    assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir
    assert FLAGS.parse_type

    # Folder setup for saving summaries and loading checkpoints
    save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag
    test_path = save_root + os.sep + "test"
    if not tf.gfile.IsDirectory(test_path):
        tf.gfile.MakeDirs(test_path)

    train_path = FLAGS.log_dir + os.sep + FLAGS.exp_tag + os.sep + "train"

    model_ckpt = tf.train.latest_checkpoint(
        train_path)  # Get checkpoint to load
    tf.logging.info("Loading checkpoint %s", model_ckpt)
    assert model_ckpt, "Checkpoints could not be loaded, check that train_path %s exists" % train_path

    # Sanity check graph reset
    tf.reset_default_graph()
    tf.logging.info("Clean graph reset...")

    # try:
    dataset = Dataset.create(dtype=FLAGS.dtype,
                             name=FLAGS.name,
                             config=FLAGS.data_dir)
    dataset.pixel_noise_stddev = 0.1
    dataset.random_sampling = False
    # except Exception:
    #     raise ValueError("config=%s did not point to a valid Shapeworld dataset" % FLAGS.data_dir)

    # Get parsing and parameter feats
    params = Config(mode="test", sw_specification=dataset.specification())

    # Parse decoding arg from CLI
    params.decode_type = FLAGS.decode_type
    assert params.decode_type in ['greedy', 'sample', 'beam']

    # MODEL SETUP ------------------------------------------------------------
    g = tf.Graph()
    with g.as_default():
        parser = SimpleBatchParser(src_vocab=dataset.vocabularies['language'],
                                   batch_type=FLAGS.parse_type)
        vocab, rev_vocab = parser.get_vocab()
        params.vocab_size = len(parser.tgt_vocab)

        batch = tf_util.batch_records(dataset,
                                      mode=FLAGS.data_partition,
                                      batch_size=params.batch_size)
        model = CaptioningModel(config=params, batch_parser=parser)
        model.build_model(batch)

        restore_model = tf.train.Saver()

        tf.logging.info("Network built...")

    # TESTING SETUP ------------------------------------------------------------

    if FLAGS.num_imgs < 1:
        num_imgs = params.instances_per_shard * params.num_shards
    else:
        num_imgs = FLAGS.num_imgs
    tf.logging.info("Running test for %d images", num_imgs)

    test_writer = tf.summary.FileWriter(logdir=test_path, graph=g)

    with tf.Session(graph=g,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # Launch data loading queues
        coordinator = tf.train.Coordinator()
        queue_threads = tf.train.start_queue_runners(sess=sess,
                                                     coord=coordinator)

        # Model restoration
        restore_model.restore(sess, model_ckpt)
        tf.logging.info("Model restored!")

        # Trained model does not need initialisation. Init the vocab conversation tables
        sess.run([tf.tables_initializer()])

        #  Freeze graph
        sess.graph.finalize()

        # Get global step
        global_step = tf.train.global_step(sess, model.global_step)
        tf.logging.info("Successfully loaded %s at global step = %d.",
                        os.path.basename(model_ckpt), global_step)

        start_test_time = time.time()
        corrects = []
        incorrects = []  # For correctly formed, but wrong captions
        misses = []  # For incorrectly formed captions
        perplexities = []

        for b_idx in range(num_imgs):
            # idx_batch = dataset.generate(n=params.batch_size, mode=FLAGS.data_partition, include_model=True)

            reference_caps, inf_decoder_outputs, batch_perplexity = sess.run(
                fetches=[
                    model.reference_captions, model.inf_decoder_output,
                    model.batch_perplexity
                ],
                feed_dict={model.phase: 0})

            ref_cap = reference_caps.squeeze()
            inf_cap = inf_decoder_outputs.sample_id.squeeze()
            perplexities.append(batch_perplexity)

            if inf_cap.ndim > 0 and inf_cap.ndim > 0:
                print("%d REF -> %s | INF -> %s" % (b_idx, " ".join(
                    rev_vocab[r]
                    for r in ref_cap), " ".join(rev_vocab[r]
                                                for r in inf_cap)))

                # Strip <S>, </S> and any irrelevant tokens and convert to list for order insensitivity
                ref_cap = set([
                    tok for tok in ref_cap
                    if int(tok) not in parser.token_filter
                ])
                inf_cap = set([
                    tok for tok in inf_cap
                    if int(tok) not in parser.token_filter
                ])

                if np.all([i in ref_cap for i in inf_cap]):
                    corrects.append(1)
                else:
                    incorrects.append((ref_cap, inf_cap))
            else:
                print("Skipping %d as inf_cap %s is malformed" %
                      (b_idx, inf_cap))
                misses.append(1)

        # Overall scores for checkpoint
        avg_acc = np.mean(corrects).squeeze()
        std_acc = np.std(corrects).squeeze()
        print("Accuracy: %s -> %.5f ± %.5f | Misses: %d " %
              (FLAGS.parse_type, avg_acc, std_acc, len(misses)))

        avg_perplexity = np.mean(perplexities).squeeze()
        std_perplexity = np.std(perplexities).squeeze()
        print("------------")
        print("PERPLEXITY -> %.5f +- %.5f" % (avg_perplexity, std_perplexity))

        new_summ = tf.Summary()
        new_summ.value.add(tag="%s/avg_acc_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=avg_acc)

        new_summ.value.add(tag="%s/std_acc_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=std_acc)
        new_summ.value.add(tag="%s/perplexity_avg_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=avg_perplexity)
        new_summ.value.add(tag="%s/perplexity_std_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=std_perplexity)

        test_writer.add_summary(new_summ,
                                tf.train.global_step(sess, model.global_step))
        test_writer.flush()

        coordinator.request_stop()
        coordinator.join(threads=queue_threads)

        end_time = time.time() - start_test_time
        tf.logging.info('Testing complete in %.2f-secs/%.2f-mins/%.2f-hours',
                        end_time, end_time / 60, end_time / (60 * 60))
Ejemplo n.º 8
0
    # TFRecords utility
    if args.tf_records:
        from shapeworld import tf_util

    # tensorflow verbosity
    if args.verbosity >= 2:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
    else:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    # dataset
    exclude_values = ('world', ) if args.features else ('world_features', )
    dataset = Dataset.create(dtype=args.type,
                             name=args.name,
                             variant=args.variant,
                             language=args.language,
                             config=args.config,
                             exclude_values=exclude_values,
                             **args.config_values)

    # information about dataset and model
    if args.verbosity >= 1:
        sys.stdout.write('{time} train {model} on {dataset}\n'.format(
            time=datetime.now().strftime('%H:%M:%S'),
            model=args.model,
            dataset=dataset))
        if args.config is None:
            if args.config_values:
                sys.stdout.write('         config: {config}\n'.format(
                    config=args.config_values))
        else: