def _internal_run(self) -> State: self.should_terminate = self.should_terminate_single_epoch = False self._init_timers(self.state) try: start_time = time.time() self._fire_event(Events.STARTED) while self.state.epoch < self.state.max_epochs and not self.should_terminate: self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) if self._dataloader_iter is None: self._setup_engine() time_taken = self._run_once_on_dataset() # time is available for handlers but must be update after fire self.state.times[Events.EPOCH_COMPLETED.name] = time_taken handlers_start_time = time.time() if self.should_terminate: self._fire_event(Events.TERMINATE) else: self._fire_event(Events.EPOCH_COMPLETED) time_taken += time.time() - handlers_start_time # update time wrt handlers self.state.times[Events.EPOCH_COMPLETED.name] = time_taken hours, mins, secs = _to_hours_mins_secs(time_taken) self.logger.info( "Epoch[%s] Complete. Time taken: %02d:%02d:%02d" % (self.state.epoch, hours, mins, secs)) if self.should_terminate: break time_taken = time.time() - start_time # time is available for handlers but must be update after fire self.state.times[Events.COMPLETED.name] = time_taken handlers_start_time = time.time() self._fire_event(Events.COMPLETED) time_taken += time.time() - handlers_start_time # update time wrt handlers self.state.times[Events.COMPLETED.name] = time_taken hours, mins, secs = _to_hours_mins_secs(time_taken) self.logger.info( "Engine run complete. Time taken: %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: self._dataloader_iter = None self.logger.error( "Engine run is terminating due to exception: %s.", str(e)) self._handle_exception(e) self._dataloader_iter = None return self.state
def run(self, dataloader, max_epochs=1, start_epoch=0, iteration=0): """ Runs the process_function, and support resume from other start_epoch or iteration :param dataloader: :param start_epoch: which epoch to start with :param iteration: which iteration to start with :param max_epochs: :return: RunningState """ self.state = RunningState(epoch=start_epoch, iteration=iteration, max_epochs=max_epochs) try: self._logger.info("Engine run starting with max_epochs={}".format(max_epochs)) start_time = time.time() self._fire_event(Events.STARTED) while self.state.epoch < max_epochs and not self.should_terminate: self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) hours, mins, secs = self._run_once_on_dataset(dataloader) self._logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) if self.should_terminate: break self._fire_event(Events.EPOCH_COMPLETED) self._fire_event(Events.COMPLETED) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) self._logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: self._logger.error("Engine run is terminating due to exception: %s", str(e)) self._handle_exception(e) return self.state
def _internal_run(self): self.should_terminate = self.should_terminate_single_epoch = False try: start_time = time.time() self._fire_event(Events.STARTED) while self.state.epoch < self.state.max_epochs and not self.should_terminate: self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) if self._dataloader_iter is None: self._setup_engine() hours, mins, secs = self._run_once_on_dataset() self.logger.info( "Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) if self.should_terminate: break self._fire_event(Events.EPOCH_COMPLETED) self._fire_event(Events.COMPLETED) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) self.logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: self._dataloader_iter = self._dataloader_len = None self.logger.error( "Engine run is terminating due to exception: %s.", str(e)) self._handle_exception(e) self._dataloader_iter = self._dataloader_len = None return self.state
def run(self, data, max_epochs=1): """Runs the `process_function` over the passed data. Args: data (Iterable): Collection of batches allowing repeated iteration (e.g., list or `DataLoader`). max_epochs (int, optional): max epochs to run for (default: 1). Returns: State: output state. Note: User can dynamically preprocess input batch at :attr:`~ignite.engine.Events.ITERATION_STARTED` and store output batch in `engine.state.batch`. Latter is passed as usually to `process_function` as argument: .. code-block:: python trainer = ... @trainer.on(Events.ITERATION_STARTED) def switch_batch(engine): engine.state.batch = preprocess_batch(engine.state.batch) """ self.state = State(dataloader=data, max_epochs=max_epochs, metrics={}) self.should_terminate = self.should_terminate_single_epoch = False try: self._logger.info( "Engine run starting with max_epochs={}.".format(max_epochs)) start_time = time.time() self._fire_event(Events.STARTED) while self.state.epoch < max_epochs and not self.should_terminate: self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) hours, mins, secs = self._run_once_on_dataset() self._logger.info( "Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) if self.should_terminate: break self._fire_event(Events.EPOCH_COMPLETED) self._fire_event(Events.COMPLETED) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) self._logger.info( "Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: self._logger.error( "Engine run is terminating due to exception: %s.", str(e)) self._handle_exception(e) return self.state
def run(self, training_data, max_epochs=1): """ Train the model, evaluate the validation set and update best parameters if the validation loss improves. In the event that the validation set is not run (or doesn't exist), the training loss is used to update the best parameters. Parameters ---------- training_data : Iterable Collection of training batches allowing repeated iteration (e.g., list or DataLoader) max_epochs: int, optional max epochs to train for [default=1] Returns ------- None """ self.dataloader = training_data self.current_iteration = 0 self.current_epoch = 0 try: self._logger.info( "Training starting with max_epochs={}".format(max_epochs)) self.max_epochs = max_epochs start_time = time.time() self._fire_event(Events.STARTED) while self.current_epoch < max_epochs and not self.should_terminate: self.current_epoch += 1 self._fire_event(Events.EPOCH_STARTED) self._train_one_epoch(training_data) if self.should_terminate: break self._fire_event(Events.EPOCH_COMPLETED) self._fire_event(Events.COMPLETED) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) self._logger.info("Training complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: self._logger.error("Training is terminating due to exception: %s", str(e)) self._handle_exception(e)
def _run_once_on_dataset(self): start_time = time.time() for batch in self.state.dataloader: self.state.batch = batch self.state.iteration += 1 self._fire_event(Events.ITERATION_STARTED) self.state.output = self._process_function(self, batch) self._fire_event(Events.ITERATION_COMPLETED) if self.should_terminate or self.should_terminate_single_epoch: self.should_terminate_single_epoch = False break time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) return hours, mins, secs
def run(self, data, max_epochs=1): """Runs the process_function over the passed data. Args: data (Iterable): Collection of batches allowing repeated iteration (e.g., list or `DataLoader`) max_epochs (int, optional): max epochs to run for (default: 1) Returns: State: output state """ self.state = State(dataloader=data, epoch=0, max_epochs=max_epochs, metrics={}) try: self._logger.info( "Engine run starting with max_epochs={}".format(max_epochs)) start_time = time.time() self._fire_event(Events.STARTED) while self.state.epoch < max_epochs and not self.should_terminate: self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) hours, mins, secs = self._run_once_on_dataset() self._logger.info( "Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) if self.should_terminate: break self._fire_event(Events.EPOCH_COMPLETED) self._fire_event(Events.COMPLETED) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) self._logger.info( "Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: self._logger.error( "Engine run is terminating due to exception: %s", str(e)) self._handle_exception(e) return self.state
def _run_once_on_dataset(self, state): try: start_time = time.time() for batch in state.dataloader: state.batch = batch state.iteration += 1 self._fire_event(Events.ITERATION_STARTED, state) state.output = self._process_function(batch) self._fire_event(Events.ITERATION_COMPLETED, state) if self.should_terminate: break time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) return hours, mins, secs except BaseException as e: self._logger.error("Current run is terminating due to exception: %s", str(e)) self._handle_exception(state, e)
def _run_once_on_dataset(self, dataloader): start_time = time.time() try: for batch in dataloader: self.state.iteration += 1 self._fire_event(Events.ITERATION_STARTED) self.state.output = self._process_function(self, batch) self._fire_event(Events.ITERATION_COMPLETED) if self.should_terminate or self.should_terminate_single_epoch: self.should_terminate_single_epoch = False break except BaseException as e: self._logger.error("Current run is terminating due to exception: %s", str(e)) self._handle_exception(e) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) return hours, mins, secs
def _run_once_on_dataset(self, dataloader): self.dataloader = dataloader try: start_time = time.time() for batch in dataloader: self.current_iteration += 1 self._fire_event(Events.ITERATION_STARTED) step_result = self._process_function(batch) if step_result is not None: self.history.append(step_result) self._fire_event(Events.ITERATION_COMPLETED) if self.should_terminate: break time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) return hours, mins, secs except BaseException as e: self._logger.error( "Current run is terminating due to exception: %s", str(e)) self._handle_exception(e)
def _run_once_on_dataset(self): start_time = time.time() # We need to setup iter_counter > 0 if we resume from an iteration iter_counter = self._init_iter.pop() if len(self._init_iter) > 0 else 0 should_exit = False try: while True: try: self._fire_event(Events.GET_BATCH_STARTED) batch = next(self._dataloader_iter) self._fire_event(Events.GET_BATCH_COMPLETED) iter_counter += 1 should_exit = False except StopIteration: if self._dataloader_len is None: if iter_counter > 0: self._dataloader_len = iter_counter else: # this can happen when data is finite iterator and epoch_length is equal to its size self._dataloader_len = self.state.iteration # Should exit while loop if we can not iterate if should_exit: if not self._is_done(self.state): warnings.warn( "Data iterator can not provide data anymore but required total number of " "iterations to run is not reached. " "Current iteration: {} vs Total iterations to run : {}" .format( self.state.iteration, self.state.epoch_length * self.state.max_epochs)) break # set seed on restart of data iterator self.setup_seed() self._dataloader_iter = iter(self.state.dataloader) should_exit = True continue self.state.batch = batch self.state.iteration += 1 self._fire_event(Events.ITERATION_STARTED) self.state.output = self._process_function( self, self.state.batch) self._fire_event(Events.ITERATION_COMPLETED) # TODO: remove refs on batch to avoid high mem consumption ? -> need verification # self.state.batch = batch = None if self.should_terminate or self.should_terminate_single_epoch: self.should_terminate_single_epoch = False self._manual_seed(self.state.seed, self.state.iteration // iter_counter) self._dataloader_iter = iter(self.state.dataloader) break if iter_counter == self.state.epoch_length: break except BaseException as e: self.logger.error( "Current run is terminating due to exception: %s.", str(e)) self._handle_exception(e) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) return hours, mins, secs
def log_validation_results(trainer): trainer.state.iteration = 0 train_metrics = trainer.state.metrics print("Epoch[{}] Training X-Entropy={:.3f}".format( trainer.state.epoch, train_metrics["x-entropy"])) trainer.state.training_history["x-entropy"].append( train_metrics["x-entropy"]) evaluator.run(validation_dataloader) metrics = evaluator.state.metrics valid_metrics = evaluator.state.metrics valid_history = trainer.state.validation_history valid_metric_strings = [] if valid_metrics["x-entropy"] < trainer.state.min_valid_xent: valid_metric_strings.append( Fore.GREEN + \ "X-Entropy={:.3f}".format(valid_metrics["x-entropy"]) + \ Style.RESET_ALL) trainer.state.min_valid_xent = valid_metrics["x-entropy"] else: valid_metric_strings.append( "X-Entropy={:.3f}".format(valid_metrics["x-entropy"])) if valid_metrics["rouge"]["rouge-1"] > trainer.state.max_valid_rouge1: valid_metric_strings.append( Fore.GREEN + \ "Rouge-1={:.3f}".format(valid_metrics["rouge"]["rouge-1"]) + \ Style.RESET_ALL) trainer.state.max_valid_rouge1 = valid_metrics["rouge"]["rouge-1"] else: valid_metric_strings.append( "Rouge-1={:.3f}".format(valid_metrics["rouge"]["rouge-1"])) if valid_metrics["rouge"]["rouge-2"] > trainer.state.max_valid_rouge2: valid_metric_strings.append( Fore.GREEN + \ "Rouge-2={:.3f}".format(valid_metrics["rouge"]["rouge-2"]) + \ Style.RESET_ALL) trainer.state.max_valid_rouge2 = valid_metrics["rouge"]["rouge-2"] else: valid_metric_strings.append( "Rouge-2={:.3f}".format(valid_metrics["rouge"]["rouge-2"])) print("Epoch[{}] Validation {}".format( trainer.state.epoch, " ".join(valid_metric_strings))) valid_history["x-entropy"].append(valid_metrics["x-entropy"]) valid_history["rouge-1"].append(valid_metrics["rouge"]["rouge-1"]) valid_history["rouge-2"].append(valid_metrics["rouge"]["rouge-2"]) hrs, mins, secs = _to_hours_mins_secs( time.time() - trainer.state.start_time) print("Epoch[{}] Time Taken: {:02.0f}:{:02.0f}:{:02.0f}".format( trainer.state.epoch, hrs, mins, secs)) print() if results_path: if not results_path.parent.exists(): results_path.parent.mkdir(parents=True, exist_ok=True) results_path.write_text( json.dumps({"training": trainer.state.training_history, "validation": trainer.state.validation_history}))