def test_text_abs(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = { 'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]'] } predictor = build_predictor(args, tokenizer, symbols, model, logger) predictor.translate(test_iter, step)
def test_text_abs(args, device_id, pt, step): device = "cpu" if args.visible_gpus == "-1" else "cuda" if pt != "": test_from = pt else: test_from = args.test_from logger.info("Loading checkpoint from %s" % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint["opt"]) for k in opt.keys(): if k in model_flags: setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader( args, load_dataset(args, "test", shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True, ) tokenizer = BertTokenizer.from_pretrained("bert-base-chinese", do_lower_case=True, cache_dir=args.temp_dir) symbols = { "BOS": tokenizer.vocab["[unused0]"], "EOS": tokenizer.vocab["[unused1]"], "PAD": tokenizer.vocab["[PAD]"], "EOQ": tokenizer.vocab["[unused2]"], } predictor = build_predictor(args, tokenizer, symbols, model, logger) predictor.translate(test_iter, step)
def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = ExtSummarizer(args, device, checkpoint) model.eval() valid_iter = data_loader.Dataloader(args, load_dataset(args, 'dev', shuffle=False), args.batch_size, device, shuffle=False, is_test=False) trainer = build_trainer(args, device_id, model, None) stats = trainer.validate(valid_iter, step) return stats.xent()
def test_abs(args, device_id, pt, step, tokenizer): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) symbols = { 'BOS': tokenizer.vocab[special_token['[BOS]']], 'EOS': tokenizer.vocab[special_token['[EOS]']], 'PAD': tokenizer.vocab[special_token['[PAD]']], 'EOQ': tokenizer.vocab[special_token['[QOS]']] } predictor = build_predictor(args, tokenizer, symbols, model, logger) predictor.translate(test_iter, step)
def test_abs(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) tokenizer = BertData(args).tokenizer #tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) # tokenizer = None # if args.pretrained_model_type in ['bert-base-uncased', 'bert-base-multilingual-uncased']: # tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_type, do_lower_case=True, cache_dir=args.temp_dir) # # if not tokenizer: # raise NotImplementedError("tokenizer") # tokenizer = add_to_vocab(tokenizer, ['[unused0]', '[unused1]', '[PAD]', '[unused2]']) symbols = {'BOS': tokenizer.convert_tokens_to_ids('[unused0]'), 'EOS': tokenizer.convert_tokens_to_ids('[unused1]'), 'PAD': tokenizer.convert_tokens_to_ids('[PAD]'), 'EOQ': tokenizer.convert_tokens_to_ids('[unused2]')} predictor = build_predictor(args, tokenizer, symbols, model, logger) predictor.translate(test_iter, step)
def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), args.batch_size, device, shuffle=False, is_test=False) if args.bart: tokenizer = AutoTokenizer.from_pretrained('/home/ybai/downloads/bart', do_lower_case=True, cache_dir=args.temp_dir, local_files_only=False) symbols = {'BOS': tokenizer.encoder['madeupword0000'], 'EOS': tokenizer.encoder['madeupword0001'], 'PAD': tokenizer.encoder['<pad>'], 'EOQ': tokenizer.encoder['madeupword0002']} else: tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased', do_lower_case=True, cache_dir=args.temp_dir, local_files_only=True) symbols = {'BOS': tokenizer.vocab['[unused1]'], 'EOS': tokenizer.vocab['[unused2]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused3]']} valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device) trainer = build_trainer(args, device_id, model, None, valid_loss) stats = trainer.validate(valid_iter, step) return stats.xent()
def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if pt != '': test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if k in model_flags: setattr(args, k, opt[k]) print(args) symbols, tokenizer = get_symbol_and_tokenizer(args.encoder, args.temp_dir) model = AbsSummarizer(args, device, checkpoint, symbols=symbols) model.eval() valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), args.batch_size, device, shuffle=False, is_test=False, tokenizer=tokenizer) valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device) trainer = build_trainer(args, device_id, model, None, valid_loss) stats = trainer.validate(valid_iter, step) return stats.xent()
def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) config = BertConfig.from_json_file(args.bert_config_path) model = Summarizer(args, device, load_pretrained_bert=False, bert_config=config) model.load_cp(checkpoint) model.eval() valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), args.batch_size, device, shuffle=False, is_test=False) trainer = build_trainer(args, device_id, model, None) # comet_experiment.log_parameters(config) with comet_experiment.test(): stats = trainer.validate(valid_iter, step) return stats.xent()
def test_text_abs(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) # for chinese tokenization add_token_list = ['[unused1]', '[unused2]', '[unused3]', '[unused4]', '[unused5]'] if args.bart: tokenizer = AutoTokenizer.from_pretrained('bart-base', do_lower_case=True, cache_dir=args.temp_dir, local_files_only=False) # tokenizer = AutoTokenizer.from_pretrained('/home/ybai/downloads/bart', do_lower_case=True, # cache_dir=args.temp_dir, local_files_only=False) symbols = {'BOS': tokenizer.encoder['madeupword0000'], 'EOS': tokenizer.encoder['madeupword0001'], 'PAD': tokenizer.encoder['<pad>'], 'EOQ': tokenizer.encoder['madeupword0002']} else: tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased', do_lower_case=True, cache_dir=args.temp_dir, local_files_only=False, additional_special_tokens=add_token_list) symbols = {'BOS': tokenizer.vocab['[unused1]'], 'EOS': tokenizer.vocab['[unused2]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused3]']} predictor = build_predictor(args, tokenizer, symbols, model, logger) predictor.translate(test_iter, step)
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') # step = self.optim._step + 1 step = self.optims[0]._step + 1 true_batchs = [] accum = 0 normalization = 0 train_iter = train_iter_fct() total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): true_batchs.append(batch) num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum() normalization += num_tokens.item() accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) report_stats = self._maybe_report_training( step, train_steps, self.optims[0].learning_rate, report_stats) if step % self.args.report_every == 0: self.model.eval() logger.info('Model in set eval state') valid_iter = data_loader.Dataloader( self.args, load_dataset(self.args, 'test', shuffle=False), self.args.batch_size, "cuda", shuffle=False, is_test=True) tokenizer = BertTokenizer.from_pretrained( self.args.model_path, do_lower_case=True) symbols = { 'BOS': tokenizer.vocab['[unused1]'], 'EOS': tokenizer.vocab['[unused2]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused3]'] } valid_loss = abs_loss(self.model.generator, symbols, self.model.vocab_size, train=False, device="cuda") trainer = build_trainer(self.args, 0, self.model, None, valid_loss) stats = trainer.validate(valid_iter, step) self.report_manager.report_step( self.optims[0].learning_rate, step, train_stats=None, valid_stats=stats) self.model.train() true_batchs = [] accum = 0 normalization = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats
def test_model(self, extractor, corpus_type='test', block_trigram=True, quick_test=False): logger.info('Test SentExt model (%s) and GuidAbs model (%s) ...' % (extractor.name, self.model_file)) testname = '%s_guidabs_%s' % (extractor.name, 'blocktrigram' if block_trigram else 'noblocktrigram') # buid args args = self._build_abs_args() args.mode = 'test' args.bert_data_path = path.join(self.data_path, 'cnndm') args.model_path = self.result_path args.log_file = path.join(self.result_path, 'test_varextabs.%s.log' % testname) args.result_path = path.join(self.result_path, 'cnndm_' + testname) args.block_trigram = block_trigram init_logger(args.log_file) # load abs model abs_model_file = self.model_file logger.info('Loading abs model %s' % abs_model_file) step_abs = int(abs_model_file.split('.')[-2].split('_')[-1]) checkpoint = torch.load(abs_model_file, map_location=lambda storage, loc: storage) model_abs = model_bld.AbsSummarizer(args, args.device, checkpoint) model_abs.eval() # init model testers tokenizer = BertTokenizer.from_pretrained(path.join( args.bert_model_path, model_abs.bert.model_name), do_lower_case=True, cache_dir=args.temp_dir) symbols = { 'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]'] } predictor = pred_abs.build_predictor(args, tokenizer, symbols, model_abs, logger) test_iter = data_ldr.Dataloader(args, data_ldr.load_dataset(args, corpus_type, shuffle=False), args.test_batch_size, args.device, shuffle=False, is_test=True, keep_order=True) logger.info('Generating Ext/GuidAbs results %s ...' % args.result_path) avg_f1 = test_ext_abs(logger, args, extractor, predictor, 0, step_abs, test_iter, quick_test=quick_test) return avg_f1
def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), args.batch_size, device, shuffle=False, is_test=False) if (args.bert_model == 'bert-base-multilingual-cased'): tokenizer = BertTokenizer.from_pretrained( 'bert-base-multilingual-cased', do_lower_case=False, cache_dir=args.temp_dir) else: tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True, cache_dir=args.temp_dir) print(len(tokenizer.vocab)) if (len(tokenizer.vocab) == 31748): f = open(args.bert_model + "/vocab.txt", "a") f.write( "\n[unused1]\n[unused2]\n[unused3]\n[unused4]\n[unused5]\n[unused6]\n[unused7]" ) f.close() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True) print(len(tokenizer.vocab)) symbols = { 'BOS': tokenizer.vocab['[unused1]'], 'EOS': tokenizer.vocab['[unused2]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused3]'] } model = AbsSummarizer(args, device, checkpoint) model.eval() valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device) trainer = build_trainer(args, device_id, model, None, valid_loss) stats = trainer.validate(valid_iter, step) return stats.xent()
def summarize(text): with io.open('../raw_stories/test.story', 'w', encoding="utf8") as file: file.write(text.strip() + "\n\n@highlight\n\n" + "tim") # TOKENIZE # raw_stories -> merged_stories_tokenized parser = argparse.ArgumentParser() parser.add_argument("-mode", default='', type=str, help='format_to_lines or format_to_bert') parser.add_argument( "-oracle_mode", default='greedy', type=str, help= 'how to generate oracle summaries, greedy or combination, combination will generate more accurate oracles but take much longer time.' ) parser.add_argument("-map_path", default='../data/') parser.add_argument("-raw_path", default='../raw_stories/') parser.add_argument("-save_path", default='../merged_stories_tokenized/') parser.add_argument("-shard_size", default=2000, type=int) parser.add_argument('-min_nsents', default=3, type=int) parser.add_argument('-max_nsents', default=100, type=int) parser.add_argument('-min_src_ntokens', default=5, type=int) parser.add_argument('-max_src_ntokens', default=200, type=int) parser.add_argument("-lower", type=str2bool, nargs='?', const=True, default=True) parser.add_argument('-log_file', default='../logs/cnndm.log') parser.add_argument( '-dataset', default='', help='train, valid or test, defaul will process all datasets') parser.add_argument('-n_cpus', default=2, type=int) args = parser.parse_args() data_builder.tokenize(args) # FORMAT TO LINES # merged_stories_tokenized -> my_json_data parser = argparse.ArgumentParser() parser.add_argument("-mode", default='', type=str, help='format_to_lines or format_to_bert') parser.add_argument( "-oracle_mode", default='greedy', type=str, help= 'how to generate oracle summaries, greedy or combination, combination will generate more accurate oracles but take much longer time.' ) parser.add_argument("-map_path", default='../data/') parser.add_argument("-raw_path", default='../merged_stories_tokenized/') parser.add_argument("-save_path", default='../my_json_data/') parser.add_argument("-shard_size", default=2000, type=int) parser.add_argument('-min_nsents', default=3, type=int) parser.add_argument('-max_nsents', default=100, type=int) parser.add_argument('-min_src_ntokens', default=5, type=int) parser.add_argument('-max_src_ntokens', default=200, type=int) parser.add_argument("-lower", type=str2bool, nargs='?', const=True, default=True) parser.add_argument('-log_file', default='../logs/cnndm.log') parser.add_argument( '-dataset', default='', help='train, valid or test, defaul will process all datasets') parser.add_argument('-n_cpus', default=2, type=int) args = parser.parse_args() data_builder.format_to_lines_only_test(args) # FORMAT TO BERT # my_json_data -> bert_data_final parser = argparse.ArgumentParser() parser.add_argument("-mode", default='', type=str, help='format_to_lines or format_to_bert') parser.add_argument( "-oracle_mode", default='greedy', type=str, help= 'how to generate oracle summaries, greedy or combination, combination will generate more accurate oracles but take much longer time.' ) parser.add_argument("-map_path", default='../data/') parser.add_argument("-raw_path", default='../my_json_data/') parser.add_argument("-save_path", default='../bert_data_final/') parser.add_argument("-shard_size", default=2000, type=int) parser.add_argument('-min_nsents', default=3, type=int) parser.add_argument('-max_nsents', default=100, type=int) parser.add_argument('-min_src_ntokens', default=5, type=int) parser.add_argument('-max_src_ntokens', default=200, type=int) parser.add_argument("-lower", type=str2bool, nargs='?', const=True, default=True) parser.add_argument('-log_file', default='../../logs/preprocess.log') parser.add_argument( '-dataset', default='test', help='train, valid or test, defaul will process all datasets') parser.add_argument('-n_cpus', default=2, type=int) args = parser.parse_args() data_builder.format_to_bert(args) # GENERATE SUMMARY test_iter = data_loader.Dataloader(model_args, load_dataset(model_args, 'test', shuffle=False), model_args.batch_size, device, shuffle=False, is_test=True) trainer = build_trainer(model_args, device_id, model, None) result_string = trainer.test(test_iter, step) os.remove("../raw_stories/test.story") os.remove("../merged_stories_tokenized/test.story.json") os.remove("../my_json_data/test.0.json") os.remove("../bert_data_final/test.0.bert.pt") return result_string
def baseline(args, cal_lead=False, cal_oracle=False): test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, args.test_batch_ex_size, 'cpu', shuffle=False, is_test=True) if cal_lead: mode = "lead" else: mode = "oracle" rouge = Rouge() pred_path = '%s.%s.pred' % (args.result_path, mode) gold_path = '%s.%s.gold' % (args.result_path, mode) save_pred = open(pred_path, 'w', encoding='utf-8') save_gold = open(gold_path, 'w', encoding='utf-8') with torch.no_grad(): for batch in test_iter: summaries = batch.summ_txt origin_sents = batch.original_str ex_segs = batch.ex_segs ex_segs = [sum(ex_segs[:i]) for i in range(len(ex_segs)+1)] for idx in range(len(summaries)): summary = summaries[idx] txt = origin_sents[ex_segs[idx]:ex_segs[idx+1]] if cal_oracle: selected = [] max_rouge = 0. while len(selected) < args.ranking_max_k: cur_max_rouge = max_rouge cur_id = -1 for i in range(len(txt)): if (i in selected): continue c = selected + [i] temp_txt = " ".join([txt[j] for j in c]) rouge_score = rouge.get_scores(temp_txt, summary) rouge_1 = rouge_score[0]["rouge-1"]["f"] rouge_l = rouge_score[0]["rouge-l"]["f"] rouge_score = rouge_1 + rouge_l if rouge_score > cur_max_rouge: cur_max_rouge = rouge_score cur_id = i if (cur_id == -1): break selected.append(cur_id) max_rouge = cur_max_rouge pred_txt = " ".join([txt[j] for j in selected]) else: k = min(max(len(txt) // (2*args.win_size+1), 1), args.ranking_max_k) pred_txt = " ".join(txt[:k]) save_gold.write(summary + "\n") save_pred.write(pred_txt + "\n") save_gold.flush() save_pred.flush() save_gold.close() save_pred.close() length = test_length(pred_path, gold_path) bleu = test_bleu(pred_path, gold_path) file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) pred_rouges = file_rouge.get_scores(avg=True) logger.info('Length ratio:\n%s' % str(length)) logger.info('Bleu:\n%.2f' % (bleu*100)) logger.info('Rouges:\n%s' % rouge_results_to_str(pred_rouges))
def train_iter_fct(): # return data_loader.AbstractiveDataloader(load_dataset('train', True), symbols, FLAGS.batch_size, device, True) return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device, shuffle=True, is_test=False)
) logging.info('Optimizer = Adam') json_path = os.path.join(args.model_dir, 'params.json') assert os.path.isfile( json_path), "No json configuration found at {}".format(json_path) logging.info("json path : " + json_path) #Read params file params = utils.Params(json_path) for item in params.dict: logging.info(item + " : " + str(params.dict[item])) #Generate Dataloader logging.info("Generating the dataloader") dataloader = data_loader.Dataloader(params) logging.info("Done loading the Dataloader") # use GPU if available cuda_present = torch.cuda.is_available() #Boolean cuda_present = False if cuda_present: logging.info("using CUDA") else: logging.info("cuda not available, using CPU") logging.info("Loading model and weights") for t in range(1):
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') # step = self.optim._step + 1 step = self.optim._step + 1 true_batchs = [] accum = 0 normalization = 0 neg_valid_loss = [] # minheap, minum value at top heapq.heapify(neg_valid_loss) # use neg loss to find top 3 largest neg loss total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) #select_counts = np.random.choice(range(3), train_steps + 1) cur_epoch = 0 train_iter = train_iter_fct() #logger.info('Current Epoch:%d' % cur_epoch) #logger.info('maxEpoch:%d' % self.args.max_epoch) #while step <= train_steps: while cur_epoch < self.args.max_epoch: reduce_counter = 0 logger.info('Current Epoch:%d' % cur_epoch) for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): # from batch.labels, add selected sent index to batch # after teacher forcing, use model selected sentences # or infer scores of batch and get selected sent index # then add selected sent index to the batch true_batchs.append(batch) #normalization += batch.batch_size ##loss normalized wrong normalization = batch.batch_size ##loss recorded correspond to each minibatch accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum(distributed .all_gather_list (normalization)) self._gradient_accumulation( true_batchs, normalization, total_stats, report_stats) report_stats = self._maybe_report_training( step, train_steps, self.optim.learning_rate, report_stats) true_batchs = [] accum = 0 normalization = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): valid_iter =data_loader.Dataloader(self.args, load_dataset(self.args, 'valid', shuffle=False), self.args.batch_size * 10, self.device, shuffle=False, is_test=True) #batch_size train: 3000, test: 60000 stats = self.validate(valid_iter, step, self.args.valid_by_rouge) self.model.train() # back to training cur_valid_loss = stats.xent() checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) if len(neg_valid_loss) < self.args.save_model_count: self._save(step) heapq.heappush(neg_valid_loss, (-cur_valid_loss, checkpoint_path)) else: if -cur_valid_loss > neg_valid_loss[0][0]: heapq.heappush(neg_valid_loss, (-cur_valid_loss, checkpoint_path)) worse_loss, worse_model = heapq.heappop(neg_valid_loss) os.remove(worse_model) self._save(step) #else do not save it logger.info('step_%d:%s' % (step, str(neg_valid_loss))) step += 1 if step > train_steps: break cur_epoch += 1 train_iter = train_iter_fct() return total_stats, neg_valid_loss
def train_iter_fct(): return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device, shuffle=True, is_test=False)
def abs_decoder(args, device_id, pt): step = int(pt.split('.')[-2].split('_')[-1]) device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) config = BertConfig.from_json_file(args.bert_config_path) model = Summarizer(args, device, load_pretrained_bert=False, bert_config=config) # decoder decoder = Decoder(model.bert.model.config.hidden_size // 2, model.bert.model.config.vocab_size, model.bert.model.config.hidden_size, model.bert.model.embeddings, device, logger) # 2*hidden_dim = embedding_size # get initial s_t s_t_1 = get_initial_s(model.bert.model.config.hidden_size, device) model.load_cp(checkpoint) s_t_1.load_cp(checkpoint) decoder.load_cp(checkpoint) model.eval() decoder.eval() s_t_1.eval() # tokenizer,nlp tokenizer = BertTokenizer.from_pretrained( 'bert-base-uncased', do_lower_case=True, never_split=('[SEP]', '[CLS]', '[PAD]', '[unused0]', '[unused1]', '[unused2]', '[UNK]'), no_word_piece=True) nlp = StanfordCoreNLP(r'/home1/bqw/stanford-corenlp-full-2018-10-05') # nlp.logging_level = 10 test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.batch_size, device, shuffle=False, is_test=True) trainer = build_trainer(args, device_id, model, None, decoder=decoder, get_s_t=s_t_1, device=device_id, tokenizer=tokenizer, nlp=nlp, extract_num=args.extract_num) trainer.abs_decode(test_iter, step)