Example #1
0
class FlickrDataset(Dataset):
    """A Pytorch Dataset utilizing streaming"""
    def __init__(self, opt, shared=None):
        self.opt = opt
        self.datatype = self.opt.get('datatype')
        self.training = self.datatype.startswith('train')
        self.num_epochs = self.opt.get('num_epochs', 0)
        self.image_loader = ImageLoader(opt)
        data_path, self.image_path = _path(opt)
        self._setup_data(data_path, opt.get('unittest', False))
        self.dict_agent = DictionaryAgent(opt)

    def __getitem__(self, index):
        cap = self.data[index]
        image_id = int(cap['filename'].replace('.jpg', ''))
        ep = {
            'text': QUESTION,
            'image': self.get_image(image_id),
            'episode_done': True,
        }
        if self.opt.get('extract_image', False):
            ep['image_id'] = image_id
            return ep

        ep['labels'] = [s['raw'] for s in cap['sentences']]
        ep['valid'] = True
        if 'train' not in self.datatype:
            ep['label_candidates'] = self.cands
        return (index, ep)

    def __len__(self):
        return self.num_episodes()

    def _setup_data(self, data_path, unittest):
        with open(data_path) as data_file:
            raw_data = json.load(data_file)['images']
            if 'train' in self.datatype:
                self.data = [d for d in raw_data if d['split'] == 'train']
            elif 'valid' in self.datatype:
                self.data = [d for d in raw_data if d['split'] == 'val']
                self.cands = [l for d in self.data for l in [s['raw'] for s in d['sentences']]]
            else:
                self.data = [d for d in raw_data if d['split'] == 'test']
                self.cands = [l for d in self.data for l in [s['raw'] for s in d['sentences']]]
        if unittest:
            self.caption = self.caption[:10]

    def get_image(self, image_id):
        im_path = os.path.join(self.image_path, '%d.jpg' % (image_id))
        return self.image_loader.load(im_path)

    def num_episodes(self):
        return len(self.data)

    def num_examples(self):
        return self.num_episodes()

    def num_images(self):
        return self.num_episodes()
    def test_other_image_modes(self):
        """
        Test non-featurized image modes.
        """
        with testing_utils.tempdir() as tmp:
            image_file = 'tmp.jpg'
            image_path = os.path.join(tmp, image_file)
            image_zip_path = os.path.join(tmp, 'tmp.zip')
            image = Image.new('RGB', (16, 16), color=0)

            with PathManager.open(image_path, 'wb') as fp:
                image.save(fp, 'JPEG')

            with zipfile.ZipFile(PathManager.open(image_zip_path, 'wb'),
                                 mode='w') as zipf:
                zipf.write(image_path, arcname=image_file)

            for im in ['raw', 'ascii']:
                loader = ImageLoader({"image_mode": im})
                loader.load(image_path)
                loader.load(f"{image_zip_path}/{image_file}")
Example #3
0
    def __init__(self, opt, agent, bot, image_idx: int, image_act: Message):
        super().__init__(opt, agent=agent, bot=bot)

        self.image_stack = opt['image_stack']
        self.image_idx = image_idx
        self.image_act = image_act

        # Get a stringified version of the image to show the user
        orig_image = self.image_act['image']
        self.image_src = get_image_src(image=orig_image)

        # Get a featurized version of the image to show the bot
        with NamedTemporaryFile(suffix='.jpg') as f:
            orig_image.save(f)
            image_loader = ImageLoader(self.bot.model_agent.opt)
            self.image_act.force_set('image', image_loader.load(f.name))
Example #4
0
class DialogData(object):
    """Provides a data structure for accessing textual dialog data.

    ``data_loader`` is an iterable, with each call returning:

        ``(x, ...), new_episode?``

        Where

        - ``x`` is a query and possibly context

        ``...`` can contain additional fields, specifically

        - ``y`` is an iterable of label(s) for that query
        - ``c`` is an iterable of label candidates that the student can choose from
        - ``i`` is a str path to an image on disk, which will be loaded by the data
          class at request-time. should always point to the raw image file.
        - ``o`` is a json dump with other values
        - ``new_episode?`` is a boolean value specifying whether that example is the start of a new episode. If you don't use episodes set this to ``True`` every time.


    ``cands`` can be set to provide a list of candidate labels for every example
    in this dataset, which the agent can choose from (the correct answer
    should be in this set).


    ``random`` tells the data class whether or not to visit episodes sequentially
    or randomly when returning examples to the caller.
    """

    def __init__(self, opt, data_loader=None, cands=None, shared=None, **kwargs):
        # self.data is a list of episodes
        # each episode is a tuple of entries
        # each entry is a tuple of values for the action/observation table
        if shared:
            self.image_loader = shared.get('image_loader', None)
            self.data = shared.get('data', [])
            self.cands = shared.get('cands', None)
        else:
            self.image_loader = ImageLoader(opt)
            self.data = []
            self._load(data_loader, opt['datafile'])
            self.cands = None if cands == None else set(sys.intern(c) for c in cands)
        self.addedCands = []
        self.copied_cands = False

    def share(self):
        shared = {
            'data': self.data,
            'cands': self.cands,
            'image_loader': self.image_loader
        }
        return shared

    def __len__(self):
        """Returns total number of entries available. Each episode has at least
        one entry, but might have many more.
        """
        return sum(len(episode) for episode in self.data)

    def _read_episode(self, data_generator):
        """Reads one episode at a time from the provided iterator over entries.
        """
        episode = []
        last_cands = None
        for entry, new in data_generator:
            if new and len(episode) > 0:
                yield tuple(episode)
                episode = []
                last_cands = None

            # intern all strings so we don't store them more than once
            new_entry = []
            if len(entry) > 0:
                # process text if available
                if entry[0] is not None:
                    new_entry.append(sys.intern(entry[0]))
                else:
                    new_entry.append(None)
                if len(entry) > 1:
                    # process labels if available
                    if entry[1] is None:
                        new_entry.append(None)
                    elif hasattr(entry[1], '__iter__') and type(entry[1]) is not str:
                        # make sure iterable over labels, not single string
                        new_entry.append(tuple(sys.intern(e) for e in entry[1]))
                    else:
                        raise TypeError('Must provide iterable over labels, not a single string.')
                    if len(entry) > 2:
                        # process label candidates if available
                        if entry[2] is None:
                            new_entry.append(None)
                        elif last_cands and entry[2] is last_cands:
                            # if cands are shared, say "same" so we
                            # don't store them again
                            new_entry.append(
                                sys.intern('same as last time'))
                        elif hasattr(entry[2], '__iter__') and type(entry[2]) is not str:
                            # make sure iterable over candidates, not single string
                            last_cands = entry[2]
                            new_entry.append(tuple(
                                sys.intern(e) for e in entry[2]))
                        else:
                            raise TypeError('Must provide iterable over label candidates, not a single string.')
                        if len(entry) > 3:
                            if entry[3] is None:
                                new_entry.append(None)
                            else:
                                new_entry.append(sys.intern(entry[3]))

                            # process other values
                            if len(entry) > 4 and entry[4] is not None:
                                new_entry.append(sys.intern(entry[4]))

            episode.append(tuple(new_entry))

        if len(episode) > 0:
            yield tuple(episode)
    def _load(self, data_loader, datafile):
        """Loads up data from an iterator over tuples described in the class
        docs.
        """
        for episode in self._read_episode(data_loader(datafile)):
            self.data.append(episode)

    def num_episodes(self):
        """Return number of episodes in the dataset."""
        return len(self.data)

    def get(self, episode_idx, entry_idx=0):
        """Returns a specific entry from the dataset."""
        # first look up data
        episode = self.data[episode_idx]
        entry = episode[entry_idx]
        episode_done = entry_idx == len(episode) - 1
        end_of_data = episode_done and episode_idx == len(self.data) - 1

        # now pack it in a action-observation dictionary
        table = self.build_table(entry)

        # last entry in this episode
        table['episode_done'] = episode_done
        return table, end_of_data

    def build_table(self, entry):
        """Packs an entry into an action-observation dictionary."""
        table = {}
        if entry[0] is not None:
            table['text'] = entry[0]
        if len(entry) > 1:
            if entry[1] is not None:
                table['labels'] = entry[1]
            if len(entry) > 2:
                if entry[2] is not None:
                    table['label_candidates'] = entry[2]
                if len(entry) > 3 and entry[3] is not None:
                    img = self.image_loader.load(entry[3])
                    if img is not None:
                        table['image'] = img
                if len(entry) > 4 and entry[4] is not None:
                    for key, val in json.loads(entry[4]).items():
                        table[key] = val

        if (table.get('labels', None) is not None
                and self.cands is not None):
            if self.addedCands:
                # remove elements in addedCands
                self.cands.difference_update(self.addedCands)
                self.addedCands.clear()
            for label in table['labels']:
                if label not in self.cands:
                    # add labels, queue them for removal next time
                    if not self.copied_cands:
                        self.cands = self.cands.copy()
                        self.copied_cands = True
                    self.cands.add(label)
                    self.addedCands.append(label)
            table['label_candidates'] = self.cands

        if 'labels' in table and 'label_candidates' in table:
            if table['labels'][0] not in table['label_candidates']:
                raise RuntimeError('true label missing from candidate labels')
        return table
Example #5
0
class DefaultDataset(Dataset):
    """
    A Pytorch Dataset.
    """

    def __init__(self, opt):
        self.opt = opt
        opt['image_load_task'] = 'personality_captions'
        self.image_mode = opt.get('image_mode', 'no_image_model')
        self.datatype = self.opt.get('datatype')
        self.training = self.datatype.startswith('train')
        self.include_image = opt.get('include_image')
        self.include_personality = opt.get('include_personality')
        self.num_test_labels = opt.get('num_test_labels', 1)
        data_path, personalities_data_path, self.image_path = _path(opt)
        self.image_loader = ImageLoader(opt)
        self._setup_data(data_path, personalities_data_path)

    @staticmethod
    def add_cmdline_args(argparser):
        """
        Add command line args.
        """
        PersonalityCaptionsTeacher.add_cmdline_args(argparser)

    def _setup_data(self, data_path, personalities_data_path):
        print('loading: ' + data_path)
        with open(data_path) as f:
            self.data = json.load(f)
        with open(personalities_data_path) as f:
            self.personalities = json.load(f)

    def __getitem__(self, index):
        data = self.data[index]
        image = self.get_image(data['image_hash'])

        ep = {
            'text': data['personality'] if self.include_personality else '',
            'episode_done': True,
            'image': image if self.include_image else None,
        }

        if self.opt.get('extract_image', False):
            ep['image_id'] = data['image_hash']
            return ep
        ep['labels'] = [data['comment']]
        if self.num_test_labels == 5 and 'test' in self.datatype:
            ep['labels'] += data['additional_comments']
        if not self.training:
            if self.num_test_labels == 5 and 'test' in self.datatype:
                ep['label_candidates'] = data['500_candidates']
            else:
                ep['label_candidates'] = data['candidates']

        return (index, ep)

    def __len__(self):
        return self.num_episodes()

    def get_image(self, image_id):
        """
        Get image.

        :param image_id:
            id of the image

        :return:
            image from the image loader
        """
        im_path = os.path.join(self.image_path, '{}.jpg'.format(image_id))
        return self.image_loader.load(im_path)

    def num_episodes(self):
        """
        Return number of episodes.
        """
        return len(self.data)

    def num_examples(self):
        """
        Return number of examples.
        """
        return self.num_episodes()

    def num_images(self):
        """
        Return number of images.
        """
        if not hasattr(self, 'num_imgs'):
            self.num_imgs = len({d['image_num'] for d in self.data})
        return self.num_imgs
Example #6
0
class FlickrDataset(Dataset):
    """A Pytorch Dataset utilizing streaming"""
    def __init__(self, opt):
        self.opt = opt
        self.use_hdf5 = opt.get('use_hdf5', False)
        self.datatype = self.opt.get('datatype')
        self.training = self.datatype.startswith('train')
        self.num_epochs = self.opt.get('num_epochs', 0)
        self.image_loader = ImageLoader(opt)
        caption_path, self.image_path = _path(opt)
        self._setup_data(caption_path, opt.get('unittest', False))
        if self.use_hdf5:
            try:
                import h5py
                self.h5py = h5py
            except ModuleNotFoundError:
                raise ModuleNotFoundError('Need to install h5py - `pip install h5py`')
            self._setup_image_data()
        self.dict_agent = DictionaryAgent(opt)

    def __getitem__(self, index):
        index %= self.num_episodes()
        cap = self.caption[index]
        ep = {
            'text': self.dict_agent.txt2vec(QUESTION),
            'image': self.get_image(cap['image_id']),
            'episode_done': True,
        }
        if self.opt.get('extract_image', False):
            ep['image_id'] = cap['image_id']
            return ep

        ep['labels'] = [self.dict_agent.txt2vec(cc) for cc in cap['captions']]
        ep['valid'] = True
        ep['use_hdf5'] = self.use_hdf5
        return (index, ep)

    def __len__(self):
        num_epochs = self.num_epochs if self.num_epochs > 0 else 100
        num_iters = num_epochs if self.training else 1
        return int(num_iters * self.num_episodes())

    def _load_lens(self):
        with open(self.length_datafile) as length:
            lengths = json.load(length)
            self.num_eps = lengths['num_eps']
            self.num_exs = lengths['num_exs']

    def _setup_data(self, caption_path, unittest):
        with open(caption_path) as data_file:
            self.caption = []
            prev_img_id = None
            for line in data_file:
                img_id = line.split('#')[0][:-4]
                caption = line.split('\t')[1]
                if img_id != prev_img_id:
                    prev_img_id = img_id
                    to_add = {}
                    to_add['image_id'] = int(img_id)
                    to_add['captions'] = [caption]
                    self.caption.append(to_add)
                else:
                    self.caption[-1]['captions'].append(caption)
        if unittest:
            self.caption = self.caption[:10]
        self.image_paths = set()
        for cap in self.caption:
            self.image_paths.add(os.path.join(self.image_path,
                                              '%d.jpg' % (cap['image_id'])))

    def _setup_image_data(self):
        '''hdf5 image dataset'''
        extract_feats(self.opt)
        im = self.opt.get('image_mode')
        hdf5_path = self.image_path + 'mode_{}_noatt.hdf5'.format(im)
        hdf5_file = self.h5py.File(hdf5_path, 'r')
        self.image_dataset = hdf5_file['images']

        image_id_to_idx_path = self.image_path + 'mode_{}_id_to_idx.txt'.format(im)
        with open(image_id_to_idx_path, 'r') as f:
            self.image_id_to_idx = json.load(f)

    def get_image(self, image_id):
        if not self.use_hdf5:
            im_path = os.path.join(self.image_path, '%d.jpg' % (image_id))
            return self.image_loader.load(im_path)
        else:
            img_idx = self.image_id_to_idx[str(image_id)]
            return torch.Tensor(self.image_dataset[img_idx])

    def num_episodes(self):
        return len(self.caption)

    def num_examples(self):
        return self.num_episodes()

    def num_images(self):
        return self.num_episodes()
Example #7
0
class DefaultDataset(Dataset):
    """A Pytorch Dataset utilizing streaming."""
    def __init__(self, opt, version='2014'):
        self.opt = opt
        self.use_hdf5 = opt.get('use_hdf5', False)
        self.datatype = self.opt.get('datatype')
        self.training = self.datatype.startswith('train')
        self.num_epochs = self.opt.get('num_epochs', 0)
        self.image_loader = ImageLoader(opt)
        test_info_path, annotation_path, self.image_path = _path(opt, version)
        self._setup_data(test_info_path, annotation_path,
                         opt.get('unittest', False))
        if self.use_hdf5:
            try:
                import h5py
                self.h5py = h5py
            except ImportError:
                raise ImportError('Need to install h5py - `pip install h5py`')
            self._setup_image_data()
        self.dict_agent = DictionaryAgent(opt)

    def __getitem__(self, index):
        index %= self.num_episodes()
        image_id = None
        if not self.datatype.startswith('test'):
            anno = self.annotation['annotations'][index]
            image_id = anno['image_id']
        else:
            image_id = self.test_info['images'][index]['id']
        ep = {
            'text': self.dict_agent.txt2vec(QUESTION),
            'image': self.get_image(image_id),
            'episode_done': True,
        }
        if self.opt.get('extract_image', False):
            ep['image_id'] = image_id
            return ep
        if not self.datatype.startswith('test'):
            anno = self.annotation['annotations'][index]
            ep['labels'] = [anno['caption']]
            ep['valid'] = True
        else:
            ep['valid'] = True
        ep['use_hdf5'] = self.use_hdf5
        return (index, ep)

    def __len__(self):
        num_epochs = self.num_epochs if self.num_epochs > 0 else 100
        num_iters = num_epochs if self.training else 1
        return int(num_iters * self.num_episodes())

    def _load_lens(self):
        with open(self.length_datafile) as length:
            lengths = json.load(length)
            self.num_eps = lengths['num_eps']
            self.num_exs = lengths['num_exs']

    def _setup_data(self, test_info_path, annotation_path, unittest):
        if not self.datatype.startswith('test'):
            with open(annotation_path) as data_file:
                self.annotation = json.load(data_file)
        else:
            with open(test_info_path) as data_file:
                self.test_info = json.load(data_file)

        if unittest:
            if not self.datatype.startswith('test'):
                self.annotation['annotations'] = self.annotation[
                    'annotations'][:10]
            else:
                self.test_info['images'] = self.test_info['images'][:10]

        self.image_paths = set()
        # Depending on whether we are using the train/val/test set, we need to
        # find the image IDs in annotations or test image info
        if not self.datatype.startswith('test'):
            for anno in self.annotation['annotations']:
                self.image_paths.add(
                    os.path.join(self.image_path,
                                 '%012d.jpg' % (anno['image_id'])))
        else:
            for info in self.test_info['images']:
                self.image_paths.add(
                    os.path.join(self.image_path, '%012d.jpg' % (info['id'])))

    def _setup_image_data(self):
        '''hdf5 image dataset'''
        extract_feats(self.opt)
        im = self.opt.get('image_mode')
        hdf5_path = os.path.join(self.image_path,
                                 'mode_{}_noatt.hdf5'.format(im))
        hdf5_file = self.h5py.File(hdf5_path, 'r')
        self.image_dataset = hdf5_file['images']

        image_id_to_idx_path = os.path.join(self.image_path,
                                            'mode_{}_id_to_idx.txt'.format(im))
        with open(image_id_to_idx_path, 'r') as f:
            self.image_id_to_idx = json.load(f)

    def get_image(self, image_id):
        if not self.use_hdf5:
            im_path = os.path.join(self.image_path, '%012d.jpg' % (image_id))
            return self.image_loader.load(im_path)
        else:
            img_idx = self.image_id_to_idx[str(image_id)]
            return torch.Tensor(self.image_dataset[img_idx])

    def num_examples(self):
        if not self.datatype.startswith('test'):
            return len(self.annotation['annotations'])
        else:
            return len(self.test_info['images'])

    def num_episodes(self):
        return self.num_examples()

    def num_images(self):
        if not hasattr(self, 'num_imgs'):
            return self.num_examples()
        return self.num_imgs
Example #8
0
class DefaultDataset(Dataset):
    """A Pytorch Dataset"""
    def __init__(self, opt):
        self.opt = opt
        opt['image_load_task'] = 'image_chat'
        self.image_mode = opt.get('image_mode', 'none')
        self.datatype = self.opt.get('datatype')
        self.training = self.datatype.startswith('train')
        self.include_image = opt.get('include_image')
        self.include_personality = opt.get('include_personality')
        self.num_cands = opt.get('num_cands')
        data_path, personalities_data_path, self.image_path = _path(opt)
        self.image_loader = ImageLoader(opt)
        self._setup_data(data_path, personalities_data_path)

    @staticmethod
    def add_cmdline_args(argparser):
        ImageChatTeacher.add_cmdline_args(argparser)

    def _setup_data(self, data_path, personalities_data_path):
        print('loading: ' + data_path)
        with open(data_path) as f:
            self.data = json.load(f)
        with open(personalities_data_path) as f:
            self.personalities = json.load(f)

    def __getitem__(self, index):
        data = self.data[index]
        dialog = data['dialog']
        personality = dialog[-1][0]
        text = ''
        if len(dialog) > 1:
            text = '\n'.join((dialog[0][1], dialog[1][1]))
        if self.include_personality:
            text += personality
        label = dialog[-1][1]
        image = self.get_image(data['image_hash'])

        ep = {
            'text': text,
            'episode_done': True,
            'image': image if self.include_image else None,
        }

        if self.opt.get('extract_image', False):
            ep['image_id'] = data['image_hash']
            return ep
        if not self.opt['datatype'].startswith('test'):
            ep['labels'] = [label]
        if not self.training:
            ep['label_candidates'] = data['candidates'][-1][self.num_cands]

        return (index, ep)

    def __len__(self):
        return self.num_episodes()

    def get_image(self, image_id):
        im_path = os.path.join(self.image_path, '{}.jpg'.format(image_id))
        return self.image_loader.load(im_path)

    def num_episodes(self):
        return len(self.data)

    def num_examples(self):
        return self.num_episodes()

    def num_images(self):
        if not hasattr(self, 'num_imgs'):
            self.num_imgs = len({d['image_num'] for d in self.data})
        return self.num_imgs
Example #9
0
class DefaultDataset(Dataset):
    """A Pytorch Dataset utilizing streaming."""
    def __init__(self, opt, version='2014'):
        self.opt = opt
        self.use_intro = opt.get('use_intro')
        self.num_cands = opt.get('num_cands')
        self.use_hdf5 = opt.get('use_hdf5', False)
        self.datatype = self.opt.get('datatype')
        self.training = self.datatype.startswith('train')
        self.num_epochs = self.opt.get('num_epochs', 0)
        self.image_loader = ImageLoader(opt)
        test_info_path, annotation_path, self.image_path = _path(opt, version)
        self.cands = load_candidates(opt['datapath'], self.datatype, version)
        self._setup_data(test_info_path, annotation_path,
                         opt.get('unittest', False))
        if self.use_hdf5:
            try:
                import h5py
                self.h5py = h5py
            except ImportError:
                raise ImportError('Need to install h5py - `pip install h5py`')
            self._setup_image_data()
        self.dict_agent = DictionaryAgent(opt)

    @staticmethod
    def add_cmdline_args(argparser):
        agent = argparser.add_argument_group('Comment Battle arguments')
        agent.add_argument('--use_intro',
                           type="bool",
                           default=False,
                           help='Include an intro question with each image \
                                for readability (e.g. for coco_caption, \
                                Describe the above picture in a sentence.)')
        agent.add_argument('--num_cands',
                           type=int,
                           default=150,
                           help='Number of candidates to use during \
                                evaluation, setting to -1 uses all.')

    def __getitem__(self, index):
        index %= self.num_episodes()
        image_id = None
        if not self.datatype.startswith('test'):
            anno = self.annotation['annotations'][index]
            image_id = anno['image_id']
        else:
            image_id = self.test_info['images'][index]['id']
        ep = {
            'image': self.get_image(image_id),
            'episode_done': True,
        }
        if self.use_intro:
            ep['text'] = QUESTION
        if self.opt.get('extract_image', False):
            ep['image_id'] = image_id
            return ep
        if not self.datatype.startswith('test'):
            anno = self.annotation['annotations'][index]
            ep['labels'] = [anno['caption']]
            if not self.datatype.startswith('train'):
                if self.num_cands == -1:
                    candidates = self.cands
                else:
                    # Can only randomly select from validation set
                    candidates = random.Random(index).choices(self.cands,
                                                              k=self.num_cands)
                if anno['caption'] not in candidates:
                    candidates.pop(0)
                else:
                    candidates.remove(anno['caption'])

                candidate_labels = [anno['caption']]
                candidate_labels += candidates

                ep['label_candidates'] = candidate_labels
        else:
            # TESTING
            if self.num_cands == -1:
                candidates = self.cands
            else:
                # Can select from train+test set
                candidates = random.Random(index).choices(self.cands,
                                                          k=self.num_cands)
            ep['label_candidates'] = candidates

        ep['use_hdf5'] = self.use_hdf5
        return (index, ep)

    def __len__(self):
        num_epochs = self.num_epochs if self.num_epochs > 0 else 100
        num_iters = num_epochs if self.training else 1
        return int(num_iters * self.num_episodes())

    def _load_lens(self):
        with open(self.length_datafile) as length:
            lengths = json.load(length)
            self.num_eps = lengths['num_eps']
            self.num_exs = lengths['num_exs']

    def _setup_data(self, test_info_path, annotation_path, unittest):
        if not self.datatype.startswith('test'):
            with open(annotation_path) as data_file:
                self.annotation = json.load(data_file)
        else:
            with open(test_info_path) as data_file:
                self.test_info = json.load(data_file)

        if unittest:
            if not self.datatype.startswith('test'):
                self.annotation['annotations'] = self.annotation[
                    'annotations'][:10]
            else:
                self.test_info['images'] = self.test_info['images'][:10]

        self.image_paths = set()
        # Depending on whether we are using the train/val/test set, we need to
        # find the image IDs in annotations or test image info
        if not self.datatype.startswith('test'):
            for anno in self.annotation['annotations']:
                self.image_paths.add(self.image_path + '%012d.jpg' %
                                     (anno['image_id']))
        else:
            for info in self.test_info['images']:
                self.image_paths.add(self.image_path + '%012d.jpg' %
                                     (info['id']))

    def _setup_image_data(self):
        '''hdf5 image dataset'''
        extract_feats(self.opt)
        im = self.opt.get('image_mode')
        hdf5_path = os.path.join(self.image_path,
                                 'mode_{}_noatt.hdf5'.format(im))
        hdf5_file = self.h5py.File(hdf5_path, 'r')
        self.image_dataset = hdf5_file['images']

        image_id_to_idx_path = os.path.join(self.image_path,
                                            'mode_{}_id_to_idx.txt'.format(im))
        with open(image_id_to_idx_path, 'r') as f:
            self.image_id_to_idx = json.load(f)

    def get_image(self, image_id):
        if not self.use_hdf5:
            im_path = os.path.join(self.image_path, '%012d.jpg' % (image_id))
            return self.image_loader.load(im_path)
        else:
            img_idx = self.image_id_to_idx[str(image_id)]
            return torch.Tensor(self.image_dataset[img_idx])

    def num_examples(self):
        if not self.datatype.startswith('test'):
            return len(self.annotation['annotations'])
        else:
            return len(self.test_info['images'])

    def num_episodes(self):
        return self.num_examples()

    def num_images(self):
        if not hasattr(self, 'num_imgs'):
            return self.num_examples()
        return self.num_imgs
Example #10
0
class DefaultDataset(Dataset):
    """A Pytorch Dataset utilizing streaming."""
    def __init__(self, opt, version='2017'):
        self.opt = opt
        self.version = version
        self.use_intro = opt.get('use_intro', False)
        self.num_cands = opt.get('num_cands', -1)
        self.datatype = self.opt.get('datatype')
        self.include_rest_val = opt.get('include_rest_val', True)
        self.image_loader = ImageLoader(opt)
        test_info_path, annotation_path, self.image_path = _path(opt, version)
        self._setup_data(test_info_path, annotation_path, opt)

    @staticmethod
    def add_cmdline_args(argparser):
        DefaultTeacher.add_cmdline_args(argparser)

    def __getitem__(self, index):
        ep = {'episode_done': True}
        if self.use_intro:
            ep['text'] = QUESTION

        if hasattr(self, 'annotation'):
            anno = self.annotation[index]
        else:
            anno = self.test_info['images'][index]

        if self.version == '2014':
            ep['labels'] = [s['raw'] for s in anno['sentences']]
            ep['image_id'] = anno['cocoid']
            ep['split'] = anno['split']
        elif not self.datatype.startswith('test'):
            ep['image_id'] = anno['image_id']
            ep['labels'] = [anno['caption']]
        else:
            ep['image_id'] = anno['id']

        ep['image']: self.get_image(ep['image_id'], anno.get('split', None))

        if self.opt.get('extract_image', False):
            return ep

        # Add Label Cands
        if not self.datatype.startswith('train'):
            if self.num_cands == -1:
                ep['label_candidates'] = self.cands
            else:
                candidates = random.Random(index).choices(self.cands,
                                                          k=self.num_cands)
                label = random.choice(ep.get('labels', ['']))
                if not (label == '' or label in candidates):
                    candidates.pop(0)
                    candidates.append(label)
                    random.shuffle(candidates)
                ep['label_candidates'] = candidates

        return (index, ep)

    def __len__(self):
        return self.num_episodes()

    def _load_lens(self):
        with open(self.length_datafile) as length:
            lengths = json.load(length)
            self.num_eps = lengths['num_eps']
            self.num_exs = lengths['num_exs']

    def _setup_data(self, test_info_path, annotation_path, opt):
        if self.version == '2014':
            with open(annotation_path) as data_file:
                raw_data = json.load(data_file)['images']
            if 'train' in self.datatype:
                self.annotation = [
                    d for d in raw_data if d['split'] == 'train'
                ]
                if self.include_rest_val:
                    self.annotation += [
                        d for d in raw_data if d['split'] == 'restval'
                    ]
            elif 'valid' in self.datatype:
                self.annotation = [d for d in raw_data if d['split'] == 'val']
                self.cands = [
                    l for d in self.annotation
                    for l in [s['raw'] for s in d['sentences']]
                ]
            else:
                self.annotation = [d for d in raw_data if d['split'] == 'test']
                self.cands = [
                    l for d in self.annotation
                    for l in [s['raw'] for s in d['sentences']]
                ]
        else:
            if not self.datatype.startswith('test'):
                print('loading: ' + annotation_path)
                with open(annotation_path) as data_file:
                    self.annotation = json.load(data_file)['annotations']
            else:
                print('loading: ' + test_info_path)
                with open(test_info_path) as data_file:
                    self.test_info = json.load(data_file)
            if not self.datatype.startswith('train'):
                self.cands = load_candidates(opt['datapath'], opt['datatype'],
                                             self.version)
        if opt.get('unittest', False):
            if not self.datatype.startswith('test'):
                self.annotation = self.annotation[:10]
            else:
                self.test_info['images'] = self.test_info['images'][:10]

    def get_image(self, image_id, split):
        if split == 'restval':
            im_path = self.image_path.replace('train', 'val')
        else:
            im_path = self.image_path
        im_path = os.path.join(im_path, '%012d.jpg' % (image_id))
        return self.image_loader.load(im_path)

    def num_examples(self):
        if self.version == '2014' or not self.datatype.startswith('test'):
            return len(self.annotation)
        else:
            # For 2017, we only have annotations for the train and val sets,
            # so for the test set we need to determine how many images we have.
            return len(self.test_info['images'])

    def num_episodes(self):
        return self.num_examples()

    def num_images(self):
        if not hasattr(self, 'num_imgs'):
            return self.num_examples()
        return self.num_imgs
Example #11
0
class OeTeacher(Teacher):
    """VQA v2.0 Open-Ended teacher, which loads the json VQA data and
    implements its own `act` method for interacting with student agent.
    agent.
    """
    def __init__(self, opt, shared=None):
        super().__init__(opt)
        self.datatype = opt['datatype']
        data_path, annotation_path, self.image_path = _path(opt)

        if shared and 'ques' in shared:
            self.ques = shared['ques']
            if 'annotation' in shared:
                self.annotation = shared['annotation']
            self.image_loader = shared['image_loader']
        else:
            self._setup_data(data_path, annotation_path)
            self.image_loader = ImageLoader(opt)

        self.len = len(self.ques['questions'])

        # for ordered data in batch mode (especially, for validation and
        # testing), each teacher in the batch gets a start index and a step
        # size so they all process disparate sets of the data
        self.step_size = opt.get('batchsize', 1)
        self.data_offset = opt.get('batchindex', 0)

        self.reset()

    def __len__(self):
        return self.len

    def reset(self):
        # Reset the dialog so that it is at the start of the epoch,
        # and all metrics are reset.
        super().reset()
        self.lastY = None
        self.episode_idx = self.data_offset - self.step_size

    def observe(self, observation):
        """Process observation for metrics."""
        if self.lastY is not None:
            self.metrics.update(observation, self.lastY)
            self.lastY = None
        return observation

    def act(self):
        if self.datatype == 'train':
            self.episode_idx = random.randrange(self.len)
        else:
            self.episode_idx = (self.episode_idx + self.step_size) % len(self)
            if self.episode_idx == len(self) - self.step_size:
                self.epochDone = True

        qa = self.ques['questions'][self.episode_idx]
        question = qa['question']
        image_id = qa['image_id']

        img_path = self.image_path + '%012d.jpg' % (image_id)

        action = {
            'image': self.image_loader.load(img_path),
            'text': question,
            'episode_done': True
        }

        if not self.datatype.startswith('test'):
            anno = self.annotation['annotations'][self.episode_idx]
            self.lastY = [ans['answer'] for ans in anno['answers']]

        if self.datatype.startswith('train'):
            action['labels'] = self.lastY

        return action

    def share(self):
        shared = super().share()
        shared['ques'] = self.ques
        if hasattr(self, 'annotation'):
            shared['annotation'] = self.annotation
        shared['image_loader'] = self.image_loader
        return shared

    def _setup_data(self, data_path, annotation_path):
        print('loading: ' + data_path)
        with open(data_path) as data_file:
            self.ques = json.load(data_file)

        if self.datatype != 'test':
            print('loading: ' + annotation_path)
            with open(annotation_path) as data_file:
                self.annotation = json.load(data_file)
Example #12
0
class QADataCollectionWorld(MTurkTaskWorld):
    """
    World for recording a turker's question and answer given a context.

    Assumes the context is a random context from a given task, e.g. from SQuAD, CBT,
    etc.
    """

    collector_agent_id = 'QA Collector'

    def __init__(self, opt, task, mturk_agent):
        self.task = task
        self.mturk_agent = mturk_agent
        self.episodeDone = False
        self.turn_index = -1
        self.context = None
        self.question = None
        self.answer = None
        self.image_loader = ImageLoader(opt)

    def parley(self):
        # Each turn starts from the QA Collector agent
        # self.turn_index = (self.turn_index + 1) % 2
        self.turn_index = self.turn_index + 1
        ad = {'episode_done': False}
        ad['id'] = self.__class__.collector_agent_id

        if self.turn_index == 0:
            # At the first turn, the QA Collector agent provides the context
            # and prompts the turker to ask a question regarding the context

            # Get context from SQuAD teacher agent
            qa = self.task.act()
            self.context = '\n'.join(qa['text'].split('\n')[:-1])

            context_list = self.context.split('\n')
            table = "<table border=\"1\"><tr> <td>{}, {}</td> <td>{}, {}</td> </tr> <tr> <td>{}, {}</td> <td>{}, {}</td> </tr> </table>".format(
                context_list[0], context_list[1], context_list[2],
                context_list[3])

            # Wrap the context with a prompt telling the turker what to do next
            # ad['text'] = (
            #     self.context + '\n\nPlease provide a question given this context.'
            # )
            ad['text'] = (context_list +
                          '\n\nPlease provide a question given this context.')

            img = self.image_loader.load(
                "/Users/shiquan/PycharmProjects/ParlAI/parlai/mturk/tasks/qa_data_collection/banana.jpg"
            )
            buffered = BytesIO()
            img.save(buffered, format="JPEG")
            encoded = str(
                base64.b64encode(buffered.getvalue()).decode('ascii'))
            ad['image'] = encoded

            self.mturk_agent.observe(validate(ad))
            self.question = self.mturk_agent.act()
            # Can log the turker's question here

        if self.turn_index >= 1 and self.turn_index < 5:
            # At the second turn, the QA Collector collects the turker's
            # question from the first turn, and then prompts the
            # turker to provide the answer

            # A prompt telling the turker what to do next
            ad['text'] = 'Thanks. And what is the answer to your question?'

            self.mturk_agent.observe(validate(ad))
            self.answer = self.mturk_agent.act()

        if self.turn_index == 5:
            # At the second turn, the QA Collector collects the turker's
            # question from the first turn, and then prompts the
            # turker to provide the answer

            # A prompt telling the turker what to do next
            ad['text'] = 'Thanks. And what is the answer to your question?'

            ad['episode_done'] = True  # end of episode

            self.mturk_agent.observe(validate(ad))
            self.answer = self.mturk_agent.act()

            self.episodeDone = True

    def episode_done(self):
        return self.episodeDone

    def shutdown(self):
        self.task.shutdown()
        self.mturk_agent.shutdown()

    def review_work(self):
        # Can review the work here to accept or reject it
        pass

    def get_custom_task_data(self):
        # brings important data together for the task, to later be used for
        # creating the dataset. If data requires pickling, put it in a field
        # called 'needs-pickle'.
        return {'context': self.context, 'acts': [self.question, self.answer]}
Example #13
0
class SplitTeacher(Teacher):
    """FVQA Teacher, which loads the json VQA data and implements its own
    `act` method for interacting with student agent.

    Use "fvqa:split:X" to choose between splits 0-4 (inclusive), or just
    "fvqa" to use the default split (0).
    """

    def __init__(self, opt, shared=None):
        super().__init__(opt)

        dt = opt['datatype'].split(':')[0]
        if dt not in ('train', 'test'):
            raise RuntimeError('Not valid datatype (only train/test).')

        task = opt.get('task', 'fvqa:split:0')
        task_num = 0  # default to train/split 0
        split = task.split(':')
        if len(split) > 2:
            task_num = split[2]
            if task_num not in [str(i) for i in range(5)]:
                raise RuntimeError('Invalid train/test split ID (0-4 inclusive)')

        if not hasattr(self, 'factmetrics'):
            if shared and shared.get('factmetrics'):
                self.factmetrics = shared['factmetrics']
            else:
                self.factmetrics = Metrics(opt)
            self.datatype = opt['datatype']
        questions_path, trainset_path, self.image_path = _path(opt)

        if shared and 'ques' in shared:
            self.ques = shared['ques']
        else:
            self._setup_data(questions_path, trainset_path, dt, task_num)
        self.len = len(self.ques)

        self.asked_question = False
        # for ordered data in batch mode (especially, for validation and
        # testing), each teacher in the batch gets a start index and a step
        # size so they all process disparate sets of the data
        self.step_size = opt.get('batchsize', 1)
        self.data_offset = opt.get('batchindex', 0)
        self.image_loader = ImageLoader(opt)

        self.reset()

    def __len__(self):
        return self.len

    def report(self):
        r = super().report()
        r['factmetrics'] = self.factmetrics.report()
        return r

    def reset(self):
        # Reset the dialog so that it is at the start of the epoch,
        # and all metrics are reset.
        super().reset()
        self.lastY = None
        self.episode_idx = self.data_offset - self.step_size
        self.epochDone = False

    def reset_metrics(self):
        super().reset_metrics()
        self.factmetrics.clear()

    def observe(self, observation):
        """Process observation for metrics."""
        if self.lastY is not None:
            if self.asked_question:
                self.metrics.update(observation, self.lastY[0])
            else:
                self.factmetrics.update(observation, self.lastY[1])
                self.lastY = None
        return observation

    def act(self):
        if self.asked_question:
            self.asked_question = False
            action = {'text': 'Which fact supports this answer?', 'episode_done': True}
            if self.datatype.startswith('train'):
                action['labels'] = self.lastY[1]
            if self.datatype != 'train' and self.episode_idx + self.step_size >= len(self):
                self.epochDone = True
            return action

        if self.datatype == 'train':
            self.episode_idx = random.randrange(self.len)
        else:
            self.episode_idx = (self.episode_idx + self.step_size) % len(self)

        self.asked_question = True
        qa = self.ques[self.episode_idx]
        question = qa['question']
        img_path = self.image_path + qa['img_file']

        action = {
            'image': self.image_loader.load(img_path),
            'text': question,
            'episode_done': False
        }

        human_readable = qa['fact_surface'].replace('[', '').replace(']', '')
        self.lastY = [[qa['answer']], [human_readable]]

        if self.datatype.startswith('train'):
            action['labels'] = self.lastY[0]

        return action

    def share(self):
        shared = super().share()
        shared['factmetrics'] = self.factmetrics
        shared['ques'] = self.ques
        if hasattr(self, 'facts'):
            shared['facts'] = self.facts
        return shared

    def _setup_data(self, questions_path, trainset_path, datatype, task_num):
        print('loading: ' + questions_path)
        with open(questions_path) as questions_file:
            questions = json.load(questions_file)
        train_test_images = set()
        with open(os.path.join(trainset_path, '{}_list_{}.txt'.format(datatype, task_num))) as imageset:
            for line in imageset:
                train_test_images.add(line.strip())
        self.ques = [questions[k] for k in sorted(questions.keys()) if questions[k]['img_file'] in train_test_images]
Example #14
0
class OeTeacher(Teacher):
    """
    VQA Open-Ended teacher, which loads the json vqa data and implements its
    own `act` method for interacting with student agent.
    """
    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        self.datatype = opt['datatype']
        data_path, annotation_path, self.image_path = _path(opt)

        if shared and 'ques' in shared:
            self.ques = shared['ques']
            if 'annotation' in shared:
                self.annotation = shared['annotation']
        else:
            self._setup_data(data_path, annotation_path)

        # for ordered data in batch mode (especially, for validation and
        # testing), each teacher in the batch gets a start index and a step
        # size so they all process disparate sets of the data
        self.step_size = opt.get('batchsize', 1)
        self.data_offset = opt.get('batchindex', 0)
        self.image_loader = ImageLoader(opt)
        self.reset()

    def __len__(self):
        return len(self.ques['questions'])

    def reset(self):
        # Reset the dialog so that it is at the start of the epoch,
        # and all metrics are reset.
        super().reset()
        self.lastY = None
        self.episode_idx = self.data_offset - self.step_size

    def observe(self, observation):
        """Process observation for metrics."""
        if self.lastY is not None:
            self.metrics.update(observation, self.lastY)
            self.lastY = None
        return observation

    def act(self):
        if self.datatype == 'train':
            self.episode_idx = random.randrange(len(self))
        else:
            self.episode_idx = (self.episode_idx + self.step_size) % len(self)
            if self.episode_idx == len(self) - self.step_size:
                self.epochDone = True

        qa = self.ques['questions'][self.episode_idx]
        question = qa['question']
        image_id = qa['image_id']

        img_path = self.image_path + '%012d.jpg' % (image_id)

        action = {
            'image': self.image_loader.load(img_path),
            'text': question,
            'episode_done': True
        }

        if not self.datatype.startswith('test'):
            anno = self.annotation['annotations'][self.episode_idx]
            self.lastY = [ans['answer'] for ans in anno['answers']]

        if self.datatype.startswith('train'):
            action['labels'] = self.lastY

        return action

    def share(self):
        shared = super().share()
        shared['ques'] = self.ques
        if hasattr(self, 'annotation'):
            shared['annotation'] = self.annotation
        return shared

    def _setup_data(self, data_path, annotation_path):
        print('loading: ' + data_path)
        with open(data_path) as data_file:
            self.ques = json.load(data_file)

        if self.datatype != 'test':
            print('loading: ' + annotation_path)
            with open(annotation_path) as data_file:
                self.annotation = json.load(data_file)
Example #15
0
class SplitTeacher(Teacher):
    """FVQA Teacher, which loads the json VQA data and implements its own
    `act` method for interacting with student agent.

    Use "fvqa:split:X" to choose between splits 0-4 (inclusive), or just
    "fvqa" to use the default split (0).
    """
    def __init__(self, opt, shared=None):
        super().__init__(opt)

        dt = opt['datatype'].split(':')[0]
        if dt not in ('train', 'test'):
            raise RuntimeError('Not valid datatype (only train/test).')

        task = opt.get('task', 'fvqa:split:0')
        task_num = 0  # default to train/split 0
        split = task.split(':')
        if len(split) > 2:
            task_num = split[2]
            if task_num not in [str(i) for i in range(5)]:
                raise RuntimeError(
                    'Invalid train/test split ID (0-4 inclusive)')

        if not hasattr(self, 'factmetrics'):
            if shared and shared.get('factmetrics'):
                self.factmetrics = shared['factmetrics']
            else:
                self.factmetrics = Metrics(opt)
            self.datatype = opt['datatype']
        questions_path, trainset_path, self.image_path = _path(opt)

        if shared and 'ques' in shared:
            self.ques = shared['ques']
        else:
            self._setup_data(questions_path, trainset_path, dt, task_num)
        self.len = len(self.ques)

        self.asked_question = False
        # for ordered data in batch mode (especially, for validation and
        # testing), each teacher in the batch gets a start index and a step
        # size so they all process disparate sets of the data
        self.step_size = opt.get('batchsize', 1)
        self.data_offset = opt.get('batchindex', 0)
        self.image_loader = ImageLoader(opt)

        self.reset()

    def num_examples(self):
        return self.len

    def num_episodes(self):
        return self.len

    def report(self):
        r = super().report()
        r['factmetrics'] = self.factmetrics.report()
        return r

    def reset(self):
        # Reset the dialog so that it is at the start of the epoch,
        # and all metrics are reset.
        super().reset()
        self.lastY = None
        self.episode_idx = self.data_offset - self.step_size
        self.epochDone = False

    def reset_metrics(self):
        super().reset_metrics()
        self.factmetrics.clear()

    def observe(self, observation):
        """Process observation for metrics."""
        if self.lastY is not None:
            if self.asked_question:
                self.metrics.update(observation, self.lastY[0])
            else:
                self.factmetrics.update(observation, self.lastY[1])
                self.lastY = None
        return observation

    def act(self):
        if self.asked_question:
            self.asked_question = False
            action = {
                'text': 'Which fact supports this answer?',
                'episode_done': True
            }
            if self.datatype.startswith('train'):
                action['labels'] = self.lastY[1]
            if self.datatype != 'train' and self.episode_idx + self.step_size >= self.num_episodes(
            ):
                self.epochDone = True
            return action

        if self.datatype == 'train':
            self.episode_idx = random.randrange(self.len)
        else:
            self.episode_idx = (self.episode_idx +
                                self.step_size) % self.num_episodes()

        self.asked_question = True
        qa = self.ques[self.episode_idx]
        question = qa['question']
        img_path = self.image_path + qa['img_file']

        action = {
            'image': self.image_loader.load(img_path),
            'text': question,
            'episode_done': False
        }

        human_readable = qa['fact_surface'].replace('[', '').replace(']', '')
        self.lastY = [[qa['answer']], [human_readable]]

        if self.datatype.startswith('train'):
            action['labels'] = self.lastY[0]

        return action

    def share(self):
        shared = super().share()
        shared['factmetrics'] = self.factmetrics
        shared['ques'] = self.ques
        if hasattr(self, 'facts'):
            shared['facts'] = self.facts
        return shared

    def _setup_data(self, questions_path, trainset_path, datatype, task_num):
        print('loading: ' + questions_path)
        with open(questions_path) as questions_file:
            questions = json.load(questions_file)
        train_test_images = set()
        with open(
                os.path.join(trainset_path,
                             '{}_list_{}.txt'.format(datatype,
                                                     task_num))) as imageset:
            for line in imageset:
                train_test_images.add(line.strip())
        self.ques = [
            questions[k] for k in sorted(questions.keys())
            if questions[k]['img_file'] in train_test_images
        ]
Example #16
0
class MTurkIGCEvalWorld(MultiAgentDialogWorld):
    """World where an agent observes 5 images and 3 comments about the images,
       and ranks the comments
    """
    def __init__(self, opt, agents=None, shared=None, world_tag='NONE'):
        self.turn_idx = 0
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.world_tag = world_tag
        self.max_resp_time = opt['max_resp_time']  # in secs
        super().__init__(opt, agents, shared)
        self.agents = agents
        self.agent = agents[0]
        self.data = []
        self.exact_match = False
        self.num_images = opt['num_images']
        self.d_rnd = opt.get('dialog_round')
        self.image_path = opt.get('image_path')
        self.task_dir = opt['task_dir']
        opt['image_mode'] = 'raw'
        self.image_loader = ImageLoader(opt)

    def episode_done(self):
        return self.chat_done

    def parley(self):
        """RATER is given an image, context (and possibly some questions)
           and is asked to rate the responses.
        """
        # Initial Message Value
        control_msg = {'episode_done': False}
        control_msg['id'] = 'SYSTEM'
        """First, we give RATER the image and context
        """
        while self.turn_idx < self.num_images:
            print(self.world_tag + ' is at turn {}...'.format(self.turn_idx))
            # Send image to turker
            if self.d_rnd == 'questions':
                control_msg['description'] = config_questions[
                    'task_description']
            else:
                control_msg['description'] = config_responses[
                    'task_description']
            self.example_id, igc_example = self.agent.example_generator.pop_example(
            )
            img = self.image_loader.load(self.image_id_to_path(
                self.example_id))
            buffered = BytesIO()
            img.save(buffered, format="JPEG")
            encoded = str(
                base64.b64encode(buffered.getvalue()).decode('ascii'))
            control_msg['image'] = encoded
            control_msg['context'] = igc_example['context']
            """
                Setup Options for rating
            """
            if self.d_rnd == 'questions':
                options = [(k, v) for k, v in igc_example['questions'].items()]
            else:
                control_msg['question'] = igc_example['question']
                options = [(k, v) for k, v in igc_example['responses'].items()]
            random.shuffle(options)
            options, dup_dict = self.filter_option_duplicates(options)
            control_msg['options'] = [c[1] for c in options]
            # Collect rating from turker
            rate_msg = RATE_MSG if self.d_rnd == 'questions' else RATE_RESPONSE_MSG
            control_msg['text'] = rate_msg.format(self.turn_idx + 1)
            control_msg['new_eval'] = True
            self.agent.observe(validate(control_msg))
            time.sleep(1)
            act = self.agent.act(timeout=self.max_resp_time)
            # First timeout check
            self.check_timeout(act)
            if self.chat_done:
                break
            try:
                ratings = []
                collected_ratings = list(
                    zip([q[0] for q in options], act['ratings']))
                for opt, rating in collected_ratings:
                    for other_opt in dup_dict[opt]:
                        ratings.append((other_opt, rating))
                igc_example['ratings'] = ratings
            except Exception:
                # Agent disconnected
                break
            igc_example['dialog_round_evaluated'] = self.d_rnd
            self.data.append(igc_example)
            self.turn_idx += 1

        if self.turn_idx == self.num_images:
            control_msg['text'] = CHAT_ENDED_MSG.format(self.num_images)
            self.agent.observe(validate(control_msg))
        self.chat_done = True
        return

    def image_id_to_path(self, image_id):
        if self.image_path == '':
            return os.path.join(self.task_dir, 'banana.jpg')
        else:
            return '{}/{}.jpg'.format(self.image_path, id)

    def filter_option_duplicates(self, options):
        # options = [(opt, text), (opt2, text2), ...]
        new_options = []
        text_to_opt = {}
        opt_to_opt = {}
        for opt, text in options:
            if text not in text_to_opt:
                text_to_opt[text] = opt
                new_options.append([opt, text])
                opt_to_opt[opt] = [opt]
            else:
                opt_to_opt[text_to_opt[text]].append(opt)
        return new_options, opt_to_opt

    def check_timeout(self, act):
        if act['text'] == '[TIMEOUT]' and act['episode_done']:
            control_msg = {'episode_done': True}
            control_msg['id'] = 'SYSTEM'
            control_msg['text'] = TIMEOUT_MSG
            for ag in self.agents:
                if ag.id != act['id']:
                    ag.observe(validate(control_msg))
            self.chat_done = True
            return True
        elif act['text'] == '[DISCONNECT]':
            self.chat_done = True
            return True
        else:
            return False

    def save_data(self):
        convo_finished = True
        for ag in self.agents:
            if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected
                    or ag.hit_is_expired):
                convo_finished = False
        if not convo_finished:
            ag.example_generator.push_example(self.example_id)
            print("\n**Push image {} back to stack. **\n".format(
                self.example_id))
        self.agents[0].example_generator.save_idx_stack()
        data_path = self.opt['data_path']
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        if convo_finished:
            filename = os.path.join(
                data_path,
                '{}_{}_{}.pkl'.format(time.strftime("%Y%m%d-%H%M%S"),
                                      np.random.randint(0, 1000),
                                      self.task_type))
        else:
            filename = os.path.join(
                data_path, '{}_{}_{}_incomplete.pkl'.format(
                    time.strftime("%Y%m%d-%H%M%S"), np.random.randint(0, 1000),
                    self.task_type))
        pickle.dump(
            {
                'data': self.data,
                'worker': self.agents[0].worker_id,
                'hit_id': self.agents[0].hit_id,
                'assignment_id': self.agents[0].assignment_id
            }, open(filename, 'wb'))
        print('{}: Data successfully saved at {}.'.format(
            self.world_tag, filename))

    def review_work(self):
        global review_agent

        def review_agent(ag):
            pass  # auto approve 5 days

        Parallel(n_jobs=len(self.agents),
                 backend='threading')(delayed(review_agent)(agent)
                                      for agent in self.agents)

    def shutdown(self):
        """Shutdown all mturk agents in parallel, otherwise if one mturk agent
        is disconnected then it could prevent other mturk agents from
        completing.
        """
        global shutdown_agent

        def shutdown_agent(agent):
            agent.shutdown()

        Parallel(n_jobs=len(self.agents),
                 backend='threading')(delayed(shutdown_agent)(agent)
                                      for agent in self.agents)
Example #17
0
class VQADataset(Dataset):
    """A Pytorch Dataset utilizing streaming"""
    def __init__(self, opt):
        self.opt = opt
        self.use_att = opt.get('attention', False)
        self.use_hdf5 = opt.get('use_hdf5', False)
        self.opt['use_hdf5_extraction'] = self.use_hdf5
        self.datatype = self.opt.get('datatype')
        self.training = self.datatype.startswith('train')
        self.num_epochs = self.opt.get('num_epochs', 0)
        self.image_loader = ImageLoader(opt)
        data_path, annotation_path, self.image_path = _path(opt)
        self._setup_data(data_path, annotation_path,
                         opt.get('unittest', False))
        if self.use_hdf5:
            try:
                import h5py

                self.h5py = h5py
            except ImportError:
                raise ImportError('Need to install h5py - `pip install h5py`')
            self._setup_image_data()
        self.dict_agent = VqaDictionaryAgent(opt)

    def __getitem__(self, index):
        index %= self.num_episodes()
        qa = self.ques['questions'][index]
        ep = {
            'text': qa['question'],
            'image': self.get_image(qa['image_id']),
            'episode_done': True,
        }
        if self.opt.get('extract_image', False):
            ep['image_id'] = qa['image_id']
            return ep
        if not self.datatype.startswith('test'):
            anno = self.annotation['annotations'][index]
            labels = [ans['answer'] for ans in anno['answers']]
            ep['labels'] = [ans['answer'] for ans in anno['answers']]
            ep['valid'] = True
            if 'mc_label' in ep:
                if not ep['mc_label'][0] in self.dict_agent.ans2ind:
                    ep['valid'] = False
            ep = self.dict_agent.encode_question([ep], self.training)
            ep = self.dict_agent.encode_answer(ep)
            ep[0]['labels'] = labels
        else:
            ep['valid'] = True
            ep = self.dict_agent.encode_question([ep], False)
        ep[0]['use_att'] = self.use_att
        ep[0]['use_hdf5'] = self.use_hdf5
        return (index, ep)

    def __len__(self):
        num_epochs = self.num_epochs if self.num_epochs > 0 else 100
        num_iters = num_epochs if self.training else 1
        return int(num_iters * self.num_episodes())

    def _load_lens(self):
        with open(self.length_datafile) as length:
            lengths = json.load(length)
            self.num_eps = lengths['num_eps']
            self.num_exs = lengths['num_exs']

    def _setup_data(self, data_path, annotation_path, unittest):
        with open(data_path) as data_file:
            self.ques = json.load(data_file)
        if not self.datatype.startswith('test'):
            with open(annotation_path) as data_file:
                self.annotation = json.load(data_file)
        if unittest:
            self.ques['questions'] = self.ques['questions'][:10]
            if not self.datatype.startswith('test'):
                self.annotation['annotations'] = self.annotation[
                    'annotations'][:10]
        self.image_paths = set()
        for qa in self.ques['questions']:
            self.image_paths.add(self.image_path + '%012d.jpg' %
                                 (qa['image_id']))

    def _setup_image_data(self):
        '''hdf5 image dataset'''
        extract_feats(self.opt)
        im = self.opt.get('image_mode')
        if self.opt.get('attention', False):
            hdf5_path = self.image_path + 'mode_{}.hdf5'.format(im)
        else:
            hdf5_path = self.image_path + 'mode_{}_noatt.hdf5'.format(im)
        hdf5_file = self.h5py.File(hdf5_path, 'r')
        self.image_dataset = hdf5_file['images']

        image_id_to_idx_path = self.image_path + 'mode_{}_id_to_idx.txt'.format(
            im)
        with open(image_id_to_idx_path, 'r') as f:
            self.image_id_to_idx = json.load(f)

    def get_image(self, image_id):
        if not self.use_hdf5:
            im_path = self.image_path + '%012d.jpg' % (image_id)
            return self.image_loader.load(im_path)
        else:
            img_idx = self.image_id_to_idx[str(image_id)]
            return torch.Tensor(self.image_dataset[img_idx])

    def num_episodes(self):
        return len(self.ques['questions'])

    def num_examples(self):
        return self.num_episodes()

    def num_images(self):
        if not hasattr(self, 'num_imgs'):
            self.num_imgs = len(
                {q['image_id']
                 for q in self.ques['questions']})
        return self.num_imgs