def chance_reply(history: List[Tuple[bool, str]], tokenizer: OpenAIGPTTokenizer,
                          model: OpenAIGPTDoubleHeadsModel, device):
    model.to(device)

    # build the network inputs
    output = []
    inputs = [bos]
    token_types = [speaker_other if len(history) > 0 and not history[0][0] else speaker_self]
    for user, text in history:
        inputs.append(speaker_self if user else speaker_other)
        token_types.append(speaker_self if user else speaker_other)
        for token in tokenizer.tokenize(text):
            inputs.append(token)
            token_types.append(speaker_self if user else speaker_other)

    cutoff = 500
    input_ids = tokenizer.convert_tokens_to_ids(inputs)
    token_type_ids = tokenizer.convert_tokens_to_ids(token_types)

    model.eval()

    model_out = model(torch.tensor([input_ids[-cutoff:]], dtype=torch.long).to(device),
                      token_type_ids=torch.tensor([token_type_ids[-cutoff:]], dtype=torch.long).to(device))
    logits = model_out.logits[0, -1, :] / config["eval"]["temperature"]

    logits = filter_logits(logits, tokenizer, True, whitelist=[speaker_self, speaker_other])
    probs = F.softmax(logits, dim=-1)

    speaker_self_token = tokenizer.convert_tokens_to_ids(speaker_self)

    return probs[speaker_self_token].item()
Exemple #2
0
def build_inputs(history: List[Tuple[bool, List[str]]],
                 reply: Optional[Tuple[bool, List[str]]],
                 tokenizer: transformers.OpenAIGPTTokenizer,
                 populate_lm_labels=False,
                 with_eos=True,
                 with_reply=True):
    if with_reply:
        history = history + [reply]
    sequence = list(
        map(lambda x: [speaker_self
                       if x[0] else speaker_other] + x[1], history))
    # print(sequence)
    sequence[0] = [bos] + sequence[0]
    if with_eos:
        sequence[-1] = sequence[-1] + [eos]
    words = list(chain(*sequence))
    segments = list(
        chain(*[[speaker_self if s[0] else speaker_other] * len(sequence[i])
                for i, s in enumerate(history)]))
    input_ids = tokenizer.convert_tokens_to_ids(words)
    mc_token_ids = len(input_ids) - 1
    token_type_ids = tokenizer.convert_tokens_to_ids(segments)
    lm_labels = [-100] * len(input_ids)
    if populate_lm_labels:
        lm_labels = ([-100] * sum(len(s) for s in sequence[:-1])
                     ) + tokenizer.convert_tokens_to_ids(sequence[-1])
    return input_ids, mc_token_ids, token_type_ids, lm_labels
def filter_logits(logits: torch.Tensor, tokenizer: OpenAIGPTTokenizer,
                  use_whitelist: bool, whitelist: Optional[List[str]] = None,
                  blacklist: Optional[List[str]] = None) -> torch.Tensor:
    # mask: 1 if -inf, 0 if keep
    if use_whitelist:
        whitelist = tokenizer.convert_tokens_to_ids(whitelist)
        mask = torch.ones(logits.size())
        mask[whitelist] = 0
    else:
        blacklist = tokenizer.convert_tokens_to_ids(blacklist)
        mask = torch.zeros(logits.size())
        mask[blacklist] = 1

    indices = mask == 1
    logits[indices] = float("-inf")
    return logits
    def test_full_tokenizer(self):
        tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)

        text = "lower"
        bpe_tokens = ["low", "er</w>"]
        tokens = tokenizer.tokenize(text)
        self.assertListEqual(tokens, bpe_tokens)

        input_tokens = tokens + ["<unk>"]
        input_bpe_tokens = [14, 15, 20]
        self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
Exemple #5
0
def get_data_loader(dataset: ChatDataset,
                    tokenizer: transformers.OpenAIGPTTokenizer,
                    batch_size: int = 4,
                    shuffle: bool = True,
                    num_workers: int = 0) -> DataLoader:
    pad_token_id = tokenizer.convert_tokens_to_ids("<pad>")
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=shuffle,
                        num_workers=num_workers,
                        collate_fn=lambda x: make_batch(x, pad_token_id))
    return loader
def generate_from_history(history: List[Tuple[bool, str]], tokenizer: OpenAIGPTTokenizer,
                          model: OpenAIGPTDoubleHeadsModel, device,
                          token_blacklist: Optional[List[str]] = None,) -> List[str]:
    """Generates an utterance given a set of messages preceding it.

    :argument history: a list of tuples (user, message)
                            user is a boolean on whether sender is user.
                            message is string.
    :argument tokenizer: the tokenizer
    :argument model: the model
    :argument device: pytorch device to run on
    :argument token_blacklist: a list of tokens to not make the network generate"""

    model.to(device)

    # build the network inputs
    output = []
    inputs = [bos]
    token_types = [speaker_other if len(history) > 0 and not history[0][0] else speaker_self]
    for user, text in history:
        inputs.append(speaker_self if user else speaker_other)
        token_types.append(speaker_self if user else speaker_other)
        for token in tokenizer.tokenize(text):
            inputs.append(token)
            token_types.append(speaker_self if user else speaker_other)
    inputs.append(speaker_self)
    token_types.append(speaker_self)

    input_ids = tokenizer.convert_tokens_to_ids(inputs)
    token_type_ids = tokenizer.convert_tokens_to_ids(token_types)

    model.eval()

    eos_token = tokenizer.convert_tokens_to_ids(eos)
    speaker_self_token = tokenizer.convert_tokens_to_ids(speaker_self)
    speaker_other_token = tokenizer.convert_tokens_to_ids(speaker_other)

    cutoff = config["bot"]["max_token_history"]
    for i in range(config["bot"]["token_limit"]):
        model_out = model(torch.tensor([input_ids[-cutoff:]], dtype=torch.long).to(device),
                          token_type_ids=torch.tensor([token_type_ids[-cutoff:]], dtype=torch.long).to(device))
        logits = model_out.logits[0, -1, :] / config["eval"]["temperature"]
        blacklist = [bos, eos, pad] + token_blacklist
        logits = filter_logits(logits, tokenizer, False, blacklist=blacklist)
        logits = top_p_sample(logits, config["eval"]["top_p"])
        # print("{} -> {}".format(tokenizer.convert_ids_to_tokens(output[-5:]), tokenizer.convert_ids_to_tokens(torch.topk(logits, 5)[1])))
        probs = F.softmax(logits, dim=-1)
        prev = torch.multinomial(probs, 1).item()
        input_ids.append(prev)
        token_type_ids.append(speaker_self_token)
        output.append(prev)
        if prev in (speaker_other_token, eos_token):
            break

    output = tokenizer.convert_ids_to_tokens(output)
    current_msg = []
    messages = []
    for i in output:
        if i in (speaker_self, eos, speaker_other):
            messages.append(tokenizer.convert_tokens_to_string(current_msg))
            current_msg = []
        else:
            current_msg.append(i)
    if len(current_msg) > 0:
        messages.append(tokenizer.convert_tokens_to_string(current_msg))
    return messages