Пример #1
0
    def __call__(self, examples: List[Dict[str, List[int]]]) -> BatchEncoding:
        # convert list to dict and tensorize input
        batch = BatchEncoding({
            k: np.array([examples[i][k] for i in range(len(examples))])
            for k, v in examples[0].items()
        })
        batch["labels"] = batch["input_ids"].copy()
        batch["decoder_input_ids"] = shift_tokens_right(
            batch["labels"], self.tokenizer.pad_token_id,
            self.decoder_start_token_id)
        # permuting sentences
        do_permute = False
        if self.permute_sentence_ratio > 0.0:
            batch["input_ids"] = self.permute_sentences(batch["input_ids"])
            do_permute = True

        # masking span of tokens (text infilling in the paper)
        if self.mask_ratio:
            batch["input_ids"], batch["labels"] = self.span_mask_tokens(
                batch["input_ids"], batch["labels"], do_permute)

        # ignore pad tokens
        batch["attention_mask"] = (batch["input_ids"] !=
                                   self.tokenizer.pad_token_id).astype(int)
        batch["decoder_attention_mask"] = (
            batch["decoder_input_ids"] !=
            self.tokenizer.pad_token_id).astype(int)
        return batch
Пример #2
0
    def prepare_config_and_inputs(self):
        input_ids = np.clip(
            ids_tensor([self.batch_size, self.seq_length - 1],
                       self.vocab_size), 3, self.vocab_size)
        input_ids = np.concatenate((input_ids, 2 * np.ones(
            (self.batch_size, 1), dtype=np.int64)), -1)

        decoder_input_ids = shift_tokens_right(input_ids, 1, 2)

        config = BartConfig(
            vocab_size=self.vocab_size,
            d_model=self.hidden_size,
            encoder_layers=self.num_hidden_layers,
            decoder_layers=self.num_hidden_layers,
            encoder_attention_heads=self.num_attention_heads,
            decoder_attention_heads=self.num_attention_heads,
            encoder_ffn_dim=self.intermediate_size,
            decoder_ffn_dim=self.intermediate_size,
            dropout=self.hidden_dropout_prob,
            attention_dropout=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            eos_token_id=self.eos_token_id,
            bos_token_id=self.bos_token_id,
            pad_token_id=self.pad_token_id,
            initializer_range=self.initializer_range,
            use_cache=False,
        )
        inputs_dict = prepare_bart_inputs_dict(config, input_ids,
                                               decoder_input_ids)
        return config, inputs_dict
 def test_shift_tokens_right(self):
     input_ids = np.array([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=np.int64)
     shifted = shift_tokens_right(input_ids, 1, 2)
     n_pad_before = np.equal(input_ids, 1).astype(np.float32).sum()
     n_pad_after = np.equal(shifted, 1).astype(np.float32).sum()
     self.assertEqual(shifted.shape, input_ids.shape)
     self.assertEqual(n_pad_after, n_pad_before - 1)
     self.assertTrue(np.equal(shifted[:, 0], 2).all())