def sample_sequence(model, length, context, temperature=1, top_k=0, top_p=0.9, repetition_penalty=1.0, repetition_penalty_range=512, repetition_penalty_slope=3.33, device="cpu", stop_tokens=None, tokenizer=None): """Actually generate the tokens""" logger.debug( 'temp: {} top_k: {} top_p: {} rep-pen: {} rep-pen-range: {} rep-pen-slope: {}' .format(temperature, top_k, top_p, repetition_penalty, repetition_penalty_range, repetition_penalty_slope)) context_tokens = context context = torch.tensor(context, dtype=torch.long, device=device) # context = context.repeat(num_samples, 1) generated = context USE_PAST = True next_token = context pasts = None clines = 0 penalty = None if not repetition_penalty_range is None and not repetition_penalty_slope is None and repetition_penalty_range > 0: penalty = (torch.arange(repetition_penalty_range) / (repetition_penalty_range - 1)) * 2. - 1 penalty = (repetition_penalty_slope * penalty) / (1 + torch.abs(penalty) * (repetition_penalty_slope - 1)) penalty = 1 + ((penalty + 1) / 2) * (repetition_penalty - 1) with torch.no_grad(): for j in range(length): # why would we ever not use past? # is generated and next_token always same thing? if not USE_PAST: input_ids_next = generated pasts = None else: input_ids_next = next_token # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states) model_kwargs = {"past": pasts, "use_cache": True} model_inputs = model.prepare_inputs_for_generation( generated.unsqueeze(0), **model_kwargs) model_outputs = model(**model_inputs, return_dict=True) logits, pasts = model_outputs.logits, model_outputs.past_key_values logits = logits[0, -1, :].float() # Originally the order was Temperature, Repetition Penalty, then top-k/p if settings.getboolean('top-p-first'): logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) logits = logits / (temperature if temperature > 0 else 1.0) # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) plus range limit if repetition_penalty != 1.0: if penalty is not None: penalty_len = min(generated.shape[0], repetition_penalty_range) penalty_context = generated[-repetition_penalty_range:] score = torch.gather(logits, 0, penalty_context) penalty = penalty.type(score.dtype).to(score.device) penalty_window = penalty[-penalty_len:] score = torch.where(score < 0, score * penalty_window, score / penalty_window) logits.scatter_(0, penalty_context, score) else: score = torch.gather(logits, 0, generated) score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) logits.scatter_(0, generated, score) if not settings.getboolean('top-p-first'): logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) if temperature == 0: # greedy sampling: next_token = torch.argmax(logits, dim=-1).unsqueeze(-1) else: next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) generated = torch.cat((generated, next_token), dim=-1) # Decode into plain text o = generated[len(context_tokens):].tolist() generated.text = tokenizer.decode( o, clean_up_tokenization_spaces=False, skip_special_tokens=True) if use_ptoolkit(): clear_lines(clines) generated.text = format_result(generated.text) clines = output(generated.text, "ai-text") if ((stop_tokens is not None) and (j > 4) and (next_token[0] in stop_tokens)): # Why the minimum tokens, j>X. Because sometimes the models starts with whitespace, which will strip away anyway. Having a minimum amount of tokens before we stop usually means we don't just stop because of "\n " or similar logger.debug( "Stopping generation as we found stop tokens. One of `%s`, in '%s'. token generated `%s`", stop_tokens, next_token, j, ) break clear_lines(clines) return generated
def sample_sequence( model, length, context, temperature=1, top_k=0, top_p=0.9, repetition_penalty=1.0, device="cpu", stop_tokens=None, tokenizer=None ): """Actually generate the tokens""" logger.debug( 'temp: {} top_k: {} top_p: {} rep-pen: {}'.format(temperature, top_k, top_p, repetition_penalty)) context_tokens = context context = torch.tensor(context, dtype=torch.long, device=device) # context = context.repeat(num_samples, 1) generated = context USE_PAST = True next_token = context pasts = None clines = 0 with torch.no_grad(): for j in range(length): # why would we ever not use past? # is generated and next_token always same thing? if not USE_PAST: input_ids_next = generated pasts = None else: input_ids_next = next_token # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states) logits, pasts = model(input_ids=input_ids_next, past=pasts) logits = logits[-1, :].float() # переписать логику TODO if settings.getboolean('sparse-gen'): probs = entmax_bisect(logits, dim=-1, alpha=settings.sparse-level) next_token = torch.multinomial(probs, num_samples=1) else: # Originally the order was Temperature, Repetition Penalty, then top-k/p if settings.getboolean('top-p-first'): logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) logits = logits / (temperature if temperature > 0 else 1.0) # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) for k in set(generated.tolist()): logits[k] /= repetition_penalty if not settings.getboolean('top-p-first'): logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) if temperature == 0: # greedy sampling: next_token = torch.argmax(logits, dim=-1).unsqueeze(-1) else: next_token = torch.multinomial( F.softmax(logits, dim=-1), num_samples=1 ) generated = torch.cat((generated, next_token), dim=-1) # Decode into plain text o = generated[len(context_tokens):].tolist() generated.text = tokenizer.decode( o, clean_up_tokenization_spaces=False, skip_special_tokens=True ) if use_ptoolkit(): clear_lines(clines) generated.text = format_result(generated.text) clines = output(generated.text, "ai-text") if ( (stop_tokens is not None) and (j > 4) and (next_token[0] in stop_tokens) ): # Why the minimum tokens, j>X. Because sometimes the models starts with whitespace, which will strip away anyway. Having a minimum amount of tokens before we stop usually means we don't just stop because of "\n " or similar logger.debug( "Stopping generation as we found stop tokens. One of `%s`, in '%s'. token generated `%s`", stop_tokens, next_token, j, ) break clear_lines(clines) return generated