def elmo_input_reshape(inputs: TextFieldTensors, batch_size: int, number_targets: int, batch_size_num_targets: int) -> TextFieldTensors: ''' NOTE: This does not work for the hugginface transformers as when they are processed by the token indexers they produce additional key other than token ids such as mask ids and segment ids that also need handling, of which we have not had time to handle this yet. A way around this, which would be more appropriate, would be to use `target_sequences` like in the `InContext` model, to generate contextualised targets from the context rather than using the target words as is without context. :param inputs: The token indexer dictionary where the keys state the token indexer and the values are the Tensors that are of shape (Batch Size, Number Targets, Sequence Length) :param batch_size: The Batch Size :param number_targets: The max number of targets in the batch :param batch_size_num_targets: Batch Size * number of targets :returns: If the inputs contains a `elmo` or 'token_characters' key it will reshape all the keys values into shape (Batch Size * Number Targets, Sequence Length) so that it can be processed by the ELMO or character embedder/encoder. ''' if 'elmo' in inputs or 'token_characters' in inputs: temp_inputs: TextFieldTensors = defaultdict(dict) for key, inner_key_value in inputs.items(): for inner_key, value in inner_key_value.items(): temp_value = value.view(batch_size_num_targets, *value.shape[2:]) temp_inputs[key][inner_key] = temp_value return dict(temp_inputs) else: return inputs
def mask_tokens(self, inputs: TextFieldTensors) -> Tuple[TextFieldTensors, TextFieldTensors]: masked_inputs = dict() masked_targets = dict() for text_field_name, text_field in inputs.items(): masked_inputs[text_field_name] = dict() masked_targets[text_field_name] = dict() for key, tokens in text_field.items(): labels = tokens.clone() indices_masked = torch.bernoulli( torch.full(labels.shape, self.mask_probability, device=tokens.device) ).bool() tokens[indices_masked] = self.mask_idx indices_random = ( torch.bernoulli(torch.full(labels.shape, self.replace_probability, device=tokens.device)).bool() & ~indices_masked ) random_tokens = torch.randint( low=1, high=self.vocab_size, size=labels.shape, dtype=torch.long, device=tokens.device, ) tokens[indices_random] = random_tokens[indices_random] masked_inputs[text_field_name][key] = tokens masked_targets[text_field_name][key] = labels return masked_inputs, masked_targets