Ejemplo n.º 1
0
def test_timer():
    sleep_t = 0.2
    n_iter = 3

    def _train_func(engine, batch):
        time.sleep(sleep_t)

    def _test_func(engine, batch):
        time.sleep(sleep_t)

    trainer = Engine(_train_func)
    tester = Engine(_test_func)

    t_total = Timer()
    t_batch = Timer(average=True)
    t_train = Timer()

    t_total.attach(trainer)
    t_batch.attach(
        trainer,
        pause=Events.ITERATION_COMPLETED,
        resume=Events.ITERATION_STARTED,
        step=Events.ITERATION_COMPLETED,
    )
    t_train.attach(trainer,
                   pause=Events.EPOCH_COMPLETED,
                   resume=Events.EPOCH_STARTED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(trainer):
        tester.run(range(n_iter))

    # Run "training"
    trainer.run(range(n_iter))

    def _equal(lhs, rhs):
        return round(lhs, 1) == round(rhs, 1)

    assert _equal(t_total.value(), (2 * n_iter * sleep_t))
    assert _equal(t_batch.value(), (sleep_t))
    assert _equal(t_train.value(), (n_iter * sleep_t))

    t_total.reset()
    assert _equal(t_total.value(), 0.0)
Ejemplo n.º 2
0
class BasicTimeProfiler:
    """
    BasicTimeProfiler can be used to profile the handlers,
    events, data loading and data processing times.

    Examples:

    .. code-block:: python

        from ignite.contrib.handlers import BasicTimeProfiler

        trainer = Engine(train_updater)

        # Create an object of the profiler and attach an engine to it
        profiler = BasicTimeProfiler()
        profiler.attach(trainer)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_intermediate_results():
            profiler.print_results(profiler.get_results())

        trainer.run(dataloader, max_epochs=3)

        profiler.write_results('path_to_dir/time_profiling.csv')

    """

    events_to_ignore = [
        Events.EXCEPTION_RAISED,
        Events.TERMINATE,
        Events.TERMINATE_SINGLE_EPOCH,
        Events.DATALOADER_STOP_ITERATION,
    ]

    def __init__(self):
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = None
        self.processing_times = None
        self.event_handlers_times = None

        self._events = [
            Events.EPOCH_STARTED,
            Events.EPOCH_COMPLETED,
            Events.ITERATION_STARTED,
            Events.ITERATION_COMPLETED,
            Events.GET_BATCH_STARTED,
            Events.GET_BATCH_COMPLETED,
            Events.COMPLETED,
        ]
        self._fmethods = [
            self._as_first_epoch_started,
            self._as_first_epoch_completed,
            self._as_first_iter_started,
            self._as_first_iter_completed,
            self._as_first_get_batch_started,
            self._as_first_get_batch_completed,
            self._as_first_completed,
        ]
        self._lmethods = [
            self._as_last_epoch_started,
            self._as_last_epoch_completed,
            self._as_last_iter_started,
            self._as_last_iter_completed,
            self._as_last_get_batch_started,
            self._as_last_get_batch_completed,
            self._as_last_completed,
        ]

    def _reset(self, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_STARTED: torch.zeros(total_num_iters),
        }

    def _as_first_started(self, engine):
        if hasattr(engine.state.dataloader, "__len__"):
            num_iters_per_epoch = len(engine.state.dataloader)
        else:
            num_iters_per_epoch = engine.state.epoch_length

        self.max_epochs = engine.state.max_epochs
        self.total_num_iters = self.max_epochs * num_iters_per_epoch
        self._reset(self.max_epochs, self.total_num_iters)

        self.event_handlers_names = {
            e: [
                h.__qualname__
                if hasattr(h, "__qualname__") else h.__class__.__name__
                for (h, _, _) in engine._event_handlers[e]
                if "BasicTimeProfiler." not in repr(
                    h)  # avoid adding internal handlers into output
            ]
            for e in Events if e not in self.events_to_ignore
        }

        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append(
            (self._as_last_started, (engine, ), {}))

        for e, m in zip(self._events, self._fmethods):
            engine._event_handlers[e].insert(0, (m, (engine, ), {}))

        for e, m in zip(self._events, self._lmethods):
            engine._event_handlers[e].append((m, (engine, ), {}))

        # Let's go
        self._event_handlers_timer.reset()

    def _as_last_started(self, engine):
        self.event_handlers_times[
            Events.STARTED][0] = self._event_handlers_timer.value()

    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t

    def _as_first_get_batch_started(self, engine):
        self._event_handlers_timer.reset()
        self._dataflow_timer.reset()

    def _as_last_get_batch_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t

    def _as_first_get_batch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_get_batch_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t

        d = self._dataflow_timer.value()
        self.dataflow_times[i] = d

        self._dataflow_timer.reset()

    def _as_first_iter_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t

        self._processing_timer.reset()

    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t

        self._event_handlers_timer.reset()

    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t

    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t

    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_completed(self, engine):
        self.event_handlers_times[
            Events.COMPLETED][0] = self._event_handlers_timer.value()

        # Remove added handlers:
        engine.remove_event_handler(self._as_last_started, Events.STARTED)

        for e, m in zip(self._events, self._fmethods):
            engine.remove_event_handler(m, e)

        for e, m in zip(self._events, self._lmethods):
            engine.remove_event_handler(m, e)

    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(
                0, (self._as_first_started, (engine, ), {}))

    @staticmethod
    def _compute_basic_stats(data):
        # compute on non-zero data:
        data = data[data > 0]
        out = [
            ("total",
             torch.sum(data).item() if len(data) > 0 else "not yet triggered")
        ]
        if len(data) > 1:
            out += [
                ("min/index", (torch.min(data).item(),
                               torch.argmin(data).item())),
                ("max/index", (torch.max(data).item(),
                               torch.argmax(data).item())),
                ("mean", torch.mean(data).item()),
                ("std", torch.std(data).item()),
            ]
        return OrderedDict(out)

    def get_results(self):
        """
        Method to fetch the aggregated profiler results after the engine is run

        .. code-block:: python

            results = profiler.get_results()

        """
        total_eh_time = sum([(self.event_handlers_times[e]).sum()
                             for e in Events
                             if e not in self.events_to_ignore])

        return OrderedDict([
            ("processing_stats",
             self._compute_basic_stats(self.processing_times)),
            ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
            (
                "event_handlers_stats",
                dict([(str(e.name).replace(".", "_"),
                       self._compute_basic_stats(self.event_handlers_times[e]))
                      for e in Events if e not in self.events_to_ignore] +
                     [("total_time", total_eh_time)]),
            ),
            (
                "event_handlers_names",
                {
                    str(e.name).replace(".", "_") + "_names": v
                    for e, v in self.event_handlers_names.items()
                },
            ),
        ])

    def write_results(self, output_path):
        """
        Method to store the unaggregated profiling results to a csv file

        .. code-block:: python

            profiler.write_results('path_to_dir/awesome_filename.csv')

        Example output:

        .. code-block:: text

            -----------------------------------------------------------------
            epoch iteration processing_stats dataflow_stats Event_STARTED ...
            1.0     1.0        0.00003         0.252387        0.125676
            1.0     2.0        0.00029         0.252342        0.125123

        """
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return

        iters_per_epoch = self.total_num_iters // self.max_epochs

        epochs = torch.arange(
            self.max_epochs,
            dtype=torch.float32).repeat_interleave(iters_per_epoch) + 1
        iterations = torch.arange(self.total_num_iters,
                                  dtype=torch.float32) + 1
        processing_stats = self.processing_times
        dataflow_stats = self.dataflow_times

        event_started = self.event_handlers_times[
            Events.STARTED].repeat_interleave(self.total_num_iters)
        event_completed = self.event_handlers_times[
            Events.COMPLETED].repeat_interleave(self.total_num_iters)
        event_epoch_started = self.event_handlers_times[
            Events.EPOCH_STARTED].repeat_interleave(iters_per_epoch)
        event_epoch_completed = self.event_handlers_times[
            Events.EPOCH_COMPLETED].repeat_interleave(iters_per_epoch)

        event_iter_started = self.event_handlers_times[
            Events.ITERATION_STARTED]
        event_iter_completed = self.event_handlers_times[
            Events.ITERATION_COMPLETED]
        event_batch_started = self.event_handlers_times[
            Events.GET_BATCH_STARTED]
        event_batch_completed = self.event_handlers_times[
            Events.GET_BATCH_COMPLETED]

        results_dump = torch.stack(
            [
                epochs,
                iterations,
                processing_stats,
                dataflow_stats,
                event_started,
                event_completed,
                event_epoch_started,
                event_epoch_completed,
                event_iter_started,
                event_iter_completed,
                event_batch_started,
                event_batch_completed,
            ],
            dim=1,
        ).numpy()

        results_df = pd.DataFrame(
            data=results_dump,
            columns=[
                "epoch",
                "iteration",
                "processing_stats",
                "dataflow_stats",
                "Event_STARTED",
                "Event_COMPLETED",
                "Event_EPOCH_STARTED",
                "Event_EPOCH_COMPLETED",
                "Event_ITERATION_STARTED",
                "Event_ITERATION_COMPLETED",
                "Event_GET_BATCH_STARTED",
                "Event_GET_BATCH_COMPLETED",
            ],
        )
        results_df.to_csv(output_path, index=False)

    @staticmethod
    def print_results(results):
        """
        Method to print the aggregated results from the profiler

        .. code-block:: python

            profiler.print_results(results)

        Example output:

        .. code-block:: text

             ----------------------------------------------------
            | Time profiling stats (in seconds):                 |
             ----------------------------------------------------
            total  |  min/index  |  max/index  |  mean  |  std

            Processing function:
            157.46292 | 0.01452/1501 | 0.26905/0 | 0.07730 | 0.01258

            Dataflow:
            6.11384 | 0.00008/1935 | 0.28461/1551 | 0.00300 | 0.02693

            Event handlers:
            2.82721

            - Events.STARTED: []
            0.00000

            - Events.EPOCH_STARTED: []
            0.00006 | 0.00000/0 | 0.00000/17 | 0.00000 | 0.00000

            - Events.ITERATION_STARTED: ['PiecewiseLinear']
            0.03482 | 0.00001/188 | 0.00018/679 | 0.00002 | 0.00001

            - Events.ITERATION_COMPLETED: ['TerminateOnNan']
            0.20037 | 0.00006/866 | 0.00089/1943 | 0.00010 | 0.00003

            - Events.EPOCH_COMPLETED: ['empty_cuda_cache', 'training.<locals>.log_elapsed_time', ]
            2.57860 | 0.11529/0 | 0.14977/13 | 0.12893 | 0.00790

            - Events.COMPLETED: []
            not yet triggered

        """
        def to_str(v):
            if isinstance(v, str):
                return v
            elif isinstance(v, tuple):
                return "{:.5f}/{}".format(v[0], v[1])
            return "{:.5f}".format(v)

        def odict_to_str(d):
            out = " | ".join([to_str(v) for v in d.values()])
            return out

        others = {
            k: odict_to_str(v) if isinstance(v, OrderedDict) else v
            for k, v in results["event_handlers_stats"].items()
        }

        others.update(results["event_handlers_names"])

        output_message = """
 ----------------------------------------------------
| Time profiling stats (in seconds):                 |
 ----------------------------------------------------
total  |  min/index  |  max/index  |  mean  |  std

Processing function:
{processing_stats}

Dataflow:
{dataflow_stats}

Event handlers:
{total_time:.5f}

- Events.STARTED: {STARTED_names}
{STARTED}

- Events.EPOCH_STARTED: {EPOCH_STARTED_names}
{EPOCH_STARTED}

- Events.ITERATION_STARTED: {ITERATION_STARTED_names}
{ITERATION_STARTED}

- Events.ITERATION_COMPLETED: {ITERATION_COMPLETED_names}
{ITERATION_COMPLETED}

- Events.EPOCH_COMPLETED: {EPOCH_COMPLETED_names}
{EPOCH_COMPLETED}

- Events.COMPLETED: {COMPLETED_names}
{COMPLETED}
""".format(
            processing_stats=odict_to_str(results["processing_stats"]),
            dataflow_stats=odict_to_str(results["dataflow_stats"]),
            **others,
        )
        print(output_message)
        return output_message
class DatasetExperiment(BaseExperiment):
    def __init__(self, train, niter, nepoch, eval=None, gpu_id=[0],
                 sacred_run=None, writers=None, root=None, corruption=None,
                 **kwargs):
        super(DatasetExperiment, self).__init__(**kwargs)

        self.train = train
        self.eval = eval

        self.corruption = corruption

        self.niter = niter
        self.nepoch = nepoch
        self.device = 'cuda:0'

        self.sacred_run = sacred_run

        self.gpu_id = gpu_id

        if root is not None:
            self.basedir = make_basedir(os.path.join(root, str(sacred_run._id)))
        else:
            writers = None
            checkpoint = None

        if writers is not None:
            self.writers = init_writers(*writers, sacred_run=sacred_run, dirname=self.basedir)
        else:
            self.writers = None

        self.create_trainer()

    def create_trainer(self):
        self.trainer = Engine(self.train_step)
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.K_step)
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.log_metrics, 'train')
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self.write_metrics, 'train')

        self.pbar = ProgressBar()

        self.timer = Timer(average=True)
        self.timer.attach(self.trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                          pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.print_times)

    def print_times(self, engine):
        iteration = engine.state.iteration
        if self.writers is not None:
            self.writers.add_scalar('time_per_epoch', self.timer.total, iteration)
            self.writers.add_scalar('time_per_batch', self.timer.value(), iteration)
        self.timer.reset()

    def train_step(self, engine, batch):
        self.training()
        if isinstance(batch, Mapping):
            return self.optimize(**convert_tensor(batch, self.device))
        else:
            raise NotImplementedError

    def infer(self, **kwargs):
        raise NotImplementedError

    def optimize(self, **kwargs):
        raise NotImplementedError

    def K_step(self, **kwargs):
        raise NotImplementedError

    def log_metrics(self, engine, dataset_name):
        iteration = self.trainer.state.iteration
        epoch = self.trainer.state.epoch
        log = f"EPOCH {epoch}" f" - ITER {iteration} - {dataset_name}"

        if self.sacred_run is not None:
            log = f"ID {self.sacred_run._id} - " + log

        for metric, value in engine.state.metrics.items():
            if value is None:
                value = float('nan')
            log += f" {metric}: {value:.6f}"

        log += self.extra_log()
        self.pbar.log_message(log)

    def extra_log(self):
        return ""

    def write_metrics(self, engine, dataset_name):
        if self.writers is not None:
            for metric, value in engine.state.metrics.items():
                name = dataset_name + '/' + metric
                self.writers.add_scalar(name, value, self.trainer.state.iteration)

    def run(self):
        self.trainer.run(self.train, max_epochs=self.nepoch)
Ejemplo n.º 4
0
class HandlersTimeProfiler:
    """
    HandlersTimeProfiler can be used to profile the handlers,
    data loading and data processing times. Custom events are also
    profiled by this profiler

    Examples:

    .. code-block:: python

        from ignite.contrib.handlers import HandlersTimeProfiler

        trainer = Engine(train_updater)

        # Create an object of the profiler and attach an engine to it
        profiler = HandlersTimeProfiler()
        profiler.attach(trainer)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_intermediate_results():
            profiler.print_results(profiler.get_results())

        trainer.run(dataloader, max_epochs=3)

        profiler.write_results('path_to_dir/time_profiling.csv')

    """

    EVENT_FILTER_THESHOLD_TIME = 0.0001

    def __init__(self) -> None:
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = []  # type: List[float]
        self.processing_times = []  # type: List[float]
        self.event_handlers_times = {
        }  # type: Dict[EventEnum, Dict[str, List[float]]]

    @staticmethod
    def _get_callable_name(handler: Callable) -> str:
        # get name of the callable handler
        return getattr(handler, "__qualname__", handler.__class__.__name__)

    def _create_wrapped_handler(self, handler: Callable,
                                event: EventEnum) -> Callable:
        @functools.wraps(handler)
        def _timeit_handler(*args: Any, **kwargs: Any) -> None:
            self._event_handlers_timer.reset()
            handler(*args, **kwargs)
            t = self._event_handlers_timer.value()
            hname = self._get_callable_name(handler)
            # filter profiled time if the handler was attached to event with event filter
            if not hasattr(handler,
                           "_parent") or t >= self.EVENT_FILTER_THESHOLD_TIME:
                self.event_handlers_times[event][hname].append(t)

        # required to revert back to original handler after profiling
        setattr(_timeit_handler, "_profiler_original", handler)
        return _timeit_handler

    def _timeit_processing(self) -> None:
        # handler used for profiling processing times
        t = self._processing_timer.value()
        self.processing_times.append(t)

    def _timeit_dataflow(self) -> None:
        # handler used for profiling dataflow times
        t = self._dataflow_timer.value()
        self.dataflow_times.append(t)

    def _reset(self, event_handlers_names: Mapping[EventEnum,
                                                   List[str]]) -> None:
        # reset the variables used for profiling
        self.dataflow_times = []
        self.processing_times = []

        self.event_handlers_times = {
            e: {h: []
                for h in event_handlers_names[e]}
            for e in event_handlers_names
        }

    @staticmethod
    def _is_internal_handler(handler: Callable) -> bool:
        # checks whether the handler is internal
        return any(n in repr(handler)
                   for n in ["HandlersTimeProfiler.", "Timer."])

    def _detach_profiler_handlers(self, engine: Engine) -> None:
        # reverts handlers to original handlers
        for e in engine._event_handlers:
            for i, (func, args,
                    kwargs) in enumerate(engine._event_handlers[e]):
                if hasattr(func, "_profiler_original"):
                    engine._event_handlers[e][i] = (func._profiler_original,
                                                    args, kwargs)

    def _as_first_started(self, engine: Engine) -> None:
        # wraps original handlers for profiling

        self.event_handlers_names = {
            e: [
                self._get_callable_name(h)
                for (h, _, _) in engine._event_handlers[e]
                if not self._is_internal_handler(h)
            ]
            for e in engine._allowed_events
        }

        self._reset(self.event_handlers_names)

        for e in engine._allowed_events:
            for i, (func, args,
                    kwargs) in enumerate(engine._event_handlers[e]):
                if not self._is_internal_handler(func):
                    engine._event_handlers[e][i] = (
                        self._create_wrapped_handler(func, e), args, kwargs)

        # processing timer
        engine.add_event_handler(Events.ITERATION_STARTED,
                                 self._processing_timer.reset)
        engine._event_handlers[Events.ITERATION_COMPLETED].insert(
            0, (self._timeit_processing, (), {}))

        # dataflow timer
        engine.add_event_handler(Events.GET_BATCH_STARTED,
                                 self._dataflow_timer.reset)
        engine._event_handlers[Events.GET_BATCH_COMPLETED].insert(
            0, (self._timeit_dataflow, (), {}))

        # revert back the wrapped handlers with original handlers at the end
        engine.add_event_handler(Events.COMPLETED,
                                 self._detach_profiler_handlers)

    def attach(self, engine: Engine) -> None:
        """Attach HandlersTimeProfiler to the given engine.

        Args:
            engine: the instance of Engine to attach
        """
        if not isinstance(engine, Engine):
            raise TypeError(
                f"Argument engine should be ignite.engine.Engine, but given {type(engine)}"
            )

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(
                0, (self._as_first_started, (engine, ), {}))

    def get_results(self) -> List[List[Union[str, float]]]:
        """
        Method to fetch the aggregated profiler results after the engine is run

        .. code-block:: python

            results = profiler.get_results()

        """
        total_eh_time = sum([
            sum(self.event_handlers_times[e][h])
            for e in self.event_handlers_times
            for h in self.event_handlers_times[e]
        ])
        total_eh_time = round(
            float(total_eh_time),
            5,
        )

        def compute_basic_stats(
            times: Union[Sequence, torch.Tensor]
        ) -> List[Union[str, float, Tuple[Union[str, float], Union[str,
                                                                   float]]]]:
            data = torch.as_tensor(times, dtype=torch.float32)
            # compute on non-zero data:
            data = data[data > 0]
            total = round(torch.sum(data).item(), 5) if len(
                data) > 0 else "not triggered"  # type: Union[str, float]
            min_index = ("None", "None"
                         )  # type: Tuple[Union[str, float], Union[str, float]]
            max_index = ("None", "None"
                         )  # type: Tuple[Union[str, float], Union[str, float]]
            mean = "None"  # type: Union[str, float]
            std = "None"  # type: Union[str, float]
            if len(data) > 0:
                min_index = (round(torch.min(data).item(),
                                   5), torch.argmin(data).item())
                max_index = (round(torch.max(data).item(),
                                   5), torch.argmax(data).item())
                mean = round(torch.mean(data).item(), 5)
                if len(data) > 1:
                    std = round(torch.std(data).item(), 5)
            return [total, min_index, max_index, mean, std]

        event_handler_stats = [[
            h,
            getattr(e, "name", str(e)),
            *compute_basic_stats(
                torch.tensor(self.event_handlers_times[e][h],
                             dtype=torch.float32)),
        ] for e in self.event_handlers_times
                               for h in self.event_handlers_times[e]]
        event_handler_stats.append(
            ["Total", "", total_eh_time, "", "", "", ""])
        event_handler_stats.append([
            "Processing", "None", *compute_basic_stats(self.processing_times)
        ])
        event_handler_stats.append(
            ["Dataflow", "None", *compute_basic_stats(self.dataflow_times)])

        return event_handler_stats

    def write_results(self, output_path: str) -> None:
        """
        Method to store the unaggregated profiling results to a csv file

        Args:
            output_path: file output path containing a filename

        .. code-block:: python

            profiler.write_results('path_to_dir/awesome_filename.csv')

        Example output:

        .. code-block:: text

            -----------------------------------------------------------------
            # processing_stats dataflow_stats training.<locals>.log_elapsed_time (EPOCH_COMPLETED) ...
            1     0.00003         0.252387                          0.125676
            2     0.00029         0.252342                          0.125123

        """
        try:
            import pandas as pd
        except ImportError:
            raise RuntimeError("Need pandas to write results as files")

        processing_stats = torch.tensor(self.processing_times,
                                        dtype=torch.float32)
        dataflow_stats = torch.tensor(self.dataflow_times, dtype=torch.float32)

        cols = [processing_stats, dataflow_stats]
        headers = ["processing_stats", "dataflow_stats"]
        for e in self.event_handlers_times:
            for h in self.event_handlers_times[e]:
                headers.append(f"{h} ({getattr(e, 'name', str(e))})")
                cols.append(
                    torch.tensor(self.event_handlers_times[e][h],
                                 dtype=torch.float32))
        # Determine maximum length
        max_len = max([x.numel() for x in cols])

        count_col = torch.arange(max_len, dtype=torch.float32) + 1
        cols.insert(0, count_col)
        headers.insert(0, "#")

        # pad all tensors to have same length
        cols = [
            torch.nn.functional.pad(x,
                                    pad=(0, max_len - x.numel()),
                                    mode="constant",
                                    value=0) for x in cols
        ]

        results_dump = torch.stack(
            cols,
            dim=1,
        ).numpy()

        results_df = pd.DataFrame(
            data=results_dump,
            columns=headers,
        )
        results_df.to_csv(output_path, index=False)

    @staticmethod
    def print_results(results: List[List[Union[str, float]]]) -> None:
        """
        Method to print the aggregated results from the profiler

        Args:
            results: the aggregated results from the profiler

        .. code-block:: python

            profiler.print_results(results)

        Example output:

        .. code-block:: text

            -----------------------------------------  -----------------------  -------------- ...
            Handler                                    Event Name                     Total(s)
            -----------------------------------------  -----------------------  --------------
            run.<locals>.log_training_results          EPOCH_COMPLETED                19.43245
            run.<locals>.log_validation_results        EPOCH_COMPLETED                 2.55271
            run.<locals>.log_time                      EPOCH_COMPLETED                 0.00049
            run.<locals>.log_intermediate_results      EPOCH_COMPLETED                 0.00106
            run.<locals>.log_training_loss             ITERATION_COMPLETED               0.059
            run.<locals>.log_time                      COMPLETED                 not triggered
            -----------------------------------------  -----------------------  --------------
            Total                                                                     22.04571
            -----------------------------------------  -----------------------  --------------
            Processing took total 11.29543s [min/index: 0.00393s/1875, max/index: 0.00784s/0,
             mean: 0.00602s, std: 0.00034s]
            Dataflow took total 16.24365s [min/index: 0.00533s/1874, max/index: 0.01129s/937,
             mean: 0.00866s, std: 0.00113s]

        """
        # adopted implementation of torch.autograd.profiler.build_table
        handler_column_width = max([len(item[0]) for item in results
                                    ]) + 4  # type: ignore[arg-type]
        event_column_width = max([len(item[1]) for item in results
                                  ]) + 4  # type: ignore[arg-type]

        DEFAULT_COLUMN_WIDTH = 14

        headers = [
            "Handler",
            "Event Name",
            "Total(s)",
            "Min(s)/IDX",
            "Max(s)/IDX",
            "Mean(s)",
            "Std(s)",
        ]

        # Have to use a list because nonlocal is Py3 only...
        SPACING_SIZE = 2
        row_format_lst = [""]
        header_sep_lst = [""]
        line_length_lst = [-SPACING_SIZE]

        def add_column(padding: int, text_dir: str = ">") -> None:
            row_format_lst[0] += "{: " + text_dir + str(padding) + "}" + (
                " " * SPACING_SIZE)
            header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE)
            line_length_lst[0] += padding + SPACING_SIZE

        add_column(handler_column_width, text_dir="<")
        add_column(event_column_width, text_dir="<")
        for _ in headers[2:]:
            add_column(DEFAULT_COLUMN_WIDTH)

        row_format = row_format_lst[0]
        header_sep = header_sep_lst[0]

        result = []

        def append(s: str) -> None:
            result.append(s)
            result.append("\n")

        result.append("\n")
        append(header_sep)
        append(row_format.format(*headers))
        append(header_sep)

        for row in results[:-3]:
            # format min/idx and max/idx
            row[3] = "{}/{}".format(*row[3])  # type: ignore[misc]
            row[4] = "{}/{}".format(*row[4])  # type: ignore[misc]

            append(row_format.format(*row))

        append(header_sep)
        # print total handlers time row
        append(row_format.format(*results[-3]))
        append(header_sep)

        summary_format = "{} took total {}s [min/index: {}, max/index: {}, mean: {}s, std: {}s]"
        for row in results[-2:]:
            row[3] = "{}s/{}".format(*row[3])  # type: ignore[misc]
            row[4] = "{}s/{}".format(*row[4])  # type: ignore[misc]
            del row[1]
            append(summary_format.format(*row))
        print("".join(result))
Ejemplo n.º 5
0
class BasicTimeProfiler:
    """
    BasicTimeProfiler can be used to profile the handlers,
    events, data loading and data processing times.

    Examples:

    .. code-block:: python

        #
        # Create an object of the profiler and attach an engine to it
        #
        profiler = BasicTimeProfiler()
        trainer = Engine(train_updater)
        profiler.attach(trainer)
        trainer.run(dataloader, max_epochs=3)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_intermediate_results():
            profiler.print_results(profiler.get_results())

        profiler.write_results('path_to_dir/time_profiling.csv')

    """

    events_to_ignore = [
        Events.EXCEPTION_RAISED,
        Events.TERMINATE,
        Events.TERMINATE_SINGLE_EPOCH,
        Events.DATALOADER_STOP_ITERATION,
    ]

    def __init__(self):
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = None
        self.processing_times = None
        self.event_handlers_times = None

        self._events = [
            Events.EPOCH_STARTED,
            Events.EPOCH_COMPLETED,
            Events.ITERATION_STARTED,
            Events.ITERATION_COMPLETED,
            Events.GET_BATCH_STARTED,
            Events.GET_BATCH_COMPLETED,
            Events.COMPLETED,
        ]
        self._fmethods = [
            self._as_first_epoch_started,
            self._as_first_epoch_completed,
            self._as_first_iter_started,
            self._as_first_iter_completed,
            self._as_first_get_batch_started,
            self._as_first_get_batch_completed,
            self._as_first_completed,
        ]
        self._lmethods = [
            self._as_last_epoch_started,
            self._as_last_epoch_completed,
            self._as_last_iter_started,
            self._as_last_iter_completed,
            self._as_last_get_batch_started,
            self._as_last_get_batch_completed,
            self._as_last_completed,
        ]

    def _reset(self, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_STARTED: torch.zeros(total_num_iters),
        }

    def _as_first_started(self, engine):
        if hasattr(engine.state.dataloader, "__len__"):
            num_iters_per_epoch = len(engine.state.dataloader)
        else:
            num_iters_per_epoch = engine.state.epoch_length

        self.max_epochs = engine.state.max_epochs
        self.total_num_iters = self.max_epochs * num_iters_per_epoch
        self._reset(self.max_epochs, self.total_num_iters)

        self.event_handlers_names = {
            e: [
                h.__qualname__
                if hasattr(h, "__qualname__") else h.__class__.__name__
                for (h, _, _) in engine._event_handlers[e]
                if "BasicTimeProfiler." not in repr(
                    h)  # avoid adding internal handlers into output
            ]
            for e in Events if e not in self.events_to_ignore
        }

        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append(
            (self._as_last_started, (engine, ), {}))

        for e, m in zip(self._events, self._fmethods):
            engine._event_handlers[e].insert(0, (m, (engine, ), {}))

        for e, m in zip(self._events, self._lmethods):
            engine._event_handlers[e].append((m, (engine, ), {}))

        # Let's go
        self._event_handlers_timer.reset()

    def _as_last_started(self, engine):
        self.event_handlers_times[
            Events.STARTED][0] = self._event_handlers_timer.value()

    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t

    def _as_first_get_batch_started(self, engine):
        self._event_handlers_timer.reset()
        self._dataflow_timer.reset()

    def _as_last_get_batch_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t

    def _as_first_get_batch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_get_batch_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t

        d = self._dataflow_timer.value()
        self.dataflow_times[i] = d

        self._dataflow_timer.reset()

    def _as_first_iter_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t

        self._processing_timer.reset()

    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t

        self._event_handlers_timer.reset()

    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t

    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t

    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_completed(self, engine):
        self.event_handlers_times[
            Events.COMPLETED][0] = self._event_handlers_timer.value()

        # Remove added handlers:
        engine.remove_event_handler(self._as_last_started, Events.STARTED)

        for e, m in zip(self._events, self._fmethods):
            engine.remove_event_handler(m, e)

        for e, m in zip(self._events, self._lmethods):
            engine.remove_event_handler(m, e)

    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(
                0, (self._as_first_started, (engine, ), {}))

    @staticmethod
    def _compute_basic_stats(data):
        # compute on non-zero data:
        data = data[data > 0]
        out = [
            ("total",
             torch.sum(data).item() if len(data) > 0 else "not yet triggered")
        ]
        if len(data) > 1:
            out += [
                ("min/index", (torch.min(data).item(),
                               torch.argmin(data).item())),
                ("max/index", (torch.max(data).item(),
                               torch.argmax(data).item())),
                ("mean", torch.mean(data).item()),
                ("std", torch.std(data).item()),
            ]
        return OrderedDict(out)

    def get_results(self):
        """
        Method to fetch the aggregated profiler results after the engine is run

        .. code-block:: python

            results = profiler.get_results()

        """
        total_eh_time = sum([(self.event_handlers_times[e]).sum()
                             for e in Events
                             if e not in self.events_to_ignore])

        return OrderedDict([
            ("processing_stats",
             self._compute_basic_stats(self.processing_times)),
            ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
            (
                "event_handlers_stats",
                dict([(str(e.name).replace(".", "_"),
                       self._compute_basic_stats(self.event_handlers_times[e]))
                      for e in Events if e not in self.events_to_ignore] +
                     [("total_time", total_eh_time)]),
            ),
            (
                "event_handlers_names",
                {
                    str(e.name).replace(".", "_") + "_names": v
                    for e, v in self.event_handlers_names.items()
                },
            ),
        ])

    def write_results(self, output_path):
        """
        Method to store the unaggregated profiling results to a csv file

        .. code-block:: python

            profiler.write_results('path_to_dir/awesome_filename.csv')

        Example output:

        .. code-block:: text

            -----------------------------------------------------------------
            epoch iteration processing_stats dataflow_stats Event_STARTED ...
            1.0     1.0        0.00003         0.252387        0.125676
            1.0     2.0        0.00029         0.252342        0.125123

        """
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return

        iters_per_epoch = self.total_num_iters // self.max_epochs

        epochs = torch.arange(
            self.max_epochs,
            dtype=torch.float32).repeat_interleave(iters_per_epoch) + 1
        iterations = torch.arange(self.total_num_iters,
                                  dtype=torch.float32) + 1
        processing_stats = self.processing_times
        dataflow_stats = self.dataflow_times

        event_started = self.event_handlers_times[
            Events.STARTED].repeat_interleave(self.total_num_iters)
        event_completed = self.event_handlers_times[
            Events.COMPLETED].repeat_interleave(self.total_num_iters)
        event_epoch_started = self.event_handlers_times[
            Events.EPOCH_STARTED].repeat_interleave(iters_per_epoch)
        event_epoch_completed = self.event_handlers_times[
            Events.EPOCH_COMPLETED].repeat_interleave(iters_per_epoch)

        event_iter_started = self.event_handlers_times[
            Events.ITERATION_STARTED]
        event_iter_completed = self.event_handlers_times[
            Events.ITERATION_COMPLETED]
        event_batch_started = self.event_handlers_times[
            Events.GET_BATCH_STARTED]
        event_batch_completed = self.event_handlers_times[
            Events.GET_BATCH_COMPLETED]

        results_dump = torch.stack(
            [
                epochs,
                iterations,
                processing_stats,
                dataflow_stats,
                event_started,
                event_completed,
                event_epoch_started,
                event_epoch_completed,
                event_iter_started,
                event_iter_completed,
                event_batch_started,
                event_batch_completed,
            ],
            dim=1,
        ).numpy()

        results_df = pd.DataFrame(
            data=results_dump,
            columns=[
                "epoch",
                "iteration",
                "processing_stats",
                "dataflow_stats",
                "Event_STARTED",
                "Event_COMPLETED",
                "Event_EPOCH_STARTED",
                "Event_EPOCH_COMPLETED",
                "Event_ITERATION_STARTED",
                "Event_ITERATION_COMPLETED",
                "Event_GET_BATCH_STARTED",
                "Event_GET_BATCH_COMPLETED",
            ],
        )
        results_df.to_csv(output_path, index=False)

    @staticmethod
    def print_results(results):
        """
        Method to print the aggregated results from the profiler

        .. code-block:: python

            profiler.print_results(results)

        Example output:

        .. code-block:: text

            --------------------------------------------
            - Time profiling results:
            --------------------------------------------
            Processing function time stats (in seconds):
                min/index: (1.3081999895803165e-05, 1)
                max/index: (1.433099987480091e-05, 0)
                mean: 1.3706499885302037e-05
                std: 8.831763693706307e-07
                total: 2.7412999770604074e-05

            Dataflow time stats (in seconds):
                min/index: (5.199999941396527e-05, 0)
                max/index: (0.00010925399692496285, 1)
                mean: 8.062699635047466e-05
                std: 4.048469054396264e-05
                total: 0.0001612539927009493


            Time stats of event handlers (in seconds):
            - Total time spent:
                1.0080009698867798

            - Events.STARTED:
                total: 0.1256754994392395

            Handlers names:
            ['delay_start']
            --------------------------------------------
        """
        def odict_to_str(d):
            out = ""
            for k, v in d.items():
                out += "\t{}: {}\n".format(k, v)
            return out

        others = {
            k: odict_to_str(v) if isinstance(v, OrderedDict) else v
            for k, v in results["event_handlers_stats"].items()
        }

        others.update(results["event_handlers_names"])

        output_message = """
--------------------------------------------
- Time profiling results:
--------------------------------------------

Processing function time stats (in seconds):
{processing_stats}

Dataflow time stats (in seconds):
{dataflow_stats}

Time stats of event handlers (in seconds):
- Total time spent:
\t{total_time}

- Events.STARTED:
{STARTED}
Handlers names:
{STARTED_names}

- Events.EPOCH_STARTED:
{EPOCH_STARTED}
Handlers names:
{EPOCH_STARTED_names}

- Events.ITERATION_STARTED:
{ITERATION_STARTED}
Handlers names:
{ITERATION_STARTED_names}

- Events.ITERATION_COMPLETED:
{ITERATION_COMPLETED}
Handlers names:
{ITERATION_COMPLETED_names}

- Events.EPOCH_COMPLETED:
{EPOCH_COMPLETED}
Handlers names:
{EPOCH_COMPLETED_names}

- Events.COMPLETED:
{COMPLETED}
Handlers names:
{COMPLETED_names}

""".format(
            processing_stats=odict_to_str(results["processing_stats"]),
            dataflow_stats=odict_to_str(results["dataflow_stats"]),
            **others,
        )
        print(output_message)
        return output_message
class BasicTimeProfiler(object):
    """
    BasicTimeProfiler can be used to profile the handlers,
    events, data loading and data processing times.

    Examples:

    .. code-block:: python

        #
        # Create an object of the profiler and attach an engine to it
        #
        profiler = BasicTimeProfiler()
        trainer = Engine(train_updater)
        profiler.attach(trainer)
        trainer.run(dataloader, max_epochs=3)

    """
    def __init__(self):
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = None
        self.processing_times = None
        self.event_handlers_times = None

        self._events = [
            Events.EPOCH_STARTED, Events.EPOCH_COMPLETED,
            Events.ITERATION_STARTED, Events.ITERATION_COMPLETED,
            Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED,
            Events.COMPLETED
        ]
        self._fmethods = [
            self._as_first_epoch_started, self._as_first_epoch_completed,
            self._as_first_iter_started, self._as_first_iter_completed,
            self._as_first_get_batch_started,
            self._as_first_get_batch_completed, self._as_first_completed
        ]
        self._lmethods = [
            self._as_last_epoch_started, self._as_last_epoch_completed,
            self._as_last_iter_started, self._as_last_iter_completed,
            self._as_last_get_batch_started, self._as_last_get_batch_completed,
            self._as_last_completed
        ]

    def _reset(self, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_STARTED: torch.zeros(total_num_iters)
        }

    def _as_first_started(self, engine):
        if hasattr(engine.state.dataloader, "__len__"):
            num_iters_per_epoch = len(engine.state.dataloader)
        else:
            num_iters_per_epoch = engine.state.epoch_length

        self.max_epochs = engine.state.max_epochs
        self.total_num_iters = self.max_epochs * num_iters_per_epoch
        self._reset(self.max_epochs, self.total_num_iters)

        self.event_handlers_names = {
            e: [
                h.__qualname__
                if hasattr(h, "__qualname__") else h.__class__.__name__
                for (h, _, _) in engine._event_handlers[e]
            ]
            for e in Events if e != Events.EXCEPTION_RAISED
        }

        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append(
            (self._as_last_started, (), {}))

        for e, m in zip(self._events, self._fmethods):
            engine._event_handlers[e].insert(0, (m, (), {}))

        for e, m in zip(self._events, self._lmethods):
            engine._event_handlers[e].append((m, (), {}))

        # Let's go
        self._event_handlers_timer.reset()

    def _as_last_started(self, engine):
        self.event_handlers_times[
            Events.STARTED][0] = self._event_handlers_timer.value()

    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t

    def _as_first_get_batch_started(self, engine):
        self._event_handlers_timer.reset()
        self._dataflow_timer.reset()

    def _as_last_get_batch_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t

    def _as_first_get_batch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_get_batch_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t

        d = self._dataflow_timer.value()
        self.dataflow_times[i] = d

        self._dataflow_timer.reset()

    def _as_first_iter_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t

        self._processing_timer.reset()

    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t

        self._event_handlers_timer.reset()

    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t

    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t

    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_completed(self, engine):
        self.event_handlers_times[
            Events.COMPLETED][0] = self._event_handlers_timer.value()

        # Remove added handlers:
        engine.remove_event_handler(self._as_last_started, Events.STARTED)

        for e, m in zip(self._events, self._fmethods):
            engine.remove_event_handler(m, e)

        for e, m in zip(self._events, self._lmethods):
            engine.remove_event_handler(m, e)

    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED]\
                .insert(0, (self._as_first_started, (), {}))

    @staticmethod
    def _compute_basic_stats(data):
        return OrderedDict([
            ('min/index', (torch.min(data).item(), torch.argmin(data).item())),
            ('max/index', (torch.max(data).item(), torch.argmax(data).item())),
            ('mean', torch.mean(data).item()), ('std', torch.std(data).item()),
            ('total', torch.sum(data).item())
        ])

    def get_results(self):
        """
        Method to fetch the aggregated profiler results after the engine is run

        .. code-block:: python

            results = profiler.get_results()

        """
        events_to_ignore = [Events.EXCEPTION_RAISED]
        total_eh_time = sum([
            sum(self.event_handlers_times[e]) for e in Events
            if e not in events_to_ignore
        ])
        return OrderedDict([
            ("processing_stats",
             self._compute_basic_stats(self.processing_times)),
            ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
            ("event_handlers_stats",
             dict([(str(e).replace(".", "_"),
                    self._compute_basic_stats(self.event_handlers_times[e]))
                   for e in Events if e not in events_to_ignore] +
                  [("total_time", total_eh_time)])),
            ("event_handlers_names", {
                str(e).replace(".", "_") + "_names": v
                for e, v in self.event_handlers_names.items()
            })
        ])

    def write_results(self, output_path):
        """
        Method to store the unaggregated profiling results to a csv file

        .. code-block:: python

            profiler.write_results('path_to_dir/awesome_filename.csv')

        Example output:

        .. code-block:: text

            -----------------------------------------------------------------
            epoch iteration processing_stats dataflow_stats Event_STARTED ...
            1.0     1.0        0.00003         0.252387        0.125676
            1.0     2.0        0.00029         0.252342        0.125123

        """
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return

        iters_per_epoch = self.total_num_iters // self.max_epochs

        epochs = torch.arange(self.max_epochs, dtype=torch.float32)\
            .repeat_interleave(iters_per_epoch) + 1
        iterations = torch.arange(self.total_num_iters,
                                  dtype=torch.float32) + 1
        processing_stats = self.processing_times
        dataflow_stats = self.dataflow_times

        event_started = self.event_handlers_times[Events.STARTED]\
            .repeat_interleave(self.total_num_iters)
        event_completed = self.event_handlers_times[Events.COMPLETED]\
            .repeat_interleave(self.total_num_iters)
        event_epoch_started = self.event_handlers_times[Events.EPOCH_STARTED]\
            .repeat_interleave(iters_per_epoch)
        event_epoch_completed = self.event_handlers_times[Events.EPOCH_COMPLETED]\
            .repeat_interleave(iters_per_epoch)

        event_iter_started = self.event_handlers_times[
            Events.ITERATION_STARTED]
        event_iter_completed = self.event_handlers_times[
            Events.ITERATION_COMPLETED]
        event_batch_started = self.event_handlers_times[
            Events.GET_BATCH_STARTED]
        event_batch_completed = self.event_handlers_times[
            Events.GET_BATCH_COMPLETED]

        results_dump = torch.stack([
            epochs, iterations, processing_stats, dataflow_stats,
            event_started, event_completed, event_epoch_started,
            event_epoch_completed, event_iter_started, event_iter_completed,
            event_batch_started, event_batch_completed
        ],
                                   dim=1).numpy()

        results_df = pd.DataFrame(
            data=results_dump,
            columns=[
                'epoch', 'iteration', 'processing_stats', 'dataflow_stats',
                'Event_STARTED', 'Event_COMPLETED', 'Event_EPOCH_STARTED',
                'Event_EPOCH_COMPLETED', 'Event_ITERATION_STARTED',
                'Event_ITERATION_COMPLETED', 'Event_GET_BATCH_STARTED',
                'Event_GET_BATCH_COMPLETED'
            ])
        results_df.to_csv(output_path, index=False)

    @staticmethod
    def print_results(results):
        """
        Method to print the aggregated results from the profiler

        .. code-block:: python

            profiler.print_results(results)

        Example output:

        .. code-block:: text

            --------------------------------------------
            - Time profiling results:
            --------------------------------------------
            Processing function time stats (in seconds):
                min/index: (2.9754010029137135e-05, 0)
                max/index: (2.9754010029137135e-05, 0)
                mean: 2.9754010029137135e-05
                std: nan
                total: 2.9754010029137135e-05

            Dataflow time stats (in seconds):
                min/index: (0.2523871660232544, 0)
                max/index: (0.2523871660232544, 0)
                mean: 0.2523871660232544
                std: nan
                total: 0.2523871660232544

            Time stats of event handlers (in seconds):
            - Total time spent:
                1.0080009698867798

            - Events.STARTED:
                min/index: (0.1256754994392395, 0)
                max/index: (0.1256754994392395, 0)
                mean: 0.1256754994392395
                std: nan
                total: 0.1256754994392395

            Handlers names:
            ['BasicTimeProfiler._as_first_started', 'delay_start']
            --------------------------------------------
        """
        def odict_to_str(d):
            out = ""
            for k, v in d.items():
                out += "\t{}: {}\n".format(k, v)
            return out

        others = {
            k: odict_to_str(v) if isinstance(v, OrderedDict) else v
            for k, v in results['event_handlers_stats'].items()
        }

        others.update(results['event_handlers_names'])

        output_message = """
--------------------------------------------
- Time profiling results:
--------------------------------------------

Processing function time stats (in seconds):
{processing_stats}

Dataflow time stats (in seconds):
{dataflow_stats}

Time stats of event handlers (in seconds):
- Total time spent:
\t{total_time}

- Events.STARTED:
{Events_STARTED}
Handlers names:
{Events_STARTED_names}

- Events.EPOCH_STARTED:
{Events_EPOCH_STARTED}
Handlers names:
{Events_EPOCH_STARTED_names}

- Events.ITERATION_STARTED:
{Events_ITERATION_STARTED}
Handlers names:
{Events_ITERATION_STARTED_names}

- Events.ITERATION_COMPLETED:
{Events_ITERATION_COMPLETED}
Handlers names:
{Events_ITERATION_COMPLETED_names}

- Events.EPOCH_COMPLETED:
{Events_EPOCH_COMPLETED}
Handlers names:
{Events_EPOCH_COMPLETED_names}

- Events.COMPLETED:
{Events_COMPLETED}
Handlers names:
{Events_COMPLETED_names}

""".format(processing_stats=odict_to_str(results['processing_stats']),
           dataflow_stats=odict_to_str(results['dataflow_stats']),
           **others)
        print(output_message)
        return output_message
Ejemplo n.º 7
0
class GANTrainer:
    def __init__(self,
                 G,
                 D,
                 criterionG,
                 criterionD,
                 optimizerG,
                 optimizerD,
                 lr_schedulerG=None,
                 lr_schedulerD=None,
                 make_latent=None,
                 metrics=None,
                 save_path=".",
                 name="GAN",
                 gan_type='gan'):

        self.G = G
        self.D = D
        self.criterionG = criterionG
        self.criterionD = criterionD
        self.optimizerG = optimizerG
        self.optimizerD = optimizerD
        self.lr_schedulerG = lr_schedulerG
        self.lr_schedulerD = lr_schedulerD
        self.make_latent = make_latent
        self.metrics = metrics or {}
        self.name = name
        root = Path(save_path).expanduser().absolute()
        self.save_path = root / 'gan_trainer' / self.name

        self.metric_history = defaultdict(list)
        self.device = 'cuda' if CUDA else 'cpu'
        self._timer = Timer()
        self._iterations = 0

        self.G.to(self.device)
        self.D.to(self.device)

        assert gan_type in ['gan', 'acgan', 'cgan', 'infogan']
        if gan_type == 'gan':
            self.create_fn = create_gan_trainer
        elif gan_type == 'acgan':
            self.create_fn = create_acgan_trainer
        elif gan_type == 'cgan':
            self.create_fn = create_cgan_trainer
        elif gan_type == 'infogan':
            self.create_fn = create_infogan_trainer

    def _lr_scheduler_step(self, engine):
        if self.lr_schedulerG:
            self.lr_schedulerG.step()
        if self.lr_schedulerD:
            self.lr_schedulerD.step()

    def _attach_timer(self, engine):
        self._trained_time = 0
        self._timer.reset()

    def _increment_iteration(self, engine):
        self._iterations += 1

    def _log_results(self, engine, log_interval, max_iter):
        if self._iterations % log_interval != 0:
            return
        i = engine.state.iteration
        elapsed = self._timer.value()
        self._timer.reset()
        self._trained_time += elapsed
        trained_time = self._trained_time
        eta_seconds = int((trained_time / i) * (max_iter - i))
        it_fmt = "%" + str(len(str(max_iter))) + "d"
        print(("Iter: " + it_fmt + ", Cost: %.2fs, Eta: %s") %
              (self._iterations, elapsed,
               datetime.timedelta(seconds=eta_seconds)))

        for name, metric in self.metrics.items():
            metric.completed(engine, name)
            metric.reset()

        msg = ""
        for name, val in engine.state.metrics.items():
            msg += "%s: %.4f\t" % (name, val)
            self.metric_history[name].append(val)
        print(msg)

    def fit(self, it, max_iter, log_interval=100, callbacks=()):

        engine = self.create_fn(self.G, self.D, self.criterionG,
                                self.criterionD, self.optimizerG,
                                self.optimizerD, self.make_latent,
                                self.metrics, self.device)
        self._attach_timer(engine)

        engine.add_event_handler(Events.ITERATION_STARTED,
                                 self._lr_scheduler_step)

        engine.add_event_handler(Events.ITERATION_COMPLETED,
                                 self._increment_iteration)
        engine.add_event_handler(Events.ITERATION_COMPLETED, self._log_results,
                                 log_interval, max_iter)

        for callback in callbacks:
            engine.add_event_handler(Events.ITERATION_COMPLETED,
                                     _trainer_callback_wrap(callback), self)

        # Run
        engine.run(it, max_iter)

        # Return history
        return self.metric_history

    def state_dict(self):
        s = {
            "iterations": self.iterations(),
            "G": self.G.state_dict(),
            "D": self.D.state_dict(),
            "optimizerG": self.optimizerG.state_dict(),
            "optimizerD": self.optimizerD.state_dict(),
            "criterionG": self.criterionG.state_dict(),
            "criterionD": self.criterionD.state_dict(),
            "lr_schedulerG": None,
            "lr_schedulerD": None,
            "metric_history": self.metric_history,
        }
        if self.lr_schedulerG:
            s["lr_schedulerG"] = self.lr_schedulerG.state_dict()
        if self.lr_schedulerD:
            s["lr_schedulerD"] = self.lr_schedulerD.state_dict()
        return s

    def load_state_dict(self, state_dict):
        iterations, G, D, optimizerG, optimizerD, criterionG, criterionD, lr_schedulerG, lr_schedulerD, metric_history = get(
            [
                "iterations", "G", "D", "optimizerG", "optimizerD",
                "criterionG", "criterionD", "lr_schedulerG", "lr_schedulerD",
                "metric_history"
            ], state_dict)
        self._iterations = iterations
        self.G.load_state_dict(G)
        self.D.load_state_dict(D)
        self.optimizerG.load_state_dict(optimizerG)
        self.optimizerD.load_state_dict(optimizerD)
        self.criterionG.load_state_dict(criterionG)
        self.criterionD.load_state_dict(criterionD)
        if self.lr_schedulerG and lr_schedulerG:
            self.lr_schedulerG.load_state_dict(lr_schedulerG)
        if self.lr_schedulerD and lr_schedulerD:
            self.lr_schedulerD.load_state_dict(lr_schedulerD)
        self.metric_history = metric_history

    def save(self, remove_prev=True):
        d = self.save_path
        d.mkdir(parents=True, exist_ok=True)

        if remove_prev:
            pattern = "%s_trainer*.pth" % self.name
            saves = list(d.glob(pattern))
            if len(saves) != 0:
                fp = max(saves, key=lambda f: f.stat().st_mtime)
                p = "%s_trainer_(?P<iters>[0-9]+).pth" % self.name
                iters = int(re.match(p, fp.name).group('iters'))
                if self.iterations() > iters:
                    fp.unlink()

        filename = "%s_trainer_%d.pth" % (self.name, self.iterations())
        fp = d / filename
        torch.save(self.state_dict(), fp)
        print("Save trainer as %s" % fp)

    def load(self):
        d = self.save_path
        pattern = "%s_trainer*.pth" % self.name
        saves = list(d.glob(pattern))
        if len(saves) == 0:
            raise FileNotFoundError("No checkpoint to load for %s in %s" %
                                    (self.name, self.save_path))
        fp = max(saves, key=lambda f: f.stat().st_mtime)
        self.load_state_dict(torch.load(fp, map_location=self.device))
        print("Load trainer from %s" % fp)

    def iterations(self):
        return self._iterations
class Trainer:
    """
    Class which setups the training logic which mainly involves defining callback handlers and attaching them to
    the training loop.
    """
    def __init__(self, model, config, evaluator, data_loader, tb_writer,
                 run_info, logger, checkpoint_dir):
        """
        Creates a new trainer object for training a model.
        :param model: model to train. Needs to inherit from the BaseModel class.
        :param config: dictionary containing the whole configuration of the experiment
        :param evaluator: Instance of the evaluator class, used to run evaluation on a specified schedule
        :param data_loader: pytorch data loader providing the training data
        :param tb_writer: tensorboardX summary writer
        :param run_info: sacred run info for loging training progress
        :param logger: python logger object
        :param checkpoint_dir: directory path for storing checkpoints
        """
        self.run_info = run_info
        self.logger = logger
        self.data_loader = data_loader
        self.evaluator = evaluator
        self.engine = Engine(self._step)
        self.model = model
        self.config = config
        self.train_cfg = config['train']
        self.tb_writer = tb_writer

        self.pbar = ProgressBar(ascii=True, desc='* Epoch')
        self.timer = Timer(average=True)
        self.save_last_checkpoint_handler = ModelCheckpoint(
            checkpoint_dir,
            'last',
            save_interval=self.train_cfg['save_interval'],
            n_saved=self.train_cfg['save_n_last'],
            require_empty=False)

        self.add_handler()

    def run(self):
        """
        Start the training loop which will run until all epochs are complete
        :return:
        """
        self.engine.run(self.data_loader,
                        max_epochs=self.train_cfg['n_epochs'])

    def add_handler(self):
        """
        Adds all the callback handlers to the trainer engine. Should be called in the end of the init.
        :return:
        """
        # Learning rate decay
        for lr_s in self.model.schedulers:
            self.engine.add_event_handler(Events.ITERATION_STARTED, lr_s)

        # Checkpoint saving
        self.engine.add_event_handler(Events.EPOCH_STARTED,
                                      self.save_last_checkpoint_handler,
                                      self.model.networks)

        # Progbar
        monitoring_metrics = self.model.metric_names
        for mm in monitoring_metrics:
            RunningAverage(output_transform=self._extract_loss(mm)).attach(
                self.engine, mm)
        self.pbar.attach(self.engine, metric_names=monitoring_metrics)

        # Timer
        self.timer.attach(self.engine,
                          start=Events.EPOCH_STARTED,
                          resume=Events.ITERATION_STARTED,
                          pause=Events.ITERATION_COMPLETED,
                          step=Events.ITERATION_COMPLETED)

        # Logging
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_log_train_results)
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_log_train_images)
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_run_evaluation)
        self.engine.add_event_handler(Events.EPOCH_COMPLETED,
                                      self._handle_print_times)

        # Exception handling
        self.engine.add_event_handler(Events.EXCEPTION_RAISED,
                                      self._handle_exception)

    def _step(self, engine, batch):
        """
        Definition of a single training step. This function gets automatically called by the engine every iteration.
        :param engine: trainer engine
        :param batch: one batch provided by the dataloader
        :return:
        """
        self.model.train()
        self.model.set_input(batch)
        self.model.optimize_parameters()
        return self.model.state

    def _handle_log_train_results(self, engine):
        """
        Handler for writing the losses to tensorboard and sacred.
        :param engine: train engine
        :return:
        """
        if (engine.state.iteration - 1) % self.train_cfg['log_interval'] == 0:
            metrics = engine.state.metrics  # does not include non scalar metrics, since loggers can not handle this

            for m_name, m_val in metrics.items():
                if m_val is None:
                    raise ValueError(f'Value for {m_name} is None')
                self.run_info.log_scalar("train.%s" % m_name, m_val,
                                         engine.state.iteration)
                self.tb_writer.add_scalar("train/%s" % m_name, m_val,
                                          engine.state.iteration)

            for lr_name, lr_val in self.model.learning_rates.items():
                if lr_val is None:
                    raise ValueError(f'Value for {lr_name} is None')
                self.run_info.log_scalar("train.%s" % lr_name, lr_val,
                                         engine.state.iteration)
                self.tb_writer.add_scalar("train/%s" % lr_name, lr_val,
                                          engine.state.iteration)

    def _handle_log_train_images(self, engine):
        """
        Handler for writing visual samples to tensorboard.
        :param engine: train engine
        :return:
        """
        if (engine.state.iteration -
                1) % self.train_cfg['img_log_interval'] == 0:
            for name, visual in self.model.visuals.items():
                # TODO remove the visual.transpose here and put it in the visualization function of the models
                self.tb_writer.add_image('train/%s' % name,
                                         visual.transpose(2, 0, 1),
                                         engine.state.iteration)
            for name, figure in self.model.figures.items():
                self.tb_writer.add_figure('train_metrics/%s' % name, figure,
                                          engine.state.iteration)

    def _handle_run_evaluation(self, engine):
        """
        Handler which will execute evaluation by running the evaluator object.
        :param engine: train engine
        :return:
        """
        if (engine.state.iteration - 1) % self.train_cfg['eval_interval'] == 0:
            self.evaluator.run()

    def _handle_exception(self, engine, e):
        """
        Exception handler which ensures that the model gets saved when stopped through a keyboard interruption.
        :param engine: train engine
        :param e: the exception which caused the training to stop
        :return:
        """
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            self.logger.warning(
                'KeyboardInterrupt caught. Exiting gracefully.')
            self.save_last_checkpoint_handler(engine, self.model.networks)
        else:
            raise e

    def _handle_print_times(self, engine):
        """
        Handler for logging timer information for different training and evaluation steps.
        :param engine: train engine
        :return:
        """
        self.logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, self.timer.value()))
        self.timer.reset()

    @staticmethod
    def _extract_loss(key):
        """
        Helper method to return losses for the RunningAverage
        :param key: (str) loss name
        :return: (fn) for the corresponding key
        """
        def _func(losses):
            return losses[key]

        return _func
Ejemplo n.º 9
0
class BasicTimeProfiler(object):

    def __init__(self):        
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

    def _reset(self, num_iters, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)        
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters)
        }

    def _as_first_started(self, engine):
        num_iters = engine.state.max_epochs * len(engine.state.dataloader)
        self._reset(len(engine.state.dataloader), engine.state.max_epochs, num_iters)
        
        self.event_handlers_names = {
            e: [h.__qualname__ if hasattr(h, "__qualname__") else h.__class__.__name__ 
                for (h, _, _) in engine._event_handlers[e]]
            for e in Events if e != Events.EXCEPTION_RAISED
        }

        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append((self._as_last_started, (), {}))
        #  - add the first handlers 
        events = [Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, 
                  Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, 
                  Events.COMPLETED]
        fmethods = [self._as_first_epoch_started, self._as_first_epoch_completed, 
                   self._as_first_iter_started, self._as_first_iter_completed,
                   self._as_first_completed]
        lmethods = [self._as_last_epoch_started, self._as_last_epoch_completed, 
                   self._as_last_iter_started, self._as_last_iter_completed,
                   self._as_last_completed]

        for e, m in zip(events, fmethods):
            engine._event_handlers[e].insert(0, (m, (), {}))

        for e, m in zip(events, lmethods):
            engine._event_handlers[e].append((m, (), {}))

        # Let's go
        self._event_handlers_timer.reset()

    def _as_last_started(self, engine):        
        self.event_handlers_times[Events.STARTED][0] = self._event_handlers_timer.value()

    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t

        self._dataflow_timer.reset()

    def _as_first_iter_started(self, engine):        
        t = self._dataflow_timer.value()
        i = engine.state.iteration - 1
        self.dataflow_times[i] = t

        self._event_handlers_timer.reset()

    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t

        self._processing_timer.reset()

    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t

        self._event_handlers_timer.reset()        

    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t

        self._dataflow_timer.reset()

    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t

    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_completed(self, engine):
        self.event_handlers_times[Events.COMPLETED][0] = self._event_handlers_timer.value()
        
        # Remove added handlers:
        remove_handler(engine, self._as_last_started, Events.STARTED)
        
        #  - add the first handlers 
        events = [Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, 
                  Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, 
                  Events.COMPLETED]
        fmethods = [self._as_first_epoch_started, self._as_first_epoch_completed, 
                   self._as_first_iter_started, self._as_first_iter_completed,
                   self._as_first_completed]
        lmethods = [self._as_last_epoch_started, self._as_last_epoch_completed, 
                   self._as_last_iter_started, self._as_last_iter_completed,
                   self._as_last_completed]

        for e, m in zip(events, fmethods):
            remove_handler(engine, m, e)          

        for e, m in zip(events, lmethods):
            remove_handler(engine, m, e)

    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))
        
        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (), {}))
    
    @staticmethod
    def _compute_basic_stats(data):    
        return OrderedDict([
            ('min/index', (torch.min(data).item(), torch.argmin(data).item())),
            ('max/index', (torch.max(data).item(), torch.argmax(data).item())),
            ('mean', torch.mean(data).item()),
            ('std', torch.std(data).item()),
            ('total', torch.sum(data).item())
        ])
            
    def get_results(self):
        total_eh_time = sum([sum(self.event_handlers_times[e]) for e in Events if e != Events.EXCEPTION_RAISED])
        return OrderedDict([
            ("processing_stats", self._compute_basic_stats(self.processing_times)),
            ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
            ("event_handlers_stats",
                dict([(str(e).replace(".","_"), self._compute_basic_stats(self.event_handlers_times[e])) for e in Events 
                      if e != Events.EXCEPTION_RAISED] + 
                     [("total_time", total_eh_time)])
            ),
            ("event_handlers_names", {str(e).replace(".","_") + "_names": v 
                                      for e, v in self.event_handlers_names.items()})
        ])

    @staticmethod
    def print_results(results):

        def odict_to_str(d):
            out = ""
            for k, v in d.items():
                out += "\t{}: {}\n".format(k, v)
            return out

        others = {k: odict_to_str(v) if isinstance(v, OrderedDict) else v 
                  for k, v in results['event_handlers_stats'].items()}
        
        others.update(results['event_handlers_names'])
        
        output_message = """
--------------------------------------------
- Time profiling results:
--------------------------------------------
              
Processing function time stats (in seconds):
{processing_stats}

Dataflow time stats (in seconds):
{dataflow_stats}
        
Time stats of event handlers (in seconds):
- Total time spent:
\t{total_time}

- Events.STARTED:
{Events_STARTED}
Handlers names:
{Events_STARTED_names}

- Events.EPOCH_STARTED:
{Events_EPOCH_STARTED}
Handlers names:
{Events_EPOCH_STARTED_names}

- Events.ITERATION_STARTED:
{Events_ITERATION_STARTED}
Handlers names:
{Events_ITERATION_STARTED_names}

- Events.ITERATION_COMPLETED:
{Events_ITERATION_COMPLETED}
Handlers names:
{Events_ITERATION_COMPLETED_names}

- Events.EPOCH_COMPLETED:
{Events_EPOCH_COMPLETED}
Handlers names:
{Events_EPOCH_COMPLETED_names}

- Events.COMPLETED:
{Events_COMPLETED}
Handlers names:
{Events_COMPLETED_names}

""".format(processing_stats=odict_to_str(results['processing_stats']), 
           dataflow_stats=odict_to_str(results['dataflow_stats']), 
           **others)
        print(output_message)
        return output_message
    
    @staticmethod
    def write_results(output_path):
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return
        
        raise NotImplementedError("")