コード例 #1
0
    def generate_world(self):
        world = World(self.world_size, self.world_color)
        n = 0
        last_entity = -1

        if self.num_entities == 0:
            return world

        for _ in range(self.num_entities * self.__class__.MAX_ATTEMPTS):
            entity = self.sample_entity(world=world, last_entity=last_entity)
            if world.add_entity(
                    entity,
                    collision_tolerance=self.collision_tolerance,
                    collision_shade_difference=self.collision_shade_difference,
                    boundary_tolerance=self.boundary_tolerance):
                last_entity = entity
                n += 1
                if n == self.num_entities:
                    break
            else:
                last_entity = None
        else:
            return None

        if self.collision_tolerance:
            world.sort_entities()

        return world
コード例 #2
0
    def generate_train_world(self):
        world = World(self.world_size, self.world_color)
        n = 0
        last_entity = -1

        if self.num_entities == 0:
            return world

        for _ in range(self.num_entities * self.__class__.MAX_ATTEMPTS):
            entity = self.sample_entity(world=world, last_entity=last_entity)
            combination = (entity.shape.name, entity.color.name,
                           entity.texture.name)
            if combination in self.invalid_combinations:
                last_entity = None
            elif world.add_entity(
                    entity,
                    collision_tolerance=self.collision_tolerance,
                    collision_shade_difference=self.collision_shade_difference,
                    boundary_tolerance=self.boundary_tolerance):
                last_entity = entity
                n += 1
                if n == self.num_entities:
                    break
            else:
                last_entity = None
        else:
            return None

        if self.collision_tolerance:
            world.sort_entities()

        return world
コード例 #3
0
 def sample_entity(self,
                   world: World,
                   last_entity,
                   combinations=None,
                   location_range: Dict = None):
     if last_entity == -1:
         self.provoke_collision = random() < self.provoke_collision_rate
     elif last_entity is not None:
         self.provoke_collision = random() < self.provoke_collision_rate
     if location_range is not None:
         center = world.random_location(
             provoke_collision=self.provoke_collision, **location_range)
     else:
         center = world.random_location(
             provoke_collision=self.provoke_collision)
     if combinations is None:
         return Entity.random_instance(
             center=center,
             rotation=self.rotation,
             size_range=self.size_range,
             distortion_range=self.distortion_range,
             shade_range=self.shade_range,
             shapes=self.shapes,
             colors=self.colors,
             textures=self.textures)
     else:
         return Entity.random_instance(
             center=center,
             rotation=self.rotation,
             size_range=self.size_range,
             distortion_range=self.distortion_range,
             shade_range=self.shade_range,
             combinations=combinations)
コード例 #4
0
 def serialize_value(value,
                     value_name,
                     value_type,
                     write_file,
                     id2word=None):
     if value_type == 'int':
         value = '\n'.join(str(int(x)) for x in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'float':
         value = '\n'.join(str(float(x)) for x in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'vector':
         value = '\n'.join(','.join(str(x) for x in vector)
                           for vector in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'text':
         assert id2word
         value = '\n'.join(' '.join(id2word[word_id]
                                    for word_id in text if word_id)
                           for text in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'world':
         for n in range(len(value)):
             image = World.get_image(world=value[n])
             image_bytes = BytesIO()
             image.save(image_bytes, format='bmp')
             write_file('{}-{}.bmp'.format(value_name, n),
                        image_bytes.getvalue(),
                        binary=True)
             image_bytes.close()
     elif value_type == 'model':
         value = json.dumps(value)
         write_file(value_name + '.json', value)
コード例 #5
0
def images_iter(directory, mode):
    mode = 'dev' if mode == 'validation' else mode
    directory = os.path.join(directory, mode, 'images')
    for root, dirs, files in os.walk(directory):
        if root == directory:
            assert not files
        else:
            assert not dirs
            for filename in files:
                assert filename[:len(mode) + 1] == mode + '-'
                identifier = filename[len(mode) + 1: -6]
                assert identifier[-2:] in ('-0', '-1', '-2', '-3')
                with open(os.path.join(root, filename), 'rb') as filehandle:
                    image = np.array(object=Image.open(fp=filehandle))
                    world1 = World.from_image(image[:, :100, :])
                    world2 = World.from_image(image[:, 150:250, :])
                    world3 = World.from_image(image[:, 300:, :])
                yield identifier, (world1, world2, world3)
コード例 #6
0
ファイル: generic.py プロジェクト: codeaudit/ShapeWorld
 def generate_validation_world(self):
     world = World(self.world_size, self.world_color, self.noise_range)
     n_entity = choice(self.validation_entity_counts)
     if n_entity == 0:
         return world
     shapes = GenericGenerator.choose(self.validation_shapes,
                                      self.shapes_range)
     colors = GenericGenerator.choose(self.validation_colors,
                                      self.colors_range)
     textures = GenericGenerator.choose(self.validation_textures,
                                        self.textures_range)
     if self.validation_combinations:
         for _ in range(GenericGenerator.MAX_ATTEMPTS):
             entity = Entity.random_instance(
                 center=world.random_location(),
                 rotation=self.rotation,
                 size_range=self.size_range,
                 distortion_range=self.distortion_range,
                 shade_range=self.shade_range,
                 combinations=self.validation_combinations)
             if world.add_entity(
                     entity,
                     boundary_tolerance=self.boundary_tolerance,
                     collision_tolerance=self.collision_tolerance):
                 break
         else:
             return None
         n = 1
     else:
         n = 0
     for _ in range(n_entity * GenericGenerator.MAX_ATTEMPTS):
         entity = Entity.random_instance(
             center=world.random_location(),
             shapes=shapes,
             size_range=self.size_range,
             distortion_range=self.distortion_range,
             rotation=self.rotation,
             colors=colors,
             shade_range=self.shade_range,
             textures=textures)
         n += world.add_entity(entity,
                               boundary_tolerance=self.boundary_tolerance,
                               collision_tolerance=self.collision_tolerance)
         if n >= n_entity:
             break
     else:
         return None
     if self.collision_tolerance:
         world.sort_entities()
     return world
コード例 #7
0
 def generate_train_world(self):
     world = World(self.world_size, self.world_color)
     if self.num_entities == 0:
         return world
     n = 0
     provoke_collision = random() < self.provoke_collision_rate
     for _ in range(self.num_entities * self.__class__.MAX_ATTEMPTS):
         center = world.random_location(provoke_collision=provoke_collision)
         entity = Entity.random_instance(
             center=center,
             shapes=self.selected_shapes,
             size_range=self.size_range,
             distortion_range=self.distortion_range,
             rotation=self.rotation,
             colors=self.selected_colors,
             shade_range=self.shade_range,
             textures=self.selected_textures)
         combination = (entity.shape.name, entity.color.name,
                        entity.texture.name)
         if combination in self.validation_combinations or combination in self.test_combinations:
             continue
         if world.add_entity(entity,
                             boundary_tolerance=self.boundary_tolerance,
                             collision_tolerance=self.collision_tolerance):
             n += 1
             provoke_collision = random() < self.provoke_collision_rate
         if n == self.num_entities:
             break
     else:
         return None
     if self.collision_tolerance:
         world.sort_entities()
     return world
コード例 #8
0
def images_iter(directory, mode, parts=None):
    split = 'val' if mode == 'validation' else mode
    if parts is not None:
        split += parts[mode]
    directory = os.path.join(directory, 'images', split)
    for root, dirs, files in os.walk(directory):
        assert root == directory
        assert not dirs
        for n in range(len(files)):
            filename = 'CLEVR_{}_{:0>6}.png'.format(split, n)
            files.remove(filename)
            with open(os.path.join(root, filename), 'rb') as filehandle:
                image = Image.open(fp=filehandle)
                world = World.from_image(image)
            yield world
コード例 #9
0
        def semparse_evaluate(world, caption):
            """Evaluate a caption for agreement against a specified world. """

            # Get the set of MRS parses
            mrs_iter = next(Ace_.parse(sentence_list=[caption]))

            if not mrs_iter:
                return ParseScore(caption, 0, 0, 0, 1)

            # For each MRS parse, attempt a conversion to DMRS
            for mrs in mrs_iter:
                try:
                    dmrs_conversion = mrs.convert_to(cls=Dmrs, copy_nodes=True)

                    # If a DMRS conversion is found, break loop
                    if dmrs_conversion:
                        break
                # Catch a bad parsing err
                except Exception as exc:
                    print(exc)
                    return ParseScore(caption, 0, 0, 1, 0)

            if not dmrs_conversion:
                return ParseScore(caption, 0, 0, 1, 0)
            else:
                # DMRS parse found. Attempt to convert to ShapeWorld predicate analysis
                analyses = Analyzer_.analyze(dmrs=dmrs_conversion)

            # If analyses is empty either bad parse or out of the supported logical scope for SW
            if not analyses:
                print("Empty analyses. Does not fit logical scope of SW")
                return ParseScore(caption, 0, 0, 1, 0)

            else:
                caption = analyses[0]
                world_ = World.from_model(model=world)
                predication = PragmaticalPredication(agreeing=world_.entities)
                caption.apply_to_predication(predication=predication)
                agreement = caption.agreement(predication=predication,
                                              world=world)

                if agreement == 1.0:
                    return ParseScore(caption, 1, 0, 0, 0)
                else:
                    return ParseScore(caption, 0, 1, 0, 0)
コード例 #10
0
    def generate_validation_world(self):
        if self.validation_combinations is None:
            return self.generate_train_world()

        world = World(self.world_size, self.world_color)
        n = 0
        last_entity = -1

        if self.num_entities == 0:
            return world

        if self.validation_combination_rate is not None:
            while True:
                entity = self.sample_entity(world=world, last_entity=last_entity, combinations=self.validation_combinations)
                if world.add_entity(entity, boundary_tolerance=self.boundary_tolerance, collision_tolerance=self.collision_tolerance, collision_shade_difference=self.collision_shade_difference):
                    n += 1
                    last_entity = entity
                    break
                else:
                    last_entity = None

        if self.num_entities == 1:
            return world

        pick_space = random() < self.validation_space_rate
        pick_combination = pick_space and random() < self.validation_combination_rate
        for _ in range(self.num_entities * self.__class__.MAX_ATTEMPTS):
            if pick_combination:
                entity = self.sample_entity(world=world, last_entity=last_entity, combinations=self.validation_combinations)
            elif pick_space:
                entity = self.sample_entity(world=world, last_entity=last_entity, combinations=self.validation_space)
            else:
                entity = self.sample_entity(world=world, last_entity=last_entity)
            combination = (entity.shape.name, entity.color.name, entity.texture.name)
            if combination in self.invalid_validation_combinations:
                last_entity = None
            elif world.add_entity(entity, collision_tolerance=self.collision_tolerance, collision_shade_difference=self.collision_shade_difference, boundary_tolerance=self.boundary_tolerance):
                n += 1
                if n == self.num_entities:
                    break
                last_entity = entity
                pick_space = random() < self.validation_space_rate
                pick_combination = pick_space and random() < self.validation_combination_rate
            else:
                last_entity = None
        else:
            return None

        if self.collision_tolerance:
            world.sort_entities()

        return world
コード例 #11
0
 def deserialize_value(value_name, value_type, read_file, word2id=None):
     if value_type == 'int':
         value = read_file(value_name + '.txt')
         value = [int(x) for x in value.split()]
         return value
     elif value_type == 'float':
         value = read_file(value_name + '.txt')
         value = [float(x) for x in value.split()]
         return value
     elif value_type == 'vector':
         value = read_file(value_name + '.txt')
         value = [[float(x) for x in vector.split(',')]
                  for vector in value.split()]
         return value
     elif value_type == 'text':
         assert word2id
         value = read_file(value_name + '.txt')
         value = [[word2id[word] for word in text.split(' ')]
                  for text in value.split('\n')[:-1]]
         return value
     elif value_type == 'world':
         value = []
         n = 0
         while True:
             image_bytes = read_file('{}-{}.bmp'.format(value_name, n),
                                     binary=True)
             if image_bytes is None:
                 break
             image_bytes = BytesIO(image_bytes)
             image = Image.open(image_bytes)
             value.append(World.from_image(image))
             n += 1
         return value
     elif value_type == 'model':
         value = read_file(value_name + '.json')
         value = json.loads(value)
         return value
コード例 #12
0
 def generate_test_world(self):
     world = World(self.world_size, self.world_color)
     if self.num_entities == 0:
         return world
     n = 0
     last_entity = -1
     if self.test_combinations:
         while True:
             entity = self.sample_entity(
                 world=world,
                 last_entity=last_entity,
                 combinations=self.test_combinations)
             if world.add_entity(
                     entity,
                     boundary_tolerance=self.boundary_tolerance,
                     collision_tolerance=self.collision_tolerance):
                 last_entity = entity
                 n += 1
                 break
             else:
                 last_entity = None
     if n < self.num_entities:
         for _ in range(self.num_entities * self.__class__.MAX_ATTEMPTS):
             entity = self.sample_entity(world=world,
                                         last_entity=last_entity)
             if world.add_entity(
                     entity,
                     boundary_tolerance=self.boundary_tolerance,
                     collision_tolerance=self.collision_tolerance):
                 last_entity = entity
                 n += 1
                 if n == self.num_entities:
                     break
             else:
                 last_entity = None
         else:
             return None
     if self.collision_tolerance:
         world.sort_entities()
     return world
コード例 #13
0
                assert False

            elif args.clevr_format:
                from shapeworld.world import World
                from shapeworld.datasets.clevr_util import parse_program
                assert args.type == 'agreement'
                worlds = generated['world']
                captions = generated['caption']
                captions_length = generated['caption_length']
                captions_model = generated.get('caption_model')
                agreements = generated['agreement']
                for n in range(len(worlds)):
                    index = (shard - 1) * args.instances + n
                    filename = 'world_{}.png'.format(index)
                    image_bytes = BytesIO()
                    World.get_image(world_array=worlds[n]).save(image_bytes,
                                                                format='png')
                    with open(os.path.join(directory, filename),
                              'wb') as filehandle:
                        filehandle.write(image_bytes.getvalue())
                    image_bytes.close()
                    id2word = dataset.vocabulary(value_type='language')
                    if 'alternatives' in generated:
                        captions_iter = zip(captions[n], captions_length[n],
                                            captions_model[n], agreements[n])
                    else:
                        captions_iter = zip(
                            (captions[n], ), (captions_length[n], ),
                            (captions_model[n], ), (agreements[n], ))
                    for caption, caption_length, caption_model, agreement in captions_iter:
                        if agreement == 1.0:
                            answer = 'true'
コード例 #14
0
def main(args):
    if args.debug_every <= 1:
        pdb.set_trace()

    if args.sw_name is not None or args.sw_config is not None:
        assert args.image is None and args.question is None

        from shapeworld import Dataset, torch_util
        from shapeworld.datasets import clevr_util

        class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader):
            def __iter__(self):
                for batch in super(ShapeWorldDataLoader, self).__iter__():
                    if "caption" in batch:
                        question = batch["caption"].long()
                    else:
                        question = batch["question"].long()
                    if args.sw_features == 1:
                        image = batch["world_features"]
                    else:
                        image = batch["world"]
                    feats = image
                    if "agreement" in batch:
                        answer = batch["agreement"].long()
                    else:
                        answer = batch["answer"].long()
                    if "caption_model" in batch:
                        assert args.sw_name.startswith(
                            "clevr") or args.sw_program == 3
                        program_seq = batch["caption_model"]
                        # .apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model)))
                    elif "question_model" in batch:
                        program_seq = batch["question_model"]
                    elif "caption" in batch:
                        if args.sw_program == 1:
                            program_seq = batch["caption_pn"].long()
                        elif args.sw_program == 2:
                            program_seq = batch["caption_rpn"].long()
                        else:
                            program_seq = [None]
                    else:
                        program_seq = [None]
                    # program_seq = torch.IntTensor([0 for _ in batch['question']])
                    program_json = dict()
                    yield question, image, feats, answer, program_seq, program_json

        dataset = Dataset.create(
            dtype=args.sw_type,
            name=args.sw_name,
            variant=args.sw_variant,
            language=args.sw_language,
            config=args.sw_config,
        )
        print("ShapeWorld dataset: {} (variant: {})".format(
            dataset, args.sw_variant))
        print("Config: " + str(args.sw_config))

        if args.program_generator is not None:
            with open(args.program_generator + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.execution_engine is not None:
            with open(args.execution_engine + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.baseline_model is not None:
            with open(args.baseline_model + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        program_token_to_idx = vocab["program_token_to_idx"]

        include_model = args.model_type in ("PG", "EE", "PG+EE") and (
            args.sw_name.startswith("clevr") or args.sw_program == 3)
        if include_model:

            def preprocess(model):
                if args.sw_name.startswith("clevr"):
                    program_prefix = vr.programs.list_to_prefix(
                        model["program"])
                else:
                    program_prefix = clevr_util.parse_program(mode=0,
                                                              model=model)
                program_str = vr.programs.list_to_str(program_prefix)
                program_tokens = tokenize(program_str)
                program_encoded = encode(program_tokens, program_token_to_idx)
                program_encoded += [
                    program_token_to_idx["<NULL>"]
                    for _ in range(27 - len(program_encoded))
                ]
                return np.asarray(program_encoded, dtype=np.int64)

            if args.sw_name.startswith("clevr"):
                preprocessing = dict(question_model=preprocess)
            else:
                preprocessing = dict(caption_model=preprocess)

        elif args.sw_program in (1, 2):

            def preprocess(caption_pn):
                caption_pn += (caption_pn > 0) * 2
                for n, symbol in enumerate(caption_pn):
                    if symbol == 0:
                        caption_pn[n] = 2
                        break
                caption_pn = np.concatenate(([1], caption_pn))
                return caption_pn

            if args.sw_program == 1:
                preprocessing = dict(caption_pn=preprocess)
            else:
                preprocessing = dict(caption_rpn=preprocess)

        else:
            preprocessing = None

        dataset = torch_util.ShapeWorldDataset(
            dataset=dataset,
            mode=(None if args.sw_mode == "none" else args.sw_mode),
            include_model=include_model,
            epoch=(args.num_samples is None),
            preprocessing=preprocessing,
        )

        loader = ShapeWorldDataLoader(dataset=dataset,
                                      batch_size=args.batch_size)

    model = None
    if args.model_type in ("CNN", "LSTM", "CNN+LSTM", "CNN+LSTM+SA"):
        assert args.baseline_model is not None
        print("Loading baseline model from", args.baseline_model)
        model, _ = utils.load_baseline(args.baseline_model)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            model.rnn.expand_vocab(new_vocab["question_token_to_idx"])
    elif args.program_generator is not None and args.execution_engine is not None:
        pg, _ = utils.load_program_generator(args.program_generator,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.execution_engine,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    elif args.model_type == "FiLM":
        assert args.baseline_model is not None
        pg, _ = utils.load_program_generator(args.baseline_model,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.baseline_model,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    else:
        print(
            "Must give either --baseline_model or --program_generator and --execution_engine"
        )
        return

    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
    if args.question is not None and args.image is not None:
        run_single_example(args, model, dtype, args.question)
    # Interactive mode
    elif (args.image is not None and args.input_question_h5 is None
          and args.input_features_h5 is None):
        feats_var = extract_image_features(args, dtype)
        print(colored("Ask me something!", "cyan"))
        while True:
            # Get user question
            question_raw = input(">>> ")
            run_single_example(args, model, dtype, question_raw, feats_var)
    elif args.sw_name is not None or args.sw_config is not None:
        predictions, visualization = run_batch(args, model, dtype, loader)
        if args.sw_pred_dir is not None:
            assert args.sw_pred_name is not None
            pred_dir = os.path.join(
                args.sw_pred_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            if not os.path.isdir(pred_dir):
                os.makedirs(pred_dir)
            id2word = dataset.dataset.vocabulary(value_type="language")
            with open(
                    os.path.join(
                        pred_dir,
                        args.sw_pred_name + "-" + args.sw_mode + ".txt"),
                    "w",
            ) as filehandle:
                filehandle.write("".join(
                    "{} {} {}\n".format(correct, agreement, " ".join(
                        id2word[c] for c in caption))
                    for correct, agreement, caption in zip(
                        predictions["correct"],
                        predictions["agreement"],
                        predictions["caption"],
                    )))
            print("Predictions saved")
        if args.sw_vis_dir is not None:
            assert args.sw_vis_name is not None
            from io import BytesIO
            from shapeworld.world import World

            vis_dir = os.path.join(
                args.sw_vis_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            image_dir = os.path.join(vis_dir, args.sw_mode, "images")
            if not os.path.isdir(image_dir):
                os.makedirs(image_dir)
            worlds = np.transpose(visualization["world"], (0, 2, 3, 1))
            for n in range(worlds.shape[0]):
                image = World.get_image(world_array=worlds[n])
                image_bytes = BytesIO()
                image.save(image_bytes, format="png")
                with open(os.path.join(image_dir, "world-{}.png".format(n)),
                          "wb") as filehandle:
                    filehandle.write(image_bytes.getvalue())
                image_bytes.close()
            with open(
                    os.path.join(
                        vis_dir,
                        args.sw_vis_name + "-" + args.sw_mode + ".html"),
                    "w",
            ) as filehandle:
                html = dataset.dataset.get_html(
                    generated=visualization,
                    image_format="png",
                    image_dir=(args.sw_mode + "/images/"),
                )
                filehandle.write(html)
            print("Visualization saved")
    else:
        vocab = load_vocab(args)
        loader_kwargs = {
            "question_h5": args.input_question_h5,
            "feature_h5": args.input_features_h5,
            "vocab": vocab,
            "batch_size": args.batch_size,
        }
        if args.family_split_file is not None:
            with open(args.family_split_file, "r") as f:
                loader_kwargs["question_families"] = json.load(f)
        with ClevrDataLoader(**loader_kwargs) as loader:
            run_batch(args, model, dtype, loader)
コード例 #15
0
 def deserialize_value(value_name,
                       value_type,
                       read_file,
                       num_concat_worlds=0,
                       word2id=None):
     value_type, alts = alternatives_type(value_type=value_type)
     if value_type == 'int':
         value = read_file(value_name + '.txt')
         if alts:
             value = [[int(x) for x in xs.split(';')]
                      for xs in value.split('\n')[:-1]]
         else:
             value = [int(x) for x in value.split('\n')[:-1]]
         return value
     elif value_type == 'float':
         value = read_file(value_name + '.txt')
         if alts:
             value = [[float(x) for x in xs.split(';')]
                      for xs in value.split('\n')[:-1]]
         else:
             value = [float(x) for x in value.split('\n')[:-1]]
         return value
     elif value_type == 'vector(int)':
         value = read_file(value_name + '.txt')
         if alts:
             value = [[[int(x) for x in vector.split(',')]
                       for vector in vectors.split(';')]
                      for vectors in value.split('\n')[:-1]]
         else:
             value = [[int(x) for x in vector.split(',')]
                      for vector in value.split('\n')[:-1]]
         return value
     elif value_type == 'vector(float)':
         value = read_file(value_name + '.txt')
         if alts:
             value = [[[float(x) for x in vector.split(',')]
                       for vector in vectors.split(';')]
                      for vectors in value.split('\n')[:-1]]
         else:
             value = [[float(x) for x in vector.split(',')]
                      for vector in value.split('\n')[:-1]]
         return value
     elif value_type == 'text':
         assert word2id
         value = read_file(value_name + '.txt')
         if alts:
             value = [[[word2id[word] for word in text.split(' ')]
                       for text in texts.split('\n')]
                      for texts in value.split('\n\n')[:-1]]
         else:
             value = [[word2id[word] for word in text.split(' ')]
                      for text in value.split('\n')[:-1]]
         return value
     elif value_type == 'world':
         if num_concat_worlds:
             size = ceil(sqrt(num_concat_worlds))
             image_bytes = read_file(value_name + '.bmp', binary=True)
             assert image_bytes is not None
             image_bytes = BytesIO(image_bytes)
             image = Image.open(image_bytes)
             worlds = World.from_image(image)
             height = worlds.shape[0] // ceil(num_concat_worlds / size)
             assert worlds.shape[0] % ceil(num_concat_worlds / size) == 0
             width = worlds.shape[1] // size
             assert worlds.shape[1] % size == 0
             value = []
             for y in range(ceil(num_concat_worlds / size)):
                 for x in range(size if y < num_concat_worlds //
                                size else num_concat_worlds % size):
                     value.append(worlds[y * height:(y + 1) * height,
                                         x * width:(x + 1) * width, :])
         else:
             value = []
             n = 0
             while True:
                 image_bytes = read_file('{}-{}.bmp'.format(value_name, n),
                                         binary=True)
                 if image_bytes is None:
                     break
                 image_bytes = BytesIO(image_bytes)
                 image = Image.open(image_bytes)
                 value.append(World.from_image(image))
                 n += 1
         return value
     elif value_type == 'model':
         value = read_file(value_name + '.json')
         value = json.loads(value)
         return value
コード例 #16
0
 def serialize_value(value,
                     value_name,
                     value_type,
                     write_file,
                     concat_worlds=False,
                     id2word=None):
     value_type, alts = alternatives_type(value_type=value_type)
     if value_type == 'int':
         if alts:
             value = '\n'.join(';'.join(str(int(x)) for x in xs)
                               for xs in value) + '\n'
         else:
             value = '\n'.join(str(int(x)) for x in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'float':
         if alts:
             value = '\n'.join(';'.join(str(float(x)) for x in xs)
                               for xs in value) + '\n'
         else:
             value = '\n'.join(str(float(x)) for x in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'vector(int)' or value_type == 'vector(float)':
         if alts:
             value = '\n'.join(';'.join(','.join(str(x) for x in vector)
                                        for vector in vectors)
                               for vectors in value) + '\n'
         else:
             value = '\n'.join(','.join(str(x) for x in vector)
                               for vector in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'text':
         assert id2word
         if alts:
             value = '\n\n'.join('\n'.join(' '.join(id2word[word_id]
                                                    for word_id in text
                                                    if word_id)
                                           for text in texts)
                                 for texts in value) + '\n\n'
         else:
             value = '\n'.join(' '.join(id2word[word_id]
                                        for word_id in text if word_id)
                               for text in value) + '\n'
         write_file(value_name + '.txt', value)
     elif value_type == 'world':
         if concat_worlds:
             size = ceil(sqrt(len(value)))
             worlds = []
             for y in range(ceil(len(value) / size)):
                 if y < len(value) // size:
                     worlds.append(
                         np.concatenate(
                             [value[y * size + x] for x in range(size)],
                             axis=1))
                 else:
                     worlds.append(
                         np.concatenate([
                             value[y * size + x]
                             for x in range(len(value) % size)
                         ] + [
                             np.zeros_like(a=value[0])
                             for _ in range(-len(value) % size)
                         ],
                                        axis=1))
             worlds = np.concatenate(worlds, axis=0)
             image = World.get_image(world_array=worlds)
             image_bytes = BytesIO()
             image.save(image_bytes, format='bmp')
             write_file(value_name + '.bmp',
                        image_bytes.getvalue(),
                        binary=True)
             image_bytes.close()
         else:
             for n in range(len(value)):
                 image = World.get_image(world_array=value[n])
                 image_bytes = BytesIO()
                 image.save(image_bytes, format='bmp')
                 write_file('{}-{}.bmp'.format(value_name, n),
                            image_bytes.getvalue(),
                            binary=True)
                 image_bytes.close()
     elif value_type == 'model':
         value = json.dumps(value)
         write_file(value_name + '.json', value)