def all_gather_stats_list(stat_list, max_size=4096): """ Gather a `Statistics` list accross all processes/nodes Args: stat_list(list([`Statistics`])): list of statistics objects to gather accross all processes/nodes max_size(int): max buffer size to use Returns: our_stats(list([`Statistics`])): list of updated stats """ from torch.distributed import get_rank from distributed import all_gather_list # Get a list of world_size lists with len(stat_list) Statistics objects all_stats = all_gather_list(stat_list, max_size=max_size) our_rank = get_rank() our_stats = all_stats[our_rank] for other_rank, stats in enumerate(all_stats): if other_rank == our_rank: continue for i, stat in enumerate(stats): our_stats[i].update(stat, update_n_src_words=True) return our_stats
def train(self, train_iter_fct, train_steps): logger.info('Start training...') step = self.optim._step + 1 true_batchs = [] accum = 0 normalization = 0 n_gpu = self.n_gpu gpu_rank = self.gpu_rank grad_accum_count = self.grad_accum_count # Iterable of training batches. train_iter = train_iter_fct() # Configure statistics report. total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) # Training loop. while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if n_gpu == 0 or i % n_gpu == gpu_rank: true_batchs.append(batch) normalization += batch.batch_size accum += 1 if accum == grad_accum_count: reduce_counter += 1 if n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) # Gradient accumulation for model. self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) # Report statistics for training. report_stats = self._maybe_report_training( step, train_steps, self.optim.learning_rate, report_stats) # Initialize variables true_batchs = [] accum = 0 normalization = 0 if step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0: self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') # step = self.optim._step + 1 step = self.optim._step + 1 true_batchs = [] accum = 0 normalization = 0 train_iter = train_iter_fct() total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): true_batchs.append(batch) normalization += batch.batch_size accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) report_stats = self._maybe_report_training( step, train_steps, self.optim.learning_rate, report_stats) true_batchs = [] accum = 0 normalization = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') step = self.optims[0]._step + 1 true_batchs = [] accum = 0 tgt_tokens = 0 src_tokens = 0 sents = 0 examples = 0 train_iter = train_iter_fct() total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) while step <= train_steps: for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): true_batchs.append(batch) tgt_tokens += batch.tgt[:, 1:].ne( self.abs_loss.padding_idx).sum().item() src_tokens += batch.src[:, 1:].ne( self.abs_loss.padding_idx).sum().item() sents += batch.src.size(0) examples += batch.tgt.size(0) accum += 1 if accum == self.grad_accum_count: if self.n_gpu > 1: tgt_tokens = sum( distributed.all_gather_list(tgt_tokens)) src_tokens = sum( distributed.all_gather_list(src_tokens)) sents = sum(distributed.all_gather_list(sents)) examples = sum( distributed.all_gather_list(examples)) normalization = (tgt_tokens, src_tokens, sents, examples) self._gradient_calculation(true_batchs, normalization, total_stats, report_stats, step) report_stats = self._maybe_report_training( step, train_steps, self.optims[0].learning_rate, report_stats) true_batchs = [] accum = 0 src_tokens = 0 tgt_tokens = 0 sents = 0 examples = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats
def train(self, args, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') # step = self.optim._step + 1 step = self.optims[0]._step + 1 true_batchs = [] accum = 0 normalization = 0 train_iter = train_iter_fct() total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): true_batchs.append(batch) num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum() normalization += num_tokens.item() accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats, step, train_steps) self._maybe_report_training( step, train_steps, self.optims[0].learning_rate, report_stats) acc, ppl, xent, lr = self._maybe_return_stats( step, train_steps, self.optims[0].learning_rate, report_stats) if step == train_steps: log_param(f"{args.mode}_lr", lr) log_metric(f"{args.mode}_acc", acc) log_metric(f"{args.mode}_ppl", ppl) log_metric(f"{args.mode}_xent", xent) true_batchs = [] accum = 0 normalization = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') # step = self.optim._step + 1 step = self.optims[0]._step + 1 true_batchs = [] accum = 0 normalization = 0 train_iter = train_iter_fct() total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): true_batchs.append(batch) num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum() normalization += num_tokens.item() accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) report_stats = self._maybe_report_training( step, train_steps, self.optims[0].learning_rate, report_stats) if step % self.args.report_every == 0: self.model.eval() logger.info('Model in set eval state') valid_iter = data_loader.Dataloader( self.args, load_dataset(self.args, 'test', shuffle=False), self.args.batch_size, "cuda", shuffle=False, is_test=True) tokenizer = BertTokenizer.from_pretrained( self.args.model_path, do_lower_case=True) symbols = { 'BOS': tokenizer.vocab['[unused1]'], 'EOS': tokenizer.vocab['[unused2]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused3]'] } valid_loss = abs_loss(self.model.generator, symbols, self.model.vocab_size, train=False, device="cuda") trainer = build_trainer(self.args, 0, self.model, None, valid_loss) stats = trainer.validate(valid_iter, step) self.report_manager.report_step( self.optims[0].learning_rate, step, train_stats=None, valid_stats=stats) self.model.train() true_batchs = [] accum = 0 normalization = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats
def validate(self, valid_iter_fct, step=0): """Main validation process of MTL. Args: train_iter_fct (function): return a instance of data.data_loader.MetaDataloader. """ logger.info('Start validating...') step = 0 ckpt_step = self.optims[0]._step # resume the step recorded in optims true_sup_batchs = [] true_qry_batchs = [] accum = 0 task_accum = 0 sup_normalization = 0 qry_normalization = 0 # Dataloader valid_iter = valid_iter_fct() # class Dataloader # Reporter and Statistics report_outer_stats = Statistics() report_inner_stats = Statistics() self._start_report_manager(start_time=report_outer_stats.start_time) # Make sure the accumulation of gradient is correct assert self.args.accum_count == self.args.num_batch_in_task while step <= self.args.train_steps: for i, (sup_batch, qry_batch) in enumerate(valid_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): # Collect batches (= self.grad_accum_count) as real batch true_sup_batchs.append(sup_batch) true_qry_batchs.append(qry_batch) # Count non-padding words in bathces sup_num_tokens = sup_batch.tgt[:, 1:].ne( self.loss.padding_idx).sum() qry_num_tokens = qry_batch.tgt[:, 1:].ne( self.loss.padding_idx).sum() sup_normalization += sup_num_tokens.item() qry_normalization += qry_num_tokens.item() # Gradient normalize for tasks qry_normalization = qry_normalization * self.args.num_task accum += 1 if accum == self.args.num_batch_in_task: task_accum += 1 # NOTE: Clear optimizer state self.optims_inner[task_accum - 1][0].optimizer.clear_states() self.optims_inner[task_accum - 1][1].optimizer.clear_states() #=============== Inner Update ================ # Sum-up non-padding words from multi-GPU if self.n_gpu > 1: sup_normalization = sum( distributed.all_gather_list(sup_normalization)) inner_step = 1 while inner_step <= self.args.inner_train_steps: # Compute gradient and update self._maml_inner_gradient_accumulation( true_sup_batchs, sup_normalization, report_inner_stats, inner_step, task_accum, inference_mode=True) # Call self.report_manager to report training process (if reach args.report_every) report_inner_stats = self._maybe_report_inner_training( inner_step, self.args.inner_train_steps, self.optims_inner[task_accum - 1][0].learning_rate, self.optims_inner[task_accum - 1][1].learning_rate, report_inner_stats) inner_step += 1 if inner_step > self.args.inner_train_steps: break #=============================================== #=============== Outer No Update ================ self.model.eval() # Calculate loss only, no update for the initialization self._valid(true_qry_batchs, report_outer_stats, ckpt_step) # Clean fast weight self.model._clean_fast_weights_mode() self.model.train() #=============================================== # Reset true_sup_batchs = [] true_qry_batchs = [] accum = 0 sup_normalization = 0 qry_normalization = 0 if (task_accum == self.args.num_task): # Reset task_accum = 0 # Check steps to stop step += 1 if step > self.args.train_steps: break # End for an epoch, reload & reset valid_iter = valid_iter_fct() # Report average result afer all validation steps self._report_step(0, ckpt_step, valid_stats=report_outer_stats) # first arg is lr self.report_manager.tensorboard_writer.flush( ) # force to output the log return report_outer_stats
def train(self, train_iter_fct): """Main training process of MTL. Args: train_iter_fct (function): return a instance of data.data_loader.MetaDataloader. """ logger.info('Start training... (' + str(self.args.maml_type) + ')') step = self.optims[0]._step + 1 # resume the step recorded in optims true_sup_batchs = [] true_qry_batchs = [] accum = 0 task_accum = 0 sup_normalization = 0 qry_normalization = 0 # Dataloader train_iter = train_iter_fct() # class Dataloader # Reporter and Statistics report_outer_stats = Statistics() report_inner_stats = Statistics() self._start_report_manager(start_time=report_outer_stats.start_time) # Current only support MAML assert self.args.maml_type == 'maml' # Make sure the accumulation of gradient is correct assert self.args.accum_count == self.args.num_batch_in_task while step <= self.args.train_steps: # NOTE: Outer loop for i, (sup_batch, qry_batch) in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): # Collect batches (= self.grad_accum_count) as real batch true_sup_batchs.append(sup_batch) true_qry_batchs.append(qry_batch) # Count non-padding words in bathces sup_num_tokens = sup_batch.tgt[:, 1:].ne( self.loss.padding_idx).sum() qry_num_tokens = qry_batch.tgt[:, 1:].ne( self.loss.padding_idx).sum() sup_normalization += sup_num_tokens.item() qry_normalization += qry_num_tokens.item() accum += 1 if accum == self.args.num_batch_in_task: task_accum += 1 #=============== Inner Update ================ # Sum-up non-padding words from multi-GPU if self.n_gpu > 1: sup_normalization = sum( distributed.all_gather_list(sup_normalization)) inner_step = 1 while inner_step <= self.args.inner_train_steps: # NOTE: Inner loop # Compute gradient and update self._maml_inner_gradient_accumulation( true_sup_batchs, sup_normalization, report_inner_stats, inner_step, task_accum) # Call self.report_manager to report training process (if reach args.report_every) report_inner_stats = self._maybe_report_inner_training( inner_step, self.args.inner_train_steps, self.optims_inner[task_accum - 1][0].learning_rate, self.optims_inner[task_accum - 1][1].learning_rate, report_inner_stats) inner_step += 1 if inner_step > self.args.inner_train_steps: break #=============== Outer Update ================ # Sum-up non-padding words from multi-GPU if self.n_gpu > 1: qry_normalization = sum( distributed.all_gather_list(qry_normalization)) # Compute gradient and update self._maml_outter_gradient_accumulation( true_qry_batchs, qry_normalization, report_outer_stats, step, inner_step, task_accum) if (task_accum == self.args.num_task): # Calculate gradient norm total_norm = 0.0 for p in self.model.parameters(): if (p.grad is not None): param_norm = p.grad.data.norm(2) total_norm += param_norm.item()**2 total_norm = total_norm**(1. / 2) #=============================================== # Reset true_sup_batchs = [] true_qry_batchs = [] accum = 0 sup_normalization = 0 qry_normalization = 0 if (task_accum == self.args.num_task): # Call self.report_manager to report training process(if reach args.report_every) report_outer_stats = self._maybe_report_training( step, self.args.train_steps, self.optims[0].learning_rate, self.optims[1].learning_rate, report_outer_stats) # Reset task_accum = 0 # Save if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) # Check steps to stop step += 1 if step > self.args.train_steps: break # End for an epoch, reload and reset train_iter = train_iter_fct() self.report_manager.tensorboard_writer.flush( ) # force to output the log
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') step = self.optims[0]._step + 1 true_batchs = [] accum = 0 normalization = 0 # Dataloader train_iter = train_iter_fct() # class Dataloader # Reporter and Statistics report_stats = Statistics() self._start_report_manager(start_time=report_stats.start_time) while step <= train_steps: for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): # Collect batches (= self.grad_accum_count) as real batch true_batchs.append(batch) # Count non-padding words in bathces num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum() normalization += num_tokens.item() accum += 1 if accum == self.grad_accum_count: # Sum-up non-padding words from multi-GPU if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) # Compute gradient and update self._gradient_accumulation(true_batchs, normalization, report_stats) # Call self.report_manager to report training process report_stats = self._maybe_report_training( step, train_steps, self.optims[0].learning_rate, self.optims[1].learning_rate, report_stats) # Reset true_batchs = [] accum = 0 normalization = 0 # Save if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) # Check steps to stop step += 1 if step > train_steps: break # End for an epoch, reload data train_iter = train_iter_fct() self.report_manager.tensorboard_writer.flush( ) # force to output the log
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') list_xent_train = list() list_xent_valid = list() # step = self.optim._step + 1 step = self.optims[0]._step true_batchs = [] accum = 0 normalization = 0 train_iter = train_iter_fct() total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): true_batchs.append(batch) num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum() normalization += num_tokens.item() accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) report_stats, train_xent = self._maybe_report_training( step, train_steps, self.optims[0].learning_rate, report_stats) if step % 100 == 0: _, train_xent = train_xent list_xent_train.append(train_xent) valid_iter = valid_iter_fct() list_xent_valid.append( self.validate(valid_iter=valid_iter, step=step).xent()) self.save_plot( list_xent_train, list_xent_valid, "/content/drive/MyDrive/Colab Notebooks/Models/PreSumm/save_plot_Baseline_" + str(step) + ".txt") arr_train = np.arange(start=0, stop=len(list_xent_train) * 100, step=100) arr_valid = np.arange(start=0, stop=len(list_xent_valid) * 100, step=100) plt.clf() plt.title("Loss Function") plt.plot(arr_valid, list_xent_valid, color="orange", label="Validate") plt.plot(arr_train, list_xent_train, color="blue", label="Train") plt.ylabel("Cross Entropy(Xent)") plt.xlabel("Step") plt.legend() plt.savefig(str(step) + ".jpg") true_batchs = [] accum = 0 normalization = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') # step = self.optim._step + 1 step = self.optim._step + 1 true_batchs = [] accum = 0 normalization = 0 neg_valid_loss = [] # minheap, minum value at top heapq.heapify( neg_valid_loss) # use neg loss to find top 3 largest neg loss total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) #select_counts = np.random.choice(range(3), train_steps + 1) cur_epoch = 0 train_iter = train_iter_fct() #logger.info('Current Epoch:%d' % cur_epoch) #logger.info('maxEpoch:%d' % self.args.max_epoch) #while step <= train_steps: while cur_epoch < self.args.max_epoch: reduce_counter = 0 logger.info('Current Epoch:%d' % cur_epoch) for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): # from batch.labels, add selected sent index to batch # after teacher forcing, use model selected sentences # or infer scores of batch and get selected sent index # then add selected sent index to the batch true_batchs.append(batch) #normalization += batch.batch_size ##loss normalized wrong normalization = batch.batch_size ##loss recorded correspond to each minibatch accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) report_stats = self._maybe_report_training( step, train_steps, self.optim.learning_rate, report_stats) true_batchs = [] accum = 0 normalization = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): valid_iter = data_loader.Dataloader( self.args, load_dataset(self.args, 'valid', shuffle=False), self.args.batch_size * 10, self.device, shuffle=False, is_test=True) #batch_size train: 3000, test: 60000 stats = self.validate(valid_iter, step, self.args.valid_by_rouge) self.model.train() # back to training cur_valid_loss = stats.xent() checkpoint_path = os.path.join( self.args.model_path, 'model_step_%d.pt' % step) # if len(neg_valid_loss) < self.args.save_model_count: self._save(step) heapq.heappush(neg_valid_loss, (-cur_valid_loss, checkpoint_path)) # else: # if -cur_valid_loss > neg_valid_loss[0][0]: # heapq.heappush(neg_valid_loss, (-cur_valid_loss, checkpoint_path)) # worse_loss, worse_model = heapq.heappop(neg_valid_loss) # os.remove(worse_model) # self._save(step) #else do not save it logger.info('step_%d:%s' % (step, str(neg_valid_loss))) step += 1 if step > train_steps: break cur_epoch += 1 train_iter = train_iter_fct() return total_stats, neg_valid_loss
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` [2020-07-18 15:29:09,747 INFO] Step 550/ 4000; acc: 25.00; ppl: 212.10; xent: 5.36; lr: 0.00000039; 0/ 2 tok/s; 3145 sec Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') # step = self.optim._step + 1 step = self.optims[0]._step + 1 true_batchs = [] accum = 0 normalization = 0 train_iter = train_iter_fct() total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) least_loss = float('inf') while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): true_batchs.append(batch) num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum() normalization += num_tokens.item() accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) report_stats = self._maybe_report_training( step, train_steps, self.optims[0].learning_rate, report_stats) true_batchs = [] accum = 0 normalization = 0 if (step > self.save_checkpoint_steps and step % 50 == 0): if (self.current_loss < least_loss): least_loss = self.current_loss if (self.gpu_rank == 0): self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats