def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if pt != '': test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if k in model_flags: setattr(args, k, opt[k]) print(args) config = XLNetConfig.from_pretrained(args.config_path) model = Summarizer(args, device, load_pretrained_bert=False, bert_config=config) model.load_cp(checkpoint) model.eval() valid_iter = Dataloader(args, load_dataset(args, 'valid', shuffle=False), args.batch_size, device, shuffle=False, is_test=False) trainer = build_trainer(args, device_id, model, None) stats = trainer.validate(valid_iter, step) return stats.xent()
def multi_main(args): """ Spawns 1 process per GPU """ init_logger() nb_gpu = args.world_size mp = torch.multiprocessing.get_context('spawn') # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for i in range(nb_gpu): device_id = i procs.append( mp.Process(target=run, args=( args, device_id, error_queue, ), daemon=True)) procs[i].start() logger.info(" Starting process pid: %d " % procs[i].pid) error_handler.add_child(procs[i].pid) for p in procs: p.join()
def build_trainer(args, device_id, model, optim): """ Configures GPU device, summary writer, report manager :return trainer: trainer object created with above arguments """ grad_accum_count = args.accum_count n_gpu = args.world_size # Configure GPU device if device_id >= 0: gpu_rank = int(args.gpu_ranks[device_id]) else: gpu_rank = 0 n_gpu = 0 print('gpu_rank %d' % gpu_rank) # Configure summary writer tensorboard_log_dir = args.model_path writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") # Configure report manager report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) # Create trainer object trainer = Trainer(args, model, optim, grad_accum_count, n_gpu, gpu_rank, report_manager) # print number of params if model: n_params = _tally_parameters(model) logger.info('* number of parameters: %d' % n_params) return trainer
def train(self, train_iter_fct, train_steps): logger.info('Start training...') step = self.optim._step + 1 true_batchs = [] accum = 0 normalization = 0 n_gpu = self.n_gpu gpu_rank = self.gpu_rank grad_accum_count = self.grad_accum_count # Iterable of training batches. train_iter = train_iter_fct() # Configure statistics report. total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) # Training loop. while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if n_gpu == 0 or i % n_gpu == gpu_rank: true_batchs.append(batch) normalization += batch.batch_size accum += 1 if accum == grad_accum_count: reduce_counter += 1 if n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) # Gradient accumulation for model. self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) # Report statistics for training. report_stats = self._maybe_report_training( step, train_steps, self.optim.learning_rate, report_stats) # Initialize variables true_batchs = [] accum = 0 normalization = 0 if step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0: self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats
def wait_and_validate(args, device_id): timestep = 0 if args.test_all: cp_files = sorted( glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) cp_files.sort(key=os.path.getmtime) xent_lst = [] for i, cp in enumerate(cp_files): step = int(cp.split('.')[-2].split('_')[-1]) xent = validate(args, device_id, cp, step) xent_lst.append((xent, cp)) max_step = xent_lst.index(min(xent_lst)) if (i - max_step > 10): break xent_lst = sorted(xent_lst, key=lambda x: x[0])[:3] logger.info('PPL %s' % str(xent_lst)) for xent, cp in xent_lst: step = int(cp.split('.')[-2].split('_')[-1]) test(args, device_id, cp, step) else: while (True): cp_files = sorted( glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) cp_files.sort(key=os.path.getmtime) if (cp_files): cp = cp_files[-1] time_of_cp = os.path.getmtime(cp) if (not os.path.getsize(cp) > 0): time.sleep(60) continue if (time_of_cp > timestep): timestep = time_of_cp step = int(cp.split('.')[-2].split('_')[-1]) validate(args, device_id, cp, step) test(args, device_id, cp, step) cp_files = sorted( glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) cp_files.sort(key=os.path.getmtime) if (cp_files): cp = cp_files[-1] time_of_cp = os.path.getmtime(cp) if (time_of_cp > timestep): continue else: time.sleep(300)
def output(self, step, num_steps, learning_rate, start): """Write out statistics to stdout. Args: step (int): current step n_batch (int): total batches start (int): start time of step. """ t = self.elapsed_time() step_fmt = "%2d" % step if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) logger.info( ("Step %s; xent: %4.2f; " + "lr: %7.7f; %3.0f docs/s; %6.0f sec") % (step_fmt, self.xent(), learning_rate, self.n_docs / (t + 1e-5), time.time() - start)) sys.stdout.flush()
def train(args, device_id): # Start logger. init_logger(args.log_file) # Configure training device. device = "cpu" if args.visible_gpus == '-1' else "cuda" logger.info('Device ID %d' % device_id) logger.info('Device %s' % device) # Configure manual seed. torch.manual_seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True # Set CUDA device. if device_id >= 0: torch.cuda.set_device(device_id) torch.cuda.manual_seed(args.seed) # Configure manual seed. torch.manual_seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True # Dataloader used for training. def train_iter_fct(): return Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device, shuffle=True, is_test=False) # Build the model. model = Summarizer(args, device, load_pretrained=True) # Configure the checkpoint. if args.train_from != '': logger.info('Loading checkpoint from %s' % args.train_from) checkpoint = torch.load(args.train_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if k in model_flags: setattr(args, k, opt[k]) model.load_cp(checkpoint) optim = builder.build_optim(args, model, checkpoint) else: optim = builder.build_optim(args, model, None) logger.info(model) # Train the model trainer = build_trainer(args, device_id, model, optim) trainer.train(train_iter_fct, args.train_steps)
def _save(self, step): real_model = self.model # real_generator = (self.generator.module # if isinstance(self.generator, torch.nn.DataParallel) # else self.generator) model_state_dict = real_model.state_dict() # generator_state_dict = real_generator.state_dict() checkpoint = { 'model': model_state_dict, # 'generator': generator_state_dict, 'opt': self.args, 'optim': self.optim, } checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) logger.info("Saving checkpoint %s" % checkpoint_path) # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) if (not os.path.exists(checkpoint_path)): torch.save(checkpoint, checkpoint_path) return checkpoint, checkpoint_path
def _format_xlnet(param): json_file, args, save_file = param # if file already exists, ignore if os.path.exists(save_file): logger.info('Ignore %s' % save_file) return xlnet = XLData(args) logger.info('Processing %s' % json_file) jobs = json.load(open(json_file)) data_set = [] # iterate over text in json_file for d in jobs: # generate oracle ids src, tgt = d['src'], d['tgt'] if args.oracle_mode == 'greedy': oracle_ids = greedy(src, tgt, 3) elif args.oracle_mode == 'combination': oracle_ids = combination(src, tgt, 3) # process data using oracle ids xl_data = xlnet.process(src, tgt, oracle_ids) if xl_data is None: continue indexed_tokens, labels, segments_ids, cls_ids, src_txt, tgt_txt = xl_data b_data_dict = { "src": indexed_tokens, "labels": labels, "segs": segments_ids, 'clss': cls_ids, 'src_txt': src_txt, "tgt_txt": tgt_txt } data_set.append(b_data_dict) # save file with torch logger.info('Saving to %s' % save_file) torch.save(data_set, save_file) gc.collect()
def _lazy_dataset_loader(pt_file, corpus_type): dataset = torch.load(pt_file) logger.info('Loading %s dataset from %s, number of examples: %d' % (corpus_type, pt_file, len(dataset))) return dataset
def log(self, *args, **kwargs): logger.info(*args, **kwargs)
def test(self, test_iter, step, cal_lead=False, cal_oracle=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. def _get_ngrams(n, text): ngram_set = set() text_length = len(text) max_index_ngram_start = text_length - n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(text[i:i + n])) return ngram_set def _block_tri(c, p): tri_c = _get_ngrams(3, c.split()) for s in p: tri_s = _get_ngrams(3, s.split()) if len(tri_c.intersection(tri_s)) > 0: return True return False if (not cal_lead and not cal_oracle): self.model.eval() stats = Statistics() can_path = '%s_step%d.candidate' % (self.args.result_path, step) gold_path = '%s_step%d.gold' % (self.args.result_path, step) with open(can_path, 'w') as save_pred: with open(gold_path, 'w') as save_gold: with torch.no_grad(): for batch in test_iter: src = batch.src labels = batch.labels segs = batch.segs clss = batch.clss mask = batch.mask mask_cls = batch.mask_cls gold = [] pred = [] if (cal_lead): selected_ids = [list(range(batch.clss.size(1))) ] * batch.batch_size elif (cal_oracle): selected_ids = [[ j for j in range(batch.clss.size(1)) if labels[i][j] == 1 ] for i in range(batch.batch_size)] else: sent_scores, mask = self.model( src, clss, mask, mask_cls) loss = self.loss(sent_scores, labels.float()) loss = (loss * mask.float()).sum() batch_stats = Statistics( float(loss.cpu().data.numpy()), len(labels)) stats.update(batch_stats) sent_scores = sent_scores + mask.float() sent_scores = sent_scores.cpu().data.numpy() selected_ids = np.argsort(-sent_scores, 1) # selected_ids = np.sort(selected_ids,1) for i, idx in enumerate(selected_ids): _pred = [] if (len(batch.src_str[i]) == 0): continue for j in selected_ids[i][:len(batch.src_str[i])]: if (j >= len(batch.src_str[i])): continue candidate = batch.src_str[i][j].strip() if (self.args.block_trigram): if (not _block_tri(candidate, _pred)): _pred.append(candidate) else: _pred.append(candidate) if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): break _pred = '<q>'.join(_pred) if (self.args.recall_eval): _pred = ' '.join( _pred.split() [:len(batch.tgt_str[i].split())]) pred.append(_pred) gold.append(batch.tgt_str[i]) for i in range(len(gold)): save_gold.write(gold[i].strip() + '\n') for i in range(len(pred)): save_pred.write(pred[i].strip() + '\n') if (step != -1 and self.args.report_rouge): rouges = test_rouge(self.args.temp_dir, can_path, gold_path) logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) self._report_step(0, step, valid_stats=stats) return stats