Ejemplo n.º 1
0
def sample_sequence(model,
                    length,
                    context,
                    temperature=1,
                    top_k=0,
                    top_p=0.9,
                    repetition_penalty=1.0,
                    repetition_penalty_range=512,
                    repetition_penalty_slope=3.33,
                    device="cpu",
                    stop_tokens=None,
                    tokenizer=None):
    """Actually generate the tokens"""
    logger.debug(
        'temp: {}    top_k: {}    top_p: {}    rep-pen: {}    rep-pen-range: {}    rep-pen-slope: {}'
        .format(temperature, top_k, top_p, repetition_penalty,
                repetition_penalty_range, repetition_penalty_slope))
    context_tokens = context
    context = torch.tensor(context, dtype=torch.long, device=device)
    # context = context.repeat(num_samples, 1)
    generated = context
    USE_PAST = True
    next_token = context
    pasts = None
    clines = 0

    penalty = None
    if not repetition_penalty_range is None and not repetition_penalty_slope is None and repetition_penalty_range > 0:
        penalty = (torch.arange(repetition_penalty_range) /
                   (repetition_penalty_range - 1)) * 2. - 1
        penalty = (repetition_penalty_slope *
                   penalty) / (1 + torch.abs(penalty) *
                               (repetition_penalty_slope - 1))
        penalty = 1 + ((penalty + 1) / 2) * (repetition_penalty - 1)

    with torch.no_grad():
        for j in range(length):
            # why would we ever not use past?
            # is generated and next_token always same thing?
            if not USE_PAST:
                input_ids_next = generated
                pasts = None
            else:
                input_ids_next = next_token

            # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            model_kwargs = {"past": pasts, "use_cache": True}
            model_inputs = model.prepare_inputs_for_generation(
                generated.unsqueeze(0), **model_kwargs)
            model_outputs = model(**model_inputs, return_dict=True)
            logits, pasts = model_outputs.logits, model_outputs.past_key_values
            logits = logits[0, -1, :].float()

            # Originally the order was Temperature, Repetition Penalty, then top-k/p
            if settings.getboolean('top-p-first'):
                logits = top_k_top_p_filtering(logits,
                                               top_k=top_k,
                                               top_p=top_p)

            logits = logits / (temperature if temperature > 0 else 1.0)

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) plus range limit
            if repetition_penalty != 1.0:
                if penalty is not None:
                    penalty_len = min(generated.shape[0],
                                      repetition_penalty_range)
                    penalty_context = generated[-repetition_penalty_range:]
                    score = torch.gather(logits, 0, penalty_context)
                    penalty = penalty.type(score.dtype).to(score.device)
                    penalty_window = penalty[-penalty_len:]
                    score = torch.where(score < 0, score * penalty_window,
                                        score / penalty_window)
                    logits.scatter_(0, penalty_context, score)
                else:
                    score = torch.gather(logits, 0, generated)
                    score = torch.where(score < 0, score * repetition_penalty,
                                        score / repetition_penalty)
                    logits.scatter_(0, generated, score)

            if not settings.getboolean('top-p-first'):
                logits = top_k_top_p_filtering(logits,
                                               top_k=top_k,
                                               top_p=top_p)

            if temperature == 0:  # greedy sampling:
                next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(logits, dim=-1),
                                               num_samples=1)
            generated = torch.cat((generated, next_token), dim=-1)
            # Decode into plain text
            o = generated[len(context_tokens):].tolist()
            generated.text = tokenizer.decode(
                o,
                clean_up_tokenization_spaces=False,
                skip_special_tokens=True)
            if use_ptoolkit():
                clear_lines(clines)
                generated.text = format_result(generated.text)
                clines = output(generated.text, "ai-text")
            if ((stop_tokens is not None) and (j > 4)
                    and (next_token[0] in stop_tokens)):
                # Why the minimum tokens, j>X. Because sometimes the models starts with whitespace, which will strip away anyway. Having a minimum amount of tokens before we stop usually means we don't just stop because of "\n " or similar
                logger.debug(
                    "Stopping generation as we found stop tokens. One of `%s`, in '%s'. token generated `%s`",
                    stop_tokens,
                    next_token,
                    j,
                )
                break
    clear_lines(clines)
    return generated
Ejemplo n.º 2
0
def sample_sequence(
        model,
        length,
        context,
        temperature=1,
        top_k=0,
        top_p=0.9,
        repetition_penalty=1.0,
        device="cpu",
        stop_tokens=None,
        tokenizer=None
):
    """Actually generate the tokens"""
    logger.debug(
        'temp: {}    top_k: {}    top_p: {}    rep-pen: {}'.format(temperature, top_k, top_p, repetition_penalty))
    context_tokens = context
    context = torch.tensor(context, dtype=torch.long, device=device)
    # context = context.repeat(num_samples, 1)
    generated = context
    USE_PAST = True
    next_token = context
    pasts = None
    clines = 0
    with torch.no_grad():
        for j in range(length):
            # why would we ever not use past?
            # is generated and next_token always same thing?
            if not USE_PAST:
                input_ids_next = generated
                pasts = None
            else:
                input_ids_next = next_token

            # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            logits, pasts = model(input_ids=input_ids_next, past=pasts)
            logits = logits[-1, :].float()

            # переписать  логику TODO
            if settings.getboolean('sparse-gen'): 
                probs = entmax_bisect(logits, dim=-1, alpha=settings.sparse-level)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                # Originally the order was Temperature, Repetition Penalty, then top-k/p
                if settings.getboolean('top-p-first'):
                    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

                logits = logits / (temperature if temperature > 0 else 1.0)

                # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
                for k in set(generated.tolist()):
                    logits[k] /= repetition_penalty

                if not settings.getboolean('top-p-first'):
                    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

                if temperature == 0:  # greedy sampling:
                    next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
                else:
                    next_token = torch.multinomial(
                        F.softmax(logits, dim=-1), num_samples=1
                    )
            generated = torch.cat((generated, next_token), dim=-1)
            # Decode into plain text
            o = generated[len(context_tokens):].tolist()
            generated.text = tokenizer.decode(
                o, clean_up_tokenization_spaces=False, skip_special_tokens=True
            )
            if use_ptoolkit():
                clear_lines(clines)
                generated.text = format_result(generated.text)
                clines = output(generated.text, "ai-text")
            if (
                    (stop_tokens is not None)
                    and (j > 4)
                    and (next_token[0] in stop_tokens)
            ):
                # Why the minimum tokens, j>X. Because sometimes the models starts with whitespace, which will strip away anyway. Having a minimum amount of tokens before we stop usually means we don't just stop because of "\n " or similar
                logger.debug(
                    "Stopping generation as we found stop tokens. One of `%s`, in '%s'. token generated `%s`",
                    stop_tokens,
                    next_token,
                    j,
                )
                break
    clear_lines(clines)
    return generated