Exemple #1
0
class DataflowBenchmark:
    def __init__(self, num_iters=100, prepare_batch=None):

        from ignite.handlers import Timer

        device = idist.device()

        def upload_to_gpu(engine, batch):
            if prepare_batch is not None:
                x, y = prepare_batch(batch, device=device, non_blocking=False)

        self.num_iters = num_iters
        self.benchmark_dataflow = Engine(upload_to_gpu)

        @self.benchmark_dataflow.on(Events.ITERATION_COMPLETED(once=num_iters))
        def stop_benchmark_dataflow(engine):
            engine.terminate()

        if idist.get_rank() == 0:

            @self.benchmark_dataflow.on(
                Events.ITERATION_COMPLETED(every=num_iters // 100))
            def show_progress_benchmark_dataflow(engine):
                print(".", end=" ")

        self.timer = Timer(average=False)
        self.timer.attach(
            self.benchmark_dataflow,
            start=Events.EPOCH_STARTED,
            resume=Events.ITERATION_STARTED,
            pause=Events.ITERATION_COMPLETED,
            step=Events.ITERATION_COMPLETED,
        )

    def attach(self, trainer, train_loader):

        from torch.utils.data import DataLoader

        @trainer.on(Events.STARTED)
        def run_benchmark(_):
            if idist.get_rank() == 0:
                print("-" * 50)
                print(" - Dataflow benchmark")

            self.benchmark_dataflow.run(train_loader)
            t = self.timer.value()

            if idist.get_rank() == 0:
                print(" ")
                print(" Total time ({} iterations) : {:.5f} seconds".format(
                    self.num_iters, t))
                print(" time per iteration         : {} seconds".format(
                    t / self.num_iters))

                if isinstance(train_loader, DataLoader):
                    num_images = train_loader.batch_size * self.num_iters
                    print(" number of images / s       : {}".format(
                        num_images / t))

                print("-" * 50)
Exemple #2
0
def test_timer():
    sleep_t = 0.2
    n_iter = 3

    def _train_func(engine, batch):
        time.sleep(sleep_t)

    def _test_func(engine, batch):
        time.sleep(sleep_t)

    trainer = Engine(_train_func)
    tester = Engine(_test_func)

    t_total = Timer()
    t_batch = Timer(average=True)
    t_train = Timer()

    t_total.attach(trainer)
    t_batch.attach(
        trainer,
        pause=Events.ITERATION_COMPLETED,
        resume=Events.ITERATION_STARTED,
        step=Events.ITERATION_COMPLETED,
    )
    t_train.attach(trainer,
                   pause=Events.EPOCH_COMPLETED,
                   resume=Events.EPOCH_STARTED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(trainer):
        tester.run(range(n_iter))

    # Run "training"
    trainer.run(range(n_iter))

    def _equal(lhs, rhs):
        return round(lhs, 1) == round(rhs, 1)

    assert _equal(t_total.value(), (2 * n_iter * sleep_t))
    assert _equal(t_batch.value(), (sleep_t))
    assert _equal(t_train.value(), (n_iter * sleep_t))

    t_total.reset()
    assert _equal(t_total.value(), 0.0)
Exemple #3
0
def evaluate_model(run_name, model, optimizer, device, loss_name, loss_params, chosen_diseases, dataloader,
                   experiment_mode="debug", base_dir=utils.BASE_DIR):
    # Create tester engine
    tester = Engine(utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params, training=False))

    loss_metric = RunningAverage(output_transform=lambda x: x[0], alpha=1)
    loss_metric.attach(tester, loss_name)
    
    utilsT.attach_metrics(tester, chosen_diseases, "prec", Precision, True)
    utilsT.attach_metrics(tester, chosen_diseases, "recall", Recall, True)
    utilsT.attach_metrics(tester, chosen_diseases, "acc", Accuracy, True)
    utilsT.attach_metrics(tester, chosen_diseases, "roc_auc", utilsT.RocAucMetric, False)
    utilsT.attach_metrics(tester, chosen_diseases, "cm", ConfusionMatrix,
                          get_transform_fn=utilsT.get_transform_cm, metric_args=(2,))

    timer = Timer(average=True)
    timer.attach(tester, start=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED)

    # Save metrics
    log_metrics = list(ALL_METRICS)

    # Run test
    print("Testing...")
    tester.run(dataloader, 1)
    

    # Capture time
    secs_per_epoch = timer.value()
    duration_per_epoch = utils.duration_to_str(int(secs_per_epoch))
    print("Time elapsed in epoch: ", duration_per_epoch)

    # Copy metrics dict
    metrics = dict()
    original_metrics = tester.state.metrics
    for metric_name in log_metrics:
        for disease_name in chosen_diseases:
            key = metric_name + "_" + disease_name
            if key not in original_metrics:
                print("Metric not found in tester, skipping: ", key)
                continue

            metrics[key] = original_metrics[key]
            
    # Copy CMs
    for disesase_name in chosen_diseases:
        key = "cm_" + disease_name
        if key not in original_metrics:
            print("CM not found in tester, skipping: ", key)
            continue
        
        cm = original_metrics[key]
        metrics[key] = cm.numpy().tolist()
    
    # Save to file
    folder = os.path.join(base_dir, "results", experiment_mode)
    os.makedirs(folder, exist_ok=True)
    
    fname = os.path.join(folder, run_name + ".json")
    with open(fname, "w+") as f:
        json.dump(metrics, f)   
    print("Saved metrics to: ", fname)
    
    return metrics
Exemple #4
0
class HandlersTimeProfiler:
    """
    HandlersTimeProfiler can be used to profile the handlers,
    data loading and data processing times. Custom events are also
    profiled by this profiler

    Examples:

    .. code-block:: python

        from ignite.contrib.handlers import HandlersTimeProfiler

        trainer = Engine(train_updater)

        # Create an object of the profiler and attach an engine to it
        profiler = HandlersTimeProfiler()
        profiler.attach(trainer)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_intermediate_results():
            profiler.print_results(profiler.get_results())

        trainer.run(dataloader, max_epochs=3)

        profiler.write_results('path_to_dir/time_profiling.csv')

    """

    EVENT_FILTER_THESHOLD_TIME = 0.0001

    def __init__(self) -> None:
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = []  # type: List[float]
        self.processing_times = []  # type: List[float]
        self.event_handlers_times = {
        }  # type: Dict[EventEnum, Dict[str, List[float]]]

    @staticmethod
    def _get_callable_name(handler: Callable) -> str:
        # get name of the callable handler
        return getattr(handler, "__qualname__", handler.__class__.__name__)

    def _create_wrapped_handler(self, handler: Callable,
                                event: EventEnum) -> Callable:
        @functools.wraps(handler)
        def _timeit_handler(*args: Any, **kwargs: Any) -> None:
            self._event_handlers_timer.reset()
            handler(*args, **kwargs)
            t = self._event_handlers_timer.value()
            hname = self._get_callable_name(handler)
            # filter profiled time if the handler was attached to event with event filter
            if not hasattr(handler,
                           "_parent") or t >= self.EVENT_FILTER_THESHOLD_TIME:
                self.event_handlers_times[event][hname].append(t)

        # required to revert back to original handler after profiling
        setattr(_timeit_handler, "_profiler_original", handler)
        return _timeit_handler

    def _timeit_processing(self) -> None:
        # handler used for profiling processing times
        t = self._processing_timer.value()
        self.processing_times.append(t)

    def _timeit_dataflow(self) -> None:
        # handler used for profiling dataflow times
        t = self._dataflow_timer.value()
        self.dataflow_times.append(t)

    def _reset(self, event_handlers_names: Mapping[EventEnum,
                                                   List[str]]) -> None:
        # reset the variables used for profiling
        self.dataflow_times = []
        self.processing_times = []

        self.event_handlers_times = {
            e: {h: []
                for h in event_handlers_names[e]}
            for e in event_handlers_names
        }

    @staticmethod
    def _is_internal_handler(handler: Callable) -> bool:
        # checks whether the handler is internal
        return any(n in repr(handler)
                   for n in ["HandlersTimeProfiler.", "Timer."])

    def _detach_profiler_handlers(self, engine: Engine) -> None:
        # reverts handlers to original handlers
        for e in engine._event_handlers:
            for i, (func, args,
                    kwargs) in enumerate(engine._event_handlers[e]):
                if hasattr(func, "_profiler_original"):
                    engine._event_handlers[e][i] = (func._profiler_original,
                                                    args, kwargs)

    def _as_first_started(self, engine: Engine) -> None:
        # wraps original handlers for profiling

        self.event_handlers_names = {
            e: [
                self._get_callable_name(h)
                for (h, _, _) in engine._event_handlers[e]
                if not self._is_internal_handler(h)
            ]
            for e in engine._allowed_events
        }

        self._reset(self.event_handlers_names)

        for e in engine._allowed_events:
            for i, (func, args,
                    kwargs) in enumerate(engine._event_handlers[e]):
                if not self._is_internal_handler(func):
                    engine._event_handlers[e][i] = (
                        self._create_wrapped_handler(func, e), args, kwargs)

        # processing timer
        engine.add_event_handler(Events.ITERATION_STARTED,
                                 self._processing_timer.reset)
        engine._event_handlers[Events.ITERATION_COMPLETED].insert(
            0, (self._timeit_processing, (), {}))

        # dataflow timer
        engine.add_event_handler(Events.GET_BATCH_STARTED,
                                 self._dataflow_timer.reset)
        engine._event_handlers[Events.GET_BATCH_COMPLETED].insert(
            0, (self._timeit_dataflow, (), {}))

        # revert back the wrapped handlers with original handlers at the end
        engine.add_event_handler(Events.COMPLETED,
                                 self._detach_profiler_handlers)

    def attach(self, engine: Engine) -> None:
        """Attach HandlersTimeProfiler to the given engine.

        Args:
            engine: the instance of Engine to attach
        """
        if not isinstance(engine, Engine):
            raise TypeError(
                f"Argument engine should be ignite.engine.Engine, but given {type(engine)}"
            )

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(
                0, (self._as_first_started, (engine, ), {}))

    def get_results(self) -> List[List[Union[str, float]]]:
        """
        Method to fetch the aggregated profiler results after the engine is run

        .. code-block:: python

            results = profiler.get_results()

        """
        total_eh_time = sum([
            sum(self.event_handlers_times[e][h])
            for e in self.event_handlers_times
            for h in self.event_handlers_times[e]
        ])
        total_eh_time = round(
            float(total_eh_time),
            5,
        )

        def compute_basic_stats(
            times: Union[Sequence, torch.Tensor]
        ) -> List[Union[str, float, Tuple[Union[str, float], Union[str,
                                                                   float]]]]:
            data = torch.as_tensor(times, dtype=torch.float32)
            # compute on non-zero data:
            data = data[data > 0]
            total = round(torch.sum(data).item(), 5) if len(
                data) > 0 else "not triggered"  # type: Union[str, float]
            min_index = ("None", "None"
                         )  # type: Tuple[Union[str, float], Union[str, float]]
            max_index = ("None", "None"
                         )  # type: Tuple[Union[str, float], Union[str, float]]
            mean = "None"  # type: Union[str, float]
            std = "None"  # type: Union[str, float]
            if len(data) > 0:
                min_index = (round(torch.min(data).item(),
                                   5), torch.argmin(data).item())
                max_index = (round(torch.max(data).item(),
                                   5), torch.argmax(data).item())
                mean = round(torch.mean(data).item(), 5)
                if len(data) > 1:
                    std = round(torch.std(data).item(), 5)
            return [total, min_index, max_index, mean, std]

        event_handler_stats = [[
            h,
            getattr(e, "name", str(e)),
            *compute_basic_stats(
                torch.tensor(self.event_handlers_times[e][h],
                             dtype=torch.float32)),
        ] for e in self.event_handlers_times
                               for h in self.event_handlers_times[e]]
        event_handler_stats.append(
            ["Total", "", total_eh_time, "", "", "", ""])
        event_handler_stats.append([
            "Processing", "None", *compute_basic_stats(self.processing_times)
        ])
        event_handler_stats.append(
            ["Dataflow", "None", *compute_basic_stats(self.dataflow_times)])

        return event_handler_stats

    def write_results(self, output_path: str) -> None:
        """
        Method to store the unaggregated profiling results to a csv file

        Args:
            output_path: file output path containing a filename

        .. code-block:: python

            profiler.write_results('path_to_dir/awesome_filename.csv')

        Example output:

        .. code-block:: text

            -----------------------------------------------------------------
            # processing_stats dataflow_stats training.<locals>.log_elapsed_time (EPOCH_COMPLETED) ...
            1     0.00003         0.252387                          0.125676
            2     0.00029         0.252342                          0.125123

        """
        try:
            import pandas as pd
        except ImportError:
            raise RuntimeError("Need pandas to write results as files")

        processing_stats = torch.tensor(self.processing_times,
                                        dtype=torch.float32)
        dataflow_stats = torch.tensor(self.dataflow_times, dtype=torch.float32)

        cols = [processing_stats, dataflow_stats]
        headers = ["processing_stats", "dataflow_stats"]
        for e in self.event_handlers_times:
            for h in self.event_handlers_times[e]:
                headers.append(f"{h} ({getattr(e, 'name', str(e))})")
                cols.append(
                    torch.tensor(self.event_handlers_times[e][h],
                                 dtype=torch.float32))
        # Determine maximum length
        max_len = max([x.numel() for x in cols])

        count_col = torch.arange(max_len, dtype=torch.float32) + 1
        cols.insert(0, count_col)
        headers.insert(0, "#")

        # pad all tensors to have same length
        cols = [
            torch.nn.functional.pad(x,
                                    pad=(0, max_len - x.numel()),
                                    mode="constant",
                                    value=0) for x in cols
        ]

        results_dump = torch.stack(
            cols,
            dim=1,
        ).numpy()

        results_df = pd.DataFrame(
            data=results_dump,
            columns=headers,
        )
        results_df.to_csv(output_path, index=False)

    @staticmethod
    def print_results(results: List[List[Union[str, float]]]) -> None:
        """
        Method to print the aggregated results from the profiler

        Args:
            results: the aggregated results from the profiler

        .. code-block:: python

            profiler.print_results(results)

        Example output:

        .. code-block:: text

            -----------------------------------------  -----------------------  -------------- ...
            Handler                                    Event Name                     Total(s)
            -----------------------------------------  -----------------------  --------------
            run.<locals>.log_training_results          EPOCH_COMPLETED                19.43245
            run.<locals>.log_validation_results        EPOCH_COMPLETED                 2.55271
            run.<locals>.log_time                      EPOCH_COMPLETED                 0.00049
            run.<locals>.log_intermediate_results      EPOCH_COMPLETED                 0.00106
            run.<locals>.log_training_loss             ITERATION_COMPLETED               0.059
            run.<locals>.log_time                      COMPLETED                 not triggered
            -----------------------------------------  -----------------------  --------------
            Total                                                                     22.04571
            -----------------------------------------  -----------------------  --------------
            Processing took total 11.29543s [min/index: 0.00393s/1875, max/index: 0.00784s/0,
             mean: 0.00602s, std: 0.00034s]
            Dataflow took total 16.24365s [min/index: 0.00533s/1874, max/index: 0.01129s/937,
             mean: 0.00866s, std: 0.00113s]

        """
        # adopted implementation of torch.autograd.profiler.build_table
        handler_column_width = max([len(item[0]) for item in results
                                    ]) + 4  # type: ignore[arg-type]
        event_column_width = max([len(item[1]) for item in results
                                  ]) + 4  # type: ignore[arg-type]

        DEFAULT_COLUMN_WIDTH = 14

        headers = [
            "Handler",
            "Event Name",
            "Total(s)",
            "Min(s)/IDX",
            "Max(s)/IDX",
            "Mean(s)",
            "Std(s)",
        ]

        # Have to use a list because nonlocal is Py3 only...
        SPACING_SIZE = 2
        row_format_lst = [""]
        header_sep_lst = [""]
        line_length_lst = [-SPACING_SIZE]

        def add_column(padding: int, text_dir: str = ">") -> None:
            row_format_lst[0] += "{: " + text_dir + str(padding) + "}" + (
                " " * SPACING_SIZE)
            header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE)
            line_length_lst[0] += padding + SPACING_SIZE

        add_column(handler_column_width, text_dir="<")
        add_column(event_column_width, text_dir="<")
        for _ in headers[2:]:
            add_column(DEFAULT_COLUMN_WIDTH)

        row_format = row_format_lst[0]
        header_sep = header_sep_lst[0]

        result = []

        def append(s: str) -> None:
            result.append(s)
            result.append("\n")

        result.append("\n")
        append(header_sep)
        append(row_format.format(*headers))
        append(header_sep)

        for row in results[:-3]:
            # format min/idx and max/idx
            row[3] = "{}/{}".format(*row[3])  # type: ignore[misc]
            row[4] = "{}/{}".format(*row[4])  # type: ignore[misc]

            append(row_format.format(*row))

        append(header_sep)
        # print total handlers time row
        append(row_format.format(*results[-3]))
        append(header_sep)

        summary_format = "{} took total {}s [min/index: {}, max/index: {}, mean: {}s, std: {}s]"
        for row in results[-2:]:
            row[3] = "{}s/{}".format(*row[3])  # type: ignore[misc]
            row[4] = "{}s/{}".format(*row[4])  # type: ignore[misc]
            del row[1]
            append(summary_format.format(*row))
        print("".join(result))
Exemple #5
0
class BasicTimeProfiler:
    """
    BasicTimeProfiler can be used to profile the handlers,
    events, data loading and data processing times.

    Examples:

    .. code-block:: python

        #
        # Create an object of the profiler and attach an engine to it
        #
        profiler = BasicTimeProfiler()
        trainer = Engine(train_updater)
        profiler.attach(trainer)
        trainer.run(dataloader, max_epochs=3)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_intermediate_results():
            profiler.print_results(profiler.get_results())

        profiler.write_results('path_to_dir/time_profiling.csv')

    """

    events_to_ignore = [
        Events.EXCEPTION_RAISED,
        Events.TERMINATE,
        Events.TERMINATE_SINGLE_EPOCH,
        Events.DATALOADER_STOP_ITERATION,
    ]

    def __init__(self):
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = None
        self.processing_times = None
        self.event_handlers_times = None

        self._events = [
            Events.EPOCH_STARTED,
            Events.EPOCH_COMPLETED,
            Events.ITERATION_STARTED,
            Events.ITERATION_COMPLETED,
            Events.GET_BATCH_STARTED,
            Events.GET_BATCH_COMPLETED,
            Events.COMPLETED,
        ]
        self._fmethods = [
            self._as_first_epoch_started,
            self._as_first_epoch_completed,
            self._as_first_iter_started,
            self._as_first_iter_completed,
            self._as_first_get_batch_started,
            self._as_first_get_batch_completed,
            self._as_first_completed,
        ]
        self._lmethods = [
            self._as_last_epoch_started,
            self._as_last_epoch_completed,
            self._as_last_iter_started,
            self._as_last_iter_completed,
            self._as_last_get_batch_started,
            self._as_last_get_batch_completed,
            self._as_last_completed,
        ]

    def _reset(self, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_STARTED: torch.zeros(total_num_iters),
        }

    def _as_first_started(self, engine):
        if hasattr(engine.state.dataloader, "__len__"):
            num_iters_per_epoch = len(engine.state.dataloader)
        else:
            num_iters_per_epoch = engine.state.epoch_length

        self.max_epochs = engine.state.max_epochs
        self.total_num_iters = self.max_epochs * num_iters_per_epoch
        self._reset(self.max_epochs, self.total_num_iters)

        self.event_handlers_names = {
            e: [
                h.__qualname__
                if hasattr(h, "__qualname__") else h.__class__.__name__
                for (h, _, _) in engine._event_handlers[e]
                if "BasicTimeProfiler." not in repr(
                    h)  # avoid adding internal handlers into output
            ]
            for e in Events if e not in self.events_to_ignore
        }

        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append(
            (self._as_last_started, (engine, ), {}))

        for e, m in zip(self._events, self._fmethods):
            engine._event_handlers[e].insert(0, (m, (engine, ), {}))

        for e, m in zip(self._events, self._lmethods):
            engine._event_handlers[e].append((m, (engine, ), {}))

        # Let's go
        self._event_handlers_timer.reset()

    def _as_last_started(self, engine):
        self.event_handlers_times[
            Events.STARTED][0] = self._event_handlers_timer.value()

    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t

    def _as_first_get_batch_started(self, engine):
        self._event_handlers_timer.reset()
        self._dataflow_timer.reset()

    def _as_last_get_batch_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t

    def _as_first_get_batch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_get_batch_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t

        d = self._dataflow_timer.value()
        self.dataflow_times[i] = d

        self._dataflow_timer.reset()

    def _as_first_iter_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t

        self._processing_timer.reset()

    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t

        self._event_handlers_timer.reset()

    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t

    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t

    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_completed(self, engine):
        self.event_handlers_times[
            Events.COMPLETED][0] = self._event_handlers_timer.value()

        # Remove added handlers:
        engine.remove_event_handler(self._as_last_started, Events.STARTED)

        for e, m in zip(self._events, self._fmethods):
            engine.remove_event_handler(m, e)

        for e, m in zip(self._events, self._lmethods):
            engine.remove_event_handler(m, e)

    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(
                0, (self._as_first_started, (engine, ), {}))

    @staticmethod
    def _compute_basic_stats(data):
        # compute on non-zero data:
        data = data[data > 0]
        out = [
            ("total",
             torch.sum(data).item() if len(data) > 0 else "not yet triggered")
        ]
        if len(data) > 1:
            out += [
                ("min/index", (torch.min(data).item(),
                               torch.argmin(data).item())),
                ("max/index", (torch.max(data).item(),
                               torch.argmax(data).item())),
                ("mean", torch.mean(data).item()),
                ("std", torch.std(data).item()),
            ]
        return OrderedDict(out)

    def get_results(self):
        """
        Method to fetch the aggregated profiler results after the engine is run

        .. code-block:: python

            results = profiler.get_results()

        """
        total_eh_time = sum([(self.event_handlers_times[e]).sum()
                             for e in Events
                             if e not in self.events_to_ignore])

        return OrderedDict([
            ("processing_stats",
             self._compute_basic_stats(self.processing_times)),
            ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
            (
                "event_handlers_stats",
                dict([(str(e.name).replace(".", "_"),
                       self._compute_basic_stats(self.event_handlers_times[e]))
                      for e in Events if e not in self.events_to_ignore] +
                     [("total_time", total_eh_time)]),
            ),
            (
                "event_handlers_names",
                {
                    str(e.name).replace(".", "_") + "_names": v
                    for e, v in self.event_handlers_names.items()
                },
            ),
        ])

    def write_results(self, output_path):
        """
        Method to store the unaggregated profiling results to a csv file

        .. code-block:: python

            profiler.write_results('path_to_dir/awesome_filename.csv')

        Example output:

        .. code-block:: text

            -----------------------------------------------------------------
            epoch iteration processing_stats dataflow_stats Event_STARTED ...
            1.0     1.0        0.00003         0.252387        0.125676
            1.0     2.0        0.00029         0.252342        0.125123

        """
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return

        iters_per_epoch = self.total_num_iters // self.max_epochs

        epochs = torch.arange(
            self.max_epochs,
            dtype=torch.float32).repeat_interleave(iters_per_epoch) + 1
        iterations = torch.arange(self.total_num_iters,
                                  dtype=torch.float32) + 1
        processing_stats = self.processing_times
        dataflow_stats = self.dataflow_times

        event_started = self.event_handlers_times[
            Events.STARTED].repeat_interleave(self.total_num_iters)
        event_completed = self.event_handlers_times[
            Events.COMPLETED].repeat_interleave(self.total_num_iters)
        event_epoch_started = self.event_handlers_times[
            Events.EPOCH_STARTED].repeat_interleave(iters_per_epoch)
        event_epoch_completed = self.event_handlers_times[
            Events.EPOCH_COMPLETED].repeat_interleave(iters_per_epoch)

        event_iter_started = self.event_handlers_times[
            Events.ITERATION_STARTED]
        event_iter_completed = self.event_handlers_times[
            Events.ITERATION_COMPLETED]
        event_batch_started = self.event_handlers_times[
            Events.GET_BATCH_STARTED]
        event_batch_completed = self.event_handlers_times[
            Events.GET_BATCH_COMPLETED]

        results_dump = torch.stack(
            [
                epochs,
                iterations,
                processing_stats,
                dataflow_stats,
                event_started,
                event_completed,
                event_epoch_started,
                event_epoch_completed,
                event_iter_started,
                event_iter_completed,
                event_batch_started,
                event_batch_completed,
            ],
            dim=1,
        ).numpy()

        results_df = pd.DataFrame(
            data=results_dump,
            columns=[
                "epoch",
                "iteration",
                "processing_stats",
                "dataflow_stats",
                "Event_STARTED",
                "Event_COMPLETED",
                "Event_EPOCH_STARTED",
                "Event_EPOCH_COMPLETED",
                "Event_ITERATION_STARTED",
                "Event_ITERATION_COMPLETED",
                "Event_GET_BATCH_STARTED",
                "Event_GET_BATCH_COMPLETED",
            ],
        )
        results_df.to_csv(output_path, index=False)

    @staticmethod
    def print_results(results):
        """
        Method to print the aggregated results from the profiler

        .. code-block:: python

            profiler.print_results(results)

        Example output:

        .. code-block:: text

            --------------------------------------------
            - Time profiling results:
            --------------------------------------------
            Processing function time stats (in seconds):
                min/index: (1.3081999895803165e-05, 1)
                max/index: (1.433099987480091e-05, 0)
                mean: 1.3706499885302037e-05
                std: 8.831763693706307e-07
                total: 2.7412999770604074e-05

            Dataflow time stats (in seconds):
                min/index: (5.199999941396527e-05, 0)
                max/index: (0.00010925399692496285, 1)
                mean: 8.062699635047466e-05
                std: 4.048469054396264e-05
                total: 0.0001612539927009493


            Time stats of event handlers (in seconds):
            - Total time spent:
                1.0080009698867798

            - Events.STARTED:
                total: 0.1256754994392395

            Handlers names:
            ['delay_start']
            --------------------------------------------
        """
        def odict_to_str(d):
            out = ""
            for k, v in d.items():
                out += "\t{}: {}\n".format(k, v)
            return out

        others = {
            k: odict_to_str(v) if isinstance(v, OrderedDict) else v
            for k, v in results["event_handlers_stats"].items()
        }

        others.update(results["event_handlers_names"])

        output_message = """
--------------------------------------------
- Time profiling results:
--------------------------------------------

Processing function time stats (in seconds):
{processing_stats}

Dataflow time stats (in seconds):
{dataflow_stats}

Time stats of event handlers (in seconds):
- Total time spent:
\t{total_time}

- Events.STARTED:
{STARTED}
Handlers names:
{STARTED_names}

- Events.EPOCH_STARTED:
{EPOCH_STARTED}
Handlers names:
{EPOCH_STARTED_names}

- Events.ITERATION_STARTED:
{ITERATION_STARTED}
Handlers names:
{ITERATION_STARTED_names}

- Events.ITERATION_COMPLETED:
{ITERATION_COMPLETED}
Handlers names:
{ITERATION_COMPLETED_names}

- Events.EPOCH_COMPLETED:
{EPOCH_COMPLETED}
Handlers names:
{EPOCH_COMPLETED_names}

- Events.COMPLETED:
{COMPLETED}
Handlers names:
{COMPLETED_names}

""".format(
            processing_stats=odict_to_str(results["processing_stats"]),
            dataflow_stats=odict_to_str(results["dataflow_stats"]),
            **others,
        )
        print(output_message)
        return output_message
Exemple #6
0
class ModelTimer(object):
    """Timer object can be used to measure average time.

    This timer class compute averaged time for:
        - "batch": time of each iteration
        - "data": time of getting batch data
        - "forward": time of model forwarding

    Exmaple:

        >>> from ignite.engine import Engine, Events
        >>> from ignite.handlers import Timer
        >>> trainer = Engine(function)
        >>> timer = ModelTimer()
        >>> timer.attach(trainer)

    Version:
        - ignite>=0.4.2
    """
    def __init__(self):
        self.data_timer = Timer(average=True)
        self.nn_timer = Timer(average=True)
        self.timer = Timer(average=True)

    def attach(self, enginer):
        self.data_timer.attach(
            enginer,
            start=Events.EPOCH_STARTED,
            resume=Events.GET_BATCH_STARTED,
            pause=Events.GET_BATCH_COMPLETED,
            step=Events.ITERATION_COMPLETED,
        )
        self.nn_timer.attach(
            enginer,
            start=Events.EPOCH_STARTED,
            resume=Events.GET_BATCH_COMPLETED,
            pause=Events.GET_BATCH_STARTED,
            step=Events.ITERATION_COMPLETED,
        )
        self.timer.attach(
            enginer,
            start=Events.EPOCH_STARTED,
            step=Events.ITERATION_COMPLETED,
        )

    def __str__(self):
        return "{:.3f} batch, {:.3f} data, {:.3f} forward".format(
            self.timer.value(), self.data_timer.value(), self.nn_timer.value())

    def value(self):
        """Return averaged time for each phase.

        Returns:
            dict: time
        """
        time = {
            "data": self.data_timer.value(),
            "batch": self.timer.value(),
            "forward": self.nn_timer.value()
        }
        return time
def run(args):
    
    # distributed training variable
    args.world_size = 1
    
    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
        
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        if args.arch in model_names:
            model = models.__dict__[args.arch](pretrained=True)
        elif args.arch in local_model_names:
            model = local_models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating new model '{}'".format(args.arch))
        if args.arch in model_names:
            model = models.__dict__[args.arch](pretrained=False)
        elif args.arch in local_model_names:
            model = local_models.__dict__[args.arch](pretrained=False)
            
    if args.arch == 'inception_v3':
        model.aux_logits=False

    device = 'cuda'
    model.to(device)
    
    ### data loaders
    train_loader, val_loader = get_data_loaders(args)
    
    ### optimizers
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                    weight_decay=args.weight_decay)

    if args.opt == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), args.lr,
                                    weight_decay=args.weight_decay)

    if args.opt == 'radam':
        optimizer = local_optimisers.RAdam(model.parameters(), args.lr,
                                          betas=(0.9, 0.999), eps=1e-8,
                                           weight_decay=args.weight_decay)

    # will need my own train loop function to make the fp16 from apex work
    if args.fp16:
        model, optimizer = amp.initialize(model, optimizer,
                                      opt_level=args.opt_level,
                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                                      loss_scale=args.loss_scale
                                      )

    
    if args.wandb_logging:    
        wandb.watch(model)
    
    loss_fn = F.cross_entropy
    accuracy = Accuracy()
    
    scheduler = OneCycleLR(optimizer, num_steps=args.epochs, lr_range=(args.lr/10, args.lr))

    
    # dali loaders need extra class?
    if args.fp16:
        trainer = create_supervised_fp16_trainer(model, optimizer, loss_fn,
                                        device=device, non_blocking=False,
                                        prepare_batch=prepare_dali_batch)
        evaluator = create_supervised_fp16_evaluator(model,
                                            metrics={'cross_entropy': Loss(loss_fn),
                                                    'accuracy': accuracy},
                                            device=device, non_blocking=False ,
                                            prepare_batch=prepare_dali_batch)                                        
    else:
        trainer = create_supervised_trainer(model, optimizer, loss_fn,
                                        device=device, non_blocking=False,
                                        prepare_batch=prepare_dali_batch)

        evaluator = create_supervised_evaluator(model,
                                            metrics={'cross_entropy': Loss(loss_fn),
                                                    'accuracy': accuracy},
                                            device=device, non_blocking=False ,
                                            prepare_batch=prepare_dali_batch)
        
    timer = Timer(average=True)
    timer.attach(trainer,
                start=Events.EPOCH_STARTED,
                resume=Events.ITERATION_STARTED,
                pause=Events.ITERATION_COMPLETED,
                step=Events.ITERATION_COMPLETED)
    
    epoch_timer = Timer(average=True)
    epoch_timer.attach(trainer,
                      start = Events.EPOCH_STARTED,
                      resume = Events.EPOCH_STARTED,
                      pause=Events.EPOCH_COMPLETED,
                      step=Events.EPOCH_COMPLETED)

    def score_function(engine):
        # we want to stop training at 94%
        # early stopping just looks at metric no changing?
        # how to threshold this?
        accur = engine.state.metrics['accuracy']
        if accur > 0.94:
            res = True
        else:
            res = accur

        return res
    
    #es_handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)
    # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
    #evaluator.add_event_handler(Events.COMPLETED, es_handler)

    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(
        initial=0, leave=False, total=len(train_loader),
        desc=desc.format(0)
    )

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % args.log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(args.log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        scheduler.step()
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['cross_entropy']
        
        tqdm.write(
            "Training Results - Epoch: {} LR: {:.2f} Avg loss: {:.2f} Avg accuracy: {:.2f}".format(
                    engine.state.epoch, scheduler.get_lr(), avg_loss, avg_accuracy, )
                
        )
        
        print("Average Batch Time: {:.3f} Epoch Time: {:.3f} ".format(timer.value(), epoch_timer.value()))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['cross_entropy']

        tqdm.write(
            "Validation Results - Epoch: {} Avg loss: {:.2f} Avg accuracy: {:.2f}".format(
                    engine.state.epoch, avg_loss, avg_accuracy, )
        )

        pbar.n = pbar.last_print_n = 0
        
        save_checkpoint({'epoch': engine.state.epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'val_loss': avg_loss,
                        'optimizer': optimizer.state_dict()},
                        False)

    trainer.run(train_loader, max_epochs=args.epochs)
    pbar.close()
    
    print("Average Batch Time: {:.3f} Epoch Time: {:.3f} ".format(timer.value(), epoch_timer.value()))
Exemple #8
0
class GANTrainer:
    def __init__(self,
                 G,
                 D,
                 criterionG,
                 criterionD,
                 optimizerG,
                 optimizerD,
                 lr_schedulerG=None,
                 lr_schedulerD=None,
                 metrics={},
                 device=None,
                 save_path=".",
                 name="Net"):

        self.G = G
        self.D = D
        self.criterionG = criterionG
        self.criterionD = criterionD
        self.optimizerG = optimizerG
        self.optimizerD = optimizerD
        self.lr_schedulerG = lr_schedulerG
        self.lr_schedulerD = lr_schedulerD
        self.metrics = metrics
        self.device = device or ('cuda'
                                 if torch.cuda.is_available() else 'cpu')
        self.save_path = save_path
        self.name = name

        self.metric_history = defaultdict(list)
        self._print_callbacks = set([lambda msg: print(msg, end='')])
        self._weixin_logined = False
        self._timer = Timer()
        self._epochs = 0

        self.G.to(self.device)
        self.D.to(self.device)
        # self._create_evaluator()

    def _create_engine(self):
        engine = create_gan_trainer(self.G, self.D, self.criterionG,
                                    self.criterionD, self.optimizerG,
                                    self.optimizerD, self.lr_schedulerG,
                                    self.lr_schedulerD, self.metrics,
                                    self.device)
        engine.add_event_handler(Events.EPOCH_STARTED, self._lr_scheduler_step)
        self._timer.attach(engine, start=Events.EPOCH_STARTED)
        return engine

    def _on(self, event_name, f, *args, **kwargs):
        cancel_event(self.engine, event_name, f)
        self.engine.add_event_handler(event_name, f, *args, **kwargs)

    def _fprint(self, msg):
        for f in self._print_callbacks:
            try:
                f(msg)
            except Exception as e:
                pass

    def _enable_send_weixin(self):
        if self._weixin_logined:
            self._print_callbacks.add(send_weixin)
        else:
            print("Weixin is not logged in.")

    def _disable_send_weixin(self):
        self._print_callbacks.discard(send_weixin)

    def _lr_scheduler_step(self, engine):
        if self.lr_schedulerG:
            self.lr_schedulerG.step()
        if self.lr_schedulerD:
            self.lr_schedulerD.step()

    def _log_epochs(self, engine, epochs):
        self._epochs += 1
        self._fprint(
            "Epoch %d/%d\n" %
            (self._epochs, self._epochs + epochs - engine.state.epoch))

    def _evaluate(self, engine, val_loader):
        self.evaluator.run(val_loader)

    def _log_results(self, engine, validate):
        elapsed = int(self._timer.value())
        msg = ""
        msg += "elapsed: %ds\t" % elapsed
        for name, val in engine.state.metrics.items():
            msg += "%s: %.4f\t" % (name, val)
            self.metric_history[name].append(val)
        msg += '\n'
        if validate:
            msg += "validate ------\t"
            for name, val in self.evaluator.state.metrics.items():
                msg += "%s: %.4f\t" % (name, val)
                self.metric_history["val_" + name].append(val)
            msg += "\n"
        self._fprint(msg)

    def fit(self,
            train_loader,
            epochs,
            val_loader=None,
            send_weixin=False,
            save_per_epochs=None,
            callbacks=[]):
        validate = val_loader is not None
        # Weixin
        if send_weixin:
            self._enable_send_weixin()

        # Create engine
        engine = self._create_engine()

        # Register events
        engine.add_event_handler(Events.EPOCH_STARTED, self._log_epochs,
                                 epochs)

        if validate:
            engine.add_event_handler(Events.EPOCH_COMPLETED, self._evaluate,
                                     val_loader)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self._log_results,
                                 validate)

        # Set checkpoint
        if save_per_epochs:
            checkpoint_handler = ModelCheckpoint(self.save_path,
                                                 self.name,
                                                 save_per_epochs,
                                                 save_as_state_dict=True,
                                                 require_empty=False)
            checkpoint_handler._iteration = self.epochs()
            engine.add_event_handler(Events.EPOCH_COMPLETED,
                                     checkpoint_handler, {"trainer": self})

        for callback in callbacks:
            engine.add_event_handler(Events.EPOCH_COMPLETED,
                                     _callback_wrapper(callback), self)

        # Run
        engine.run(train_loader, epochs)

        # Destroy
        self._disable_send_weixin()

        # Return history
        hist = {
            metric: hist[-epochs:]
            for metric, hist in self.metric_history.items()
        }
        if not validate:
            hist = keyfilter(lambda k: not k.startswith("val_"), hist)
        return hist

    def state_dict(self):
        s = {
            "epochs": self.epochs(),
            "models": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "criterion": self.criterion.state_dict(),
            "lr_scheduler": None,
            "metric_history": self.metric_history,
        }
        if self.lr_scheduler:
            s["lr_scheduler"] = self.lr_scheduler.state_dict()
        return s

    def load_state_dict(self, state_dict):
        epochs, model, optimizer, criterion, lr_scheduler, metric_history = get(
            [
                "epochs", "models", "optimizer", "criterion", "lr_scheduler",
                "metric_history"
            ], state_dict)
        self._epochs = epochs
        self.model.load_state_dict(model)
        self.optimizer.load_state_dict(optimizer)
        self.criterion.load_state_dict(criterion)
        if self.lr_scheduler and lr_scheduler:
            self.lr_scheduler.load_state_dict(lr_scheduler)
        self.metric_history = metric_history

    def epochs(self):
        return self._epochs

    def login_weixin(self, save_path='.'):
        import itchat
        itchat.logout()
        itchat.auto_login(hotReload=True,
                          enableCmdQR=2,
                          statusStorageDir=os.path.join(
                              save_path, 'weixin.pkl'))
        self._weixin_logined = True

    def logout_weixin(self):
        import itchat
        itchat.logout()
        self._weixin_logined = False

    def set_lr(self, lr):
        set_lr(lr, self.optimizer, self.lr_scheduler)
class Trainer:
    def __init__(self,
                 model,
                 criterion,
                 optimizer,
                 lr_scheduler=None,
                 metrics=None,
                 test_metrics=None,
                 save_path=".",
                 name="Net"):

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.metrics = metrics or {}
        self.test_metrics = test_metrics
        if test_metrics is None:
            self.test_metrics = metrics.copy()
            if 'loss' in metrics and isinstance(metrics['loss'], TrainLoss):
                self.test_metrics['loss'] = Loss(criterion=criterion)
        self.save_path = os.path.join(save_path, 'trainer')
        self.name = name

        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        log_dir = os.path.join(save_path, 'runs', self.name, current_time)
        self.writer = SummaryWriter(log_dir)

        self.metric_history = defaultdict(list)
        self.device = 'cuda' if CUDA else 'cpu'
        self._timer = Timer()
        self._epochs = 0

        self.model.to(self.device)

    def _log_epochs(self, engine, epochs):
        print(
            "Epoch %d/%d" %
            (self._epochs + 1, self._epochs + 1 + epochs - engine.state.epoch))

    def _lr_scheduler_step(self, engine):
        data_loader = engine.state.dataloader
        iteration = engine.state.iteration - 1
        iters_per_epoch = len(data_loader)
        cur_iter = iteration % iters_per_epoch
        if self.lr_scheduler:
            self.lr_scheduler.step(self.epochs() +
                                   (cur_iter / iters_per_epoch))

    def _attach_timer(self, engine):
        self._timer.attach(engine, start=Events.EPOCH_STARTED)

    def _increment_epoch(self, engine):
        self._epochs += 1

    def _log_results(self, engine):
        elapsed = int(self._timer.value())
        msg = "elapsed: %ds\t" % elapsed
        for name, val in engine.state.metrics.items():
            if isinstance(val, float):
                msg += "%s: %.4f\t" % (name, val)
                self.writer.add_scalar(name, val, self.epochs())
            else:
                msg += "%s: %s\t" % (name, val)
                for i, v in enumerate(val):
                    pass
                    self.writer.add_scalar("%s-%d" % (name, i + 1), v,
                                           self.epochs())
            self.metric_history[name].append(val)
        print(msg)

    def _log_val_results(self, engine, evaluator, per_epochs=1):
        if engine.state.epoch % per_epochs != 0:
            return
        msg = "validate ------\t"
        for name, val in evaluator.state.metrics.items():
            if isinstance(val, float):
                msg += "%s: %.4f\t" % (name, val)
                self.writer.add_scalar(name, val, self.epochs())
            else:
                msg += "%s: %s\t" % (name, val)
                for i, v in enumerate(val):
                    pass
                    self.writer.add_scalar("%s-%d" % (name, i + 1), v,
                                           self.epochs())
            self.metric_history["val_" + name].append(val)
        print(msg)

    def fit(self,
            train_loader,
            epochs=1,
            val_loader=None,
            save=None,
            iterations=None,
            callbacks=()):

        engine = create_supervised_trainer(self.model, self.criterion,
                                           self.optimizer, self.metrics,
                                           self.device)
        self._attach_timer(engine)

        engine.add_event_handler(Events.ITERATION_STARTED,
                                 self._lr_scheduler_step)

        engine.add_event_handler(Events.EPOCH_STARTED, self._log_epochs,
                                 epochs)

        if val_loader is not None:
            if isinstance(val_loader, tuple):
                val_loader, eval_per_epochs = val_loader
            else:
                eval_per_epochs = 1
            evaluator = create_supervised_evaluator(self.model,
                                                    self.test_metrics,
                                                    self.device)
            engine.add_event_handler(Events.EPOCH_COMPLETED, _evaluate,
                                     evaluator, val_loader, eval_per_epochs)

        engine.add_event_handler(Events.EPOCH_COMPLETED, self._increment_epoch)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self._log_results)
        if val_loader is not None:
            engine.add_event_handler(Events.EPOCH_COMPLETED,
                                     self._log_val_results, evaluator,
                                     eval_per_epochs)

        # Set checkpoint
        if save:
            checkpoint_handler = save.parse(self)
            engine.add_event_handler(Events.EPOCH_COMPLETED,
                                     checkpoint_handler, {"trainer": self})

        for callback in callbacks:
            engine.add_event_handler(Events.EPOCH_COMPLETED, wrap(callback),
                                     self)

        if iterations:
            engine.add_event_handler(Events.ITERATION_COMPLETED,
                                     _terminate_on_iterations, iterations)
            epochs = 1000

        # Run
        engine.run(train_loader, epochs)

        # Return history
        hist = {
            metric: hist[-epochs:]
            for metric, hist in self.metric_history.items()
        }
        if val_loader is None:
            hist = keyfilter(lambda k: not k.startswith("val_"), hist)
        return hist

    def fit1(self,
             train_loader,
             epochs,
             validate=None,
             save=None,
             callbacks=()):
        validate = ValSet.parse(validate, self)

        engine = create_supervised_trainer(self.model, self.criterion,
                                           self.optimizer, self.metrics,
                                           self.device)
        self._attach_timer(engine)

        engine.add_event_handler(Events.EPOCH_STARTED, self._log_epochs,
                                 epochs)
        engine.add_event_handler(Events.ITERATION_STARTED,
                                 self._lr_scheduler_step)

        engine.add_event_handler(Events.EPOCH_COMPLETED, self._increment_epoch)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self._log_results)

        engine.add_event_handler(Events.EPOCH_COMPLETED,
                                 wrap(validate.evaluate))
        engine.add_event_handler(Events.EPOCH_COMPLETED,
                                 wrap(validate.log_results))

        # Set checkpoint
        if save:
            checkpoint_handler = save.parse(self)
            engine.add_event_handler(Events.EPOCH_COMPLETED,
                                     checkpoint_handler, {"trainer": self})

        for callback in callbacks:
            engine.add_event_handler(Events.EPOCH_COMPLETED, wrap(callback),
                                     self)

        engine.run(train_loader, epochs)

        # Return history
        hist = {
            metric: hist[-epochs:]
            for metric, hist in self.metric_history.items()
        }
        return hist

    def fit2(self, train_loader, epochs=1, save=None, callbacks=()):

        engine = create_supervised_trainer(self.model, self.criterion,
                                           self.optimizer, self.metrics,
                                           self.device)

        engine.add_event_handler(Events.ITERATION_STARTED,
                                 self._lr_scheduler_step)

        self._timer.attach(engine, start=Events.EPOCH_STARTED)
        engine.add_event_handler(Events.EPOCH_STARTED, self._log_epochs,
                                 epochs)

        engine.add_event_handler(Events.EPOCH_COMPLETED, self._increment_epoch)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self._log_results)

        # Set checkpoint
        if save:
            checkpoint_handler = save.parse(self)
            engine.add_event_handler(Events.EPOCH_COMPLETED,
                                     checkpoint_handler, {"trainer": self})

        for callback in callbacks:
            engine.add_event_handler(Events.EPOCH_COMPLETED, wrap(callback),
                                     self)

        # Run
        engine.run(train_loader, epochs)

        # Return history
        hist = {
            metric: hist[-epochs:]
            for metric, hist in self.metric_history.items()
        }
        hist = keyfilter(lambda k: not k.startswith("val_"), hist)
        return hist

    def state_dict(self):
        s = {
            "epochs": self.epochs(),
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "lr_scheduler": None,
            "metric_history": self.metric_history,
        }
        if self.lr_scheduler:
            s["lr_scheduler"] = self.lr_scheduler.state_dict()
        return s

    def load_state_dict(self, state_dict):
        epochs, model, optimizer, lr_scheduler, metric_history = get(
            ["epochs", "model", "optimizer", "lr_scheduler", "metric_history"],
            state_dict)
        self._epochs = epochs
        self.model.load_state_dict(model)
        self.optimizer.load_state_dict(optimizer)
        if self.lr_scheduler and lr_scheduler:
            self.lr_scheduler.load_state_dict(lr_scheduler)
        self.metric_history = metric_history

    def save(self):
        d = Path(self.save_path)
        d.mkdir(parents=True, exist_ok=True)
        filename = "%s_trainer_%d.pth" % (self.name, self.epochs())
        fp = d / filename
        torch.save(self.state_dict(), fp)
        print("Save trainer as %s" % fp)

    def load(self):
        d = Path(self.save_path)
        pattern = "%s_trainer*.pth" % self.name
        saves = list(d.glob(pattern))
        if len(saves) == 0:
            raise FileNotFoundError("No checkpoint to load for %s in %s" %
                                    (self.name, self.save_path))
        fp = max(saves, key=lambda f: f.stat().st_mtime)
        self.load_state_dict(torch.load(fp, map_location=self.device))
        print("Load trainer from %s" % fp)

    def epochs(self):
        return self._epochs

    def evaluate(self, test_loader, evaluate_metrics=None):
        if evaluate_metrics is None:
            evaluate_metrics = self.test_metrics
        evaluator = create_supervised_evaluator(self.model, evaluate_metrics,
                                                self.device)
        return evaluator.run(test_loader).metrics

    def set_lr(self, lr):
        set_lr(lr, self.optimizer, self.lr_scheduler)
Exemple #10
0
class GANTrainer:
    def __init__(self,
                 G,
                 D,
                 criterionG,
                 criterionD,
                 optimizerG,
                 optimizerD,
                 lr_schedulerG=None,
                 lr_schedulerD=None,
                 make_latent=None,
                 metrics=None,
                 save_path=".",
                 name="GAN",
                 gan_type='gan'):

        self.G = G
        self.D = D
        self.criterionG = criterionG
        self.criterionD = criterionD
        self.optimizerG = optimizerG
        self.optimizerD = optimizerD
        self.lr_schedulerG = lr_schedulerG
        self.lr_schedulerD = lr_schedulerD
        self.make_latent = make_latent
        self.metrics = metrics or {}
        self.name = name
        root = Path(save_path).expanduser().absolute()
        self.save_path = root / 'gan_trainer' / self.name

        self.metric_history = defaultdict(list)
        self.device = 'cuda' if CUDA else 'cpu'
        self._timer = Timer()
        self._iterations = 0

        self.G.to(self.device)
        self.D.to(self.device)

        assert gan_type in ['gan', 'acgan', 'cgan', 'infogan']
        if gan_type == 'gan':
            self.create_fn = create_gan_trainer
        elif gan_type == 'acgan':
            self.create_fn = create_acgan_trainer
        elif gan_type == 'cgan':
            self.create_fn = create_cgan_trainer
        elif gan_type == 'infogan':
            self.create_fn = create_infogan_trainer

    def _lr_scheduler_step(self, engine):
        if self.lr_schedulerG:
            self.lr_schedulerG.step()
        if self.lr_schedulerD:
            self.lr_schedulerD.step()

    def _attach_timer(self, engine):
        self._trained_time = 0
        self._timer.reset()

    def _increment_iteration(self, engine):
        self._iterations += 1

    def _log_results(self, engine, log_interval, max_iter):
        if self._iterations % log_interval != 0:
            return
        i = engine.state.iteration
        elapsed = self._timer.value()
        self._timer.reset()
        self._trained_time += elapsed
        trained_time = self._trained_time
        eta_seconds = int((trained_time / i) * (max_iter - i))
        it_fmt = "%" + str(len(str(max_iter))) + "d"
        print(("Iter: " + it_fmt + ", Cost: %.2fs, Eta: %s") %
              (self._iterations, elapsed,
               datetime.timedelta(seconds=eta_seconds)))

        for name, metric in self.metrics.items():
            metric.completed(engine, name)
            metric.reset()

        msg = ""
        for name, val in engine.state.metrics.items():
            msg += "%s: %.4f\t" % (name, val)
            self.metric_history[name].append(val)
        print(msg)

    def fit(self, it, max_iter, log_interval=100, callbacks=()):

        engine = self.create_fn(self.G, self.D, self.criterionG,
                                self.criterionD, self.optimizerG,
                                self.optimizerD, self.make_latent,
                                self.metrics, self.device)
        self._attach_timer(engine)

        engine.add_event_handler(Events.ITERATION_STARTED,
                                 self._lr_scheduler_step)

        engine.add_event_handler(Events.ITERATION_COMPLETED,
                                 self._increment_iteration)
        engine.add_event_handler(Events.ITERATION_COMPLETED, self._log_results,
                                 log_interval, max_iter)

        for callback in callbacks:
            engine.add_event_handler(Events.ITERATION_COMPLETED,
                                     _trainer_callback_wrap(callback), self)

        # Run
        engine.run(it, max_iter)

        # Return history
        return self.metric_history

    def state_dict(self):
        s = {
            "iterations": self.iterations(),
            "G": self.G.state_dict(),
            "D": self.D.state_dict(),
            "optimizerG": self.optimizerG.state_dict(),
            "optimizerD": self.optimizerD.state_dict(),
            "criterionG": self.criterionG.state_dict(),
            "criterionD": self.criterionD.state_dict(),
            "lr_schedulerG": None,
            "lr_schedulerD": None,
            "metric_history": self.metric_history,
        }
        if self.lr_schedulerG:
            s["lr_schedulerG"] = self.lr_schedulerG.state_dict()
        if self.lr_schedulerD:
            s["lr_schedulerD"] = self.lr_schedulerD.state_dict()
        return s

    def load_state_dict(self, state_dict):
        iterations, G, D, optimizerG, optimizerD, criterionG, criterionD, lr_schedulerG, lr_schedulerD, metric_history = get(
            [
                "iterations", "G", "D", "optimizerG", "optimizerD",
                "criterionG", "criterionD", "lr_schedulerG", "lr_schedulerD",
                "metric_history"
            ], state_dict)
        self._iterations = iterations
        self.G.load_state_dict(G)
        self.D.load_state_dict(D)
        self.optimizerG.load_state_dict(optimizerG)
        self.optimizerD.load_state_dict(optimizerD)
        self.criterionG.load_state_dict(criterionG)
        self.criterionD.load_state_dict(criterionD)
        if self.lr_schedulerG and lr_schedulerG:
            self.lr_schedulerG.load_state_dict(lr_schedulerG)
        if self.lr_schedulerD and lr_schedulerD:
            self.lr_schedulerD.load_state_dict(lr_schedulerD)
        self.metric_history = metric_history

    def save(self, remove_prev=True):
        d = self.save_path
        d.mkdir(parents=True, exist_ok=True)

        if remove_prev:
            pattern = "%s_trainer*.pth" % self.name
            saves = list(d.glob(pattern))
            if len(saves) != 0:
                fp = max(saves, key=lambda f: f.stat().st_mtime)
                p = "%s_trainer_(?P<iters>[0-9]+).pth" % self.name
                iters = int(re.match(p, fp.name).group('iters'))
                if self.iterations() > iters:
                    fp.unlink()

        filename = "%s_trainer_%d.pth" % (self.name, self.iterations())
        fp = d / filename
        torch.save(self.state_dict(), fp)
        print("Save trainer as %s" % fp)

    def load(self):
        d = self.save_path
        pattern = "%s_trainer*.pth" % self.name
        saves = list(d.glob(pattern))
        if len(saves) == 0:
            raise FileNotFoundError("No checkpoint to load for %s in %s" %
                                    (self.name, self.save_path))
        fp = max(saves, key=lambda f: f.stat().st_mtime)
        self.load_state_dict(torch.load(fp, map_location=self.device))
        print("Load trainer from %s" % fp)

    def iterations(self):
        return self._iterations
Exemple #11
0
def train_model(
    name="",
    resume="",
    base_dir=utils.BASE_DIR,
    model_name="v0",
    chosen_diseases=None,
    n_epochs=10,
    batch_size=4,
    oversample=False,
    max_os=None,
    shuffle=False,
    opt="sgd",
    opt_params={},
    loss_name="wbce",
    loss_params={},
    train_resnet=False,
    log_metrics=None,
    flush_secs=120,
    train_max_images=None,
    val_max_images=None,
    test_max_images=None,
    experiment_mode="debug",
    save=True,
    save_cms=True,  # Note that in this case, save_cms (to disk) includes write_cms (to TB)
    write_graph=False,
    write_emb=False,
    write_emb_img=False,
    write_img=False,
    image_format="RGB",
    multiple_gpu=False,
):

    # Choose GPU
    device = utilsT.get_torch_device()
    print("Using device: ", device)

    # Common folders
    dataset_dir = os.path.join(base_dir, "dataset")

    # Dataset handling
    print("Loading train dataset...")
    train_dataset, train_dataloader = utilsT.prepare_data(
        dataset_dir,
        "train",
        chosen_diseases,
        batch_size,
        oversample=oversample,
        max_os=max_os,
        shuffle=shuffle,
        max_images=train_max_images,
        image_format=image_format,
    )
    train_samples, _ = train_dataset.size()

    print("Loading val dataset...")
    val_dataset, val_dataloader = utilsT.prepare_data(
        dataset_dir,
        "val",
        chosen_diseases,
        batch_size,
        max_images=val_max_images,
        image_format=image_format,
    )
    val_samples, _ = val_dataset.size()

    # Should be the same than chosen_diseases
    chosen_diseases = list(train_dataset.classes)
    print("Chosen diseases: ", chosen_diseases)

    if resume:
        # Load model and optimizer
        model, model_name, optimizer, opt, loss_name, loss_params, chosen_diseases = models.load_model(
            base_dir, resume, experiment_mode="", device=device)
        model.train(True)
    else:
        # Create model
        model = models.init_empty_model(model_name,
                                        chosen_diseases,
                                        train_resnet=train_resnet).to(device)

        # Create optimizer
        OptClass = optimizers.get_optimizer_class(opt)
        optimizer = OptClass(model.parameters(), **opt_params)
        # print("OPT: ", opt_params)

    # Allow multiple GPUs
    if multiple_gpu:
        model = DataParallel(model)

    # Tensorboard log options
    run_name = utils.get_timestamp()
    if name:
        run_name += "_{}".format(name)

    if len(chosen_diseases) == 1:
        run_name += "_{}".format(chosen_diseases[0])
    elif len(chosen_diseases) == 14:
        run_name += "_all"

    log_dir = get_log_dir(base_dir, run_name, experiment_mode=experiment_mode)

    print("Run name: ", run_name)
    print("Saved TB in: ", log_dir)

    writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)

    # Create validator engine
    validator = Engine(
        utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params,
                           False))

    val_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
    val_loss.attach(validator, loss_name)

    utilsT.attach_metrics(validator, chosen_diseases, "prec", Precision, True)
    utilsT.attach_metrics(validator, chosen_diseases, "recall", Recall, True)
    utilsT.attach_metrics(validator, chosen_diseases, "acc", Accuracy, True)
    utilsT.attach_metrics(validator, chosen_diseases, "roc_auc",
                          utilsT.RocAucMetric, False)
    utilsT.attach_metrics(validator,
                          chosen_diseases,
                          "cm",
                          ConfusionMatrix,
                          get_transform_fn=utilsT.get_transform_cm,
                          metric_args=(2, ))
    utilsT.attach_metrics(validator,
                          chosen_diseases,
                          "positives",
                          RunningAverage,
                          get_transform_fn=utilsT.get_count_positives)

    # Create trainer engine
    trainer = Engine(
        utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params,
                           True))

    train_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
    train_loss.attach(trainer, loss_name)

    utilsT.attach_metrics(trainer, chosen_diseases, "acc", Accuracy, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "prec", Precision, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "recall", Recall, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "roc_auc",
                          utilsT.RocAucMetric, False)
    utilsT.attach_metrics(trainer,
                          chosen_diseases,
                          "cm",
                          ConfusionMatrix,
                          get_transform_fn=utilsT.get_transform_cm,
                          metric_args=(2, ))
    utilsT.attach_metrics(trainer,
                          chosen_diseases,
                          "positives",
                          RunningAverage,
                          get_transform_fn=utilsT.get_count_positives)

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 step=Events.EPOCH_COMPLETED)

    # TODO: Early stopping
    #     def score_function(engine):
    #         val_loss = engine.state.metrics[loss_name]
    #         return -val_loss

    #     handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
    #     validator.add_event_handler(Events.COMPLETED, handler)

    # Metrics callbacks
    if log_metrics is None:
        log_metrics = list(ALL_METRICS)

    def _write_metrics(run_type, metrics, epoch, wall_time):
        loss = metrics.get(loss_name, 0)

        writer.add_scalar("Loss/" + run_type, loss, epoch, wall_time)

        for metric_base_name in log_metrics:
            for disease in chosen_diseases:
                metric_value = metrics.get(
                    "{}_{}".format(metric_base_name, disease), -1)
                writer.add_scalar(
                    "{}_{}/{}".format(metric_base_name, disease, run_type),
                    metric_value, epoch, wall_time)

    @trainer.on(Events.EPOCH_COMPLETED)
    def tb_write_metrics(trainer):
        epoch = trainer.state.epoch
        max_epochs = trainer.state.max_epochs

        # Run on evaluation
        validator.run(val_dataloader, 1)

        # Common time
        wall_time = time.time()

        # Log all metrics to TB
        _write_metrics("train", trainer.state.metrics, epoch, wall_time)
        _write_metrics("val", validator.state.metrics, epoch, wall_time)

        train_loss = trainer.state.metrics.get(loss_name, 0)
        val_loss = validator.state.metrics.get(loss_name, 0)

        tb_write_histogram(writer, model, epoch, wall_time)

        print("Finished epoch {}/{}, loss {:.3f}, val loss {:.3f} (took {})".
              format(epoch, max_epochs, train_loss, val_loss,
                     utils.duration_to_str(int(timer._elapsed()))))

    # Hparam dict
    hparam_dict = {
        "resume": resume,
        "n_diseases": len(chosen_diseases),
        "diseases": ",".join(chosen_diseases),
        "n_epochs": n_epochs,
        "batch_size": batch_size,
        "shuffle": shuffle,
        "model_name": model_name,
        "opt": opt,
        "loss": loss_name,
        "samples (train, val)": "{},{}".format(train_samples, val_samples),
        "train_resnet": train_resnet,
        "multiple_gpu": multiple_gpu,
    }

    def copy_params(params_dict, base_name):
        for name, value in params_dict.items():
            hparam_dict["{}_{}".format(base_name, name)] = value

    copy_params(loss_params, "loss")
    copy_params(opt_params, "opt")
    print("HPARAM: ", hparam_dict)

    # Train
    print("-" * 50)
    print("Training...")
    trainer.run(train_dataloader, n_epochs)

    # Capture time
    secs_per_epoch = timer.value()
    duration_per_epoch = utils.duration_to_str(int(secs_per_epoch))
    print("Average time per epoch: ", duration_per_epoch)
    print("-" * 50)

    ## Write all hparams
    hparam_dict["duration_per_epoch"] = duration_per_epoch

    # FIXME: this is commented to avoid having too many hparams in TB frontend
    # metrics
    #     def copy_metrics(engine, engine_name):
    #         for metric_name, metric_value in engine.state.metrics.items():
    #             hparam_dict["{}_{}".format(engine_name, metric_name)] = metric_value
    #     copy_metrics(trainer, "train")
    #     copy_metrics(validator, "val")

    print("Writing TB hparams")
    writer.add_hparams(hparam_dict, {})

    # Save model to disk
    if save:
        print("Saving model...")
        models.save_model(base_dir, run_name, model_name, experiment_mode,
                          hparam_dict, trainer, model, optimizer)

    # Write graph to TB
    if write_graph:
        print("Writing TB graph...")
        tb_write_graph(writer, model, train_dataloader, device)

    # Write embeddings to TB
    if write_emb:
        print("Writing TB embeddings...")
        image_size = 256 if write_emb_img else 0

        # FIXME: be able to select images (balanced, train vs val, etc)
        image_list = list(train_dataset.label_index["FileName"])[:1000]
        # disease = chosen_diseases[0]
        # positive = train_dataset.label_index[train_dataset.label_index[disease] == 1]
        # negative = train_dataset.label_index[train_dataset.label_index[disease] == 0]
        # positive_images = list(positive["FileName"])[:25]
        # negative_images = list(negative["FileName"])[:25]
        # image_list = positive_images + negative_images

        all_images, all_embeddings, all_predictions, all_ground_truths = gen_embeddings(
            model,
            train_dataset,
            device,
            image_list=image_list,
            image_size=image_size)
        tb_write_embeddings(
            writer,
            chosen_diseases,
            all_images,
            all_embeddings,
            all_predictions,
            all_ground_truths,
            global_step=n_epochs,
            use_images=write_emb_img,
            tag="1000_{}".format("img" if write_emb_img else "no_img"),
        )

    # Save confusion matrices (is expensive to calculate them afterwards)
    if save_cms:
        print("Saving confusion matrices...")
        # Assure folder
        cms_dir = os.path.join(base_dir, "cms", experiment_mode)
        os.makedirs(cms_dir, exist_ok=True)
        base_fname = os.path.join(cms_dir, run_name)

        n_diseases = len(chosen_diseases)

        def extract_cms(metrics):
            """Extract confusion matrices from a metrics dict."""
            cms = []
            for disease in chosen_diseases:
                key = "cm_" + disease
                if key not in metrics:
                    cm = np.array([[-1, -1], [-1, -1]])
                else:
                    cm = metrics[key].numpy()

                cms.append(cm)
            return np.array(cms)

        # Train confusion matrix
        train_cms = extract_cms(trainer.state.metrics)
        np.save(base_fname + "_train", train_cms)
        tb_write_cms(writer, "train", chosen_diseases, train_cms)

        # Validation confusion matrix
        val_cms = extract_cms(validator.state.metrics)
        np.save(base_fname + "_val", val_cms)
        tb_write_cms(writer, "val", chosen_diseases, val_cms)

        # All confusion matrix (train + val)
        all_cms = train_cms + val_cms
        np.save(base_fname + "_all", all_cms)

        # Print to console
        if len(chosen_diseases) == 1:
            print("Train CM: ")
            print(train_cms[0])
            print("Val CM: ")
            print(val_cms[0])


#             print("Train CM 2: ")
#             print(trainer.state.metrics["cm_" + chosen_diseases[0]])
#             print("Val CM 2: ")
#             print(validator.state.metrics["cm_" + chosen_diseases[0]])

    if write_img:
        # NOTE: this option is not recommended, use Testing notebook to plot and analyze images

        print("Writing images to TB...")

        test_dataset, test_dataloader = utilsT.prepare_data(
            dataset_dir,
            "test",
            chosen_diseases,
            batch_size,
            max_images=test_max_images,
        )

        # TODO: add a way to select images?
        # image_list = list(test_dataset.label_index["FileName"])[:3]

        # Examples in test_dataset (with bboxes available):
        image_list = [
            # "00010277_000.png", # (Effusion, Infiltrate, Mass, Pneumonia)
            # "00018427_004.png", # (Atelectasis, Effusion, Mass)
            # "00021703_001.png", # (Atelectasis, Effusion, Infiltrate)
            # "00028640_008.png", # (Effusion, Infiltrate)
            # "00019124_104.png", # (Pneumothorax)
            # "00019124_090.png", # (Nodule)
            # "00020318_007.png", # (Pneumothorax)
            "00000003_000.png",  # (0)
            # "00000003_001.png", # (0)
            # "00000003_002.png", # (0)
            "00000732_005.png",  # (Cardiomegaly, Pneumothorax)
            # "00012261_001.png", # (Cardiomegaly, Pneumonia)
            # "00013249_033.png", # (Cardiomegaly, Pneumonia)
            # "00029808_003.png", # (Cardiomegaly, Pneumonia)
            # "00022215_012.png", # (Cardiomegaly, Pneumonia)
            # "00011402_007.png", # (Cardiomegaly, Pneumonia)
            # "00019018_007.png", # (Cardiomegaly, Infiltrate)
            # "00021009_001.png", # (Cardiomegaly, Infiltrate)
            # "00013670_151.png", # (Cardiomegaly, Infiltrate)
            # "00005066_030.png", # (Cardiomegaly, Infiltrate, Effusion)
            "00012288_000.png",  # (Cardiomegaly)
            "00008399_007.png",  # (Cardiomegaly)
            "00005532_000.png",  # (Cardiomegaly)
            "00005532_014.png",  # (Cardiomegaly)
            "00005532_016.png",  # (Cardiomegaly)
            "00005827_000.png",  # (Cardiomegaly)
            # "00006912_007.png", # (Cardiomegaly)
            # "00007037_000.png", # (Cardiomegaly)
            # "00007043_000.png", # (Cardiomegaly)
            # "00012741_004.png", # (Cardiomegaly)
            # "00007551_020.png", # (Cardiomegaly)
            # "00007735_040.png", # (Cardiomegaly)
            # "00008339_010.png", # (Cardiomegaly)
            # "00008365_000.png", # (Cardiomegaly)
            # "00012686_003.png", # (Cardiomegaly)
        ]

        tb_write_images(writer, model, test_dataset, chosen_diseases, n_epochs,
                        device, image_list)

    # Close TB writer
    if experiment_mode != "debug":
        writer.close()

    # Run post_train
    print("-" * 50)
    print("Running post_train...")

    print("Loading test dataset...")
    test_dataset, test_dataloader = utilsT.prepare_data(
        dataset_dir,
        "test",
        chosen_diseases,
        batch_size,
        max_images=test_max_images)

    save_cms_with_names(run_name, experiment_mode, model, test_dataset,
                        test_dataloader, chosen_diseases)

    evaluate_model(run_name,
                   model,
                   optimizer,
                   device,
                   loss_name,
                   loss_params,
                   chosen_diseases,
                   test_dataloader,
                   experiment_mode=experiment_mode,
                   base_dir=base_dir)

    # Return values for debugging
    model_run = ModelRun(model, run_name, model_name, chosen_diseases)
    if experiment_mode == "debug":
        model_run.save_debug_data(writer, trainer, validator, train_dataset,
                                  train_dataloader, val_dataset,
                                  val_dataloader)

    return model_run
class Trainer:
    """
    Class which setups the training logic which mainly involves defining callback handlers and attaching them to
    the training loop.
    """
    def __init__(self, model, config, evaluator, data_loader, tb_writer,
                 run_info, logger, checkpoint_dir):
        """
        Creates a new trainer object for training a model.
        :param model: model to train. Needs to inherit from the BaseModel class.
        :param config: dictionary containing the whole configuration of the experiment
        :param evaluator: Instance of the evaluator class, used to run evaluation on a specified schedule
        :param data_loader: pytorch data loader providing the training data
        :param tb_writer: tensorboardX summary writer
        :param run_info: sacred run info for loging training progress
        :param logger: python logger object
        :param checkpoint_dir: directory path for storing checkpoints
        """
        self.run_info = run_info
        self.logger = logger
        self.data_loader = data_loader
        self.evaluator = evaluator
        self.engine = Engine(self._step)
        self.model = model
        self.config = config
        self.train_cfg = config['train']
        self.tb_writer = tb_writer

        self.pbar = ProgressBar(ascii=True, desc='* Epoch')
        self.timer = Timer(average=True)
        self.save_last_checkpoint_handler = ModelCheckpoint(
            checkpoint_dir,
            'last',
            save_interval=self.train_cfg['save_interval'],
            n_saved=self.train_cfg['save_n_last'],
            require_empty=False)

        self.add_handler()

    def run(self):
        """
        Start the training loop which will run until all epochs are complete
        :return:
        """
        self.engine.run(self.data_loader,
                        max_epochs=self.train_cfg['n_epochs'])

    def add_handler(self):
        """
        Adds all the callback handlers to the trainer engine. Should be called in the end of the init.
        :return:
        """
        # Learning rate decay
        for lr_s in self.model.schedulers:
            self.engine.add_event_handler(Events.ITERATION_STARTED, lr_s)

        # Checkpoint saving
        self.engine.add_event_handler(Events.EPOCH_STARTED,
                                      self.save_last_checkpoint_handler,
                                      self.model.networks)

        # Progbar
        monitoring_metrics = self.model.metric_names
        for mm in monitoring_metrics:
            RunningAverage(output_transform=self._extract_loss(mm)).attach(
                self.engine, mm)
        self.pbar.attach(self.engine, metric_names=monitoring_metrics)

        # Timer
        self.timer.attach(self.engine,
                          start=Events.EPOCH_STARTED,
                          resume=Events.ITERATION_STARTED,
                          pause=Events.ITERATION_COMPLETED,
                          step=Events.ITERATION_COMPLETED)

        # Logging
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_log_train_results)
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_log_train_images)
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_run_evaluation)
        self.engine.add_event_handler(Events.EPOCH_COMPLETED,
                                      self._handle_print_times)

        # Exception handling
        self.engine.add_event_handler(Events.EXCEPTION_RAISED,
                                      self._handle_exception)

    def _step(self, engine, batch):
        """
        Definition of a single training step. This function gets automatically called by the engine every iteration.
        :param engine: trainer engine
        :param batch: one batch provided by the dataloader
        :return:
        """
        self.model.train()
        self.model.set_input(batch)
        self.model.optimize_parameters()
        return self.model.state

    def _handle_log_train_results(self, engine):
        """
        Handler for writing the losses to tensorboard and sacred.
        :param engine: train engine
        :return:
        """
        if (engine.state.iteration - 1) % self.train_cfg['log_interval'] == 0:
            metrics = engine.state.metrics  # does not include non scalar metrics, since loggers can not handle this

            for m_name, m_val in metrics.items():
                if m_val is None:
                    raise ValueError(f'Value for {m_name} is None')
                self.run_info.log_scalar("train.%s" % m_name, m_val,
                                         engine.state.iteration)
                self.tb_writer.add_scalar("train/%s" % m_name, m_val,
                                          engine.state.iteration)

            for lr_name, lr_val in self.model.learning_rates.items():
                if lr_val is None:
                    raise ValueError(f'Value for {lr_name} is None')
                self.run_info.log_scalar("train.%s" % lr_name, lr_val,
                                         engine.state.iteration)
                self.tb_writer.add_scalar("train/%s" % lr_name, lr_val,
                                          engine.state.iteration)

    def _handle_log_train_images(self, engine):
        """
        Handler for writing visual samples to tensorboard.
        :param engine: train engine
        :return:
        """
        if (engine.state.iteration -
                1) % self.train_cfg['img_log_interval'] == 0:
            for name, visual in self.model.visuals.items():
                # TODO remove the visual.transpose here and put it in the visualization function of the models
                self.tb_writer.add_image('train/%s' % name,
                                         visual.transpose(2, 0, 1),
                                         engine.state.iteration)
            for name, figure in self.model.figures.items():
                self.tb_writer.add_figure('train_metrics/%s' % name, figure,
                                          engine.state.iteration)

    def _handle_run_evaluation(self, engine):
        """
        Handler which will execute evaluation by running the evaluator object.
        :param engine: train engine
        :return:
        """
        if (engine.state.iteration - 1) % self.train_cfg['eval_interval'] == 0:
            self.evaluator.run()

    def _handle_exception(self, engine, e):
        """
        Exception handler which ensures that the model gets saved when stopped through a keyboard interruption.
        :param engine: train engine
        :param e: the exception which caused the training to stop
        :return:
        """
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            self.logger.warning(
                'KeyboardInterrupt caught. Exiting gracefully.')
            self.save_last_checkpoint_handler(engine, self.model.networks)
        else:
            raise e

    def _handle_print_times(self, engine):
        """
        Handler for logging timer information for different training and evaluation steps.
        :param engine: train engine
        :return:
        """
        self.logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, self.timer.value()))
        self.timer.reset()

    @staticmethod
    def _extract_loss(key):
        """
        Helper method to return losses for the RunningAverage
        :param key: (str) loss name
        :return: (fn) for the corresponding key
        """
        def _func(losses):
            return losses[key]

        return _func
Exemple #13
0
class BasicTimeProfiler(object):

    def __init__(self):        
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

    def _reset(self, num_iters, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)        
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters)
        }

    def _as_first_started(self, engine):
        num_iters = engine.state.max_epochs * len(engine.state.dataloader)
        self._reset(len(engine.state.dataloader), engine.state.max_epochs, num_iters)
        
        self.event_handlers_names = {
            e: [h.__qualname__ if hasattr(h, "__qualname__") else h.__class__.__name__ 
                for (h, _, _) in engine._event_handlers[e]]
            for e in Events if e != Events.EXCEPTION_RAISED
        }

        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append((self._as_last_started, (), {}))
        #  - add the first handlers 
        events = [Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, 
                  Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, 
                  Events.COMPLETED]
        fmethods = [self._as_first_epoch_started, self._as_first_epoch_completed, 
                   self._as_first_iter_started, self._as_first_iter_completed,
                   self._as_first_completed]
        lmethods = [self._as_last_epoch_started, self._as_last_epoch_completed, 
                   self._as_last_iter_started, self._as_last_iter_completed,
                   self._as_last_completed]

        for e, m in zip(events, fmethods):
            engine._event_handlers[e].insert(0, (m, (), {}))

        for e, m in zip(events, lmethods):
            engine._event_handlers[e].append((m, (), {}))

        # Let's go
        self._event_handlers_timer.reset()

    def _as_last_started(self, engine):        
        self.event_handlers_times[Events.STARTED][0] = self._event_handlers_timer.value()

    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t

        self._dataflow_timer.reset()

    def _as_first_iter_started(self, engine):        
        t = self._dataflow_timer.value()
        i = engine.state.iteration - 1
        self.dataflow_times[i] = t

        self._event_handlers_timer.reset()

    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t

        self._processing_timer.reset()

    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t

        self._event_handlers_timer.reset()        

    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t

        self._dataflow_timer.reset()

    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t

    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_completed(self, engine):
        self.event_handlers_times[Events.COMPLETED][0] = self._event_handlers_timer.value()
        
        # Remove added handlers:
        remove_handler(engine, self._as_last_started, Events.STARTED)
        
        #  - add the first handlers 
        events = [Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, 
                  Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, 
                  Events.COMPLETED]
        fmethods = [self._as_first_epoch_started, self._as_first_epoch_completed, 
                   self._as_first_iter_started, self._as_first_iter_completed,
                   self._as_first_completed]
        lmethods = [self._as_last_epoch_started, self._as_last_epoch_completed, 
                   self._as_last_iter_started, self._as_last_iter_completed,
                   self._as_last_completed]

        for e, m in zip(events, fmethods):
            remove_handler(engine, m, e)          

        for e, m in zip(events, lmethods):
            remove_handler(engine, m, e)

    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))
        
        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (), {}))
    
    @staticmethod
    def _compute_basic_stats(data):    
        return OrderedDict([
            ('min/index', (torch.min(data).item(), torch.argmin(data).item())),
            ('max/index', (torch.max(data).item(), torch.argmax(data).item())),
            ('mean', torch.mean(data).item()),
            ('std', torch.std(data).item()),
            ('total', torch.sum(data).item())
        ])
            
    def get_results(self):
        total_eh_time = sum([sum(self.event_handlers_times[e]) for e in Events if e != Events.EXCEPTION_RAISED])
        return OrderedDict([
            ("processing_stats", self._compute_basic_stats(self.processing_times)),
            ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
            ("event_handlers_stats",
                dict([(str(e).replace(".","_"), self._compute_basic_stats(self.event_handlers_times[e])) for e in Events 
                      if e != Events.EXCEPTION_RAISED] + 
                     [("total_time", total_eh_time)])
            ),
            ("event_handlers_names", {str(e).replace(".","_") + "_names": v 
                                      for e, v in self.event_handlers_names.items()})
        ])

    @staticmethod
    def print_results(results):

        def odict_to_str(d):
            out = ""
            for k, v in d.items():
                out += "\t{}: {}\n".format(k, v)
            return out

        others = {k: odict_to_str(v) if isinstance(v, OrderedDict) else v 
                  for k, v in results['event_handlers_stats'].items()}
        
        others.update(results['event_handlers_names'])
        
        output_message = """
--------------------------------------------
- Time profiling results:
--------------------------------------------
              
Processing function time stats (in seconds):
{processing_stats}

Dataflow time stats (in seconds):
{dataflow_stats}
        
Time stats of event handlers (in seconds):
- Total time spent:
\t{total_time}

- Events.STARTED:
{Events_STARTED}
Handlers names:
{Events_STARTED_names}

- Events.EPOCH_STARTED:
{Events_EPOCH_STARTED}
Handlers names:
{Events_EPOCH_STARTED_names}

- Events.ITERATION_STARTED:
{Events_ITERATION_STARTED}
Handlers names:
{Events_ITERATION_STARTED_names}

- Events.ITERATION_COMPLETED:
{Events_ITERATION_COMPLETED}
Handlers names:
{Events_ITERATION_COMPLETED_names}

- Events.EPOCH_COMPLETED:
{Events_EPOCH_COMPLETED}
Handlers names:
{Events_EPOCH_COMPLETED_names}

- Events.COMPLETED:
{Events_COMPLETED}
Handlers names:
{Events_COMPLETED_names}

""".format(processing_stats=odict_to_str(results['processing_stats']), 
           dataflow_stats=odict_to_str(results['dataflow_stats']), 
           **others)
        print(output_message)
        return output_message
    
    @staticmethod
    def write_results(output_path):
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return
        
        raise NotImplementedError("")
class DatasetExperiment(BaseExperiment):
    def __init__(self, train, niter, nepoch, eval=None, gpu_id=[0],
                 sacred_run=None, writers=None, root=None, corruption=None,
                 **kwargs):
        super(DatasetExperiment, self).__init__(**kwargs)

        self.train = train
        self.eval = eval

        self.corruption = corruption

        self.niter = niter
        self.nepoch = nepoch
        self.device = 'cuda:0'

        self.sacred_run = sacred_run

        self.gpu_id = gpu_id

        if root is not None:
            self.basedir = make_basedir(os.path.join(root, str(sacred_run._id)))
        else:
            writers = None
            checkpoint = None

        if writers is not None:
            self.writers = init_writers(*writers, sacred_run=sacred_run, dirname=self.basedir)
        else:
            self.writers = None

        self.create_trainer()

    def create_trainer(self):
        self.trainer = Engine(self.train_step)
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.K_step)
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.log_metrics, 'train')
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self.write_metrics, 'train')

        self.pbar = ProgressBar()

        self.timer = Timer(average=True)
        self.timer.attach(self.trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                          pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.print_times)

    def print_times(self, engine):
        iteration = engine.state.iteration
        if self.writers is not None:
            self.writers.add_scalar('time_per_epoch', self.timer.total, iteration)
            self.writers.add_scalar('time_per_batch', self.timer.value(), iteration)
        self.timer.reset()

    def train_step(self, engine, batch):
        self.training()
        if isinstance(batch, Mapping):
            return self.optimize(**convert_tensor(batch, self.device))
        else:
            raise NotImplementedError

    def infer(self, **kwargs):
        raise NotImplementedError

    def optimize(self, **kwargs):
        raise NotImplementedError

    def K_step(self, **kwargs):
        raise NotImplementedError

    def log_metrics(self, engine, dataset_name):
        iteration = self.trainer.state.iteration
        epoch = self.trainer.state.epoch
        log = f"EPOCH {epoch}" f" - ITER {iteration} - {dataset_name}"

        if self.sacred_run is not None:
            log = f"ID {self.sacred_run._id} - " + log

        for metric, value in engine.state.metrics.items():
            if value is None:
                value = float('nan')
            log += f" {metric}: {value:.6f}"

        log += self.extra_log()
        self.pbar.log_message(log)

    def extra_log(self):
        return ""

    def write_metrics(self, engine, dataset_name):
        if self.writers is not None:
            for metric, value in engine.state.metrics.items():
                name = dataset_name + '/' + metric
                self.writers.add_scalar(name, value, self.trainer.state.iteration)

    def run(self):
        self.trainer.run(self.train, max_epochs=self.nepoch)
class BasicTimeProfiler(object):
    """
    BasicTimeProfiler can be used to profile the handlers,
    events, data loading and data processing times.

    Examples:

    .. code-block:: python

        #
        # Create an object of the profiler and attach an engine to it
        #
        profiler = BasicTimeProfiler()
        trainer = Engine(train_updater)
        profiler.attach(trainer)
        trainer.run(dataloader, max_epochs=3)

    """
    def __init__(self):
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = None
        self.processing_times = None
        self.event_handlers_times = None

        self._events = [
            Events.EPOCH_STARTED, Events.EPOCH_COMPLETED,
            Events.ITERATION_STARTED, Events.ITERATION_COMPLETED,
            Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED,
            Events.COMPLETED
        ]
        self._fmethods = [
            self._as_first_epoch_started, self._as_first_epoch_completed,
            self._as_first_iter_started, self._as_first_iter_completed,
            self._as_first_get_batch_started,
            self._as_first_get_batch_completed, self._as_first_completed
        ]
        self._lmethods = [
            self._as_last_epoch_started, self._as_last_epoch_completed,
            self._as_last_iter_started, self._as_last_iter_completed,
            self._as_last_get_batch_started, self._as_last_get_batch_completed,
            self._as_last_completed
        ]

    def _reset(self, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_STARTED: torch.zeros(total_num_iters)
        }

    def _as_first_started(self, engine):
        if hasattr(engine.state.dataloader, "__len__"):
            num_iters_per_epoch = len(engine.state.dataloader)
        else:
            num_iters_per_epoch = engine.state.epoch_length

        self.max_epochs = engine.state.max_epochs
        self.total_num_iters = self.max_epochs * num_iters_per_epoch
        self._reset(self.max_epochs, self.total_num_iters)

        self.event_handlers_names = {
            e: [
                h.__qualname__
                if hasattr(h, "__qualname__") else h.__class__.__name__
                for (h, _, _) in engine._event_handlers[e]
            ]
            for e in Events if e != Events.EXCEPTION_RAISED
        }

        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append(
            (self._as_last_started, (), {}))

        for e, m in zip(self._events, self._fmethods):
            engine._event_handlers[e].insert(0, (m, (), {}))

        for e, m in zip(self._events, self._lmethods):
            engine._event_handlers[e].append((m, (), {}))

        # Let's go
        self._event_handlers_timer.reset()

    def _as_last_started(self, engine):
        self.event_handlers_times[
            Events.STARTED][0] = self._event_handlers_timer.value()

    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t

    def _as_first_get_batch_started(self, engine):
        self._event_handlers_timer.reset()
        self._dataflow_timer.reset()

    def _as_last_get_batch_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t

    def _as_first_get_batch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_get_batch_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t

        d = self._dataflow_timer.value()
        self.dataflow_times[i] = d

        self._dataflow_timer.reset()

    def _as_first_iter_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t

        self._processing_timer.reset()

    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t

        self._event_handlers_timer.reset()

    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t

    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t

    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_completed(self, engine):
        self.event_handlers_times[
            Events.COMPLETED][0] = self._event_handlers_timer.value()

        # Remove added handlers:
        engine.remove_event_handler(self._as_last_started, Events.STARTED)

        for e, m in zip(self._events, self._fmethods):
            engine.remove_event_handler(m, e)

        for e, m in zip(self._events, self._lmethods):
            engine.remove_event_handler(m, e)

    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED]\
                .insert(0, (self._as_first_started, (), {}))

    @staticmethod
    def _compute_basic_stats(data):
        return OrderedDict([
            ('min/index', (torch.min(data).item(), torch.argmin(data).item())),
            ('max/index', (torch.max(data).item(), torch.argmax(data).item())),
            ('mean', torch.mean(data).item()), ('std', torch.std(data).item()),
            ('total', torch.sum(data).item())
        ])

    def get_results(self):
        """
        Method to fetch the aggregated profiler results after the engine is run

        .. code-block:: python

            results = profiler.get_results()

        """
        events_to_ignore = [Events.EXCEPTION_RAISED]
        total_eh_time = sum([
            sum(self.event_handlers_times[e]) for e in Events
            if e not in events_to_ignore
        ])
        return OrderedDict([
            ("processing_stats",
             self._compute_basic_stats(self.processing_times)),
            ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
            ("event_handlers_stats",
             dict([(str(e).replace(".", "_"),
                    self._compute_basic_stats(self.event_handlers_times[e]))
                   for e in Events if e not in events_to_ignore] +
                  [("total_time", total_eh_time)])),
            ("event_handlers_names", {
                str(e).replace(".", "_") + "_names": v
                for e, v in self.event_handlers_names.items()
            })
        ])

    def write_results(self, output_path):
        """
        Method to store the unaggregated profiling results to a csv file

        .. code-block:: python

            profiler.write_results('path_to_dir/awesome_filename.csv')

        Example output:

        .. code-block:: text

            -----------------------------------------------------------------
            epoch iteration processing_stats dataflow_stats Event_STARTED ...
            1.0     1.0        0.00003         0.252387        0.125676
            1.0     2.0        0.00029         0.252342        0.125123

        """
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return

        iters_per_epoch = self.total_num_iters // self.max_epochs

        epochs = torch.arange(self.max_epochs, dtype=torch.float32)\
            .repeat_interleave(iters_per_epoch) + 1
        iterations = torch.arange(self.total_num_iters,
                                  dtype=torch.float32) + 1
        processing_stats = self.processing_times
        dataflow_stats = self.dataflow_times

        event_started = self.event_handlers_times[Events.STARTED]\
            .repeat_interleave(self.total_num_iters)
        event_completed = self.event_handlers_times[Events.COMPLETED]\
            .repeat_interleave(self.total_num_iters)
        event_epoch_started = self.event_handlers_times[Events.EPOCH_STARTED]\
            .repeat_interleave(iters_per_epoch)
        event_epoch_completed = self.event_handlers_times[Events.EPOCH_COMPLETED]\
            .repeat_interleave(iters_per_epoch)

        event_iter_started = self.event_handlers_times[
            Events.ITERATION_STARTED]
        event_iter_completed = self.event_handlers_times[
            Events.ITERATION_COMPLETED]
        event_batch_started = self.event_handlers_times[
            Events.GET_BATCH_STARTED]
        event_batch_completed = self.event_handlers_times[
            Events.GET_BATCH_COMPLETED]

        results_dump = torch.stack([
            epochs, iterations, processing_stats, dataflow_stats,
            event_started, event_completed, event_epoch_started,
            event_epoch_completed, event_iter_started, event_iter_completed,
            event_batch_started, event_batch_completed
        ],
                                   dim=1).numpy()

        results_df = pd.DataFrame(
            data=results_dump,
            columns=[
                'epoch', 'iteration', 'processing_stats', 'dataflow_stats',
                'Event_STARTED', 'Event_COMPLETED', 'Event_EPOCH_STARTED',
                'Event_EPOCH_COMPLETED', 'Event_ITERATION_STARTED',
                'Event_ITERATION_COMPLETED', 'Event_GET_BATCH_STARTED',
                'Event_GET_BATCH_COMPLETED'
            ])
        results_df.to_csv(output_path, index=False)

    @staticmethod
    def print_results(results):
        """
        Method to print the aggregated results from the profiler

        .. code-block:: python

            profiler.print_results(results)

        Example output:

        .. code-block:: text

            --------------------------------------------
            - Time profiling results:
            --------------------------------------------
            Processing function time stats (in seconds):
                min/index: (2.9754010029137135e-05, 0)
                max/index: (2.9754010029137135e-05, 0)
                mean: 2.9754010029137135e-05
                std: nan
                total: 2.9754010029137135e-05

            Dataflow time stats (in seconds):
                min/index: (0.2523871660232544, 0)
                max/index: (0.2523871660232544, 0)
                mean: 0.2523871660232544
                std: nan
                total: 0.2523871660232544

            Time stats of event handlers (in seconds):
            - Total time spent:
                1.0080009698867798

            - Events.STARTED:
                min/index: (0.1256754994392395, 0)
                max/index: (0.1256754994392395, 0)
                mean: 0.1256754994392395
                std: nan
                total: 0.1256754994392395

            Handlers names:
            ['BasicTimeProfiler._as_first_started', 'delay_start']
            --------------------------------------------
        """
        def odict_to_str(d):
            out = ""
            for k, v in d.items():
                out += "\t{}: {}\n".format(k, v)
            return out

        others = {
            k: odict_to_str(v) if isinstance(v, OrderedDict) else v
            for k, v in results['event_handlers_stats'].items()
        }

        others.update(results['event_handlers_names'])

        output_message = """
--------------------------------------------
- Time profiling results:
--------------------------------------------

Processing function time stats (in seconds):
{processing_stats}

Dataflow time stats (in seconds):
{dataflow_stats}

Time stats of event handlers (in seconds):
- Total time spent:
\t{total_time}

- Events.STARTED:
{Events_STARTED}
Handlers names:
{Events_STARTED_names}

- Events.EPOCH_STARTED:
{Events_EPOCH_STARTED}
Handlers names:
{Events_EPOCH_STARTED_names}

- Events.ITERATION_STARTED:
{Events_ITERATION_STARTED}
Handlers names:
{Events_ITERATION_STARTED_names}

- Events.ITERATION_COMPLETED:
{Events_ITERATION_COMPLETED}
Handlers names:
{Events_ITERATION_COMPLETED_names}

- Events.EPOCH_COMPLETED:
{Events_EPOCH_COMPLETED}
Handlers names:
{Events_EPOCH_COMPLETED_names}

- Events.COMPLETED:
{Events_COMPLETED}
Handlers names:
{Events_COMPLETED_names}

""".format(processing_stats=odict_to_str(results['processing_stats']),
           dataflow_stats=odict_to_str(results['dataflow_stats']),
           **others)
        print(output_message)
        return output_message
Exemple #16
0
class BasicTimeProfiler:
    """
    BasicTimeProfiler can be used to profile the handlers,
    events, data loading and data processing times.

    Examples:

    .. code-block:: python

        from ignite.contrib.handlers import BasicTimeProfiler

        trainer = Engine(train_updater)

        # Create an object of the profiler and attach an engine to it
        profiler = BasicTimeProfiler()
        profiler.attach(trainer)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_intermediate_results():
            profiler.print_results(profiler.get_results())

        trainer.run(dataloader, max_epochs=3)

        profiler.write_results('path_to_dir/time_profiling.csv')

    """

    events_to_ignore = [
        Events.EXCEPTION_RAISED,
        Events.TERMINATE,
        Events.TERMINATE_SINGLE_EPOCH,
        Events.DATALOADER_STOP_ITERATION,
    ]

    def __init__(self):
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = None
        self.processing_times = None
        self.event_handlers_times = None

        self._events = [
            Events.EPOCH_STARTED,
            Events.EPOCH_COMPLETED,
            Events.ITERATION_STARTED,
            Events.ITERATION_COMPLETED,
            Events.GET_BATCH_STARTED,
            Events.GET_BATCH_COMPLETED,
            Events.COMPLETED,
        ]
        self._fmethods = [
            self._as_first_epoch_started,
            self._as_first_epoch_completed,
            self._as_first_iter_started,
            self._as_first_iter_completed,
            self._as_first_get_batch_started,
            self._as_first_get_batch_completed,
            self._as_first_completed,
        ]
        self._lmethods = [
            self._as_last_epoch_started,
            self._as_last_epoch_completed,
            self._as_last_iter_started,
            self._as_last_iter_completed,
            self._as_last_get_batch_started,
            self._as_last_get_batch_completed,
            self._as_last_completed,
        ]

    def _reset(self, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_STARTED: torch.zeros(total_num_iters),
        }

    def _as_first_started(self, engine):
        if hasattr(engine.state.dataloader, "__len__"):
            num_iters_per_epoch = len(engine.state.dataloader)
        else:
            num_iters_per_epoch = engine.state.epoch_length

        self.max_epochs = engine.state.max_epochs
        self.total_num_iters = self.max_epochs * num_iters_per_epoch
        self._reset(self.max_epochs, self.total_num_iters)

        self.event_handlers_names = {
            e: [
                h.__qualname__
                if hasattr(h, "__qualname__") else h.__class__.__name__
                for (h, _, _) in engine._event_handlers[e]
                if "BasicTimeProfiler." not in repr(
                    h)  # avoid adding internal handlers into output
            ]
            for e in Events if e not in self.events_to_ignore
        }

        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append(
            (self._as_last_started, (engine, ), {}))

        for e, m in zip(self._events, self._fmethods):
            engine._event_handlers[e].insert(0, (m, (engine, ), {}))

        for e, m in zip(self._events, self._lmethods):
            engine._event_handlers[e].append((m, (engine, ), {}))

        # Let's go
        self._event_handlers_timer.reset()

    def _as_last_started(self, engine):
        self.event_handlers_times[
            Events.STARTED][0] = self._event_handlers_timer.value()

    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t

    def _as_first_get_batch_started(self, engine):
        self._event_handlers_timer.reset()
        self._dataflow_timer.reset()

    def _as_last_get_batch_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t

    def _as_first_get_batch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_get_batch_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t

        d = self._dataflow_timer.value()
        self.dataflow_times[i] = d

        self._dataflow_timer.reset()

    def _as_first_iter_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t

        self._processing_timer.reset()

    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t

        self._event_handlers_timer.reset()

    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t

    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t

    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_completed(self, engine):
        self.event_handlers_times[
            Events.COMPLETED][0] = self._event_handlers_timer.value()

        # Remove added handlers:
        engine.remove_event_handler(self._as_last_started, Events.STARTED)

        for e, m in zip(self._events, self._fmethods):
            engine.remove_event_handler(m, e)

        for e, m in zip(self._events, self._lmethods):
            engine.remove_event_handler(m, e)

    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))

        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(
                0, (self._as_first_started, (engine, ), {}))

    @staticmethod
    def _compute_basic_stats(data):
        # compute on non-zero data:
        data = data[data > 0]
        out = [
            ("total",
             torch.sum(data).item() if len(data) > 0 else "not yet triggered")
        ]
        if len(data) > 1:
            out += [
                ("min/index", (torch.min(data).item(),
                               torch.argmin(data).item())),
                ("max/index", (torch.max(data).item(),
                               torch.argmax(data).item())),
                ("mean", torch.mean(data).item()),
                ("std", torch.std(data).item()),
            ]
        return OrderedDict(out)

    def get_results(self):
        """
        Method to fetch the aggregated profiler results after the engine is run

        .. code-block:: python

            results = profiler.get_results()

        """
        total_eh_time = sum([(self.event_handlers_times[e]).sum()
                             for e in Events
                             if e not in self.events_to_ignore])

        return OrderedDict([
            ("processing_stats",
             self._compute_basic_stats(self.processing_times)),
            ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
            (
                "event_handlers_stats",
                dict([(str(e.name).replace(".", "_"),
                       self._compute_basic_stats(self.event_handlers_times[e]))
                      for e in Events if e not in self.events_to_ignore] +
                     [("total_time", total_eh_time)]),
            ),
            (
                "event_handlers_names",
                {
                    str(e.name).replace(".", "_") + "_names": v
                    for e, v in self.event_handlers_names.items()
                },
            ),
        ])

    def write_results(self, output_path):
        """
        Method to store the unaggregated profiling results to a csv file

        .. code-block:: python

            profiler.write_results('path_to_dir/awesome_filename.csv')

        Example output:

        .. code-block:: text

            -----------------------------------------------------------------
            epoch iteration processing_stats dataflow_stats Event_STARTED ...
            1.0     1.0        0.00003         0.252387        0.125676
            1.0     2.0        0.00029         0.252342        0.125123

        """
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return

        iters_per_epoch = self.total_num_iters // self.max_epochs

        epochs = torch.arange(
            self.max_epochs,
            dtype=torch.float32).repeat_interleave(iters_per_epoch) + 1
        iterations = torch.arange(self.total_num_iters,
                                  dtype=torch.float32) + 1
        processing_stats = self.processing_times
        dataflow_stats = self.dataflow_times

        event_started = self.event_handlers_times[
            Events.STARTED].repeat_interleave(self.total_num_iters)
        event_completed = self.event_handlers_times[
            Events.COMPLETED].repeat_interleave(self.total_num_iters)
        event_epoch_started = self.event_handlers_times[
            Events.EPOCH_STARTED].repeat_interleave(iters_per_epoch)
        event_epoch_completed = self.event_handlers_times[
            Events.EPOCH_COMPLETED].repeat_interleave(iters_per_epoch)

        event_iter_started = self.event_handlers_times[
            Events.ITERATION_STARTED]
        event_iter_completed = self.event_handlers_times[
            Events.ITERATION_COMPLETED]
        event_batch_started = self.event_handlers_times[
            Events.GET_BATCH_STARTED]
        event_batch_completed = self.event_handlers_times[
            Events.GET_BATCH_COMPLETED]

        results_dump = torch.stack(
            [
                epochs,
                iterations,
                processing_stats,
                dataflow_stats,
                event_started,
                event_completed,
                event_epoch_started,
                event_epoch_completed,
                event_iter_started,
                event_iter_completed,
                event_batch_started,
                event_batch_completed,
            ],
            dim=1,
        ).numpy()

        results_df = pd.DataFrame(
            data=results_dump,
            columns=[
                "epoch",
                "iteration",
                "processing_stats",
                "dataflow_stats",
                "Event_STARTED",
                "Event_COMPLETED",
                "Event_EPOCH_STARTED",
                "Event_EPOCH_COMPLETED",
                "Event_ITERATION_STARTED",
                "Event_ITERATION_COMPLETED",
                "Event_GET_BATCH_STARTED",
                "Event_GET_BATCH_COMPLETED",
            ],
        )
        results_df.to_csv(output_path, index=False)

    @staticmethod
    def print_results(results):
        """
        Method to print the aggregated results from the profiler

        .. code-block:: python

            profiler.print_results(results)

        Example output:

        .. code-block:: text

             ----------------------------------------------------
            | Time profiling stats (in seconds):                 |
             ----------------------------------------------------
            total  |  min/index  |  max/index  |  mean  |  std

            Processing function:
            157.46292 | 0.01452/1501 | 0.26905/0 | 0.07730 | 0.01258

            Dataflow:
            6.11384 | 0.00008/1935 | 0.28461/1551 | 0.00300 | 0.02693

            Event handlers:
            2.82721

            - Events.STARTED: []
            0.00000

            - Events.EPOCH_STARTED: []
            0.00006 | 0.00000/0 | 0.00000/17 | 0.00000 | 0.00000

            - Events.ITERATION_STARTED: ['PiecewiseLinear']
            0.03482 | 0.00001/188 | 0.00018/679 | 0.00002 | 0.00001

            - Events.ITERATION_COMPLETED: ['TerminateOnNan']
            0.20037 | 0.00006/866 | 0.00089/1943 | 0.00010 | 0.00003

            - Events.EPOCH_COMPLETED: ['empty_cuda_cache', 'training.<locals>.log_elapsed_time', ]
            2.57860 | 0.11529/0 | 0.14977/13 | 0.12893 | 0.00790

            - Events.COMPLETED: []
            not yet triggered

        """
        def to_str(v):
            if isinstance(v, str):
                return v
            elif isinstance(v, tuple):
                return "{:.5f}/{}".format(v[0], v[1])
            return "{:.5f}".format(v)

        def odict_to_str(d):
            out = " | ".join([to_str(v) for v in d.values()])
            return out

        others = {
            k: odict_to_str(v) if isinstance(v, OrderedDict) else v
            for k, v in results["event_handlers_stats"].items()
        }

        others.update(results["event_handlers_names"])

        output_message = """
 ----------------------------------------------------
| Time profiling stats (in seconds):                 |
 ----------------------------------------------------
total  |  min/index  |  max/index  |  mean  |  std

Processing function:
{processing_stats}

Dataflow:
{dataflow_stats}

Event handlers:
{total_time:.5f}

- Events.STARTED: {STARTED_names}
{STARTED}

- Events.EPOCH_STARTED: {EPOCH_STARTED_names}
{EPOCH_STARTED}

- Events.ITERATION_STARTED: {ITERATION_STARTED_names}
{ITERATION_STARTED}

- Events.ITERATION_COMPLETED: {ITERATION_COMPLETED_names}
{ITERATION_COMPLETED}

- Events.EPOCH_COMPLETED: {EPOCH_COMPLETED_names}
{EPOCH_COMPLETED}

- Events.COMPLETED: {COMPLETED_names}
{COMPLETED}
""".format(
            processing_stats=odict_to_str(results["processing_stats"]),
            dataflow_stats=odict_to_str(results["dataflow_stats"]),
            **others,
        )
        print(output_message)
        return output_message
Exemple #17
0
            print(
                timestamp +
                " Test set Results  CLAMP -  Avg ssim: {:.2f} Avg psnr: {:.2f} Avg loss: {:.2f}"
                .format(test_metrics['ssimclamp'], test_metrics['psnrclamp'],
                        0))
            experiment.log_metric("test_ssim clamp", test_metrics['ssimclamp'])
            experiment.log_metric("test_psnr clamp", test_metrics['psnrclamp'])
        else:
            print(
                timestamp +
                " Test set Results -  Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
                .format(test_metrics['mae'], test_metrics['mse'], 0))
            experiment.log_metric("test_mae", test_metrics['mae'])
            experiment.log_metric("test_mse", test_metrics['mse'])
        experiment.log_metric("evaluate_test_timer",
                              evaluate_test_timer.value())
        print("evaluate_test_timer ", evaluate_test_timer.value())
        # experiment.log_metric("test_loss", test_metrics['loss'])
    else:
        # docs on save and load
        to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer}
        save_handler = Checkpoint(to_save,
                                  DiskSaver('saved_model/' + args.task_id,
                                            create_dir=True,
                                            atomic=True),
                                  filename_prefix=args.task_id,
                                  n_saved=3)

        save_handler_best = Checkpoint(
            to_save,
            DiskSaver('saved_model_best/' + args.task_id,