def test_prepare_bart_decoder_inputs(self):
        config, *_ = self._get_config_and_data(output_past=False)
        input_ids = _long_tensor(([4, 4, 2]))  # only used for .device if decoder_input_ids is passed
        decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
        ignore = LARGE_NEGATIVE
        decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
        expected_mask = torch.tensor(
            [
                [0, ignore, ignore],
                [0, 0, ignore],
                [ignore, ignore, ignore],  # never attend to the final token, because its pad
            ]
        ).to(input_ids.device)
        self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3))
        self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all())

        # Test no causal mask
        config, *_ = self._get_config_and_data(output_past=True)
        expected_just_padding_mask = torch.tensor(
            [[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]]  # never attend to the final token, because its pad
        ).to(input_ids.device)
        _, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
        self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3))
        self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all())

        decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]])
        # Attend to everything if no pad tokens and no causal mask
        _, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs(
            config, input_ids, decoder_input_ids
        )
        self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
Exemple #2
0
    def test_advanced_inputs(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.use_cache = False
        inputs_dict["input_ids"][:, -2:] = config.pad_token_id
        decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
            config, inputs_dict["input_ids"]
        )
        model = BartModel(config).to(torch_device).eval()

        decoder_features_with_created_mask = model(**inputs_dict)[0]
        decoder_features_with_passed_mask = model(
            decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
        )[0]
        _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
        useless_mask = torch.zeros_like(decoder_attn_mask)
        decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
        self.assertTrue(isinstance(decoder_features, torch.Tensor))  # no hidden states or attentions
        self.assertEqual(
            decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)
        )
        if decoder_attn_mask.min().item() < -1e3:  # some tokens were masked
            self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())

        # Test different encoder attention masks
        decoder_features_with_long_encoder_mask = model(
            inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long()
        )[0]
        _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
Exemple #3
0
    def forward(
        self,
        input_ids,
        decoder_input_ids,
        attention_mask=None,
        decoder_padding_mask=None,
        encoder_outputs=None,
        return_encoder_outputs=False,
    ):
        if attention_mask is None:
            attention_mask = input_ids == self.config.pad_token_id

        if encoder_outputs is None:
            encoder_outputs = self.encoder(input_ids,
                                           attention_mask=attention_mask)

        if return_encoder_outputs:
            return encoder_outputs

        assert encoder_outputs is not None
        assert decoder_input_ids is not None

        decoder_input_ids = decoder_input_ids[:, :-1]

        _, decoder_padding_mask, decoder_causal_mask = _prepare_bart_decoder_inputs(
            self.config,
            input_ids=None,
            decoder_input_ids=decoder_input_ids,
            decoder_padding_mask=decoder_padding_mask,
            causal_mask_dtype=self.shared.weight.dtype,
        )

        attention_mask2 = torch.cat(
            (torch.zeros(input_ids.shape[0], 1).bool().cuda(),
             attention_mask[:, self.config.max_sent_len + 2:]),
            dim=1)

        # decoder
        decoder_outputs = self.decoder(
            decoder_input_ids,
            torch.cat((encoder_outputs[1],
                       encoder_outputs[0][:, self.config.max_sent_len + 2:]),
                      dim=1),
            decoder_padding_mask=decoder_padding_mask,
            decoder_causal_mask=decoder_causal_mask,
            encoder_attention_mask=attention_mask2,
        )[0]

        batch_size = decoder_outputs.shape[0]
        outputs = self.linear(decoder_outputs.contiguous().view(
            -1, self.config.d_model))
        outputs = outputs.view(batch_size, -1, self.config.vocab_size)

        # discriminator
        for p in self.adversary.parameters():
            p.required_grad = False
        adv_outputs = self.adversary(encoder_outputs[1])

        return outputs, adv_outputs
Exemple #4
0
 def test_prepare_bart_decoder_inputs(self):
     config, *_ = self._get_config_and_data()
     input_ids = _long_tensor(([4, 4, 2]))
     decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
     ignore = float("-inf")
     decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
         config, input_ids, decoder_input_ids
     )
     expected_causal_mask = torch.tensor(
         [[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]]  # never attend to the final token, because its pad
     ).to(input_ids.device)
     self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
     self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
    def test_advanced_inputs(self):
        # (config, input_ids, token_type_ids, input_mask, *unused) = \
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
            config, inputs_dict["input_ids"])
        model = BartModel(config)
        model.to(torch_device)
        model.eval()
        # test init
        self.assertTrue(
            (model.encoder.embed_tokens.weight == model.shared.weight
             ).all().item())

        def _check_var(module):
            """Check that we initialized various parameters from N(0, config.init_std)."""
            self.assertAlmostEqual(
                torch.std(module.weight).item(), config.init_std, 2)

        _check_var(model.encoder.embed_tokens)
        _check_var(model.encoder.layers[0].self_attn.k_proj)
        _check_var(model.encoder.layers[0].fc1)
        _check_var(model.encoder.embed_positions)

        decoder_features_with_created_mask = model.forward(**inputs_dict)[0]
        decoder_features_with_passed_mask = model.forward(
            decoder_attention_mask=decoder_attn_mask,
            decoder_input_ids=decoder_input_ids,
            **inputs_dict)[0]
        _assert_tensors_equal(decoder_features_with_passed_mask,
                              decoder_features_with_created_mask)
        useless_mask = torch.zeros_like(decoder_attn_mask)
        decoder_features = model.forward(decoder_attention_mask=useless_mask,
                                         **inputs_dict)[0]
        self.assertTrue(isinstance(
            decoder_features, torch.Tensor))  # no hidden states or attentions
        self.assertEqual(decoder_features.size(),
                         (self.model_tester.batch_size,
                          self.model_tester.seq_length, config.d_model))
        if decoder_attn_mask.min().item() < -1e3:  # some tokens were masked
            self.assertFalse(
                (decoder_features_with_created_mask == decoder_features
                 ).all().item())

        # Test different encoder attention masks
        decoder_features_with_long_encoder_mask = model.forward(
            inputs_dict["input_ids"],
            attention_mask=inputs_dict["attention_mask"].long())[0]
        _assert_tensors_equal(decoder_features_with_long_encoder_mask,
                              decoder_features_with_created_mask)
Exemple #6
0
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        encoder_outputs: Optional[Tuple] = None,
        decoder_attention_mask=None,
        decoder_cached_states=None,
        use_cache=False,
    ):

        # make masks if user doesn't supply
        if not use_cache:
            (
                decoder_input_ids,
                decoder_padding_mask,
                causal_mask,
            ) = _prepare_bart_decoder_inputs(
                self.config,
                input_ids,
                decoder_input_ids=decoder_input_ids,
                decoder_padding_mask=decoder_attention_mask,
                causal_mask_dtype=self.shared.weight.dtype,
            )
        else:
            decoder_padding_mask, causal_mask = None, None

        assert decoder_input_ids is not None
        if input_ids is None and encoder_outputs is None:
            encoder_outputs = (None, None)
        elif encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids, attention_mask=attention_mask
            )
        assert isinstance(encoder_outputs, tuple)
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            decoder_input_ids,
            encoder_outputs[0],
            attention_mask,
            decoder_padding_mask,
            decoder_causal_mask=causal_mask,
            decoder_cached_states=decoder_cached_states,
            use_cache=use_cache,
        )
        # Attention and hidden_states will be [] or None if they aren't needed
        decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs)
        assert isinstance(decoder_outputs[0], torch.Tensor)
        encoder_outputs: Tuple = _filter_out_falsey_values(encoder_outputs)
        return decoder_outputs + encoder_outputs
 def forward(self, input_ids, output_ids, input_mask, output_mask):
     # encoder_hidden_states: [batch_size, max_length, hidden_size]
     encoder_hidden_states = self.encoder(input_ids=input_ids,
                                          attention_mask=input_mask)
     # out: [batch_size, max_length, hidden_size]
     decoder_input_ids, decoder_padding_mask, causal_mask = modeling_bart._prepare_bart_decoder_inputs(
         self.config, input_ids=output_ids)
     out, _, _, _ = self.decoder(
         input_ids=decoder_input_ids,
         encoder_padding_mask=input_mask,
         decoder_padding_mask=decoder_padding_mask,
         decoder_causal_mask=causal_mask,
         encoder_hidden_states=encoder_hidden_states[0])
     out = self.linear(out)
     return out
Exemple #8
0
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        encoder_outputs: Optional[Tuple] = None,
        decoder_attention_mask=None,
        decoder_past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_tuple=None,
        **kwargs,
    ):

        if decoder_input_ids is None:
            use_cache = False

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple

        # make masks if user doesn't supply
        if not use_cache:
            decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
                self.config,
                input_ids,
                decoder_input_ids=decoder_input_ids,
                decoder_padding_mask=decoder_attention_mask,
                causal_mask_dtype=self.shared.weight.dtype,
            )
        else:
            decoder_padding_mask, causal_mask = None, None

        causal_mask[0, 1] = 0

        assert decoder_input_ids is not None

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_tuple=return_tuple,
            )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_tuple=False
        elif not return_tuple and not isinstance(encoder_outputs,
                                                 BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            decoder_input_ids,
            encoder_outputs[0],
            attention_mask,
            decoder_padding_mask,
            decoder_causal_mask=causal_mask,
            decoder_past_key_values=decoder_past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_tuple=return_tuple,
        )

        if return_tuple:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
Exemple #9
0
def sample_generate(
        top_k=50,
        temperature=0.7,
        decoder_path="/content/BART_CheckPoints/model-9.pth",
        batch_size=1,
        gpu_id=0
):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    print('load model')
    # ------------------------LOAD MODEL-----------------
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
    model = BartLM()
    model.load_state_dict(torch.load(decoder_path, map_location='cuda'))
    model = model.to(device)
    model.eval()

    print('load success')
    # ------------------------END LOAD MODEL--------------

    # ------------------------LOAD VALIDATE DATA------------------
    test_data = torch.load("/content/test_data.pth")
    test_dataset = TensorDataset(*test_data)
    test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size)
    # ------------------------END LOAD VALIDATE DATA--------------

    # ------------------------START SAMPLE GENERETE-------------------
    update_count = 0

    bleu_2scores = 0
    bleu_4scores = 0
    nist_2scores = 0
    nist_4scores = 0

    sen_length = 0
    meteor_scores = 0

    sentences = []
    print('start generate....')

    for batch in test_dataloader:
        with torch.no_grad():
            #############################################################
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, _ = batch

            past = model.encoder(encoder_input, mask_encoder_input)

            prev_pred = decoder_input[:, :1]
            sentence = prev_pred

            # decoding loop
            for i in range(100):
                decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
                    config, input_ids=sentence)
                logits, _, _, _ = model.decoder(input_ids=decoder_input_ids, encoder_padding_mask=mask_encoder_input,
                                                decoder_padding_mask=decoder_padding_mask,
                                                decoder_causal_mask=causal_mask,
                                                encoder_hidden_states=past[0])
                logits = model.linear(logits)
                logits = logits[:, -1]
                logits = logits.squeeze(1) / temperature

                logits = top_k_logits(logits, k=top_k)
                probs = F.softmax(logits, dim=-1)
                prev_pred = torch.multinomial(probs, num_samples=1)
                sentence = torch.cat([sentence, prev_pred], dim=-1)
                if prev_pred[0][0] == 102:
                    break

            predict = tokenizer.convert_ids_to_tokens(sentence[0].tolist())
            target = decoder_input.squeeze(dim=0)
            target_num = (target != 0).sum()
            inputs = encoder_input.squeeze(dim=0)
            input_num = (inputs != 0).sum()
            inputs = tokenizer.convert_ids_to_tokens(inputs[:input_num].tolist())
            reference = tokenizer.convert_ids_to_tokens(target[:target_num].tolist())

            print('-' * 20 + f"example {update_count}" + '-' * 20)
            print("input: {}".format(re.sub("Ġ", "", " ".join(inputs))))
            print("output: {}".format(re.sub("Ġ", "", " ".join(reference))))
            print("predict: {}".format(re.sub("Ġ", "", " ".join(predict))))

            temp_bleu_2, \
            temp_bleu_4, \
            temp_nist_2, \
            temp_nist_4, \
            temp_meteor_scores = calculate_metrics(predict[1:-1], reference[1:-1])

            bleu_2scores += temp_bleu_2
            bleu_4scores += temp_bleu_4
            nist_2scores += temp_nist_2
            nist_4scores += temp_nist_4

            meteor_scores += temp_meteor_scores
            sentences.append(" ".join(predict[1:-1]))
            update_count += 1

    entro, dist = cal_entropy(sentences)
    mean_len, var_len = cal_length(sentences)
    print(f'avg: {mean_len}, var: {var_len}')
    print(f'entro: {entro}')
    print(f'dist: {dist}')
    print(f'test bleu_2scores: {bleu_2scores / update_count}')
    print(f'test bleu_4scores: {bleu_4scores / update_count}')
    print(f'test nist_2scores: {nist_2scores / update_count}')
    print(f'test nist_4scores: {nist_4scores / update_count}')
    print(f'test meteor_scores: {meteor_scores / update_count}')
Exemple #10
0
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        encoder_outputs: Optional[Tuple] = None,
        decoder_attention_mask=None,
        decoder_cached_states=None,
        use_cache=False,
        final_layer=None,
    ):
        if encoder_outputs is None and input_ids is not None:
            encoder_outputs = self.encoder(
                input_ids=input_ids, attention_mask=attention_mask
            )
        elif encoder_outputs is None:
            encoder_outputs = (None,)
        assert isinstance(encoder_outputs, tuple)

        if decoder_cached_states is None:
            decoder_cached_states = [None] * len(decoder_input_ids)

        if decoder_attention_mask is None:
            decoder_attention_mask = [None] * len(decoder_input_ids)

        all_dec_outputs = []
        if isinstance(final_layer, int):
            if isinstance(decoder_input_ids, list):
                decoder_input_ids = [decoder_input_ids[final_layer]]
                decoder_attention_mask = [decoder_attention_mask[final_layer]]
                decoder_cached_states = [decoder_cached_states[final_layer]]
            else:  # decoder doesn't come in multi output: generation time
                decoder_input_ids = [decoder_input_ids]
                decoder_attention_mask = [decoder_attention_mask]
                decoder_cached_states = (
                    decoder_cached_states
                    if decoder_cached_states[0] is None
                    else [decoder_cached_states]
                )

        # If final_layer is None (i.e. at training time), it will compute both outputs
        # Otherwise a specified decoder branch will be used
        for idx, (d_input_ids, d_attn_mask, d_cached_states) in enumerate(
            zip(decoder_input_ids, decoder_attention_mask, decoder_cached_states)
        ):
            # make masks if user doesn't supply
            if not use_cache:
                (
                    d_input_ids,
                    d_padding_mask,
                    causal_mask,
                ) = _prepare_bart_decoder_inputs(
                    self.config,
                    input_ids,
                    decoder_input_ids=d_input_ids,
                    decoder_padding_mask=d_attn_mask,
                    causal_mask_dtype=self.shared.weight.dtype,
                )
            else:
                d_padding_mask, causal_mask = None, None

            assert decoder_input_ids is not None
            # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
            decoder_outputs = self.decoder(
                d_input_ids,
                encoder_outputs[0],
                attention_mask,
                d_padding_mask,
                decoder_causal_mask=causal_mask,
                decoder_cached_states=d_cached_states,
                use_cache=use_cache,
                final_layer=(final_layer if final_layer is not None else idx),
            )
            all_dec_outputs.append(decoder_outputs)

        # Attention and hidden_states will be [] or None if they aren't needed
        decoder_outputs = [
            _filter_out_falsey_values(d_outs) for d_outs in all_dec_outputs
        ]
        assert isinstance(decoder_outputs[0][0], torch.Tensor)
        encoder_outputs = _filter_out_falsey_values(encoder_outputs)
        return (decoder_outputs, encoder_outputs)