예제 #1
0
 def _run_iteration(self):
     try:
         with Timer('read_data', self.profile):
             batch = next(self.epoch_iterator)
     except StopIteration:
         if not self.log.status['received_first_batch']:
             reraise_as(ValueError("epoch iterator yielded zero batches"))
         return False
     self.log.status['received_first_batch'] = True
     self._run_extensions('before_batch', batch)
     with Timer('train', self.profile):
         self.algorithm.process_batch(batch)
     self.status['iterations_done'] += 1
     self._run_extensions('after_batch', batch)
     self._check_finish_training('batch')
     return True
예제 #2
0
    def run(self):
        logging.basicConfig()

        with change_recursion_limit(cfg.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                if self.log.status['iterations_done'] > 0:
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc(e)
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception as inner_e:
                    logger.error(traceback.format_exc(inner_e))
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if cfg.profile:
                    self.profile.report()
                self._restore_signal_handlers()
예제 #3
0
 def _run_epoch(self):
     if not self.status.get('epoch_started', False):
         try:
             self.log.status['received_first_batch'] = False
             self.epoch_iterator = (self.data_stream.get_epoch_iterator(
                 as_dict=True))
         except StopIteration:
             return False
         self.status['epoch_started'] = True
         self._run_extensions('before_epoch')
     with Timer('epoch', self.profile):
         while self._run_iteration():
             pass
     self.status['epoch_started'] = False
     self.status['epochs_done'] += 1
     self.status['_epoch_ends'].append(self.status['iterations_done'])
     self._run_extensions('after_epoch')
     self._check_finish_training('epoch')
     return True
예제 #4
0
 def _run_extensions(self, method_name, *args):
     with Timer(method_name, self.profile):
         for extension in self.extensions:
             with Timer(type(extension).__name__, self.profile):
                 extension.dispatch(CallbackName(method_name), *args)
예제 #5
0
    def run(self):
        """Starts the main loop.

        The main loop ends when a training extension makes
        a `training_finish_requested` record in the log.

        """
        # This should do nothing if the user has already configured
        # logging, and will it least enable error messages otherwise.
        logging.basicConfig()

        if self._model and isinstance(self.algorithm,
                                      DifferentiableCostMinimizer):
            # Sanity check: model and algorithm should be configured
            # similarly.
            if not self._model.get_objective() == self.algorithm.cost:
                logger.warning("different costs for model and algorithm")
            if not (set(self._model.get_params().values()) == set(
                    self.algorithm.params)):
                logger.warning("different params for model and algorithm")

        with change_recursion_limit(config.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                # We can not write "else:" here because extensions
                # called "before_training" could have changed the status
                # of the main loop.
                if self.log.status['iterations_done'] > 0:
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc(e)
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception as inner_e:
                    logger.error(traceback.format_exc(inner_e))
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if config.profile:
                    self.profile.report()
                self._restore_signal_handlers()
예제 #6
0
    def run(self):
        """Starts the main loop.

        The main loop ends when a training extension makes
        a `training_finish_requested` record in the log.

        """
        # This should do nothing if the user has already configured
        # logging, and will it least enable error messages otherwise.
        logging.basicConfig()

        # If this is resumption from a checkpoint, it is crucial to
        # reset `profile.current`. Otherwise, it simply does not hurt.
        self.profile.current = []

        # check the model only if it wants to be checked
        if hasattr(self._model, 'check_sanity'):
            self._model.check_sanity(self.algorithm)

        with change_recursion_limit(config.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                # We can not write "else:" here because extensions
                # called "before_training" could have changed the status
                # of the main loop.
                if self.log.status['iterations_done'] > 0:
                    self.log.resume()
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc()
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error', e)
                except Exception:
                    logger.error(traceback.format_exc())
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                self._restore_signal_handlers()
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if config.profile:
                    self.profile.report()