def test_timer(): sleep_t = 0.2 n_iter = 3 def _train_func(engine, batch): time.sleep(sleep_t) def _test_func(engine, batch): time.sleep(sleep_t) trainer = Engine(_train_func) tester = Engine(_test_func) t_total = Timer() t_batch = Timer(average=True) t_train = Timer() t_total.attach(trainer) t_batch.attach( trainer, pause=Events.ITERATION_COMPLETED, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED, ) t_train.attach(trainer, pause=Events.EPOCH_COMPLETED, resume=Events.EPOCH_STARTED) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(trainer): tester.run(range(n_iter)) # Run "training" trainer.run(range(n_iter)) def _equal(lhs, rhs): return round(lhs, 1) == round(rhs, 1) assert _equal(t_total.value(), (2 * n_iter * sleep_t)) assert _equal(t_batch.value(), (sleep_t)) assert _equal(t_train.value(), (n_iter * sleep_t)) t_total.reset() assert _equal(t_total.value(), 0.0)
class BasicTimeProfiler: """ BasicTimeProfiler can be used to profile the handlers, events, data loading and data processing times. Examples: .. code-block:: python from ignite.contrib.handlers import BasicTimeProfiler trainer = Engine(train_updater) # Create an object of the profiler and attach an engine to it profiler = BasicTimeProfiler() profiler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) def log_intermediate_results(): profiler.print_results(profiler.get_results()) trainer.run(dataloader, max_epochs=3) profiler.write_results('path_to_dir/time_profiling.csv') """ events_to_ignore = [ Events.EXCEPTION_RAISED, Events.TERMINATE, Events.TERMINATE_SINGLE_EPOCH, Events.DATALOADER_STOP_ITERATION, ] def __init__(self): self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() self.dataflow_times = None self.processing_times = None self.event_handlers_times = None self._events = [ Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED, Events.COMPLETED, ] self._fmethods = [ self._as_first_epoch_started, self._as_first_epoch_completed, self._as_first_iter_started, self._as_first_iter_completed, self._as_first_get_batch_started, self._as_first_get_batch_completed, self._as_first_completed, ] self._lmethods = [ self._as_last_epoch_started, self._as_last_epoch_completed, self._as_last_iter_started, self._as_last_iter_completed, self._as_last_get_batch_started, self._as_last_get_batch_completed, self._as_last_completed, ] def _reset(self, num_epochs, total_num_iters): self.dataflow_times = torch.zeros(total_num_iters) self.processing_times = torch.zeros(total_num_iters) self.event_handlers_times = { Events.STARTED: torch.zeros(1), Events.COMPLETED: torch.zeros(1), Events.EPOCH_STARTED: torch.zeros(num_epochs), Events.EPOCH_COMPLETED: torch.zeros(num_epochs), Events.ITERATION_STARTED: torch.zeros(total_num_iters), Events.ITERATION_COMPLETED: torch.zeros(total_num_iters), Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters), Events.GET_BATCH_STARTED: torch.zeros(total_num_iters), } def _as_first_started(self, engine): if hasattr(engine.state.dataloader, "__len__"): num_iters_per_epoch = len(engine.state.dataloader) else: num_iters_per_epoch = engine.state.epoch_length self.max_epochs = engine.state.max_epochs self.total_num_iters = self.max_epochs * num_iters_per_epoch self._reset(self.max_epochs, self.total_num_iters) self.event_handlers_names = { e: [ h.__qualname__ if hasattr(h, "__qualname__") else h.__class__.__name__ for (h, _, _) in engine._event_handlers[e] if "BasicTimeProfiler." not in repr( h) # avoid adding internal handlers into output ] for e in Events if e not in self.events_to_ignore } # Setup all other handlers: engine._event_handlers[Events.STARTED].append( (self._as_last_started, (engine, ), {})) for e, m in zip(self._events, self._fmethods): engine._event_handlers[e].insert(0, (m, (engine, ), {})) for e, m in zip(self._events, self._lmethods): engine._event_handlers[e].append((m, (engine, ), {})) # Let's go self._event_handlers_timer.reset() def _as_last_started(self, engine): self.event_handlers_times[ Events.STARTED][0] = self._event_handlers_timer.value() def _as_first_epoch_started(self, engine): self._event_handlers_timer.reset() def _as_last_epoch_started(self, engine): t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_STARTED][e] = t def _as_first_get_batch_started(self, engine): self._event_handlers_timer.reset() self._dataflow_timer.reset() def _as_last_get_batch_started(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t def _as_first_get_batch_completed(self, engine): self._event_handlers_timer.reset() def _as_last_get_batch_completed(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t d = self._dataflow_timer.value() self.dataflow_times[i] = d self._dataflow_timer.reset() def _as_first_iter_started(self, engine): self._event_handlers_timer.reset() def _as_last_iter_started(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_STARTED][i] = t self._processing_timer.reset() def _as_first_iter_completed(self, engine): t = self._processing_timer.value() i = engine.state.iteration - 1 self.processing_times[i] = t self._event_handlers_timer.reset() def _as_last_iter_completed(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t def _as_first_epoch_completed(self, engine): self._event_handlers_timer.reset() def _as_last_epoch_completed(self, engine): t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t def _as_first_completed(self, engine): self._event_handlers_timer.reset() def _as_last_completed(self, engine): self.event_handlers_times[ Events.COMPLETED][0] = self._event_handlers_timer.value() # Remove added handlers: engine.remove_event_handler(self._as_last_started, Events.STARTED) for e, m in zip(self._events, self._fmethods): engine.remove_event_handler(m, e) for e, m in zip(self._events, self._lmethods): engine.remove_event_handler(m, e) def attach(self, engine): if not isinstance(engine, Engine): raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine))) if not engine.has_event_handler(self._as_first_started): engine._event_handlers[Events.STARTED].insert( 0, (self._as_first_started, (engine, ), {})) @staticmethod def _compute_basic_stats(data): # compute on non-zero data: data = data[data > 0] out = [ ("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered") ] if len(data) > 1: out += [ ("min/index", (torch.min(data).item(), torch.argmin(data).item())), ("max/index", (torch.max(data).item(), torch.argmax(data).item())), ("mean", torch.mean(data).item()), ("std", torch.std(data).item()), ] return OrderedDict(out) def get_results(self): """ Method to fetch the aggregated profiler results after the engine is run .. code-block:: python results = profiler.get_results() """ total_eh_time = sum([(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore]) return OrderedDict([ ("processing_stats", self._compute_basic_stats(self.processing_times)), ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)), ( "event_handlers_stats", dict([(str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e])) for e in Events if e not in self.events_to_ignore] + [("total_time", total_eh_time)]), ), ( "event_handlers_names", { str(e.name).replace(".", "_") + "_names": v for e, v in self.event_handlers_names.items() }, ), ]) def write_results(self, output_path): """ Method to store the unaggregated profiling results to a csv file .. code-block:: python profiler.write_results('path_to_dir/awesome_filename.csv') Example output: .. code-block:: text ----------------------------------------------------------------- epoch iteration processing_stats dataflow_stats Event_STARTED ... 1.0 1.0 0.00003 0.252387 0.125676 1.0 2.0 0.00029 0.252342 0.125123 """ try: import pandas as pd except ImportError: print("Need pandas to write results as files") return iters_per_epoch = self.total_num_iters // self.max_epochs epochs = torch.arange( self.max_epochs, dtype=torch.float32).repeat_interleave(iters_per_epoch) + 1 iterations = torch.arange(self.total_num_iters, dtype=torch.float32) + 1 processing_stats = self.processing_times dataflow_stats = self.dataflow_times event_started = self.event_handlers_times[ Events.STARTED].repeat_interleave(self.total_num_iters) event_completed = self.event_handlers_times[ Events.COMPLETED].repeat_interleave(self.total_num_iters) event_epoch_started = self.event_handlers_times[ Events.EPOCH_STARTED].repeat_interleave(iters_per_epoch) event_epoch_completed = self.event_handlers_times[ Events.EPOCH_COMPLETED].repeat_interleave(iters_per_epoch) event_iter_started = self.event_handlers_times[ Events.ITERATION_STARTED] event_iter_completed = self.event_handlers_times[ Events.ITERATION_COMPLETED] event_batch_started = self.event_handlers_times[ Events.GET_BATCH_STARTED] event_batch_completed = self.event_handlers_times[ Events.GET_BATCH_COMPLETED] results_dump = torch.stack( [ epochs, iterations, processing_stats, dataflow_stats, event_started, event_completed, event_epoch_started, event_epoch_completed, event_iter_started, event_iter_completed, event_batch_started, event_batch_completed, ], dim=1, ).numpy() results_df = pd.DataFrame( data=results_dump, columns=[ "epoch", "iteration", "processing_stats", "dataflow_stats", "Event_STARTED", "Event_COMPLETED", "Event_EPOCH_STARTED", "Event_EPOCH_COMPLETED", "Event_ITERATION_STARTED", "Event_ITERATION_COMPLETED", "Event_GET_BATCH_STARTED", "Event_GET_BATCH_COMPLETED", ], ) results_df.to_csv(output_path, index=False) @staticmethod def print_results(results): """ Method to print the aggregated results from the profiler .. code-block:: python profiler.print_results(results) Example output: .. code-block:: text ---------------------------------------------------- | Time profiling stats (in seconds): | ---------------------------------------------------- total | min/index | max/index | mean | std Processing function: 157.46292 | 0.01452/1501 | 0.26905/0 | 0.07730 | 0.01258 Dataflow: 6.11384 | 0.00008/1935 | 0.28461/1551 | 0.00300 | 0.02693 Event handlers: 2.82721 - Events.STARTED: [] 0.00000 - Events.EPOCH_STARTED: [] 0.00006 | 0.00000/0 | 0.00000/17 | 0.00000 | 0.00000 - Events.ITERATION_STARTED: ['PiecewiseLinear'] 0.03482 | 0.00001/188 | 0.00018/679 | 0.00002 | 0.00001 - Events.ITERATION_COMPLETED: ['TerminateOnNan'] 0.20037 | 0.00006/866 | 0.00089/1943 | 0.00010 | 0.00003 - Events.EPOCH_COMPLETED: ['empty_cuda_cache', 'training.<locals>.log_elapsed_time', ] 2.57860 | 0.11529/0 | 0.14977/13 | 0.12893 | 0.00790 - Events.COMPLETED: [] not yet triggered """ def to_str(v): if isinstance(v, str): return v elif isinstance(v, tuple): return "{:.5f}/{}".format(v[0], v[1]) return "{:.5f}".format(v) def odict_to_str(d): out = " | ".join([to_str(v) for v in d.values()]) return out others = { k: odict_to_str(v) if isinstance(v, OrderedDict) else v for k, v in results["event_handlers_stats"].items() } others.update(results["event_handlers_names"]) output_message = """ ---------------------------------------------------- | Time profiling stats (in seconds): | ---------------------------------------------------- total | min/index | max/index | mean | std Processing function: {processing_stats} Dataflow: {dataflow_stats} Event handlers: {total_time:.5f} - Events.STARTED: {STARTED_names} {STARTED} - Events.EPOCH_STARTED: {EPOCH_STARTED_names} {EPOCH_STARTED} - Events.ITERATION_STARTED: {ITERATION_STARTED_names} {ITERATION_STARTED} - Events.ITERATION_COMPLETED: {ITERATION_COMPLETED_names} {ITERATION_COMPLETED} - Events.EPOCH_COMPLETED: {EPOCH_COMPLETED_names} {EPOCH_COMPLETED} - Events.COMPLETED: {COMPLETED_names} {COMPLETED} """.format( processing_stats=odict_to_str(results["processing_stats"]), dataflow_stats=odict_to_str(results["dataflow_stats"]), **others, ) print(output_message) return output_message
class DatasetExperiment(BaseExperiment): def __init__(self, train, niter, nepoch, eval=None, gpu_id=[0], sacred_run=None, writers=None, root=None, corruption=None, **kwargs): super(DatasetExperiment, self).__init__(**kwargs) self.train = train self.eval = eval self.corruption = corruption self.niter = niter self.nepoch = nepoch self.device = 'cuda:0' self.sacred_run = sacred_run self.gpu_id = gpu_id if root is not None: self.basedir = make_basedir(os.path.join(root, str(sacred_run._id))) else: writers = None checkpoint = None if writers is not None: self.writers = init_writers(*writers, sacred_run=sacred_run, dirname=self.basedir) else: self.writers = None self.create_trainer() def create_trainer(self): self.trainer = Engine(self.train_step) self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.K_step) self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.log_metrics, 'train') self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self.write_metrics, 'train') self.pbar = ProgressBar() self.timer = Timer(average=True) self.timer.attach(self.trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.print_times) def print_times(self, engine): iteration = engine.state.iteration if self.writers is not None: self.writers.add_scalar('time_per_epoch', self.timer.total, iteration) self.writers.add_scalar('time_per_batch', self.timer.value(), iteration) self.timer.reset() def train_step(self, engine, batch): self.training() if isinstance(batch, Mapping): return self.optimize(**convert_tensor(batch, self.device)) else: raise NotImplementedError def infer(self, **kwargs): raise NotImplementedError def optimize(self, **kwargs): raise NotImplementedError def K_step(self, **kwargs): raise NotImplementedError def log_metrics(self, engine, dataset_name): iteration = self.trainer.state.iteration epoch = self.trainer.state.epoch log = f"EPOCH {epoch}" f" - ITER {iteration} - {dataset_name}" if self.sacred_run is not None: log = f"ID {self.sacred_run._id} - " + log for metric, value in engine.state.metrics.items(): if value is None: value = float('nan') log += f" {metric}: {value:.6f}" log += self.extra_log() self.pbar.log_message(log) def extra_log(self): return "" def write_metrics(self, engine, dataset_name): if self.writers is not None: for metric, value in engine.state.metrics.items(): name = dataset_name + '/' + metric self.writers.add_scalar(name, value, self.trainer.state.iteration) def run(self): self.trainer.run(self.train, max_epochs=self.nepoch)
class HandlersTimeProfiler: """ HandlersTimeProfiler can be used to profile the handlers, data loading and data processing times. Custom events are also profiled by this profiler Examples: .. code-block:: python from ignite.contrib.handlers import HandlersTimeProfiler trainer = Engine(train_updater) # Create an object of the profiler and attach an engine to it profiler = HandlersTimeProfiler() profiler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) def log_intermediate_results(): profiler.print_results(profiler.get_results()) trainer.run(dataloader, max_epochs=3) profiler.write_results('path_to_dir/time_profiling.csv') """ EVENT_FILTER_THESHOLD_TIME = 0.0001 def __init__(self) -> None: self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() self.dataflow_times = [] # type: List[float] self.processing_times = [] # type: List[float] self.event_handlers_times = { } # type: Dict[EventEnum, Dict[str, List[float]]] @staticmethod def _get_callable_name(handler: Callable) -> str: # get name of the callable handler return getattr(handler, "__qualname__", handler.__class__.__name__) def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable: @functools.wraps(handler) def _timeit_handler(*args: Any, **kwargs: Any) -> None: self._event_handlers_timer.reset() handler(*args, **kwargs) t = self._event_handlers_timer.value() hname = self._get_callable_name(handler) # filter profiled time if the handler was attached to event with event filter if not hasattr(handler, "_parent") or t >= self.EVENT_FILTER_THESHOLD_TIME: self.event_handlers_times[event][hname].append(t) # required to revert back to original handler after profiling setattr(_timeit_handler, "_profiler_original", handler) return _timeit_handler def _timeit_processing(self) -> None: # handler used for profiling processing times t = self._processing_timer.value() self.processing_times.append(t) def _timeit_dataflow(self) -> None: # handler used for profiling dataflow times t = self._dataflow_timer.value() self.dataflow_times.append(t) def _reset(self, event_handlers_names: Mapping[EventEnum, List[str]]) -> None: # reset the variables used for profiling self.dataflow_times = [] self.processing_times = [] self.event_handlers_times = { e: {h: [] for h in event_handlers_names[e]} for e in event_handlers_names } @staticmethod def _is_internal_handler(handler: Callable) -> bool: # checks whether the handler is internal return any(n in repr(handler) for n in ["HandlersTimeProfiler.", "Timer."]) def _detach_profiler_handlers(self, engine: Engine) -> None: # reverts handlers to original handlers for e in engine._event_handlers: for i, (func, args, kwargs) in enumerate(engine._event_handlers[e]): if hasattr(func, "_profiler_original"): engine._event_handlers[e][i] = (func._profiler_original, args, kwargs) def _as_first_started(self, engine: Engine) -> None: # wraps original handlers for profiling self.event_handlers_names = { e: [ self._get_callable_name(h) for (h, _, _) in engine._event_handlers[e] if not self._is_internal_handler(h) ] for e in engine._allowed_events } self._reset(self.event_handlers_names) for e in engine._allowed_events: for i, (func, args, kwargs) in enumerate(engine._event_handlers[e]): if not self._is_internal_handler(func): engine._event_handlers[e][i] = ( self._create_wrapped_handler(func, e), args, kwargs) # processing timer engine.add_event_handler(Events.ITERATION_STARTED, self._processing_timer.reset) engine._event_handlers[Events.ITERATION_COMPLETED].insert( 0, (self._timeit_processing, (), {})) # dataflow timer engine.add_event_handler(Events.GET_BATCH_STARTED, self._dataflow_timer.reset) engine._event_handlers[Events.GET_BATCH_COMPLETED].insert( 0, (self._timeit_dataflow, (), {})) # revert back the wrapped handlers with original handlers at the end engine.add_event_handler(Events.COMPLETED, self._detach_profiler_handlers) def attach(self, engine: Engine) -> None: """Attach HandlersTimeProfiler to the given engine. Args: engine: the instance of Engine to attach """ if not isinstance(engine, Engine): raise TypeError( f"Argument engine should be ignite.engine.Engine, but given {type(engine)}" ) if not engine.has_event_handler(self._as_first_started): engine._event_handlers[Events.STARTED].insert( 0, (self._as_first_started, (engine, ), {})) def get_results(self) -> List[List[Union[str, float]]]: """ Method to fetch the aggregated profiler results after the engine is run .. code-block:: python results = profiler.get_results() """ total_eh_time = sum([ sum(self.event_handlers_times[e][h]) for e in self.event_handlers_times for h in self.event_handlers_times[e] ]) total_eh_time = round( float(total_eh_time), 5, ) def compute_basic_stats( times: Union[Sequence, torch.Tensor] ) -> List[Union[str, float, Tuple[Union[str, float], Union[str, float]]]]: data = torch.as_tensor(times, dtype=torch.float32) # compute on non-zero data: data = data[data > 0] total = round(torch.sum(data).item(), 5) if len( data) > 0 else "not triggered" # type: Union[str, float] min_index = ("None", "None" ) # type: Tuple[Union[str, float], Union[str, float]] max_index = ("None", "None" ) # type: Tuple[Union[str, float], Union[str, float]] mean = "None" # type: Union[str, float] std = "None" # type: Union[str, float] if len(data) > 0: min_index = (round(torch.min(data).item(), 5), torch.argmin(data).item()) max_index = (round(torch.max(data).item(), 5), torch.argmax(data).item()) mean = round(torch.mean(data).item(), 5) if len(data) > 1: std = round(torch.std(data).item(), 5) return [total, min_index, max_index, mean, std] event_handler_stats = [[ h, getattr(e, "name", str(e)), *compute_basic_stats( torch.tensor(self.event_handlers_times[e][h], dtype=torch.float32)), ] for e in self.event_handlers_times for h in self.event_handlers_times[e]] event_handler_stats.append( ["Total", "", total_eh_time, "", "", "", ""]) event_handler_stats.append([ "Processing", "None", *compute_basic_stats(self.processing_times) ]) event_handler_stats.append( ["Dataflow", "None", *compute_basic_stats(self.dataflow_times)]) return event_handler_stats def write_results(self, output_path: str) -> None: """ Method to store the unaggregated profiling results to a csv file Args: output_path: file output path containing a filename .. code-block:: python profiler.write_results('path_to_dir/awesome_filename.csv') Example output: .. code-block:: text ----------------------------------------------------------------- # processing_stats dataflow_stats training.<locals>.log_elapsed_time (EPOCH_COMPLETED) ... 1 0.00003 0.252387 0.125676 2 0.00029 0.252342 0.125123 """ try: import pandas as pd except ImportError: raise RuntimeError("Need pandas to write results as files") processing_stats = torch.tensor(self.processing_times, dtype=torch.float32) dataflow_stats = torch.tensor(self.dataflow_times, dtype=torch.float32) cols = [processing_stats, dataflow_stats] headers = ["processing_stats", "dataflow_stats"] for e in self.event_handlers_times: for h in self.event_handlers_times[e]: headers.append(f"{h} ({getattr(e, 'name', str(e))})") cols.append( torch.tensor(self.event_handlers_times[e][h], dtype=torch.float32)) # Determine maximum length max_len = max([x.numel() for x in cols]) count_col = torch.arange(max_len, dtype=torch.float32) + 1 cols.insert(0, count_col) headers.insert(0, "#") # pad all tensors to have same length cols = [ torch.nn.functional.pad(x, pad=(0, max_len - x.numel()), mode="constant", value=0) for x in cols ] results_dump = torch.stack( cols, dim=1, ).numpy() results_df = pd.DataFrame( data=results_dump, columns=headers, ) results_df.to_csv(output_path, index=False) @staticmethod def print_results(results: List[List[Union[str, float]]]) -> None: """ Method to print the aggregated results from the profiler Args: results: the aggregated results from the profiler .. code-block:: python profiler.print_results(results) Example output: .. code-block:: text ----------------------------------------- ----------------------- -------------- ... Handler Event Name Total(s) ----------------------------------------- ----------------------- -------------- run.<locals>.log_training_results EPOCH_COMPLETED 19.43245 run.<locals>.log_validation_results EPOCH_COMPLETED 2.55271 run.<locals>.log_time EPOCH_COMPLETED 0.00049 run.<locals>.log_intermediate_results EPOCH_COMPLETED 0.00106 run.<locals>.log_training_loss ITERATION_COMPLETED 0.059 run.<locals>.log_time COMPLETED not triggered ----------------------------------------- ----------------------- -------------- Total 22.04571 ----------------------------------------- ----------------------- -------------- Processing took total 11.29543s [min/index: 0.00393s/1875, max/index: 0.00784s/0, mean: 0.00602s, std: 0.00034s] Dataflow took total 16.24365s [min/index: 0.00533s/1874, max/index: 0.01129s/937, mean: 0.00866s, std: 0.00113s] """ # adopted implementation of torch.autograd.profiler.build_table handler_column_width = max([len(item[0]) for item in results ]) + 4 # type: ignore[arg-type] event_column_width = max([len(item[1]) for item in results ]) + 4 # type: ignore[arg-type] DEFAULT_COLUMN_WIDTH = 14 headers = [ "Handler", "Event Name", "Total(s)", "Min(s)/IDX", "Max(s)/IDX", "Mean(s)", "Std(s)", ] # Have to use a list because nonlocal is Py3 only... SPACING_SIZE = 2 row_format_lst = [""] header_sep_lst = [""] line_length_lst = [-SPACING_SIZE] def add_column(padding: int, text_dir: str = ">") -> None: row_format_lst[0] += "{: " + text_dir + str(padding) + "}" + ( " " * SPACING_SIZE) header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE) line_length_lst[0] += padding + SPACING_SIZE add_column(handler_column_width, text_dir="<") add_column(event_column_width, text_dir="<") for _ in headers[2:]: add_column(DEFAULT_COLUMN_WIDTH) row_format = row_format_lst[0] header_sep = header_sep_lst[0] result = [] def append(s: str) -> None: result.append(s) result.append("\n") result.append("\n") append(header_sep) append(row_format.format(*headers)) append(header_sep) for row in results[:-3]: # format min/idx and max/idx row[3] = "{}/{}".format(*row[3]) # type: ignore[misc] row[4] = "{}/{}".format(*row[4]) # type: ignore[misc] append(row_format.format(*row)) append(header_sep) # print total handlers time row append(row_format.format(*results[-3])) append(header_sep) summary_format = "{} took total {}s [min/index: {}, max/index: {}, mean: {}s, std: {}s]" for row in results[-2:]: row[3] = "{}s/{}".format(*row[3]) # type: ignore[misc] row[4] = "{}s/{}".format(*row[4]) # type: ignore[misc] del row[1] append(summary_format.format(*row)) print("".join(result))
class BasicTimeProfiler: """ BasicTimeProfiler can be used to profile the handlers, events, data loading and data processing times. Examples: .. code-block:: python # # Create an object of the profiler and attach an engine to it # profiler = BasicTimeProfiler() trainer = Engine(train_updater) profiler.attach(trainer) trainer.run(dataloader, max_epochs=3) @trainer.on(Events.EPOCH_COMPLETED) def log_intermediate_results(): profiler.print_results(profiler.get_results()) profiler.write_results('path_to_dir/time_profiling.csv') """ events_to_ignore = [ Events.EXCEPTION_RAISED, Events.TERMINATE, Events.TERMINATE_SINGLE_EPOCH, Events.DATALOADER_STOP_ITERATION, ] def __init__(self): self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() self.dataflow_times = None self.processing_times = None self.event_handlers_times = None self._events = [ Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED, Events.COMPLETED, ] self._fmethods = [ self._as_first_epoch_started, self._as_first_epoch_completed, self._as_first_iter_started, self._as_first_iter_completed, self._as_first_get_batch_started, self._as_first_get_batch_completed, self._as_first_completed, ] self._lmethods = [ self._as_last_epoch_started, self._as_last_epoch_completed, self._as_last_iter_started, self._as_last_iter_completed, self._as_last_get_batch_started, self._as_last_get_batch_completed, self._as_last_completed, ] def _reset(self, num_epochs, total_num_iters): self.dataflow_times = torch.zeros(total_num_iters) self.processing_times = torch.zeros(total_num_iters) self.event_handlers_times = { Events.STARTED: torch.zeros(1), Events.COMPLETED: torch.zeros(1), Events.EPOCH_STARTED: torch.zeros(num_epochs), Events.EPOCH_COMPLETED: torch.zeros(num_epochs), Events.ITERATION_STARTED: torch.zeros(total_num_iters), Events.ITERATION_COMPLETED: torch.zeros(total_num_iters), Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters), Events.GET_BATCH_STARTED: torch.zeros(total_num_iters), } def _as_first_started(self, engine): if hasattr(engine.state.dataloader, "__len__"): num_iters_per_epoch = len(engine.state.dataloader) else: num_iters_per_epoch = engine.state.epoch_length self.max_epochs = engine.state.max_epochs self.total_num_iters = self.max_epochs * num_iters_per_epoch self._reset(self.max_epochs, self.total_num_iters) self.event_handlers_names = { e: [ h.__qualname__ if hasattr(h, "__qualname__") else h.__class__.__name__ for (h, _, _) in engine._event_handlers[e] if "BasicTimeProfiler." not in repr( h) # avoid adding internal handlers into output ] for e in Events if e not in self.events_to_ignore } # Setup all other handlers: engine._event_handlers[Events.STARTED].append( (self._as_last_started, (engine, ), {})) for e, m in zip(self._events, self._fmethods): engine._event_handlers[e].insert(0, (m, (engine, ), {})) for e, m in zip(self._events, self._lmethods): engine._event_handlers[e].append((m, (engine, ), {})) # Let's go self._event_handlers_timer.reset() def _as_last_started(self, engine): self.event_handlers_times[ Events.STARTED][0] = self._event_handlers_timer.value() def _as_first_epoch_started(self, engine): self._event_handlers_timer.reset() def _as_last_epoch_started(self, engine): t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_STARTED][e] = t def _as_first_get_batch_started(self, engine): self._event_handlers_timer.reset() self._dataflow_timer.reset() def _as_last_get_batch_started(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t def _as_first_get_batch_completed(self, engine): self._event_handlers_timer.reset() def _as_last_get_batch_completed(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t d = self._dataflow_timer.value() self.dataflow_times[i] = d self._dataflow_timer.reset() def _as_first_iter_started(self, engine): self._event_handlers_timer.reset() def _as_last_iter_started(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_STARTED][i] = t self._processing_timer.reset() def _as_first_iter_completed(self, engine): t = self._processing_timer.value() i = engine.state.iteration - 1 self.processing_times[i] = t self._event_handlers_timer.reset() def _as_last_iter_completed(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t def _as_first_epoch_completed(self, engine): self._event_handlers_timer.reset() def _as_last_epoch_completed(self, engine): t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t def _as_first_completed(self, engine): self._event_handlers_timer.reset() def _as_last_completed(self, engine): self.event_handlers_times[ Events.COMPLETED][0] = self._event_handlers_timer.value() # Remove added handlers: engine.remove_event_handler(self._as_last_started, Events.STARTED) for e, m in zip(self._events, self._fmethods): engine.remove_event_handler(m, e) for e, m in zip(self._events, self._lmethods): engine.remove_event_handler(m, e) def attach(self, engine): if not isinstance(engine, Engine): raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine))) if not engine.has_event_handler(self._as_first_started): engine._event_handlers[Events.STARTED].insert( 0, (self._as_first_started, (engine, ), {})) @staticmethod def _compute_basic_stats(data): # compute on non-zero data: data = data[data > 0] out = [ ("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered") ] if len(data) > 1: out += [ ("min/index", (torch.min(data).item(), torch.argmin(data).item())), ("max/index", (torch.max(data).item(), torch.argmax(data).item())), ("mean", torch.mean(data).item()), ("std", torch.std(data).item()), ] return OrderedDict(out) def get_results(self): """ Method to fetch the aggregated profiler results after the engine is run .. code-block:: python results = profiler.get_results() """ total_eh_time = sum([(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore]) return OrderedDict([ ("processing_stats", self._compute_basic_stats(self.processing_times)), ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)), ( "event_handlers_stats", dict([(str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e])) for e in Events if e not in self.events_to_ignore] + [("total_time", total_eh_time)]), ), ( "event_handlers_names", { str(e.name).replace(".", "_") + "_names": v for e, v in self.event_handlers_names.items() }, ), ]) def write_results(self, output_path): """ Method to store the unaggregated profiling results to a csv file .. code-block:: python profiler.write_results('path_to_dir/awesome_filename.csv') Example output: .. code-block:: text ----------------------------------------------------------------- epoch iteration processing_stats dataflow_stats Event_STARTED ... 1.0 1.0 0.00003 0.252387 0.125676 1.0 2.0 0.00029 0.252342 0.125123 """ try: import pandas as pd except ImportError: print("Need pandas to write results as files") return iters_per_epoch = self.total_num_iters // self.max_epochs epochs = torch.arange( self.max_epochs, dtype=torch.float32).repeat_interleave(iters_per_epoch) + 1 iterations = torch.arange(self.total_num_iters, dtype=torch.float32) + 1 processing_stats = self.processing_times dataflow_stats = self.dataflow_times event_started = self.event_handlers_times[ Events.STARTED].repeat_interleave(self.total_num_iters) event_completed = self.event_handlers_times[ Events.COMPLETED].repeat_interleave(self.total_num_iters) event_epoch_started = self.event_handlers_times[ Events.EPOCH_STARTED].repeat_interleave(iters_per_epoch) event_epoch_completed = self.event_handlers_times[ Events.EPOCH_COMPLETED].repeat_interleave(iters_per_epoch) event_iter_started = self.event_handlers_times[ Events.ITERATION_STARTED] event_iter_completed = self.event_handlers_times[ Events.ITERATION_COMPLETED] event_batch_started = self.event_handlers_times[ Events.GET_BATCH_STARTED] event_batch_completed = self.event_handlers_times[ Events.GET_BATCH_COMPLETED] results_dump = torch.stack( [ epochs, iterations, processing_stats, dataflow_stats, event_started, event_completed, event_epoch_started, event_epoch_completed, event_iter_started, event_iter_completed, event_batch_started, event_batch_completed, ], dim=1, ).numpy() results_df = pd.DataFrame( data=results_dump, columns=[ "epoch", "iteration", "processing_stats", "dataflow_stats", "Event_STARTED", "Event_COMPLETED", "Event_EPOCH_STARTED", "Event_EPOCH_COMPLETED", "Event_ITERATION_STARTED", "Event_ITERATION_COMPLETED", "Event_GET_BATCH_STARTED", "Event_GET_BATCH_COMPLETED", ], ) results_df.to_csv(output_path, index=False) @staticmethod def print_results(results): """ Method to print the aggregated results from the profiler .. code-block:: python profiler.print_results(results) Example output: .. code-block:: text -------------------------------------------- - Time profiling results: -------------------------------------------- Processing function time stats (in seconds): min/index: (1.3081999895803165e-05, 1) max/index: (1.433099987480091e-05, 0) mean: 1.3706499885302037e-05 std: 8.831763693706307e-07 total: 2.7412999770604074e-05 Dataflow time stats (in seconds): min/index: (5.199999941396527e-05, 0) max/index: (0.00010925399692496285, 1) mean: 8.062699635047466e-05 std: 4.048469054396264e-05 total: 0.0001612539927009493 Time stats of event handlers (in seconds): - Total time spent: 1.0080009698867798 - Events.STARTED: total: 0.1256754994392395 Handlers names: ['delay_start'] -------------------------------------------- """ def odict_to_str(d): out = "" for k, v in d.items(): out += "\t{}: {}\n".format(k, v) return out others = { k: odict_to_str(v) if isinstance(v, OrderedDict) else v for k, v in results["event_handlers_stats"].items() } others.update(results["event_handlers_names"]) output_message = """ -------------------------------------------- - Time profiling results: -------------------------------------------- Processing function time stats (in seconds): {processing_stats} Dataflow time stats (in seconds): {dataflow_stats} Time stats of event handlers (in seconds): - Total time spent: \t{total_time} - Events.STARTED: {STARTED} Handlers names: {STARTED_names} - Events.EPOCH_STARTED: {EPOCH_STARTED} Handlers names: {EPOCH_STARTED_names} - Events.ITERATION_STARTED: {ITERATION_STARTED} Handlers names: {ITERATION_STARTED_names} - Events.ITERATION_COMPLETED: {ITERATION_COMPLETED} Handlers names: {ITERATION_COMPLETED_names} - Events.EPOCH_COMPLETED: {EPOCH_COMPLETED} Handlers names: {EPOCH_COMPLETED_names} - Events.COMPLETED: {COMPLETED} Handlers names: {COMPLETED_names} """.format( processing_stats=odict_to_str(results["processing_stats"]), dataflow_stats=odict_to_str(results["dataflow_stats"]), **others, ) print(output_message) return output_message
class BasicTimeProfiler(object): """ BasicTimeProfiler can be used to profile the handlers, events, data loading and data processing times. Examples: .. code-block:: python # # Create an object of the profiler and attach an engine to it # profiler = BasicTimeProfiler() trainer = Engine(train_updater) profiler.attach(trainer) trainer.run(dataloader, max_epochs=3) """ def __init__(self): self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() self.dataflow_times = None self.processing_times = None self.event_handlers_times = None self._events = [ Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED, Events.COMPLETED ] self._fmethods = [ self._as_first_epoch_started, self._as_first_epoch_completed, self._as_first_iter_started, self._as_first_iter_completed, self._as_first_get_batch_started, self._as_first_get_batch_completed, self._as_first_completed ] self._lmethods = [ self._as_last_epoch_started, self._as_last_epoch_completed, self._as_last_iter_started, self._as_last_iter_completed, self._as_last_get_batch_started, self._as_last_get_batch_completed, self._as_last_completed ] def _reset(self, num_epochs, total_num_iters): self.dataflow_times = torch.zeros(total_num_iters) self.processing_times = torch.zeros(total_num_iters) self.event_handlers_times = { Events.STARTED: torch.zeros(1), Events.COMPLETED: torch.zeros(1), Events.EPOCH_STARTED: torch.zeros(num_epochs), Events.EPOCH_COMPLETED: torch.zeros(num_epochs), Events.ITERATION_STARTED: torch.zeros(total_num_iters), Events.ITERATION_COMPLETED: torch.zeros(total_num_iters), Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters), Events.GET_BATCH_STARTED: torch.zeros(total_num_iters) } def _as_first_started(self, engine): if hasattr(engine.state.dataloader, "__len__"): num_iters_per_epoch = len(engine.state.dataloader) else: num_iters_per_epoch = engine.state.epoch_length self.max_epochs = engine.state.max_epochs self.total_num_iters = self.max_epochs * num_iters_per_epoch self._reset(self.max_epochs, self.total_num_iters) self.event_handlers_names = { e: [ h.__qualname__ if hasattr(h, "__qualname__") else h.__class__.__name__ for (h, _, _) in engine._event_handlers[e] ] for e in Events if e != Events.EXCEPTION_RAISED } # Setup all other handlers: engine._event_handlers[Events.STARTED].append( (self._as_last_started, (), {})) for e, m in zip(self._events, self._fmethods): engine._event_handlers[e].insert(0, (m, (), {})) for e, m in zip(self._events, self._lmethods): engine._event_handlers[e].append((m, (), {})) # Let's go self._event_handlers_timer.reset() def _as_last_started(self, engine): self.event_handlers_times[ Events.STARTED][0] = self._event_handlers_timer.value() def _as_first_epoch_started(self, engine): self._event_handlers_timer.reset() def _as_last_epoch_started(self, engine): t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_STARTED][e] = t def _as_first_get_batch_started(self, engine): self._event_handlers_timer.reset() self._dataflow_timer.reset() def _as_last_get_batch_started(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t def _as_first_get_batch_completed(self, engine): self._event_handlers_timer.reset() def _as_last_get_batch_completed(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t d = self._dataflow_timer.value() self.dataflow_times[i] = d self._dataflow_timer.reset() def _as_first_iter_started(self, engine): self._event_handlers_timer.reset() def _as_last_iter_started(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_STARTED][i] = t self._processing_timer.reset() def _as_first_iter_completed(self, engine): t = self._processing_timer.value() i = engine.state.iteration - 1 self.processing_times[i] = t self._event_handlers_timer.reset() def _as_last_iter_completed(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t def _as_first_epoch_completed(self, engine): self._event_handlers_timer.reset() def _as_last_epoch_completed(self, engine): t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t def _as_first_completed(self, engine): self._event_handlers_timer.reset() def _as_last_completed(self, engine): self.event_handlers_times[ Events.COMPLETED][0] = self._event_handlers_timer.value() # Remove added handlers: engine.remove_event_handler(self._as_last_started, Events.STARTED) for e, m in zip(self._events, self._fmethods): engine.remove_event_handler(m, e) for e, m in zip(self._events, self._lmethods): engine.remove_event_handler(m, e) def attach(self, engine): if not isinstance(engine, Engine): raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine))) if not engine.has_event_handler(self._as_first_started): engine._event_handlers[Events.STARTED]\ .insert(0, (self._as_first_started, (), {})) @staticmethod def _compute_basic_stats(data): return OrderedDict([ ('min/index', (torch.min(data).item(), torch.argmin(data).item())), ('max/index', (torch.max(data).item(), torch.argmax(data).item())), ('mean', torch.mean(data).item()), ('std', torch.std(data).item()), ('total', torch.sum(data).item()) ]) def get_results(self): """ Method to fetch the aggregated profiler results after the engine is run .. code-block:: python results = profiler.get_results() """ events_to_ignore = [Events.EXCEPTION_RAISED] total_eh_time = sum([ sum(self.event_handlers_times[e]) for e in Events if e not in events_to_ignore ]) return OrderedDict([ ("processing_stats", self._compute_basic_stats(self.processing_times)), ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)), ("event_handlers_stats", dict([(str(e).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e])) for e in Events if e not in events_to_ignore] + [("total_time", total_eh_time)])), ("event_handlers_names", { str(e).replace(".", "_") + "_names": v for e, v in self.event_handlers_names.items() }) ]) def write_results(self, output_path): """ Method to store the unaggregated profiling results to a csv file .. code-block:: python profiler.write_results('path_to_dir/awesome_filename.csv') Example output: .. code-block:: text ----------------------------------------------------------------- epoch iteration processing_stats dataflow_stats Event_STARTED ... 1.0 1.0 0.00003 0.252387 0.125676 1.0 2.0 0.00029 0.252342 0.125123 """ try: import pandas as pd except ImportError: print("Need pandas to write results as files") return iters_per_epoch = self.total_num_iters // self.max_epochs epochs = torch.arange(self.max_epochs, dtype=torch.float32)\ .repeat_interleave(iters_per_epoch) + 1 iterations = torch.arange(self.total_num_iters, dtype=torch.float32) + 1 processing_stats = self.processing_times dataflow_stats = self.dataflow_times event_started = self.event_handlers_times[Events.STARTED]\ .repeat_interleave(self.total_num_iters) event_completed = self.event_handlers_times[Events.COMPLETED]\ .repeat_interleave(self.total_num_iters) event_epoch_started = self.event_handlers_times[Events.EPOCH_STARTED]\ .repeat_interleave(iters_per_epoch) event_epoch_completed = self.event_handlers_times[Events.EPOCH_COMPLETED]\ .repeat_interleave(iters_per_epoch) event_iter_started = self.event_handlers_times[ Events.ITERATION_STARTED] event_iter_completed = self.event_handlers_times[ Events.ITERATION_COMPLETED] event_batch_started = self.event_handlers_times[ Events.GET_BATCH_STARTED] event_batch_completed = self.event_handlers_times[ Events.GET_BATCH_COMPLETED] results_dump = torch.stack([ epochs, iterations, processing_stats, dataflow_stats, event_started, event_completed, event_epoch_started, event_epoch_completed, event_iter_started, event_iter_completed, event_batch_started, event_batch_completed ], dim=1).numpy() results_df = pd.DataFrame( data=results_dump, columns=[ 'epoch', 'iteration', 'processing_stats', 'dataflow_stats', 'Event_STARTED', 'Event_COMPLETED', 'Event_EPOCH_STARTED', 'Event_EPOCH_COMPLETED', 'Event_ITERATION_STARTED', 'Event_ITERATION_COMPLETED', 'Event_GET_BATCH_STARTED', 'Event_GET_BATCH_COMPLETED' ]) results_df.to_csv(output_path, index=False) @staticmethod def print_results(results): """ Method to print the aggregated results from the profiler .. code-block:: python profiler.print_results(results) Example output: .. code-block:: text -------------------------------------------- - Time profiling results: -------------------------------------------- Processing function time stats (in seconds): min/index: (2.9754010029137135e-05, 0) max/index: (2.9754010029137135e-05, 0) mean: 2.9754010029137135e-05 std: nan total: 2.9754010029137135e-05 Dataflow time stats (in seconds): min/index: (0.2523871660232544, 0) max/index: (0.2523871660232544, 0) mean: 0.2523871660232544 std: nan total: 0.2523871660232544 Time stats of event handlers (in seconds): - Total time spent: 1.0080009698867798 - Events.STARTED: min/index: (0.1256754994392395, 0) max/index: (0.1256754994392395, 0) mean: 0.1256754994392395 std: nan total: 0.1256754994392395 Handlers names: ['BasicTimeProfiler._as_first_started', 'delay_start'] -------------------------------------------- """ def odict_to_str(d): out = "" for k, v in d.items(): out += "\t{}: {}\n".format(k, v) return out others = { k: odict_to_str(v) if isinstance(v, OrderedDict) else v for k, v in results['event_handlers_stats'].items() } others.update(results['event_handlers_names']) output_message = """ -------------------------------------------- - Time profiling results: -------------------------------------------- Processing function time stats (in seconds): {processing_stats} Dataflow time stats (in seconds): {dataflow_stats} Time stats of event handlers (in seconds): - Total time spent: \t{total_time} - Events.STARTED: {Events_STARTED} Handlers names: {Events_STARTED_names} - Events.EPOCH_STARTED: {Events_EPOCH_STARTED} Handlers names: {Events_EPOCH_STARTED_names} - Events.ITERATION_STARTED: {Events_ITERATION_STARTED} Handlers names: {Events_ITERATION_STARTED_names} - Events.ITERATION_COMPLETED: {Events_ITERATION_COMPLETED} Handlers names: {Events_ITERATION_COMPLETED_names} - Events.EPOCH_COMPLETED: {Events_EPOCH_COMPLETED} Handlers names: {Events_EPOCH_COMPLETED_names} - Events.COMPLETED: {Events_COMPLETED} Handlers names: {Events_COMPLETED_names} """.format(processing_stats=odict_to_str(results['processing_stats']), dataflow_stats=odict_to_str(results['dataflow_stats']), **others) print(output_message) return output_message
class GANTrainer: def __init__(self, G, D, criterionG, criterionD, optimizerG, optimizerD, lr_schedulerG=None, lr_schedulerD=None, make_latent=None, metrics=None, save_path=".", name="GAN", gan_type='gan'): self.G = G self.D = D self.criterionG = criterionG self.criterionD = criterionD self.optimizerG = optimizerG self.optimizerD = optimizerD self.lr_schedulerG = lr_schedulerG self.lr_schedulerD = lr_schedulerD self.make_latent = make_latent self.metrics = metrics or {} self.name = name root = Path(save_path).expanduser().absolute() self.save_path = root / 'gan_trainer' / self.name self.metric_history = defaultdict(list) self.device = 'cuda' if CUDA else 'cpu' self._timer = Timer() self._iterations = 0 self.G.to(self.device) self.D.to(self.device) assert gan_type in ['gan', 'acgan', 'cgan', 'infogan'] if gan_type == 'gan': self.create_fn = create_gan_trainer elif gan_type == 'acgan': self.create_fn = create_acgan_trainer elif gan_type == 'cgan': self.create_fn = create_cgan_trainer elif gan_type == 'infogan': self.create_fn = create_infogan_trainer def _lr_scheduler_step(self, engine): if self.lr_schedulerG: self.lr_schedulerG.step() if self.lr_schedulerD: self.lr_schedulerD.step() def _attach_timer(self, engine): self._trained_time = 0 self._timer.reset() def _increment_iteration(self, engine): self._iterations += 1 def _log_results(self, engine, log_interval, max_iter): if self._iterations % log_interval != 0: return i = engine.state.iteration elapsed = self._timer.value() self._timer.reset() self._trained_time += elapsed trained_time = self._trained_time eta_seconds = int((trained_time / i) * (max_iter - i)) it_fmt = "%" + str(len(str(max_iter))) + "d" print(("Iter: " + it_fmt + ", Cost: %.2fs, Eta: %s") % (self._iterations, elapsed, datetime.timedelta(seconds=eta_seconds))) for name, metric in self.metrics.items(): metric.completed(engine, name) metric.reset() msg = "" for name, val in engine.state.metrics.items(): msg += "%s: %.4f\t" % (name, val) self.metric_history[name].append(val) print(msg) def fit(self, it, max_iter, log_interval=100, callbacks=()): engine = self.create_fn(self.G, self.D, self.criterionG, self.criterionD, self.optimizerG, self.optimizerD, self.make_latent, self.metrics, self.device) self._attach_timer(engine) engine.add_event_handler(Events.ITERATION_STARTED, self._lr_scheduler_step) engine.add_event_handler(Events.ITERATION_COMPLETED, self._increment_iteration) engine.add_event_handler(Events.ITERATION_COMPLETED, self._log_results, log_interval, max_iter) for callback in callbacks: engine.add_event_handler(Events.ITERATION_COMPLETED, _trainer_callback_wrap(callback), self) # Run engine.run(it, max_iter) # Return history return self.metric_history def state_dict(self): s = { "iterations": self.iterations(), "G": self.G.state_dict(), "D": self.D.state_dict(), "optimizerG": self.optimizerG.state_dict(), "optimizerD": self.optimizerD.state_dict(), "criterionG": self.criterionG.state_dict(), "criterionD": self.criterionD.state_dict(), "lr_schedulerG": None, "lr_schedulerD": None, "metric_history": self.metric_history, } if self.lr_schedulerG: s["lr_schedulerG"] = self.lr_schedulerG.state_dict() if self.lr_schedulerD: s["lr_schedulerD"] = self.lr_schedulerD.state_dict() return s def load_state_dict(self, state_dict): iterations, G, D, optimizerG, optimizerD, criterionG, criterionD, lr_schedulerG, lr_schedulerD, metric_history = get( [ "iterations", "G", "D", "optimizerG", "optimizerD", "criterionG", "criterionD", "lr_schedulerG", "lr_schedulerD", "metric_history" ], state_dict) self._iterations = iterations self.G.load_state_dict(G) self.D.load_state_dict(D) self.optimizerG.load_state_dict(optimizerG) self.optimizerD.load_state_dict(optimizerD) self.criterionG.load_state_dict(criterionG) self.criterionD.load_state_dict(criterionD) if self.lr_schedulerG and lr_schedulerG: self.lr_schedulerG.load_state_dict(lr_schedulerG) if self.lr_schedulerD and lr_schedulerD: self.lr_schedulerD.load_state_dict(lr_schedulerD) self.metric_history = metric_history def save(self, remove_prev=True): d = self.save_path d.mkdir(parents=True, exist_ok=True) if remove_prev: pattern = "%s_trainer*.pth" % self.name saves = list(d.glob(pattern)) if len(saves) != 0: fp = max(saves, key=lambda f: f.stat().st_mtime) p = "%s_trainer_(?P<iters>[0-9]+).pth" % self.name iters = int(re.match(p, fp.name).group('iters')) if self.iterations() > iters: fp.unlink() filename = "%s_trainer_%d.pth" % (self.name, self.iterations()) fp = d / filename torch.save(self.state_dict(), fp) print("Save trainer as %s" % fp) def load(self): d = self.save_path pattern = "%s_trainer*.pth" % self.name saves = list(d.glob(pattern)) if len(saves) == 0: raise FileNotFoundError("No checkpoint to load for %s in %s" % (self.name, self.save_path)) fp = max(saves, key=lambda f: f.stat().st_mtime) self.load_state_dict(torch.load(fp, map_location=self.device)) print("Load trainer from %s" % fp) def iterations(self): return self._iterations
class Trainer: """ Class which setups the training logic which mainly involves defining callback handlers and attaching them to the training loop. """ def __init__(self, model, config, evaluator, data_loader, tb_writer, run_info, logger, checkpoint_dir): """ Creates a new trainer object for training a model. :param model: model to train. Needs to inherit from the BaseModel class. :param config: dictionary containing the whole configuration of the experiment :param evaluator: Instance of the evaluator class, used to run evaluation on a specified schedule :param data_loader: pytorch data loader providing the training data :param tb_writer: tensorboardX summary writer :param run_info: sacred run info for loging training progress :param logger: python logger object :param checkpoint_dir: directory path for storing checkpoints """ self.run_info = run_info self.logger = logger self.data_loader = data_loader self.evaluator = evaluator self.engine = Engine(self._step) self.model = model self.config = config self.train_cfg = config['train'] self.tb_writer = tb_writer self.pbar = ProgressBar(ascii=True, desc='* Epoch') self.timer = Timer(average=True) self.save_last_checkpoint_handler = ModelCheckpoint( checkpoint_dir, 'last', save_interval=self.train_cfg['save_interval'], n_saved=self.train_cfg['save_n_last'], require_empty=False) self.add_handler() def run(self): """ Start the training loop which will run until all epochs are complete :return: """ self.engine.run(self.data_loader, max_epochs=self.train_cfg['n_epochs']) def add_handler(self): """ Adds all the callback handlers to the trainer engine. Should be called in the end of the init. :return: """ # Learning rate decay for lr_s in self.model.schedulers: self.engine.add_event_handler(Events.ITERATION_STARTED, lr_s) # Checkpoint saving self.engine.add_event_handler(Events.EPOCH_STARTED, self.save_last_checkpoint_handler, self.model.networks) # Progbar monitoring_metrics = self.model.metric_names for mm in monitoring_metrics: RunningAverage(output_transform=self._extract_loss(mm)).attach( self.engine, mm) self.pbar.attach(self.engine, metric_names=monitoring_metrics) # Timer self.timer.attach(self.engine, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # Logging self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._handle_log_train_results) self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._handle_log_train_images) self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._handle_run_evaluation) self.engine.add_event_handler(Events.EPOCH_COMPLETED, self._handle_print_times) # Exception handling self.engine.add_event_handler(Events.EXCEPTION_RAISED, self._handle_exception) def _step(self, engine, batch): """ Definition of a single training step. This function gets automatically called by the engine every iteration. :param engine: trainer engine :param batch: one batch provided by the dataloader :return: """ self.model.train() self.model.set_input(batch) self.model.optimize_parameters() return self.model.state def _handle_log_train_results(self, engine): """ Handler for writing the losses to tensorboard and sacred. :param engine: train engine :return: """ if (engine.state.iteration - 1) % self.train_cfg['log_interval'] == 0: metrics = engine.state.metrics # does not include non scalar metrics, since loggers can not handle this for m_name, m_val in metrics.items(): if m_val is None: raise ValueError(f'Value for {m_name} is None') self.run_info.log_scalar("train.%s" % m_name, m_val, engine.state.iteration) self.tb_writer.add_scalar("train/%s" % m_name, m_val, engine.state.iteration) for lr_name, lr_val in self.model.learning_rates.items(): if lr_val is None: raise ValueError(f'Value for {lr_name} is None') self.run_info.log_scalar("train.%s" % lr_name, lr_val, engine.state.iteration) self.tb_writer.add_scalar("train/%s" % lr_name, lr_val, engine.state.iteration) def _handle_log_train_images(self, engine): """ Handler for writing visual samples to tensorboard. :param engine: train engine :return: """ if (engine.state.iteration - 1) % self.train_cfg['img_log_interval'] == 0: for name, visual in self.model.visuals.items(): # TODO remove the visual.transpose here and put it in the visualization function of the models self.tb_writer.add_image('train/%s' % name, visual.transpose(2, 0, 1), engine.state.iteration) for name, figure in self.model.figures.items(): self.tb_writer.add_figure('train_metrics/%s' % name, figure, engine.state.iteration) def _handle_run_evaluation(self, engine): """ Handler which will execute evaluation by running the evaluator object. :param engine: train engine :return: """ if (engine.state.iteration - 1) % self.train_cfg['eval_interval'] == 0: self.evaluator.run() def _handle_exception(self, engine, e): """ Exception handler which ensures that the model gets saved when stopped through a keyboard interruption. :param engine: train engine :param e: the exception which caused the training to stop :return: """ if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() self.logger.warning( 'KeyboardInterrupt caught. Exiting gracefully.') self.save_last_checkpoint_handler(engine, self.model.networks) else: raise e def _handle_print_times(self, engine): """ Handler for logging timer information for different training and evaluation steps. :param engine: train engine :return: """ self.logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, self.timer.value())) self.timer.reset() @staticmethod def _extract_loss(key): """ Helper method to return losses for the RunningAverage :param key: (str) loss name :return: (fn) for the corresponding key """ def _func(losses): return losses[key] return _func
class BasicTimeProfiler(object): def __init__(self): self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() def _reset(self, num_iters, num_epochs, total_num_iters): self.dataflow_times = torch.zeros(total_num_iters) self.processing_times = torch.zeros(total_num_iters) self.event_handlers_times = { Events.STARTED: torch.zeros(1), Events.COMPLETED: torch.zeros(1), Events.EPOCH_STARTED: torch.zeros(num_epochs), Events.EPOCH_COMPLETED: torch.zeros(num_epochs), Events.ITERATION_STARTED: torch.zeros(total_num_iters), Events.ITERATION_COMPLETED: torch.zeros(total_num_iters) } def _as_first_started(self, engine): num_iters = engine.state.max_epochs * len(engine.state.dataloader) self._reset(len(engine.state.dataloader), engine.state.max_epochs, num_iters) self.event_handlers_names = { e: [h.__qualname__ if hasattr(h, "__qualname__") else h.__class__.__name__ for (h, _, _) in engine._event_handlers[e]] for e in Events if e != Events.EXCEPTION_RAISED } # Setup all other handlers: engine._event_handlers[Events.STARTED].append((self._as_last_started, (), {})) # - add the first handlers events = [Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, Events.COMPLETED] fmethods = [self._as_first_epoch_started, self._as_first_epoch_completed, self._as_first_iter_started, self._as_first_iter_completed, self._as_first_completed] lmethods = [self._as_last_epoch_started, self._as_last_epoch_completed, self._as_last_iter_started, self._as_last_iter_completed, self._as_last_completed] for e, m in zip(events, fmethods): engine._event_handlers[e].insert(0, (m, (), {})) for e, m in zip(events, lmethods): engine._event_handlers[e].append((m, (), {})) # Let's go self._event_handlers_timer.reset() def _as_last_started(self, engine): self.event_handlers_times[Events.STARTED][0] = self._event_handlers_timer.value() def _as_first_epoch_started(self, engine): self._event_handlers_timer.reset() def _as_last_epoch_started(self, engine): t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_STARTED][e] = t self._dataflow_timer.reset() def _as_first_iter_started(self, engine): t = self._dataflow_timer.value() i = engine.state.iteration - 1 self.dataflow_times[i] = t self._event_handlers_timer.reset() def _as_last_iter_started(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_STARTED][i] = t self._processing_timer.reset() def _as_first_iter_completed(self, engine): t = self._processing_timer.value() i = engine.state.iteration - 1 self.processing_times[i] = t self._event_handlers_timer.reset() def _as_last_iter_completed(self, engine): t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t self._dataflow_timer.reset() def _as_first_epoch_completed(self, engine): self._event_handlers_timer.reset() def _as_last_epoch_completed(self, engine): t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t def _as_first_completed(self, engine): self._event_handlers_timer.reset() def _as_last_completed(self, engine): self.event_handlers_times[Events.COMPLETED][0] = self._event_handlers_timer.value() # Remove added handlers: remove_handler(engine, self._as_last_started, Events.STARTED) # - add the first handlers events = [Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, Events.COMPLETED] fmethods = [self._as_first_epoch_started, self._as_first_epoch_completed, self._as_first_iter_started, self._as_first_iter_completed, self._as_first_completed] lmethods = [self._as_last_epoch_started, self._as_last_epoch_completed, self._as_last_iter_started, self._as_last_iter_completed, self._as_last_completed] for e, m in zip(events, fmethods): remove_handler(engine, m, e) for e, m in zip(events, lmethods): remove_handler(engine, m, e) def attach(self, engine): if not isinstance(engine, Engine): raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine))) if not engine.has_event_handler(self._as_first_started): engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (), {})) @staticmethod def _compute_basic_stats(data): return OrderedDict([ ('min/index', (torch.min(data).item(), torch.argmin(data).item())), ('max/index', (torch.max(data).item(), torch.argmax(data).item())), ('mean', torch.mean(data).item()), ('std', torch.std(data).item()), ('total', torch.sum(data).item()) ]) def get_results(self): total_eh_time = sum([sum(self.event_handlers_times[e]) for e in Events if e != Events.EXCEPTION_RAISED]) return OrderedDict([ ("processing_stats", self._compute_basic_stats(self.processing_times)), ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)), ("event_handlers_stats", dict([(str(e).replace(".","_"), self._compute_basic_stats(self.event_handlers_times[e])) for e in Events if e != Events.EXCEPTION_RAISED] + [("total_time", total_eh_time)]) ), ("event_handlers_names", {str(e).replace(".","_") + "_names": v for e, v in self.event_handlers_names.items()}) ]) @staticmethod def print_results(results): def odict_to_str(d): out = "" for k, v in d.items(): out += "\t{}: {}\n".format(k, v) return out others = {k: odict_to_str(v) if isinstance(v, OrderedDict) else v for k, v in results['event_handlers_stats'].items()} others.update(results['event_handlers_names']) output_message = """ -------------------------------------------- - Time profiling results: -------------------------------------------- Processing function time stats (in seconds): {processing_stats} Dataflow time stats (in seconds): {dataflow_stats} Time stats of event handlers (in seconds): - Total time spent: \t{total_time} - Events.STARTED: {Events_STARTED} Handlers names: {Events_STARTED_names} - Events.EPOCH_STARTED: {Events_EPOCH_STARTED} Handlers names: {Events_EPOCH_STARTED_names} - Events.ITERATION_STARTED: {Events_ITERATION_STARTED} Handlers names: {Events_ITERATION_STARTED_names} - Events.ITERATION_COMPLETED: {Events_ITERATION_COMPLETED} Handlers names: {Events_ITERATION_COMPLETED_names} - Events.EPOCH_COMPLETED: {Events_EPOCH_COMPLETED} Handlers names: {Events_EPOCH_COMPLETED_names} - Events.COMPLETED: {Events_COMPLETED} Handlers names: {Events_COMPLETED_names} """.format(processing_stats=odict_to_str(results['processing_stats']), dataflow_stats=odict_to_str(results['dataflow_stats']), **others) print(output_message) return output_message @staticmethod def write_results(output_path): try: import pandas as pd except ImportError: print("Need pandas to write results as files") return raise NotImplementedError("")