def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.eval() if args.prob: self.transform.append(Field('probs')) logger.info("Load the data") dataset = Dataset(self.transform, data) dataset.build(args.batch_size, args.buckets, shuffle=False) logger.info(f"\n{dataset}") logger.info("Make predictions on the dataset") start = datetime.now() preds = self._predict(dataset.loader) elapsed = datetime.now() - start for name, value in preds.items(): setattr(dataset, name, value) if pred is not None: logger.info(f"Save predicted results to {pred}") self.transform.save(pred, dataset.sentences) logger.info( f"{elapsed}s elapsed, {len(dataset) / elapsed.total_seconds():.2f} Sents/s" ) return dataset
def train_abs_multi(args): """ Spawns 1 process per GPU """ init_logger() nb_gpu = args.world_size mp = torch.multiprocessing.get_context('spawn') # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for i in range(nb_gpu): device_id = i procs.append( mp.Process(target=run, args=( args, device_id, error_queue, ), daemon=True)) procs[i].start() logger.info(" Starting process pid: %d " % procs[i].pid) error_handler.add_child(procs[i].pid) for p in procs: p.join()
def test_text_abs(args, text_src, script=False): logger.info('Loading checkpoint from %s' % args.test_from) device = "cpu" if args.visible_gpus == '-1' else "cuda" checkpoint = torch.load(args.test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.load_text(args, text_src, device, script=script) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = { 'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]'] } predictor = build_predictor(args, tokenizer, symbols, model, logger) return predictor.translate(test_iter, -1)
def validate_abs(args, device_id): timestep = 0 if (args.test_all): cp_files = sorted( glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) cp_files.sort(key=os.path.getmtime) xent_lst = [] for i, cp in enumerate(cp_files): step = int(cp.split('.')[-2].split('_')[-1]) if (args.test_start_from != -1 and step < args.test_start_from): xent_lst.append((1e6, cp)) continue xent = validate(args, device_id, cp, step) xent_lst.append((xent, cp)) max_step = xent_lst.index(min(xent_lst)) if (i - max_step > 10): break xent_lst = sorted(xent_lst, key=lambda x: x[0])[:5] logger.info('PPL %s' % str(xent_lst)) for xent, cp in xent_lst: step = int(cp.split('.')[-2].split('_')[-1]) test_abs(args, device_id, cp, step) else: while (True): cp_files = sorted( glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) cp_files.sort(key=os.path.getmtime) if (cp_files): cp = cp_files[-1] time_of_cp = os.path.getmtime(cp) if (not os.path.getsize(cp) > 0): time.sleep(60) continue if (time_of_cp > timestep): timestep = time_of_cp step = int(cp.split('.')[-2].split('_')[-1]) validate(args, device_id, cp, step) test_abs(args, device_id, cp, step) cp_files = sorted( glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) cp_files.sort(key=os.path.getmtime) if (cp_files): cp = cp_files[-1] time_of_cp = os.path.getmtime(cp) if (time_of_cp > timestep): continue else: time.sleep(300)
def output(self, step, num_steps, learning_rate, start): """Write out statistics to stdout. Args: step (int): current step n_batch (int): total batches start (int): start time of step. """ t = self.elapsed_time() logger.info( ("Step %2d/%5d; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + "lr: %7.8f; %3.0f/%3.0f tok/s; %6.0f sec") % (step, num_steps, self.accuracy(), self.ppl(), self.xent(), learning_rate, self.n_src_words / (t + 1e-5), self.n_words / (t + 1e-5), time.time() - start)) sys.stdout.flush()
def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), args.batch_size, device, shuffle=False, is_test=False) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = { 'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]'] } valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device) trainer = build_trainer(args, device_id, model, None, valid_loss) stats = trainer.validate(valid_iter, step) return stats.xent()
def build_trainer(args, device_id, model, optims, loss): """ Simplify `Trainer` creation based on user `opt`s* Args: opt (:obj:`Namespace`): user options (usually from argument parsing) model (:obj:`onmt.models.NMTModel`): the model to train fields (dict): dict of fields optim (:obj:`onmt.utils.Optimizer`): optimizer used during training data_type (str): string describing the type of data e.g. "text", "img", "audio" model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object used to save the model """ device = "cpu" if args.visible_gpus == '-1' else "cuda" grad_accum_count = args.accum_count n_gpu = args.world_size if device_id >= 0: gpu_rank = int(args.gpu_ranks[device_id]) else: gpu_rank = 0 n_gpu = 0 print('gpu_rank %d' % gpu_rank) tensorboard_log_dir = args.model_path writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) trainer = Trainer(args, model, optims, loss, grad_accum_count, n_gpu, gpu_rank, report_manager) # print(tr) if (model): n_params = _tally_parameters(model) logger.info('* number of parameters: %d' % n_params) return trainer
def evaluate(self, data, buckets=8, batch_size=5000, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.train() logger.info("Load the data") dataset = Dataset(self.transform, data) dataset.build(args.batch_size, args.buckets) logger.info(f"\n{dataset}") logger.info("Evaluate the dataset") start = datetime.now() loss, metric = self._evaluate(dataset.loader) elapsed = datetime.now() - start logger.info(f"loss: {loss:.4f} - {metric}") logger.info( f"{elapsed}s elapsed, {len(dataset)/elapsed.total_seconds():.2f} Sents/s" ) return loss, metric
def _save(self, step): real_model = self.model # real_generator = (self.generator.module # if isinstance(self.generator, torch.nn.DataParallel) # else self.generator) model_state_dict = real_model.state_dict() # generator_state_dict = real_generator.state_dict() checkpoint = { 'model': model_state_dict, # 'generator': generator_state_dict, 'opt': self.args, 'optims': self.optims, } checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) logger.info("Saving checkpoint %s" % checkpoint_path) # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) if (not os.path.exists(checkpoint_path)): torch.save(checkpoint, checkpoint_path) return checkpoint, checkpoint_path
def test_abs(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) print(checkpoint) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = { 'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]'] } predictor = build_predictor(args, tokenizer, symbols, model, logger) predictor.translate(test_iter, step)
def log(self, *args, **kwargs): logger.info(*args, **kwargs)
def _lazy_dataset_loader(pt_file, corpus_type): dataset = torch.load(pt_file) logger.info('Loading %s dataset from %s, number of examples: %d' % (corpus_type, pt_file, len(dataset))) return dataset
def train(self, train, dev, test, buckets=32, batch_size=5000, lr=8e-4, mu=.9, nu=.9, epsilon=1e-12, clip=5.0, decay=.75, decay_steps=5000, step_decay_factor=0.5, step_decay_patience=15, epochs=5000, patience=100, verbose=True, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.train() if dist.is_initialized(): args.batch_size = args.batch_size // dist.get_world_size() logger.info("Load the data") train = Dataset(self.transform, args.train, **args) dev = Dataset(self.transform, args.dev) test = Dataset(self.transform, args.test) train.build(args.batch_size, args.buckets, True, dist.is_initialized()) dev.build(args.batch_size, args.buckets) test.build(args.batch_size, args.buckets) logger.info( f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") logger.info(f"{self.model}\n") if dist.is_initialized(): self.model = DDP(self.model, device_ids=[dist.get_rank()], find_unused_parameters=True) self.optimizer = Adam(self.model.parameters(), args.lr, (args.mu, args.nu), args.epsilon) if self.args.learning_rate_schedule == 'Exponential': self.scheduler = ExponentialLR(self.optimizer, args.decay**(1 / args.decay_steps)) elif self.args.learning_rate_schedule == 'Plateau': self.scheduler = ReduceLROnPlateau( self.optimizer, 'max', factor=args.step_decay_factor, patience=args.step_decay_patience, verbose=True) elapsed = timedelta() best_e, best_metric = 1, Metric() best_metric_test = Metric() for epoch in range(1, args.epochs + 1): start = datetime.now() logger.info(f"Epoch {epoch} / {args.epochs}:") loss = self._train(train.loader) logger.info(f"{'train:':6} - loss: {loss:.4f}") loss, dev_metric = self._evaluate(dev.loader) logger.info(f"{'dev:':6} - loss: {loss:.4f} - {dev_metric}") loss, test_metric = self._evaluate(test.loader) logger.info(f"{'test:':6} - loss: {loss:.4f} - {test_metric}") t = datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric dev_metric_name = '_dev_LP_{:.2f}_LR_{:.2f}_LF_{:.2f}.pt'.format( 100 * best_metric.lp, 100 * best_metric.lr, 100 * best_metric.lf) if is_master(): self.save(args.path + dev_metric_name) logger.info(f"{t}s elapsed (saved)\n") keep_last_n_checkpoint(args.path + '_dev_', n=5) else: logger.info(f"{t}s elapsed\n") elapsed += t if self.args.learning_rate_schedule == 'Plateau': self.scheduler.step(best_metric.score) # if epoch - best_e >= args.patience: # break loss, metric = self.load(args.path)._evaluate(test.loader) logger.info(f"Epoch {best_e} saved") logger.info(f"{'dev:':6} - {best_metric}") logger.info(f"{'test:':6} - {metric}") logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch")
def train_abs_single(args, device_id): init_logger(args.log_file) logger.info(str(args)) device = "cpu" if args.visible_gpus == '-1' else "cuda" logger.info('Device ID %d' % device_id) logger.info('Device %s' % device) torch.manual_seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True if device_id >= 0: torch.cuda.set_device(device_id) torch.cuda.manual_seed(args.seed) if args.train_from != '': logger.info('Loading checkpoint from %s' % args.train_from) checkpoint = torch.load(args.train_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) else: checkpoint = None if (args.load_from_extractive != ''): logger.info('Loading bert from extractive model %s' % args.load_from_extractive) bert_from_extractive = torch.load( args.load_from_extractive, map_location=lambda storage, loc: storage) bert_from_extractive = bert_from_extractive['model'] else: bert_from_extractive = None torch.manual_seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True def train_iter_fct(): return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device, shuffle=True, is_test=False) model = AbsSummarizer(args, device, checkpoint, bert_from_extractive) if (args.sep_optim): optim_bert = model_builder.build_optim_bert(args, model, checkpoint) optim_dec = model_builder.build_optim_dec(args, model, checkpoint) optim = [optim_bert, optim_dec] else: optim = [model_builder.build_optim(args, model, checkpoint)] logger.info(model) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = { 'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]'] } train_loss = abs_loss(model.generator, symbols, model.vocab_size, device, train=True, label_smoothing=args.label_smoothing) trainer = build_trainer(args, device_id, model, optim, train_loss) trainer.train(train_iter_fct, args.train_steps)
def test(self, test_iter, step, cal_lead=False, cal_oracle=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. def _get_ngrams(n, text): ngram_set = set() text_length = len(text) max_index_ngram_start = text_length - n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(text[i:i + n])) return ngram_set def _block_tri(c, p): tri_c = _get_ngrams(3, c.split()) for s in p: tri_s = _get_ngrams(3, s.split()) if len(tri_c.intersection(tri_s)) > 0: return True return False if (not cal_lead and not cal_oracle): self.model.eval() stats = Statistics() can_path = '%s_step%d.candidate' % (self.args.result_path, step) gold_path = '%s_step%d.gold' % (self.args.result_path, step) with open(can_path, 'w') as save_pred: with open(gold_path, 'w') as save_gold: with torch.no_grad(): for batch in test_iter: gold = [] pred = [] if (cal_lead): selected_ids = [list(range(batch.clss.size(1))) ] * batch.batch_size for i, idx in enumerate(selected_ids): _pred = [] if (len(batch.src_str[i]) == 0): continue for j in selected_ids[i][:len(batch.src_str[i])]: if (j >= len(batch.src_str[i])): continue candidate = batch.src_str[i][j].strip() _pred.append(candidate) if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): break _pred = '<q>'.join(_pred) if (self.args.recall_eval): _pred = ' '.join( _pred.split() [:len(batch.tgt_str[i].split())]) pred.append(_pred) gold.append(batch.tgt_str[i]) for i in range(len(gold)): save_gold.write(gold[i].strip() + '\n') for i in range(len(pred)): save_pred.write(pred[i].strip() + '\n') if (step != -1 and self.args.report_rouge): rouges = test_rouge(self.args.temp_dir, can_path, gold_path) logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) self._report_step(0, step, valid_stats=stats) return stats
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') # step = self.optim._step + 1 step = self.optims[0]._step + 1 true_batchs = [] accum = 0 normalization = 0 train_iter = train_iter_fct() total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): true_batchs.append(batch) num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum() normalization += num_tokens.item() accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) report_stats = self._maybe_report_training( step, train_steps, self.optims[0].learning_rate, report_stats) true_batchs = [] accum = 0 normalization = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats