def get_sample(model, prompt, length:int, num_samples:int, allow_linebreak:bool): logger.info("*" * 200) logger.info(prompt) filter_n = tokenizer.encode('\n')[-1:] filter_single = [1] + tokenizer.encode('[')[-1:] + tokenizer.encode('(')[-1:] filter_single += [] if allow_linebreak else filter_n context_tokens = tokenizer.encode(prompt) out = sample_sequence( model=model, context=context_tokens, length=length, temperature=1, top_k=0, top_p=0.9, device=device, filter_single=filter_single, filter_double=filter_n, num_samples=num_samples, ).to('cpu') prompt = tokenizer.decode(context_tokens) len_prompt = len(prompt) replies = [out[item, :].tolist() for item in range(len(out))] text = [tokenizer.decode(item)[len_prompt:] for item in replies] reg_text = [re.match(r'[\w\W]*[\.!?]\n', item) for item in text] reg_text2 = [re.match(r'[\w\W]*[\.!?]', item) for item in text] result = [reg_item[0] if reg_item else reg_item2[0] if reg_item2 else item for reg_item, reg_item2, item in zip(reg_text, reg_text2, text)] logger.info("=" * 200) logger.info(result) return result
def sample(self, prompt: str, length: int, num_samples: int = 1, allow_linebreak: bool = True, stop_token: int = -1): filter_n = self.tokenizer.encode('\n')[-1:] filter_single = [ 1 ] + self.tokenizer.encode('[')[-1:] + self.tokenizer.encode('(')[-1:] filter_single += [] if allow_linebreak else filter_n context_tokens = self.tokenizer.encode(prompt) if stop_token == -1: out = sample_sequence(model=self.model, context=context_tokens, length=length, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, device=self.device, filter_single=filter_single, filter_double=filter_n, num_samples=num_samples).to('cpu') else: out = sample_sequence_until_token(model=self.model, context=context_tokens, length=length, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, device=self.device, filter_single=filter_single, filter_double=filter_n, num_samples=num_samples, stop_token=stop_token).to('cpu') prompt = self.tokenizer.decode(context_tokens) len_prompt = len(prompt) replies = [out[item, :].tolist() for item in range(len(out))] text = [self.tokenizer.decode(item)[len_prompt:] for item in replies] reg_text = [re.match(r'[\w\W]*[\.!?]\n', item) for item in text] reg_text2 = [re.match(r'[\w\W]*[\.!?]', item) for item in text] result = [ reg_item[0] if reg_item else reg_item2[0] if reg_item2 else item for reg_item, reg_item2, item in zip(reg_text, reg_text2, text) ] return result
def print_sample(model, tokenizer, device, args): model.eval() raw_text = """ """ context_tokens = tokenizer.encode(raw_text) out = sample_sequence( model=model, context=context_tokens, length=500, temperature=1, top_k=0, top_p=0.9, device=device, #is_xlnet=bool(args.model_type == "xlnet"), ) out = out[0, len(context_tokens):].tolist() text = raw_text + tokenizer.decode(out) print(text)
def print_sample(model, tokenizer, device): model.eval() raw_text = """ На словах ты Лев Толстой,\n А на деле -""" context_tokens = tokenizer.encode(raw_text) out = sample_sequence( model=model, context=context_tokens, length=500, temperature=1, top_k=0, top_p=0.9, device=device, #is_xlnet=bool(args.model_type == "xlnet"), ) out = out[0, len(context_tokens):].tolist() text = tokenizer.decode(out) print(raw_text + text) model.train()
def print_sample(model, tokenizer, device, args): model.eval() raw_text = """ Хорошее утро """ context_tokens = tokenizer.encode(raw_text) out = sample_sequence( model=model, context=context_tokens, length=500, temperature=1, top_k=0, top_p=0.9, device=device, #is_xlnet=bool(args.model_type == "xlnet"), ) out = out[0, len(context_tokens):].tolist() text = raw_text + tokenizer.decode(out) print(text) with open(os.path.join(args.output_dir, 'sample.txt'), 'w') as f: f.write(text) model.train()
def print_sample(model, tokenizer, device, args): model.eval() raw_text = """ На словах ты Лев Толстой,\n А на деле -""" context_tokens = tokenizer.encode(raw_text) out = sample_sequence(model=model, context=context_tokens, length=500, temperature=1, top_k=0, top_p=0.9, device=device, max_input=0 #is_xlnet=bool(args.model_type == "xlnet"), ) out = out[0, len(context_tokens):].tolist() text = raw_text + tokenizer.decode(out) log_info(text) if xm.is_master_ordinal(): with open(os.path.join(args.output_dir, 'sample.txt'), 'w') as f: f.write(text) model.train()
def get_sample(prompt, model, tokenizer, device): logger.info("*" * 200) logger.info(prompt) model.to(device) model.eval() filter_n = tokenizer.encode('\n')[-1:] context_tokens = tokenizer.encode(prompt) out = sample_sequence(model=model, context=context_tokens, length=150, temperature=1, top_k=0, top_p=0.9, device=device, filter_double=filter_n) out = out[0, len(context_tokens):].tolist() text = tokenizer.decode(out) result = re.match(r'[\w\W]*[\.!?]\n', text) if result: text = result[0] logger.info("=" * 200) logger.info(text) return text
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) parser.add_argument("--model_type", default="gpt2", type=str, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) parser.add_argument("--prompt", type=str, default="") parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.") parser.add_argument("--length", type=int, default=100) parser.add_argument("--num_samples", type=int, default=1) parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 0 implies greedy sampling") parser.add_argument( "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2") parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--stop_token', type=str, default='<|endoftext|>', help="Token at which text generation is stopped") args = parser.parse_args() args.device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() set_seed(args) args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_type) model = model_class.from_pretrained(args.model_name_or_path) model.to(args.device) model.eval() if args.length < 0 and model.config.max_position_embeddings > 0: args.length = model.config.max_position_embeddings elif 0 < model.config.max_position_embeddings < args.length: args.length = model.config.max_position_embeddings # No generation bigger than model size elif args.length < 0: args.length = MAX_LENGTH # avoid infinite loop logger.info(args) val_df = pd.read_csv("datasets/misinfo.csv") val_df = val_df[val_df["split"] == "val"] for i in range(args.num_samples): val_df["sample-" + str(i)] = "" for i, row in tqdm(val_df.iterrows()): context_tokens = tokenizer.encode(row.title, add_special_tokens=False) out = sample_sequence( model=model, context=context_tokens, num_samples=args.num_samples, length=args.length, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, is_xlnet=bool(args.model_type == "xlnet"), is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device=args.device, ) out = out[:, len(context_tokens):].tolist() j = 0 for o in out: text = tokenizer.decode(o, clean_up_tokenization_spaces=True) text = text[:text.find(args.stop_token) if args. stop_token else None] val_df.at[i, "sample-" + str(j)] = text j += 1 val_df.to_csv("datasets/val_samples.csv")