def __init__(self, models, algorithm, data_stream,
                 num_encs=1, num_decs=1, log=None, extensions=None):
        """
        models : dict, mapping cg_name to blocks.model
        algorithm : SGDMultiCG
        data_stream : data stream, either MultiSourceStream or MultiEncStream
        num_encs : int, number of encoders
        num_decs : int, number of decoders
        log : blocks.log, the logger object
        extensions : blocks.extensions, the main loop extensions
        """
        self.models = models
        self.num_encs = num_encs
        self.num_decs = num_decs
        self.num_cgs = len(models)

        if log is None:
            log = TrainingLog()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False
Exemple #2
0
    def __init__(self,
                 algorithm,
                 data_stream,
                 model=None,
                 log=None,
                 log_backend=None,
                 extensions=None):
        if log is None:
            if log_backend is None:
                log_backend = config.log_backend
            log = BACKENDS[log_backend]()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.epoch_iterator = None
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self._model = model

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False
    def __init__(self, algorithm_g, g_out, algorithm_d, d_out, data_stream, 
                 false_generated, false_dataset,
                 generator=None, discriminator=None, noise_per_sample=10, k=1,
                 minibatches=1, log=None, log_backend=None, extensions=None):
        if log is None:
            if log_backend is None:
                log_backend = config.log_backend
            log = BACKENDS[log_backend]()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm_g
        self.algorithm_g = algorithm_g
        self.algorithm_d = algorithm_d
        self.log = log
        self.extensions = extensions
        self.g_out = g_out
        self.d_out = d_out
        self.k = k
        self.minibatches = minibatches
        self.noise_per_sample = noise_per_sample
        self.false_generated = false_generated
        self.false_dataset = false_dataset

        self.profile = Profile()

        self._generator = generator

        self._discriminator = discriminator

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False
Exemple #4
0
    def __init__(self, models, algorithm, data_stream,
                 num_encs=1, num_decs=1, log=None, extensions=None):
        """
        models : dict, mapping cg_name to blocks.model
        algorithm : SGDMultiCG
        data_stream : data stream, either MultiSourceStream or MultiEncStream
        num_encs : int, number of encoders
        num_decs : int, number of decoders
        log : blocks.log, the logger object
        extensions : blocks.extensions, the main loop extensions
        """
        self.models = models
        self.num_encs = num_encs
        self.num_decs = num_decs
        self.num_cgs = len(models)

        if log is None:
            log = TrainingLog()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False
Exemple #5
0
    def __init__(self,
                 algorithm,
                 data_stream,
                 model=None,
                 log=None,
                 extensions=None):
        if log is None:
            log = TrainingLog()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self._model = model

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False
Exemple #6
0
    def __init__(self, algorithm, data_stream,
                 model=None, log=None, extensions=None):
        if log is None:
            log = TrainingLog()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self._model = model

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False
Exemple #7
0
    def __init__(self, algorithm, data_stream, model=None, log=None,
                 log_backend=None, extensions=None):
        if log is None:
            if log_backend is None:
                log_backend = config.log_backend
            log = BACKENDS[log_backend]()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.epoch_iterator = None
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self._model = model

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False
Exemple #8
0
class MainLoop(object):
    """The standard main loop of Blocks.

    In the `MainLoop` a model is trained by a training algorithm using data
    extracted from a data stream. This process is scrupulously documented
    in a log object.

    The `MainLoop` itself does very little: only fetching the data from the
    data stream and feeding it to the algorithm. It expects the extensions
    to do most of the job. A respective callback of every extension is
    called at every stage of training. The extensions should communicate
    between themselves and with the main loop object by means of making
    records in the log. For instance in order to stop the training
    procedure an extension can make a record
    `training_finish_requested=True` in the log. The main loop checks for
    such a record after every batch and every epoch and terminates when
    finds it.

    The `MainLoop` also handles interruption signal SIGINT for you (e.g.
    the one program receives when you press Ctrl + C). It notes this event
    in the log and at the next iteration or epoch end the main loop will
    be gracefully finished, with calling all necessary extension callbacks
    and waiting until they finish.

    Parameters
    ----------
    algorithm : object
        The training algorithm.
    data_stream : instance of :class:`.DataStream`.
        The data stream.
    model : :class:`.AbstractModel` instance, optional
        The model object. It is entirely transparent for the main loop
        but may be used by extensions.
    log : instance of :class:`.TrainingLog`, optional
        The log. When not given, a :class:`.TrainingLog` is created.
    extensions : list of :class:`.TrainingExtension` instances
        The training extensions. Will be called in the same order as given
        here.
    profile : :class:`.Profile`
        Keeps track of the times spent in differen segments of the training
        loop.

    """
    def __init__(self,
                 algorithm,
                 data_stream,
                 model=None,
                 log=None,
                 extensions=None):
        if log is None:
            log = TrainingLog()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self._model = model

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False

    @property
    def model(self):
        if not self._model:
            raise AttributeError("no model in this main loop" +
                                 no_model_message)
        return self._model

    @property
    def iteration_state(self):
        """Quick access to the (data stream, epoch iterator) pair."""
        return (self.data_stream, self.epoch_iterator)

    @iteration_state.setter
    def iteration_state(self, value):
        (self.data_stream, self.epoch_iterator) = value

    @property
    def status(self):
        """A shortcut for `self.log.status`."""
        return self.log.status

    def run(self):
        """Starts the main loop.

        The main loop ends when a training extension makes
        a `training_finish_requested` record in the log.

        """
        # This should do nothing if the user has already configured
        # logging, and will it least enable error messages otherwise.
        logging.basicConfig()

        if self._model and isinstance(self.algorithm,
                                      DifferentiableCostMinimizer):
            # Sanity check: model and algorithm should be configured
            # similarly.
            if not self._model.get_objective() == self.algorithm.cost:
                logger.warning("different costs for model and algorithm")
            if not (set(self._model.get_params().values()) == set(
                    self.algorithm.params)):
                logger.warning("different params for model and algorithm")

        with change_recursion_limit(config.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                # We can not write "else:" here because extensions
                # called "before_training" could have changed the status
                # of the main loop.
                if self.log.status['iterations_done'] > 0:
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc(e)
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception as inner_e:
                    logger.error(traceback.format_exc(inner_e))
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if config.profile:
                    self.profile.report()
                self._restore_signal_handlers()

    def find_extension(self, name):
        """Find an extension with a given name.

        Parameters
        ----------
        name : str
            The name of the extension looked for.

        Notes
        -----
        Will crash if there no or several extension found.

        """
        return unpack([
            extension
            for extension in self.extensions if extension.name == name
        ],
                      singleton=True)

    def _run_epoch(self):
        if not self.status.get('epoch_started', False):
            try:
                self.log.status['received_first_batch'] = False
                self.epoch_iterator = (self.data_stream.get_epoch_iterator(
                    as_dict=True))
            except StopIteration:
                return False
            self.status['epoch_started'] = True
            self._run_extensions('before_epoch')
        with Timer('epoch', self.profile):
            while self._run_iteration():
                pass
        self.status['epoch_started'] = False
        self.status['epochs_done'] += 1
        self.status['_epoch_ends'].append(self.status['iterations_done'])
        self._run_extensions('after_epoch')
        self._check_finish_training('epoch')
        return True

    def _run_iteration(self):
        try:
            with Timer('read_data', self.profile):
                batch = next(self.epoch_iterator)
        except StopIteration:
            if not self.log.status['received_first_batch']:
                reraise_as(ValueError("epoch iterator yielded zero batches"))
            return False
        self.log.status['received_first_batch'] = True
        self._run_extensions('before_batch', batch)
        with Timer('train', self.profile):
            self.algorithm.process_batch(batch)
        self.status['iterations_done'] += 1
        self._run_extensions('after_batch', batch)
        self._check_finish_training('batch')
        return True

    def _run_extensions(self, method_name, *args):
        with Timer(method_name, self.profile):
            for extension in self.extensions:
                with Timer(type(extension).__name__, self.profile):
                    extension.dispatch(CallbackName(method_name), *args)

    def _check_finish_training(self, level):
        """Checks whether the current training should be terminated.

        Parameters
        ----------
        level : {'epoch', 'batch'}
            The level at which this check was performed. In some cases, we
            only want to quit after completing the remained of the epoch.

        """
        # In case when keyboard interrupt is handled right at the end of
        # the iteration the corresponding log record can be found only in
        # the previous row.
        if (self.log.current_row.get('training_finish_requested', False)
                or self.status.get('batch_interrupt_received', False)):
            raise TrainingFinish
        if (level == 'epoch'
                and self.status.get('epoch_interrupt_received', False)):
            raise TrainingFinish

    def _handle_epoch_interrupt(self, signal_number, frame):
        # Try to complete the current epoch if user presses CTRL + C
        logger.warning('Received epoch interrupt signal.' +
                       epoch_interrupt_message)
        signal.signal(signal.SIGINT, self._handle_batch_interrupt)
        self.log.current_row['epoch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['epoch_interrupt_received'] = True

    def _handle_batch_interrupt(self, signal_number, frame):
        # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
        self._restore_signal_handlers()
        logger.warning('Received batch interrupt signal.' +
                       batch_interrupt_message)
        self.log.current_row['batch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['batch_interrupt_received'] = True

    def _restore_signal_handlers(self):
        signal.signal(signal.SIGINT, self.original_sigint_handler)
        signal.signal(signal.SIGTERM, self.original_sigterm_handler)
class GANMainLoop(object):
    """The standard main loop for GAN.

    Parameters
    ----------
    algorithm_g : instance of :class:`~blocks.algorithms.TrainingAlgorithm`
        The training algorithm for the generator.
    algorithm_d : instance of :class:`~blocks.algorithms.TrainingAlgorithm`
        The training algorithm for the discriminator.
    data_stream : instance of :class:`.DataStream`.
        The data stream. Should support :class:`AbstractDataStream`
        interface from Fuel.
    generator : instance of :class:`.ComputationGraph`
    discriminator : instance of :class:`.ComputationGraph`
    log : instance of :class:`.TrainingLog`, optional
        The log. When not given, a :class:`.TrainingLog` is created.
    log_backend : str
        The backend to use for the log. Currently `python` and `sqlite` are
        available. If not given, `config.log_backend` will be used. Ignored
        if `log` is passed.
    extensions : list of :class:`.TrainingExtension` instances
        The training extensions. Will be called in the same order as given
        here.

    """
    def __init__(self, algorithm_g, g_out, algorithm_d, d_out, data_stream, 
                 false_generated, false_dataset, generator_errors,
                 generator=None, discriminator=None, noise_per_sample=10, k=1,
                 minibatches=1, log=None, log_backend=None, extensions=None, observables=[]):
        if log is None:
            if log_backend is None:
                log_backend = config.log_backend
            log = BACKENDS[log_backend]()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm_g
        self.algorithm_g = algorithm_g
        self.algorithm_d = algorithm_d
        self.log = log
        self.extensions = extensions
        self.g_out = g_out
        self.d_out = d_out
        self.k = k
        self.minibatches = minibatches
        self.noise_per_sample = noise_per_sample
        self.false_generated = false_generated
        self.false_dataset = false_dataset
        self.generator_errors = generator_errors
        self.observables = observables
        self.first_batch = None

        self.profile = Profile()

        self._generator = generator

        self._discriminator = discriminator

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False

    @property
    def generator(self):
        if not self._generator:
            raise AttributeError("no generator in this main loop" +
                                 no_model_message)
        return self._generator

    @property
    def discriminator(self):
        if not self._discriminator:
            raise AttributeError("no discriminator in this main loop" +
                                 no_model_message)
        return self._discriminator

    @property
    def iteration_state(self):
        """Quick access to the (data stream, epoch iterator) pair."""
        return (self.data_stream, self.epoch_iterator)

    @iteration_state.setter
    def iteration_state(self, value):
        (self.data_stream, self.epoch_iterator) = value

    @property
    def status(self):
        """A shortcut for `self.log.status`."""
        return self.log.status

    def run(self):
        """Starts the main loop.

        The main loop ends when a training extension makes
        a `training_finish_requested` record in the log.

        """
        # This should do nothing if the user has already configured
        # logging, and will it least enable error messages otherwise.
        logging.basicConfig()

        # If this is resumption from a checkpoint, it is crucial to
        # reset `profile.current`. Otherwise, it simply does not hurt.
        self.profile.current = []

        with change_recursion_limit(config.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm_g.initialize()
                        self.algorithm_d.initialize()
                    self.status['training_started'] = True
                # We can not write "else:" here because extensions
                # called "before_training" could have changed the status
                # of the main loop.
                if self.log.status['iterations_done'] > 0:
                    self.log.resume()
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc()
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception:
                    logger.error(traceback.format_exc())
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                self._restore_signal_handlers()
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if config.profile:
                    self.profile.report()

    def find_extension(self, name):
        """Find an extension with a given name.

        Parameters
        ----------
        name : str
            The name of the extension looked for.

        Notes
        -----
        Will crash if there no or several extension found.

        """
        return unpack([extension for extension in self.extensions
                       if extension.name == name], singleton=True)

    def _run_epoch(self):
        if not self.status.get('epoch_started', False):
            try:
                self.log.status['received_first_batch'] = False
                self.epoch_iterator = (self.data_stream.
                                       get_epoch_iterator(as_dict=True))
            except StopIteration:
                return False
            self.status['epoch_started'] = True
            self._run_extensions('before_epoch')
        with Timer('epoch', self.profile):
            while self._run_iteration():
                pass
        self.status['epoch_started'] = False
        self.status['epochs_done'] += 1
        # Log might not allow mutating objects, so use += instead of append
        self.status['_epoch_ends'] += [self.status['iterations_done']]
        self._run_extensions('after_epoch')
        self._check_finish_training('epoch')
        return True

    def _run_iteration(self):
        ministeps_made = 0 # disabled D learning
        self.false_generated.set_value(0.)
        self.false_dataset.set_value(0.)
        while ministeps_made < self.k:
            try:
                with Timer('read_data', self.profile):
                    batch = next(self.epoch_iterator)
            except StopIteration:
                if not self.log.status['received_first_batch']:
                    reraise_as(ValueError("epoch iterator yielded zero batches"))
                return False
            self.log.status['received_first_batch'] = True
            self._run_extensions('before_batch', batch)
            batch = batch['features']
            noise = np.random.rand(self.minibatches, self.noise_per_sample).astype(np.float32)
            generated_batch = self._generator(noise)[0]
            bound_batch = np.zeros((batch.shape[0] * 2, batch.shape[1]), dtype=np.float32)
            bound_batch[:self.minibatches, :] = generated_batch
            bound_batch[self.minibatches:, :] = batch
            np.save('generated.npy', generated_batch)
            bound_batch = {'features': bound_batch}
            # if self.first_batch is None:
            #     self.first_batch = bound_batch
            with Timer('train', self.profile):
                # self.algorithm_d.process_batch(self.first_batch)
                self.algorithm_d.process_batch(bound_batch)
            ministeps_made += 1
        
        false_generated_perc = self.false_generated.get_value() / (self.k * self.minibatches)
        false_dataset_perc = self.false_dataset.get_value() / (self.k * self.minibatches)
        self.log[self.status['iterations_done'] + 1]['error_on_generated'] = false_generated_perc
        self.log[self.status['iterations_done'] + 1]['error_on_dataset'] = false_dataset_perc
        self.generator_errors.set_value(0.)
        noise = np.random.rand(self.minibatches, self.noise_per_sample).astype(np.float32)
        noise_batch = {'noise': noise}
        with Timer('train', self.profile):
            self.algorithm_g.process_batch(noise_batch)
        gen_errors = self.generator_errors.get_value() / self.minibatches
        self.log[self.status['iterations_done'] + 1]['generator_errors'] = gen_errors
        for o in self.observables:
            self.log[self.status['iterations_done'] + 1][o.name] = o.get_value()
        self.status['iterations_done'] += 1
        self._run_extensions('after_batch', bound_batch)
        self._check_finish_training('batch')
        return True

    def _run_extensions(self, method_name, *args):
        with Timer(method_name, self.profile):
            for extension in self.extensions:
                with Timer(type(extension).__name__, self.profile):
                    extension.dispatch(CallbackName(method_name), *args)

    def _check_finish_training(self, level):
        """Checks whether the current training should be terminated.

        Parameters
        ----------
        level : {'epoch', 'batch'}
            The level at which this check was performed. In some cases, we
            only want to quit after completing the remained of the epoch.

        """
        # In case when keyboard interrupt is handled right at the end of
        # the iteration the corresponding log record can be found only in
        # the previous row.
        if (self.log.current_row.get('training_finish_requested', False) or
                self.status.get('batch_interrupt_received', False)):
            raise TrainingFinish
        if (level == 'epoch' and
                self.status.get('epoch_interrupt_received', False)):
            raise TrainingFinish

    def _handle_epoch_interrupt(self, signal_number, frame):
        # Try to complete the current epoch if user presses CTRL + C
        logger.warning('Received epoch interrupt signal.' +
                       epoch_interrupt_message)
        signal.signal(signal.SIGINT, self._handle_batch_interrupt)
        self.log.current_row['epoch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['epoch_interrupt_received'] = True

    def _handle_batch_interrupt(self, signal_number, frame):
        # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
        self._restore_signal_handlers()
        logger.warning('Received batch interrupt signal.' +
                       batch_interrupt_message)
        self.log.current_row['batch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['batch_interrupt_received'] = True

    def _restore_signal_handlers(self):
        signal.signal(signal.SIGINT, self.original_sigint_handler)
        signal.signal(signal.SIGTERM, self.original_sigterm_handler)
Exemple #10
0
class MainLoop(object):
    """The standard main loop of Blocks.

    In the `MainLoop` a model is trained by a training algorithm using data
    extracted from a data stream. This process is scrupulously documented
    in a log object.

    The `MainLoop` itself does very little: only fetching the data from the
    data stream and feeding it to the algorithm. It expects the extensions
    to do most of the job. A respective callback of every extension is
    called at every stage of training. The extensions should communicate
    between themselves and with the main loop object by means of making
    records in the log. For instance in order to stop the training
    procedure an extension can make a record
    `training_finish_requested=True` in the log. The main loop checks for
    such a record after every batch and every epoch and terminates when
    finds it.

    The `MainLoop` also handles interruption signal SIGINT for you (e.g.
    the one program receives when you press Ctrl + C). It notes this event
    in the log and at the next iteration or epoch end the main loop will
    be gracefully finished, with calling all necessary extension callbacks
    and waiting until they finish.

    Parameters
    ----------
    algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm`
        The training algorithm.
    data_stream : instance of :class:`.DataStream`.
        The data stream. Should support :class:`AbstractDataStream`
        interface from Fuel.
    model : instance of :class:`.ComputationGraph`, optional
        An annotated computation graph, typically represented
        by :class:`ComputationGraph` or :class:`Model` object. The main
        loop object uses the model only for optional sanity checks, it is
        here mainly for the main loop extensions.
    log : instance of :class:`.TrainingLog`, optional
        The log. When not given, a :class:`.TrainingLog` is created.
    log_backend : str
        The backend to use for the log. Currently `python` and `sqlite` are
        available. If not given, `config.log_backend` will be used. Ignored
        if `log` is passed.
    extensions : list of :class:`.TrainingExtension` instances
        The training extensions. Will be called in the same order as given
        here.

    """
    def __init__(self, algorithm, data_stream, model=None, log=None,
                 log_backend=None, extensions=None):
        if log is None:
            if log_backend is None:
                log_backend = config.log_backend
            log = BACKENDS[log_backend]()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self._model = model

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False

    @property
    def model(self):
        if not self._model:
            raise AttributeError("no model in this main loop" +
                                 no_model_message)
        return self._model

    @property
    def iteration_state(self):
        """Quick access to the (data stream, epoch iterator) pair."""
        return (self.data_stream, self.epoch_iterator)

    @iteration_state.setter
    def iteration_state(self, value):
        (self.data_stream, self.epoch_iterator) = value

    @property
    def status(self):
        """A shortcut for `self.log.status`."""
        return self.log.status

    def run(self):
        """Starts the main loop.

        The main loop ends when a training extension makes
        a `training_finish_requested` record in the log.

        """
        # This should do nothing if the user has already configured
        # logging, and will it least enable error messages otherwise.
        logging.basicConfig()

        # If this is resumption from a checkpoint, it is crucial to
        # reset `profile.current`. Otherwise, it simply does not hurt.
        self.profile.current = []

        # Sanity check for the most common case
        if (self._model and isinstance(self._model, Model) and
                isinstance(self.algorithm, DifferentiableCostMinimizer)):
            if not (set(self._model.get_parameter_dict().values()) ==
                    set(self.algorithm.parameters)):
                logger.warning("different parameters for model and algorithm")

        with change_recursion_limit(config.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                # We can not write "else:" here because extensions
                # called "before_training" could have changed the status
                # of the main loop.
                if self.log.status['iterations_done'] > 0:
                    self.log.resume()
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc()
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception:
                    logger.error(traceback.format_exc())
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                self._restore_signal_handlers()
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if config.profile:
                    self.profile.report()

    def find_extension(self, name):
        """Find an extension with a given name.

        Parameters
        ----------
        name : str
            The name of the extension looked for.

        Notes
        -----
        Will crash if there no or several extension found.

        """
        return unpack([extension for extension in self.extensions
                       if extension.name == name], singleton=True)

    def _run_epoch(self):
        if not self.status.get('epoch_started', False):
            try:
                self.log.status['received_first_batch'] = False
                self.epoch_iterator = (self.data_stream.
                                       get_epoch_iterator(as_dict=True))
            except StopIteration:
                return False
            self.status['epoch_started'] = True
            self._run_extensions('before_epoch')
        with Timer('epoch', self.profile):
            while self._run_iteration():
                pass
        self.status['epoch_started'] = False
        self.status['epochs_done'] += 1
        # Log might not allow mutating objects, so use += instead of append
        self.status['_epoch_ends'] += [self.status['iterations_done']]
        self._run_extensions('after_epoch')
        self._check_finish_training('epoch')
        return True

    def _run_iteration(self):
        try:
            with Timer('read_data', self.profile):
                batch = next(self.epoch_iterator)
        except StopIteration:
            if not self.log.status['received_first_batch']:
                reraise_as(ValueError("epoch iterator yielded zero batches"))
            return False
        self.log.status['received_first_batch'] = True
        self._run_extensions('before_batch', batch)
        with Timer('train', self.profile):
            self.algorithm.process_batch(batch)
        self.status['iterations_done'] += 1
        self._run_extensions('after_batch', batch)
        self._check_finish_training('batch')
        return True

    def _run_extensions(self, method_name, *args):
        with Timer(method_name, self.profile):
            for extension in self.extensions:
                with Timer(type(extension).__name__, self.profile):
                    extension.dispatch(CallbackName(method_name), *args)

    def _check_finish_training(self, level):
        """Checks whether the current training should be terminated.

        Parameters
        ----------
        level : {'epoch', 'batch'}
            The level at which this check was performed. In some cases, we
            only want to quit after completing the remained of the epoch.

        """
        # In case when keyboard interrupt is handled right at the end of
        # the iteration the corresponding log record can be found only in
        # the previous row.
        if (self.log.current_row.get('training_finish_requested', False) or
                self.status.get('batch_interrupt_received', False)):
            raise TrainingFinish
        if (level == 'epoch' and
                self.status.get('epoch_interrupt_received', False)):
            raise TrainingFinish

    def _handle_epoch_interrupt(self, signal_number, frame):
        # Try to complete the current epoch if user presses CTRL + C
        logger.warning('Received epoch interrupt signal.' +
                       epoch_interrupt_message)
        signal.signal(signal.SIGINT, self._handle_batch_interrupt)
        self.log.current_row['epoch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['epoch_interrupt_received'] = True

    def _handle_batch_interrupt(self, signal_number, frame):
        # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
        self._restore_signal_handlers()
        logger.warning('Received batch interrupt signal.' +
                       batch_interrupt_message)
        self.log.current_row['batch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['batch_interrupt_received'] = True

    def _restore_signal_handlers(self):
        signal.signal(signal.SIGINT, self.original_sigint_handler)
        signal.signal(signal.SIGTERM, self.original_sigterm_handler)
class MainLoopWithMultiCGnoBlocks(object):
    """
    Standalone MainLoop in order to handle multi CG without blocks.
    --------------------------------------------------------
    Main-loop represents the training loop and handles necessary actions
    before, during and after the training. Given a model, algorithm and
    data_stream, main loop feeds the algorithm with the batches from
    data_stream and updates the model. Main loop calls the defined extensions
    when they are needed. Also keeps a log which contains all the status.
    """

    def __init__(self, models, algorithm, data_stream,
                 num_encs=1, num_decs=1, log=None, extensions=None):
        """
        models : dict, mapping cg_name to blocks.model
        algorithm : SGDMultiCG
        data_stream : data stream, either MultiSourceStream or MultiEncStream
        num_encs : int, number of encoders
        num_decs : int, number of decoders
        log : blocks.log, the logger object
        extensions : blocks.extensions, the main loop extensions
        """
        self.models = models
        self.num_encs = num_encs
        self.num_decs = num_decs
        self.num_cgs = len(models)

        if log is None:
            log = TrainingLog()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False

    @property
    def iteration_state(self):
        """Quick access to the (data stream, epoch iterator) pair."""
        return (self.data_stream, self.epoch_iterator)

    @iteration_state.setter
    def iteration_state(self, value):
        (self.data_stream, self.epoch_iterator) = value

    @property
    def status(self):
        """A shortcut for `self.log.status`."""
        return self.log.status

    def run(self):
        logging.basicConfig()

        with change_recursion_limit(cfg.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                if self.log.status['iterations_done'] > 0:
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc(e)
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception as inner_e:
                    logger.error(traceback.format_exc(inner_e))
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if cfg.profile:
                    self.profile.report()
                self._restore_signal_handlers()

    def find_extension(self, name):
        """Find an extension with a given name."""
        return unpack([extension for extension in self.extensions
                       if extension.name == name], singleton=True)

    def _run_epoch(self):
        if not self.status.get('epoch_started', False):
            try:
                self.log.status['received_first_batch'] = False
                self.epoch_iterator = (self.data_stream.
                                       get_epoch_iterator(as_dict=True))
            except StopIteration:
                return False
            self.status['epoch_started'] = True
            self._run_extensions('before_epoch')
        with Timer('epoch', self.profile):
            while self._run_iteration():
                pass
        self.status['epoch_started'] = False
        self.status['epochs_done'] += 1
        self.status['_epoch_ends'].append(self.status['iterations_done'])
        self._run_extensions('after_epoch')
        self._check_finish_training('epoch')
        return True

    def _run_iteration(self):
        try:
            with Timer('read_data', self.profile):
                batch = next(self.epoch_iterator)
        except StopIteration:
            if not self.log.status['received_first_batch']:
                reraise_as(ValueError("epoch iterator yielded zero batches"))
            return False
        self.log.status['received_first_batch'] = True
        self._run_extensions('before_batch', batch)
        with Timer('train', self.profile):
            self.algorithm.process_batch(batch)
        self.status['iterations_done'] += 1
        self._run_extensions('after_batch', batch)
        self._check_finish_training('batch')
        return True

    def _run_extensions(self, method_name, *args):
        with Timer(method_name, self.profile):
            for extension in self.extensions:
                with Timer(type(extension).__name__, self.profile):
                    extension.dispatch(CallbackName(method_name), *args)

    def _check_finish_training(self, level):
        """Checks whether the current training should be terminated."""
        # In case when keyboard interrupt is handled right at the end of
        # the iteration the corresponding log record can be found only in
        # the previous row.
        if (self.log.current_row.get('training_finish_requested', False) or
                self.status.get('batch_interrupt_received', False)):
            raise TrainingFinish
        if (level == 'epoch' and
                self.status.get('epoch_interrupt_received', False)):
            raise TrainingFinish

    def _handle_epoch_interrupt(self, signal_number, frame):
        # Try to complete the current epoch if user presses CTRL + C
        logger.warning('Received epoch interrupt signal.' +
                       epoch_interrupt_message)
        signal.signal(signal.SIGINT, self._handle_batch_interrupt)
        self.log.current_row['epoch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['epoch_interrupt_received'] = True

    def _handle_batch_interrupt(self, signal_number, frame):
        # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
        self._restore_signal_handlers()
        logger.warning('Received batch interrupt signal.' +
                       batch_interrupt_message)
        self.log.current_row['batch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['batch_interrupt_received'] = True

    def _restore_signal_handlers(self):
        signal.signal(signal.SIGINT, self.original_sigint_handler)
        signal.signal(signal.SIGTERM, self.original_sigterm_handler)
Exemple #12
0
class MainLoopWithMultiCGnoBlocks(object):
    """
    Standalone MainLoop in order to handle multi CG without blocks.
    --------------------------------------------------------
    Main-loop represents the training loop and handles necessary actions
    before, during and after the training. Given a model, algorithm and
    data_stream, main loop feeds the algorithm with the batches from
    data_stream and updates the model. Main loop calls the defined extensions
    when they are needed. Also keeps a log which contains all the status.
    """

    def __init__(self, models, algorithm, data_stream,
                 num_encs=1, num_decs=1, log=None, extensions=None):
        """
        models : dict, mapping cg_name to blocks.model
        algorithm : SGDMultiCG
        data_stream : data stream, either MultiSourceStream or MultiEncStream
        num_encs : int, number of encoders
        num_decs : int, number of decoders
        log : blocks.log, the logger object
        extensions : blocks.extensions, the main loop extensions
        """
        self.models = models
        self.num_encs = num_encs
        self.num_decs = num_decs
        self.num_cgs = len(models)

        if log is None:
            log = TrainingLog()
        if extensions is None:
            extensions = []

        self.data_stream = data_stream
        self.algorithm = algorithm
        self.log = log
        self.extensions = extensions

        self.profile = Profile()

        self.status['training_started'] = False
        self.status['epoch_started'] = False
        self.status['epoch_interrupt_received'] = False
        self.status['batch_interrupt_received'] = False

    @property
    def iteration_state(self):
        """Quick access to the (data stream, epoch iterator) pair."""
        return (self.data_stream, self.epoch_iterator)

    @iteration_state.setter
    def iteration_state(self, value):
        (self.data_stream, self.epoch_iterator) = value

    @property
    def status(self):
        """A shortcut for `self.log.status`."""
        return self.log.status

    def run(self):
        logging.basicConfig()

        with change_recursion_limit(cfg.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                if self.log.status['iterations_done'] > 0:
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc(e)
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception as inner_e:
                    logger.error(traceback.format_exc(inner_e))
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if cfg.profile:
                    self.profile.report()
                self._restore_signal_handlers()

    def find_extension(self, name):
        """Find an extension with a given name."""
        return unpack([extension for extension in self.extensions
                       if extension.name == name], singleton=True)

    def _run_epoch(self):
        if not self.status.get('epoch_started', False):
            try:
                self.log.status['received_first_batch'] = False
                self.epoch_iterator = (self.data_stream.
                                       get_epoch_iterator(as_dict=True))
            except StopIteration:
                return False
            self.status['epoch_started'] = True
            self._run_extensions('before_epoch')
        with Timer('epoch', self.profile):
            while self._run_iteration():
                pass
        self.status['epoch_started'] = False
        self.status['epochs_done'] += 1
        self.status['_epoch_ends'].append(self.status['iterations_done'])
        self._run_extensions('after_epoch')
        self._check_finish_training('epoch')
        return True

    def _run_iteration(self):
        try:
            with Timer('read_data', self.profile):
                batch = next(self.epoch_iterator)
        except StopIteration:
            if not self.log.status['received_first_batch']:
                reraise_as(ValueError("epoch iterator yielded zero batches"))
            return False
        self.log.status['received_first_batch'] = True
        self._run_extensions('before_batch', batch)
        with Timer('train', self.profile):
            self.algorithm.process_batch(batch)
        self.status['iterations_done'] += 1
        self._run_extensions('after_batch', batch)
        self._check_finish_training('batch')
        return True

    def _run_extensions(self, method_name, *args):
        with Timer(method_name, self.profile):
            for extension in self.extensions:
                with Timer(type(extension).__name__, self.profile):
                    extension.dispatch(CallbackName(method_name), *args)

    def _check_finish_training(self, level):
        """Checks whether the current training should be terminated."""
        # In case when keyboard interrupt is handled right at the end of
        # the iteration the corresponding log record can be found only in
        # the previous row.
        if (self.log.current_row.get('training_finish_requested', False) or
                self.status.get('batch_interrupt_received', False)):
            raise TrainingFinish
        if (level == 'epoch' and
                self.status.get('epoch_interrupt_received', False)):
            raise TrainingFinish

    def _handle_epoch_interrupt(self, signal_number, frame):
        # Try to complete the current epoch if user presses CTRL + C
        logger.warning('Received epoch interrupt signal.' +
                       epoch_interrupt_message)
        signal.signal(signal.SIGINT, self._handle_batch_interrupt)
        self.log.current_row['epoch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['epoch_interrupt_received'] = True

    def _handle_batch_interrupt(self, signal_number, frame):
        # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
        self._restore_signal_handlers()
        logger.warning('Received batch interrupt signal.' +
                       batch_interrupt_message)
        self.log.current_row['batch_interrupt_received'] = True
        # Add a record to the status. Unlike the log record it will be
        # easy to access at later iterations.
        self.status['batch_interrupt_received'] = True

    def _restore_signal_handlers(self):
        signal.signal(signal.SIGINT, self.original_sigint_handler)
        signal.signal(signal.SIGTERM, self.original_sigterm_handler)