Exemplo n.º 1
0
def chance_reply(history: List[Tuple[bool, str]], tokenizer: OpenAIGPTTokenizer,
                          model: OpenAIGPTDoubleHeadsModel, device):
    model.to(device)

    # build the network inputs
    output = []
    inputs = [bos]
    token_types = [speaker_other if len(history) > 0 and not history[0][0] else speaker_self]
    for user, text in history:
        inputs.append(speaker_self if user else speaker_other)
        token_types.append(speaker_self if user else speaker_other)
        for token in tokenizer.tokenize(text):
            inputs.append(token)
            token_types.append(speaker_self if user else speaker_other)

    cutoff = 500
    input_ids = tokenizer.convert_tokens_to_ids(inputs)
    token_type_ids = tokenizer.convert_tokens_to_ids(token_types)

    model.eval()

    model_out = model(torch.tensor([input_ids[-cutoff:]], dtype=torch.long).to(device),
                      token_type_ids=torch.tensor([token_type_ids[-cutoff:]], dtype=torch.long).to(device))
    logits = model_out.logits[0, -1, :] / config["eval"]["temperature"]

    logits = filter_logits(logits, tokenizer, True, whitelist=[speaker_self, speaker_other])
    probs = F.softmax(logits, dim=-1)

    speaker_self_token = tokenizer.convert_tokens_to_ids(speaker_self)

    return probs[speaker_self_token].item()
    def test_full_tokenizer(self):
        tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)

        text = "lower"
        bpe_tokens = ["low", "er</w>"]
        tokens = tokenizer.tokenize(text)
        self.assertListEqual(tokens, bpe_tokens)

        input_tokens = tokens + ["<unk>"]
        input_bpe_tokens = [14, 15, 20]
        self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
Exemplo n.º 3
0
def generate_from_history(history: List[Tuple[bool, str]], tokenizer: OpenAIGPTTokenizer,
                          model: OpenAIGPTDoubleHeadsModel, device,
                          token_blacklist: Optional[List[str]] = None,) -> List[str]:
    """Generates an utterance given a set of messages preceding it.

    :argument history: a list of tuples (user, message)
                            user is a boolean on whether sender is user.
                            message is string.
    :argument tokenizer: the tokenizer
    :argument model: the model
    :argument device: pytorch device to run on
    :argument token_blacklist: a list of tokens to not make the network generate"""

    model.to(device)

    # build the network inputs
    output = []
    inputs = [bos]
    token_types = [speaker_other if len(history) > 0 and not history[0][0] else speaker_self]
    for user, text in history:
        inputs.append(speaker_self if user else speaker_other)
        token_types.append(speaker_self if user else speaker_other)
        for token in tokenizer.tokenize(text):
            inputs.append(token)
            token_types.append(speaker_self if user else speaker_other)
    inputs.append(speaker_self)
    token_types.append(speaker_self)

    input_ids = tokenizer.convert_tokens_to_ids(inputs)
    token_type_ids = tokenizer.convert_tokens_to_ids(token_types)

    model.eval()

    eos_token = tokenizer.convert_tokens_to_ids(eos)
    speaker_self_token = tokenizer.convert_tokens_to_ids(speaker_self)
    speaker_other_token = tokenizer.convert_tokens_to_ids(speaker_other)

    cutoff = config["bot"]["max_token_history"]
    for i in range(config["bot"]["token_limit"]):
        model_out = model(torch.tensor([input_ids[-cutoff:]], dtype=torch.long).to(device),
                          token_type_ids=torch.tensor([token_type_ids[-cutoff:]], dtype=torch.long).to(device))
        logits = model_out.logits[0, -1, :] / config["eval"]["temperature"]
        blacklist = [bos, eos, pad] + token_blacklist
        logits = filter_logits(logits, tokenizer, False, blacklist=blacklist)
        logits = top_p_sample(logits, config["eval"]["top_p"])
        # print("{} -> {}".format(tokenizer.convert_ids_to_tokens(output[-5:]), tokenizer.convert_ids_to_tokens(torch.topk(logits, 5)[1])))
        probs = F.softmax(logits, dim=-1)
        prev = torch.multinomial(probs, 1).item()
        input_ids.append(prev)
        token_type_ids.append(speaker_self_token)
        output.append(prev)
        if prev in (speaker_other_token, eos_token):
            break

    output = tokenizer.convert_ids_to_tokens(output)
    current_msg = []
    messages = []
    for i in output:
        if i in (speaker_self, eos, speaker_other):
            messages.append(tokenizer.convert_tokens_to_string(current_msg))
            current_msg = []
        else:
            current_msg.append(i)
    if len(current_msg) > 0:
        messages.append(tokenizer.convert_tokens_to_string(current_msg))
    return messages