class FlickrDataset(Dataset): """A Pytorch Dataset utilizing streaming""" def __init__(self, opt, shared=None): self.opt = opt self.datatype = self.opt.get('datatype') self.training = self.datatype.startswith('train') self.num_epochs = self.opt.get('num_epochs', 0) self.image_loader = ImageLoader(opt) data_path, self.image_path = _path(opt) self._setup_data(data_path, opt.get('unittest', False)) self.dict_agent = DictionaryAgent(opt) def __getitem__(self, index): cap = self.data[index] image_id = int(cap['filename'].replace('.jpg', '')) ep = { 'text': QUESTION, 'image': self.get_image(image_id), 'episode_done': True, } if self.opt.get('extract_image', False): ep['image_id'] = image_id return ep ep['labels'] = [s['raw'] for s in cap['sentences']] ep['valid'] = True if 'train' not in self.datatype: ep['label_candidates'] = self.cands return (index, ep) def __len__(self): return self.num_episodes() def _setup_data(self, data_path, unittest): with open(data_path) as data_file: raw_data = json.load(data_file)['images'] if 'train' in self.datatype: self.data = [d for d in raw_data if d['split'] == 'train'] elif 'valid' in self.datatype: self.data = [d for d in raw_data if d['split'] == 'val'] self.cands = [l for d in self.data for l in [s['raw'] for s in d['sentences']]] else: self.data = [d for d in raw_data if d['split'] == 'test'] self.cands = [l for d in self.data for l in [s['raw'] for s in d['sentences']]] if unittest: self.caption = self.caption[:10] def get_image(self, image_id): im_path = os.path.join(self.image_path, '%d.jpg' % (image_id)) return self.image_loader.load(im_path) def num_episodes(self): return len(self.data) def num_examples(self): return self.num_episodes() def num_images(self): return self.num_episodes()
def test_other_image_modes(self): """ Test non-featurized image modes. """ with testing_utils.tempdir() as tmp: image_file = 'tmp.jpg' image_path = os.path.join(tmp, image_file) image_zip_path = os.path.join(tmp, 'tmp.zip') image = Image.new('RGB', (16, 16), color=0) with PathManager.open(image_path, 'wb') as fp: image.save(fp, 'JPEG') with zipfile.ZipFile(PathManager.open(image_zip_path, 'wb'), mode='w') as zipf: zipf.write(image_path, arcname=image_file) for im in ['raw', 'ascii']: loader = ImageLoader({"image_mode": im}) loader.load(image_path) loader.load(f"{image_zip_path}/{image_file}")
def __init__(self, opt, agent, bot, image_idx: int, image_act: Message): super().__init__(opt, agent=agent, bot=bot) self.image_stack = opt['image_stack'] self.image_idx = image_idx self.image_act = image_act # Get a stringified version of the image to show the user orig_image = self.image_act['image'] self.image_src = get_image_src(image=orig_image) # Get a featurized version of the image to show the bot with NamedTemporaryFile(suffix='.jpg') as f: orig_image.save(f) image_loader = ImageLoader(self.bot.model_agent.opt) self.image_act.force_set('image', image_loader.load(f.name))
class DialogData(object): """Provides a data structure for accessing textual dialog data. ``data_loader`` is an iterable, with each call returning: ``(x, ...), new_episode?`` Where - ``x`` is a query and possibly context ``...`` can contain additional fields, specifically - ``y`` is an iterable of label(s) for that query - ``c`` is an iterable of label candidates that the student can choose from - ``i`` is a str path to an image on disk, which will be loaded by the data class at request-time. should always point to the raw image file. - ``o`` is a json dump with other values - ``new_episode?`` is a boolean value specifying whether that example is the start of a new episode. If you don't use episodes set this to ``True`` every time. ``cands`` can be set to provide a list of candidate labels for every example in this dataset, which the agent can choose from (the correct answer should be in this set). ``random`` tells the data class whether or not to visit episodes sequentially or randomly when returning examples to the caller. """ def __init__(self, opt, data_loader=None, cands=None, shared=None, **kwargs): # self.data is a list of episodes # each episode is a tuple of entries # each entry is a tuple of values for the action/observation table if shared: self.image_loader = shared.get('image_loader', None) self.data = shared.get('data', []) self.cands = shared.get('cands', None) else: self.image_loader = ImageLoader(opt) self.data = [] self._load(data_loader, opt['datafile']) self.cands = None if cands == None else set(sys.intern(c) for c in cands) self.addedCands = [] self.copied_cands = False def share(self): shared = { 'data': self.data, 'cands': self.cands, 'image_loader': self.image_loader } return shared def __len__(self): """Returns total number of entries available. Each episode has at least one entry, but might have many more. """ return sum(len(episode) for episode in self.data) def _read_episode(self, data_generator): """Reads one episode at a time from the provided iterator over entries. """ episode = [] last_cands = None for entry, new in data_generator: if new and len(episode) > 0: yield tuple(episode) episode = [] last_cands = None # intern all strings so we don't store them more than once new_entry = [] if len(entry) > 0: # process text if available if entry[0] is not None: new_entry.append(sys.intern(entry[0])) else: new_entry.append(None) if len(entry) > 1: # process labels if available if entry[1] is None: new_entry.append(None) elif hasattr(entry[1], '__iter__') and type(entry[1]) is not str: # make sure iterable over labels, not single string new_entry.append(tuple(sys.intern(e) for e in entry[1])) else: raise TypeError('Must provide iterable over labels, not a single string.') if len(entry) > 2: # process label candidates if available if entry[2] is None: new_entry.append(None) elif last_cands and entry[2] is last_cands: # if cands are shared, say "same" so we # don't store them again new_entry.append( sys.intern('same as last time')) elif hasattr(entry[2], '__iter__') and type(entry[2]) is not str: # make sure iterable over candidates, not single string last_cands = entry[2] new_entry.append(tuple( sys.intern(e) for e in entry[2])) else: raise TypeError('Must provide iterable over label candidates, not a single string.') if len(entry) > 3: if entry[3] is None: new_entry.append(None) else: new_entry.append(sys.intern(entry[3])) # process other values if len(entry) > 4 and entry[4] is not None: new_entry.append(sys.intern(entry[4])) episode.append(tuple(new_entry)) if len(episode) > 0: yield tuple(episode) def _load(self, data_loader, datafile): """Loads up data from an iterator over tuples described in the class docs. """ for episode in self._read_episode(data_loader(datafile)): self.data.append(episode) def num_episodes(self): """Return number of episodes in the dataset.""" return len(self.data) def get(self, episode_idx, entry_idx=0): """Returns a specific entry from the dataset.""" # first look up data episode = self.data[episode_idx] entry = episode[entry_idx] episode_done = entry_idx == len(episode) - 1 end_of_data = episode_done and episode_idx == len(self.data) - 1 # now pack it in a action-observation dictionary table = self.build_table(entry) # last entry in this episode table['episode_done'] = episode_done return table, end_of_data def build_table(self, entry): """Packs an entry into an action-observation dictionary.""" table = {} if entry[0] is not None: table['text'] = entry[0] if len(entry) > 1: if entry[1] is not None: table['labels'] = entry[1] if len(entry) > 2: if entry[2] is not None: table['label_candidates'] = entry[2] if len(entry) > 3 and entry[3] is not None: img = self.image_loader.load(entry[3]) if img is not None: table['image'] = img if len(entry) > 4 and entry[4] is not None: for key, val in json.loads(entry[4]).items(): table[key] = val if (table.get('labels', None) is not None and self.cands is not None): if self.addedCands: # remove elements in addedCands self.cands.difference_update(self.addedCands) self.addedCands.clear() for label in table['labels']: if label not in self.cands: # add labels, queue them for removal next time if not self.copied_cands: self.cands = self.cands.copy() self.copied_cands = True self.cands.add(label) self.addedCands.append(label) table['label_candidates'] = self.cands if 'labels' in table and 'label_candidates' in table: if table['labels'][0] not in table['label_candidates']: raise RuntimeError('true label missing from candidate labels') return table
class DefaultDataset(Dataset): """ A Pytorch Dataset. """ def __init__(self, opt): self.opt = opt opt['image_load_task'] = 'personality_captions' self.image_mode = opt.get('image_mode', 'no_image_model') self.datatype = self.opt.get('datatype') self.training = self.datatype.startswith('train') self.include_image = opt.get('include_image') self.include_personality = opt.get('include_personality') self.num_test_labels = opt.get('num_test_labels', 1) data_path, personalities_data_path, self.image_path = _path(opt) self.image_loader = ImageLoader(opt) self._setup_data(data_path, personalities_data_path) @staticmethod def add_cmdline_args(argparser): """ Add command line args. """ PersonalityCaptionsTeacher.add_cmdline_args(argparser) def _setup_data(self, data_path, personalities_data_path): print('loading: ' + data_path) with open(data_path) as f: self.data = json.load(f) with open(personalities_data_path) as f: self.personalities = json.load(f) def __getitem__(self, index): data = self.data[index] image = self.get_image(data['image_hash']) ep = { 'text': data['personality'] if self.include_personality else '', 'episode_done': True, 'image': image if self.include_image else None, } if self.opt.get('extract_image', False): ep['image_id'] = data['image_hash'] return ep ep['labels'] = [data['comment']] if self.num_test_labels == 5 and 'test' in self.datatype: ep['labels'] += data['additional_comments'] if not self.training: if self.num_test_labels == 5 and 'test' in self.datatype: ep['label_candidates'] = data['500_candidates'] else: ep['label_candidates'] = data['candidates'] return (index, ep) def __len__(self): return self.num_episodes() def get_image(self, image_id): """ Get image. :param image_id: id of the image :return: image from the image loader """ im_path = os.path.join(self.image_path, '{}.jpg'.format(image_id)) return self.image_loader.load(im_path) def num_episodes(self): """ Return number of episodes. """ return len(self.data) def num_examples(self): """ Return number of examples. """ return self.num_episodes() def num_images(self): """ Return number of images. """ if not hasattr(self, 'num_imgs'): self.num_imgs = len({d['image_num'] for d in self.data}) return self.num_imgs
class FlickrDataset(Dataset): """A Pytorch Dataset utilizing streaming""" def __init__(self, opt): self.opt = opt self.use_hdf5 = opt.get('use_hdf5', False) self.datatype = self.opt.get('datatype') self.training = self.datatype.startswith('train') self.num_epochs = self.opt.get('num_epochs', 0) self.image_loader = ImageLoader(opt) caption_path, self.image_path = _path(opt) self._setup_data(caption_path, opt.get('unittest', False)) if self.use_hdf5: try: import h5py self.h5py = h5py except ModuleNotFoundError: raise ModuleNotFoundError('Need to install h5py - `pip install h5py`') self._setup_image_data() self.dict_agent = DictionaryAgent(opt) def __getitem__(self, index): index %= self.num_episodes() cap = self.caption[index] ep = { 'text': self.dict_agent.txt2vec(QUESTION), 'image': self.get_image(cap['image_id']), 'episode_done': True, } if self.opt.get('extract_image', False): ep['image_id'] = cap['image_id'] return ep ep['labels'] = [self.dict_agent.txt2vec(cc) for cc in cap['captions']] ep['valid'] = True ep['use_hdf5'] = self.use_hdf5 return (index, ep) def __len__(self): num_epochs = self.num_epochs if self.num_epochs > 0 else 100 num_iters = num_epochs if self.training else 1 return int(num_iters * self.num_episodes()) def _load_lens(self): with open(self.length_datafile) as length: lengths = json.load(length) self.num_eps = lengths['num_eps'] self.num_exs = lengths['num_exs'] def _setup_data(self, caption_path, unittest): with open(caption_path) as data_file: self.caption = [] prev_img_id = None for line in data_file: img_id = line.split('#')[0][:-4] caption = line.split('\t')[1] if img_id != prev_img_id: prev_img_id = img_id to_add = {} to_add['image_id'] = int(img_id) to_add['captions'] = [caption] self.caption.append(to_add) else: self.caption[-1]['captions'].append(caption) if unittest: self.caption = self.caption[:10] self.image_paths = set() for cap in self.caption: self.image_paths.add(os.path.join(self.image_path, '%d.jpg' % (cap['image_id']))) def _setup_image_data(self): '''hdf5 image dataset''' extract_feats(self.opt) im = self.opt.get('image_mode') hdf5_path = self.image_path + 'mode_{}_noatt.hdf5'.format(im) hdf5_file = self.h5py.File(hdf5_path, 'r') self.image_dataset = hdf5_file['images'] image_id_to_idx_path = self.image_path + 'mode_{}_id_to_idx.txt'.format(im) with open(image_id_to_idx_path, 'r') as f: self.image_id_to_idx = json.load(f) def get_image(self, image_id): if not self.use_hdf5: im_path = os.path.join(self.image_path, '%d.jpg' % (image_id)) return self.image_loader.load(im_path) else: img_idx = self.image_id_to_idx[str(image_id)] return torch.Tensor(self.image_dataset[img_idx]) def num_episodes(self): return len(self.caption) def num_examples(self): return self.num_episodes() def num_images(self): return self.num_episodes()
class DefaultDataset(Dataset): """A Pytorch Dataset utilizing streaming.""" def __init__(self, opt, version='2014'): self.opt = opt self.use_hdf5 = opt.get('use_hdf5', False) self.datatype = self.opt.get('datatype') self.training = self.datatype.startswith('train') self.num_epochs = self.opt.get('num_epochs', 0) self.image_loader = ImageLoader(opt) test_info_path, annotation_path, self.image_path = _path(opt, version) self._setup_data(test_info_path, annotation_path, opt.get('unittest', False)) if self.use_hdf5: try: import h5py self.h5py = h5py except ImportError: raise ImportError('Need to install h5py - `pip install h5py`') self._setup_image_data() self.dict_agent = DictionaryAgent(opt) def __getitem__(self, index): index %= self.num_episodes() image_id = None if not self.datatype.startswith('test'): anno = self.annotation['annotations'][index] image_id = anno['image_id'] else: image_id = self.test_info['images'][index]['id'] ep = { 'text': self.dict_agent.txt2vec(QUESTION), 'image': self.get_image(image_id), 'episode_done': True, } if self.opt.get('extract_image', False): ep['image_id'] = image_id return ep if not self.datatype.startswith('test'): anno = self.annotation['annotations'][index] ep['labels'] = [anno['caption']] ep['valid'] = True else: ep['valid'] = True ep['use_hdf5'] = self.use_hdf5 return (index, ep) def __len__(self): num_epochs = self.num_epochs if self.num_epochs > 0 else 100 num_iters = num_epochs if self.training else 1 return int(num_iters * self.num_episodes()) def _load_lens(self): with open(self.length_datafile) as length: lengths = json.load(length) self.num_eps = lengths['num_eps'] self.num_exs = lengths['num_exs'] def _setup_data(self, test_info_path, annotation_path, unittest): if not self.datatype.startswith('test'): with open(annotation_path) as data_file: self.annotation = json.load(data_file) else: with open(test_info_path) as data_file: self.test_info = json.load(data_file) if unittest: if not self.datatype.startswith('test'): self.annotation['annotations'] = self.annotation[ 'annotations'][:10] else: self.test_info['images'] = self.test_info['images'][:10] self.image_paths = set() # Depending on whether we are using the train/val/test set, we need to # find the image IDs in annotations or test image info if not self.datatype.startswith('test'): for anno in self.annotation['annotations']: self.image_paths.add( os.path.join(self.image_path, '%012d.jpg' % (anno['image_id']))) else: for info in self.test_info['images']: self.image_paths.add( os.path.join(self.image_path, '%012d.jpg' % (info['id']))) def _setup_image_data(self): '''hdf5 image dataset''' extract_feats(self.opt) im = self.opt.get('image_mode') hdf5_path = os.path.join(self.image_path, 'mode_{}_noatt.hdf5'.format(im)) hdf5_file = self.h5py.File(hdf5_path, 'r') self.image_dataset = hdf5_file['images'] image_id_to_idx_path = os.path.join(self.image_path, 'mode_{}_id_to_idx.txt'.format(im)) with open(image_id_to_idx_path, 'r') as f: self.image_id_to_idx = json.load(f) def get_image(self, image_id): if not self.use_hdf5: im_path = os.path.join(self.image_path, '%012d.jpg' % (image_id)) return self.image_loader.load(im_path) else: img_idx = self.image_id_to_idx[str(image_id)] return torch.Tensor(self.image_dataset[img_idx]) def num_examples(self): if not self.datatype.startswith('test'): return len(self.annotation['annotations']) else: return len(self.test_info['images']) def num_episodes(self): return self.num_examples() def num_images(self): if not hasattr(self, 'num_imgs'): return self.num_examples() return self.num_imgs
class DefaultDataset(Dataset): """A Pytorch Dataset""" def __init__(self, opt): self.opt = opt opt['image_load_task'] = 'image_chat' self.image_mode = opt.get('image_mode', 'none') self.datatype = self.opt.get('datatype') self.training = self.datatype.startswith('train') self.include_image = opt.get('include_image') self.include_personality = opt.get('include_personality') self.num_cands = opt.get('num_cands') data_path, personalities_data_path, self.image_path = _path(opt) self.image_loader = ImageLoader(opt) self._setup_data(data_path, personalities_data_path) @staticmethod def add_cmdline_args(argparser): ImageChatTeacher.add_cmdline_args(argparser) def _setup_data(self, data_path, personalities_data_path): print('loading: ' + data_path) with open(data_path) as f: self.data = json.load(f) with open(personalities_data_path) as f: self.personalities = json.load(f) def __getitem__(self, index): data = self.data[index] dialog = data['dialog'] personality = dialog[-1][0] text = '' if len(dialog) > 1: text = '\n'.join((dialog[0][1], dialog[1][1])) if self.include_personality: text += personality label = dialog[-1][1] image = self.get_image(data['image_hash']) ep = { 'text': text, 'episode_done': True, 'image': image if self.include_image else None, } if self.opt.get('extract_image', False): ep['image_id'] = data['image_hash'] return ep if not self.opt['datatype'].startswith('test'): ep['labels'] = [label] if not self.training: ep['label_candidates'] = data['candidates'][-1][self.num_cands] return (index, ep) def __len__(self): return self.num_episodes() def get_image(self, image_id): im_path = os.path.join(self.image_path, '{}.jpg'.format(image_id)) return self.image_loader.load(im_path) def num_episodes(self): return len(self.data) def num_examples(self): return self.num_episodes() def num_images(self): if not hasattr(self, 'num_imgs'): self.num_imgs = len({d['image_num'] for d in self.data}) return self.num_imgs
class DefaultDataset(Dataset): """A Pytorch Dataset utilizing streaming.""" def __init__(self, opt, version='2014'): self.opt = opt self.use_intro = opt.get('use_intro') self.num_cands = opt.get('num_cands') self.use_hdf5 = opt.get('use_hdf5', False) self.datatype = self.opt.get('datatype') self.training = self.datatype.startswith('train') self.num_epochs = self.opt.get('num_epochs', 0) self.image_loader = ImageLoader(opt) test_info_path, annotation_path, self.image_path = _path(opt, version) self.cands = load_candidates(opt['datapath'], self.datatype, version) self._setup_data(test_info_path, annotation_path, opt.get('unittest', False)) if self.use_hdf5: try: import h5py self.h5py = h5py except ImportError: raise ImportError('Need to install h5py - `pip install h5py`') self._setup_image_data() self.dict_agent = DictionaryAgent(opt) @staticmethod def add_cmdline_args(argparser): agent = argparser.add_argument_group('Comment Battle arguments') agent.add_argument('--use_intro', type="bool", default=False, help='Include an intro question with each image \ for readability (e.g. for coco_caption, \ Describe the above picture in a sentence.)') agent.add_argument('--num_cands', type=int, default=150, help='Number of candidates to use during \ evaluation, setting to -1 uses all.') def __getitem__(self, index): index %= self.num_episodes() image_id = None if not self.datatype.startswith('test'): anno = self.annotation['annotations'][index] image_id = anno['image_id'] else: image_id = self.test_info['images'][index]['id'] ep = { 'image': self.get_image(image_id), 'episode_done': True, } if self.use_intro: ep['text'] = QUESTION if self.opt.get('extract_image', False): ep['image_id'] = image_id return ep if not self.datatype.startswith('test'): anno = self.annotation['annotations'][index] ep['labels'] = [anno['caption']] if not self.datatype.startswith('train'): if self.num_cands == -1: candidates = self.cands else: # Can only randomly select from validation set candidates = random.Random(index).choices(self.cands, k=self.num_cands) if anno['caption'] not in candidates: candidates.pop(0) else: candidates.remove(anno['caption']) candidate_labels = [anno['caption']] candidate_labels += candidates ep['label_candidates'] = candidate_labels else: # TESTING if self.num_cands == -1: candidates = self.cands else: # Can select from train+test set candidates = random.Random(index).choices(self.cands, k=self.num_cands) ep['label_candidates'] = candidates ep['use_hdf5'] = self.use_hdf5 return (index, ep) def __len__(self): num_epochs = self.num_epochs if self.num_epochs > 0 else 100 num_iters = num_epochs if self.training else 1 return int(num_iters * self.num_episodes()) def _load_lens(self): with open(self.length_datafile) as length: lengths = json.load(length) self.num_eps = lengths['num_eps'] self.num_exs = lengths['num_exs'] def _setup_data(self, test_info_path, annotation_path, unittest): if not self.datatype.startswith('test'): with open(annotation_path) as data_file: self.annotation = json.load(data_file) else: with open(test_info_path) as data_file: self.test_info = json.load(data_file) if unittest: if not self.datatype.startswith('test'): self.annotation['annotations'] = self.annotation[ 'annotations'][:10] else: self.test_info['images'] = self.test_info['images'][:10] self.image_paths = set() # Depending on whether we are using the train/val/test set, we need to # find the image IDs in annotations or test image info if not self.datatype.startswith('test'): for anno in self.annotation['annotations']: self.image_paths.add(self.image_path + '%012d.jpg' % (anno['image_id'])) else: for info in self.test_info['images']: self.image_paths.add(self.image_path + '%012d.jpg' % (info['id'])) def _setup_image_data(self): '''hdf5 image dataset''' extract_feats(self.opt) im = self.opt.get('image_mode') hdf5_path = os.path.join(self.image_path, 'mode_{}_noatt.hdf5'.format(im)) hdf5_file = self.h5py.File(hdf5_path, 'r') self.image_dataset = hdf5_file['images'] image_id_to_idx_path = os.path.join(self.image_path, 'mode_{}_id_to_idx.txt'.format(im)) with open(image_id_to_idx_path, 'r') as f: self.image_id_to_idx = json.load(f) def get_image(self, image_id): if not self.use_hdf5: im_path = os.path.join(self.image_path, '%012d.jpg' % (image_id)) return self.image_loader.load(im_path) else: img_idx = self.image_id_to_idx[str(image_id)] return torch.Tensor(self.image_dataset[img_idx]) def num_examples(self): if not self.datatype.startswith('test'): return len(self.annotation['annotations']) else: return len(self.test_info['images']) def num_episodes(self): return self.num_examples() def num_images(self): if not hasattr(self, 'num_imgs'): return self.num_examples() return self.num_imgs
class DefaultDataset(Dataset): """A Pytorch Dataset utilizing streaming.""" def __init__(self, opt, version='2017'): self.opt = opt self.version = version self.use_intro = opt.get('use_intro', False) self.num_cands = opt.get('num_cands', -1) self.datatype = self.opt.get('datatype') self.include_rest_val = opt.get('include_rest_val', True) self.image_loader = ImageLoader(opt) test_info_path, annotation_path, self.image_path = _path(opt, version) self._setup_data(test_info_path, annotation_path, opt) @staticmethod def add_cmdline_args(argparser): DefaultTeacher.add_cmdline_args(argparser) def __getitem__(self, index): ep = {'episode_done': True} if self.use_intro: ep['text'] = QUESTION if hasattr(self, 'annotation'): anno = self.annotation[index] else: anno = self.test_info['images'][index] if self.version == '2014': ep['labels'] = [s['raw'] for s in anno['sentences']] ep['image_id'] = anno['cocoid'] ep['split'] = anno['split'] elif not self.datatype.startswith('test'): ep['image_id'] = anno['image_id'] ep['labels'] = [anno['caption']] else: ep['image_id'] = anno['id'] ep['image']: self.get_image(ep['image_id'], anno.get('split', None)) if self.opt.get('extract_image', False): return ep # Add Label Cands if not self.datatype.startswith('train'): if self.num_cands == -1: ep['label_candidates'] = self.cands else: candidates = random.Random(index).choices(self.cands, k=self.num_cands) label = random.choice(ep.get('labels', [''])) if not (label == '' or label in candidates): candidates.pop(0) candidates.append(label) random.shuffle(candidates) ep['label_candidates'] = candidates return (index, ep) def __len__(self): return self.num_episodes() def _load_lens(self): with open(self.length_datafile) as length: lengths = json.load(length) self.num_eps = lengths['num_eps'] self.num_exs = lengths['num_exs'] def _setup_data(self, test_info_path, annotation_path, opt): if self.version == '2014': with open(annotation_path) as data_file: raw_data = json.load(data_file)['images'] if 'train' in self.datatype: self.annotation = [ d for d in raw_data if d['split'] == 'train' ] if self.include_rest_val: self.annotation += [ d for d in raw_data if d['split'] == 'restval' ] elif 'valid' in self.datatype: self.annotation = [d for d in raw_data if d['split'] == 'val'] self.cands = [ l for d in self.annotation for l in [s['raw'] for s in d['sentences']] ] else: self.annotation = [d for d in raw_data if d['split'] == 'test'] self.cands = [ l for d in self.annotation for l in [s['raw'] for s in d['sentences']] ] else: if not self.datatype.startswith('test'): print('loading: ' + annotation_path) with open(annotation_path) as data_file: self.annotation = json.load(data_file)['annotations'] else: print('loading: ' + test_info_path) with open(test_info_path) as data_file: self.test_info = json.load(data_file) if not self.datatype.startswith('train'): self.cands = load_candidates(opt['datapath'], opt['datatype'], self.version) if opt.get('unittest', False): if not self.datatype.startswith('test'): self.annotation = self.annotation[:10] else: self.test_info['images'] = self.test_info['images'][:10] def get_image(self, image_id, split): if split == 'restval': im_path = self.image_path.replace('train', 'val') else: im_path = self.image_path im_path = os.path.join(im_path, '%012d.jpg' % (image_id)) return self.image_loader.load(im_path) def num_examples(self): if self.version == '2014' or not self.datatype.startswith('test'): return len(self.annotation) else: # For 2017, we only have annotations for the train and val sets, # so for the test set we need to determine how many images we have. return len(self.test_info['images']) def num_episodes(self): return self.num_examples() def num_images(self): if not hasattr(self, 'num_imgs'): return self.num_examples() return self.num_imgs
class OeTeacher(Teacher): """VQA v2.0 Open-Ended teacher, which loads the json VQA data and implements its own `act` method for interacting with student agent. agent. """ def __init__(self, opt, shared=None): super().__init__(opt) self.datatype = opt['datatype'] data_path, annotation_path, self.image_path = _path(opt) if shared and 'ques' in shared: self.ques = shared['ques'] if 'annotation' in shared: self.annotation = shared['annotation'] self.image_loader = shared['image_loader'] else: self._setup_data(data_path, annotation_path) self.image_loader = ImageLoader(opt) self.len = len(self.ques['questions']) # for ordered data in batch mode (especially, for validation and # testing), each teacher in the batch gets a start index and a step # size so they all process disparate sets of the data self.step_size = opt.get('batchsize', 1) self.data_offset = opt.get('batchindex', 0) self.reset() def __len__(self): return self.len def reset(self): # Reset the dialog so that it is at the start of the epoch, # and all metrics are reset. super().reset() self.lastY = None self.episode_idx = self.data_offset - self.step_size def observe(self, observation): """Process observation for metrics.""" if self.lastY is not None: self.metrics.update(observation, self.lastY) self.lastY = None return observation def act(self): if self.datatype == 'train': self.episode_idx = random.randrange(self.len) else: self.episode_idx = (self.episode_idx + self.step_size) % len(self) if self.episode_idx == len(self) - self.step_size: self.epochDone = True qa = self.ques['questions'][self.episode_idx] question = qa['question'] image_id = qa['image_id'] img_path = self.image_path + '%012d.jpg' % (image_id) action = { 'image': self.image_loader.load(img_path), 'text': question, 'episode_done': True } if not self.datatype.startswith('test'): anno = self.annotation['annotations'][self.episode_idx] self.lastY = [ans['answer'] for ans in anno['answers']] if self.datatype.startswith('train'): action['labels'] = self.lastY return action def share(self): shared = super().share() shared['ques'] = self.ques if hasattr(self, 'annotation'): shared['annotation'] = self.annotation shared['image_loader'] = self.image_loader return shared def _setup_data(self, data_path, annotation_path): print('loading: ' + data_path) with open(data_path) as data_file: self.ques = json.load(data_file) if self.datatype != 'test': print('loading: ' + annotation_path) with open(annotation_path) as data_file: self.annotation = json.load(data_file)
class QADataCollectionWorld(MTurkTaskWorld): """ World for recording a turker's question and answer given a context. Assumes the context is a random context from a given task, e.g. from SQuAD, CBT, etc. """ collector_agent_id = 'QA Collector' def __init__(self, opt, task, mturk_agent): self.task = task self.mturk_agent = mturk_agent self.episodeDone = False self.turn_index = -1 self.context = None self.question = None self.answer = None self.image_loader = ImageLoader(opt) def parley(self): # Each turn starts from the QA Collector agent # self.turn_index = (self.turn_index + 1) % 2 self.turn_index = self.turn_index + 1 ad = {'episode_done': False} ad['id'] = self.__class__.collector_agent_id if self.turn_index == 0: # At the first turn, the QA Collector agent provides the context # and prompts the turker to ask a question regarding the context # Get context from SQuAD teacher agent qa = self.task.act() self.context = '\n'.join(qa['text'].split('\n')[:-1]) context_list = self.context.split('\n') table = "<table border=\"1\"><tr> <td>{}, {}</td> <td>{}, {}</td> </tr> <tr> <td>{}, {}</td> <td>{}, {}</td> </tr> </table>".format( context_list[0], context_list[1], context_list[2], context_list[3]) # Wrap the context with a prompt telling the turker what to do next # ad['text'] = ( # self.context + '\n\nPlease provide a question given this context.' # ) ad['text'] = (context_list + '\n\nPlease provide a question given this context.') img = self.image_loader.load( "/Users/shiquan/PycharmProjects/ParlAI/parlai/mturk/tasks/qa_data_collection/banana.jpg" ) buffered = BytesIO() img.save(buffered, format="JPEG") encoded = str( base64.b64encode(buffered.getvalue()).decode('ascii')) ad['image'] = encoded self.mturk_agent.observe(validate(ad)) self.question = self.mturk_agent.act() # Can log the turker's question here if self.turn_index >= 1 and self.turn_index < 5: # At the second turn, the QA Collector collects the turker's # question from the first turn, and then prompts the # turker to provide the answer # A prompt telling the turker what to do next ad['text'] = 'Thanks. And what is the answer to your question?' self.mturk_agent.observe(validate(ad)) self.answer = self.mturk_agent.act() if self.turn_index == 5: # At the second turn, the QA Collector collects the turker's # question from the first turn, and then prompts the # turker to provide the answer # A prompt telling the turker what to do next ad['text'] = 'Thanks. And what is the answer to your question?' ad['episode_done'] = True # end of episode self.mturk_agent.observe(validate(ad)) self.answer = self.mturk_agent.act() self.episodeDone = True def episode_done(self): return self.episodeDone def shutdown(self): self.task.shutdown() self.mturk_agent.shutdown() def review_work(self): # Can review the work here to accept or reject it pass def get_custom_task_data(self): # brings important data together for the task, to later be used for # creating the dataset. If data requires pickling, put it in a field # called 'needs-pickle'. return {'context': self.context, 'acts': [self.question, self.answer]}
class SplitTeacher(Teacher): """FVQA Teacher, which loads the json VQA data and implements its own `act` method for interacting with student agent. Use "fvqa:split:X" to choose between splits 0-4 (inclusive), or just "fvqa" to use the default split (0). """ def __init__(self, opt, shared=None): super().__init__(opt) dt = opt['datatype'].split(':')[0] if dt not in ('train', 'test'): raise RuntimeError('Not valid datatype (only train/test).') task = opt.get('task', 'fvqa:split:0') task_num = 0 # default to train/split 0 split = task.split(':') if len(split) > 2: task_num = split[2] if task_num not in [str(i) for i in range(5)]: raise RuntimeError('Invalid train/test split ID (0-4 inclusive)') if not hasattr(self, 'factmetrics'): if shared and shared.get('factmetrics'): self.factmetrics = shared['factmetrics'] else: self.factmetrics = Metrics(opt) self.datatype = opt['datatype'] questions_path, trainset_path, self.image_path = _path(opt) if shared and 'ques' in shared: self.ques = shared['ques'] else: self._setup_data(questions_path, trainset_path, dt, task_num) self.len = len(self.ques) self.asked_question = False # for ordered data in batch mode (especially, for validation and # testing), each teacher in the batch gets a start index and a step # size so they all process disparate sets of the data self.step_size = opt.get('batchsize', 1) self.data_offset = opt.get('batchindex', 0) self.image_loader = ImageLoader(opt) self.reset() def __len__(self): return self.len def report(self): r = super().report() r['factmetrics'] = self.factmetrics.report() return r def reset(self): # Reset the dialog so that it is at the start of the epoch, # and all metrics are reset. super().reset() self.lastY = None self.episode_idx = self.data_offset - self.step_size self.epochDone = False def reset_metrics(self): super().reset_metrics() self.factmetrics.clear() def observe(self, observation): """Process observation for metrics.""" if self.lastY is not None: if self.asked_question: self.metrics.update(observation, self.lastY[0]) else: self.factmetrics.update(observation, self.lastY[1]) self.lastY = None return observation def act(self): if self.asked_question: self.asked_question = False action = {'text': 'Which fact supports this answer?', 'episode_done': True} if self.datatype.startswith('train'): action['labels'] = self.lastY[1] if self.datatype != 'train' and self.episode_idx + self.step_size >= len(self): self.epochDone = True return action if self.datatype == 'train': self.episode_idx = random.randrange(self.len) else: self.episode_idx = (self.episode_idx + self.step_size) % len(self) self.asked_question = True qa = self.ques[self.episode_idx] question = qa['question'] img_path = self.image_path + qa['img_file'] action = { 'image': self.image_loader.load(img_path), 'text': question, 'episode_done': False } human_readable = qa['fact_surface'].replace('[', '').replace(']', '') self.lastY = [[qa['answer']], [human_readable]] if self.datatype.startswith('train'): action['labels'] = self.lastY[0] return action def share(self): shared = super().share() shared['factmetrics'] = self.factmetrics shared['ques'] = self.ques if hasattr(self, 'facts'): shared['facts'] = self.facts return shared def _setup_data(self, questions_path, trainset_path, datatype, task_num): print('loading: ' + questions_path) with open(questions_path) as questions_file: questions = json.load(questions_file) train_test_images = set() with open(os.path.join(trainset_path, '{}_list_{}.txt'.format(datatype, task_num))) as imageset: for line in imageset: train_test_images.add(line.strip()) self.ques = [questions[k] for k in sorted(questions.keys()) if questions[k]['img_file'] in train_test_images]
class OeTeacher(Teacher): """ VQA Open-Ended teacher, which loads the json vqa data and implements its own `act` method for interacting with student agent. """ def __init__(self, opt, shared=None): super().__init__(opt, shared) self.datatype = opt['datatype'] data_path, annotation_path, self.image_path = _path(opt) if shared and 'ques' in shared: self.ques = shared['ques'] if 'annotation' in shared: self.annotation = shared['annotation'] else: self._setup_data(data_path, annotation_path) # for ordered data in batch mode (especially, for validation and # testing), each teacher in the batch gets a start index and a step # size so they all process disparate sets of the data self.step_size = opt.get('batchsize', 1) self.data_offset = opt.get('batchindex', 0) self.image_loader = ImageLoader(opt) self.reset() def __len__(self): return len(self.ques['questions']) def reset(self): # Reset the dialog so that it is at the start of the epoch, # and all metrics are reset. super().reset() self.lastY = None self.episode_idx = self.data_offset - self.step_size def observe(self, observation): """Process observation for metrics.""" if self.lastY is not None: self.metrics.update(observation, self.lastY) self.lastY = None return observation def act(self): if self.datatype == 'train': self.episode_idx = random.randrange(len(self)) else: self.episode_idx = (self.episode_idx + self.step_size) % len(self) if self.episode_idx == len(self) - self.step_size: self.epochDone = True qa = self.ques['questions'][self.episode_idx] question = qa['question'] image_id = qa['image_id'] img_path = self.image_path + '%012d.jpg' % (image_id) action = { 'image': self.image_loader.load(img_path), 'text': question, 'episode_done': True } if not self.datatype.startswith('test'): anno = self.annotation['annotations'][self.episode_idx] self.lastY = [ans['answer'] for ans in anno['answers']] if self.datatype.startswith('train'): action['labels'] = self.lastY return action def share(self): shared = super().share() shared['ques'] = self.ques if hasattr(self, 'annotation'): shared['annotation'] = self.annotation return shared def _setup_data(self, data_path, annotation_path): print('loading: ' + data_path) with open(data_path) as data_file: self.ques = json.load(data_file) if self.datatype != 'test': print('loading: ' + annotation_path) with open(annotation_path) as data_file: self.annotation = json.load(data_file)
class SplitTeacher(Teacher): """FVQA Teacher, which loads the json VQA data and implements its own `act` method for interacting with student agent. Use "fvqa:split:X" to choose between splits 0-4 (inclusive), or just "fvqa" to use the default split (0). """ def __init__(self, opt, shared=None): super().__init__(opt) dt = opt['datatype'].split(':')[0] if dt not in ('train', 'test'): raise RuntimeError('Not valid datatype (only train/test).') task = opt.get('task', 'fvqa:split:0') task_num = 0 # default to train/split 0 split = task.split(':') if len(split) > 2: task_num = split[2] if task_num not in [str(i) for i in range(5)]: raise RuntimeError( 'Invalid train/test split ID (0-4 inclusive)') if not hasattr(self, 'factmetrics'): if shared and shared.get('factmetrics'): self.factmetrics = shared['factmetrics'] else: self.factmetrics = Metrics(opt) self.datatype = opt['datatype'] questions_path, trainset_path, self.image_path = _path(opt) if shared and 'ques' in shared: self.ques = shared['ques'] else: self._setup_data(questions_path, trainset_path, dt, task_num) self.len = len(self.ques) self.asked_question = False # for ordered data in batch mode (especially, for validation and # testing), each teacher in the batch gets a start index and a step # size so they all process disparate sets of the data self.step_size = opt.get('batchsize', 1) self.data_offset = opt.get('batchindex', 0) self.image_loader = ImageLoader(opt) self.reset() def num_examples(self): return self.len def num_episodes(self): return self.len def report(self): r = super().report() r['factmetrics'] = self.factmetrics.report() return r def reset(self): # Reset the dialog so that it is at the start of the epoch, # and all metrics are reset. super().reset() self.lastY = None self.episode_idx = self.data_offset - self.step_size self.epochDone = False def reset_metrics(self): super().reset_metrics() self.factmetrics.clear() def observe(self, observation): """Process observation for metrics.""" if self.lastY is not None: if self.asked_question: self.metrics.update(observation, self.lastY[0]) else: self.factmetrics.update(observation, self.lastY[1]) self.lastY = None return observation def act(self): if self.asked_question: self.asked_question = False action = { 'text': 'Which fact supports this answer?', 'episode_done': True } if self.datatype.startswith('train'): action['labels'] = self.lastY[1] if self.datatype != 'train' and self.episode_idx + self.step_size >= self.num_episodes( ): self.epochDone = True return action if self.datatype == 'train': self.episode_idx = random.randrange(self.len) else: self.episode_idx = (self.episode_idx + self.step_size) % self.num_episodes() self.asked_question = True qa = self.ques[self.episode_idx] question = qa['question'] img_path = self.image_path + qa['img_file'] action = { 'image': self.image_loader.load(img_path), 'text': question, 'episode_done': False } human_readable = qa['fact_surface'].replace('[', '').replace(']', '') self.lastY = [[qa['answer']], [human_readable]] if self.datatype.startswith('train'): action['labels'] = self.lastY[0] return action def share(self): shared = super().share() shared['factmetrics'] = self.factmetrics shared['ques'] = self.ques if hasattr(self, 'facts'): shared['facts'] = self.facts return shared def _setup_data(self, questions_path, trainset_path, datatype, task_num): print('loading: ' + questions_path) with open(questions_path) as questions_file: questions = json.load(questions_file) train_test_images = set() with open( os.path.join(trainset_path, '{}_list_{}.txt'.format(datatype, task_num))) as imageset: for line in imageset: train_test_images.add(line.strip()) self.ques = [ questions[k] for k in sorted(questions.keys()) if questions[k]['img_file'] in train_test_images ]
class MTurkIGCEvalWorld(MultiAgentDialogWorld): """World where an agent observes 5 images and 3 comments about the images, and ranks the comments """ def __init__(self, opt, agents=None, shared=None, world_tag='NONE'): self.turn_idx = 0 self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.world_tag = world_tag self.max_resp_time = opt['max_resp_time'] # in secs super().__init__(opt, agents, shared) self.agents = agents self.agent = agents[0] self.data = [] self.exact_match = False self.num_images = opt['num_images'] self.d_rnd = opt.get('dialog_round') self.image_path = opt.get('image_path') self.task_dir = opt['task_dir'] opt['image_mode'] = 'raw' self.image_loader = ImageLoader(opt) def episode_done(self): return self.chat_done def parley(self): """RATER is given an image, context (and possibly some questions) and is asked to rate the responses. """ # Initial Message Value control_msg = {'episode_done': False} control_msg['id'] = 'SYSTEM' """First, we give RATER the image and context """ while self.turn_idx < self.num_images: print(self.world_tag + ' is at turn {}...'.format(self.turn_idx)) # Send image to turker if self.d_rnd == 'questions': control_msg['description'] = config_questions[ 'task_description'] else: control_msg['description'] = config_responses[ 'task_description'] self.example_id, igc_example = self.agent.example_generator.pop_example( ) img = self.image_loader.load(self.image_id_to_path( self.example_id)) buffered = BytesIO() img.save(buffered, format="JPEG") encoded = str( base64.b64encode(buffered.getvalue()).decode('ascii')) control_msg['image'] = encoded control_msg['context'] = igc_example['context'] """ Setup Options for rating """ if self.d_rnd == 'questions': options = [(k, v) for k, v in igc_example['questions'].items()] else: control_msg['question'] = igc_example['question'] options = [(k, v) for k, v in igc_example['responses'].items()] random.shuffle(options) options, dup_dict = self.filter_option_duplicates(options) control_msg['options'] = [c[1] for c in options] # Collect rating from turker rate_msg = RATE_MSG if self.d_rnd == 'questions' else RATE_RESPONSE_MSG control_msg['text'] = rate_msg.format(self.turn_idx + 1) control_msg['new_eval'] = True self.agent.observe(validate(control_msg)) time.sleep(1) act = self.agent.act(timeout=self.max_resp_time) # First timeout check self.check_timeout(act) if self.chat_done: break try: ratings = [] collected_ratings = list( zip([q[0] for q in options], act['ratings'])) for opt, rating in collected_ratings: for other_opt in dup_dict[opt]: ratings.append((other_opt, rating)) igc_example['ratings'] = ratings except Exception: # Agent disconnected break igc_example['dialog_round_evaluated'] = self.d_rnd self.data.append(igc_example) self.turn_idx += 1 if self.turn_idx == self.num_images: control_msg['text'] = CHAT_ENDED_MSG.format(self.num_images) self.agent.observe(validate(control_msg)) self.chat_done = True return def image_id_to_path(self, image_id): if self.image_path == '': return os.path.join(self.task_dir, 'banana.jpg') else: return '{}/{}.jpg'.format(self.image_path, id) def filter_option_duplicates(self, options): # options = [(opt, text), (opt2, text2), ...] new_options = [] text_to_opt = {} opt_to_opt = {} for opt, text in options: if text not in text_to_opt: text_to_opt[text] = opt new_options.append([opt, text]) opt_to_opt[opt] = [opt] else: opt_to_opt[text_to_opt[text]].append(opt) return new_options, opt_to_opt def check_timeout(self, act): if act['text'] == '[TIMEOUT]' and act['episode_done']: control_msg = {'episode_done': True} control_msg['id'] = 'SYSTEM' control_msg['text'] = TIMEOUT_MSG for ag in self.agents: if ag.id != act['id']: ag.observe(validate(control_msg)) self.chat_done = True return True elif act['text'] == '[DISCONNECT]': self.chat_done = True return True else: return False def save_data(self): convo_finished = True for ag in self.agents: if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected or ag.hit_is_expired): convo_finished = False if not convo_finished: ag.example_generator.push_example(self.example_id) print("\n**Push image {} back to stack. **\n".format( self.example_id)) self.agents[0].example_generator.save_idx_stack() data_path = self.opt['data_path'] if not os.path.exists(data_path): os.makedirs(data_path) if convo_finished: filename = os.path.join( data_path, '{}_{}_{}.pkl'.format(time.strftime("%Y%m%d-%H%M%S"), np.random.randint(0, 1000), self.task_type)) else: filename = os.path.join( data_path, '{}_{}_{}_incomplete.pkl'.format( time.strftime("%Y%m%d-%H%M%S"), np.random.randint(0, 1000), self.task_type)) pickle.dump( { 'data': self.data, 'worker': self.agents[0].worker_id, 'hit_id': self.agents[0].hit_id, 'assignment_id': self.agents[0].assignment_id }, open(filename, 'wb')) print('{}: Data successfully saved at {}.'.format( self.world_tag, filename)) def review_work(self): global review_agent def review_agent(ag): pass # auto approve 5 days Parallel(n_jobs=len(self.agents), backend='threading')(delayed(review_agent)(agent) for agent in self.agents) def shutdown(self): """Shutdown all mturk agents in parallel, otherwise if one mturk agent is disconnected then it could prevent other mturk agents from completing. """ global shutdown_agent def shutdown_agent(agent): agent.shutdown() Parallel(n_jobs=len(self.agents), backend='threading')(delayed(shutdown_agent)(agent) for agent in self.agents)
class VQADataset(Dataset): """A Pytorch Dataset utilizing streaming""" def __init__(self, opt): self.opt = opt self.use_att = opt.get('attention', False) self.use_hdf5 = opt.get('use_hdf5', False) self.opt['use_hdf5_extraction'] = self.use_hdf5 self.datatype = self.opt.get('datatype') self.training = self.datatype.startswith('train') self.num_epochs = self.opt.get('num_epochs', 0) self.image_loader = ImageLoader(opt) data_path, annotation_path, self.image_path = _path(opt) self._setup_data(data_path, annotation_path, opt.get('unittest', False)) if self.use_hdf5: try: import h5py self.h5py = h5py except ImportError: raise ImportError('Need to install h5py - `pip install h5py`') self._setup_image_data() self.dict_agent = VqaDictionaryAgent(opt) def __getitem__(self, index): index %= self.num_episodes() qa = self.ques['questions'][index] ep = { 'text': qa['question'], 'image': self.get_image(qa['image_id']), 'episode_done': True, } if self.opt.get('extract_image', False): ep['image_id'] = qa['image_id'] return ep if not self.datatype.startswith('test'): anno = self.annotation['annotations'][index] labels = [ans['answer'] for ans in anno['answers']] ep['labels'] = [ans['answer'] for ans in anno['answers']] ep['valid'] = True if 'mc_label' in ep: if not ep['mc_label'][0] in self.dict_agent.ans2ind: ep['valid'] = False ep = self.dict_agent.encode_question([ep], self.training) ep = self.dict_agent.encode_answer(ep) ep[0]['labels'] = labels else: ep['valid'] = True ep = self.dict_agent.encode_question([ep], False) ep[0]['use_att'] = self.use_att ep[0]['use_hdf5'] = self.use_hdf5 return (index, ep) def __len__(self): num_epochs = self.num_epochs if self.num_epochs > 0 else 100 num_iters = num_epochs if self.training else 1 return int(num_iters * self.num_episodes()) def _load_lens(self): with open(self.length_datafile) as length: lengths = json.load(length) self.num_eps = lengths['num_eps'] self.num_exs = lengths['num_exs'] def _setup_data(self, data_path, annotation_path, unittest): with open(data_path) as data_file: self.ques = json.load(data_file) if not self.datatype.startswith('test'): with open(annotation_path) as data_file: self.annotation = json.load(data_file) if unittest: self.ques['questions'] = self.ques['questions'][:10] if not self.datatype.startswith('test'): self.annotation['annotations'] = self.annotation[ 'annotations'][:10] self.image_paths = set() for qa in self.ques['questions']: self.image_paths.add(self.image_path + '%012d.jpg' % (qa['image_id'])) def _setup_image_data(self): '''hdf5 image dataset''' extract_feats(self.opt) im = self.opt.get('image_mode') if self.opt.get('attention', False): hdf5_path = self.image_path + 'mode_{}.hdf5'.format(im) else: hdf5_path = self.image_path + 'mode_{}_noatt.hdf5'.format(im) hdf5_file = self.h5py.File(hdf5_path, 'r') self.image_dataset = hdf5_file['images'] image_id_to_idx_path = self.image_path + 'mode_{}_id_to_idx.txt'.format( im) with open(image_id_to_idx_path, 'r') as f: self.image_id_to_idx = json.load(f) def get_image(self, image_id): if not self.use_hdf5: im_path = self.image_path + '%012d.jpg' % (image_id) return self.image_loader.load(im_path) else: img_idx = self.image_id_to_idx[str(image_id)] return torch.Tensor(self.image_dataset[img_idx]) def num_episodes(self): return len(self.ques['questions']) def num_examples(self): return self.num_episodes() def num_images(self): if not hasattr(self, 'num_imgs'): self.num_imgs = len( {q['image_id'] for q in self.ques['questions']}) return self.num_imgs