def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file): if transfo_xl_dataset_file: # Convert a pre-processed corpus (see original TensorFlow repo) with open(transfo_xl_dataset_file, "rb") as fp: corpus = pickle.load(fp, encoding="latin1") # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES[ 'pretrained_vocab_file'] print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) corpus_vocab_dict = corpus.vocab.__dict__ torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) corpus_dict_no_vocab = corpus.__dict__ corpus_dict_no_vocab.pop('vocab', None) pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME print("Save dataset to {}".format(pytorch_dataset_dump_path)) torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) if tf_checkpoint_path: # Convert a pre-trained TensorFlow model config_path = os.path.abspath(transfo_xl_config_file) tf_path = os.path.abspath(tf_checkpoint_path) print("Converting Transformer XL checkpoint from {} with config at {}". format(tf_path, config_path)) # Initialise PyTorch model if transfo_xl_config_file == "": config = TransfoXLConfig() else: config = TransfoXLConfig.from_json_file(transfo_xl_config_file) print("Building PyTorch model from configuration: {}".format( str(config))) model = TransfoXLLMHeadModel(config) model = load_tf_weights_in_transfo_xl(model, config, tf_path) # Save pytorch-model pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) print("Save PyTorch model to {}".format( os.path.abspath(pytorch_weights_dump_path))) torch.save(model.state_dict(), pytorch_weights_dump_path) print("Save configuration file to {}".format( os.path.abspath(pytorch_config_dump_path))) with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: f.write(config.to_json_string())
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels): model = TransfoXLLMHeadModel(config) model.eval() lm_logits_1, mems_1 = model(input_ids_1) loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels) lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1) loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1) outputs = { "loss_1": loss_1, "mems_1": mems_1, "lm_logits_1": lm_logits_1, "loss_2": loss_2, "mems_2": mems_2, "lm_logits_2": lm_logits_2, } return outputs
def main(): parser = argparse.ArgumentParser() add_dict_options(parser, ARGS) args = parser.parse_args() set_seed(args.seed) prefix_sampler = torch.load(args.prefix_file) tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model) model = TransfoXLLMHeadModel.from_pretrained(args.transfo_model) model.load_state_dict(torch.load(args.resume, map_location=lambda s, l: s)) model.cuda() sampler = SampleBatch(model, tokenizer, prefix_sampler) for _ in tqdm(range(args.num_samples)): print(sampler.simple_sample(pair=args.paired))
def main(): parser = argparse.ArgumentParser() add_dict_options(parser, ARGS) args = parser.parse_args() set_seed(args.seed) if args.prefix_file: prefix_sampler = torch.load(args.prefix_file) if args.transfo: tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model) model = TransfoXLLMHeadModel.from_pretrained(args.transfo_model) elif args.bert: tokenizer = BertTokenizer.from_pretrained(args.bert_model) model = BertForMaskedLM.from_pretrained(args.bert_model) else: tokenizer = GPT2Tokenizer.from_pretrained(args.gpt2_model) model = GPT2LMHeadModel.from_pretrained(args.gpt2_model) init_sos(model) if args.resume: model.load_state_dict( torch.load(args.resume, map_location=lambda s, l: s)) if not args.simple_sample: model = nn.DataParallel(model) model.cuda() if args.bert: text_batches = list(split(list(sys.stdin), 128)) for text_batch in tqdm(text_batches, desc='Augmenting'): for _ in range(args.num_samples): mtext_batch = [ ' '.join('[MASK]' if ( random.random() < 0.2 and '\t' not in x) else x for x in sent.split(' ')) for sent in text_batch ] print('\n'.join( x.replace('[SEP]', '\t').strip() for x in augment_texts( model, tokenizer, mtext_batch, max_len=args.msl))) sys.stdout.flush() return sample_batches = [ SampleBatch(model, tokenizer, prefix_sampler) for _ in range(args.num_buffers) ] if args.simple_sample: for _ in tqdm(range(args.num_samples)): print(sample_batches[0].simple_sample(pair=args.paired, transfo=args.transfo)) sys.stdout.flush() return n_output = 0 pbar = tqdm(total=args.num_samples, desc='Generating') while n_output < args.num_samples: try: sample_batch = random.choice(sample_batches) sample_batch.try_add_sample() fin_texts = sample_batch.step(pair=args.paired) except ValueError: sample_batch.try_add_sample() continue for fin_text in fin_texts: if n_output >= args.num_samples: return print(fin_text.replace(EOS_TOKEN, '').replace('<eos>', '\t')) sys.stdout.flush() pbar.update(1) n_output += 1 if (n_output + 1) % args.balance_every == 0: pbar.set_postfix(dict(last_balance=n_output)) SampleBatch.balance(sample_batches)
def main(): parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') parser.add_argument('--model_name', type=str, default='transfo-xl-wt103', help='pretrained model name') parser.add_argument('--split', type=str, default='test', choices=['all', 'valid', 'test'], help='which split to evaluate') parser.add_argument('--batch_size', type=int, default=10, help='batch size') parser.add_argument('--tgt_len', type=int, default=128, help='number of tokens to predict') parser.add_argument('--ext_len', type=int, default=0, help='length of the extended context') parser.add_argument('--mem_len', type=int, default=1600, help='length of the retained previous heads') parser.add_argument('--clamp_len', type=int, default=1000, help='max positional embedding index') parser.add_argument('--no_cuda', action='store_true', help='Do not use CUDA even though CUA is available') parser.add_argument('--work_dir', type=str, required=True, help='path to the work_dir') parser.add_argument('--no_log', action='store_true', help='do not log the eval result') parser.add_argument('--same_length', action='store_true', help='set same length attention with masking') parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") args = parser.parse_args() assert args.ext_len >= 0, 'extended context length must be non-negative' if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") logger.info("device: {}".format(device)) # Load a pre-processed dataset # You can also build the corpus yourself using TransfoXLCorpus methods # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax # and tokenizing the dataset # The pre-processed corpus is a convertion (using the conversion script ) tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name) corpus = TransfoXLCorpus.from_pretrained(args.model_name) ntokens = len(corpus.vocab) va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) # Load a pre-trained model model = TransfoXLLMHeadModel.from_pretrained(args.model_name) model = model.to(device) logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) model.reset_length(args.tgt_len, args.ext_len, args.mem_len) if args.clamp_len > 0: model.clamp_len = args.clamp_len if args.same_length: model.same_length = True ############################################################################### # Evaluation code ############################################################################### def evaluate(eval_iter): # Turn on evaluation mode which disables dropout. model.eval() total_len, total_loss = 0, 0. start_time = time.time() with torch.no_grad(): mems = None for idx, (data, target, seq_len) in enumerate(eval_iter): ret = model(data, target, mems) loss, mems = ret loss = loss.mean() total_loss += seq_len * loss.item() total_len += seq_len total_time = time.time() - start_time logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format( total_time, 1000 * total_time / (idx+1))) return total_loss / total_len # Run on test data. if args.split == 'all': test_loss = evaluate(te_iter) valid_loss = evaluate(va_iter) elif args.split == 'valid': valid_loss = evaluate(va_iter) test_loss = None elif args.split == 'test': test_loss = evaluate(te_iter) valid_loss = None def format_log(loss, split): log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( split, loss, math.exp(loss)) return log_str log_str = '' if valid_loss is not None: log_str += format_log(valid_loss, 'valid') if test_loss is not None: log_str += format_log(test_loss, 'test') logger.info('=' * 100) logger.info(log_str) logger.info('=' * 100)
def main(): def evaluate(data_source, split_encode=False): model.eval() total_loss = 0 total_words = 0 total_n = 0 batch_idx = 0 for batch in data_source: _, queries = batch try: queries, mask, total_chars, words = transfo_encode(tokenizer, queries, sos_idx, split_encode=split_encode, condition_model=args.conditioned_model) except KeyError: continue total_words += words mask = torch.Tensor(mask).cuda() queries = torch.LongTensor(queries).cuda() with torch.no_grad(): output = model(queries[:, :-1])[0].permute(0, 2, 1) targets = queries[:, 1:] crit = criterion(output, targets) mask_tot = mask[:, 1:].sum() raw_loss = (crit * mask[:, 1:]).sum() / mask_tot loss = raw_loss total_loss += raw_loss.item() * mask_tot.item() total_n += total_chars # print(total_loss / (math.log(2) * total_n)) cur_loss = total_loss / total_n elapsed = time.time() - start_time word_ppl = math.exp(total_loss / total_words) dual_print('-' * 89) dual_print('| end of epoch {:3d} | lr {:05.5f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format( epoch, optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, cur_loss, word_ppl, cur_loss / math.log(2))) dual_print('-' * 89) return cur_loss / math.log(2) parser = argparse.ArgumentParser() add_dict_options(parser, ARGS) args = parser.parse_args() set_seed(args.seed) sd = torch.load(args.cache_file) tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model, cache_dir='transfo-model') model = TransfoXLLMHeadModel.from_pretrained(args.transfo_model, cache_dir='transfo-model') if args.reset: model.apply(model.init_weights) sos_idx = None if not args.use_sos: sos_idx = None train_ds, dev_ds, test_ds = sd['splits'] criterion = nn.CrossEntropyLoss(reduction='none') train_loader = tud.DataLoader(train_ds, batch_size=args.train_batch_size, shuffle=True, drop_last=args.drop_last) dev_loader = tud.DataLoader(dev_ds, batch_size=args.eval_batch_size, shuffle=False, drop_last=args.drop_last) test_loader = tud.DataLoader(test_ds, batch_size=args.eval_batch_size, shuffle=False, drop_last=args.drop_last) no_decay = ['bias'] params = list(model.named_parameters()) optimizer_grouped_parameters = [ {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] num_train_optimization_steps = args.num_train_epochs * len(train_loader) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) scheduler = WarmupLinearSchedule(optimizer, warmup_steps=int(args.warmup_proportion * num_train_optimization_steps), t_total=num_train_optimization_steps) if args.resume: model.load_state_dict(torch.load(args.resume, map_location=lambda s, l: s)) if args.test_eval: while True: query = input("> ") print(sample_query(model, tokenizer, query)) model = nn.DataParallel(model).cuda() start_time = time.time() best_bpc = 1000000 if not args.do_train: evaluate(test_loader, split_encode=False) return for epoch in range(args.num_train_epochs): epoch += 1 total_loss = 0 total_words = 0 total_n = 0 batch_idx = 0 for batch in train_loader: model.train() _, queries = batch try: queries, mask, total_chars, words = transfo_encode(tokenizer, queries, sos_idx, split_encode=args.split_encode, condition_model=args.conditioned_model) except KeyError: dual_print('Skipped batch') continue total_words += words mask = torch.Tensor(mask).cuda() queries = torch.LongTensor(queries).cuda() optimizer.zero_grad() output = model(queries[:, :-1])[0].permute(0, 2, 1) targets = queries[:, 1:] crit = criterion(output, targets) mask_tot = mask[:, 1:].sum() raw_loss = (crit * mask[:, 1:]).sum() / mask_tot loss = raw_loss loss.backward() optimizer.step() scheduler.step() total_loss += raw_loss.item() * mask_tot.item() total_n += total_chars if batch_idx % args.log_interval == 0 and batch_idx > 0: cur_loss = total_loss / total_n word_ppl = math.exp(total_loss / total_words) total_words = 0 elapsed = time.time() - start_time dual_print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format( epoch, batch_idx, len(train_loader), optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, cur_loss, word_ppl, cur_loss / math.log(2))) total_loss = 0 total_n = 0 start_time = time.time() batch_idx += 1 bpc = evaluate(dev_loader) if bpc < best_bpc: best_bpc = bpc torch.save(model.module.state_dict(), args.save) evaluate(test_loader)