예제 #1
0
 def serialize_value(value,
                     value_name,
                     value_type,
                     write_file,
                     id2word=None):
     if value_type == 'int':
         value = '\n'.join(str(int(x)) for x in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'float':
         value = '\n'.join(str(float(x)) for x in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'vector':
         value = '\n'.join(','.join(str(x) for x in vector)
                           for vector in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'text':
         assert id2word
         value = '\n'.join(' '.join(id2word[word_id]
                                    for word_id in text if word_id)
                           for text in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'world':
         for n in range(len(value)):
             image = World.get_image(world=value[n])
             image_bytes = BytesIO()
             image.save(image_bytes, format='bmp')
             write_file('{}-{}.bmp'.format(value_name, n),
                        image_bytes.getvalue(),
                        binary=True)
             image_bytes.close()
     elif value_type == 'model':
         value = json.dumps(value)
         write_file(value_name + '.json', value)
예제 #2
0
                assert False

            elif args.clevr_format:
                from shapeworld.world import World
                from shapeworld.datasets.clevr_util import parse_program
                assert args.type == 'agreement'
                worlds = generated['world']
                captions = generated['caption']
                captions_length = generated['caption_length']
                captions_model = generated.get('caption_model')
                agreements = generated['agreement']
                for n in range(len(worlds)):
                    index = (shard - 1) * args.instances + n
                    filename = 'world_{}.png'.format(index)
                    image_bytes = BytesIO()
                    World.get_image(world_array=worlds[n]).save(image_bytes,
                                                                format='png')
                    with open(os.path.join(directory, filename),
                              'wb') as filehandle:
                        filehandle.write(image_bytes.getvalue())
                    image_bytes.close()
                    id2word = dataset.vocabulary(value_type='language')
                    if 'alternatives' in generated:
                        captions_iter = zip(captions[n], captions_length[n],
                                            captions_model[n], agreements[n])
                    else:
                        captions_iter = zip(
                            (captions[n], ), (captions_length[n], ),
                            (captions_model[n], ), (agreements[n], ))
                    for caption, caption_length, caption_model, agreement in captions_iter:
                        if agreement == 1.0:
                            answer = 'true'
예제 #3
0
def main(args):
    if args.debug_every <= 1:
        pdb.set_trace()

    if args.sw_name is not None or args.sw_config is not None:
        assert args.image is None and args.question is None

        from shapeworld import Dataset, torch_util
        from shapeworld.datasets import clevr_util

        class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader):
            def __iter__(self):
                for batch in super(ShapeWorldDataLoader, self).__iter__():
                    if "caption" in batch:
                        question = batch["caption"].long()
                    else:
                        question = batch["question"].long()
                    if args.sw_features == 1:
                        image = batch["world_features"]
                    else:
                        image = batch["world"]
                    feats = image
                    if "agreement" in batch:
                        answer = batch["agreement"].long()
                    else:
                        answer = batch["answer"].long()
                    if "caption_model" in batch:
                        assert args.sw_name.startswith(
                            "clevr") or args.sw_program == 3
                        program_seq = batch["caption_model"]
                        # .apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model)))
                    elif "question_model" in batch:
                        program_seq = batch["question_model"]
                    elif "caption" in batch:
                        if args.sw_program == 1:
                            program_seq = batch["caption_pn"].long()
                        elif args.sw_program == 2:
                            program_seq = batch["caption_rpn"].long()
                        else:
                            program_seq = [None]
                    else:
                        program_seq = [None]
                    # program_seq = torch.IntTensor([0 for _ in batch['question']])
                    program_json = dict()
                    yield question, image, feats, answer, program_seq, program_json

        dataset = Dataset.create(
            dtype=args.sw_type,
            name=args.sw_name,
            variant=args.sw_variant,
            language=args.sw_language,
            config=args.sw_config,
        )
        print("ShapeWorld dataset: {} (variant: {})".format(
            dataset, args.sw_variant))
        print("Config: " + str(args.sw_config))

        if args.program_generator is not None:
            with open(args.program_generator + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.execution_engine is not None:
            with open(args.execution_engine + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.baseline_model is not None:
            with open(args.baseline_model + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        program_token_to_idx = vocab["program_token_to_idx"]

        include_model = args.model_type in ("PG", "EE", "PG+EE") and (
            args.sw_name.startswith("clevr") or args.sw_program == 3)
        if include_model:

            def preprocess(model):
                if args.sw_name.startswith("clevr"):
                    program_prefix = vr.programs.list_to_prefix(
                        model["program"])
                else:
                    program_prefix = clevr_util.parse_program(mode=0,
                                                              model=model)
                program_str = vr.programs.list_to_str(program_prefix)
                program_tokens = tokenize(program_str)
                program_encoded = encode(program_tokens, program_token_to_idx)
                program_encoded += [
                    program_token_to_idx["<NULL>"]
                    for _ in range(27 - len(program_encoded))
                ]
                return np.asarray(program_encoded, dtype=np.int64)

            if args.sw_name.startswith("clevr"):
                preprocessing = dict(question_model=preprocess)
            else:
                preprocessing = dict(caption_model=preprocess)

        elif args.sw_program in (1, 2):

            def preprocess(caption_pn):
                caption_pn += (caption_pn > 0) * 2
                for n, symbol in enumerate(caption_pn):
                    if symbol == 0:
                        caption_pn[n] = 2
                        break
                caption_pn = np.concatenate(([1], caption_pn))
                return caption_pn

            if args.sw_program == 1:
                preprocessing = dict(caption_pn=preprocess)
            else:
                preprocessing = dict(caption_rpn=preprocess)

        else:
            preprocessing = None

        dataset = torch_util.ShapeWorldDataset(
            dataset=dataset,
            mode=(None if args.sw_mode == "none" else args.sw_mode),
            include_model=include_model,
            epoch=(args.num_samples is None),
            preprocessing=preprocessing,
        )

        loader = ShapeWorldDataLoader(dataset=dataset,
                                      batch_size=args.batch_size)

    model = None
    if args.model_type in ("CNN", "LSTM", "CNN+LSTM", "CNN+LSTM+SA"):
        assert args.baseline_model is not None
        print("Loading baseline model from", args.baseline_model)
        model, _ = utils.load_baseline(args.baseline_model)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            model.rnn.expand_vocab(new_vocab["question_token_to_idx"])
    elif args.program_generator is not None and args.execution_engine is not None:
        pg, _ = utils.load_program_generator(args.program_generator,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.execution_engine,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    elif args.model_type == "FiLM":
        assert args.baseline_model is not None
        pg, _ = utils.load_program_generator(args.baseline_model,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.baseline_model,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    else:
        print(
            "Must give either --baseline_model or --program_generator and --execution_engine"
        )
        return

    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
    if args.question is not None and args.image is not None:
        run_single_example(args, model, dtype, args.question)
    # Interactive mode
    elif (args.image is not None and args.input_question_h5 is None
          and args.input_features_h5 is None):
        feats_var = extract_image_features(args, dtype)
        print(colored("Ask me something!", "cyan"))
        while True:
            # Get user question
            question_raw = input(">>> ")
            run_single_example(args, model, dtype, question_raw, feats_var)
    elif args.sw_name is not None or args.sw_config is not None:
        predictions, visualization = run_batch(args, model, dtype, loader)
        if args.sw_pred_dir is not None:
            assert args.sw_pred_name is not None
            pred_dir = os.path.join(
                args.sw_pred_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            if not os.path.isdir(pred_dir):
                os.makedirs(pred_dir)
            id2word = dataset.dataset.vocabulary(value_type="language")
            with open(
                    os.path.join(
                        pred_dir,
                        args.sw_pred_name + "-" + args.sw_mode + ".txt"),
                    "w",
            ) as filehandle:
                filehandle.write("".join(
                    "{} {} {}\n".format(correct, agreement, " ".join(
                        id2word[c] for c in caption))
                    for correct, agreement, caption in zip(
                        predictions["correct"],
                        predictions["agreement"],
                        predictions["caption"],
                    )))
            print("Predictions saved")
        if args.sw_vis_dir is not None:
            assert args.sw_vis_name is not None
            from io import BytesIO
            from shapeworld.world import World

            vis_dir = os.path.join(
                args.sw_vis_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            image_dir = os.path.join(vis_dir, args.sw_mode, "images")
            if not os.path.isdir(image_dir):
                os.makedirs(image_dir)
            worlds = np.transpose(visualization["world"], (0, 2, 3, 1))
            for n in range(worlds.shape[0]):
                image = World.get_image(world_array=worlds[n])
                image_bytes = BytesIO()
                image.save(image_bytes, format="png")
                with open(os.path.join(image_dir, "world-{}.png".format(n)),
                          "wb") as filehandle:
                    filehandle.write(image_bytes.getvalue())
                image_bytes.close()
            with open(
                    os.path.join(
                        vis_dir,
                        args.sw_vis_name + "-" + args.sw_mode + ".html"),
                    "w",
            ) as filehandle:
                html = dataset.dataset.get_html(
                    generated=visualization,
                    image_format="png",
                    image_dir=(args.sw_mode + "/images/"),
                )
                filehandle.write(html)
            print("Visualization saved")
    else:
        vocab = load_vocab(args)
        loader_kwargs = {
            "question_h5": args.input_question_h5,
            "feature_h5": args.input_features_h5,
            "vocab": vocab,
            "batch_size": args.batch_size,
        }
        if args.family_split_file is not None:
            with open(args.family_split_file, "r") as f:
                loader_kwargs["question_families"] = json.load(f)
        with ClevrDataLoader(**loader_kwargs) as loader:
            run_batch(args, model, dtype, loader)
예제 #4
0
 def serialize_value(value,
                     value_name,
                     value_type,
                     write_file,
                     concat_worlds=False,
                     id2word=None):
     value_type, alts = alternatives_type(value_type=value_type)
     if value_type == 'int':
         if alts:
             value = '\n'.join(';'.join(str(int(x)) for x in xs)
                               for xs in value) + '\n'
         else:
             value = '\n'.join(str(int(x)) for x in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'float':
         if alts:
             value = '\n'.join(';'.join(str(float(x)) for x in xs)
                               for xs in value) + '\n'
         else:
             value = '\n'.join(str(float(x)) for x in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'vector(int)' or value_type == 'vector(float)':
         if alts:
             value = '\n'.join(';'.join(','.join(str(x) for x in vector)
                                        for vector in vectors)
                               for vectors in value) + '\n'
         else:
             value = '\n'.join(','.join(str(x) for x in vector)
                               for vector in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'text':
         assert id2word
         if alts:
             value = '\n\n'.join('\n'.join(' '.join(id2word[word_id]
                                                    for word_id in text
                                                    if word_id)
                                           for text in texts)
                                 for texts in value) + '\n\n'
         else:
             value = '\n'.join(' '.join(id2word[word_id]
                                        for word_id in text if word_id)
                               for text in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'world':
         if concat_worlds:
             size = ceil(sqrt(len(value)))
             worlds = []
             for y in range(ceil(len(value) / size)):
                 if y < len(value) // size:
                     worlds.append(
                         np.concatenate(
                             [value[y * size + x] for x in range(size)],
                             axis=1))
                 else:
                     worlds.append(
                         np.concatenate([
                             value[y * size + x]
                             for x in range(len(value) % size)
                         ] + [
                             np.zeros_like(a=value[0])
                             for _ in range(-len(value) % size)
                         ],
                                        axis=1))
             worlds = np.concatenate(worlds, axis=0)
             image = World.get_image(world_array=worlds)
             image_bytes = BytesIO()
             image.save(image_bytes, format='bmp')
             write_file(value_name + '.bmp',
                        image_bytes.getvalue(),
                        binary=True)
             image_bytes.close()
         else:
             for n in range(len(value)):
                 image = World.get_image(world_array=value[n])
                 image_bytes = BytesIO()
                 image.save(image_bytes, format='bmp')
                 write_file('{}-{}.bmp'.format(value_name, n),
                            image_bytes.getvalue(),
                            binary=True)
                 image_bytes.close()
     elif value_type == 'model':
         value = json.dumps(value)
         write_file(value_name + '.json', value)