def train(self, dataflow: DataLoader, *, num_epochs: int = 9999999, callbacks: Optional[List[Callback]] = None) -> None: self.dataflow = dataflow self.steps_per_epoch = len(self.dataflow) self.num_epochs = num_epochs if callbacks is None: callbacks = [] self.callbacks = Callbacks(callbacks) self.summary = Summary() try: self.callbacks.set_trainer(self) self.summary.set_trainer(self) self.epoch_num = 0 self.global_step = 0 train_time = time.perf_counter() self.before_train() while self.epoch_num < self.num_epochs: self.epoch_num += 1 self.local_step = 0 logger.info('Epoch {}/{} started.'.format( self.epoch_num, self.num_epochs)) epoch_time = time.perf_counter() self.before_epoch() for feed_dict in self.dataflow: self.local_step += 1 self.global_step += 1 self.before_step(feed_dict) output_dict = self.run_step(feed_dict) self.after_step(output_dict) self.trigger_step() self.after_epoch() logger.info('Training finished in {}.'.format( humanize.naturaldelta(time.perf_counter() - epoch_time))) self.trigger_epoch() logger.info('Epoch finished in {}.'.format( humanize.naturaldelta(time.perf_counter() - epoch_time))) logger.success('{} epochs of training finished in {}.'.format( self.num_epochs, humanize.naturaldelta(time.perf_counter() - train_time))) except StopTraining as e: logger.info('Training was stopped by {}.'.format(str(e))) finally: self.after_train()
def _trigger(self) -> None: start_time = time.perf_counter() self.callbacks.before_epoch() with torch.no_grad(): orig_is_add_noise = self.trainer.model.nodes[ 0].noise_model_tq.is_add_noise orig_noise_total_prob = self.trainer.model.nodes[ 0].noise_model_tq.noise_total_prob orig_mode = self.trainer.model.nodes[0].noise_model_tq.mode for node in self.trainer.model.nodes: node.noise_model_tq.noise_total_prob = self.noise_total_prob node.noise_model_tq.is_add_noise = True node.noise_model_tq.mode = 'train' for feed_dict in tqdm.tqdm(self.dataflow, ncols=0): self.callbacks.before_step(feed_dict) output_dict = self.trainer.run_step(feed_dict) self.callbacks.after_step(output_dict) for node in self.trainer.model.nodes: node.noise_model_tq.is_add_noise = orig_is_add_noise node.noise_model_tq.noise_total_prob = orig_noise_total_prob node.noise_model_tq.mode = orig_mode self.callbacks.after_epoch() logger.info('Inference finished in {}.'.format( humanize.naturaldelta(time.perf_counter() - start_time)))
def _trigger_epoch(self) -> None: if self.trainer.epoch_num < self.trainer.num_epochs: self.times.append(time.perf_counter() - self.last_time) self.last_time = time.perf_counter() estimated_time = (self.trainer.num_epochs - self.trainer.epoch_num) * np.mean(self.times) logger.info('Estimated time left: {}.'.format( humanize.naturaldelta(estimated_time)))
def _trigger(self) -> None: start_time = time.perf_counter() self.callbacks.before_epoch() with torch.no_grad(): for feed_dict in tqdm.tqdm(self.dataflow, ncols=0): self.callbacks.before_step(feed_dict) output_dict = self.trainer.run_step(feed_dict) self.callbacks.after_step(output_dict) self.callbacks.after_epoch() logger.info('Inference finished in {}.'.format( humanize.naturaldelta(time.perf_counter() - start_time)))
def train(self, dataflow: DataLoader, *, num_epochs: int = 9999999, eval_interval: int = None, splits: List[str] = None, callbacks: Optional[List[Callback]] = None) -> None: self.dataflow = dataflow self.steps_per_epoch = len(self.dataflow) self.num_epochs = num_epochs if callbacks is None: callbacks = [] self.callbacks = Callbacks(callbacks) if splits is None: self.summary = {"0": Summary()} else: self.summary = {s: Summary(split=s) for s in splits} try: self.callbacks.set_trainer(self) for s in self.summary.values(): s.set_trainer(self) self.epoch_num = 0 self.global_step = 0 train_time = time.perf_counter() self.before_train() while self.epoch_num < self.num_epochs: self.epoch_num += 1 self.local_step = 0 logger.info("Epoch {}/{} started.".format( self.epoch_num, self.num_epochs)) epoch_time = time.perf_counter() self.before_epoch() for feed_dict in self.dataflow: self.local_step += 1 self.global_step += 1 self.before_step(feed_dict) output_dict = self.run_step(feed_dict) self.after_step(output_dict) self.trigger_step() self.after_epoch() logger.info("Training finished in {}.".format( humanize.naturaldelta(time.perf_counter() - epoch_time))) if eval_interval is not None: if self.epoch_num % eval_interval == 0: self.trigger_epoch() else: self.trigger_epoch() logger.info("Epoch finished in {}.".format( humanize.naturaldelta(time.perf_counter() - epoch_time))) logger.success("{} epochs of training finished in {}.".format( self.num_epochs, humanize.naturaldelta(time.perf_counter() - train_time), )) except StopTraining as e: logger.info("Training was stopped by {}.".format(str(e))) finally: self.after_train()