Beispiel #1
0
def getGenerator():
    colPrint(
        "\nInitializing AI Engine! (This might take a few minutes)\n",
        colors["loading-message"],
    )
    models = [x for x in Path('models').iterdir() if x.is_dir()]
    if not models:
        raise FileNotFoundError(
            'There are no models in the models directory! You must download a pytorch compatible model!'
        )
    elif len(models) > 1:
        colPrint(
            "You have multiple models in your models folder. Please select one to load:",
            colors['message'])
        for n, model_path in enumerate(models):
            colPrint("{}: {}".format(n, model_path.name), colors['menu'])

        model = models[getNumberInput(len(models) - 1)]
    else:
        model = models[0]
        logger.info("Using model: " + str(model))
    return GPT2Generator(
        model_path=model,
        generate_num=settings.getint("generate-num"),
        temperature=settings.getfloat("temp"),
        top_k=settings.getint("top-keks"),
        top_p=settings.getfloat("top-p"),
        repetition_penalty=settings.getfloat("rep-pen"),
    )
    def generate(self,
                 context,
                 prompt='',
                 temperature=None,
                 top_p=None,
                 top_k=None,
                 repetition_penalty=None,
                 depth=0):
        assert (top_k is not None)
        assert (temperature is not None)
        assert (top_p)
        assert (repetition_penalty)
        # logger.debug("BEFORE PROMPT_REPLACE: `%r`", prompt)

        # prompt = [self.prompt_replace(p) for p in prompt]

        # logger.debug("AFTER PROMPT_REPLACE is: `%r`", repr(prompt))
        assert (prompt + context)

        text = self.generate_raw(context,
                                 prompt,
                                 temperature=temperature,
                                 top_k=top_k,
                                 top_p=top_p,
                                 repetition_penalty=repetition_penalty,
                                 stop_tokens=torch.tensor(
                                     [[self.tokenizer.eos_token_id]]))

        logger.debug("Generated result is: `%r`", repr(text))

        result = self.result_replace(text)

        if (depth > 6) and len(result) == 0:
            # Sometimes it keeps generating a story startng with an action (">"), if it's tried a few times and it keeps
            # happening, lets let it keep action text which starts in ">"
            # We could just blacklist that token and force it to generate something else. TODO
            result = self.result_replace(text, allow_action=True)
            logger.info(
                "Model generated empty text after formatting `%r`. Trying to format less with allow_action=True. `%r`",
                text,
                result,
            )

            # same here as above
        if len(result) == 0:
            if depth < 20:
                logger.info("Model generated empty text trying again %r",
                            depth)
                return self.generate(prompt,
                                     context,
                                     temperature=temperature,
                                     top_p=top_p,
                                     top_k=top_k,
                                     repetition_penalty=repetition_penalty,
                                     depth=depth + 1)
            else:
                logger.warn(
                    "Model generated empty text %r times. Try another action",
                    depth)
        return result
Beispiel #3
0
    def __init__(self,
                 generate_num=60,
                 temperature=0.4,
                 top_k=40,
                 top_p=0.9,
                 dtype=DTYPE,
                 model_path: Union[str,
                                   Path] = Path('models',
                                                'pytorch-gpt2-xl-aid2-v5'),
                 repetition_penalty=1,
                 repetition_penalty_range=512,
                 repetition_penalty_slope=3.33):
        self.generate_num = generate_num
        self.temp = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.samples = 1
        self.dtype = dtype
        self.repetition_penalty = repetition_penalty
        self.repetition_penalty_range = repetition_penalty_range
        self.repetition_penalty_slope = repetition_penalty_slope
        self.batch_size = 1
        self.max_history_tokens = 1024 - generate_num
        self.stop_token = "<|endoftext|>"

        if isinstance(model_path, str):
            self.checkpoint_path = model_path
            logger.warning(
                f"Using DEBUG MODE! This will load one of the generic (non-finetuned) GPT2 models. "
                f"Selected: {model_path}")
        elif isinstance(model_path, Path):
            self.checkpoint_path = model_path
            if not self.checkpoint_path.exists():
                raise FileNotFoundError(
                    "Could not find {} Make sure to download a pytorch model and put it in the models directory!"
                    .format(str(self.checkpoint_path)))
        else:
            raise ValueError(
                f"model_path must be either str or Path, got {type(model_path)}"
            )

        self.device = torch.device("cuda" if self.dtype ==
                                   torch.float16 else "cpu")
        logger.info("Using device={}, checkpoint={}, dtype={}".format(
            self.device, str(self.checkpoint_path), self.dtype))

        # Load tokenizer and model
        model_class, tokenizer_class = MODEL_CLASSES[
            "gpt2-experimental"] if settings.getboolean(
                "gpt2_experimental") else MODEL_CLASSES["gpt2"]
        if "gpt-neo" in str(model_path):
            self.max_history_tokens = 2048 - generate_num
            model_class = GPTNeoForCausalLM
        self.tokenizer = tokenizer_class.from_pretrained(
            str(self.checkpoint_path))
        self.model = model_class.from_pretrained(str(self.checkpoint_path))
        self.model.to(self.dtype).to(self.device)
        self.model.eval()
    def __init__(
        self,
        generate_num=60,
        temperature=0.4,
        top_k=40,
        top_p=0.9,
        dtype=DTYPE,
        model_path: Union[str, Path] = Path('models', 'gpt-neo-2.7B-horni'),
        repetition_penalty=1,
    ):
        self.generate_num = generate_num
        self.temp = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.samples = 1
        self.dtype = dtype
        self.repetition_penalty = repetition_penalty
        self.batch_size = 1
        self.max_history_tokens = 1024 - generate_num
        self.stop_token = "<|endoftext|>"

        if isinstance(model_path, str):
            self.checkpoint_path = model_path
            logger.warning(
                f"Using DEBUG MODE! This will load one of the generic (non-finetuned) GPT2 models. "
                f"Selected: {model_path}")
        elif isinstance(model_path, Path):
            self.checkpoint_path = model_path
            if not self.checkpoint_path.exists():
                raise FileNotFoundError(
                    "Could not find {} Make sure to download a pytorch model and put it in the models directory!"
                    .format(str(self.checkpoint_path)))
        else:
            raise ValueError(
                f"model_path must be either str or Path, got {type(model_path)}"
            )

        self.device = torch.device("cuda" if self.dtype ==
                                   torch.float16 else "cpu")
        logger.info("Using device={}, checkpoint={}, dtype={}".format(
            self.device, str(self.checkpoint_path), self.dtype))

        # Load tokenizer and model
        model_class, tokenizer_class = MODEL_CLASSES["gpt_neo"]
        self.checkpoint = torch.load(Path(model_path, 'pytorch_model.bin'),
                                     map_location='cpu')
        self.tokenizer = tokenizer_class.from_pretrained(Path(model_path))
        self.model = model_class.from_pretrained(model_path,
                                                 state_dict=self.checkpoint)
        self.model.to(self.dtype).to(self.device)
        self.model.eval()
Beispiel #5
0
    def __init__(
        self,
        generate_num=60,
        temperature=0.4,
        top_k=40,
        top_p=0.9,
        dtype=DTYPE,
        model_path=Path('models', 'pytorch-gpt2-xl-aid2-v5'),
        censor=False,
        repetition_penalty=1,
    ):
        self.generate_num = generate_num
        self.temp = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.censor = censor
        self.samples = 1
        self.dtype = dtype
        self.repetition_penalty = repetition_penalty
        self.batch_size = 1
        self.max_history_tokens = 1024 - generate_num
        self.stop_token = "<|endoftext|>"

        self.checkpoint_path = model_path
        if not self.checkpoint_path.exists():
            raise FileNotFoundError(
                "Could not find {} Make sure to download a pytorch model and put it in the models directory!"
                .format(str(self.checkpoint_path)))

        if os.environ.get("DEBUG_GPT2", False):
            self.checkpoint_path = Path('gpt2')
            logger.warning(
                "using DEBUG_GPT2 MODE! This is just for devs to quickly check a small GPT2 model with poor output"
            )
        self.device = torch.device("cuda" if self.dtype ==
                                   torch.float16 else "cpu")
        logger.info("Using device={}, checkpoint={}, dtype={}".format(
            self.device, str(self.checkpoint_path), self.dtype))

        # Load tokenizer and model
        model_class, tokenizer_class = MODEL_CLASSES["gpt2"]
        self.tokenizer = tokenizer_class.from_pretrained(self.checkpoint_path)
        self.model = model_class.from_pretrained(self.checkpoint_path)
        self.model.to(self.dtype).to(self.device)
        self.model.eval()
Beispiel #6
0
    def get_action(self):
        # While we want the story to be on track, but not to on track that it loops
        # the actions can be quite random, and this helps inject some user curated randomness
        # and prevent loops. So lets make the actions quite random, and prevent duplicates while we are at it

        # what to feed to model?
        mem_ind = random.randint(1, 6)  # How many steps to include
        sample = random.randint(0, 1)  # Random steps from history?
        include_prompt = random.randint(0, 1)  # Include the initial promts
        predicates = ['You try to ', 'You say "', 'You start to ',
                      '"']  # The model has to continue from here

        predicate = random.sample(predicates, 1)[0]
        action_prompt = self.story_manager.story_context(
            mem_ind, sample, include_prompt)
        action_prompt[-1] = action_prompt[-1].strip() + "\n> " + predicate

        result_raw = self.story_manager.generator.generate_raw(
            action_prompt,
            generate_num=settings.getint("action-generate-num"),
            temperature=settings.getfloat("action-temp"),
            stop_tokens=self.story_manager.generator.tokenizer.encode(
                ["<|endoftext|>", "\n", ">"])
            # stop_tokens=self.generator.tokenizer.encode(['>', '<|endoftext|>'])
        )
        logger.info(
            "get_action (mem_ind=%s, sample=%s, include_prompt=%s, predicate=`%r`) -> %r",
            mem_ind, sample, include_prompt, predicate, result_raw)
        result = predicate + result_raw.lstrip()
        result = clean_suggested_action(
            result, min_length=settings.getint("action-min-length"))
        # Sometimes the suggestion start with "You" we will add that on later anyway so remove it here
        result = re.sub("^ ?[Yy]ou try to ?", "You ", result)
        result = re.sub("^ ?[Yy]ou start to ?", "You ", result)
        result = re.sub("^ ?[Yy]ou say \"", "\"", result)
        result = re.sub("^ ?[Yy]ou ?", "", result)
        return result
Beispiel #7
0
    def generate(self, prompt, options=None, seed=None, depth=0):
        logger.debug("BEFORE PROMPT_REPLACE: `%r`", prompt)

        prompt = [self.prompt_replace(p) for p in prompt]

        # logger.debug("AFTER PROMPT_REPLACE is: `%r`", repr(prompt))

        text = self.generate_raw(prompt,
                                 stop_tokens=self.tokenizer.encode(
                                     ["<|endoftext|>", ">"]))

        logger.debug("Generated result is: `%r`", repr(text))

        result = self.result_replace(text)

        if (depth > 6) and len(result) == 0:
            # Sometimes it keeps generating a story startng with an action (">"), if it's tried a few times and it keeps
            # happening, lets let it keep action text which starts in ">"
            result = self.result_replace(text, allow_action=True)
            logger.info(
                "Model generated empty text after formatting `%r`. Trying to format less with allow_action=True. `%r`",
                text,
                result,
            )

        if len(result) == 0:
            if depth < 20:
                logger.info("Model generated empty text trying again %r",
                            depth)
                return self.generate(prompt + [" {}".format(depth)],
                                     seed=depth,
                                     depth=depth + 1)
            else:
                logger.warn(
                    "Model generated empty text %r times. Try another action",
                    depth)
        return result
Beispiel #8
0
def play(generator):
    print("\n")

    with open(Path("interface", "mainTitle.txt"), "r",
              encoding="utf-8") as file:
        colPrint(file.read(), colors["title"], wrap=False)

    with open(Path("interface", "subTitle.txt"), "r",
              encoding="utf-8") as file:
        cols = termWidth
        for line in file:
            line = re.sub(r'\n', '', line)
            line = line[:cols]
            #fills in the graphic using reverse video mode substituted into the areas between |'s
            colPrint(
                re.sub(r'\|[ _]*(\||$)',
                       lambda x: '\x1B[7m' + x.group(0) + '\x1B[27m', line),
                colors['subtitle'], False)

    print()
    colPrint(
        "Go to https://github.com/cloveranon/Clover-Edition/ or email [email protected] for bug reports, help, and feature requests.",
        colors['subsubtitle'])

    while True:
        # May be needed to avoid out of mem
        gc.collect()
        torch.cuda.empty_cache()

        print("\n\n")

        colPrint(
            "0: Pick Prompt From File (Default if you type nothing)\n1: Write Custom Prompt",
            colors['menu'])

        if getNumberInput(1) == 1:
            with open(Path("interface", "prompt-instructions.txt"),
                      "r",
                      encoding="utf-8") as file:
                colPrint(file.read(), colors["instructions"], False)
            prompt = colInput("Prompt>", colors["main-prompt"],
                              colors["user-text"])
            context = colInput("Context>", colors["main-prompt"],
                               colors["user-text"])
            filename = colInput(
                "Name to save prompt as? (Leave blank for no save): ",
                colors["query"],
                colors["user-text"],
            )
            filename = re.sub(
                "-$", "",
                re.sub("^-", "", re.sub("[^a-zA-Z0-9_-]+", "-", filename)))
            if filename != "":
                with open(Path("prompts", filename + ".txt"),
                          "w",
                          encoding="utf-8") as f:
                    f.write(context + "\n" + prompt)
        else:
            prompt, context = selectFile()
        assert (prompt + context)

        instructions()

        print()
        colPrint("Generating story...", colors["loading-message"])

        story = newStory(generator, prompt, context)

        while True:
            # Generate suggested actions
            act_alts = settings.getint("action-sugg")
            if act_alts > 0:

                # TODO change this to two messages for different colors
                suggested_actions = []
                colPrint("\nSuggested actions:", colors["selection-value"])
                action_suggestion_lines = 2
                for i in range(act_alts):
                    suggested_action = story.getSuggestion()
                    if len(suggested_action.strip()) > 0:
                        j = len(suggested_actions)
                        suggested_actions.append(suggested_action)
                        suggestion = "{}> {}".format(j, suggested_action)
                        action_suggestion_lines += colPrint(
                            suggestion, colors["selection-value"])
                print()

            bell()
            action = colInput("> You ", colors["main-prompt"],
                              colors["user-text"])

            # Clear suggestions and user input
            if act_alts > 0:
                action_suggestion_lines += 2
                if not IN_COLAB:
                    clear_lines(action_suggestion_lines)

                    # Show user input again
                    # colPrint("\n> " + action.rstrip(), colors["user-text"], end="")

            setRegex = re.search("^/set ([^ ]+) ([^ ]+)$", action)
            if setRegex:
                if setRegex.group(1) in settings:
                    currentSettingValue = settings[setRegex.group(1)]
                    colPrint(
                        "Current Value of {}: {}     Changing to: {}".format(
                            setRegex.group(1), currentSettingValue,
                            setRegex.group(2)))
                    settings[setRegex.group(1)] = setRegex.group(2)
                    colPrint("Save config file?", colors["query"])
                    colPrint("Saving an invalid option will corrupt file!",
                             colors["error"])
                    if (colInput(
                            "y/n? >",
                            colors["selection-prompt"],
                            colors["selection-value"],
                    ) == "y"):
                        with open("config.ini", "w", encoding="utf-8") as file:
                            config.write(file)
                else:
                    colPrint("Invalid Setting", colors["error"])
                    instructions()
            elif action == "/menu":
                break
            elif action == "/restart":
                print()
                colPrint("Restarting story...", colors["loading-message"])

                story = newStory(generator, story.prompt, context)
                continue
            elif action == "/quit":
                exit()
            elif action == "/help":
                instructions()
            elif action == "/print":
                print("\nPRINTING\n")
                #TODO colorize printed story
                colPrint(str(story), colors["print-story"])
            elif action == '/retry':

                if len(story.story) == 1:
                    print()
                    colPrint("Restarting story...", colors["loading-message"])
                    story = newStory(generator, story.prompt, context)
                    continue
                else:
                    newaction = story.story[-1][0]

                colPrint(newaction, colors['user-text'], end='')
                story.story = story.story[:-1]
                result = "\n" + story.act(newaction)[0]

                if len(story.story) >= 2:
                    similarity = get_similarity(result, story.story[-2][1][0])
                    if similarity > 0.9:
                        story.story = story.story[:-1]
                        colPrint(
                            "Woops that action caused the model to start looping. Try a different action to prevent that.",
                            colors["error"],
                        )
                        continue
                colPrint(result, colors["ai-text"])

                continue

            elif action == '/revert':

                if len(story.story) == 1:
                    colPrint("You can't go back any farther. ",
                             colors["error"])
                    continue

                story.story = story.story[:-1]
                colPrint("Last action reverted. ", colors["message"])
                if len(story.story) < 2:
                    colPrint(story.prompt, colors["ai-text"])
                colPrint(story.story[-1][1][0], colors["ai-text"])

                continue

            elif action == "/alter":
                story.story[-1][1][0] = alterText(story.story[-1][1][0])
                if len(story.story) < 2:
                    colPrint(story.prompt, colors["ai-text"])
                else:
                    colPrint("\n" + story.story[-1][0] + "\n",
                             colors["transformed-user-text"])
                colPrint("\n" + story.story[-1][1][0] + "\n\n",
                         colors["ai-text"])

            elif action == "/prompt":
                story.prompt = alterText(story.prompt)
                if len(story.story) < 2:
                    colPrint(story.prompt, colors["ai-text"])
                else:
                    colPrint("\n" + story.story[-1][0] + "\n",
                             colors["transformed-user-text"])
                colPrint("\n" + story.story[-1][1][0] + "\n\n",
                         colors["ai-text"])

            else:
                if act_alts > 0:
                    # Options to select a suggestion action
                    if action in [
                            str(i) for i in range(len(suggested_actions))
                    ]:
                        action = suggested_actions[int(action)]

                original_action = action
                action = action.strip()
                #TODO debug stuff to delete
                if action != original_action:
                    logger.debug("STRIPPED WHITE SPACE OFF ACTION %r vs %r",
                                 original_action, action)

                # Crop actions to a max length
                #action = action[:4096]

                if action != "":

                    # Roll a 20 sided dice to make things interesting
                    d = random.randint(1, 20)
                    logger.debug("roll d20=%s", d)

                    # If it says 'You say "' then it's still dialouge. Normalise it by removing `You say `, we will add again soon
                    action = re.sub("^ ?[Yy]ou say [\"']", '"', action)
                    if any(action.lstrip().startswith(t) for t in ['"', "'"]):
                        if settings.getboolean("action-d20"):
                            action = d20ify_speech(action, d)
                        else:
                            action = "You say " + action
                        logger.info(
                            "%r. %r, %r", action,
                            any(action.lstrip().startswith(t)
                                for t in ['"', "'"]),
                            settings.getboolean("action-d20"))
                    else:
                        action = first_to_second_person(action)
                        if not action.lower().startswith(
                                "you ") and not action.lower().startswith(
                                    "i "):
                            action = action[0].lower() + action[1:]
                            # roll a d20
                            if settings.getboolean("action-d20"):
                                action = d20ify_action(action, d)
                            else:
                                action = "You " + action

                        if action[-1] not in [".", "?", "!"]:
                            action = action + "."

                action = "\n> " + action + "\n"

                colPrint(
                    "\n>" + action.lstrip().lstrip("> \n"),
                    colors["transformed-user-text"],
                )
                #TODO check if leading white space makes sense
                result = "\n" + story.act(action)[0]

                #TODO: Replace all this nonsense
                if len(story.story) >= 2:
                    similarity = get_similarity(result, story.story[-2][1][0])
                    if similarity > 0.9:
                        story.story = story.story[:-1]
                        colPrint(
                            "Woops that action caused the model to start looping. Try a different action to prevent that.",
                            colors["error"],
                        )
                        continue

                if player_won(result):
                    colPrint(result + "\n CONGRATS YOU WIN", colors["message"])
                    break
                elif player_died(result):
                    colPrint(result, colors["ai-text"])
                    colPrint("YOU DIED. GAME OVER", colors["error"])
                    colPrint(
                        "\nOptions:\n0)Start a new game\n1)\"I'm not dead yet!\" (If you didn't actually die)",
                        colors["menu"],
                    )
                    choice = getNumberInput(1)
                    if choice == 0:
                        break
                    else:
                        colPrint("Sorry about that...where were we?",
                                 colors["query"])
                colPrint(result, colors["ai-text"])
Beispiel #9
0
        from IPython import get_ipython
        if (not get_ipython()) or (
                'IPKernelApp' not in get_ipython().config):  # pragma: no cover
            raise ImportError("console")
        if 'VSCODE_PID' in os.environ:  # pragma: no cover
            raise ImportError("vscode")
    except ImportError:
        if get_terminal_size()[0] == 0 or 'google.colab' in sys.modules:
            return True
        return False
    else:
        return True


IN_COLAB = _in_colab()
logger.info("Colab detected: {}".format(IN_COLAB))
IN_COLAB = IN_COLAB or settings.getboolean('colab-mode')
if IN_COLAB:
    logger.warning(
        "Colab mode enabled, disabling line clearing and readline to avoid colab bugs."
    )
else:
    try:
        import readline
        logger.info(
            'readline has been imported. This enables a number of editting features but may cause bugs for colab users.'
        )
    except ModuleNotFoundError:
        pass

termWidth = get_terminal_size()[0]
Beispiel #10
0
import torch
import torch.nn.functional as F
import re
from gpt2 import GPT2LMHeadModelExperimental
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
from getconfig import settings, logger
from utils import cut_trailing_sentence, output, clear_lines, format_result, use_ptoolkit

if not settings.getboolean('force-cpu') and not torch.cuda.is_available():
    logger.warning('CUDA is not available, you are limited to CPU only.')

DTYPE = torch.float32 if ((not torch.cuda.is_available()) or
                          settings.getboolean('force-cpu')) else torch.float16
logger.info('Cuda Available: {}    Force CPU: {}    Precision: {}'.format(
    torch.cuda.is_available(), settings.getboolean('force-cpu'),
    '32-bit' if DTYPE == torch.float32 else '16-bit'))

# warnings.filterwarnings("ignore")
MODEL_CLASSES = {
    "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
    "gpt2-experimental": (GPT2LMHeadModelExperimental, GPT2Tokenizer),
}


def getTokens(tokenizer, l):
    tokenizer.encode()


# the tokenizer does not preserve white space at the front of the string.
# so we will append something else to the front of the string and then remove it after tokenization
Beispiel #11
0
    logger.warning(
        "Colab mode enabled, disabling line clearing and readline to avoid colab bugs."
    )
else:
    try:
        if settings.getboolean('prompt-toolkit'):
            from inline_editor import edit_multiline
            from prompt_toolkit import prompt as ptprompt
            from prompt_toolkit import print_formatted_text
            from prompt_toolkit.styles import Style
            from prompt_toolkit.formatted_text import to_formatted_text, HTML
        else:
            raise ModuleNotFoundError

        logger.info(
            'Python Prompt Toolkit has been imported. This enables a number of editing features but may cause bugs for colab users.'
        )
    except (ImportError, ModuleNotFoundError):
        try:
            settings['prompt-toolkit'] = "off"
            import readline

            logger.info(
                'readline has been imported. This enables a number of editting features but may cause bugs for colab users.'
            )
        except (ImportError, ModuleNotFoundError):
            pass


def pad_text(text, width, sep=' '):
    while len(text) < width:
Beispiel #12
0
        from IPython import get_ipython
        if (not get_ipython()) or (
                'IPKernelApp' not in get_ipython().config):  # pragma: no cover
            raise ImportError("console")
        if 'VSCODE_PID' in os.environ:  # pragma: no cover
            raise ImportError("vscode")
    except ImportError:
        if get_terminal_size() == 0:
            return True
        return False
    else:
        return True


is_notebook = _is_notebook()
logger.info("Colab detected: {}".format(is_notebook))
if is_notebook:
    logger.warning(
        "Colab detected, disabling line clearing and readline to avoid colab bugs."
    )
if not is_notebook:
    try:
        import readline
    except ModuleNotFoundError:
        pass

termWidth = get_terminal_size()[0]
if termWidth < 5:
    logger.warning("Your detected terminal width is: " +
                   str(get_terminal_size()[0]))
    termWidth = 999999999
Beispiel #13
0
    # from https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
    try:
        from IPython import get_ipython
        if (not get_ipython()) or (
                'IPKernelApp' not in get_ipython().config):  # pragma: no cover
            raise ImportError("console")
        if 'VSCODE_PID' in os.environ:  # pragma: no cover
            raise ImportError("vscode")
    except ImportError:
        return False
    else:
        return True


is_notebook = _is_notebook()
logger.info("Notebook detected: {}".format(is_notebook))

termWidth = get_terminal_size()[0]
if termWidth < 5:
    logger.warning("Your detected terminal width is: " +
                   str(get_terminal_size()[0]))
    termWidth = 999999999


# ECMA-48 set graphics codes for the curious. Check out "man console_codes"
def colPrint(text, col="0", wrap=True, end=None):
    if wrap:
        width = settings.getint("text-wrap-width")
        width = 999999999 if width < 2 else width
        width = min(width, termWidth)
        text = textwrap.fill(text, width, replace_whitespace=False)
Beispiel #14
0
def sample_sequence(
    model,
    length,
    context,
    num_samples=1,
    temperature=1,
    top_k=0,
    top_p=0.9,
    repetition_penalty=1.0,
    is_xlnet=False,
    is_xlm_mlm=False,
    xlm_mask_token=None,
    xlm_lang=None,
    device="cpu",
    stop_tokens=None,
):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    USE_PAST = True
    next_token = context
    outputs = None
    with torch.no_grad():
        for j in range(length):
            if USE_PAST:
                past = outputs[1] if outputs is not None else None
                inputs = {"input_ids": next_token, "past": past}
            else:
                inputs = {"input_ids": generated}

            outputs = model(
                **inputs
            )  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][:, -1, :] / (
                temperature if temperature > 0 else 1.0)

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(num_samples):
                for k in set(generated[i].tolist()):
                    next_token_logits[i, k] /= repetition_penalty

            filtered_logits = top_k_top_p_filtering(next_token_logits,
                                                    top_k=top_k,
                                                    top_p=top_p).float()
            if temperature == 0:  # greedy sampling:
                next_token = torch.argmax(filtered_logits,
                                          dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits,
                                                         dim=-1),
                                               num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            if ((stop_tokens is not None) and (j > 4)
                    and (next_token[0][0] in stop_tokens)):
                # Why the minimum tokens, j>X. Because sometimes the models starts with whitespace, which will strip away anyway. Having a minimum amount of tokens before we stop usually means we don't just stop because of "\n " or similar
                logger.info(
                    "Stopping generation as we found stop tokens. One of `%s`, in '%s'. token generated `%s`",
                    stop_tokens,
                    next_token,
                    j,
                )
                break
    return generated
Beispiel #15
0
import os
from pathlib import Path
import itertools
import torch
import torch.nn.functional as F

from transformers import GPT2LMHeadModel, GPT2Tokenizer

from getconfig import settings, logger
from story.utils import cut_trailing_sentence

DTYPE = torch.float32 if ((not torch.cuda.is_available()) or
                          settings.getboolean('force-cpu')) else torch.float16
logger.info('Cuda Available: {}    Force CPU: {}    DTYPE: {}'.format(
    torch.cuda.is_available(), settings.getboolean('force-cpu'), DTYPE))

# warnings.filterwarnings("ignore")
MODEL_CLASSES = {
    "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
}


def top_k_top_p_filtering(logits,
                          top_k=0,
                          top_p=0.0,
                          filter_value=-float("Inf")):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).