コード例 #1
0
class BERT:
    def __init__(self):
        # tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(BERT_TYPE)

        # special tokens
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        self.tokenizer.add_special_tokens(
            {'additional_special_tokens': dataset_tokens})

        # chess tokens
        self.tokenizer.add_tokens(get_chess_tokens())

        # model
        self.configuration = BertConfig.from_pretrained(BERT_TYPE)
        self.configuration.is_decoder = True

        self.model = BertLMHeadModel.from_pretrained(
            BERT_TYPE, config=self.configuration).cuda()

        self.model.resize_token_embeddings(len(self.tokenizer))

    def load_model(self, model_path):
        self.model = BertLMHeadModel(self.configuration)
        self.model.load_state_dict(torch.load(model_path))
コード例 #2
0
ファイル: dialogBERT.py プロジェクト: guxd/DialogBERT
    def __init__(self, args, base_model_name='bert-base-uncased'):
        super(DialogBERT, self).__init__()

        if args.language == 'chinese': base_model_name = 'bert-base-chinese'

        self.tokenizer = BertTokenizer.from_pretrained(base_model_name,
                                                       cache_dir='./cache/')
        if args.model_size == 'tiny':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=256,
                                             num_hidden_layers=6,
                                             num_attention_heads=2,
                                             intermediate_size=1024)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        elif args.model_size == 'small':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=512,
                                             num_hidden_layers=8,
                                             num_attention_heads=4,
                                             intermediate_size=2048)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        else:
            self.encoder_config = BertConfig.from_pretrained(
                base_model_name, cache_dir='./cache/')
            self.utt_encoder = BertForPreTraining.from_pretrained(
                base_model_name,
                config=self.encoder_config,
                cache_dir='./cache/')

        self.context_encoder = BertModel(
            self.encoder_config)  # context encoder: encode context to vector

        self.mlm_mode = 'mse'  # 'mdn', 'mse'
        if self.mlm_mode == 'mdn':
            self.context_mlm_trans = MixtureDensityNetwork(
                self.encoder_config.hidden_size,
                self.encoder_config.hidden_size, 3)
        else:
            self.context_mlm_trans = BertPredictionHeadTransform(
                self.encoder_config
            )  # transform context hidden states back to utterance encodings

        self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob)
        self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
        #       self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1)

        self.decoder_config = deepcopy(self.encoder_config)
        self.decoder_config.is_decoder = True
        self.decoder_config.add_cross_attention = True
        self.decoder = BertLMHeadModel(self.decoder_config)
コード例 #3
0
    def __init__(
        self,
        SentenceEncoder,
        device,
        ContextEncoder,
        no_contextencoder_before_languagemodel=False,
    ):
        super().__init__()

        self.sentence_encoder = SentenceEncoder

        # Context Encoder
        if ContextEncoder == "GRUContextEncoder":
            self.context_encoder = GRUContextEncoder(input_size=768,
                                                     hidden_size=768)
        elif ContextEncoder == "PoolContextEncoder":
            self.context_encoder = PoolContextEncoder(input_size=768,
                                                      hidden_size=768)

        self.decoder = BertLMHeadModel.from_pretrained(
            "bert-base-uncased",
            is_decoder=True,
            add_cross_attention=True,
            output_hidden_states=True,
        )

        self.mpp_classifier = nn.Linear(768, 5)

        self.device = device
        self.no_contextencoder_before_languagemodel = (
            no_contextencoder_before_languagemodel)
コード例 #4
0
ファイル: __init__.py プロジェクト: biggoron/phonetizer
    def __init__(self, device='cpu', model=None):
        vocabsize = 37
        max_length = 50
        encoder_config = BertConfig(vocab_size=vocabsize,
                                    max_position_embeddings=max_length + 64,
                                    num_attention_heads=4,
                                    num_hidden_layers=4,
                                    hidden_size=128,
                                    type_vocab_size=1)
        encoder = BertModel(config=encoder_config)

        vocabsize = 33
        max_length = 50
        decoder_config = BertConfig(vocab_size=vocabsize,
                                    max_position_embeddings=max_length + 64,
                                    num_attention_heads=4,
                                    num_hidden_layers=4,
                                    hidden_size=128,
                                    type_vocab_size=1,
                                    add_cross_attentions=True,
                                    is_decoder=True)
        decoder_config.add_cross_attention = True
        decoder = BertLMHeadModel(config=decoder_config)

        # Define encoder decoder model
        self.model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
        self.model.to(device)
        self.device = device
        if model is not None:
            self.model.load_state_dict(torch.load(model))
コード例 #5
0
def chat(folder_bert, voc, testing=False):
    tf.random.set_seed(1)
    tokenizer = BertTokenizer(vocab_file=folder_bert + voc)
    if testing:
        tokens = tokenizer.tokenize("jeg tror det skal regne")
        print(tokens)
        ids = tokenizer.convert_tokens_to_ids(tokens)
        print(ids)
        print("Vocab size:", len(tokenizer.vocab))

    config = BertConfig.from_json_file(folder_bert + "/config.json")
    model = BertLMHeadModel.from_pretrained(folder_bert, config=config)
    while (1):
        text = input(">>User: "******"Bot: {}".format(tokenizer.decode(sample_output[0])))
        print("Bot: {}".format(
            tokenizer.decode(sample_output[:, input_ids.shape[-1]:][0],
                             skip_special_tokens=True)))
コード例 #6
0
    def create_and_check_decoder_model_past_large_inputs(
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        config.is_decoder = True
        config.add_cross_attention = True
        model = BertLMHeadModel(config=config).to(torch_device).eval()

        # first forward pass
        outputs = model(
            input_ids,
            attention_mask=input_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=True,
        )
        past_key_values = outputs.past_key_values

        # create hypothetical multiple next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
        next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)

        # append to next input_ids and
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
        next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)

        output_from_no_past = model(
            next_input_ids,
            attention_mask=next_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_hidden_states=True,
        )["hidden_states"][0]
        output_from_past = model(
            next_tokens,
            attention_mask=next_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            output_hidden_states=True,
        )["hidden_states"][0]

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()

        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
コード例 #7
0
 def create_and_check_bert_for_causal_lm(
     self,
     config,
     input_ids,
     token_type_ids,
     input_mask,
     sequence_labels,
     token_labels,
     choice_labels,
     encoder_hidden_states,
     encoder_attention_mask,
 ):
     model = BertLMHeadModel(config=config)
     model.to(torch_device)
     model.eval()
     loss, prediction_scores = model(input_ids,
                                     attention_mask=input_mask,
                                     token_type_ids=token_type_ids,
                                     labels=token_labels)
     result = {
         "loss": loss,
         "prediction_scores": prediction_scores,
     }
     self.parent.assertListEqual(
         list(result["prediction_scores"].size()),
         [self.batch_size, self.seq_length, self.vocab_size])
     self.check_loss_output(result)
コード例 #8
0
 def create_and_check_model_for_causal_lm_as_decoder(
     self,
     config,
     input_ids,
     token_type_ids,
     input_mask,
     sequence_labels,
     token_labels,
     choice_labels,
     encoder_hidden_states,
     encoder_attention_mask,
 ):
     config.add_cross_attention = True
     model = BertLMHeadModel(config=config)
     model.to(torch_device)
     model.eval()
     result = model(
         input_ids,
         attention_mask=input_mask,
         token_type_ids=token_type_ids,
         labels=token_labels,
         encoder_hidden_states=encoder_hidden_states,
         encoder_attention_mask=encoder_attention_mask,
     )
     result = model(
         input_ids,
         attention_mask=input_mask,
         token_type_ids=token_type_ids,
         labels=token_labels,
         encoder_hidden_states=encoder_hidden_states,
     )
     self.parent.assertEqual(
         result.logits.shape,
         (self.batch_size, self.seq_length, self.vocab_size))
コード例 #9
0
    def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self):
        config = self.get_encoder_decoder_config_small()

        # create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
        encoder_pt = BertModel(config.encoder).to(torch_device).eval()
        decoder_pt = BertLMHeadModel(config.decoder).to(torch_device).eval()

        encoder_decoder_pt = EncoderDecoderModel(encoder=encoder_pt, decoder=decoder_pt).to(torch_device).eval()

        input_ids = ids_tensor([13, 5], encoder_pt.config.vocab_size)
        decoder_input_ids = ids_tensor([13, 1], decoder_pt.config.vocab_size)

        pt_input_ids = torch.tensor(input_ids.numpy(), device=torch_device, dtype=torch.long)
        pt_decoder_input_ids = torch.tensor(decoder_input_ids.numpy(), device=torch_device, dtype=torch.long)

        logits_pt = encoder_decoder_pt(input_ids=pt_input_ids, decoder_input_ids=pt_decoder_input_ids).logits

        # PyTorch => TensorFlow
        with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
            encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
            encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
            encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
                tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
            )

        logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits

        max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
        self.assertAlmostEqual(max_diff, 0.0, places=3)

        # Make sure `from_pretrained` following `save_pretrained` work and give the same result
        with tempfile.TemporaryDirectory() as tmp_dirname:
            encoder_decoder_tf.save_pretrained(tmp_dirname)
            encoder_decoder_tf = TFEncoderDecoderModel.from_pretrained(tmp_dirname)

            logits_tf_2 = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits

            max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy()))
            self.assertAlmostEqual(max_diff, 0.0, places=3)

        # TensorFlow => PyTorch
        with tempfile.TemporaryDirectory() as tmp_dirname:
            encoder_decoder_tf.save_pretrained(tmp_dirname)
            encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)

        max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
        self.assertAlmostEqual(max_diff, 0.0, places=3)
コード例 #10
0
    def __init__(self, latent_dim: int = 512):
        super(TextDecoder, self).__init__()

        # Tokenizer
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

        # Decoder model
        config = BertConfig.from_pretrained("bert-base-uncased")
        config.is_decoder = True
        config.add_cross_attention = True

        self.decoder_model = BertLMHeadModel.from_pretrained(
            "bert-base-uncased", config=config)

        decoder_input_size = 768
        self.linear = nn.Linear(latent_dim, decoder_input_size)

        # Identifier to signal to the trainer to put the label in the decode call
        self.needs_labels = True
コード例 #11
0
ファイル: dialogBERT.py プロジェクト: guxd/DialogBERT
 def from_pretrained(self, model_dir):
     self.encoder_config = BertConfig.from_pretrained(model_dir)
     self.tokenizer = BertTokenizer.from_pretrained(
         path.join(model_dir, 'tokenizer'),
         do_lower_case=args.do_lower_case)
     self.utt_encoder = BertForPreTraining.from_pretrained(
         path.join(model_dir, 'utt_encoder'))
     self.context_encoder = BertForSequenceClassification.from_pretrained(
         path.join(model_dir, 'context_encoder'))
     self.context_mlm_trans = BertPredictionHeadTransform(
         self.encoder_config)
     self.context_mlm_trans.load_state_dict(
         torch.load(path.join(model_dir, 'context_mlm_trans.pkl')))
     self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
     self.context_order_trans.load_state_dict(
         torch.load(path.join(model_dir, 'context_order_trans.pkl')))
     self.decoder_config = BertConfig.from_pretrained(model_dir)
     self.decoder = BertLMHeadModel.from_pretrained(
         path.join(model_dir, 'decoder'))
 def get_encoder_decoder_models(self):
     encoder_model = BertModel.from_pretrained("bert-base-uncased")
     decoder_model = BertLMHeadModel.from_pretrained("bert-base-uncased", config=self.get_decoder_config())
     return {"encoder": encoder_model, "decoder": decoder_model}
コード例 #13
0
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = Wav2Vec2Model(config).eval()
     decoder_model = BertLMHeadModel(decoder_config).eval()
     return encoder_model, decoder_model
コード例 #14
0
top_10 = torch.topk(mask_word, 10, dim=1)[1][0]
for token in top_10:
    word = tokenizer.decode([token])
    new_sentence = text.replace(tokenizer.mask_token, word)
    print(new_sentence)

# get the top candidate word only
top_word = torch.argmax(mask_word, dim=1)
print(tokenizer.decode(top_word))

### Example 2: Language Modeling
print('### Example 2: Language Modeling')
# the task of predicting the best word to follow or continue a sentence given all the words already in the sentence.
model = BertLMHeadModel.from_pretrained(
    'bert-base-uncased',
    return_dict=True,
    #  is_decoder = True if we want to use this model as a standalone model for predicting the next best word in the sequence.
    is_decoder=True,
    cache_dir=os.getenv("cache_dir", "../../models"))

text = "A knife is very "
input = tokenizer.encode_plus(text, return_tensors="pt")
output = model(**input).logits[:, -1, :]
softmax = F.softmax(output, -1)
index = torch.argmax(softmax, dim=-1)
x = tokenizer.decode(index)
print(text + " " + x)

### Example 3: Next Sentence Prediction
print('### Example 3: Next Sentence Prediction')
# Next Sentence Prediction is the task of predicting whether one sentence follows another sentence.
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased',
コード例 #15
0
ファイル: dialogBERT.py プロジェクト: guxd/DialogBERT
class DialogBERT(nn.Module):
    '''Hierarchical BERT for dialog v5 with two features:
    - Masked context utterances prediction with direct MSE matching of their vectors
    - Energy-based Utterance order prediction: A novel approach to shuffle the context and predict the original order with distributed order prediction'''

    # TODO: 1. Enhance sorting net
    #       2. Better data loader for permutation ((avoid returning perm_id and use max(pos_ids) instead,

    def __init__(self, args, base_model_name='bert-base-uncased'):
        super(DialogBERT, self).__init__()

        if args.language == 'chinese': base_model_name = 'bert-base-chinese'

        self.tokenizer = BertTokenizer.from_pretrained(base_model_name,
                                                       cache_dir='./cache/')
        if args.model_size == 'tiny':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=256,
                                             num_hidden_layers=6,
                                             num_attention_heads=2,
                                             intermediate_size=1024)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        elif args.model_size == 'small':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=512,
                                             num_hidden_layers=8,
                                             num_attention_heads=4,
                                             intermediate_size=2048)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        else:
            self.encoder_config = BertConfig.from_pretrained(
                base_model_name, cache_dir='./cache/')
            self.utt_encoder = BertForPreTraining.from_pretrained(
                base_model_name,
                config=self.encoder_config,
                cache_dir='./cache/')

        self.context_encoder = BertModel(
            self.encoder_config)  # context encoder: encode context to vector

        self.mlm_mode = 'mse'  # 'mdn', 'mse'
        if self.mlm_mode == 'mdn':
            self.context_mlm_trans = MixtureDensityNetwork(
                self.encoder_config.hidden_size,
                self.encoder_config.hidden_size, 3)
        else:
            self.context_mlm_trans = BertPredictionHeadTransform(
                self.encoder_config
            )  # transform context hidden states back to utterance encodings

        self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob)
        self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
        #       self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1)

        self.decoder_config = deepcopy(self.encoder_config)
        self.decoder_config.is_decoder = True
        self.decoder_config.add_cross_attention = True
        self.decoder = BertLMHeadModel(self.decoder_config)

    def init_weights(self, m):  # Initialize Linear Weight for GAN
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-0.08,
                                   0.08)  #nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0.)

    @classmethod
    def from_pretrained(self, model_dir):
        self.encoder_config = BertConfig.from_pretrained(model_dir)
        self.tokenizer = BertTokenizer.from_pretrained(
            path.join(model_dir, 'tokenizer'),
            do_lower_case=args.do_lower_case)
        self.utt_encoder = BertForPreTraining.from_pretrained(
            path.join(model_dir, 'utt_encoder'))
        self.context_encoder = BertForSequenceClassification.from_pretrained(
            path.join(model_dir, 'context_encoder'))
        self.context_mlm_trans = BertPredictionHeadTransform(
            self.encoder_config)
        self.context_mlm_trans.load_state_dict(
            torch.load(path.join(model_dir, 'context_mlm_trans.pkl')))
        self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
        self.context_order_trans.load_state_dict(
            torch.load(path.join(model_dir, 'context_order_trans.pkl')))
        self.decoder_config = BertConfig.from_pretrained(model_dir)
        self.decoder = BertLMHeadModel.from_pretrained(
            path.join(model_dir, 'decoder'))

    def save_pretrained(self, output_dir):
        def save_module(model, save_path):
            torch.save(model_to_save.state_dict(), save_path)

        def make_list_dirs(dir_list):
            for dir_ in dir_list:
                os.makedirs(dir_, exist_ok=True)

        make_list_dirs([
            path.join(output_dir, name) for name in
            ['tokenizer', 'utt_encoder', 'context_encoder', 'decoder']
        ])
        model_to_save = self.module if hasattr(self, 'module') else self
        model_to_save.encoder_config.save_pretrained(
            output_dir)  # Save configuration file
        model_to_save.tokenizer.save_pretrained(
            path.join(output_dir, 'tokenizer'))
        model_to_save.utt_encoder.save_pretrained(
            path.join(output_dir, 'utt_encoder'))
        model_to_save.context_encoder.save_pretrained(
            path.join(output_dir, 'context_encoder'))
        save_module(model_to_save.context_mlm_trans,
                    path.join(output_dir, 'context_mlm_trans.pkl'))
        save_module(model_to_save.context_order_trans,
                    path.join(output_dir, 'context_order_trans.pkl'))
        model_to_save.decoder_config.save_pretrained(
            output_dir)  # Save configuration file
        model_to_save.decoder.save_pretrained(path.join(output_dir, 'decoder'))

    def utt_encoding(self, context, utts_attn_mask):
        batch_size, max_ctx_len, max_utt_len = context.size(
        )  #context: [batch_size x diag_len x max_utt_len]

        utts = context.view(
            -1, max_utt_len)  # [(batch_size*diag_len) x max_utt_len]
        utts_attn_mask = utts_attn_mask.view(-1, max_utt_len)
        _, utts_encodings, *_ = self.utt_encoder.bert(utts, utts_attn_mask)
        utts_encodings = utts_encodings.view(batch_size, max_ctx_len, -1)
        return utts_encodings

    def context_encoding(self, context, utts_attn_mask, ctx_attn_mask):
        #with torch.no_grad():
        utt_encodings = self.utt_encoding(context, utts_attn_mask)
        context_hiddens, pooled_output, *_ = self.context_encoder(
            None, ctx_attn_mask, None, None, None, utt_encodings)
        # context_hiddens:[batch_size x ctx_len x dim]; pooled_output=[batch_size x dim]

        return context_hiddens, pooled_output

    def train_dialog_flow(self, context, context_utts_attn_mask,
                          context_attn_mask, context_lm_targets,
                          context_position_perm_id, context_position_ids,
                          response):
        """
        only train the dialog flow model
        """
        self.context_encoder.train()  # set the module in training mode.
        self.context_mlm_trans.train()

        context_hiddens, context_encoding = self.context_encoding(
            context, context_utts_attn_mask, context_attn_mask)
        lm_pred_encodings = self.context_mlm_trans(
            self.dropout(context_hiddens))

        context_lm_targets[context_lm_targets == -100] = 0
        ctx_lm_mask = context_lm_targets.sum(2)
        if (ctx_lm_mask > 0).sum() == 0: ctx_lm_mask[0, 0] = 1
        lm_pred_encodings = lm_pred_encodings[ctx_lm_mask > 0]
        context_lm_targets = context_lm_targets[ctx_lm_mask > 0]
        context_lm_targets_attn_mask = context_utts_attn_mask[ctx_lm_mask > 0]

        with torch.no_grad():
            _, lm_tgt_encodings, *_ = self.utt_encoder.bert(
                context_lm_targets, context_lm_targets_attn_mask)

        loss_ctx_mlm = MSELoss()(lm_pred_encodings,
                                 lm_tgt_encodings)  # [num_selected_utts x dim]

        # context order prediction
        if isinstance(self.context_order_trans, SelfSorting):
            sorting_scores = self.context_order_trans(context_hiddens,
                                                      context_attn_mask)
        else:
            sorting_scores = self.context_order_trans(context_hiddens)
        sorting_pad_mask = context_attn_mask == 0
        sorting_pad_mask[
            context_position_perm_id <
            1] = True  # exclude single-turn and unshuffled dialogs
        loss_ctx_uop = listNet(sorting_scores, context_position_ids,
                               sorting_pad_mask)
        #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask)

        loss = loss_ctx_mlm + loss_ctx_uop

        return {
            'loss': loss,
            'loss_ctx_mlm': loss_ctx_lm,
            'loss_ctx_uop': loss_ctx_uop
        }

    def train_decoder(self, context, context_utts_attn_mask, context_attn_mask,
                      context_lm_targets, context_position_perm_id,
                      context_position_ids, response):
        """
         only train the decoder
         """
        self.decoder.train()

        with torch.no_grad():
            context_hiddens, context_encoding = self.context_encoding(
                context, context_utts_attn_mask, context_attn_mask)

        ## train decoder
        dec_input, dec_target = response[:, :-1].contiguous(
        ), response[:, 1:].clone()

        dec_output, *_ = self.decoder(
            dec_input,
            dec_input.ne(self.tokenizer.pad_token_id).long(),
            None,
            None,
            None,
            None,
            encoder_hidden_states=context_hiddens,
            encoder_attention_mask=context_attn_mask,
        )

        batch_size, seq_len, vocab_size = dec_output.size()
        dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100
        dec_target[context_position_perm_id >
                   1] == -100  # ignore responses whose contexts are shuffled
        loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size),
                                          dec_target.view(-1))

        results = {'loss': loss_decoder, 'loss_decoder': loss_decoder}

        return results

    def forward(self, context, context_utts_attn_mask, context_attn_mask,
                context_mlm_targets, context_position_perm_id,
                context_position_ids, response):
        self.train()
        batch_size, max_ctx_len, max_utt_len = context.size(
        )  #context: [batch_size x diag_len x max_utt_len]

        context_hiddens, context_encoding = self.context_encoding(
            context, context_utts_attn_mask, context_attn_mask)

        ## train dialog flow modeling
        context_mlm_targets[context_mlm_targets == -100] = 0
        ctx_mlm_mask = context_mlm_targets.sum(2)  #[batch_size x num_utts]
        if (ctx_mlm_mask > 0).sum() == 0: ctx_mlm_mask[0, 0] = 1
        ctx_mlm_mask = ctx_mlm_mask > 0

        with torch.no_grad():
            _, mlm_tgt_encodings, *_ = self.utt_encoder.bert(
                context_mlm_targets[ctx_mlm_mask],
                context_utts_attn_mask[ctx_mlm_mask])

        if self.mlm_mode == 'mdn':  # mixture density network
            mlm_pred_pi, mlm_pred_normal = self.context_mlm_trans(
                self.dropout(context_hiddens[ctx_mlm_mask]))
            loss_ctx_mlm = self.context_mlm_trans.loss(mlm_pred_pi,
                                                       mlm_pred_normal,
                                                       mlm_tgt_encodings)
        else:  # simply mean square loss
            mlm_pred_encodings = self.context_mlm_trans(
                self.dropout(context_hiddens[ctx_mlm_mask]))
            loss_ctx_mlm = MSELoss()(
                mlm_pred_encodings,
                mlm_tgt_encodings)  # [num_selected_utts x dim]

        # context order prediction
        if isinstance(self.context_order_trans, SelfSorting):
            sorting_scores = self.context_order_trans(context_hiddens,
                                                      context_attn_mask)
        else:
            sorting_scores = self.context_order_trans(context_hiddens)
        sorting_pad_mask = context_attn_mask == 0
        sorting_pad_mask[
            context_position_perm_id <
            1] = True  # exclude single-turn and unshuffled dialogs
        loss_ctx_uop = listNet(sorting_scores, context_position_ids,
                               sorting_pad_mask)
        #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask)

        ## train decoder
        dec_input, dec_target = response[:, :-1].contiguous(
        ), response[:, 1:].clone()

        dec_output, *_ = self.decoder(
            dec_input,
            dec_input.ne(self.tokenizer.pad_token_id).long(),
            None,
            None,
            None,
            None,
            encoder_hidden_states=context_hiddens,
            encoder_attention_mask=context_attn_mask,
        )

        batch_size, seq_len, vocab_size = dec_output.size()
        dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100
        dec_target[context_position_perm_id >
                   1] = -100  # ignore responses whose context was shuffled
        loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size),
                                          dec_target.view(-1))

        loss = loss_ctx_mlm + loss_ctx_uop + loss_decoder

        results = {
            'loss': loss,
            'loss_ctx_mlm': loss_ctx_mlm,
            'loss_ctx_uop': loss_ctx_uop,
            'loss_decoder': loss_decoder
        }

        return results

    def validate(self, context, context_utts_attn_mask, context_attn_mask,
                 context_lm_targets, context_position_perm_id,
                 context_position_ids, response):
        results = self.train_decoder(context, context_utts_attn_mask,
                                     context_attn_mask, context_lm_targets,
                                     context_position_perm_id,
                                     context_position_ids, response)
        return results['loss'].item()

    def generate(self, input_batch, max_len=30, num_samples=1, mode='sample'):
        self.eval()
        device = next(self.parameters()).device
        context, context_utts_attn_mask, context_attn_mask = [
            t.to(device) for t in input_batch[:3]
        ]
        ground_truth = input_batch[6].numpy()

        context_hiddens, context_encoding = self.context_encoding(
            context, context_utts_attn_mask, context_attn_mask)

        generated = torch.zeros(
            (num_samples, 1), dtype=torch.long,
            device=device).fill_(self.tokenizer.cls_token_id)
        # [batch_sz x 1] (1=seq_len)

        sample_lens = torch.ones((num_samples, 1),
                                 dtype=torch.long,
                                 device=device)
        len_inc = torch.ones((num_samples, 1), dtype=torch.long, device=device)
        for _ in range(max_len):
            outputs, *_ = self.decoder(
                generated,
                generated.ne(self.tokenizer.pad_token_id).long(),
                None,
                None,
                None,
                None,
                encoder_hidden_states=context_hiddens,
                encoder_attention_mask=context_attn_mask,
            )  # [batch_size x seq_len x vocab_size]
            next_token_logits = outputs[:,
                                        -1, :] / self.decoder_config.temperature

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(num_samples):
                for _ in set(generated[i].tolist()):
                    next_token_logits[
                        i, _] /= self.decoder_config.repetition_penalty

            filtered_logits = top_k_top_p_filtering(
                next_token_logits,
                top_k=self.decoder_config.top_k,
                top_p=self.decoder_config.top_p)
            if mode == 'greedy':  # greedy sampling:
                next_token = torch.argmax(filtered_logits,
                                          dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(torch.softmax(filtered_logits,
                                                             dim=-1),
                                               num_samples=num_samples)
            next_token[len_inc == 0] = self.tokenizer.pad_token_id
            generated = torch.cat((generated, next_token), dim=1)
            len_inc = len_inc * (
                next_token != self.tokenizer.sep_token_id).long(
                )  # stop incresing length (set 0 bit) when EOS is encountered
            if len_inc.sum() < 1: break
            sample_lens = sample_lens + len_inc

        # to numpy
        sample_words = generated.data.cpu().numpy()
        sample_lens = sample_lens.data.cpu().numpy()

        context = context.data.cpu().numpy()
        return sample_words, sample_lens, context, ground_truth  # nparray: [repeat x seq_len]
コード例 #16
0
encoder = BertModel(config=encoder_config)

vocabsize = decparams["vocab_size"]
max_length = decparams["max_length"]
decoder_config = BertConfig(
    vocab_size=vocabsize,
    max_position_embeddings=max_length + 64,  # this shuold be some large value
    num_attention_heads=decparams["num_attn_heads"],
    num_hidden_layers=decparams["num_hidden_layers"],
    hidden_size=decparams["hidden_size"],
    type_vocab_size=1,
    is_decoder=True,
    add_cross_attention=True)  # Very Important

decoder = BertLMHeadModel(config=decoder_config)

# Define encoder decoder model
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
model.to(device)


def count_parameters(mdl):
    return sum(p.numel() for p in mdl.parameters() if p.requires_grad)


print(f'The encoder has {count_parameters(encoder):,} trainable parameters')
print(f'The decoder has {count_parameters(decoder):,} trainable parameters')
print(f'The model has {count_parameters(model):,} trainable parameters')

optimizer = optim.Adam(model.parameters(), lr=modelparams['lr'])
コード例 #17
0
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = BertModel(config)
     decoder_model = BertLMHeadModel(decoder_config)
     return encoder_model, decoder_model
コード例 #18
0
ファイル: model.py プロジェクト: icapucap/bert
def build_model(config):

    src_tokenizer = BertTokenizer.from_pretrained(
        'bert-base-multilingual-cased')
    tgt_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    tgt_tokenizer.bos_token = '<s>'
    tgt_tokenizer.eos_token = '</s>'

    #hidden_size and intermediate_size are both wrt all the attention heads.
    #Should be divisible by num_attention_heads
    encoder_config = BertConfig(
        vocab_size=src_tokenizer.vocab_size,
        hidden_size=config.hidden_size,
        num_hidden_layers=config.num_hidden_layers,
        num_attention_heads=config.num_attention_heads,
        intermediate_size=config.intermediate_size,
        hidden_act=config.hidden_act,
        hidden_dropout_prob=config.dropout_prob,
        attention_probs_dropout_prob=config.dropout_prob,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12)

    decoder_config = BertConfig(
        vocab_size=tgt_tokenizer.vocab_size,
        hidden_size=config.hidden_size,
        num_hidden_layers=config.num_hidden_layers,
        num_attention_heads=config.num_attention_heads,
        intermediate_size=config.intermediate_size,
        hidden_act=config.hidden_act,
        hidden_dropout_prob=config.dropout_prob,
        attention_probs_dropout_prob=config.dropout_prob,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        is_decoder=True)

    encoder = BertModel(encoder_config)

    encoder_embeddings = torch.nn.Embedding(
        src_tokenizer.vocab_size,
        config.hidden_size,
        padding_idx=src_tokenizer.pad_token_id)
    decoder_embeddings = torch.nn.Embedding(
        tgt_tokenizer.vocab_size,
        config.hidden_size,
        padding_idx=tgt_tokenizer.pad_token_id)

    decoder = BertLMHeadModel(decoder_config)
    encoder.set_input_embeddings(encoder_embeddings.cuda())
    decoder.set_input_embeddings(decoder_embeddings.cuda())

    # #Create encoder and decoder embedding layers.
    # encoder_embeddings = torch.nn.Embedding(src_tokenizer.vocab_size, config.hidden_size, padding_idx=src_tokenizer.pad_token_id)
    # decoder_embeddings = torch.nn.Embedding(tgt_tokenizer.vocab_size, config.hidden_size, padding_idx=tgt_tokenizer.pad_token_id)

    # encoder = BertModel(encoder_config)
    # encoder.set_input_embeddings(encoder_embeddings.cuda())

    # decoder = BertForMaskedLM(decoder_config)
    # decoder.set_input_embeddings(decoder_embeddings.cuda())
    model = TranslationModel(config, src_tokenizer, tgt_tokenizer, encoder,
                             decoder)
    model.cuda()

    tokenizers = ED({'src': src_tokenizer, 'tgt': tgt_tokenizer})

    return model, tokenizers
コード例 #19
0
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = Speech2TextEncoder(config).eval()
     decoder_model = BertLMHeadModel(decoder_config).eval()
     return encoder_model, decoder_model
コード例 #20
0
 def load_model(self, model_path):
     self.model = BertLMHeadModel(self.configuration)
     self.model.load_state_dict(torch.load(model_path))