示例#1
0
def _clear_grid(grid_size: int) -> None:
    """
    Clear the grid from the console output by clearing lines equal to the number of rows in the
    grid plus row separators.

    :param int grid_size: number of rows in the grid to clear
    :return: None
    """
    utils.clear_lines(2 * grid_size - 1)
示例#2
0
def _play(grid: List[List[str]], time_limit: int = DEFAULT_TIME_LIMIT) -> List[str]:
    """
    Play a single-player game of Boggle with the given grid of letters. The player has time_limit
    seconds to find as many words in the grid as possible and enter them in the prompt.

    :param [List[List[str]]] grid: grid of letters to find words in
    :param int time_limit: time limit (in seconds) for player entries
    :return: list of all player entries within the time limit.
    :rtype: List[str]
    """
    # Conceal the grid until player ready
    _render_concealed_grid(grid)
    print(
        "You will have {} seconds to find as many words as you can. Ready?".format(
            time_limit
        )
    )
    input()
    utils.clear_lines(2)
    _clear_grid(len(grid))
    _render_grid(grid)
    return _prompt_player(time_limit)
示例#3
0
def _prompt_player(time_limit: int = DEFAULT_TIME_LIMIT) -> List[str]:
    """
    :param int time_limit: time limit (in seconds) for player entries
    :return: list of all player entries within the time limit
    :rtype: List[str]
    """
    print(
        "Enter as many words as you can find in the next {} seconds".format(time_limit)
    )
    timer = threading.Timer(time_limit, _thread.interrupt_main)
    player_entries = []
    try:
        timer.start()
        while True:
            # Capitalize all inputs
            player_entries.append(input().upper())
            # Clear user input to avoid pushing grid out of view
            utils.clear_lines(1)
    except KeyboardInterrupt:
        pass
    timer.cancel()
    print("Time's up!")
    return player_entries
示例#4
0
def sample_sequence(model,
                    length,
                    context,
                    temperature=1,
                    top_k=0,
                    top_p=0.9,
                    repetition_penalty=1.0,
                    repetition_penalty_range=512,
                    repetition_penalty_slope=3.33,
                    device="cpu",
                    stop_tokens=None,
                    tokenizer=None):
    """Actually generate the tokens"""
    logger.debug(
        'temp: {}    top_k: {}    top_p: {}    rep-pen: {}    rep-pen-range: {}    rep-pen-slope: {}'
        .format(temperature, top_k, top_p, repetition_penalty,
                repetition_penalty_range, repetition_penalty_slope))
    context_tokens = context
    context = torch.tensor(context, dtype=torch.long, device=device)
    # context = context.repeat(num_samples, 1)
    generated = context
    USE_PAST = True
    next_token = context
    pasts = None
    clines = 0

    penalty = None
    if not repetition_penalty_range is None and not repetition_penalty_slope is None and repetition_penalty_range > 0:
        penalty = (torch.arange(repetition_penalty_range) /
                   (repetition_penalty_range - 1)) * 2. - 1
        penalty = (repetition_penalty_slope *
                   penalty) / (1 + torch.abs(penalty) *
                               (repetition_penalty_slope - 1))
        penalty = 1 + ((penalty + 1) / 2) * (repetition_penalty - 1)

    with torch.no_grad():
        for j in range(length):
            # why would we ever not use past?
            # is generated and next_token always same thing?
            if not USE_PAST:
                input_ids_next = generated
                pasts = None
            else:
                input_ids_next = next_token

            # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            model_kwargs = {"past": pasts, "use_cache": True}
            model_inputs = model.prepare_inputs_for_generation(
                generated.unsqueeze(0), **model_kwargs)
            model_outputs = model(**model_inputs, return_dict=True)
            logits, pasts = model_outputs.logits, model_outputs.past_key_values
            logits = logits[0, -1, :].float()

            # Originally the order was Temperature, Repetition Penalty, then top-k/p
            if settings.getboolean('top-p-first'):
                logits = top_k_top_p_filtering(logits,
                                               top_k=top_k,
                                               top_p=top_p)

            logits = logits / (temperature if temperature > 0 else 1.0)

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) plus range limit
            if repetition_penalty != 1.0:
                if penalty is not None:
                    penalty_len = min(generated.shape[0],
                                      repetition_penalty_range)
                    penalty_context = generated[-repetition_penalty_range:]
                    score = torch.gather(logits, 0, penalty_context)
                    penalty = penalty.type(score.dtype).to(score.device)
                    penalty_window = penalty[-penalty_len:]
                    score = torch.where(score < 0, score * penalty_window,
                                        score / penalty_window)
                    logits.scatter_(0, penalty_context, score)
                else:
                    score = torch.gather(logits, 0, generated)
                    score = torch.where(score < 0, score * repetition_penalty,
                                        score / repetition_penalty)
                    logits.scatter_(0, generated, score)

            if not settings.getboolean('top-p-first'):
                logits = top_k_top_p_filtering(logits,
                                               top_k=top_k,
                                               top_p=top_p)

            if temperature == 0:  # greedy sampling:
                next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(logits, dim=-1),
                                               num_samples=1)
            generated = torch.cat((generated, next_token), dim=-1)
            # Decode into plain text
            o = generated[len(context_tokens):].tolist()
            generated.text = tokenizer.decode(
                o,
                clean_up_tokenization_spaces=False,
                skip_special_tokens=True)
            if use_ptoolkit():
                clear_lines(clines)
                generated.text = format_result(generated.text)
                clines = output(generated.text, "ai-text")
            if ((stop_tokens is not None) and (j > 4)
                    and (next_token[0] in stop_tokens)):
                # Why the minimum tokens, j>X. Because sometimes the models starts with whitespace, which will strip away anyway. Having a minimum amount of tokens before we stop usually means we don't just stop because of "\n " or similar
                logger.debug(
                    "Stopping generation as we found stop tokens. One of `%s`, in '%s'. token generated `%s`",
                    stop_tokens,
                    next_token,
                    j,
                )
                break
    clear_lines(clines)
    return generated
示例#5
0
def sample_sequence(
        model,
        length,
        context,
        temperature=1,
        top_k=0,
        top_p=0.9,
        repetition_penalty=1.0,
        device="cpu",
        stop_tokens=None,
        tokenizer=None
):
    """Actually generate the tokens"""
    logger.debug(
        'temp: {}    top_k: {}    top_p: {}    rep-pen: {}'.format(temperature, top_k, top_p, repetition_penalty))
    context_tokens = context
    context = torch.tensor(context, dtype=torch.long, device=device)
    # context = context.repeat(num_samples, 1)
    generated = context
    USE_PAST = True
    next_token = context
    pasts = None
    clines = 0
    with torch.no_grad():
        for j in range(length):
            # why would we ever not use past?
            # is generated and next_token always same thing?
            if not USE_PAST:
                input_ids_next = generated
                pasts = None
            else:
                input_ids_next = next_token

            # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            logits, pasts = model(input_ids=input_ids_next, past=pasts)
            logits = logits[-1, :].float()

            # переписать  логику TODO
            if settings.getboolean('sparse-gen'): 
                probs = entmax_bisect(logits, dim=-1, alpha=settings.sparse-level)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                # Originally the order was Temperature, Repetition Penalty, then top-k/p
                if settings.getboolean('top-p-first'):
                    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

                logits = logits / (temperature if temperature > 0 else 1.0)

                # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
                for k in set(generated.tolist()):
                    logits[k] /= repetition_penalty

                if not settings.getboolean('top-p-first'):
                    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

                if temperature == 0:  # greedy sampling:
                    next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
                else:
                    next_token = torch.multinomial(
                        F.softmax(logits, dim=-1), num_samples=1
                    )
            generated = torch.cat((generated, next_token), dim=-1)
            # Decode into plain text
            o = generated[len(context_tokens):].tolist()
            generated.text = tokenizer.decode(
                o, clean_up_tokenization_spaces=False, skip_special_tokens=True
            )
            if use_ptoolkit():
                clear_lines(clines)
                generated.text = format_result(generated.text)
                clines = output(generated.text, "ai-text")
            if (
                    (stop_tokens is not None)
                    and (j > 4)
                    and (next_token[0] in stop_tokens)
            ):
                # Why the minimum tokens, j>X. Because sometimes the models starts with whitespace, which will strip away anyway. Having a minimum amount of tokens before we stop usually means we don't just stop because of "\n " or similar
                logger.debug(
                    "Stopping generation as we found stop tokens. One of `%s`, in '%s'. token generated `%s`",
                    stop_tokens,
                    next_token,
                    j,
                )
                break
    clear_lines(clines)
    return generated
示例#6
0
def edit_multiline(default_text=""):
    kb = KeyBindings()

    @kb.add('c-q')
    @kb.add('escape', 'enter')
    def exit_(event):
        """
        Pressing Ctrl-Q, Alt+Enter or Esc + Enter will exit the editor.
        """
        event.app.exit(textf.text)

    @kb.add('c-c')
    def do_copy(event):
        data = textf.buffer.copy_selection()
        get_app().clipboard.set_data(data)

    @kb.add('c-x', eager=True)
    def do_cut(event):
        data = textf.buffer.cut_selection()
        get_app().clipboard.set_data(data)

    @kb.add('c-z')
    def do_undo(event):
        textf.buffer.undo()

    @kb.add('c-y')
    def do_redo(event):
        textf.buffer.redo()

    @kb.add('c-a')
    def do_select_all(event):
        textf.buffer.cursor_position = 0
        textf.buffer.start_selection()
        textf.buffer.cursor_position = len(textf.buffer.text)
        update_stored_pos(None)

    @kb.add('c-v')
    def do_paste(event):
        textf.buffer.paste_clipboard_data(get_app().clipboard.get_data())

    @kb.add('left')
    def kb_left(event):
        textf.buffer.selection_state = None
        if textf.buffer.cursor_position != 0 and textf.text[
                textf.buffer.cursor_position - 1] == '\n':
            textf.buffer.cursor_up()
            textf.buffer.cursor_right(len(textf.text))
        else:
            textf.buffer.cursor_left()
        update_stored_pos(None)

    @kb.add('right')
    def kb_right(event):
        textf.buffer.selection_state = None
        if textf.buffer.cursor_position < len(textf.text) and textf.text[
                textf.buffer.cursor_position] == '\n':
            textf.buffer.cursor_down()
            textf.buffer.cursor_left(len(textf.text))

        else:
            textf.buffer.cursor_right()
        update_stored_pos(None)

    @kb.add('home')
    def kb_home(event):
        textf.buffer.selection_state = None
        width = getTermWidth()
        doc = textf.document
        if textf.buffer.cursor_position == doc._line_start_indexes[
                cursor_row()] + int(cursor_col() / width) * width:
            textf.buffer.cursor_position = doc._line_start_indexes[
                cursor_row()]
        else:
            textf.buffer.cursor_position = doc._line_start_indexes[
                cursor_row()] + int(cursor_col() / width) * width
        update_stored_pos(None)

    @kb.add('end')
    def kb_end(event):
        textf.buffer.selection_state = None
        width = getTermWidth()
        doc = textf.document
        row = cursor_row()
        if textf.buffer.cursor_position == doc._line_start_indexes[row] + (
                int(cursor_col() / width) + 1) * width - 1:
            textf.buffer.cursor_position = doc._line_start_indexes[row] + len(
                doc.current_line)
        else:
            textf.buffer.cursor_position = min(
                doc._line_start_indexes[row] +
                (int(cursor_col() / width) + 1) * width - 1,
                doc._line_start_indexes[row] + len(doc.current_line))
        update_stored_pos(None)

    @kb.add('up')
    def kb_up(event):
        textf.freezestore = True
        width = getTermWidth()
        doc = textf.document
        textf.buffer.selection_state = None
        col = cursor_col()
        row = cursor_row()
        if width > 9000:  # A failsafe in case the terminal size is incorrectly detected
            textf.buffer.cursor_up()
            return

        if col >= width:  # Move one row up staying on the same line
            textf.buffer.cursor_position = doc._line_start_indexes[row] + int(
                col / width - 1) * width + textf.stored_cursor_pos
        elif row >= 1:  # Moving up to a different line
            prevlinelen = len(doc.lines[row - 1])

            textf.buffer.cursor_position = min(
                doc._line_start_indexes[row] - 1,
                doc._line_start_indexes[row - 1] +
                int(prevlinelen / width) * width + textf.stored_cursor_pos)
        else:  # Cursor is on the first row of first line
            textf.buffer.cursor_position = 0
            textf.freezestore = False
            update_stored_pos(None)

    @kb.add('down')
    def kb_down(event):
        textf.freezestore = True
        width = getTermWidth()
        doc = textf.document
        textf.buffer.selection_state = None
        col = cursor_col()
        row = cursor_row()
        nextlinelen = len(doc.lines[row +
                                    1]) if row < len(doc.lines) - 1 else -1
        if width > 9000:  # A failsafe in case the terminal size is incorrectly detected
            textf.buffer.cursor_down()
            return

        if col <= len(doc.current_line
                      ) - width:  # Move one row down staying on the same line
            textf.buffer.cursor_position = doc._line_start_indexes[row] + int(
                col / width + 1) * width + textf.stored_cursor_pos
        elif nextlinelen < 0:  # Move to the very end
            textf.buffer.cursor_position = len(textf.text)
            textf.freezestore = False
            update_stored_pos(None)
        # Move to the end of the same line the cursor is on
        elif col != len(doc.lines[row]) and textf.stored_cursor_pos >= len(
                doc.lines[row]) - int(len(doc.lines[row]) / width) * width:
            textf.buffer.cursor_position = doc._line_start_indexes[row + 1] - 1
        else:  # Move to a different line
            textf.buffer.cursor_position = min(
                doc._line_start_indexes[row + 1] + nextlinelen,
                doc._line_start_indexes[row + 1] + textf.stored_cursor_pos)

    textf = TextArea()
    bottom_bar_text = FormattedTextControl(
        text=
        '\nCurrently editing. Press Ctrl+Q, Alt+Enter or Esc + Enter to exit.')
    bottom_bar = Window(content=bottom_bar_text)

    root_container = HSplit([
        textf,
        bottom_bar,
    ])

    layout = Layout(root_container)

    app = Application(key_bindings=kb,
                      layout=layout,
                      enable_page_navigation_bindings=True,
                      full_screen=False)
    textf.freezestore = False
    textf.text = default_text
    textf.buffer.cursor_position = len(textf.buffer.text)

    # Find the row the cursor is at
    # My own function, in fear of race conditions
    def cursor_row():
        i = 0
        while i < len(
                textf.document._line_start_indexes
        ) and textf.buffer.cursor_position >= textf.document._line_start_indexes[
                i]:
            i += 1
        return i - 1

    # Find the column the cursor is at
    # There is a built-in function, but I think there's some kind of a race condition if it's used
    def cursor_col():
        i = textf.buffer.cursor_position - 1
        while i >= 0 and textf.text[i] != '\n':
            i -= 1
        return textf.buffer.cursor_position - i - 1

    def update_stored_pos(event):
        if not event:
            textf.freezestore = False
        if textf.freezestore:
            textf.freezestore = False
            return
        width = getTermWidth()
        col = cursor_col()
        textf.stored_cursor_pos = col - int(col / width) * width

    textf.buffer.on_cursor_position_changed += update_stored_pos
    update_stored_pos(None)

    text = app.run()

    clear_lines(1)

    return text