def build_insertion_example( self, source_tokens, labels, target_insertions = None ): """Constructs the masked input TF Example for the insertion model. Args: source_tokens: List of source tokens. labels: List of edit label tuples (base_tag, num_insertions). target_insertions: Inserted target tokens per source token. Only provided when constructing training examples. Returns: A feed_dict containing input features to be fed to a predictor (for inference) or to be converted to a tf.Example (for model training). If the labels don't contain any insertions, returns None. """ masked_tokens, target_tokens = self.build_insertion_tokens( source_tokens, labels, target_insertions) if constants.MASK not in masked_tokens: # No need to create an insertion example so return None. return None return utils.build_feed_dict( masked_tokens, self.tokenizer, target_tokens=target_tokens, max_seq_length=self._max_seq_length, max_predictions_per_seq=self._max_predictions_per_seq)
def test_build_feed_dict(self, source, target, masks): feed_dict = utils.build_feed_dict(source, self._tokenizer, target) for i, mask_id in enumerate(feed_dict['masked_lm_ids'][0]): # Ignore padding. if mask_id == 0: continue self.assertEqual( mask_id, self._tokenizer.convert_tokens_to_ids(masks[i])[0])
def _convert_source_sentences_into_batch(self, source_sentences, is_insertion): """Converts source sentence into a batch.""" batch_dictionaries = [] for source_sentence in source_sentences: if is_insertion: # Note source_sentence is the output from the tagging model and # therefore already tokenized. example = utils.build_feed_dict( source_sentence.split(' '), self._builder.tokenizer, max_seq_length=self._sequence_length, max_predictions_per_seq=self._max_predictions) assert example is not None, ( f'Source sentence {source_sentence} returned None when ' 'converting to insertion example.') # Previously the code produced an output with a batch size of 1, this # dimension is removed, as we do arbitrary batching in this code now. example = dict(example) for k, v in example.items(): example[k] = v[0] # Note masked_lm_ids and masked_lm_weights are filled with zeros. batch_dictionaries.append({ 'input_word_ids': np.array(example['input_ids']), 'input_mask': np.array(example['input_mask']), 'input_type_ids': np.array(example['segment_ids']), 'masked_lm_positions': np.array(example['masked_lm_positions']), 'masked_lm_ids': np.array(example['masked_lm_ids']), 'masked_lm_weights': np.array(example['masked_lm_weights']) }) else: example, _ = self._builder.build_bert_example( [source_sentence], target=None, is_test_time=True) assert example is not None, (f'Tagging could not convert ' f'{source_sentence}.') dict_element = { 'input_word_ids': np.array(example.features['input_ids']), 'input_mask': np.array(example.features['input_mask']), 'input_type_ids': np.array(example.features['segment_ids']), } batch_dictionaries.append(dict_element) # Convert from a list of dictionaries to dictionary of lists. batch_list = ({ k: np.array([dic[k] for dic in batch_dictionaries]) for k in batch_dictionaries[0] }) return batch_dictionaries, batch_list
def create_insertion_example( self, source_tokens, labels, source_indexes, target_tokens): """Creates training/test features for insertion model. Args: source_tokens: List of source tokens. labels: List of label IDs, which correspond to a list of labels (KEEP, DELETE, MASK|1, MASK|2...). source_indexes: List of next tokens (see pointing converter for more details) (ordered by source tokens). target_tokens: List of target tokens. Returns: A dictionary of features needed by the tensorflow insertion model. """ # Reorder source sentence, add MASK tokens, adds deleted tokens # (to both source_tokens and target_tokens). masked_tokens, target_tokens = self._create_masked_source( source_tokens, labels, source_indexes, target_tokens) if target_tokens and constants.MASK not in masked_tokens: # Generate random MASKs. if self._do_random_mask: # Don't mask the start or end token. indexes = list(range(1, len(masked_tokens) - 1)) random.shuffle(indexes) # Limit MASK to ~10% of the source tokens. indexes = indexes[:int(len(masked_tokens) * 0.1)] for index in indexes: masked_tokens[index] = constants.MASK elif self._do_lazy_generation: return None return utils.build_feed_dict( masked_tokens, self._tokenizer, target_tokens=target_tokens, max_seq_length=self._max_seq_length, max_predictions_per_seq=self._max_predictions_per_seq)