def on_resume_begin(self, *args: Any, **kwargs: Any) -> None: self.mbar: master_bar = kwargs.get('master_bar') mbar: master_bar = kwargs.get('master_bar') trainer_state: Dict = self.statemgr.state.get('trainer') epoch_curr = trainer_state.get('epoch')['curr'] trainer_state.get('epoch')['curr'] = epoch_curr - 1 metric_state: Dict = self.statemgr.state.get('metric') # clean up metric clean_up_metric_resume(metric_state, epoch_curr) # header write line = line_head_builder(metric_state) if isnotebook(): mbar.write(line, table=True) else: mbar.write(line, table=False) for epoch in range(epoch_curr - 1): metric_state: Dict = self.statemgr.state.get('metric') epoch_state: Dict = trainer_state.get('epoch') tdelta = time_formatter(epoch_state.get("time")[epoch]) tremain = time_formatter(epoch_state.get("remain")[epoch]) line = line_builder_resume(metric_state, epoch, tdelta, tremain) if isnotebook(): mbar.write(line, table=True) else: mbar.write(line, table=False)
def on_validate_end(self, *args: Any, **kwargs: Any) -> None: mbar: master_bar = kwargs.get('master_bar') trainer_state: TrainerState = self.statemgr.get_state('trainer') epoch_curr = trainer_state.get_property('epoch')['curr'] if epoch_curr == 0: # show header for first time metric_state: MetricState = self.statemgr.get_state('metric') line = line_head_builder(metric_state) if isnotebook(): mbar.write(line, table=True) else: mbar.write(line, table=False)
def on_epoch_end(self, *args: Any, **kwargs: Any) -> None: mbar: master_bar = kwargs.get('master_bar') epoch: int = kwargs.get("epoch") trainer_state: TrainerState = self.statemgr.get_state('trainer') metric_state: MetricState = self.statemgr.get_state('metric') epoch_state: Dict = trainer_state.get_property('epoch') tdelta, tremain = time_delta_remain(epoch_state) line = line_builder(metric_state, epoch, tdelta, tremain) if isnotebook(): mbar.write(line, table=True) else: mbar.write(line, table=False) epoch_curr = trainer_state.get_property('epoch')['curr'] if epoch_curr > 1: graph = graph_builder(metric_state, trainer_state) mbar.names = ['trn_loss', 'val_loss'] if isnotebook(): mbar.update_graph(graph)
def line_head_builder(metric_state: Dict): train: Dict = metric_state.get('train') line = ['epoch'] for val in train.keys(): line.append(f'trn_{val}') line.append(f'val_{val}') line.append('time') line.append('remain') if isnotebook(): return line else: return build_line_console(line)
def line_builder_resume(metric_state: Dict, epoch, tdelta, tremain): train: Dict = metric_state.get('train') valid: Dict = metric_state.get('valid') line = [f'{epoch+1}'] for key in train.keys(): line.append(f"{train[key]['epoch'][epoch]:.6f}") line.append(f"{valid[key]['epoch'][epoch]:.6f}") line.append(f'{tdelta}') line.append(f'{tremain}') if isnotebook(): return line else: return build_line_console(line)