示例#1
0
    def _internal_run(self) -> State:
        self.should_terminate = self.should_terminate_single_epoch = False
        self._init_timers(self.state)
        try:
            start_time = time.time()
            self._fire_event(Events.STARTED)
            while self.state.epoch < self.state.max_epochs and not self.should_terminate:
                self.state.epoch += 1
                self._fire_event(Events.EPOCH_STARTED)

                if self._dataloader_iter is None:
                    self._setup_engine()

                time_taken = self._run_once_on_dataset()
                # time is available for handlers but must be update after fire
                self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
                handlers_start_time = time.time()
                if self.should_terminate:
                    self._fire_event(Events.TERMINATE)
                else:
                    self._fire_event(Events.EPOCH_COMPLETED)
                time_taken += time.time() - handlers_start_time
                # update time wrt handlers
                self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
                hours, mins, secs = _to_hours_mins_secs(time_taken)
                self.logger.info(
                    "Epoch[%s] Complete. Time taken: %02d:%02d:%02d" %
                    (self.state.epoch, hours, mins, secs))
                if self.should_terminate:
                    break

            time_taken = time.time() - start_time
            # time is available for handlers but must be update after fire
            self.state.times[Events.COMPLETED.name] = time_taken
            handlers_start_time = time.time()
            self._fire_event(Events.COMPLETED)
            time_taken += time.time() - handlers_start_time
            # update time wrt handlers
            self.state.times[Events.COMPLETED.name] = time_taken
            hours, mins, secs = _to_hours_mins_secs(time_taken)
            self.logger.info(
                "Engine run complete. Time taken: %02d:%02d:%02d" %
                (hours, mins, secs))

        except BaseException as e:
            self._dataloader_iter = None
            self.logger.error(
                "Engine run is terminating due to exception: %s.", str(e))
            self._handle_exception(e)

        self._dataloader_iter = None
        return self.state
示例#2
0
文件: engine.py 项目: taohu88/tea
    def run(self, dataloader, max_epochs=1, start_epoch=0, iteration=0):
        """
        Runs the process_function, and support resume from other start_epoch or iteration
        :param dataloader:
        :param start_epoch: which epoch to start with
        :param iteration: which iteration to start with
        :param max_epochs:
        :return: RunningState
        """
        self.state = RunningState(epoch=start_epoch, iteration=iteration, max_epochs=max_epochs)
        try:
            self._logger.info("Engine run starting with max_epochs={}".format(max_epochs))
            start_time = time.time()
            self._fire_event(Events.STARTED)
            while self.state.epoch < max_epochs and not self.should_terminate:
                self.state.epoch += 1
                self._fire_event(Events.EPOCH_STARTED)
                hours, mins, secs = self._run_once_on_dataset(dataloader)
                self._logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs)
                if self.should_terminate:
                    break
                self._fire_event(Events.EPOCH_COMPLETED)

            self._fire_event(Events.COMPLETED)
            time_taken = time.time() - start_time
            hours, mins, secs = _to_hours_mins_secs(time_taken)
            self._logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs))

        except BaseException as e:
            self._logger.error("Engine run is terminating due to exception: %s", str(e))
            self._handle_exception(e)

        return self.state
示例#3
0
文件: engine.py 项目: schopfej/ignite
    def _internal_run(self):
        self.should_terminate = self.should_terminate_single_epoch = False
        try:
            start_time = time.time()
            self._fire_event(Events.STARTED)
            while self.state.epoch < self.state.max_epochs and not self.should_terminate:
                self.state.epoch += 1
                self._fire_event(Events.EPOCH_STARTED)

                if self._dataloader_iter is None:
                    self._setup_engine()

                hours, mins, secs = self._run_once_on_dataset()

                self.logger.info(
                    "Epoch[%s] Complete. Time taken: %02d:%02d:%02d",
                    self.state.epoch, hours, mins, secs)
                if self.should_terminate:
                    break
                self._fire_event(Events.EPOCH_COMPLETED)

            self._fire_event(Events.COMPLETED)
            time_taken = time.time() - start_time
            hours, mins, secs = _to_hours_mins_secs(time_taken)
            self.logger.info("Engine run complete. Time taken %02d:%02d:%02d" %
                             (hours, mins, secs))

        except BaseException as e:
            self._dataloader_iter = self._dataloader_len = None
            self.logger.error(
                "Engine run is terminating due to exception: %s.", str(e))
            self._handle_exception(e)

        self._dataloader_iter = self._dataloader_len = None
        return self.state
示例#4
0
文件: engine.py 项目: gq123smu/ignite
    def run(self, data, max_epochs=1):
        """Runs the `process_function` over the passed data.

        Args:
            data (Iterable): Collection of batches allowing repeated iteration (e.g., list or `DataLoader`).
            max_epochs (int, optional): max epochs to run for (default: 1).

        Returns:
            State: output state.


        Note:
            User can dynamically preprocess input batch at :attr:`~ignite.engine.Events.ITERATION_STARTED` and store
            output batch in `engine.state.batch`. Latter is passed as usually to `process_function` as argument:

            .. code-block:: python

                trainer = ...

                @trainer.on(Events.ITERATION_STARTED)
                def switch_batch(engine):
                    engine.state.batch = preprocess_batch(engine.state.batch)

        """

        self.state = State(dataloader=data, max_epochs=max_epochs, metrics={})
        self.should_terminate = self.should_terminate_single_epoch = False

        try:
            self._logger.info(
                "Engine run starting with max_epochs={}.".format(max_epochs))
            start_time = time.time()
            self._fire_event(Events.STARTED)
            while self.state.epoch < max_epochs and not self.should_terminate:
                self.state.epoch += 1
                self._fire_event(Events.EPOCH_STARTED)
                hours, mins, secs = self._run_once_on_dataset()
                self._logger.info(
                    "Epoch[%s] Complete. Time taken: %02d:%02d:%02d",
                    self.state.epoch, hours, mins, secs)
                if self.should_terminate:
                    break
                self._fire_event(Events.EPOCH_COMPLETED)

            self._fire_event(Events.COMPLETED)
            time_taken = time.time() - start_time
            hours, mins, secs = _to_hours_mins_secs(time_taken)
            self._logger.info(
                "Engine run complete. Time taken %02d:%02d:%02d" %
                (hours, mins, secs))

        except BaseException as e:
            self._logger.error(
                "Engine run is terminating due to exception: %s.", str(e))
            self._handle_exception(e)

        return self.state
示例#5
0
    def run(self, training_data, max_epochs=1):
        """
        Train the model, evaluate the validation set and update best parameters if the validation loss
        improves.
        In the event that the validation set is not run (or doesn't exist), the training loss is used
        to update the best parameters.

        Parameters
        ----------
        training_data : Iterable
            Collection of training batches allowing repeated iteration (e.g., list or DataLoader)
        max_epochs: int, optional
            max epochs to train for [default=1]

        Returns
        -------
        None
        """
        self.dataloader = training_data
        self.current_iteration = 0
        self.current_epoch = 0

        try:
            self._logger.info(
                "Training starting with max_epochs={}".format(max_epochs))

            self.max_epochs = max_epochs

            start_time = time.time()

            self._fire_event(Events.STARTED)
            while self.current_epoch < max_epochs and not self.should_terminate:
                self.current_epoch += 1
                self._fire_event(Events.EPOCH_STARTED)
                self._train_one_epoch(training_data)
                if self.should_terminate:
                    break
                self._fire_event(Events.EPOCH_COMPLETED)

            self._fire_event(Events.COMPLETED)
            time_taken = time.time() - start_time
            hours, mins, secs = _to_hours_mins_secs(time_taken)
            self._logger.info("Training complete. Time taken %02d:%02d:%02d" %
                              (hours, mins, secs))

        except BaseException as e:
            self._logger.error("Training is terminating due to exception: %s",
                               str(e))
            self._handle_exception(e)
    def _run_once_on_dataset(self):
        start_time = time.time()

        for batch in self.state.dataloader:
            self.state.batch = batch
            self.state.iteration += 1
            self._fire_event(Events.ITERATION_STARTED)
            self.state.output = self._process_function(self, batch)
            self._fire_event(Events.ITERATION_COMPLETED)
            if self.should_terminate or self.should_terminate_single_epoch:
                self.should_terminate_single_epoch = False
                break

        time_taken = time.time() - start_time
        hours, mins, secs = _to_hours_mins_secs(time_taken)

        return hours, mins, secs
示例#7
0
文件: engine.py 项目: py361/ignite
    def run(self, data, max_epochs=1):
        """Runs the process_function over the passed data.

        Args:
            data (Iterable): Collection of batches allowing repeated iteration (e.g., list or `DataLoader`)
            max_epochs (int, optional): max epochs to run for (default: 1)

        Returns:
            State: output state
        """

        self.state = State(dataloader=data,
                           epoch=0,
                           max_epochs=max_epochs,
                           metrics={})

        try:
            self._logger.info(
                "Engine run starting with max_epochs={}".format(max_epochs))
            start_time = time.time()
            self._fire_event(Events.STARTED)
            while self.state.epoch < max_epochs and not self.should_terminate:
                self.state.epoch += 1
                self._fire_event(Events.EPOCH_STARTED)
                hours, mins, secs = self._run_once_on_dataset()
                self._logger.info(
                    "Epoch[%s] Complete. Time taken: %02d:%02d:%02d",
                    self.state.epoch, hours, mins, secs)
                if self.should_terminate:
                    break
                self._fire_event(Events.EPOCH_COMPLETED)

            self._fire_event(Events.COMPLETED)
            time_taken = time.time() - start_time
            hours, mins, secs = _to_hours_mins_secs(time_taken)
            self._logger.info(
                "Engine run complete. Time taken %02d:%02d:%02d" %
                (hours, mins, secs))

        except BaseException as e:
            self._logger.error(
                "Engine run is terminating due to exception: %s", str(e))
            self._handle_exception(e)

        return self.state
示例#8
0
文件: engine.py 项目: DNGros/ignite
    def _run_once_on_dataset(self, state):
        try:
            start_time = time.time()
            for batch in state.dataloader:
                state.batch = batch
                state.iteration += 1
                self._fire_event(Events.ITERATION_STARTED, state)
                state.output = self._process_function(batch)
                self._fire_event(Events.ITERATION_COMPLETED, state)
                if self.should_terminate:
                    break

            time_taken = time.time() - start_time
            hours, mins, secs = _to_hours_mins_secs(time_taken)
            return hours, mins, secs
        except BaseException as e:
            self._logger.error("Current run is terminating due to exception: %s", str(e))
            self._handle_exception(state, e)
示例#9
0
文件: engine.py 项目: taohu88/tea
    def _run_once_on_dataset(self, dataloader):
        start_time = time.time()

        try:
            for batch in dataloader:
                self.state.iteration += 1
                self._fire_event(Events.ITERATION_STARTED)
                self.state.output = self._process_function(self, batch)
                self._fire_event(Events.ITERATION_COMPLETED)
                if self.should_terminate or self.should_terminate_single_epoch:
                    self.should_terminate_single_epoch = False
                    break

        except BaseException as e:
            self._logger.error("Current run is terminating due to exception: %s", str(e))
            self._handle_exception(e)

        time_taken = time.time() - start_time
        hours, mins, secs = _to_hours_mins_secs(time_taken)

        return hours, mins, secs
示例#10
0
    def _run_once_on_dataset(self, dataloader):
        self.dataloader = dataloader
        try:
            start_time = time.time()
            for batch in dataloader:
                self.current_iteration += 1
                self._fire_event(Events.ITERATION_STARTED)
                step_result = self._process_function(batch)
                if step_result is not None:
                    self.history.append(step_result)

                self._fire_event(Events.ITERATION_COMPLETED)
                if self.should_terminate:
                    break

            time_taken = time.time() - start_time
            hours, mins, secs = _to_hours_mins_secs(time_taken)
            return hours, mins, secs
        except BaseException as e:
            self._logger.error(
                "Current run is terminating due to exception: %s", str(e))
            self._handle_exception(e)
示例#11
0
文件: engine.py 项目: schopfej/ignite
    def _run_once_on_dataset(self):
        start_time = time.time()

        # We need to setup iter_counter > 0 if we resume from an iteration
        iter_counter = self._init_iter.pop() if len(self._init_iter) > 0 else 0
        should_exit = False
        try:
            while True:
                try:
                    self._fire_event(Events.GET_BATCH_STARTED)
                    batch = next(self._dataloader_iter)
                    self._fire_event(Events.GET_BATCH_COMPLETED)
                    iter_counter += 1
                    should_exit = False
                except StopIteration:
                    if self._dataloader_len is None:
                        if iter_counter > 0:
                            self._dataloader_len = iter_counter
                        else:
                            # this can happen when data is finite iterator and epoch_length is equal to its size
                            self._dataloader_len = self.state.iteration

                    # Should exit while loop if we can not iterate
                    if should_exit:
                        if not self._is_done(self.state):
                            warnings.warn(
                                "Data iterator can not provide data anymore but required total number of "
                                "iterations to run is not reached. "
                                "Current iteration: {} vs Total iterations to run : {}"
                                .format(
                                    self.state.iteration,
                                    self.state.epoch_length *
                                    self.state.max_epochs))
                        break

                    # set seed on restart of data iterator
                    self.setup_seed()
                    self._dataloader_iter = iter(self.state.dataloader)

                    should_exit = True

                    continue

                self.state.batch = batch
                self.state.iteration += 1
                self._fire_event(Events.ITERATION_STARTED)
                self.state.output = self._process_function(
                    self, self.state.batch)
                self._fire_event(Events.ITERATION_COMPLETED)

                # TODO: remove refs on batch to avoid high mem consumption ? -> need verification
                # self.state.batch = batch = None

                if self.should_terminate or self.should_terminate_single_epoch:
                    self.should_terminate_single_epoch = False
                    self._manual_seed(self.state.seed,
                                      self.state.iteration // iter_counter)
                    self._dataloader_iter = iter(self.state.dataloader)
                    break

                if iter_counter == self.state.epoch_length:
                    break

        except BaseException as e:
            self.logger.error(
                "Current run is terminating due to exception: %s.", str(e))
            self._handle_exception(e)

        time_taken = time.time() - start_time
        hours, mins, secs = _to_hours_mins_secs(time_taken)

        return hours, mins, secs
示例#12
0
    def log_validation_results(trainer):
        trainer.state.iteration = 0
        train_metrics = trainer.state.metrics
        print("Epoch[{}] Training X-Entropy={:.3f}".format(
            trainer.state.epoch, train_metrics["x-entropy"]))
        trainer.state.training_history["x-entropy"].append(
            train_metrics["x-entropy"])

        evaluator.run(validation_dataloader)
        metrics = evaluator.state.metrics
        
        valid_metrics = evaluator.state.metrics
        valid_history = trainer.state.validation_history
        valid_metric_strings = []

        if valid_metrics["x-entropy"] < trainer.state.min_valid_xent:
            valid_metric_strings.append(
                Fore.GREEN + \
                "X-Entropy={:.3f}".format(valid_metrics["x-entropy"]) + \
                Style.RESET_ALL)
            trainer.state.min_valid_xent = valid_metrics["x-entropy"]
        else:
            valid_metric_strings.append(
                "X-Entropy={:.3f}".format(valid_metrics["x-entropy"]))
 
        if valid_metrics["rouge"]["rouge-1"] > trainer.state.max_valid_rouge1:
            valid_metric_strings.append(
                Fore.GREEN + \
                "Rouge-1={:.3f}".format(valid_metrics["rouge"]["rouge-1"]) + \
                Style.RESET_ALL)
            trainer.state.max_valid_rouge1 = valid_metrics["rouge"]["rouge-1"]
        else:
            valid_metric_strings.append(
                "Rouge-1={:.3f}".format(valid_metrics["rouge"]["rouge-1"]))
        
        if valid_metrics["rouge"]["rouge-2"] > trainer.state.max_valid_rouge2:
            valid_metric_strings.append(
                Fore.GREEN + \
                "Rouge-2={:.3f}".format(valid_metrics["rouge"]["rouge-2"]) + \
                Style.RESET_ALL)
            trainer.state.max_valid_rouge2 = valid_metrics["rouge"]["rouge-2"]
        else:
            valid_metric_strings.append(
                "Rouge-2={:.3f}".format(valid_metrics["rouge"]["rouge-2"]))
        
        print("Epoch[{}] Validation {}".format(
            trainer.state.epoch,
            " ".join(valid_metric_strings))) 

        valid_history["x-entropy"].append(valid_metrics["x-entropy"])
        valid_history["rouge-1"].append(valid_metrics["rouge"]["rouge-1"])
        valid_history["rouge-2"].append(valid_metrics["rouge"]["rouge-2"])

        hrs, mins, secs = _to_hours_mins_secs(
            time.time() - trainer.state.start_time)
        print("Epoch[{}] Time Taken: {:02.0f}:{:02.0f}:{:02.0f}".format(
            trainer.state.epoch, hrs, mins, secs))

        print()

        if results_path:
            if not results_path.parent.exists():
                results_path.parent.mkdir(parents=True, exist_ok=True)
            results_path.write_text(
                json.dumps({"training": trainer.state.training_history,
                            "validation": trainer.state.validation_history}))