Exemplo n.º 1
0
def get_sample(model, prompt, length:int, num_samples:int, allow_linebreak:bool):
    logger.info("*" * 200)
    logger.info(prompt)
   
    filter_n = tokenizer.encode('\n')[-1:]
    filter_single = [1] + tokenizer.encode('[')[-1:] + tokenizer.encode('(')[-1:]
    filter_single += [] if allow_linebreak else filter_n

    context_tokens = tokenizer.encode(prompt)
    out = sample_sequence(
        model=model,
        context=context_tokens,
        length=length,
        temperature=1,
        top_k=0,
        top_p=0.9,
        device=device,
        filter_single=filter_single,
        filter_double=filter_n,
        num_samples=num_samples,
    ).to('cpu')

    prompt = tokenizer.decode(context_tokens)
    len_prompt = len(prompt)
   
    replies = [out[item, :].tolist() for item in range(len(out))]
    text = [tokenizer.decode(item)[len_prompt:] for item in replies]
    reg_text = [re.match(r'[\w\W]*[\.!?]\n', item) for item in text]
    reg_text2 = [re.match(r'[\w\W]*[\.!?]', item) for item in text]
    result = [reg_item[0] if reg_item else reg_item2[0] if reg_item2 else item for reg_item, reg_item2, item in zip(reg_text, reg_text2, text)]
    logger.info("=" * 200)
    logger.info(result)
    return result
Exemplo n.º 2
0
    def sample(self,
               prompt: str,
               length: int,
               num_samples: int = 1,
               allow_linebreak: bool = True,
               stop_token: int = -1):
        filter_n = self.tokenizer.encode('\n')[-1:]
        filter_single = [
            1
        ] + self.tokenizer.encode('[')[-1:] + self.tokenizer.encode('(')[-1:]
        filter_single += [] if allow_linebreak else filter_n

        context_tokens = self.tokenizer.encode(prompt)
        if stop_token == -1:
            out = sample_sequence(model=self.model,
                                  context=context_tokens,
                                  length=length,
                                  temperature=self.temperature,
                                  top_k=self.top_k,
                                  top_p=self.top_p,
                                  device=self.device,
                                  filter_single=filter_single,
                                  filter_double=filter_n,
                                  num_samples=num_samples).to('cpu')
        else:
            out = sample_sequence_until_token(model=self.model,
                                              context=context_tokens,
                                              length=length,
                                              temperature=self.temperature,
                                              top_k=self.top_k,
                                              top_p=self.top_p,
                                              device=self.device,
                                              filter_single=filter_single,
                                              filter_double=filter_n,
                                              num_samples=num_samples,
                                              stop_token=stop_token).to('cpu')

        prompt = self.tokenizer.decode(context_tokens)
        len_prompt = len(prompt)

        replies = [out[item, :].tolist() for item in range(len(out))]
        text = [self.tokenizer.decode(item)[len_prompt:] for item in replies]
        reg_text = [re.match(r'[\w\W]*[\.!?]\n', item) for item in text]
        reg_text2 = [re.match(r'[\w\W]*[\.!?]', item) for item in text]
        result = [
            reg_item[0] if reg_item else reg_item2[0] if reg_item2 else item
            for reg_item, reg_item2, item in zip(reg_text, reg_text2, text)
        ]
        return result
Exemplo n.º 3
0
def print_sample(model, tokenizer, device, args):
    model.eval()
    raw_text = """ """
    context_tokens = tokenizer.encode(raw_text)
    out = sample_sequence(
        model=model,
        context=context_tokens,
        length=500,
        temperature=1,
        top_k=0,
        top_p=0.9,
        device=device,
        #is_xlnet=bool(args.model_type == "xlnet"),
    )
    out = out[0, len(context_tokens):].tolist()
    text = raw_text + tokenizer.decode(out)
    print(text)
Exemplo n.º 4
0
def print_sample(model, tokenizer, device):
    model.eval()
    raw_text = """ На словах ты Лев Толстой,\n А на деле -"""
    context_tokens = tokenizer.encode(raw_text)
    out = sample_sequence(
        model=model,
        context=context_tokens,
        length=500,
        temperature=1,
        top_k=0,
        top_p=0.9,
        device=device,
        #is_xlnet=bool(args.model_type == "xlnet"),
    )
    out = out[0, len(context_tokens):].tolist()
    text = tokenizer.decode(out)
    print(raw_text + text)
    model.train()
def print_sample(model, tokenizer, device, args):
    model.eval()
    raw_text = """ Хорошее утро """
    context_tokens = tokenizer.encode(raw_text)
    out = sample_sequence(
        model=model,
        context=context_tokens,
        length=500,
        temperature=1,
        top_k=0,
        top_p=0.9,
        device=device,
        #is_xlnet=bool(args.model_type == "xlnet"),
    )
    out = out[0, len(context_tokens):].tolist()
    text = raw_text + tokenizer.decode(out)
    print(text)

    with open(os.path.join(args.output_dir, 'sample.txt'), 'w') as f:
        f.write(text)

    model.train()
Exemplo n.º 6
0
def print_sample(model, tokenizer, device, args):
    model.eval()
    raw_text = """ На словах ты Лев Толстой,\n А на деле -"""
    context_tokens = tokenizer.encode(raw_text)
    out = sample_sequence(model=model,
                          context=context_tokens,
                          length=500,
                          temperature=1,
                          top_k=0,
                          top_p=0.9,
                          device=device,
                          max_input=0
                          #is_xlnet=bool(args.model_type == "xlnet"),
                          )
    out = out[0, len(context_tokens):].tolist()
    text = raw_text + tokenizer.decode(out)
    log_info(text)

    if xm.is_master_ordinal():
        with open(os.path.join(args.output_dir, 'sample.txt'), 'w') as f:
            f.write(text)

    model.train()
Exemplo n.º 7
0
def get_sample(prompt, model, tokenizer, device):
    logger.info("*" * 200)
    logger.info(prompt)

    model.to(device)
    model.eval()

    filter_n = tokenizer.encode('\n')[-1:]
    context_tokens = tokenizer.encode(prompt)
    out = sample_sequence(model=model,
                          context=context_tokens,
                          length=150,
                          temperature=1,
                          top_k=0,
                          top_p=0.9,
                          device=device,
                          filter_double=filter_n)
    out = out[0, len(context_tokens):].tolist()
    text = tokenizer.decode(out)
    result = re.match(r'[\w\W]*[\.!?]\n', text)
    if result: text = result[0]
    logger.info("=" * 200)
    logger.info(text)
    return text
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument("--model_type",
                        default="gpt2",
                        type=str,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--prompt", type=str, default="")
    parser.add_argument("--padding_text", type=str, default="")
    parser.add_argument("--xlm_lang",
                        type=str,
                        default="",
                        help="Optional language when used with the XLM model.")
    parser.add_argument("--length", type=int, default=100)
    parser.add_argument("--num_samples", type=int, default=1)
    parser.add_argument("--temperature",
                        type=float,
                        default=1.0,
                        help="temperature of 0 implies greedy sampling")
    parser.add_argument(
        "--repetition_penalty",
        type=float,
        default=1.0,
        help="primarily useful for CTRL model; in that case, use 1.2")
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--top_p", type=float, default=0.9)
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--stop_token',
                        type=str,
                        default='<|endoftext|>',
                        help="Token at which text generation is stopped")
    args = parser.parse_args()

    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()

    set_seed(args)

    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_type)
    model = model_class.from_pretrained(args.model_name_or_path)
    model.to(args.device)
    model.eval()

    if args.length < 0 and model.config.max_position_embeddings > 0:
        args.length = model.config.max_position_embeddings
    elif 0 < model.config.max_position_embeddings < args.length:
        args.length = model.config.max_position_embeddings  # No generation bigger than model size
    elif args.length < 0:
        args.length = MAX_LENGTH  # avoid infinite loop

    logger.info(args)
    val_df = pd.read_csv("datasets/misinfo.csv")
    val_df = val_df[val_df["split"] == "val"]
    for i in range(args.num_samples):
        val_df["sample-" + str(i)] = ""

    for i, row in tqdm(val_df.iterrows()):
        context_tokens = tokenizer.encode(row.title, add_special_tokens=False)
        out = sample_sequence(
            model=model,
            context=context_tokens,
            num_samples=args.num_samples,
            length=args.length,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
            repetition_penalty=args.repetition_penalty,
            is_xlnet=bool(args.model_type == "xlnet"),
            is_xlm_mlm=False,
            xlm_mask_token=None,
            xlm_lang=None,
            device=args.device,
        )

        out = out[:, len(context_tokens):].tolist()
        j = 0
        for o in out:
            text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
            text = text[:text.find(args.stop_token) if args.
                        stop_token else None]
            val_df.at[i, "sample-" + str(j)] = text
            j += 1
    val_df.to_csv("datasets/val_samples.csv")