示例#1
0
    def __init__(self, opt, agent, bot, image_idx: int, image_act: Message):
        super().__init__(opt, agent=agent, bot=bot)

        self.image_stack = opt['image_stack']
        self.image_idx = image_idx
        self.image_act = image_act

        # Get a stringified version of the image to show the user
        orig_image = self.image_act['image']
        self.image_src = get_image_src(image=orig_image)

        # Get a featurized version of the image to show the bot
        with NamedTemporaryFile(suffix='.jpg') as f:
            orig_image.save(f)
            image_loader = ImageLoader(self.bot.model_agent.opt)
            self.image_act.force_set('image', image_loader.load(f.name))
def save_image_contexts(task_opt: Opt):
    """
    Save a JSON of images and associated contexts for the model image chat task.

    Note that each image will have BST-style context information saved with it, such as
    persona strings and a pair of lines of dialogue from another dataset.
    TODO: perhaps have the image chat task make use of this context information
    """

    print('Creating teacher to loop over images.')
    agent = RepeatLabelAgent(task_opt)
    world = create_task(task_opt, agent)
    num_examples = task_opt['num_examples']

    print('Creating context generator.')
    context_generator = get_context_generator()

    print(
        f'Looping over {num_examples:d} images and pulling a context for each one.'
    )
    image_contexts = []
    unique_image_srcs = set()
    while len(image_contexts) < num_examples:

        # Get the next teacher act
        world.parley()
        teacher_act = world.get_acts()[0]

        image_src = get_image_src(image=teacher_act['image'])
        if image_src in unique_image_srcs:
            # Skip over non-unique images, such as from the later turns of an episode
            print('\tSkipping non-unique image.')
        else:
            unique_image_srcs.add(image_src)
            image_context = {
                'image_act': teacher_act,
                'context_info': context_generator.get_context(),
            }
            image_contexts.append(image_context)
            if len(image_contexts) % 5 == 0:
                print(f'Collected {len(image_contexts):d} images.')

    print(f'{len(image_contexts):d} image contexts created.')

    # Save
    with open(task_opt['image_context_path'], 'wb') as f:
        pickle.dump(image_contexts, f)