Пример #1
0
def main() -> None:
    tokenizer = Tokenizer(args.vocab_file)
    vocabulary_size = len(tokenizer)
    dataset = SentenceDataset(args.input_file, tokenizer=tokenizer.encode)
    loader = DataLoader(dataset,
                        args.batch_size,
                        shuffle=False,
                        collate_fn=dataset.collate_fn,
                        drop_last=False)

    searcher = BeamSearch(tokenizer.eos_index, beam_size=args.search_width)

    model = VAE(
        num_embeddings=len(tokenizer),
        dim_embedding=args.dim_embedding,
        dim_hidden=args.dim_hidden,
        dim_latent=args.dim_latent,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional,
        dropout=0.,
        word_dropout=0.,
        dropped_index=tokenizer.unk_index,
    ).to(device)
    model.load_state_dict(torch.load(args.checkpoint_file,
                                     map_location=device))
    model.eval()

    print('Generating sentence...')
    all_hypotheses = []
    with torch.no_grad():
        for s in tqdm(loader):
            s = s.to(device)
            length = torch.sum(s != tokenizer.pad_index, dim=-1)
            bsz = s.shape[0]

            mean, logvar = model.encode(s, length)
            # z = model.reparameterize(mean, logvar)
            z = mean

            hidden = model.fc_hidden(z)
            hidden = hidden.view(bsz, -1,
                                 model.dim_hidden).transpose(0,
                                                             1).contiguous()

            start_predictions = torch.zeros(bsz, device=device).fill_(
                tokenizer.bos_index).long()
            start_state = {'hidden': hidden.permute(1, 0, 2)}
            predictions, log_probabilities = searcher.search(
                start_predictions, start_state, model.step)

            for preds in predictions:
                tokens = preds[0]
                tokens = tokens[tokens != tokenizer.eos_index].tolist()
                all_hypotheses.append(tokenizer.decode(tokens))
    print('Done')

    with open(args.output_file, 'w') as f:
        f.write('\n'.join(all_hypotheses))
Пример #2
0
def main() -> None:
    tokenizer = Tokenizer(args.vocab_file)
    vocabulary_size = len(tokenizer)

    searcher = BeamSearch(tokenizer.eos_index, beam_size=args.search_width)

    model = VAE(
        num_embeddings=len(tokenizer),
        dim_embedding=args.dim_embedding,
        dim_hidden=args.dim_hidden,
        dim_latent=args.dim_latent,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional,
        dropout=0.,
        word_dropout=0.,
        dropped_index=tokenizer.unk_index,
    ).to(device)
    model.load_state_dict(torch.load(args.checkpoint_file,
                                     map_location=device))
    model.eval()

    sentence1 = input('Please input sentence1: ')
    sentence2 = input('Please input sentence2: ')

    s1 = [tokenizer.bos_index
          ] + tokenizer.encode(sentence1) + [tokenizer.eos_index]
    s2 = [tokenizer.bos_index
          ] + tokenizer.encode(sentence2) + [tokenizer.eos_index]

    z1, _ = model.encode(
        torch.tensor([s1]).to(device),
        torch.tensor([len(s1)]).to(device))
    z2, _ = model.encode(
        torch.tensor([s2]).to(device),
        torch.tensor([len(s2)]).to(device))

    print("\nGenerate intermediate sentences")
    print("      %s" % sentence1)
    for r in range(1, 10):
        z = (1 - 0.1 * r) * z1 + 0.1 * r * z2
        hidden = model.fc_hidden(z)
        hidden = hidden.view(1, -1,
                             model.dim_hidden).transpose(0, 1).contiguous()

        start_predictions = torch.zeros(1, device=device).fill_(
            tokenizer.bos_index).long()
        start_state = {'hidden': hidden.permute(1, 0, 2)}
        predictions, log_probabilities = searcher.search(
            start_predictions, start_state, model.step)

        tokens = predictions[0, 0]
        tokens = tokens[tokens != tokenizer.eos_index].tolist()
        print("[%d:%d] %s" % (10 - r, r, tokenizer.decode(tokens)))
    print("      %s" % sentence2)
Пример #3
0
def main() -> None:
    tokenizer = Tokenizer(args.vocab_file)
    vocabulary_size = len(tokenizer)

    searcher = BeamSearch(tokenizer.eos_index, beam_size=args.search_width)

    model = VAE(
        num_embeddings=len(tokenizer),
        dim_embedding=args.dim_embedding,
        dim_hidden=args.dim_hidden,
        dim_latent=args.dim_latent,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional,
        dropout=0.,
        word_dropout=0.,
        dropped_index=tokenizer.unk_index,
    ).to(device)
    model.load_state_dict(torch.load(args.checkpoint_file,
                                     map_location=device))
    model.eval()

    z = torch.randn(args.sample_size, args.dim_latent, device=device)
    hidden = model.fc_hidden(z)
    hidden = hidden.view(args.sample_size, -1,
                         model.dim_hidden).transpose(0, 1).contiguous()

    start_predictions = torch.zeros(args.sample_size, device=device).fill_(
        tokenizer.bos_index).long()
    start_state = {'hidden': hidden.permute(1, 0, 2)}
    predictions, log_probabilities = searcher.search(start_predictions,
                                                     start_state, model.step)

    for pred in predictions:
        tokens = pred[0]
        tokens = tokens[tokens != tokenizer.eos_index].tolist()
        print(tokenizer.decode(tokens))
Пример #4
0
def collate(data: List[str], tokenizer: Tokenizer, block_size: int) -> Batch:
    ids = tokenizer.encode(data, block_size)
    mask = tokenizer.mask(ids)
    return Batch(ids=ids, attention_mask=mask)


def build_data_iterator(tokenizer,
                        dataset,
                        batch_size,
                        block_size,
                        random_sampler=False) -> DataLoader:
    sampler = RandomSampler(dataset) if random_sampler else SequentialSampler(
        dataset)
    iterator = DataLoader(
        dataset,
        sampler=sampler,
        batch_size=batch_size,
        collate_fn=lambda data: collate(data, tokenizer, block_size),
    )
    return iterator


if __name__ == "__main__":
    tokenizer = Tokenizer("tokenizer.model")
    with open("corpus.txt", encoding="utf-8") as f:
        dataset = f.readlines()
    iterator = build_data_iterator(tokenizer, dataset, 8, 128)
    batch = next(iter(iterator))
    print(tokenizer.decode(batch[0]))
Пример #5
0
class Seq2SeqModel(nn.Module):
    """
    模型
    """
    def __init__(self, config: BertConfig):
        super(Seq2SeqModel, self).__init__()
        # 获取配置信息
        self.hidden_dim = config.hidden_size
        self.vocab_size = config.vocab_size

        # encoder and decoder
        self.bert = BertModel(config)
        self.decoder = BertLMPredictionHead(
            config, self.bert.embeddings.word_embeddings.weight)

        # 加载字典和分词器
        self.word2ix = load_bert_vocab()
        self.tokenizer = Tokenizer(self.word2ix)

    def compute_loss(self, predictions, labels, target_mask):
        """
        target_mask : 句子a部分和pad部分全为0, 而句子b部分为1
        """
        predictions = predictions.view(-1, self.vocab_size)
        labels = labels.view(-1)
        target_mask = target_mask.view(-1).float()
        loss = nn.CrossEntropyLoss(ignore_index=0, reduction="none")
        return (loss(predictions, labels) * target_mask
                ).sum() / target_mask.sum()  ## 通过mask 取消 pad 和句子a部分预测的影响

    def forward(self,
                input_tensor,
                token_type_id,
                position_enc=None,
                labels=None,
                device="cpu"):
        '''
        :param input_tensor: 传入输入
        :param token_type_id: 句子标志
        :param position_enc: 位置编码
        :param labels: 解码的句子
        :param device:
        :return:
        '''
        input_shape = input_tensor.size()

        seq_len = input_shape[1]
        # 构建特殊的mask
        ones = torch.ones((1, 1, seq_len, seq_len),
                          dtype=torch.float32,
                          device=device)
        a_mask = ones.tril()  # 下三角矩阵
        s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float()
        s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float()
        a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask
        # print(a_mask.size())   # torch.Size([2, 1, 44, 44])

        enc_layers, _ = self.bert(input_tensor,
                                  position_ids=position_enc,
                                  token_type_ids=token_type_id,
                                  attention_mask=a_mask,
                                  output_all_encoded_layers=True)
        # print(_.size())  # torch.Size([2, 768])   (batch_size, hidden_size)

        squence_out = enc_layers[-1]  # 取出来最后一层输出
        # print(squence_out.size())    # torch.Size([2, 31, 768])

        predictions = self.decoder(squence_out)
        # print(labels.size())   # torch.Size([2, 30])
        # print(predictions.size())   # torch.Size([2, 31, 21128])

        if labels is not None:
            # 计算loss
            # 需要构建特殊的输出mask 才能计算正确的loss
            # 预测的值不用取最后sep符号的结果 因此是到-1
            predictions = predictions[:, :-1].contiguous()
            # print(predictions.size())  # torch.Size([2, 30, 21128])

            target_mask = token_type_id[:, 1:].contiguous()
            # print(target_mask)
            loss = self.compute_loss(predictions, labels, target_mask)
            return predictions, loss
        else:
            return predictions

    def generate(self, text, out_max_length=50, beam_size=1, device="cpu"):
        # 对一个句子生成相应的结果
        # 通过输出最大长度得到输入的最大长度,这里问题不大,如果超过最大长度会进行截断
        self.out_max_length = out_max_length
        input_max_length = Config.max_length - out_max_length
        # print(text)
        token_ids, token_type_ids = self.tokenizer.encode(
            text, max_length=input_max_length)
        token_ids = torch.tensor(token_ids, device=device).view(1, -1)
        token_type_ids = torch.tensor(token_type_ids,
                                      device=device).view(1, -1)
        out_puts_ids = self.beam_search(token_ids,
                                        token_type_ids,
                                        self.word2ix,
                                        beam_size=beam_size,
                                        device=device)
        # 解码 得到相应输出
        return self.tokenizer.decode(out_puts_ids)

    def beam_search(self,
                    token_ids,
                    token_type_ids,
                    word2ix,
                    beam_size=1,
                    device="cpu"):
        """
        beam-search操作
        """
        sep_id = word2ix["[SEP]"]
        # 用来保存输出序列
        output_ids = [[]]
        # 用来保存累计得分
        output_scores = torch.zeros(token_ids.shape[0], device=device)
        for step in range(self.out_max_length):

            scores = self.forward(token_ids, token_type_ids, device=device)
            # print(scores.shape)
            if step == 0:
                # 重复beam-size次 输入ids
                token_ids = token_ids.view(1, -1).repeat(beam_size, 1)
                token_type_ids = token_type_ids.view(1,
                                                     -1).repeat(beam_size, 1)
            ## 计算log 分值 (beam_size, vocab_size)
            logit_score = torch.log_softmax(scores, dim=-1)[:, -1]
            logit_score = output_scores.view(-1, 1) + logit_score  # 累计得分
            ## 取topk的时候我们是展平了然后再去调用topk函数
            # 展平
            logit_score = logit_score.view(-1)
            hype_score, hype_pos = torch.topk(logit_score, beam_size)
            indice1 = hype_pos / scores.shape[-1]  # 行索引
            indice2 = hype_pos % scores.shape[-1]  # 列索引

            # 下面需要更新一下输出了
            new_hype_scores = []
            new_hype_ids = []
            # 为啥有这个[],就是因为要过滤掉结束的序列。
            next_chars = []  # 用来保存新预测出来的一个字符,继续接到输入序列后面,再去预测新字符
            for i_1, i_2, score in zip(indice1, indice2, hype_score):
                i_1 = i_1.item()
                i_2 = i_2.item()
                socre = score.item()

                hype_id = output_ids[i_1] + [i_2]  # 保存所有输出的序列,而不仅仅是新预测的单个字符

                if i_2 == sep_id:
                    # 说明解码到最后了
                    if score == torch.max(hype_score).item():
                        # 说明找到得分最大的那个序列了 直接返回即可
                        return hype_id[:-1]
                    else:
                        # 完成一个解码了,但这个解码得分并不是最高,因此的话需要跳过这个序列
                        beam_size -= 1
                else:
                    new_hype_ids.append(hype_id)
                    new_hype_scores.append(score)
                    next_chars.append(i_2)  # 收集一下,需要连接到当前的输入序列之后

            output_ids = new_hype_ids

            output_scores = torch.tensor(new_hype_scores,
                                         dtype=torch.float32,
                                         device=device)
            # 现在需要重新构造输入数据了,用上一次输入连接上这次新输出的字符,再输入bert中预测新字符
            token_ids = token_ids[:len(output_ids)].contiguous(
            )  # 截取,因为要过滤掉已经完成预测的序列
            token_type_ids = token_type_ids[:len(output_ids)].contiguous()

            next_chars = torch.tensor(next_chars,
                                      dtype=torch.long,
                                      device=device).view(-1, 1)
            next_token_type_ids = torch.ones_like(next_chars, device=device)
            # 连接
            token_ids = torch.cat((token_ids, next_chars), dim=1)
            token_type_ids = torch.cat((token_type_ids, next_token_type_ids),
                                       dim=1)
            if beam_size < 1:
                break

        # 如果达到最大长度的话 直接把得分最高的输出序列返回把
        return output_ids[output_scores.argmax().item()]
Пример #6
0
def main(checkpoint, spm_path, outf, n_words, bptt, seed, use_cuda,
         temperature):
    if seed:
        torch.manual_seed(seed)

    if torch.cuda.is_available():
        if not use_cuda:
            print(
                'WARNING: You have a CUDA device, so you should probably run with --cuda'
            )
    device = torch.device('cuda' if use_cuda else 'cpu')

    if temperature < 1e-3:
        parser.error('--temperature has to be greater or equal 1e-3')

    tokenizer = Tokenizer('models/sp_8000.model')

    with open(checkpoint, 'rb') as f:
        model = torch.load(f).to(device)
    model.eval()

    model_type = model.model_type if hasattr(model, 'model_type') else None

    if model_type == 'LSTMTransformer':
        hidden = model.init_hidden(1)
        mems = None
    elif model_type == 'Transformer':
        pass
    else:
        hidden = model.init_hidden(1)

    input = torch.tensor([[1]], dtype=torch.long).to(device)

    s = []
    with torch.no_grad():  # no tracking history
        for i in range(n_words):
            if model_type == 'LSTMTransformer':
                output, hidden, mems = model(input, hidden, mems)
                word_weights = output[-1].squeeze().div(
                    temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]
                input.fill_(word_idx)
            elif model_type == 'Transformer':
                output = model(input, False)
                word_weights = output[-1].squeeze().div(
                    temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]
                word_tensor = torch.Tensor([[word_idx]]).long().to(device)
                input = torch.cat([input, word_tensor], 0)[-bptt:]
            else:
                output, hidden = model(input, hidden)
                word_weights = output.squeeze().div(temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]
                input.fill_(word_idx)

            s.append(int(word_idx))
            if word_idx == 2:
                break

    txt = tokenizer.decode(s)
    with open(outf, 'w') as f:
        f.write(txt)

    print(txt)
Пример #7
0
        outputs = model.generate(
            input_ids,
            do_sample=True,
            max_length=args.length,
            top_p=args.top_p,
            top_k=0,
            no_repeat_ngram_size=args.no_repeat_ngram_size,
            num_return_sequences=args.num_return_sequences)
    elif args.temperature:
        outputs = model.generate(
            input_ids,
            do_sample=True,
            max_length=args.length,
            top_k=0,
            temperature=args.temperature,
            no_repeat_ngram_size=args.no_repeat_ngram_size,
            num_return_sequences=args.num_return_sequences)
    else:
        outputs = model.generate(
            input_ids,
            max_length=args.length,
            num_beams=args.num_beams,
            early_stopping=True,
            no_repeat_ngram_size=args.no_repeat_ngram_size,
            num_return_sequences=args.num_return_sequences)

    for i, sample_output in enumerate(outputs):
        print("{}: {}".format(
            i,
            tokenizer.decode(sample_output)[0].split('<EOS>')[0]))