def test_prepare_bart_decoder_inputs(self): config, *_ = self._get_config_and_data(output_past=False) input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]]) ignore = LARGE_NEGATIVE decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids) expected_mask = torch.tensor( [ [0, ignore, ignore], [0, 0, ignore], [ignore, ignore, ignore], # never attend to the final token, because its pad ] ).to(input_ids.device) self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3)) self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all()) # Test no causal mask config, *_ = self._get_config_and_data(output_past=True) expected_just_padding_mask = torch.tensor( [[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad ).to(input_ids.device) _, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids) self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3)) self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all()) decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]]) # Attend to everything if no pad tokens and no causal mask _, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs( config, input_ids, decoder_input_ids ) self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
def test_advanced_inputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.use_cache = False inputs_dict["input_ids"][:, -2:] = config.pad_token_id decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs( config, inputs_dict["input_ids"] ) model = BartModel(config).to(torch_device).eval() decoder_features_with_created_mask = model(**inputs_dict)[0] decoder_features_with_passed_mask = model( decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict )[0] _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask) useless_mask = torch.zeros_like(decoder_attn_mask) decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0] self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions self.assertEqual( decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model) ) if decoder_attn_mask.min().item() < -1e3: # some tokens were masked self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item()) # Test different encoder attention masks decoder_features_with_long_encoder_mask = model( inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long() )[0] _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
def forward( self, input_ids, decoder_input_ids, attention_mask=None, decoder_padding_mask=None, encoder_outputs=None, return_encoder_outputs=False, ): if attention_mask is None: attention_mask = input_ids == self.config.pad_token_id if encoder_outputs is None: encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask) if return_encoder_outputs: return encoder_outputs assert encoder_outputs is not None assert decoder_input_ids is not None decoder_input_ids = decoder_input_ids[:, :-1] _, decoder_padding_mask, decoder_causal_mask = _prepare_bart_decoder_inputs( self.config, input_ids=None, decoder_input_ids=decoder_input_ids, decoder_padding_mask=decoder_padding_mask, causal_mask_dtype=self.shared.weight.dtype, ) attention_mask2 = torch.cat( (torch.zeros(input_ids.shape[0], 1).bool().cuda(), attention_mask[:, self.config.max_sent_len + 2:]), dim=1) # decoder decoder_outputs = self.decoder( decoder_input_ids, torch.cat((encoder_outputs[1], encoder_outputs[0][:, self.config.max_sent_len + 2:]), dim=1), decoder_padding_mask=decoder_padding_mask, decoder_causal_mask=decoder_causal_mask, encoder_attention_mask=attention_mask2, )[0] batch_size = decoder_outputs.shape[0] outputs = self.linear(decoder_outputs.contiguous().view( -1, self.config.d_model)) outputs = outputs.view(batch_size, -1, self.config.vocab_size) # discriminator for p in self.adversary.parameters(): p.required_grad = False adv_outputs = self.adversary(encoder_outputs[1]) return outputs, adv_outputs
def test_prepare_bart_decoder_inputs(self): config, *_ = self._get_config_and_data() input_ids = _long_tensor(([4, 4, 2])) decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]]) ignore = float("-inf") decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs( config, input_ids, decoder_input_ids ) expected_causal_mask = torch.tensor( [[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad ).to(input_ids.device) self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size()) self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
def test_advanced_inputs(self): # (config, input_ids, token_type_ids, input_mask, *unused) = \ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs( config, inputs_dict["input_ids"]) model = BartModel(config) model.to(torch_device) model.eval() # test init self.assertTrue( (model.encoder.embed_tokens.weight == model.shared.weight ).all().item()) def _check_var(module): """Check that we initialized various parameters from N(0, config.init_std).""" self.assertAlmostEqual( torch.std(module.weight).item(), config.init_std, 2) _check_var(model.encoder.embed_tokens) _check_var(model.encoder.layers[0].self_attn.k_proj) _check_var(model.encoder.layers[0].fc1) _check_var(model.encoder.embed_positions) decoder_features_with_created_mask = model.forward(**inputs_dict)[0] decoder_features_with_passed_mask = model.forward( decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict)[0] _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask) useless_mask = torch.zeros_like(decoder_attn_mask) decoder_features = model.forward(decoder_attention_mask=useless_mask, **inputs_dict)[0] self.assertTrue(isinstance( decoder_features, torch.Tensor)) # no hidden states or attentions self.assertEqual(decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)) if decoder_attn_mask.min().item() < -1e3: # some tokens were masked self.assertFalse( (decoder_features_with_created_mask == decoder_features ).all().item()) # Test different encoder attention masks decoder_features_with_long_encoder_mask = model.forward( inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long())[0] _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
def forward( self, input_ids, attention_mask=None, decoder_input_ids=None, encoder_outputs: Optional[Tuple] = None, decoder_attention_mask=None, decoder_cached_states=None, use_cache=False, ): # make masks if user doesn't supply if not use_cache: ( decoder_input_ids, decoder_padding_mask, causal_mask, ) = _prepare_bart_decoder_inputs( self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_padding_mask=decoder_attention_mask, causal_mask_dtype=self.shared.weight.dtype, ) else: decoder_padding_mask, causal_mask = None, None assert decoder_input_ids is not None if input_ids is None and encoder_outputs is None: encoder_outputs = (None, None) elif encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask ) assert isinstance(encoder_outputs, tuple) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids, encoder_outputs[0], attention_mask, decoder_padding_mask, decoder_causal_mask=causal_mask, decoder_cached_states=decoder_cached_states, use_cache=use_cache, ) # Attention and hidden_states will be [] or None if they aren't needed decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs) assert isinstance(decoder_outputs[0], torch.Tensor) encoder_outputs: Tuple = _filter_out_falsey_values(encoder_outputs) return decoder_outputs + encoder_outputs
def forward(self, input_ids, output_ids, input_mask, output_mask): # encoder_hidden_states: [batch_size, max_length, hidden_size] encoder_hidden_states = self.encoder(input_ids=input_ids, attention_mask=input_mask) # out: [batch_size, max_length, hidden_size] decoder_input_ids, decoder_padding_mask, causal_mask = modeling_bart._prepare_bart_decoder_inputs( self.config, input_ids=output_ids) out, _, _, _ = self.decoder( input_ids=decoder_input_ids, encoder_padding_mask=input_mask, decoder_padding_mask=decoder_padding_mask, decoder_causal_mask=causal_mask, encoder_hidden_states=encoder_hidden_states[0]) out = self.linear(out) return out
def forward( self, input_ids, attention_mask=None, decoder_input_ids=None, encoder_outputs: Optional[Tuple] = None, decoder_attention_mask=None, decoder_past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_tuple=None, **kwargs, ): if decoder_input_ids is None: use_cache = False output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple # make masks if user doesn't supply if not use_cache: decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_padding_mask=decoder_attention_mask, causal_mask_dtype=self.shared.weight.dtype, ) else: decoder_padding_mask, causal_mask = None, None causal_mask[0, 1] = 0 assert decoder_input_ids is not None if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_tuple=return_tuple, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_tuple=False elif not return_tuple and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids, encoder_outputs[0], attention_mask, decoder_padding_mask, decoder_causal_mask=causal_mask, decoder_past_key_values=decoder_past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_tuple=return_tuple, ) if return_tuple: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, decoder_past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )
def sample_generate( top_k=50, temperature=0.7, decoder_path="/content/BART_CheckPoints/model-9.pth", batch_size=1, gpu_id=0 ): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") print('load model') # ------------------------LOAD MODEL----------------- tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') model = BartLM() model.load_state_dict(torch.load(decoder_path, map_location='cuda')) model = model.to(device) model.eval() print('load success') # ------------------------END LOAD MODEL-------------- # ------------------------LOAD VALIDATE DATA------------------ test_data = torch.load("/content/test_data.pth") test_dataset = TensorDataset(*test_data) test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size) # ------------------------END LOAD VALIDATE DATA-------------- # ------------------------START SAMPLE GENERETE------------------- update_count = 0 bleu_2scores = 0 bleu_4scores = 0 nist_2scores = 0 nist_4scores = 0 sen_length = 0 meteor_scores = 0 sentences = [] print('start generate....') for batch in test_dataloader: with torch.no_grad(): ############################################################# batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, _ = batch past = model.encoder(encoder_input, mask_encoder_input) prev_pred = decoder_input[:, :1] sentence = prev_pred # decoding loop for i in range(100): decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( config, input_ids=sentence) logits, _, _, _ = model.decoder(input_ids=decoder_input_ids, encoder_padding_mask=mask_encoder_input, decoder_padding_mask=decoder_padding_mask, decoder_causal_mask=causal_mask, encoder_hidden_states=past[0]) logits = model.linear(logits) logits = logits[:, -1] logits = logits.squeeze(1) / temperature logits = top_k_logits(logits, k=top_k) probs = F.softmax(logits, dim=-1) prev_pred = torch.multinomial(probs, num_samples=1) sentence = torch.cat([sentence, prev_pred], dim=-1) if prev_pred[0][0] == 102: break predict = tokenizer.convert_ids_to_tokens(sentence[0].tolist()) target = decoder_input.squeeze(dim=0) target_num = (target != 0).sum() inputs = encoder_input.squeeze(dim=0) input_num = (inputs != 0).sum() inputs = tokenizer.convert_ids_to_tokens(inputs[:input_num].tolist()) reference = tokenizer.convert_ids_to_tokens(target[:target_num].tolist()) print('-' * 20 + f"example {update_count}" + '-' * 20) print("input: {}".format(re.sub("Ġ", "", " ".join(inputs)))) print("output: {}".format(re.sub("Ġ", "", " ".join(reference)))) print("predict: {}".format(re.sub("Ġ", "", " ".join(predict)))) temp_bleu_2, \ temp_bleu_4, \ temp_nist_2, \ temp_nist_4, \ temp_meteor_scores = calculate_metrics(predict[1:-1], reference[1:-1]) bleu_2scores += temp_bleu_2 bleu_4scores += temp_bleu_4 nist_2scores += temp_nist_2 nist_4scores += temp_nist_4 meteor_scores += temp_meteor_scores sentences.append(" ".join(predict[1:-1])) update_count += 1 entro, dist = cal_entropy(sentences) mean_len, var_len = cal_length(sentences) print(f'avg: {mean_len}, var: {var_len}') print(f'entro: {entro}') print(f'dist: {dist}') print(f'test bleu_2scores: {bleu_2scores / update_count}') print(f'test bleu_4scores: {bleu_4scores / update_count}') print(f'test nist_2scores: {nist_2scores / update_count}') print(f'test nist_4scores: {nist_4scores / update_count}') print(f'test meteor_scores: {meteor_scores / update_count}')
def forward( self, input_ids, attention_mask=None, decoder_input_ids=None, encoder_outputs: Optional[Tuple] = None, decoder_attention_mask=None, decoder_cached_states=None, use_cache=False, final_layer=None, ): if encoder_outputs is None and input_ids is not None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask ) elif encoder_outputs is None: encoder_outputs = (None,) assert isinstance(encoder_outputs, tuple) if decoder_cached_states is None: decoder_cached_states = [None] * len(decoder_input_ids) if decoder_attention_mask is None: decoder_attention_mask = [None] * len(decoder_input_ids) all_dec_outputs = [] if isinstance(final_layer, int): if isinstance(decoder_input_ids, list): decoder_input_ids = [decoder_input_ids[final_layer]] decoder_attention_mask = [decoder_attention_mask[final_layer]] decoder_cached_states = [decoder_cached_states[final_layer]] else: # decoder doesn't come in multi output: generation time decoder_input_ids = [decoder_input_ids] decoder_attention_mask = [decoder_attention_mask] decoder_cached_states = ( decoder_cached_states if decoder_cached_states[0] is None else [decoder_cached_states] ) # If final_layer is None (i.e. at training time), it will compute both outputs # Otherwise a specified decoder branch will be used for idx, (d_input_ids, d_attn_mask, d_cached_states) in enumerate( zip(decoder_input_ids, decoder_attention_mask, decoder_cached_states) ): # make masks if user doesn't supply if not use_cache: ( d_input_ids, d_padding_mask, causal_mask, ) = _prepare_bart_decoder_inputs( self.config, input_ids, decoder_input_ids=d_input_ids, decoder_padding_mask=d_attn_mask, causal_mask_dtype=self.shared.weight.dtype, ) else: d_padding_mask, causal_mask = None, None assert decoder_input_ids is not None # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) decoder_outputs = self.decoder( d_input_ids, encoder_outputs[0], attention_mask, d_padding_mask, decoder_causal_mask=causal_mask, decoder_cached_states=d_cached_states, use_cache=use_cache, final_layer=(final_layer if final_layer is not None else idx), ) all_dec_outputs.append(decoder_outputs) # Attention and hidden_states will be [] or None if they aren't needed decoder_outputs = [ _filter_out_falsey_values(d_outs) for d_outs in all_dec_outputs ] assert isinstance(decoder_outputs[0][0], torch.Tensor) encoder_outputs = _filter_out_falsey_values(encoder_outputs) return (decoder_outputs, encoder_outputs)