Пример #1
0
def elmo_input_reshape(inputs: TextFieldTensors, batch_size: int,
                       number_targets: int,
                       batch_size_num_targets: int) -> TextFieldTensors:
    '''
    NOTE: This does not work for the hugginface transformers as when they are 
    processed by the token indexers they produce additional key other than 
    token ids such as mask ids and segment ids that also need handling, of 
    which we have not had time to handle this yet. A way around this, which 
    would be more appropriate, would be to use `target_sequences` like in the 
    `InContext` model, to generate contextualised targets from the context rather 
    than using the target words as is without context.

    :param inputs: The token indexer dictionary where the keys state the token 
                   indexer and the values are the Tensors that are of shape 
                   (Batch Size, Number Targets, Sequence Length)
    :param batch_size: The Batch Size
    :param number_targets: The max number of targets in the batch
    :param batch_size_num_targets: Batch Size * number of targets
    :returns: If the inputs contains a `elmo` or 'token_characters' key it will 
              reshape all the keys values into shape 
              (Batch Size * Number Targets, Sequence Length) so that it can be 
              processed by the ELMO or character embedder/encoder. 
    '''
    if 'elmo' in inputs or 'token_characters' in inputs:
        temp_inputs: TextFieldTensors = defaultdict(dict)
        for key, inner_key_value in inputs.items():
            for inner_key, value in inner_key_value.items():
                temp_value = value.view(batch_size_num_targets,
                                        *value.shape[2:])
                temp_inputs[key][inner_key] = temp_value
        return dict(temp_inputs)
    else:
        return inputs
Пример #2
0
    def mask_tokens(self, inputs: TextFieldTensors) -> Tuple[TextFieldTensors, TextFieldTensors]:

        masked_inputs = dict()
        masked_targets = dict()
        for text_field_name, text_field in inputs.items():
            masked_inputs[text_field_name] = dict()
            masked_targets[text_field_name] = dict()
            for key, tokens in text_field.items():
                labels = tokens.clone()

                indices_masked = torch.bernoulli(
                    torch.full(labels.shape, self.mask_probability, device=tokens.device)
                ).bool()
                tokens[indices_masked] = self.mask_idx

                indices_random = (
                    torch.bernoulli(torch.full(labels.shape, self.replace_probability, device=tokens.device)).bool()
                    & ~indices_masked
                )
                random_tokens = torch.randint(
                    low=1, high=self.vocab_size, size=labels.shape, dtype=torch.long, device=tokens.device,
                )
                tokens[indices_random] = random_tokens[indices_random]

                masked_inputs[text_field_name][key] = tokens
                masked_targets[text_field_name][key] = labels
        return masked_inputs, masked_targets