def check_encoder_decoder_model(self, config, input_ids, attention_mask,
                                    encoder_hidden_states, decoder_config,
                                    decoder_input_ids, decoder_attention_mask,
                                    **kwargs):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        self.assertTrue(enc_dec_model.config.decoder.is_decoder)
        self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
        self.assertTrue(enc_dec_model.config.is_encoder_decoder)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder[0].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (input_ids.shape + (config.hidden_size, )))
        encoder_outputs = (encoder_hidden_states, )
        outputs_encoder_decoder = enc_dec_model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder[0].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (input_ids.shape + (config.hidden_size, )))
Пример #2
0
    def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self):
        config = self.get_encoder_decoder_config_small()

        # create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
        encoder_pt = BertModel(config.encoder).to(torch_device).eval()
        decoder_pt = BertLMHeadModel(config.decoder).to(torch_device).eval()

        encoder_decoder_pt = EncoderDecoderModel(encoder=encoder_pt, decoder=decoder_pt).to(torch_device).eval()

        input_ids = ids_tensor([13, 5], encoder_pt.config.vocab_size)
        decoder_input_ids = ids_tensor([13, 1], decoder_pt.config.vocab_size)

        pt_input_ids = torch.tensor(input_ids.numpy(), device=torch_device, dtype=torch.long)
        pt_decoder_input_ids = torch.tensor(decoder_input_ids.numpy(), device=torch_device, dtype=torch.long)

        logits_pt = encoder_decoder_pt(input_ids=pt_input_ids, decoder_input_ids=pt_decoder_input_ids).logits

        # PyTorch => TensorFlow
        with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
            encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
            encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
            encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
                tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
            )

        logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits

        max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
        self.assertAlmostEqual(max_diff, 0.0, places=3)

        # TensorFlow => PyTorch
        with tempfile.TemporaryDirectory() as tmp_dirname:
            encoder_decoder_tf.save_pretrained(tmp_dirname)
            encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)

        max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
        self.assertAlmostEqual(max_diff, 0.0, places=3)
Пример #3
0
    def __init__(self, config, dataset):
        super(BERT2BERT, self).__init__(config, dataset)

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        self.encoder_configure = BertConfig.from_pretrained('bert-base-cased')

        self.decoder_configure = BertConfig.from_pretrained('bert-base-cased')

        self.encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(
            encoder_config=self.encoder_configure,
            decoder_config=self.decoder_configure)

        self.encoder = BertGenerationEncoder.from_pretrained('bert-base-cased',
                                                             bos_token_id=101,
                                                             eos_token_id=102)

        self.decoder = BertGenerationDecoder.from_pretrained(
            'bert-base-cased',
            add_cross_attention=True,
            is_decoder=True,
            bos_token_id=101,
            eos_token_id=102)

        self.encoder_decoder = EncoderDecoderModel(
            encoder=self.encoder,
            decoder=self.decoder,
            config=self.encoder_decoder_config)

        self.sos_token = dataset.sos_token
        self.eos_token = dataset.eos_token
        self.padding_token_idx = self.tokenizer.pad_token_id
        self.max_source_length = config['source_max_seq_length']
        self.max_target_length = config['target_max_seq_length']

        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx,
                                        reduction='none')
Пример #4
0
    def check_encoder_decoder_model_labels(
        self,
        config,
        input_ids,
        attention_mask,
        encoder_hidden_states,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        labels,
        **kwargs,
    ):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
            return_dict=True,
        )

        loss = outputs_encoder_decoder["loss"]
        # check that backprop works
        loss.backward()

        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape,
            (input_ids.shape + (config.hidden_size, )))
Пример #5
0
    def __init__(
        self,
        model_save_path: str,
        batch_size: int,
        num_gpus: int,
        max_len: int = 512,
        lr: float = 3e-5,
        weight_decay: float = 1e-4,
        save_step_interval: int = 1000,
        accelerator: str = "ddp",
        precision: int = 16,
        use_amp: bool = True,
    ) -> None:
        super(Bert2Bert, self).__init__(
            model_save_path=model_save_path,
            max_len=max_len,
            batch_size=batch_size,
            num_gpus=num_gpus,
            lr=lr,
            weight_decay=weight_decay,
            save_step_interval=save_step_interval,
            accelerator=accelerator,
            precision=precision,
            use_amp=use_amp,
        )
        encoder_config = BertConfig.from_pretrained("monologg/kobert")
        decoder_config = BertConfig.from_pretrained("monologg/kobert")
        config = EncoderDecoderConfig.from_encoder_decoder_configs(
            encoder_config, decoder_config)

        self.model = EncoderDecoderModel(config)
        self.tokenizer = KoBertTokenizer()

        state_dict = BertModel.from_pretrained("monologg/kobert").state_dict()
        self.model.encoder.load_state_dict(state_dict)
        self.model.decoder.bert.load_state_dict(state_dict, strict=False)
    def create_and_check_encoder_decoder_shared_weights(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask, labels,
            **kwargs):
        torch.manual_seed(0)
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        model = EncoderDecoderModel(encoder=encoder_model,
                                    decoder=decoder_model)
        model.to(torch_device)
        model.eval()
        # load state dict copies weights but does not tie them
        decoder_state_dict = model.decoder._modules[
            model.decoder.base_model_prefix].state_dict()
        model.encoder.load_state_dict(decoder_state_dict, strict=False)

        torch.manual_seed(0)
        tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        config = EncoderDecoderConfig.from_encoder_decoder_configs(
            tied_encoder_model.config,
            tied_decoder_model.config,
            tie_encoder_decoder=True)
        tied_model = EncoderDecoderModel(encoder=tied_encoder_model,
                                         decoder=tied_decoder_model,
                                         config=config)
        tied_model.to(torch_device)
        tied_model.eval()

        model_result = model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        tied_model_result = tied_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        # check that models has less parameters
        self.assertLess(sum(p.numel() for p in tied_model.parameters()),
                        sum(p.numel() for p in model.parameters()))
        random_slice_idx = ids_tensor((1, ), model_result[0].shape[-1]).item()

        # check that outputs are equal
        self.assertTrue(
            torch.allclose(model_result[0][0, :, random_slice_idx],
                           tied_model_result[0][0, :, random_slice_idx],
                           atol=1e-4))

        # check that outputs after saving and loading are equal
        with tempfile.TemporaryDirectory() as tmpdirname:
            tied_model.save_pretrained(tmpdirname)
            tied_model = EncoderDecoderModel.from_pretrained(tmpdirname)
            tied_model.to(torch_device)
            tied_model.eval()

            # check that models has less parameters
            self.assertLess(sum(p.numel() for p in tied_model.parameters()),
                            sum(p.numel() for p in model.parameters()))
            random_slice_idx = ids_tensor((1, ),
                                          model_result[0].shape[-1]).item()

            tied_model_result = tied_model(
                input_ids=input_ids,
                decoder_input_ids=decoder_input_ids,
                attention_mask=attention_mask,
                decoder_attention_mask=decoder_attention_mask,
            )

            # check that outputs are equal
            self.assertTrue(
                torch.allclose(model_result[0][0, :, random_slice_idx],
                               tied_model_result[0][0, :, random_slice_idx],
                               atol=1e-4))
Пример #7
0
def encoder_decoder_example():
	from transformers import EncoderDecoderConfig, EncoderDecoderModel
	from transformers import BertConfig, GPT2Config

	pretrained_model_name = 'bert-base-uncased'
	#pretrained_model_name = 'gpt2'

	if 'bert' in pretrained_model_name:
		# Initialize a BERT bert-base-uncased style configuration.
		config_encoder, config_decoder = BertConfig(), BertConfig()
	elif 'gpt2' in pretrained_model_name:
		config_encoder, config_decoder = GPT2Config(), GPT2Config()
	else:
		print('Invalid model, {}.'.format(pretrained_model_name))
		return

	config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

	if 'bert' in pretrained_model_name:
		# Initialize a Bert2Bert model from the bert-base-uncased style configurations.
		model = EncoderDecoderModel(config=config)
		#model = EncoderDecoderModel.from_encoder_decoder_pretrained(pretrained_model_name, pretrained_model_name)  # Initialize Bert2Bert from pre-trained checkpoints.
		tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
	elif 'gpt2' in pretrained_model_name:
		model = EncoderDecoderModel(config=config)
		tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name)

	#print('Configuration of the encoder & decoder:\n{}.\n{}.'.format(model.config.encoder, model.config.decoder))
	#print('Encoder type = {}, decoder type = {}.'.format(type(model.encoder), type(model.decoder)))

	if False:
		# Access the model configuration.
		config_encoder = model.config.encoder
		config_decoder  = model.config.decoder

		# Set decoder config to causal LM.
		config_decoder.is_decoder = True
		config_decoder.add_cross_attention = True

	#--------------------
	input_ids = torch.tensor(tokenizer.encode('Hello, my dog is cute', add_special_tokens=True)).unsqueeze(0)  # Batch size 1.

	if False:
		# Forward.
		outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)

		# Train.
		outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)
		loss, logits = outputs.loss, outputs.logits

		# Save the model, including its configuration.
		model.save_pretrained('my-model')

		#--------------------
		# Load model and config from pretrained folder.
		encoder_decoder_config = EncoderDecoderConfig.from_pretrained('my-model')
		model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)

	#--------------------
	# Generate.
	#	REF [site] >>
	#		https://huggingface.co/transformers/internal/generation_utils.html
	#		https://huggingface.co/blog/how-to-generate
	generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
	#generated = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2, num_return_sequences=5, do_sample=True, top_k=0, temperature=0.7, early_stopping=True, decoder_start_token_id=model.config.decoder.pad_token_id)
	print('Generated = {}.'.format(tokenizer.decode(generated[0], skip_special_tokens=True)))
Пример #8
0
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Version 1: load encoder-decoder together.
#model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "gpt2")

# Version 2: load pretrained modules separatelly and join them.
encoder = BertGenerationEncoder.from_pretrained("bert-base-uncased",
                                                bos_token_id=101,
                                                eos_token_id=102)
# add cross attention layers and use the same BOS and EOS tokens.
decoder = GPT2LMHeadModel.from_pretrained("gpt2",
                                          add_cross_attention=True,
                                          is_decoder=True,
                                          bos_token_id=101,
                                          eos_token_id=102)
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

# encode context the generation is conditioned on
input_ids = tokenizer.encode('I enjoy walking with my cute dog',
                             return_tensors='pt')

# Activate beam search and early_stopping.
# A simple remedy is to introduce n-grams (a.k.a word sequences of n words) penalties
# as introduced by Paulus et al. (2017) and Klein et al. (2017).
# The most common n-grams penalty makes sure that no n-gram appears twice by
# manually setting the probability of next words that could create an already seen n-gram to 0.
beam_output = model.generate(
    input_ids,
    max_length=50,
    num_beams=5,
    early_stopping=True,
Пример #9
0
    add_cross_attention = True, # add cross attention layers
    vocab_size = len(decoder_tokenizer),
    # Set required tokens.
    unk_token_id = decoder_tokenizer.vocab["[UNK]"],
    sep_token_id = decoder_tokenizer.vocab["[SEP]"],
    pad_token_id = decoder_tokenizer.vocab["[PAD]"],
    cls_token_id = decoder_tokenizer.vocab["[CLS]"],
    mask_token_id = decoder_tokenizer.vocab["[MASK]"],
    bos_token_id = decoder_tokenizer.vocab["[BOS]"],
    eos_token_id = decoder_tokenizer.vocab["[EOS]"],
    )
# Initialize a brand new bert-based decoder.
decoder = BertGenerationDecoder(config=decoder_config)

# Setup enc-decoder mode.
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
bert2bert.config.decoder_start_token_id=decoder_tokenizer.vocab["[CLS]"]
bert2bert.config.pad_token_id=decoder_tokenizer.vocab["[PAD]"]

# Elementary Training.
optimizer = torch.optim.Adam(bert2bert.parameters(), lr=0.000001)
bert2bert.cuda()

for epoch in range(30):
    print("*"*50, "Epoch", epoch, "*"*50)
    if True:
        for batch in tqdm(sierra_dl):
            # tokenize commands and goals.
            inputs = encoder_tokenizer(batch["command"], add_special_tokens=True, return_tensors="pt", padding=True, truncation=True)
            labels = decoder_tokenizer(batch["symbolic_plan_processed"], return_tensors="pt", padding=True, max_length=sierra_ds.max_plan_length, truncation=True, add_special_tokens=True, )
Пример #10
0
from transformers import (EncoderDecoderModel, PreTrainedModel, BertTokenizer,
                          BertGenerationEncoder, BertGenerationDecoder)

encoder = BertGenerationEncoder.from_pretrained(
    model_type, bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID
)  # add cross attention layers and use BERT’s cls token as BOS token and sep token as EOS token

decoder = BertGenerationDecoder.from_pretrained(model_type,
                                                add_cross_attention=True,
                                                is_decoder=True,
                                                bos_token_id=BOS_TOKEN_ID,
                                                eos_token_id=EOS_TOKEN_ID)
model = EncoderDecoderModel(encoder=encoder, decoder=decoder).to(device)
def sample_generate(top_k=50,
                    temperature=1.0,
                    model_path='/content/BERT checkpoints/model-9.pth',
                    gpu_id=0):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    # ------------------------LOAD MODEL-----------------
    print('load the model....')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_encoder = BertConfig.from_pretrained('bert-base-uncased')
    bert_decoder = BertConfig.from_pretrained('bert-base-uncased',
                                              is_decoder=True)
    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        bert_encoder, bert_decoder)
    model = EncoderDecoderModel(config)
    model.load_state_dict(torch.load(model_path, map_location='cuda'))
    model = model.to(device)
    encoder = model.get_encoder()
    decoder = model.get_decoder()
    model.eval()

    print('load success')
    # ------------------------END LOAD MODEL--------------

    # ------------------------LOAD VALIDATE DATA------------------
    test_data = torch.load("/content/test_data.pth")
    test_dataset = TensorDataset(*test_data)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 shuffle=False,
                                 batch_size=1)
    # ------------------------END LOAD VALIDATE DATA--------------

    # ------------------------START GENERETE-------------------
    update_count = 0

    bleu_2scores = 0
    bleu_4scores = 0
    nist_2scores = 0
    nist_4scores = 0
    sentences = []
    meteor_scores = 0

    print('start generating....')
    for batch in test_dataloader:
        with torch.no_grad():
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, _ = batch

            past, _ = encoder(encoder_input, mask_encoder_input)

            prev_pred = decoder_input[:, :1]
            sentence = prev_pred

            # decoding loop
            for i in range(100):
                logits = decoder(sentence, encoder_hidden_states=past)

                logits = logits[0][:, -1]
                logits = logits.squeeze(1) / temperature

                logits = top_k_logits(logits, k=top_k)
                probs = F.softmax(logits, dim=-1)
                prev_pred = torch.multinomial(probs, num_samples=1)
                sentence = torch.cat([sentence, prev_pred], dim=-1)
                if prev_pred[0][0] == 102:
                    break

            predict = tokenizer.convert_ids_to_tokens(sentence[0].tolist())

            encoder_input = encoder_input.squeeze(dim=0)
            encoder_input_num = (encoder_input != 0).sum()
            inputs = tokenizer.convert_ids_to_tokens(
                encoder_input[:encoder_input_num].tolist())

            decoder_input = decoder_input.squeeze(dim=0)
            decoder_input_num = (decoder_input != 0).sum()

            reference = tokenizer.convert_ids_to_tokens(
                decoder_input[:decoder_input_num].tolist())
            print('-' * 20 + f"example {update_count}" + '-' * 20)
            print(f"input: {' '.join(inputs)}")
            print(f"output: {' '.join(reference)}")
            print(f"predict: {' '.join(predict)}")

            temp_bleu_2, \
            temp_bleu_4, \
            temp_nist_2, \
            temp_nist_4, \
            temp_meteor_scores = calculate_metrics(predict[1:-1], reference[1:-1])

            bleu_2scores += temp_bleu_2
            bleu_4scores += temp_bleu_4
            nist_2scores += temp_nist_2
            nist_4scores += temp_nist_4

            meteor_scores += temp_meteor_scores
            sentences.append(" ".join(predict[1:-1]))
            update_count += 1

    entro, dist = cal_entropy(sentences)
    mean_len, var_len = cal_length(sentences)
    print(f'avg: {mean_len}, var: {var_len}')
    print(f'entro: {entro}')
    print(f'dist: {dist}')
    print(f'test bleu_2scores: {bleu_2scores / update_count}')
    print(f'test bleu_4scores: {bleu_4scores / update_count}')
    print(f'test nist_2scores: {nist_2scores / update_count}')
    print(f'test nist_4scores: {nist_4scores / update_count}')
    print(f'test meteor_scores: {meteor_scores / update_count}')
Пример #12
0
def train_model(epochs=10,
                num_gradients_accumulation=4,
                batch_size=4,
                gpu_id=0,
                lr=1e-5,
                load_dir='/content/BERT checkpoints'):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    # ------------------------LOAD MODEL-----------------
    print('load the model....')
    bert_encoder = BertConfig.from_pretrained('bert-base-uncased')
    bert_decoder = BertConfig.from_pretrained('bert-base-uncased',
                                              is_decoder=True)
    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        bert_encoder, bert_decoder)
    model = EncoderDecoderModel(config)

    model = model.to(device)

    print('load success')
    # ------------------------END LOAD MODEL--------------

    # ------------------------LOAD TRAIN DATA------------------
    train_data = torch.load("/content/train_data.pth")
    train_dataset = TensorDataset(*train_data)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  shuffle=True,
                                  batch_size=batch_size)
    val_data = torch.load("/content/validate_data.pth")
    val_dataset = TensorDataset(*val_data)
    val_dataloader = DataLoader(dataset=val_dataset,
                                shuffle=True,
                                batch_size=batch_size)
    # ------------------------END LOAD TRAIN DATA--------------

    # ------------------------SET OPTIMIZER-------------------
    num_train_optimization_steps = len(
        train_dataset) * epochs // batch_size // num_gradients_accumulation

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=lr,
        weight_decay=0.01,
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_train_optimization_steps // 10,
        num_training_steps=num_train_optimization_steps)

    # ------------------------START TRAINING-------------------
    update_count = 0

    start = time.time()
    print('start training....')
    for epoch in range(epochs):
        # ------------------------training------------------------
        model.train()
        losses = 0
        times = 0

        print('\n' + '-' * 20 + f'epoch {epoch}' + '-' * 20)
        for batch in tqdm(train_dataloader):
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch
            logits = model(input_ids=encoder_input,
                           attention_mask=mask_encoder_input,
                           decoder_input_ids=decoder_input,
                           decoder_attention_mask=mask_decoder_input)

            out = logits[0][:, :-1].contiguous()
            target = decoder_input[:, 1:].contiguous()
            target_mask = mask_decoder_input[:, 1:].contiguous()
            loss = util.sequence_cross_entropy_with_logits(out,
                                                           target,
                                                           target_mask,
                                                           average="token")
            loss.backward()

            losses += loss.item()
            times += 1

            update_count += 1

            if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        end = time.time()
        print(f'time: {(end - start)}')
        print(f'loss: {losses / times}')
        start = end

        # ------------------------validate------------------------
        model.eval()

        perplexity = 0
        batch_count = 0
        print('\nstart calculate the perplexity....')

        with torch.no_grad():
            for batch in tqdm(val_dataloader):
                batch = [item.to(device) for item in batch]

                encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch
                logits = model(input_ids=encoder_input,
                               attention_mask=mask_encoder_input,
                               decoder_input_ids=decoder_input,
                               decoder_attention_mask=mask_decoder_input)

                out = logits[0][:, :-1].contiguous()
                target = decoder_input[:, 1:].contiguous()
                target_mask = mask_decoder_input[:, 1:].contiguous()
                # print(out.shape,target.shape,target_mask.shape)
                loss = util.sequence_cross_entropy_with_logits(out,
                                                               target,
                                                               target_mask,
                                                               average="token")
                perplexity += np.exp(loss.item())
                batch_count += 1

        print(f'\nvalidate perplexity: {perplexity / batch_count}')

        torch.save(
            model.state_dict(),
            os.path.join(os.path.abspath('.'), load_dir,
                         "model-" + str(epoch) + ".pth"))
Пример #13
0
    config_decoder = BertConfig()

    encoder_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    decoder_tokenizer = BasicTokenizer("yz/vocab.txt")


    config_decoder.update({
        "vocab_size": len(decoder_tokenizer.vocab),
        "num_hidden_layers":3,
        "num_attention_heads":3
    })

    config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

    # 导入模型 BERT
    model = EncoderDecoderModel(config=config)
    model.encoder = BertModel.from_pretrained('bert-base-uncased')
    if args.gpu and torch.cuda.is_available():
        model = model.cuda()
    
    loss_fun = nn.CrossEntropyLoss()
    #loss_fun = nn.CrossEntropyLoss(ignore_index=0)

    optimizer = torch.optim.Adam(model.decoder.parameters(), lr=args.lr)

    # 记录时间
    begin_time = datetime.now()
    print("Start training BERT: ", begin_time)

    # 开始训练
    for epoch in range(args.epoch):