def __init__(self, anno_path, vocab_fn, split="val", sent_token_length=40, np_token_length=6, num_np_per_sent=6, debug=False): super(WIDERTextDataset, self).__init__() # load annotations with open(anno_path, 'r') as f: anns = json.load(f) self.anns = [ann for ann in anns if ann['split'] == split] self.anns = self.anns[:500] if debug else self.anns # init tokenizer self.sent_token_length = sent_token_length self.np_token_length = np_token_length self.num_np_per_sent = num_np_per_sent self.tokenizer = WIDER_Tokenizer(vocab_fn) # init captions data self.captions, self.images = [], [] for ann in self.anns: for cap in ann['captions']: self.captions.append(cap) self.images.append(ann['file_path']) self.len = len(self.captions)
class WIDERTriplet_NP(WIDERTriplet): def __init__(self, anno_path, img_dir, mask_dir, vocab_fn, sent_token_length=40, np_token_length=6, num_np_per_sent=6, split='train', transform=None, debug=False): super(WIDERTriplet_NP, self).__init__(anno_path, img_dir, split, transform, debug) self.sent_token_length = sent_token_length self.np_token_length = np_token_length self.num_np_per_sent = num_np_per_sent self.tokenizer = WIDER_Tokenizer(vocab_fn) self.np_extractor = NPExtractor() print("size of dataset: %d" % len(self.anns)) def _load_cap(self, index): len_sent = self.sent_token_length len_np = self.np_token_length N = self.num_np_per_sent cap = self.anns[index]['captions'] cap_token = self.tokenizer.tokenize(cap, len_sent) cap_token = torch.LongTensor([cap_token])[0] nps = self.np_extractor.sent_parse(cap) nps = [ torch.LongTensor([self.tokenizer.tokenize(np, len_np)]) for np in nps ] num_nps = min(len(nps), N) nps = torch.cat(nps) nps = nps[:N] if nps.size(0) > N else torch.cat( (nps, torch.ones((N - nps.size(0), len_np)).long())) return cap_token, nps, num_nps def __getitem__(self, index): # sample # load image img = self._load_img(index) # load caption cap, nps, num_nps = self._load_cap(index) # load pid pid = self.ann2person[index] pid = self.person2label[pid] return (img, cap, nps, num_nps, pid)
def __init__(self, anno_path, img_dir, vocab_fn, token_length=40, split='train', transform=None, debug=False): super(WIDERTriplet_Basic, self).__init__(anno_path, img_dir, split, transform, debug) self.token_length = token_length self.tokenizer = WIDER_Tokenizer(vocab_fn) print("size of dataset: %d" % len(self.anns))
class WIDERTriplet_Basic(WIDERTriplet): def __init__(self, anno_path, img_dir, vocab_fn, token_length=40, split='train', transform=None, debug=False): super(WIDERTriplet_Basic, self).__init__(anno_path, img_dir, split, transform, debug) self.token_length = token_length self.tokenizer = WIDER_Tokenizer(vocab_fn) print("size of dataset: %d" % len(self.anns)) def _load_cap(self, index, i=None): cap = self.anns[index]['captions'] cap_token = self.tokenizer.tokenize(cap, 40) cap_token = torch.LongTensor([cap_token])[0] return cap_token def __getitem__(self, index): if self.train: # sample # load image curr_img = self._load_img(index) # load caption cap = self._load_cap(index) # load pid pid = self.ann2person[index] pid = self.person2label[pid] return (curr_img, cap, pid)
class WIDERTriplet_NP(WIDERTriplet): def __init__(self, anno_path, img_dir, vocab_fn, token_length=40, split='train', transform=None, debug=False): super(WIDERTriplet_NP, self).__init__(anno_path, img_dir, split, transform, debug) self.token_length = token_length self.tokenizer = WIDER_Tokenizer(vocab_fn) self.np_extractor = NPExtractor() print("size of dataset: %d" % len(self.anns)) def _load_cap(self, index, i=None): cap = self.anns[index]['captions'] cap_token = self.tokenizer.tokenize(cap, 40) cap_token = torch.LongTensor([cap_token]) nps = self.np_extractor.sent_parse(cap) nps = [ torch.LongTensor([self.tokenizer.tokenize(np, 6)]) for np in nps ] # nps = torch.cat(nps) return cap_token, nps def __getitem__(self, index): # sample pos_index, neg_index = self._triplet_sample(index) # load image curr_img = self._load_img(index) pos_img = self._load_img(pos_index) neg_img = self._load_img(neg_index) # load caption cap, nps = self._load_cap(index) pos_cap, pos_nps = self._load_cap(pos_index) neg_cap, neg_nps = self._load_cap(neg_index) # load pid pid = self.ann2person[index] pid = self.person2label[pid] pos_pid = self.ann2person[pos_index] pos_pid = self.person2label[pos_pid] neg_pid = self.ann2person[neg_index] neg_pid = self.person2label[neg_pid] return (curr_img, pos_img, neg_img, cap, pos_cap, neg_cap, nps, pos_nps, neg_nps, pid, pos_pid, neg_pid)
class WIDERTextDataset(data.Dataset): """ Basic. Return - caption (indexed) - image_fn (private_key) """ def __init__(self, anno_path, vocab_fn, split="val", sent_token_length=40, np_token_length=6, num_np_per_sent=6, debug=False): super(WIDERTextDataset, self).__init__() # load annotations with open(anno_path, 'r') as f: anns = json.load(f) self.anns = [ann for ann in anns if ann['split'] == split] self.anns = self.anns[:500] if debug else self.anns # init tokenizer self.sent_token_length = sent_token_length self.np_token_length = np_token_length self.num_np_per_sent = num_np_per_sent self.tokenizer = WIDER_Tokenizer(vocab_fn) # init captions data self.captions, self.images = [], [] for ann in self.anns: for cap in ann['captions']: self.captions.append(cap) self.images.append(ann['file_path']) self.len = len(self.captions) def get_all_keys(self): return self.images def __len__(self): return self.len def _load_cap(self, index): cap = self.captions[index] cap_token = self.tokenizer.tokenize(cap, self.sent_token_length) cap_token = torch.LongTensor([cap_token])[0] return cap_token def __getitem__(self, index): cap = self._load_cap(index) image_fn = self.images[index] return cap, image_fn
def __init__(self, anno_path, img_dir, mask_dir, vocab_fn, sent_token_length=40, np_token_length=6, num_np_per_sent=6, split='train', transform=None, debug=False): super(WIDERTriplet_NP, self).__init__(anno_path, img_dir, split, transform, debug) self.sent_token_length = sent_token_length self.np_token_length = np_token_length self.num_np_per_sent = num_np_per_sent self.tokenizer = WIDER_Tokenizer(vocab_fn) self.np_extractor = NPExtractor() print("size of dataset: %d" % len(self.anns))
class WIDERTriplet_Part(WIDERTriplet): def __init__(self, anno_path, img_dir, mask_dir, vocab_fn, sent_token_length=40, np_token_length=6, num_np_per_sent=6, split='train', transform=None, debug=False): super(WIDERTriplet_Part, self).__init__(anno_path, img_dir, split, transform, debug) self.mask_dir = mask_dir self.toTensor = transforms.ToTensor() self.sent_token_length = sent_token_length self.np_token_length = np_token_length self.num_np_per_sent = num_np_per_sent self.tokenizer = WIDER_Tokenizer(vocab_fn) self.np_extractor = NPExtractor() print("size of dataset: %d" % len(self.anns)) def _load_cap(self, index): len_sent = self.sent_token_length len_np = self.np_token_length N = self.num_np_per_sent cap = self.anns[index]['captions'] cap_token = self.tokenizer.tokenize(cap, len_sent) cap_token = torch.LongTensor([cap_token])[0] nps = self.np_extractor.sent_parse(cap) nps = [ torch.LongTensor([self.tokenizer.tokenize(np, len_np)]) for np in nps ] num_nps = min(len(nps), N) nps = torch.cat(nps) nps = nps[:N] if nps.size(0) > N else torch.cat( (nps, torch.ones((N - nps.size(0), len_np)).long())) return cap_token, nps, num_nps def _load_img(self, index): img_fn = self.anns[index]['file_path'] mask_fn = img_fn.replace('/', '_')[:-4] + '.npy' img_path = os.path.join(self.img_dir, img_fn) mask_path = os.path.join(self.mask_dir, mask_fn) image = Image.open(img_path).convert('RGB') mask = torch.from_numpy(np.load(mask_path)).float() if self.transform: image = self.transform(image) if random.random() > 0.5: image = torch.flip(image, [2]) mask = torch.flip(mask, [2]) return image, mask def __getitem__(self, index): # sample # load image img, mask = self._load_img(index) # load caption cap, nps, num_nps = self._load_cap(index) # load pid pid = self.ann2person[index] pid = self.person2label[pid] return (img, mask, cap, nps, num_nps, pid)