Пример #1
0
    def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter):
        """Execute iteration level event write operation based on Ignite engine.state data.
        Default is to write the loss value of current iteration.

        Args:
            engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator.
            writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler.

        """
        loss = self.output_transform(engine.state.output)
        if loss is None:
            return  # do nothing if output is empty
        if isinstance(loss, dict):
            for name in sorted(loss):
                value = loss[name]
                if not is_scalar(value):
                    warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,'
                                  ' make sure `output_transform(engine.state.output)` returns'
                                  ' a scalar or dictionary of key and scalar pairs to avoid this warning.'
                                  ' {}:{}'.format(name, type(value)))
                    continue  # not plot multi dimensional output
                writer.add_scalar(name, value.item() if torch.is_tensor(value) else value, engine.state.iteration)
        elif is_scalar(loss):  # not printing multi dimensional output
            writer.add_scalar(self.tag_name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration)
        else:
            warnings.warn('ignoring non-scalar output in TensorBoardStatsHandler,'
                          ' make sure `output_transform(engine.state.output)` returns'
                          ' a scalar or a dictionary of key and scalar pairs to avoid this warning.'
                          ' {}'.format(type(loss)))
        writer.flush()
Пример #2
0
    def _default_iteration_print(self, engine: Engine):
        """Execute iteration log operation based on Ignite engine.state data.
        Print the values from ignite state.logs dict.
        Default behavior is to print loss from output[1], skip if output[1] is not loss.

        Args:
            engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator.

        """
        loss = self.output_transform(engine.state.output)
        if loss is None:
            return  # no printing if the output is empty

        out_str = ''
        if isinstance(loss, dict):  # print dictionary items
            for name in sorted(loss):
                value = loss[name]
                if not is_scalar(value):
                    warnings.warn(
                        'ignoring non-scalar output in StatsHandler,'
                        ' make sure `output_transform(engine.state.output)` returns'
                        ' a scalar or dictionary of key and scalar pairs to avoid this warning.'
                        ' {}:{}'.format(name, type(value)))
                    continue  # not printing multi dimensional output
                out_str += self.key_var_format.format(
                    name,
                    value.item() if torch.is_tensor(value) else value)
        else:
            if is_scalar(loss):  # not printing multi dimensional output
                out_str += self.key_var_format.format(
                    self.tag_name,
                    loss.item() if torch.is_tensor(loss) else loss)
            else:
                warnings.warn(
                    'ignoring non-scalar output in StatsHandler,'
                    ' make sure `output_transform(engine.state.output)` returns'
                    ' a scalar or a dictionary of key and scalar pairs to avoid this warning.'
                    ' {}'.format(type(loss)))

        if not out_str:
            return  # no value to print

        num_iterations = engine.state.epoch_length
        current_iteration = (engine.state.iteration - 1) % num_iterations + 1
        current_epoch = engine.state.epoch
        num_epochs = engine.state.max_epochs

        base_str = "Epoch: {}/{}, Iter: {}/{} --".format(
            current_epoch, num_epochs, current_iteration, num_iterations)

        self.logger.info(' '.join([base_str, out_str]))