Пример #1
0
        default=(),
        help=
        'Additional dataset configuration values passed as command line arguments'
    )

    args = parser.parse_args()
    args.config_values = util.parse_config(values=args.config_values)

    # TFRecords utility
    if args.tf_records:
        from shapeworld import tf_util

    # does not include variant, as loading data for generation is not expected
    dataset = Dataset.create(dtype=args.type,
                             name=args.name,
                             language=args.language,
                             config=args.config,
                             **args.config_values)
    sys.stdout.write('{time} {dataset}\n'.format(
        time=datetime.now().strftime('%H:%M:%S'), dataset=dataset))
    if args.config is None:
        if args.config_values:
            sys.stdout.write('         config: {config}\n'.format(
                config=args.config_values))
    else:
        sys.stdout.write(
            '         config: {config}\n'.format(config=args.config))
        if args.config_values:
            sys.stdout.write('                 {config}\n'.format(
                config=args.config_values))
    sys.stdout.flush()
Пример #2
0
def main(_):
    # FILESYSTEM SETUP ------------------------------------------------------------
    assert FLAGS.data_dir, "Must specify data location!"
    assert FLAGS.log_dir, "Must specify experiment to log to!"
    assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir
    assert FLAGS.parse_type

    # Folder setup for saving summaries and loading checkpoints
    save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag
    test_path = save_root + os.sep + "test"
    if not tf.gfile.IsDirectory(test_path):
        tf.gfile.MakeDirs(test_path)

    train_path = FLAGS.log_dir + os.sep + FLAGS.exp_tag + os.sep + "train"

    model_ckpt = tf.train.latest_checkpoint(
        train_path)  # Get checkpoint to load
    tf.logging.info("Loading checkpoint %s", model_ckpt)
    assert model_ckpt, "Checkpoints could not be loaded, check that train_path %s exists" % train_path

    # Sanity check graph reset
    tf.reset_default_graph()
    tf.logging.info("Clean graph reset...")

    # try:
    dataset = Dataset.create(dtype=FLAGS.dtype,
                             name=FLAGS.name,
                             config=FLAGS.data_dir)
    dataset.pixel_noise_stddev = 0.1
    dataset.random_sampling = False
    # except Exception:
    #     raise ValueError("config=%s did not point to a valid Shapeworld dataset" % FLAGS.data_dir)

    # Get parsing and parameter feats
    params = Config(mode="test", sw_specification=dataset.specification())

    # Parse decoding arg from CLI
    params.decode_type = FLAGS.decode_type
    assert params.decode_type in ['greedy', 'sample', 'beam']

    # MODEL SETUP ------------------------------------------------------------
    g = tf.Graph()
    with g.as_default():
        parser = SimpleBatchParser(src_vocab=dataset.vocabularies['language'],
                                   batch_type=FLAGS.parse_type)
        vocab, rev_vocab = parser.get_vocab()
        params.vocab_size = len(parser.tgt_vocab)

        batch = tf_util.batch_records(dataset,
                                      mode=FLAGS.data_partition,
                                      batch_size=params.batch_size)
        model = CaptioningModel(config=params, batch_parser=parser)
        model.build_model(batch)

        restore_model = tf.train.Saver()

        tf.logging.info("Network built...")

    # TESTING SETUP ------------------------------------------------------------

    if FLAGS.num_imgs < 1:
        num_imgs = params.instances_per_shard * params.num_shards
    else:
        num_imgs = FLAGS.num_imgs
    tf.logging.info("Running test for %d images", num_imgs)

    test_writer = tf.summary.FileWriter(logdir=test_path, graph=g)

    with tf.Session(graph=g,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # Launch data loading queues
        coordinator = tf.train.Coordinator()
        queue_threads = tf.train.start_queue_runners(sess=sess,
                                                     coord=coordinator)

        # Model restoration
        restore_model.restore(sess, model_ckpt)
        tf.logging.info("Model restored!")

        # Trained model does not need initialisation. Init the vocab conversation tables
        sess.run([tf.tables_initializer()])

        #  Freeze graph
        sess.graph.finalize()

        # Get global step
        global_step = tf.train.global_step(sess, model.global_step)
        tf.logging.info("Successfully loaded %s at global step = %d.",
                        os.path.basename(model_ckpt), global_step)

        start_test_time = time.time()
        corrects = []
        incorrects = []  # For correctly formed, but wrong captions
        misses = []  # For incorrectly formed captions
        perplexities = []

        for b_idx in range(num_imgs):
            # idx_batch = dataset.generate(n=params.batch_size, mode=FLAGS.data_partition, include_model=True)

            reference_caps, inf_decoder_outputs, batch_perplexity = sess.run(
                fetches=[
                    model.reference_captions, model.inf_decoder_output,
                    model.batch_perplexity
                ],
                feed_dict={model.phase: 0})

            ref_cap = reference_caps.squeeze()
            inf_cap = inf_decoder_outputs.sample_id.squeeze()
            perplexities.append(batch_perplexity)

            if inf_cap.ndim > 0 and inf_cap.ndim > 0:
                print("%d REF -> %s | INF -> %s" % (b_idx, " ".join(
                    rev_vocab[r]
                    for r in ref_cap), " ".join(rev_vocab[r]
                                                for r in inf_cap)))

                # Strip <S>, </S> and any irrelevant tokens and convert to list for order insensitivity
                ref_cap = set([
                    tok for tok in ref_cap
                    if int(tok) not in parser.token_filter
                ])
                inf_cap = set([
                    tok for tok in inf_cap
                    if int(tok) not in parser.token_filter
                ])

                if np.all([i in ref_cap for i in inf_cap]):
                    corrects.append(1)
                else:
                    incorrects.append((ref_cap, inf_cap))
            else:
                print("Skipping %d as inf_cap %s is malformed" %
                      (b_idx, inf_cap))
                misses.append(1)

        # Overall scores for checkpoint
        avg_acc = np.mean(corrects).squeeze()
        std_acc = np.std(corrects).squeeze()
        print("Accuracy: %s -> %.5f ± %.5f | Misses: %d " %
              (FLAGS.parse_type, avg_acc, std_acc, len(misses)))

        avg_perplexity = np.mean(perplexities).squeeze()
        std_perplexity = np.std(perplexities).squeeze()
        print("------------")
        print("PERPLEXITY -> %.5f +- %.5f" % (avg_perplexity, std_perplexity))

        new_summ = tf.Summary()
        new_summ.value.add(tag="%s/avg_acc_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=avg_acc)

        new_summ.value.add(tag="%s/std_acc_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=std_acc)
        new_summ.value.add(tag="%s/perplexity_avg_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=avg_perplexity)
        new_summ.value.add(tag="%s/perplexity_std_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=std_perplexity)

        test_writer.add_summary(new_summ,
                                tf.train.global_step(sess, model.global_step))
        test_writer.flush()

        coordinator.request_stop()
        coordinator.join(threads=queue_threads)

        end_time = time.time() - start_test_time
        tf.logging.info('Testing complete in %.2f-secs/%.2f-mins/%.2f-hours',
                        end_time, end_time / 60, end_time / (60 * 60))
Пример #3
0
def main(args):
    if args.debug_every <= 1:
        pdb.set_trace()

    if args.sw_name is not None or args.sw_config is not None:
        assert args.image is None and args.question is None

        from shapeworld import Dataset, torch_util
        from shapeworld.datasets import clevr_util

        class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader):
            def __iter__(self):
                for batch in super(ShapeWorldDataLoader, self).__iter__():
                    if "caption" in batch:
                        question = batch["caption"].long()
                    else:
                        question = batch["question"].long()
                    if args.sw_features == 1:
                        image = batch["world_features"]
                    else:
                        image = batch["world"]
                    feats = image
                    if "agreement" in batch:
                        answer = batch["agreement"].long()
                    else:
                        answer = batch["answer"].long()
                    if "caption_model" in batch:
                        assert args.sw_name.startswith(
                            "clevr") or args.sw_program == 3
                        program_seq = batch["caption_model"]
                        # .apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model)))
                    elif "question_model" in batch:
                        program_seq = batch["question_model"]
                    elif "caption" in batch:
                        if args.sw_program == 1:
                            program_seq = batch["caption_pn"].long()
                        elif args.sw_program == 2:
                            program_seq = batch["caption_rpn"].long()
                        else:
                            program_seq = [None]
                    else:
                        program_seq = [None]
                    # program_seq = torch.IntTensor([0 for _ in batch['question']])
                    program_json = dict()
                    yield question, image, feats, answer, program_seq, program_json

        dataset = Dataset.create(
            dtype=args.sw_type,
            name=args.sw_name,
            variant=args.sw_variant,
            language=args.sw_language,
            config=args.sw_config,
        )
        print("ShapeWorld dataset: {} (variant: {})".format(
            dataset, args.sw_variant))
        print("Config: " + str(args.sw_config))

        if args.program_generator is not None:
            with open(args.program_generator + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.execution_engine is not None:
            with open(args.execution_engine + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.baseline_model is not None:
            with open(args.baseline_model + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        program_token_to_idx = vocab["program_token_to_idx"]

        include_model = args.model_type in ("PG", "EE", "PG+EE") and (
            args.sw_name.startswith("clevr") or args.sw_program == 3)
        if include_model:

            def preprocess(model):
                if args.sw_name.startswith("clevr"):
                    program_prefix = vr.programs.list_to_prefix(
                        model["program"])
                else:
                    program_prefix = clevr_util.parse_program(mode=0,
                                                              model=model)
                program_str = vr.programs.list_to_str(program_prefix)
                program_tokens = tokenize(program_str)
                program_encoded = encode(program_tokens, program_token_to_idx)
                program_encoded += [
                    program_token_to_idx["<NULL>"]
                    for _ in range(27 - len(program_encoded))
                ]
                return np.asarray(program_encoded, dtype=np.int64)

            if args.sw_name.startswith("clevr"):
                preprocessing = dict(question_model=preprocess)
            else:
                preprocessing = dict(caption_model=preprocess)

        elif args.sw_program in (1, 2):

            def preprocess(caption_pn):
                caption_pn += (caption_pn > 0) * 2
                for n, symbol in enumerate(caption_pn):
                    if symbol == 0:
                        caption_pn[n] = 2
                        break
                caption_pn = np.concatenate(([1], caption_pn))
                return caption_pn

            if args.sw_program == 1:
                preprocessing = dict(caption_pn=preprocess)
            else:
                preprocessing = dict(caption_rpn=preprocess)

        else:
            preprocessing = None

        dataset = torch_util.ShapeWorldDataset(
            dataset=dataset,
            mode=(None if args.sw_mode == "none" else args.sw_mode),
            include_model=include_model,
            epoch=(args.num_samples is None),
            preprocessing=preprocessing,
        )

        loader = ShapeWorldDataLoader(dataset=dataset,
                                      batch_size=args.batch_size)

    model = None
    if args.model_type in ("CNN", "LSTM", "CNN+LSTM", "CNN+LSTM+SA"):
        assert args.baseline_model is not None
        print("Loading baseline model from", args.baseline_model)
        model, _ = utils.load_baseline(args.baseline_model)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            model.rnn.expand_vocab(new_vocab["question_token_to_idx"])
    elif args.program_generator is not None and args.execution_engine is not None:
        pg, _ = utils.load_program_generator(args.program_generator,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.execution_engine,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    elif args.model_type == "FiLM":
        assert args.baseline_model is not None
        pg, _ = utils.load_program_generator(args.baseline_model,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.baseline_model,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    else:
        print(
            "Must give either --baseline_model or --program_generator and --execution_engine"
        )
        return

    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
    if args.question is not None and args.image is not None:
        run_single_example(args, model, dtype, args.question)
    # Interactive mode
    elif (args.image is not None and args.input_question_h5 is None
          and args.input_features_h5 is None):
        feats_var = extract_image_features(args, dtype)
        print(colored("Ask me something!", "cyan"))
        while True:
            # Get user question
            question_raw = input(">>> ")
            run_single_example(args, model, dtype, question_raw, feats_var)
    elif args.sw_name is not None or args.sw_config is not None:
        predictions, visualization = run_batch(args, model, dtype, loader)
        if args.sw_pred_dir is not None:
            assert args.sw_pred_name is not None
            pred_dir = os.path.join(
                args.sw_pred_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            if not os.path.isdir(pred_dir):
                os.makedirs(pred_dir)
            id2word = dataset.dataset.vocabulary(value_type="language")
            with open(
                    os.path.join(
                        pred_dir,
                        args.sw_pred_name + "-" + args.sw_mode + ".txt"),
                    "w",
            ) as filehandle:
                filehandle.write("".join(
                    "{} {} {}\n".format(correct, agreement, " ".join(
                        id2word[c] for c in caption))
                    for correct, agreement, caption in zip(
                        predictions["correct"],
                        predictions["agreement"],
                        predictions["caption"],
                    )))
            print("Predictions saved")
        if args.sw_vis_dir is not None:
            assert args.sw_vis_name is not None
            from io import BytesIO
            from shapeworld.world import World

            vis_dir = os.path.join(
                args.sw_vis_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            image_dir = os.path.join(vis_dir, args.sw_mode, "images")
            if not os.path.isdir(image_dir):
                os.makedirs(image_dir)
            worlds = np.transpose(visualization["world"], (0, 2, 3, 1))
            for n in range(worlds.shape[0]):
                image = World.get_image(world_array=worlds[n])
                image_bytes = BytesIO()
                image.save(image_bytes, format="png")
                with open(os.path.join(image_dir, "world-{}.png".format(n)),
                          "wb") as filehandle:
                    filehandle.write(image_bytes.getvalue())
                image_bytes.close()
            with open(
                    os.path.join(
                        vis_dir,
                        args.sw_vis_name + "-" + args.sw_mode + ".html"),
                    "w",
            ) as filehandle:
                html = dataset.dataset.get_html(
                    generated=visualization,
                    image_format="png",
                    image_dir=(args.sw_mode + "/images/"),
                )
                filehandle.write(html)
            print("Visualization saved")
    else:
        vocab = load_vocab(args)
        loader_kwargs = {
            "question_h5": args.input_question_h5,
            "feature_h5": args.input_features_h5,
            "vocab": vocab,
            "batch_size": args.batch_size,
        }
        if args.family_split_file is not None:
            with open(args.family_split_file, "r") as f:
                loader_kwargs["question_families"] = json.load(f)
        with ClevrDataLoader(**loader_kwargs) as loader:
            run_batch(args, model, dtype, loader)
Пример #4
0
    # TFRecords utility
    if args.tf_records:
        from shapeworld import tf_util

    # tensorflow verbosity
    if args.verbosity >= 2:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
    else:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    # dataset
    exclude_values = ('world', ) if args.features else ('world_features', )
    dataset = Dataset.create(dtype=args.type,
                             name=args.name,
                             variant=args.variant,
                             language=args.language,
                             config=args.config,
                             exclude_values=exclude_values,
                             **args.config_values)

    # information about dataset and model
    if args.verbosity >= 1:
        sys.stdout.write('{time} train {model} on {dataset}\n'.format(
            time=datetime.now().strftime('%H:%M:%S'),
            model=args.model,
            dataset=dataset))
        if args.config is None:
            if args.config_values:
                sys.stdout.write('         config: {config}\n'.format(
                    config=args.config_values))
        else: