コード例 #1
0
    def check_encoder_decoder_model_from_pretrained_configs(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask,
            **kwargs):
        encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(
            config, decoder_config)
        self.assertTrue(encoder_decoder_config.decoder.is_decoder)

        enc_dec_model = EncoderDecoderModel(encoder_decoder_config)
        enc_dec_model.to(torch_device)
        enc_dec_model.eval()

        self.assertTrue(enc_dec_model.config.is_encoder_decoder)

        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            return_dict=True,
        )

        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape,
            (input_ids.shape + (config.hidden_size, )))
コード例 #2
0
    def check_encoder_decoder_model_labels(self, config, input_ids,
                                           attention_mask,
                                           encoder_hidden_states,
                                           decoder_config, decoder_input_ids,
                                           decoder_attention_mask, labels,
                                           **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
            return_dict=True,
        )

        loss = outputs_encoder_decoder["loss"]
        # check that backprop works
        loss.backward()

        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape,
            (input_ids.shape + (config.hidden_size, )))
コード例 #3
0
    def create_and_check_bert_encoder_decoder_model(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask,
            **kwargs):
        encoder_model = BertModel(config)
        decoder_model = BertForMaskedLM(decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder[0].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (input_ids.shape + (config.hidden_size, )))
        encoder_outputs = (encoder_hidden_states, )
        outputs_encoder_decoder = enc_dec_model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder[0].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (input_ids.shape + (config.hidden_size, )))
コード例 #4
0
    def check_save_and_load(self, config, input_ids, attention_mask,
                            encoder_hidden_states, decoder_config,
                            decoder_input_ids, decoder_attention_mask,
                            **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        enc_dec_model.to(torch_device)
        enc_dec_model.eval()
        with torch.no_grad():
            outputs = enc_dec_model(
                input_ids=input_ids,
                decoder_input_ids=decoder_input_ids,
                attention_mask=attention_mask,
                decoder_attention_mask=decoder_attention_mask,
            )
            out_2 = outputs[0].cpu().numpy()
            out_2[np.isnan(out_2)] = 0

            with tempfile.TemporaryDirectory() as tmpdirname:
                enc_dec_model.save_pretrained(tmpdirname)
                EncoderDecoderModel.from_pretrained(tmpdirname)

                after_outputs = enc_dec_model(
                    input_ids=input_ids,
                    decoder_input_ids=decoder_input_ids,
                    attention_mask=attention_mask,
                    decoder_attention_mask=decoder_attention_mask,
                )
                out_1 = after_outputs[0].cpu().numpy()
                out_1[np.isnan(out_1)] = 0
                max_diff = np.amax(np.abs(out_1 - out_2))
                self.assertLessEqual(max_diff, 1e-5)
コード例 #5
0
    def create_and_check_bert_encoder_decoder_model_lm_labels(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask,
            lm_labels, **kwargs):
        encoder_model = BertModel(config)
        decoder_model = BertForMaskedLM(decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            lm_labels=lm_labels,
        )

        lm_loss = outputs_encoder_decoder[0]
        self.check_loss_output(lm_loss)
        # check that backprop works
        lm_loss.backward()

        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[2].shape,
                         (input_ids.shape + (config.hidden_size, )))
    def check_encoder_decoder_model_output_attentions(
        self,
        config,
        input_ids,
        attention_mask,
        encoder_hidden_states,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        labels,
        **kwargs
    ):
        # make the decoder inputs a different shape from the encoder inputs to harden the test
        decoder_input_ids = decoder_input_ids[:, :-1]
        decoder_attention_mask = decoder_attention_mask[:, :-1]
        encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            output_attentions=True,
        )

        encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
        self.assertEqual(len(encoder_attentions), config.num_hidden_layers)

        self.assertEqual(
            encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
        )

        decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
        num_decoder_layers = (
            decoder_config.num_decoder_layers
            if hasattr(decoder_config, "num_decoder_layers")
            else decoder_config.num_hidden_layers
        )
        self.assertEqual(len(decoder_attentions), num_decoder_layers)

        self.assertEqual(
            decoder_attentions[0].shape[-3:],
            (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
        )

        cross_attentions = outputs_encoder_decoder["cross_attentions"]
        self.assertEqual(len(cross_attentions), num_decoder_layers)

        cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
            1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
        )
        self.assertEqual(
            cross_attentions[0].shape[-3:],
            (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
        )
    def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
        enc_dec_model.to(torch_device)

        # Bert does not have a bos token id, so use pad_token_id instead
        generated_output = enc_dec_model.generate(
            input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
        )
        self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
コード例 #8
0
    def check_encoder_decoder_model_output_attentions(
        self,
        config,
        input_ids,
        attention_mask,
        encoder_hidden_states,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        labels,
        **kwargs
    ):
        encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            output_attentions=True,
            return_dict=True,
        )

        encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
        self.assertEqual(len(encoder_attentions), config.num_hidden_layers)

        self.assertListEqual(
            list(encoder_attentions[0].shape[-3:]),
            [config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]],
        )

        decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
        num_decoder_layers = (
            decoder_config.num_decoder_layers
            if hasattr(decoder_config, "num_decoder_layers")
            else decoder_config.num_hidden_layers
        )
        self.assertEqual(len(decoder_attentions), num_decoder_layers)

        self.assertListEqual(
            list(decoder_attentions[0].shape[-3:]),
            [decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]],
        )

        cross_attentions = outputs_encoder_decoder["cross_attentions"]
        self.assertEqual(len(cross_attentions), num_decoder_layers)

        cross_attention_input_seq_len = input_ids.shape[-1] * (
            1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
        )
        self.assertListEqual(
            list(cross_attentions[0].shape[-3:]),
            [decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]],
        )
コード例 #9
0
    def check_encoder_decoder_model(
        self,
        config,
        input_ids,
        attention_mask,
        encoder_hidden_states,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        **kwargs,
    ):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        self.assertTrue(enc_dec_model.config.decoder.is_decoder)
        self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
        self.assertTrue(enc_dec_model.config.is_encoder_decoder)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            return_dict=True,
        )
        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape,
            (input_ids.shape + (config.hidden_size, )))

        encoder_outputs = BaseModelOutput(
            last_hidden_state=encoder_hidden_states)
        outputs_encoder_decoder = enc_dec_model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            return_dict=True,
        )

        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape,
            (input_ids.shape + (config.hidden_size, )))
コード例 #10
0
    def check_encoder_decoder_model(self, config, input_ids, attention_mask,
                                    encoder_hidden_states, decoder_config,
                                    decoder_input_ids, decoder_attention_mask,
                                    **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        self.assertTrue(enc_dec_model.config.decoder.is_decoder)
        self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
        self.assertTrue(enc_dec_model.config.is_encoder_decoder)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder[0].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (input_ids.shape + (config.hidden_size, )))
        encoder_outputs = (encoder_hidden_states, )
        outputs_encoder_decoder = enc_dec_model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder[0].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (input_ids.shape + (config.hidden_size, )))
コード例 #11
0
def get_model(args):
    if args.model_path:
        model = EncoderDecoderModel.from_pretrained(args.model_path)
        src_tokenizer = BertTokenizer.from_pretrained(
            os.path.join(args.model_path, "src_tokenizer")
        )
        tgt_tokenizer = GPT2Tokenizer.from_pretrained(
            os.path.join(args.model_path, "tgt_tokenizer")
        )
        tgt_tokenizer.build_inputs_with_special_tokens = types.MethodType(
            build_inputs_with_special_tokens, tgt_tokenizer
        )
        if local_rank == 0 or local_rank == -1:
            print("model and tokenizer load from save success")
    else:
        src_tokenizer = BertTokenizer.from_pretrained(args.src_pretrain_dataset_name)
        tgt_tokenizer = GPT2Tokenizer.from_pretrained(args.tgt_pretrain_dataset_name)
        tgt_tokenizer.add_special_tokens(
            {"bos_token": "[BOS]", "eos_token": "[EOS]", "pad_token": "[PAD]"}
        )
        tgt_tokenizer.build_inputs_with_special_tokens = types.MethodType(
            build_inputs_with_special_tokens, tgt_tokenizer
        )
        encoder = BertGenerationEncoder.from_pretrained(args.src_pretrain_dataset_name)
        decoder = GPT2LMHeadModel.from_pretrained(
            args.tgt_pretrain_dataset_name, add_cross_attention=True, is_decoder=True
        )
        decoder.resize_token_embeddings(len(tgt_tokenizer))
        decoder.config.bos_token_id = tgt_tokenizer.bos_token_id
        decoder.config.eos_token_id = tgt_tokenizer.eos_token_id
        decoder.config.vocab_size = len(tgt_tokenizer)
        decoder.config.add_cross_attention = True
        decoder.config.is_decoder = True
        model_config = EncoderDecoderConfig.from_encoder_decoder_configs(
            encoder.config, decoder.config
        )
        model = EncoderDecoderModel(
            encoder=encoder, decoder=decoder, config=model_config
        )
    if local_rank != -1:
        model = model.to(device)
    if args.ngpu > 1:
        print("{}/{} GPU start".format(local_rank, torch.cuda.device_count()))
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank
        )
    optimizer, scheduler = get_optimizer_and_schedule(args, model)
    return model, src_tokenizer, tgt_tokenizer, optimizer, scheduler
コード例 #12
0
    def create_and_check_encoder_decoder_shared_weights(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask, labels,
            **kwargs):
        torch.manual_seed(0)
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        model = EncoderDecoderModel(encoder=encoder_model,
                                    decoder=decoder_model)
        model.to(torch_device)
        model.eval()
        # load state dict copies weights but does not tie them
        decoder_state_dict = model.decoder._modules[
            model.decoder.base_model_prefix].state_dict()
        model.encoder.load_state_dict(decoder_state_dict, strict=False)

        torch.manual_seed(0)
        tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        config = EncoderDecoderConfig.from_encoder_decoder_configs(
            tied_encoder_model.config,
            tied_decoder_model.config,
            tie_encoder_decoder=True)
        tied_model = EncoderDecoderModel(encoder=tied_encoder_model,
                                         decoder=tied_decoder_model,
                                         config=config)
        tied_model.to(torch_device)
        tied_model.eval()

        model_result = model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        tied_model_result = tied_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        # check that models has less parameters
        self.assertLess(sum(p.numel() for p in tied_model.parameters()),
                        sum(p.numel() for p in model.parameters()))
        random_slice_idx = ids_tensor((1, ), model_result[0].shape[-1]).item()

        # check that outputs are equal
        self.assertTrue(
            torch.allclose(model_result[0][0, :, random_slice_idx],
                           tied_model_result[0][0, :, random_slice_idx],
                           atol=1e-4))

        # check that outputs after saving and loading are equal
        with tempfile.TemporaryDirectory() as tmpdirname:
            tied_model.save_pretrained(tmpdirname)
            tied_model = EncoderDecoderModel.from_pretrained(tmpdirname)
            tied_model.to(torch_device)
            tied_model.eval()

            # check that models has less parameters
            self.assertLess(sum(p.numel() for p in tied_model.parameters()),
                            sum(p.numel() for p in model.parameters()))
            random_slice_idx = ids_tensor((1, ),
                                          model_result[0].shape[-1]).item()

            tied_model_result = tied_model(
                input_ids=input_ids,
                decoder_input_ids=decoder_input_ids,
                attention_mask=attention_mask,
                decoder_attention_mask=decoder_attention_mask,
            )

            # check that outputs are equal
            self.assertTrue(
                torch.allclose(model_result[0][0, :, random_slice_idx],
                               tied_model_result[0][0, :, random_slice_idx],
                               atol=1e-4))
def sample_generate(top_k=50,
                    temperature=1.0,
                    model_path='/content/BERT checkpoints/model-9.pth',
                    gpu_id=0):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    # ------------------------LOAD MODEL-----------------
    print('load the model....')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_encoder = BertConfig.from_pretrained('bert-base-uncased')
    bert_decoder = BertConfig.from_pretrained('bert-base-uncased',
                                              is_decoder=True)
    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        bert_encoder, bert_decoder)
    model = EncoderDecoderModel(config)
    model.load_state_dict(torch.load(model_path, map_location='cuda'))
    model = model.to(device)
    encoder = model.get_encoder()
    decoder = model.get_decoder()
    model.eval()

    print('load success')
    # ------------------------END LOAD MODEL--------------

    # ------------------------LOAD VALIDATE DATA------------------
    test_data = torch.load("/content/test_data.pth")
    test_dataset = TensorDataset(*test_data)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 shuffle=False,
                                 batch_size=1)
    # ------------------------END LOAD VALIDATE DATA--------------

    # ------------------------START GENERETE-------------------
    update_count = 0

    bleu_2scores = 0
    bleu_4scores = 0
    nist_2scores = 0
    nist_4scores = 0
    sentences = []
    meteor_scores = 0

    print('start generating....')
    for batch in test_dataloader:
        with torch.no_grad():
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, _ = batch

            past, _ = encoder(encoder_input, mask_encoder_input)

            prev_pred = decoder_input[:, :1]
            sentence = prev_pred

            # decoding loop
            for i in range(100):
                logits = decoder(sentence, encoder_hidden_states=past)

                logits = logits[0][:, -1]
                logits = logits.squeeze(1) / temperature

                logits = top_k_logits(logits, k=top_k)
                probs = F.softmax(logits, dim=-1)
                prev_pred = torch.multinomial(probs, num_samples=1)
                sentence = torch.cat([sentence, prev_pred], dim=-1)
                if prev_pred[0][0] == 102:
                    break

            predict = tokenizer.convert_ids_to_tokens(sentence[0].tolist())

            encoder_input = encoder_input.squeeze(dim=0)
            encoder_input_num = (encoder_input != 0).sum()
            inputs = tokenizer.convert_ids_to_tokens(
                encoder_input[:encoder_input_num].tolist())

            decoder_input = decoder_input.squeeze(dim=0)
            decoder_input_num = (decoder_input != 0).sum()

            reference = tokenizer.convert_ids_to_tokens(
                decoder_input[:decoder_input_num].tolist())
            print('-' * 20 + f"example {update_count}" + '-' * 20)
            print(f"input: {' '.join(inputs)}")
            print(f"output: {' '.join(reference)}")
            print(f"predict: {' '.join(predict)}")

            temp_bleu_2, \
            temp_bleu_4, \
            temp_nist_2, \
            temp_nist_4, \
            temp_meteor_scores = calculate_metrics(predict[1:-1], reference[1:-1])

            bleu_2scores += temp_bleu_2
            bleu_4scores += temp_bleu_4
            nist_2scores += temp_nist_2
            nist_4scores += temp_nist_4

            meteor_scores += temp_meteor_scores
            sentences.append(" ".join(predict[1:-1]))
            update_count += 1

    entro, dist = cal_entropy(sentences)
    mean_len, var_len = cal_length(sentences)
    print(f'avg: {mean_len}, var: {var_len}')
    print(f'entro: {entro}')
    print(f'dist: {dist}')
    print(f'test bleu_2scores: {bleu_2scores / update_count}')
    print(f'test bleu_4scores: {bleu_4scores / update_count}')
    print(f'test nist_2scores: {nist_2scores / update_count}')
    print(f'test nist_4scores: {nist_4scores / update_count}')
    print(f'test meteor_scores: {meteor_scores / update_count}')
コード例 #14
0
def train_model(epochs=10,
                num_gradients_accumulation=4,
                batch_size=4,
                gpu_id=0,
                lr=1e-5,
                load_dir='/content/BERT checkpoints'):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    # ------------------------LOAD MODEL-----------------
    print('load the model....')
    bert_encoder = BertConfig.from_pretrained('bert-base-uncased')
    bert_decoder = BertConfig.from_pretrained('bert-base-uncased',
                                              is_decoder=True)
    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        bert_encoder, bert_decoder)
    model = EncoderDecoderModel(config)

    model = model.to(device)

    print('load success')
    # ------------------------END LOAD MODEL--------------

    # ------------------------LOAD TRAIN DATA------------------
    train_data = torch.load("/content/train_data.pth")
    train_dataset = TensorDataset(*train_data)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  shuffle=True,
                                  batch_size=batch_size)
    val_data = torch.load("/content/validate_data.pth")
    val_dataset = TensorDataset(*val_data)
    val_dataloader = DataLoader(dataset=val_dataset,
                                shuffle=True,
                                batch_size=batch_size)
    # ------------------------END LOAD TRAIN DATA--------------

    # ------------------------SET OPTIMIZER-------------------
    num_train_optimization_steps = len(
        train_dataset) * epochs // batch_size // num_gradients_accumulation

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=lr,
        weight_decay=0.01,
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_train_optimization_steps // 10,
        num_training_steps=num_train_optimization_steps)

    # ------------------------START TRAINING-------------------
    update_count = 0

    start = time.time()
    print('start training....')
    for epoch in range(epochs):
        # ------------------------training------------------------
        model.train()
        losses = 0
        times = 0

        print('\n' + '-' * 20 + f'epoch {epoch}' + '-' * 20)
        for batch in tqdm(train_dataloader):
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch
            logits = model(input_ids=encoder_input,
                           attention_mask=mask_encoder_input,
                           decoder_input_ids=decoder_input,
                           decoder_attention_mask=mask_decoder_input)

            out = logits[0][:, :-1].contiguous()
            target = decoder_input[:, 1:].contiguous()
            target_mask = mask_decoder_input[:, 1:].contiguous()
            loss = util.sequence_cross_entropy_with_logits(out,
                                                           target,
                                                           target_mask,
                                                           average="token")
            loss.backward()

            losses += loss.item()
            times += 1

            update_count += 1

            if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        end = time.time()
        print(f'time: {(end - start)}')
        print(f'loss: {losses / times}')
        start = end

        # ------------------------validate------------------------
        model.eval()

        perplexity = 0
        batch_count = 0
        print('\nstart calculate the perplexity....')

        with torch.no_grad():
            for batch in tqdm(val_dataloader):
                batch = [item.to(device) for item in batch]

                encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch
                logits = model(input_ids=encoder_input,
                               attention_mask=mask_encoder_input,
                               decoder_input_ids=decoder_input,
                               decoder_attention_mask=mask_decoder_input)

                out = logits[0][:, :-1].contiguous()
                target = decoder_input[:, 1:].contiguous()
                target_mask = mask_decoder_input[:, 1:].contiguous()
                # print(out.shape,target.shape,target_mask.shape)
                loss = util.sequence_cross_entropy_with_logits(out,
                                                               target,
                                                               target_mask,
                                                               average="token")
                perplexity += np.exp(loss.item())
                batch_count += 1

        print(f'\nvalidate perplexity: {perplexity / batch_count}')

        torch.save(
            model.state_dict(),
            os.path.join(os.path.abspath('.'), load_dir,
                         "model-" + str(epoch) + ".pth"))
コード例 #15
0
vocabsize = decparams["vocab_size"]
max_length = decparams["max_length"]
decoder_config = BertConfig(
    vocab_size=vocabsize,
    max_position_embeddings=max_length + 64,  # this shuold be some large value
    num_attention_heads=decparams["num_attn_heads"],
    num_hidden_layers=decparams["num_hidden_layers"],
    hidden_size=decparams["hidden_size"],
    type_vocab_size=1,
    is_decoder=True)  # Very Important

decoder = BertForMaskedLM(config=decoder_config)

# Define encoder decoder model
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
model.to(device)


def count_parameters(mdl):
    return sum(p.numel() for p in mdl.parameters() if p.requires_grad)


print(f'The encoder has {count_parameters(encoder):,} trainable parameters')
print(f'The decoder has {count_parameters(decoder):,} trainable parameters')
print(f'The model has {count_parameters(model):,} trainable parameters')

optimizer = optim.Adam(model.parameters(), lr=modelparams['lr'])
criterion = nn.NLLLoss(ignore_index=de_tokenizer.pad_token_id)

num_train_batches = len(train_dataloader)
num_valid_batches = len(valid_dataloader)
コード例 #16
0
ファイル: __init__.py プロジェクト: biggoron/phonetizer
class PhonetizerModel:

    phon_tokenizer = {
        'e': 7,
        'i': 8,
        'R': 9,
        'a': 10,
        'o': 11,
        't': 12,
        's': 13,
        'l': 14,
        'k': 15,
        'p': 16,
        'm': 17,
        'n': 18,
        'd': 19,
        'y': 20,
        '@': 21,
        'f': 22,
        'z': 23,
        'b': 24,
        '§': 25,
        'v': 26,
        '2': 27,
        '1': 28,
        'Z': 29,
        'g': 30,
        'u': 31,
        'S': 32
    }
    phon_untokenizer = {v: k for k, v in phon_tokenizer.items()}
    char_tokenizer = {
        'e': 7,
        'i': 8,
        'a': 9,
        'r': 10,
        'o': 11,
        's': 12,
        't': 13,
        'n': 14,
        'l': 15,
        'é': 16,
        'c': 17,
        'p': 18,
        'u': 19,
        'm': 20,
        'd': 21,
        '-': 22,
        'h': 23,
        'g': 24,
        'b': 25,
        'v': 26,
        'f': 27,
        'k': 28,
        'y': 29,
        'x': 30,
        'è': 31,
        'ï': 32,
        'j': 33,
        'z': 34,
        'w': 35,
        'q': 36
    }

    def __init__(self, device='cpu', model=None):
        vocabsize = 37
        max_length = 50
        encoder_config = BertConfig(vocab_size=vocabsize,
                                    max_position_embeddings=max_length + 64,
                                    num_attention_heads=4,
                                    num_hidden_layers=4,
                                    hidden_size=128,
                                    type_vocab_size=1)
        encoder = BertModel(config=encoder_config)

        vocabsize = 33
        max_length = 50
        decoder_config = BertConfig(vocab_size=vocabsize,
                                    max_position_embeddings=max_length + 64,
                                    num_attention_heads=4,
                                    num_hidden_layers=4,
                                    hidden_size=128,
                                    type_vocab_size=1,
                                    add_cross_attentions=True,
                                    is_decoder=True)
        decoder_config.add_cross_attention = True
        decoder = BertLMHeadModel(config=decoder_config)

        # Define encoder decoder model
        self.model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
        self.model.to(device)
        self.device = device
        if model is not None:
            self.model.load_state_dict(torch.load(model))

    def phonetize(self, word):
        word = word.replace('à', 'a')
        word = word.replace('û', 'u')
        word = word.replace('ù', 'u')
        word = word.replace('î', 'i')
        word = word.replace('ç', 'ss')
        word = word.replace('ô', 'o')
        word = word.replace('â', 'a')
        word = word.replace('qu', 'k')
        word = word.replace('ê', 'e')
        assert set(word).issubset(set(PhonetizerModel.char_tokenizer.keys()))
        encoded = torch.tensor(
            [0] + [PhonetizerModel.char_tokenizer[p] for p in word] + [2])
        output = self.model.generate(
            encoded.unsqueeze(0).to(self.device),
            max_length=50,
            decoder_start_token_id=0,
            eos_token_id=2,
            pad_token_id=1,
        ).detach().cpu().numpy()[0]
        bound = np.where(output == 2)[0][0] if 2 in output else 1000
        phon_pred = ''.join([
            PhonetizerModel.phon_untokenizer[c] for c in output[:bound]
            if c > 6
        ])
        return phon_pred

    def check_phonetization_error(self, word, phon):
        prediction = self.phonetize(word)[:5]
        score = pairwise2.align.globalms(list(phon[:5]),
                                         list(prediction),
                                         2,
                                         -1,
                                         -1,
                                         -.5,
                                         score_only=True,
                                         gap_char=['-']) / len(phon[:5])
        return score