示例#1
0
def pplm(prompts: pd.Series, max_len: int, num_samples: int, batch_size: int,
         class_label: int, num_iterations: int, model_name_or_path: str,
         out_file: Path):
    # Set up PPLM with multiprocessing
    generator = PPLMGeneration(model_name_or_path, device=0)
    ctx = mp.get_context('spawn')
    generator.model.share_memory()
    generator.classifier.share_memory()
    pplm_func = partial(generator.__call__,
                        class_label=class_label,
                        num_iterations=num_iterations,
                        length=max_len)

    # Repeat prompts
    prompts = prompts.repeat(num_samples)

    # Resume generation
    num_cached_generations = 0
    for generation in load_cache(out_file):
        yield generation
        num_cached_generations += 1

    # Generate with prompts
    prompts = prompts[num_cached_generations:]
    with ctx.Pool(processes=batch_size) as pool:
        for batch in tqdm(pool.imap(pplm_func, prompts),
                          total=len(prompts),
                          desc='Generation',
                          dynamic_ncols=True):
            for generation in batch:
                with out_file.open('a') as f:
                    print(json.dumps(generation), file=f)
                yield generation
示例#2
0
def _gpt2_helper(prompts: pd.Series, max_len: int, num_samples: int,
                 batch_size: int, generator: GPT2Generation, out_file: Path,
                 **generate_kwargs):
    # Repeat prompts
    prompts = prompts.repeat(num_samples)

    # Resume generation
    num_cached_generations = 0
    for generation in load_cache(out_file):
        yield generation
        num_cached_generations += 1

    # Generate with prompts
    prompts = prompts[num_cached_generations:]
    for prompt in tqdm(batchify(prompts, batch_size),
                       total=math.ceil(len(prompts) / batch_size),
                       desc=f'GPT-2 Generation',
                       dynamic_ncols=True,
                       postfix={'batch_size': batch_size}):
        # Generate
        try:
            batch = generator.generate(prompt, max_len, **generate_kwargs)
        except RuntimeError as e:
            print("Error during generation with prompt:", prompt)
            print(e)
            print("Emptying CUDA cache and retrying...")
            torch.cuda.empty_cache()

            batch = ["GENERATION_ERROR_CUDA"] * len(prompt)

        for generation in batch:
            with out_file.open('a') as f:
                print(json.dumps(generation), file=f)
            yield generation
示例#3
0
    def __init__(self, out_file: Path, total: int, rate_limit: int):
        if not rate_limit:
            print("Disabling Perspective API (rps is 0)")
            self.enabled = False
            return
        self.enabled = True
        self.requests_handled = set()
        for response in load_cache(out_file):
            self.requests_handled.add(response['request_id'])
        total -= len(self.requests_handled)

        # Setup worker thread
        self.task_queue = mp.Queue()
        self.process = mp.Process(target=self.perspective_worker,
                                  args=(self.task_queue, out_file, total,
                                        rate_limit))
        self.process.start()
示例#4
0
def _pipeline_helper(prompts: pd.Series, model_name_or_path: str, max_len: int,
                     num_samples: int, out_file: Path, **generate_kwargs):
    # Load cached generations
    num_cached_generations = 0
    for generation in load_cache(out_file):
        yield generation
        num_cached_generations += 1
    assert num_cached_generations % num_samples == 0

    # Remove prompts that have already been generated with
    prompts = prompts[num_cached_generations // num_samples:]
    if prompts.empty:
        return

    # Setup model
    generator = pipeline('text-generation', model=model_name_or_path, device=0)
    print("Created pipeline with model:", generator.model.__class__.__name__)

    # Generate with prompts
    for prompt in tqdm(prompts, desc='Generation', dynamic_ncols=True):
        # Generate
        # FIXME: this is a hack
        ctx_len = len(generator.tokenizer.tokenize(prompt))
        try:
            batch = generator(prompt,
                              num_return_sequences=num_samples,
                              clean_up_tokenization_spaces=True,
                              do_sample=True,
                              top_k=0,
                              top_p=0.9,
                              max_length=ctx_len + max_len,
                              return_prompt=False,
                              **generate_kwargs)
            batch = map(lambda g: g['generated_text'][len(prompt):], batch)
        except RuntimeError as e:
            print("Error during generation with prompt:", prompt)
            print(e)
            print("Emptying CUDA cache and continuing...")
            torch.cuda.empty_cache()

            batch = ["GENERATION_ERROR_CUDA"] * num_samples

        for generation in batch:
            with out_file.open('a') as f:
                print(json.dumps(generation), file=f)
            yield generation