def sample_generate(top_k=50,
                    temperature=1.0,
                    decoder_path='decoder.pth',
                    batch_size=1,
                    gpu_id=0):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    print('load model')
    #------------------------LOAD MODEL-----------------
    tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
    encoder = TransformerEncoder()
    encoder.load_state_dict(torch.load("encoder.pth"))
    encoder = encoder.to(device)
    encoder.eval()

    decoder = TransformerDecoderLM()
    decoder.load_state_dict(torch.load(decoder_path))
    decoder = decoder.to(device)
    decoder.eval()

    print('load success')
    #------------------------END LOAD MODEL--------------

    #------------------------LOAD VALIDATE DATA------------------
    test_data = torch.load("../test_data.pth")
    test_dataset = TensorDataset(*test_data)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 shuffle=False,
                                 batch_size=batch_size)
    #------------------------END LOAD VALIDATE DATA--------------

    #------------------------START SAMPLE GENERETE-------------------
    update_count = 0

    bleu_2scores = 0
    bleu_4scores = 0
    nist_2scores = 0
    nist_4scores = 0

    sen_length = 0
    meteor_scores = 0

    sentences = []
    print('start generate....')

    for batch in test_dataloader:
        with torch.no_grad():
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask, _ = batch

            _, past = encoder(encoder_input, mask)

            sentence = []

            prev_pred = decoder_input[:, :1]
            sentence.append(prev_pred)

            length = 1
            # decoding loop
            for i in range(100):
                mask = F.pad(mask, (0, 1), "constant", 1.0)
                logits, past = decoder(prev_pred,
                                       mask,
                                       past=past,
                                       past_length=length)
                logits = logits.squeeze(1) / temperature
                logits = top_k_logits(logits, k=top_k)
                probs = F.softmax(logits, dim=-1)
                prev_pred = torch.multinomial(probs, num_samples=1)
                sentence.append(prev_pred)
                if prev_pred[0][0] == 102:
                    break
                length += 1

            sentence = torch.cat(sentence, dim=-1)

            predict = tokenizer.convert_ids_to_tokens(sentence[0].tolist())
            target = decoder_input.squeeze(dim=0)
            target_num = (target != 0).sum()
            inputs = encoder_input.squeeze(dim=0)
            input_num = (inputs != 0).sum()
            inputs = tokenizer.convert_ids_to_tokens(
                inputs[:input_num].tolist())
            reference = tokenizer.convert_ids_to_tokens(
                target[:target_num].tolist())

            print('-' * 20 + f"example {update_count}" + '-' * 20)
            print(f"input: {''.join(inputs)}")
            print(f"output: {''.join(reference)}")
            print(f"predict: {''.join(predict)}")

            temp_bleu_2, \
            temp_bleu_4, \
            temp_nist_2, \
            temp_nist_4, \
            temp_meteor_scores = calculate_metrics(predict[1:-1], reference[1:-1])

            bleu_2scores += temp_bleu_2
            bleu_4scores += temp_bleu_4
            nist_2scores += temp_nist_2
            nist_4scores += temp_nist_4

            meteor_scores += temp_meteor_scores
            sentences.append(" ".join(predict[1:-1]))
            update_count += 1

    entro, dist = cal_entropy(sentences)
    mean_len, var_len = cal_length(sentences)
    print(f'avg: {mean_len}, var: {var_len}')
    print(f'entro: {entro}')
    print(f'dist: {dist}')
    print(f'test bleu_2scores: {bleu_2scores / update_count}')
    print(f'test bleu_4scores: {bleu_4scores / update_count}')
    print(f'test nist_2scores: {nist_2scores / update_count}')
    print(f'test nist_4scores: {nist_4scores / update_count}')
    print(f'test meteor_scores: {meteor_scores / update_count}')
Esempio n. 2
0
def calculate_perplexity(batch_size=1, gpu_id=0, decoder_path='decoder.pth'):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    #------------------------LOAD MODEL-----------------
    print('load the model....')
    encoder = TransformerEncoder()
    encoder.load_state_dict(torch.load("encoder.pth"))
    encoder = encoder.to(device)
    encoder.eval()

    decoder = TransformerDecoderLM()
    decoder.load_state_dict(torch.load(decoder_path))
    decoder = decoder.to(device)
    decoder.eval()

    print('load success')
    #------------------------END LOAD MODEL--------------

    #------------------------LOAD VAL DATA------------------
    val_data = torch.load("validate_data.pth")
    val_dataset = TensorDataset(*val_data)

    train_data = torch.load("train_data.pth")
    train_dataset = TensorDataset(*train_data)

    test_data = torch.load("test_data.pth")
    test_dataset = TensorDataset(*test_data)

    val_dataloader = DataLoader(dataset=val_dataset,
                                shuffle=False,
                                batch_size=batch_size)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  shuffle=False,
                                  batch_size=batch_size)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 shuffle=False,
                                 batch_size=batch_size)
    #------------------------END LOAD VAL DATA--------------

    #------------------------START VAL-------------------
    perplexity = 0
    batch_count = 0
    print('start calculate the train perplexity....')

    with torch.no_grad():
        for batch in train_dataloader:
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch

            _, past = encoder(encoder_input, mask_encoder_input)

            mask = torch.cat([mask_encoder_input, mask_decoder_input], dim=1)
            logits, _ = decoder(decoder_input, mask, past=past, past_length=0)

            out = logits[:, :-1].contiguous()
            target = decoder_input[:, 1:].contiguous()
            target_mask = mask_decoder_input[:, 1:].contiguous()

            loss = util.sequence_cross_entropy_with_logits(out,
                                                           target,
                                                           target_mask,
                                                           average="token")
            perplexity += np.exp(loss.item())
            batch_count += 1

    print(f'train perplexity: {perplexity / batch_count}')

    perplexity = 0
    batch_count = 0
    print('start calculate the validate perplexity....')

    with torch.no_grad():
        for batch in val_dataloader:
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch

            _, past = encoder(encoder_input, mask_encoder_input)

            mask = torch.cat([mask_encoder_input, mask_decoder_input], dim=1)
            logits, _ = decoder(decoder_input, mask, past=past, past_length=0)

            out = logits[:, :-1].contiguous()
            target = decoder_input[:, 1:].contiguous()
            target_mask = mask_decoder_input[:, 1:].contiguous()

            loss = util.sequence_cross_entropy_with_logits(out,
                                                           target,
                                                           target_mask,
                                                           average="token")
            perplexity += np.exp(loss.item())
            batch_count += 1

    print(f'validate perplexity: {perplexity / batch_count}')

    perplexity = 0
    batch_count = 0
    print('start calculate the test perplexity....')

    with torch.no_grad():
        for batch in test_dataloader:
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch

            _, past = encoder(encoder_input, mask_encoder_input)

            mask = torch.cat([mask_encoder_input, mask_decoder_input], dim=1)
            logits, _ = decoder(decoder_input, mask, past=past, past_length=0)

            out = logits[:, :-1].contiguous()
            target = decoder_input[:, 1:].contiguous()
            target_mask = mask_decoder_input[:, 1:].contiguous()

            loss = util.sequence_cross_entropy_with_logits(out,
                                                           target,
                                                           target_mask,
                                                           average="token")
            perplexity += np.exp(loss.item())
            batch_count += 1

    print(f'test perplexity: {perplexity / batch_count}')
Esempio n. 3
0
def train_model(epochs=10,
                num_gradients_accumulation=4,
                batch_size=4,
                gpu_id=0,
                lr=1e-5,
                load_dir='decoder_model'):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    #------------------------LOAD MODEL-----------------
    print('load the model....')
    encoder = TransformerEncoder()
    decoder = TransformerDecoderLM()

    encoder.load_state_dict(torch.load("encoder.pth"))
    decoder.load_state_dict(torch.load("decoder.pth"))

    encoder = encoder.to(device)
    decoder = decoder.to(device)

    print('load success')
    #------------------------END LOAD MODEL--------------

    #------------------------LOAD TRAIN DATA------------------
    train_data = torch.load("train_data.pth")
    train_dataset = TensorDataset(*train_data)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  shuffle=True,
                                  batch_size=batch_size)
    val_data = torch.load("validate_data.pth")
    val_dataset = TensorDataset(*val_data)
    val_dataloader = DataLoader(dataset=val_dataset,
                                shuffle=True,
                                batch_size=batch_size)
    #------------------------END LOAD TRAIN DATA--------------

    #------------------------SET OPTIMIZER-------------------
    num_train_optimization_steps = len(
        train_dataset) * epochs // batch_size // num_gradients_accumulation

    param_optimizer = list(decoder.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = OpenAIAdam(optimizer_grouped_parameters,
                           lr=lr,
                           warmup=0.01,
                           max_grad_norm=1.0,
                           weight_decay=0.01,
                           t_total=num_train_optimization_steps)
    #------------------------END SET OPTIMIZER--------------

    #------------------------START TRAINING-------------------
    update_count = 0

    start = time.time()
    print('start training....')
    for epoch in range(epochs):
        #------------------------training------------------------
        decoder.train()
        losses = 0
        times = 0
        for batch in train_dataloader:
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch

            _, past = encoder(encoder_input, mask_encoder_input)

            mask = torch.cat([mask_encoder_input, mask_decoder_input], dim=1)
            logits, _ = decoder(decoder_input, mask, past=past, past_length=0)

            out = logits[:, :-1].contiguous()
            target = decoder_input[:, 1:].contiguous()
            target_mask = mask_decoder_input[:, 1:].contiguous()

            loss = util.sequence_cross_entropy_with_logits(out,
                                                           target,
                                                           target_mask,
                                                           average="token")
            loss.backward()

            losses += loss.item()
            times += 1

            update_count += 1

            if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:
                optimizer.step()
                optimizer.zero_grad()
        end = time.time()
        print('-' * 20 + f'epoch {epoch}' + '-' * 20)
        print(f'time: {(end - start)}')
        print(f'loss: {losses / times}')
        start = end

        #------------------------validate------------------------
        decoder.eval()

        perplexity = 0
        batch_count = 0
        print('start calculate the perplexity....')

        with torch.no_grad():
            for batch in val_dataloader:
                batch = [item.to(device) for item in batch]

                encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch

                _, past = encoder(encoder_input, mask_encoder_input)

                mask = torch.cat([mask_encoder_input, mask_decoder_input],
                                 dim=1)
                logits, _ = decoder(decoder_input,
                                    mask,
                                    past=past,
                                    past_length=0)

                out = logits[:, :-1].contiguous()
                target = decoder_input[:, 1:].contiguous()
                target_mask = mask_decoder_input[:, 1:].contiguous()

                loss = util.sequence_cross_entropy_with_logits(out,
                                                               target,
                                                               target_mask,
                                                               average="token")
                perplexity += np.exp(loss.item())
                batch_count += 1

        print(f'validate perplexity: {perplexity / batch_count}')

        torch.save(
            decoder.state_dict(),
            os.path.join(os.path.abspath('.'), load_dir,
                         str(epoch) + "decoder.pth"))
Esempio n. 4
0
def sample_generate(top_k=50,
                    temperature=1.0,
                    decoder_path='decoder.pth',
                    batch_size=1,
                    show_num=10,
                    gpu_id=0):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    print('load model')
    #------------------------LOAD MODEL-----------------
    tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
    encoder = TransformerEncoder()
    encoder.load_state_dict(torch.load("encoder.pth"))
    encoder = encoder.to(device)
    encoder.eval()

    decoder = TransformerDecoderLM()
    decoder.load_state_dict(torch.load(decoder_path))
    decoder = decoder.to(device)
    decoder.eval()

    print('load success')
    #------------------------END LOAD MODEL--------------

    #------------------------LOAD VALIDATE DATA------------------
    val_data = torch.load("test_data.pth")
    val_dataset = TensorDataset(*val_data)
    val_dataloader = DataLoader(dataset=val_dataset,
                                shuffle=False,
                                batch_size=batch_size)
    #------------------------END LOAD VALIDATE DATA--------------

    #------------------------START SAMPLE GENERETE-------------------
    update_count = 0
    print('start validate....')

    for batch in val_dataloader:
        with torch.no_grad():
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask, _ = batch

            _, past = encoder(encoder_input, mask)

            sentence = []

            prev_pred = decoder_input[:, :1]
            sentence.append(prev_pred)

            length = 1
            # decoding loop
            for i in range(100):
                mask = F.pad(mask, (0, 1), "constant", 1.0)
                logits, past = decoder(prev_pred,
                                       mask,
                                       past=past,
                                       past_length=length)
                logits = logits.squeeze(1) / temperature
                logits = top_k_logits(logits, k=top_k)
                probs = F.softmax(logits, dim=-1)
                prev_pred = torch.multinomial(probs, num_samples=1)
                sentence.append(prev_pred)
                length += 1

            sentence = torch.cat(sentence, dim=-1)

            res = "".join(tokenizer.convert_ids_to_tokens(
                sentence[0].tolist()))
            inputs = "".join(
                tokenizer.convert_ids_to_tokens(encoder_input[0].tolist()))
            target = "".join(
                tokenizer.convert_ids_to_tokens(decoder_input[0].tolist()))

            print('-' * 20 + f'Case {update_count}' + '-' * 20)
            print('-' * 20 + 'Input' + '-' * 20)
            print(inputs)
            print('')

            print('-' * 20 + 'Predcit' + '-' * 20)
            print(res[:100])
            print('')

            print('-' * 20 + 'Target' + '-' * 20)
            print(target)
            print('')

            update_count += 1
            if update_count == show_num:
                break