def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): """Execute iteration level event write operation based on Ignite engine.state data. Default is to write the loss value of current iteration. Args: engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. """ loss = self.output_transform(engine.state.output) if loss is None: return # do nothing if output is empty if isinstance(loss, dict): for name in sorted(loss): value = loss[name] if not is_scalar(value): warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,' ' make sure `output_transform(engine.state.output)` returns' ' a scalar or dictionary of key and scalar pairs to avoid this warning.' ' {}:{}'.format(name, type(value))) continue # not plot multi dimensional output writer.add_scalar(name, value.item() if torch.is_tensor(value) else value, engine.state.iteration) elif is_scalar(loss): # not printing multi dimensional output writer.add_scalar(self.tag_name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration) else: warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,' ' make sure `output_transform(engine.state.output)` returns' ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' ' {}'.format(type(loss))) writer.flush()
def _default_iteration_print(self, engine: Engine): """Execute iteration log operation based on Ignite engine.state data. Print the values from ignite state.logs dict. Default behavior is to print loss from output[1], skip if output[1] is not loss. Args: engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. """ loss = self.output_transform(engine.state.output) if loss is None: return # no printing if the output is empty out_str = '' if isinstance(loss, dict): # print dictionary items for name in sorted(loss): value = loss[name] if not is_scalar(value): warnings.warn( 'ignoring non-scalar output in StatsHandler,' ' make sure `output_transform(engine.state.output)` returns' ' a scalar or dictionary of key and scalar pairs to avoid this warning.' ' {}:{}'.format(name, type(value))) continue # not printing multi dimensional output out_str += self.key_var_format.format( name, value.item() if torch.is_tensor(value) else value) else: if is_scalar(loss): # not printing multi dimensional output out_str += self.key_var_format.format( self.tag_name, loss.item() if torch.is_tensor(loss) else loss) else: warnings.warn( 'ignoring non-scalar output in StatsHandler,' ' make sure `output_transform(engine.state.output)` returns' ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' ' {}'.format(type(loss))) if not out_str: return # no value to print num_iterations = engine.state.epoch_length current_iteration = (engine.state.iteration - 1) % num_iterations + 1 current_epoch = engine.state.epoch num_epochs = engine.state.max_epochs base_str = "Epoch: {}/{}, Iter: {}/{} --".format( current_epoch, num_epochs, current_iteration, num_iterations) self.logger.info(' '.join([base_str, out_str]))