class BERT: def __init__(self): # tokenizer self.tokenizer = BertTokenizer.from_pretrained(BERT_TYPE) # special tokens self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.tokenizer.add_special_tokens( {'additional_special_tokens': dataset_tokens}) # chess tokens self.tokenizer.add_tokens(get_chess_tokens()) # model self.configuration = BertConfig.from_pretrained(BERT_TYPE) self.configuration.is_decoder = True self.model = BertLMHeadModel.from_pretrained( BERT_TYPE, config=self.configuration).cuda() self.model.resize_token_embeddings(len(self.tokenizer)) def load_model(self, model_path): self.model = BertLMHeadModel(self.configuration) self.model.load_state_dict(torch.load(model_path))
def __init__(self, args, base_model_name='bert-base-uncased'): super(DialogBERT, self).__init__() if args.language == 'chinese': base_model_name = 'bert-base-chinese' self.tokenizer = BertTokenizer.from_pretrained(base_model_name, cache_dir='./cache/') if args.model_size == 'tiny': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=256, num_hidden_layers=6, num_attention_heads=2, intermediate_size=1024) self.utt_encoder = BertForPreTraining(self.encoder_config) elif args.model_size == 'small': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=512, num_hidden_layers=8, num_attention_heads=4, intermediate_size=2048) self.utt_encoder = BertForPreTraining(self.encoder_config) else: self.encoder_config = BertConfig.from_pretrained( base_model_name, cache_dir='./cache/') self.utt_encoder = BertForPreTraining.from_pretrained( base_model_name, config=self.encoder_config, cache_dir='./cache/') self.context_encoder = BertModel( self.encoder_config) # context encoder: encode context to vector self.mlm_mode = 'mse' # 'mdn', 'mse' if self.mlm_mode == 'mdn': self.context_mlm_trans = MixtureDensityNetwork( self.encoder_config.hidden_size, self.encoder_config.hidden_size, 3) else: self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config ) # transform context hidden states back to utterance encodings self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) # self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1) self.decoder_config = deepcopy(self.encoder_config) self.decoder_config.is_decoder = True self.decoder_config.add_cross_attention = True self.decoder = BertLMHeadModel(self.decoder_config)
def __init__( self, SentenceEncoder, device, ContextEncoder, no_contextencoder_before_languagemodel=False, ): super().__init__() self.sentence_encoder = SentenceEncoder # Context Encoder if ContextEncoder == "GRUContextEncoder": self.context_encoder = GRUContextEncoder(input_size=768, hidden_size=768) elif ContextEncoder == "PoolContextEncoder": self.context_encoder = PoolContextEncoder(input_size=768, hidden_size=768) self.decoder = BertLMHeadModel.from_pretrained( "bert-base-uncased", is_decoder=True, add_cross_attention=True, output_hidden_states=True, ) self.mpp_classifier = nn.Linear(768, 5) self.device = device self.no_contextencoder_before_languagemodel = ( no_contextencoder_before_languagemodel)
def __init__(self, device='cpu', model=None): vocabsize = 37 max_length = 50 encoder_config = BertConfig(vocab_size=vocabsize, max_position_embeddings=max_length + 64, num_attention_heads=4, num_hidden_layers=4, hidden_size=128, type_vocab_size=1) encoder = BertModel(config=encoder_config) vocabsize = 33 max_length = 50 decoder_config = BertConfig(vocab_size=vocabsize, max_position_embeddings=max_length + 64, num_attention_heads=4, num_hidden_layers=4, hidden_size=128, type_vocab_size=1, add_cross_attentions=True, is_decoder=True) decoder_config.add_cross_attention = True decoder = BertLMHeadModel(config=decoder_config) # Define encoder decoder model self.model = EncoderDecoderModel(encoder=encoder, decoder=decoder) self.model.to(device) self.device = device if model is not None: self.model.load_state_dict(torch.load(model))
def chat(folder_bert, voc, testing=False): tf.random.set_seed(1) tokenizer = BertTokenizer(vocab_file=folder_bert + voc) if testing: tokens = tokenizer.tokenize("jeg tror det skal regne") print(tokens) ids = tokenizer.convert_tokens_to_ids(tokens) print(ids) print("Vocab size:", len(tokenizer.vocab)) config = BertConfig.from_json_file(folder_bert + "/config.json") model = BertLMHeadModel.from_pretrained(folder_bert, config=config) while (1): text = input(">>User: "******"Bot: {}".format(tokenizer.decode(sample_output[0]))) print("Bot: {}".format( tokenizer.decode(sample_output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)))
def create_and_check_decoder_model_past_large_inputs( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask, ): config.is_decoder = True config.add_cross_attention = True model = BertLMHeadModel(config=config).to(torch_device).eval() # first forward pass outputs = model( input_ids, attention_mask=input_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=True, ) past_key_values = outputs.past_key_values # create hypothetical multiple next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) # append to next input_ids and next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) output_from_no_past = model( next_input_ids, attention_mask=next_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_hidden_states=True, )["hidden_states"][0] output_from_past = model( next_tokens, attention_mask=next_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, output_hidden_states=True, )["hidden_states"][0] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_bert_for_causal_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask, ): model = BertLMHeadModel(config=config) model.to(torch_device) model.eval() loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = { "loss": loss, "prediction_scores": prediction_scores, } self.parent.assertListEqual( list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.check_loss_output(result)
def create_and_check_model_for_causal_lm_as_decoder( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask, ): config.add_cross_attention = True model = BertLMHeadModel(config=config) model.to(torch_device) model.eval() result = model( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, ) result = model( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels, encoder_hidden_states=encoder_hidden_states, ) self.parent.assertEqual( result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self): config = self.get_encoder_decoder_config_small() # create two random BERT models for bert2bert & initialize weights (+cross_attention weights) encoder_pt = BertModel(config.encoder).to(torch_device).eval() decoder_pt = BertLMHeadModel(config.decoder).to(torch_device).eval() encoder_decoder_pt = EncoderDecoderModel(encoder=encoder_pt, decoder=decoder_pt).to(torch_device).eval() input_ids = ids_tensor([13, 5], encoder_pt.config.vocab_size) decoder_input_ids = ids_tensor([13, 1], decoder_pt.config.vocab_size) pt_input_ids = torch.tensor(input_ids.numpy(), device=torch_device, dtype=torch.long) pt_decoder_input_ids = torch.tensor(decoder_input_ids.numpy(), device=torch_device, dtype=torch.long) logits_pt = encoder_decoder_pt(input_ids=pt_input_ids, decoder_input_ids=pt_decoder_input_ids).logits # PyTorch => TensorFlow with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2: encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1) encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2) encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained( tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True ) logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy())) self.assertAlmostEqual(max_diff, 0.0, places=3) # Make sure `from_pretrained` following `save_pretrained` work and give the same result with tempfile.TemporaryDirectory() as tmp_dirname: encoder_decoder_tf.save_pretrained(tmp_dirname) encoder_decoder_tf = TFEncoderDecoderModel.from_pretrained(tmp_dirname) logits_tf_2 = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy())) self.assertAlmostEqual(max_diff, 0.0, places=3) # TensorFlow => PyTorch with tempfile.TemporaryDirectory() as tmp_dirname: encoder_decoder_tf.save_pretrained(tmp_dirname) encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True) max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy())) self.assertAlmostEqual(max_diff, 0.0, places=3)
def __init__(self, latent_dim: int = 512): super(TextDecoder, self).__init__() # Tokenizer self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # Decoder model config = BertConfig.from_pretrained("bert-base-uncased") config.is_decoder = True config.add_cross_attention = True self.decoder_model = BertLMHeadModel.from_pretrained( "bert-base-uncased", config=config) decoder_input_size = 768 self.linear = nn.Linear(latent_dim, decoder_input_size) # Identifier to signal to the trainer to put the label in the decode call self.needs_labels = True
def from_pretrained(self, model_dir): self.encoder_config = BertConfig.from_pretrained(model_dir) self.tokenizer = BertTokenizer.from_pretrained( path.join(model_dir, 'tokenizer'), do_lower_case=args.do_lower_case) self.utt_encoder = BertForPreTraining.from_pretrained( path.join(model_dir, 'utt_encoder')) self.context_encoder = BertForSequenceClassification.from_pretrained( path.join(model_dir, 'context_encoder')) self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config) self.context_mlm_trans.load_state_dict( torch.load(path.join(model_dir, 'context_mlm_trans.pkl'))) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) self.context_order_trans.load_state_dict( torch.load(path.join(model_dir, 'context_order_trans.pkl'))) self.decoder_config = BertConfig.from_pretrained(model_dir) self.decoder = BertLMHeadModel.from_pretrained( path.join(model_dir, 'decoder'))
def get_encoder_decoder_models(self): encoder_model = BertModel.from_pretrained("bert-base-uncased") decoder_model = BertLMHeadModel.from_pretrained("bert-base-uncased", config=self.get_decoder_config()) return {"encoder": encoder_model, "decoder": decoder_model}
def get_encoder_decoder_model(self, config, decoder_config): encoder_model = Wav2Vec2Model(config).eval() decoder_model = BertLMHeadModel(decoder_config).eval() return encoder_model, decoder_model
top_10 = torch.topk(mask_word, 10, dim=1)[1][0] for token in top_10: word = tokenizer.decode([token]) new_sentence = text.replace(tokenizer.mask_token, word) print(new_sentence) # get the top candidate word only top_word = torch.argmax(mask_word, dim=1) print(tokenizer.decode(top_word)) ### Example 2: Language Modeling print('### Example 2: Language Modeling') # the task of predicting the best word to follow or continue a sentence given all the words already in the sentence. model = BertLMHeadModel.from_pretrained( 'bert-base-uncased', return_dict=True, # is_decoder = True if we want to use this model as a standalone model for predicting the next best word in the sequence. is_decoder=True, cache_dir=os.getenv("cache_dir", "../../models")) text = "A knife is very " input = tokenizer.encode_plus(text, return_tensors="pt") output = model(**input).logits[:, -1, :] softmax = F.softmax(output, -1) index = torch.argmax(softmax, dim=-1) x = tokenizer.decode(index) print(text + " " + x) ### Example 3: Next Sentence Prediction print('### Example 3: Next Sentence Prediction') # Next Sentence Prediction is the task of predicting whether one sentence follows another sentence. model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased',
class DialogBERT(nn.Module): '''Hierarchical BERT for dialog v5 with two features: - Masked context utterances prediction with direct MSE matching of their vectors - Energy-based Utterance order prediction: A novel approach to shuffle the context and predict the original order with distributed order prediction''' # TODO: 1. Enhance sorting net # 2. Better data loader for permutation ((avoid returning perm_id and use max(pos_ids) instead, def __init__(self, args, base_model_name='bert-base-uncased'): super(DialogBERT, self).__init__() if args.language == 'chinese': base_model_name = 'bert-base-chinese' self.tokenizer = BertTokenizer.from_pretrained(base_model_name, cache_dir='./cache/') if args.model_size == 'tiny': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=256, num_hidden_layers=6, num_attention_heads=2, intermediate_size=1024) self.utt_encoder = BertForPreTraining(self.encoder_config) elif args.model_size == 'small': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=512, num_hidden_layers=8, num_attention_heads=4, intermediate_size=2048) self.utt_encoder = BertForPreTraining(self.encoder_config) else: self.encoder_config = BertConfig.from_pretrained( base_model_name, cache_dir='./cache/') self.utt_encoder = BertForPreTraining.from_pretrained( base_model_name, config=self.encoder_config, cache_dir='./cache/') self.context_encoder = BertModel( self.encoder_config) # context encoder: encode context to vector self.mlm_mode = 'mse' # 'mdn', 'mse' if self.mlm_mode == 'mdn': self.context_mlm_trans = MixtureDensityNetwork( self.encoder_config.hidden_size, self.encoder_config.hidden_size, 3) else: self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config ) # transform context hidden states back to utterance encodings self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) # self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1) self.decoder_config = deepcopy(self.encoder_config) self.decoder_config.is_decoder = True self.decoder_config.add_cross_attention = True self.decoder = BertLMHeadModel(self.decoder_config) def init_weights(self, m): # Initialize Linear Weight for GAN if isinstance(m, nn.Linear): m.weight.data.uniform_(-0.08, 0.08) #nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0.) @classmethod def from_pretrained(self, model_dir): self.encoder_config = BertConfig.from_pretrained(model_dir) self.tokenizer = BertTokenizer.from_pretrained( path.join(model_dir, 'tokenizer'), do_lower_case=args.do_lower_case) self.utt_encoder = BertForPreTraining.from_pretrained( path.join(model_dir, 'utt_encoder')) self.context_encoder = BertForSequenceClassification.from_pretrained( path.join(model_dir, 'context_encoder')) self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config) self.context_mlm_trans.load_state_dict( torch.load(path.join(model_dir, 'context_mlm_trans.pkl'))) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) self.context_order_trans.load_state_dict( torch.load(path.join(model_dir, 'context_order_trans.pkl'))) self.decoder_config = BertConfig.from_pretrained(model_dir) self.decoder = BertLMHeadModel.from_pretrained( path.join(model_dir, 'decoder')) def save_pretrained(self, output_dir): def save_module(model, save_path): torch.save(model_to_save.state_dict(), save_path) def make_list_dirs(dir_list): for dir_ in dir_list: os.makedirs(dir_, exist_ok=True) make_list_dirs([ path.join(output_dir, name) for name in ['tokenizer', 'utt_encoder', 'context_encoder', 'decoder'] ]) model_to_save = self.module if hasattr(self, 'module') else self model_to_save.encoder_config.save_pretrained( output_dir) # Save configuration file model_to_save.tokenizer.save_pretrained( path.join(output_dir, 'tokenizer')) model_to_save.utt_encoder.save_pretrained( path.join(output_dir, 'utt_encoder')) model_to_save.context_encoder.save_pretrained( path.join(output_dir, 'context_encoder')) save_module(model_to_save.context_mlm_trans, path.join(output_dir, 'context_mlm_trans.pkl')) save_module(model_to_save.context_order_trans, path.join(output_dir, 'context_order_trans.pkl')) model_to_save.decoder_config.save_pretrained( output_dir) # Save configuration file model_to_save.decoder.save_pretrained(path.join(output_dir, 'decoder')) def utt_encoding(self, context, utts_attn_mask): batch_size, max_ctx_len, max_utt_len = context.size( ) #context: [batch_size x diag_len x max_utt_len] utts = context.view( -1, max_utt_len) # [(batch_size*diag_len) x max_utt_len] utts_attn_mask = utts_attn_mask.view(-1, max_utt_len) _, utts_encodings, *_ = self.utt_encoder.bert(utts, utts_attn_mask) utts_encodings = utts_encodings.view(batch_size, max_ctx_len, -1) return utts_encodings def context_encoding(self, context, utts_attn_mask, ctx_attn_mask): #with torch.no_grad(): utt_encodings = self.utt_encoding(context, utts_attn_mask) context_hiddens, pooled_output, *_ = self.context_encoder( None, ctx_attn_mask, None, None, None, utt_encodings) # context_hiddens:[batch_size x ctx_len x dim]; pooled_output=[batch_size x dim] return context_hiddens, pooled_output def train_dialog_flow(self, context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response): """ only train the dialog flow model """ self.context_encoder.train() # set the module in training mode. self.context_mlm_trans.train() context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) lm_pred_encodings = self.context_mlm_trans( self.dropout(context_hiddens)) context_lm_targets[context_lm_targets == -100] = 0 ctx_lm_mask = context_lm_targets.sum(2) if (ctx_lm_mask > 0).sum() == 0: ctx_lm_mask[0, 0] = 1 lm_pred_encodings = lm_pred_encodings[ctx_lm_mask > 0] context_lm_targets = context_lm_targets[ctx_lm_mask > 0] context_lm_targets_attn_mask = context_utts_attn_mask[ctx_lm_mask > 0] with torch.no_grad(): _, lm_tgt_encodings, *_ = self.utt_encoder.bert( context_lm_targets, context_lm_targets_attn_mask) loss_ctx_mlm = MSELoss()(lm_pred_encodings, lm_tgt_encodings) # [num_selected_utts x dim] # context order prediction if isinstance(self.context_order_trans, SelfSorting): sorting_scores = self.context_order_trans(context_hiddens, context_attn_mask) else: sorting_scores = self.context_order_trans(context_hiddens) sorting_pad_mask = context_attn_mask == 0 sorting_pad_mask[ context_position_perm_id < 1] = True # exclude single-turn and unshuffled dialogs loss_ctx_uop = listNet(sorting_scores, context_position_ids, sorting_pad_mask) #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask) loss = loss_ctx_mlm + loss_ctx_uop return { 'loss': loss, 'loss_ctx_mlm': loss_ctx_lm, 'loss_ctx_uop': loss_ctx_uop } def train_decoder(self, context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response): """ only train the decoder """ self.decoder.train() with torch.no_grad(): context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) ## train decoder dec_input, dec_target = response[:, :-1].contiguous( ), response[:, 1:].clone() dec_output, *_ = self.decoder( dec_input, dec_input.ne(self.tokenizer.pad_token_id).long(), None, None, None, None, encoder_hidden_states=context_hiddens, encoder_attention_mask=context_attn_mask, ) batch_size, seq_len, vocab_size = dec_output.size() dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100 dec_target[context_position_perm_id > 1] == -100 # ignore responses whose contexts are shuffled loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size), dec_target.view(-1)) results = {'loss': loss_decoder, 'loss_decoder': loss_decoder} return results def forward(self, context, context_utts_attn_mask, context_attn_mask, context_mlm_targets, context_position_perm_id, context_position_ids, response): self.train() batch_size, max_ctx_len, max_utt_len = context.size( ) #context: [batch_size x diag_len x max_utt_len] context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) ## train dialog flow modeling context_mlm_targets[context_mlm_targets == -100] = 0 ctx_mlm_mask = context_mlm_targets.sum(2) #[batch_size x num_utts] if (ctx_mlm_mask > 0).sum() == 0: ctx_mlm_mask[0, 0] = 1 ctx_mlm_mask = ctx_mlm_mask > 0 with torch.no_grad(): _, mlm_tgt_encodings, *_ = self.utt_encoder.bert( context_mlm_targets[ctx_mlm_mask], context_utts_attn_mask[ctx_mlm_mask]) if self.mlm_mode == 'mdn': # mixture density network mlm_pred_pi, mlm_pred_normal = self.context_mlm_trans( self.dropout(context_hiddens[ctx_mlm_mask])) loss_ctx_mlm = self.context_mlm_trans.loss(mlm_pred_pi, mlm_pred_normal, mlm_tgt_encodings) else: # simply mean square loss mlm_pred_encodings = self.context_mlm_trans( self.dropout(context_hiddens[ctx_mlm_mask])) loss_ctx_mlm = MSELoss()( mlm_pred_encodings, mlm_tgt_encodings) # [num_selected_utts x dim] # context order prediction if isinstance(self.context_order_trans, SelfSorting): sorting_scores = self.context_order_trans(context_hiddens, context_attn_mask) else: sorting_scores = self.context_order_trans(context_hiddens) sorting_pad_mask = context_attn_mask == 0 sorting_pad_mask[ context_position_perm_id < 1] = True # exclude single-turn and unshuffled dialogs loss_ctx_uop = listNet(sorting_scores, context_position_ids, sorting_pad_mask) #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask) ## train decoder dec_input, dec_target = response[:, :-1].contiguous( ), response[:, 1:].clone() dec_output, *_ = self.decoder( dec_input, dec_input.ne(self.tokenizer.pad_token_id).long(), None, None, None, None, encoder_hidden_states=context_hiddens, encoder_attention_mask=context_attn_mask, ) batch_size, seq_len, vocab_size = dec_output.size() dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100 dec_target[context_position_perm_id > 1] = -100 # ignore responses whose context was shuffled loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size), dec_target.view(-1)) loss = loss_ctx_mlm + loss_ctx_uop + loss_decoder results = { 'loss': loss, 'loss_ctx_mlm': loss_ctx_mlm, 'loss_ctx_uop': loss_ctx_uop, 'loss_decoder': loss_decoder } return results def validate(self, context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response): results = self.train_decoder(context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response) return results['loss'].item() def generate(self, input_batch, max_len=30, num_samples=1, mode='sample'): self.eval() device = next(self.parameters()).device context, context_utts_attn_mask, context_attn_mask = [ t.to(device) for t in input_batch[:3] ] ground_truth = input_batch[6].numpy() context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) generated = torch.zeros( (num_samples, 1), dtype=torch.long, device=device).fill_(self.tokenizer.cls_token_id) # [batch_sz x 1] (1=seq_len) sample_lens = torch.ones((num_samples, 1), dtype=torch.long, device=device) len_inc = torch.ones((num_samples, 1), dtype=torch.long, device=device) for _ in range(max_len): outputs, *_ = self.decoder( generated, generated.ne(self.tokenizer.pad_token_id).long(), None, None, None, None, encoder_hidden_states=context_hiddens, encoder_attention_mask=context_attn_mask, ) # [batch_size x seq_len x vocab_size] next_token_logits = outputs[:, -1, :] / self.decoder_config.temperature # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) for i in range(num_samples): for _ in set(generated[i].tolist()): next_token_logits[ i, _] /= self.decoder_config.repetition_penalty filtered_logits = top_k_top_p_filtering( next_token_logits, top_k=self.decoder_config.top_k, top_p=self.decoder_config.top_p) if mode == 'greedy': # greedy sampling: next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) else: next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=num_samples) next_token[len_inc == 0] = self.tokenizer.pad_token_id generated = torch.cat((generated, next_token), dim=1) len_inc = len_inc * ( next_token != self.tokenizer.sep_token_id).long( ) # stop incresing length (set 0 bit) when EOS is encountered if len_inc.sum() < 1: break sample_lens = sample_lens + len_inc # to numpy sample_words = generated.data.cpu().numpy() sample_lens = sample_lens.data.cpu().numpy() context = context.data.cpu().numpy() return sample_words, sample_lens, context, ground_truth # nparray: [repeat x seq_len]
encoder = BertModel(config=encoder_config) vocabsize = decparams["vocab_size"] max_length = decparams["max_length"] decoder_config = BertConfig( vocab_size=vocabsize, max_position_embeddings=max_length + 64, # this shuold be some large value num_attention_heads=decparams["num_attn_heads"], num_hidden_layers=decparams["num_hidden_layers"], hidden_size=decparams["hidden_size"], type_vocab_size=1, is_decoder=True, add_cross_attention=True) # Very Important decoder = BertLMHeadModel(config=decoder_config) # Define encoder decoder model model = EncoderDecoderModel(encoder=encoder, decoder=decoder) model.to(device) def count_parameters(mdl): return sum(p.numel() for p in mdl.parameters() if p.requires_grad) print(f'The encoder has {count_parameters(encoder):,} trainable parameters') print(f'The decoder has {count_parameters(decoder):,} trainable parameters') print(f'The model has {count_parameters(model):,} trainable parameters') optimizer = optim.Adam(model.parameters(), lr=modelparams['lr'])
def get_encoder_decoder_model(self, config, decoder_config): encoder_model = BertModel(config) decoder_model = BertLMHeadModel(decoder_config) return encoder_model, decoder_model
def build_model(config): src_tokenizer = BertTokenizer.from_pretrained( 'bert-base-multilingual-cased') tgt_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tgt_tokenizer.bos_token = '<s>' tgt_tokenizer.eos_token = '</s>' #hidden_size and intermediate_size are both wrt all the attention heads. #Should be divisible by num_attention_heads encoder_config = BertConfig( vocab_size=src_tokenizer.vocab_size, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, hidden_dropout_prob=config.dropout_prob, attention_probs_dropout_prob=config.dropout_prob, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12) decoder_config = BertConfig( vocab_size=tgt_tokenizer.vocab_size, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, hidden_dropout_prob=config.dropout_prob, attention_probs_dropout_prob=config.dropout_prob, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, is_decoder=True) encoder = BertModel(encoder_config) encoder_embeddings = torch.nn.Embedding( src_tokenizer.vocab_size, config.hidden_size, padding_idx=src_tokenizer.pad_token_id) decoder_embeddings = torch.nn.Embedding( tgt_tokenizer.vocab_size, config.hidden_size, padding_idx=tgt_tokenizer.pad_token_id) decoder = BertLMHeadModel(decoder_config) encoder.set_input_embeddings(encoder_embeddings.cuda()) decoder.set_input_embeddings(decoder_embeddings.cuda()) # #Create encoder and decoder embedding layers. # encoder_embeddings = torch.nn.Embedding(src_tokenizer.vocab_size, config.hidden_size, padding_idx=src_tokenizer.pad_token_id) # decoder_embeddings = torch.nn.Embedding(tgt_tokenizer.vocab_size, config.hidden_size, padding_idx=tgt_tokenizer.pad_token_id) # encoder = BertModel(encoder_config) # encoder.set_input_embeddings(encoder_embeddings.cuda()) # decoder = BertForMaskedLM(decoder_config) # decoder.set_input_embeddings(decoder_embeddings.cuda()) model = TranslationModel(config, src_tokenizer, tgt_tokenizer, encoder, decoder) model.cuda() tokenizers = ED({'src': src_tokenizer, 'tgt': tgt_tokenizer}) return model, tokenizers
def get_encoder_decoder_model(self, config, decoder_config): encoder_model = Speech2TextEncoder(config).eval() decoder_model = BertLMHeadModel(decoder_config).eval() return encoder_model, decoder_model
def load_model(self, model_path): self.model = BertLMHeadModel(self.configuration) self.model.load_state_dict(torch.load(model_path))