def collate_examples(examples): """ Turns a list of examples into a workable batch: """ if len(examples) == 1: return examples[0] B = len(examples) max_len = max(ex['x'].shape[1] for ex in examples) x_vals = [] choice_mats = [] choice_masks = [] y = torch.zeros((B, 1)) lengths = torch.zeros((B, ), dtype=torch.long) masks = torch.zeros((B, max_len)) for i, ex in enumerate(examples): x_vals.append(ex['x'].val) choice_mats.append(ex['x'].choice_mat) choice_masks.append(ex['x'].choice_mask) cur_len = ex['x'].shape[1] masks[i, :cur_len] = 1 y[i, 0] = ex['y'] lengths[i] = ex['lengths'][0] x_vals = data_util.multi_dim_padded_cat(x_vals, 0).long() choice_mats = data_util.multi_dim_padded_cat(choice_mats, 0).long() choice_masks = data_util.multi_dim_padded_cat(choice_masks, 0).long() return { 'x': ibp.DiscreteChoiceTensor(x_vals, choice_mats, choice_masks, masks), 'y': y, 'mask': masks, 'lengths': lengths }
def from_raw_data(cls, raw_data, vocab, attack_surface=None, truncate_to=None, downsample_to=None, downsample_shard=0): if downsample_to: raw_data = raw_data[downsample_shard * downsample_to:(downsample_shard + 1) * downsample_to] examples = [] for x, y in raw_data: all_words = [w.lower() for w in x.split()] if attack_surface: all_swaps = attack_surface.get_swaps(all_words) words = [w for w in all_words if w in vocab] swaps = [s for w, s in zip(all_words, all_swaps) if w in vocab] choices = [[w] + cur_swaps for w, cur_swaps in zip(words, swaps)] else: words = [w for w in all_words if w in vocab] # Delete UNK words # The input may be a single out-of-vocab word. is_oov_example = (len(words) == 0) if is_oov_example: words = [vocabulary.NULL_TOKEN] if truncate_to: words = words[:truncate_to] word_idxs = [vocab.get_index(w) for w in words] x_torch = torch.tensor(word_idxs).view(1, -1, 1) # (1, T, d) if attack_surface and not is_oov_example: choices_word_idxs = [ torch.tensor([vocab.get_index(c) for c in c_list], dtype=torch.long) for c_list in choices ] if any(0 in c.view(-1).tolist() for c in choices_word_idxs): raise ValueError("UNK tokens found") choices_torch = pad_sequence( choices_word_idxs, batch_first=True).unsqueeze(2).unsqueeze(0) # (1, T, C, 1) choices_mask = (choices_torch.squeeze(-1) != 0).long() # (1, T, C) else: choices_torch = x_torch.view(1, -1, 1, 1) # (1, T, 1, 1) choices_mask = torch.ones_like(x_torch.view(1, -1, 1)) mask_torch = torch.ones((1, len(word_idxs))) x_bounded = ibp.DiscreteChoiceTensor(x_torch, choices_torch, choices_mask, mask_torch) y_torch = torch.tensor(y, dtype=torch.float).view(1, 1) lengths_torch = torch.tensor(len(word_idxs)).view(1) examples.append( dict(x=x_bounded, y=y_torch, mask=mask_torch, lengths=lengths_torch)) return cls(raw_data, vocab, examples)
def process_example(cls, inpt, y, vocab, attack_surface, skip_prem=True, prepend_null=False): example = {} for idx, sequence in enumerate(['prem', 'hypo']): x = inpt[idx] all_words = x.split() if prepend_null: all_words = ['<NULL>'] + all_words words = [w for w in all_words if w in vocab] # Delete UNK words word_idxs = [vocab.get_index(w) for w in words] if len(word_idxs) < 1: raise ValueError( "Sequence:\n\t{}\n is all UNK words in sample:\n \t{}\n". format(x, inpt)) x_torch = torch.tensor(word_idxs).view(1, -1, 1) # (1, T, d) if attack_surface and not (skip_prem and sequence == 'prem'): swap_words = all_words[ 1:] if prepend_null else all_words # Don't try to swap NULL all_swaps = attack_surface.get_swaps(swap_words) if prepend_null: all_swaps = [ [] ] + all_swaps # Add an empty swaps list at index 0 for NULL swaps = [s for w, s in zip(all_words, all_swaps) if w in vocab] choices = [[w] + cur_swaps for w, cur_swaps in zip(words, swaps)] choices_word_idxs = [ torch.tensor([vocab.get_index(c) for c in c_list], dtype=torch.long) for c_list in choices ] if any(0 in choices.view(-1).tolist() for choices in choices_word_idxs): raise ValueError("UNK tokens found") choices_torch = pad_sequence( choices_word_idxs, batch_first=True).unsqueeze(2).unsqueeze(0) # (1, T, C, 1) choices_mask = (choices_torch.squeeze(-1) != 0).long() # (1, T, C) else: choices_torch = x_torch.view(1, -1, 1, 1) # (1, T, 1, 1) choices_mask = torch.ones_like(x_torch.view(1, -1, 1)) mask_torch = torch.ones((1, len(word_idxs))) x_bounded = ibp.DiscreteChoiceTensor(x_torch, choices_torch, choices_mask, mask_torch) lengths_torch = torch.tensor(len(word_idxs)).view(1) example[sequence] = dict(x=x_bounded, mask=mask_torch, lengths=lengths_torch) example['y'] = torch.zeros((1, len(EntailmentLabels)), dtype=torch.float) example['y'][0, y.value] = 1 return example
def augment(self, dataset): new_examples = [] for ex in tqdm(dataset.examples): new_examples.append(ex) x_orig = ex['x'] # (1, T, 1) choices = [] for i in range(x_orig.shape[1]): cur_choices = torch.masked_select( x_orig.choice_mat[0, i, :, 0], x_orig.choice_mask[0, i, :].type(torch.uint8)) choices.append(cur_choices) for t in range(self.augment_by): x_new = torch.stack([ choices[i][random.choice(range(len(choices[i])))] for i in range(len(choices)) ]).view(1, -1, 1) x_bounded = ibp.DiscreteChoiceTensor(x_new, x_orig.choice_mat, x_orig.choice_mask, x_orig.sequence_mask) ex_new = dict(ex) ex_new['x'] = x_bounded new_examples.append(ex_new) return TextClassificationDataset(None, dataset.vocab, new_examples)
def collate_examples(examples): """ Turns a list of examples into a workable batch: """ if len(examples) == 1: return examples[0] B = len(examples) max_prem_len = max(ex['prem']['x'].shape[1] for ex in examples) prem_vals = [] prem_choice_mats = [] prem_choice_masks = [] prem_lengths = torch.zeros((B, ), dtype=torch.long) prem_masks = torch.zeros((B, max_prem_len)) max_hypo_len = max(ex['hypo']['x'].shape[1] for ex in examples) hypo_vals = [] hypo_choice_mats = [] hypo_choice_masks = [] hypo_lengths = torch.zeros((B, ), dtype=torch.long) hypo_masks = torch.zeros((B, max_hypo_len)) gold_ys = [] for i, ex in enumerate(examples): prem_vals.append(ex['prem']['x'].val) prem_choice_mats.append(ex['prem']['x'].choice_mat) prem_choice_masks.append(ex['prem']['x'].choice_mask) cur_prem_len = ex['prem']['x'].shape[1] prem_masks[i, :cur_prem_len] = 1 prem_lengths[i] = ex['prem']['lengths'][0] hypo_vals.append(ex['hypo']['x'].val) hypo_choice_mats.append(ex['hypo']['x'].choice_mat) hypo_choice_masks.append(ex['hypo']['x'].choice_mask) cur_hypo_len = ex['hypo']['x'].shape[1] hypo_masks[i, :cur_hypo_len] = 1 hypo_lengths[i] = ex['hypo']['lengths'][0] gold_ys.append(ex['y']) prem_vals = data_util.multi_dim_padded_cat(prem_vals, 0).long() prem_choice_mats = data_util.multi_dim_padded_cat(prem_choice_mats, 0).long() prem_choice_masks = data_util.multi_dim_padded_cat( prem_choice_masks, 0).long() hypo_vals = data_util.multi_dim_padded_cat(hypo_vals, 0).long() hypo_choice_mats = data_util.multi_dim_padded_cat(hypo_choice_mats, 0).long() hypo_choice_masks = data_util.multi_dim_padded_cat( hypo_choice_masks, 0).long() y = torch.cat(gold_ys, 0) return { 'prem': { 'x': ibp.DiscreteChoiceTensor(prem_vals, prem_choice_mats, prem_choice_masks, prem_masks), 'mask': prem_masks, 'lengths': prem_lengths }, 'hypo': { 'x': ibp.DiscreteChoiceTensor(hypo_vals, hypo_choice_mats, hypo_choice_masks, hypo_masks), 'mask': hypo_masks, 'lengths': hypo_lengths }, 'y': y }