Exemplo n.º 1
0
    def on_resume_begin(self, *args: Any, **kwargs: Any) -> None:
        self.mbar: master_bar = kwargs.get('master_bar')

        mbar: master_bar = kwargs.get('master_bar')
        trainer_state: Dict = self.statemgr.state.get('trainer')
        epoch_curr = trainer_state.get('epoch')['curr']
        trainer_state.get('epoch')['curr'] = epoch_curr - 1

        metric_state: Dict = self.statemgr.state.get('metric')
        # clean up metric
        clean_up_metric_resume(metric_state, epoch_curr)

        # header write
        line = line_head_builder(metric_state)
        if isnotebook():
            mbar.write(line, table=True)
        else:
            mbar.write(line, table=False)

        for epoch in range(epoch_curr - 1):
            metric_state: Dict = self.statemgr.state.get('metric')
            epoch_state: Dict = trainer_state.get('epoch')
            tdelta = time_formatter(epoch_state.get("time")[epoch])
            tremain = time_formatter(epoch_state.get("remain")[epoch])
            line = line_builder_resume(metric_state, epoch, tdelta, tremain)
            if isnotebook():
                mbar.write(line, table=True)
            else:
                mbar.write(line, table=False)
Exemplo n.º 2
0
 def on_validate_end(self, *args: Any, **kwargs: Any) -> None:
     mbar: master_bar = kwargs.get('master_bar')
     trainer_state: TrainerState = self.statemgr.get_state('trainer')
     epoch_curr = trainer_state.get_property('epoch')['curr']
     if epoch_curr == 0: # show header for first time
         metric_state: MetricState = self.statemgr.get_state('metric')
         line = line_head_builder(metric_state)
         if isnotebook():
             mbar.write(line, table=True)
         else:
             mbar.write(line, table=False)
Exemplo n.º 3
0
    def on_epoch_end(self, *args: Any, **kwargs: Any) -> None:
        mbar: master_bar = kwargs.get('master_bar')
        epoch: int = kwargs.get("epoch")
        trainer_state: TrainerState = self.statemgr.get_state('trainer')
        metric_state: MetricState = self.statemgr.get_state('metric')

        epoch_state: Dict = trainer_state.get_property('epoch')
        tdelta, tremain = time_delta_remain(epoch_state)

        line = line_builder(metric_state, epoch, tdelta, tremain)
        if isnotebook():
            mbar.write(line, table=True)
        else:
            mbar.write(line, table=False)

        epoch_curr = trainer_state.get_property('epoch')['curr']
        if epoch_curr > 1:
            graph = graph_builder(metric_state, trainer_state)
            mbar.names = ['trn_loss', 'val_loss']
            if isnotebook():
                mbar.update_graph(graph)
Exemplo n.º 4
0
def line_head_builder(metric_state: Dict):
    train: Dict = metric_state.get('train')

    line = ['epoch']
    for val in train.keys():
        line.append(f'trn_{val}')
        line.append(f'val_{val}')
    line.append('time')
    line.append('remain')

    if isnotebook():
        return line
    else:
        return build_line_console(line)
Exemplo n.º 5
0
def line_builder_resume(metric_state: Dict, epoch, tdelta, tremain):
    train: Dict = metric_state.get('train')
    valid: Dict = metric_state.get('valid')

    line = [f'{epoch+1}']
    for key in train.keys():
        line.append(f"{train[key]['epoch'][epoch]:.6f}")
        line.append(f"{valid[key]['epoch'][epoch]:.6f}")
    line.append(f'{tdelta}')
    line.append(f'{tremain}')

    if isnotebook():
        return line
    else:
        return build_line_console(line)