class MainLoop(object): """The standard main loop of Blocks. In the `MainLoop` a model is trained by a training algorithm using data extracted from a data stream. This process is scrupulously documented in a log object. The `MainLoop` itself does very little: only fetching the data from the data stream and feeding it to the algorithm. It expects the extensions to do most of the job. A respective callback of every extension is called at every stage of training. The extensions should communicate between themselves and with the main loop object by means of making records in the log. For instance in order to stop the training procedure an extension can make a record `training_finish_requested=True` in the log. The main loop checks for such a record after every batch and every epoch and terminates when finds it. The `MainLoop` also handles interruption signal SIGINT for you (e.g. the one program receives when you press Ctrl + C). It notes this event in the log and at the next iteration or epoch end the main loop will be gracefully finished, with calling all necessary extension callbacks and waiting until they finish. Parameters ---------- algorithm : object The training algorithm. data_stream : instance of :class:`.DataStream`. The data stream. model : :class:`.AbstractModel` instance, optional The model object. It is entirely transparent for the main loop but may be used by extensions. log : instance of :class:`.TrainingLog`, optional The log. When not given, a :class:`.TrainingLog` is created. extensions : list of :class:`.TrainingExtension` instances The training extensions. Will be called in the same order as given here. profile : :class:`.Profile` Keeps track of the times spent in differen segments of the training loop. """ def __init__(self, algorithm, data_stream, model=None, log=None, extensions=None): if log is None: log = TrainingLog() if extensions is None: extensions = [] self.data_stream = data_stream self.algorithm = algorithm self.log = log self.extensions = extensions self.profile = Profile() self._model = model self.status['training_started'] = False self.status['epoch_started'] = False self.status['epoch_interrupt_received'] = False self.status['batch_interrupt_received'] = False @property def model(self): if not self._model: raise AttributeError("no model in this main loop" + no_model_message) return self._model @property def iteration_state(self): """Quick access to the (data stream, epoch iterator) pair.""" return (self.data_stream, self.epoch_iterator) @iteration_state.setter def iteration_state(self, value): (self.data_stream, self.epoch_iterator) = value @property def status(self): """A shortcut for `self.log.status`.""" return self.log.status def run(self): """Starts the main loop. The main loop ends when a training extension makes a `training_finish_requested` record in the log. """ # This should do nothing if the user has already configured # logging, and will it least enable error messages otherwise. logging.basicConfig() if self._model and isinstance(self.algorithm, DifferentiableCostMinimizer): # Sanity check: model and algorithm should be configured # similarly. if not self._model.get_objective() == self.algorithm.cost: logger.warning("different costs for model and algorithm") if not (set(self._model.get_params().values()) == set( self.algorithm.params)): logger.warning("different params for model and algorithm") with change_recursion_limit(config.recursion_limit): self.original_sigint_handler = signal.signal( signal.SIGINT, self._handle_epoch_interrupt) self.original_sigterm_handler = signal.signal( signal.SIGTERM, self._handle_batch_interrupt) try: logger.info("Entered the main loop") if not self.status['training_started']: for extension in self.extensions: extension.main_loop = self self._run_extensions('before_training') with Timer('initialization', self.profile): self.algorithm.initialize() self.status['training_started'] = True # We can not write "else:" here because extensions # called "before_training" could have changed the status # of the main loop. if self.log.status['iterations_done'] > 0: self._run_extensions('on_resumption') self.status['epoch_interrupt_received'] = False self.status['batch_interrupt_received'] = False with Timer('training', self.profile): while self._run_epoch(): pass except TrainingFinish: self.log.current_row['training_finished'] = True except Exception as e: self._restore_signal_handlers() self.log.current_row['got_exception'] = traceback.format_exc(e) logger.error("Error occured during training." + error_message) try: self._run_extensions('on_error') except Exception as inner_e: logger.error(traceback.format_exc(inner_e)) logger.error("Error occured when running extensions." + error_in_error_handling_message) reraise_as(e) finally: if self.log.current_row.get('training_finished', False): self._run_extensions('after_training') if config.profile: self.profile.report() self._restore_signal_handlers() def find_extension(self, name): """Find an extension with a given name. Parameters ---------- name : str The name of the extension looked for. Notes ----- Will crash if there no or several extension found. """ return unpack([ extension for extension in self.extensions if extension.name == name ], singleton=True) def _run_epoch(self): if not self.status.get('epoch_started', False): try: self.log.status['received_first_batch'] = False self.epoch_iterator = (self.data_stream.get_epoch_iterator( as_dict=True)) except StopIteration: return False self.status['epoch_started'] = True self._run_extensions('before_epoch') with Timer('epoch', self.profile): while self._run_iteration(): pass self.status['epoch_started'] = False self.status['epochs_done'] += 1 self.status['_epoch_ends'].append(self.status['iterations_done']) self._run_extensions('after_epoch') self._check_finish_training('epoch') return True def _run_iteration(self): try: with Timer('read_data', self.profile): batch = next(self.epoch_iterator) except StopIteration: if not self.log.status['received_first_batch']: reraise_as(ValueError("epoch iterator yielded zero batches")) return False self.log.status['received_first_batch'] = True self._run_extensions('before_batch', batch) with Timer('train', self.profile): self.algorithm.process_batch(batch) self.status['iterations_done'] += 1 self._run_extensions('after_batch', batch) self._check_finish_training('batch') return True def _run_extensions(self, method_name, *args): with Timer(method_name, self.profile): for extension in self.extensions: with Timer(type(extension).__name__, self.profile): extension.dispatch(CallbackName(method_name), *args) def _check_finish_training(self, level): """Checks whether the current training should be terminated. Parameters ---------- level : {'epoch', 'batch'} The level at which this check was performed. In some cases, we only want to quit after completing the remained of the epoch. """ # In case when keyboard interrupt is handled right at the end of # the iteration the corresponding log record can be found only in # the previous row. if (self.log.current_row.get('training_finish_requested', False) or self.status.get('batch_interrupt_received', False)): raise TrainingFinish if (level == 'epoch' and self.status.get('epoch_interrupt_received', False)): raise TrainingFinish def _handle_epoch_interrupt(self, signal_number, frame): # Try to complete the current epoch if user presses CTRL + C logger.warning('Received epoch interrupt signal.' + epoch_interrupt_message) signal.signal(signal.SIGINT, self._handle_batch_interrupt) self.log.current_row['epoch_interrupt_received'] = True # Add a record to the status. Unlike the log record it will be # easy to access at later iterations. self.status['epoch_interrupt_received'] = True def _handle_batch_interrupt(self, signal_number, frame): # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch self._restore_signal_handlers() logger.warning('Received batch interrupt signal.' + batch_interrupt_message) self.log.current_row['batch_interrupt_received'] = True # Add a record to the status. Unlike the log record it will be # easy to access at later iterations. self.status['batch_interrupt_received'] = True def _restore_signal_handlers(self): signal.signal(signal.SIGINT, self.original_sigint_handler) signal.signal(signal.SIGTERM, self.original_sigterm_handler)
class GANMainLoop(object): """The standard main loop for GAN. Parameters ---------- algorithm_g : instance of :class:`~blocks.algorithms.TrainingAlgorithm` The training algorithm for the generator. algorithm_d : instance of :class:`~blocks.algorithms.TrainingAlgorithm` The training algorithm for the discriminator. data_stream : instance of :class:`.DataStream`. The data stream. Should support :class:`AbstractDataStream` interface from Fuel. generator : instance of :class:`.ComputationGraph` discriminator : instance of :class:`.ComputationGraph` log : instance of :class:`.TrainingLog`, optional The log. When not given, a :class:`.TrainingLog` is created. log_backend : str The backend to use for the log. Currently `python` and `sqlite` are available. If not given, `config.log_backend` will be used. Ignored if `log` is passed. extensions : list of :class:`.TrainingExtension` instances The training extensions. Will be called in the same order as given here. """ def __init__(self, algorithm_g, g_out, algorithm_d, d_out, data_stream, false_generated, false_dataset, generator_errors, generator=None, discriminator=None, noise_per_sample=10, k=1, minibatches=1, log=None, log_backend=None, extensions=None, observables=[]): if log is None: if log_backend is None: log_backend = config.log_backend log = BACKENDS[log_backend]() if extensions is None: extensions = [] self.data_stream = data_stream self.algorithm = algorithm_g self.algorithm_g = algorithm_g self.algorithm_d = algorithm_d self.log = log self.extensions = extensions self.g_out = g_out self.d_out = d_out self.k = k self.minibatches = minibatches self.noise_per_sample = noise_per_sample self.false_generated = false_generated self.false_dataset = false_dataset self.generator_errors = generator_errors self.observables = observables self.first_batch = None self.profile = Profile() self._generator = generator self._discriminator = discriminator self.status['training_started'] = False self.status['epoch_started'] = False self.status['epoch_interrupt_received'] = False self.status['batch_interrupt_received'] = False @property def generator(self): if not self._generator: raise AttributeError("no generator in this main loop" + no_model_message) return self._generator @property def discriminator(self): if not self._discriminator: raise AttributeError("no discriminator in this main loop" + no_model_message) return self._discriminator @property def iteration_state(self): """Quick access to the (data stream, epoch iterator) pair.""" return (self.data_stream, self.epoch_iterator) @iteration_state.setter def iteration_state(self, value): (self.data_stream, self.epoch_iterator) = value @property def status(self): """A shortcut for `self.log.status`.""" return self.log.status def run(self): """Starts the main loop. The main loop ends when a training extension makes a `training_finish_requested` record in the log. """ # This should do nothing if the user has already configured # logging, and will it least enable error messages otherwise. logging.basicConfig() # If this is resumption from a checkpoint, it is crucial to # reset `profile.current`. Otherwise, it simply does not hurt. self.profile.current = [] with change_recursion_limit(config.recursion_limit): self.original_sigint_handler = signal.signal( signal.SIGINT, self._handle_epoch_interrupt) self.original_sigterm_handler = signal.signal( signal.SIGTERM, self._handle_batch_interrupt) try: logger.info("Entered the main loop") if not self.status['training_started']: for extension in self.extensions: extension.main_loop = self self._run_extensions('before_training') with Timer('initialization', self.profile): self.algorithm_g.initialize() self.algorithm_d.initialize() self.status['training_started'] = True # We can not write "else:" here because extensions # called "before_training" could have changed the status # of the main loop. if self.log.status['iterations_done'] > 0: self.log.resume() self._run_extensions('on_resumption') self.status['epoch_interrupt_received'] = False self.status['batch_interrupt_received'] = False with Timer('training', self.profile): while self._run_epoch(): pass except TrainingFinish: self.log.current_row['training_finished'] = True except Exception as e: self._restore_signal_handlers() self.log.current_row['got_exception'] = traceback.format_exc() logger.error("Error occured during training." + error_message) try: self._run_extensions('on_error') except Exception: logger.error(traceback.format_exc()) logger.error("Error occured when running extensions." + error_in_error_handling_message) reraise_as(e) finally: self._restore_signal_handlers() if self.log.current_row.get('training_finished', False): self._run_extensions('after_training') if config.profile: self.profile.report() def find_extension(self, name): """Find an extension with a given name. Parameters ---------- name : str The name of the extension looked for. Notes ----- Will crash if there no or several extension found. """ return unpack([extension for extension in self.extensions if extension.name == name], singleton=True) def _run_epoch(self): if not self.status.get('epoch_started', False): try: self.log.status['received_first_batch'] = False self.epoch_iterator = (self.data_stream. get_epoch_iterator(as_dict=True)) except StopIteration: return False self.status['epoch_started'] = True self._run_extensions('before_epoch') with Timer('epoch', self.profile): while self._run_iteration(): pass self.status['epoch_started'] = False self.status['epochs_done'] += 1 # Log might not allow mutating objects, so use += instead of append self.status['_epoch_ends'] += [self.status['iterations_done']] self._run_extensions('after_epoch') self._check_finish_training('epoch') return True def _run_iteration(self): ministeps_made = 0 # disabled D learning self.false_generated.set_value(0.) self.false_dataset.set_value(0.) while ministeps_made < self.k: try: with Timer('read_data', self.profile): batch = next(self.epoch_iterator) except StopIteration: if not self.log.status['received_first_batch']: reraise_as(ValueError("epoch iterator yielded zero batches")) return False self.log.status['received_first_batch'] = True self._run_extensions('before_batch', batch) batch = batch['features'] noise = np.random.rand(self.minibatches, self.noise_per_sample).astype(np.float32) generated_batch = self._generator(noise)[0] bound_batch = np.zeros((batch.shape[0] * 2, batch.shape[1]), dtype=np.float32) bound_batch[:self.minibatches, :] = generated_batch bound_batch[self.minibatches:, :] = batch np.save('generated.npy', generated_batch) bound_batch = {'features': bound_batch} # if self.first_batch is None: # self.first_batch = bound_batch with Timer('train', self.profile): # self.algorithm_d.process_batch(self.first_batch) self.algorithm_d.process_batch(bound_batch) ministeps_made += 1 false_generated_perc = self.false_generated.get_value() / (self.k * self.minibatches) false_dataset_perc = self.false_dataset.get_value() / (self.k * self.minibatches) self.log[self.status['iterations_done'] + 1]['error_on_generated'] = false_generated_perc self.log[self.status['iterations_done'] + 1]['error_on_dataset'] = false_dataset_perc self.generator_errors.set_value(0.) noise = np.random.rand(self.minibatches, self.noise_per_sample).astype(np.float32) noise_batch = {'noise': noise} with Timer('train', self.profile): self.algorithm_g.process_batch(noise_batch) gen_errors = self.generator_errors.get_value() / self.minibatches self.log[self.status['iterations_done'] + 1]['generator_errors'] = gen_errors for o in self.observables: self.log[self.status['iterations_done'] + 1][o.name] = o.get_value() self.status['iterations_done'] += 1 self._run_extensions('after_batch', bound_batch) self._check_finish_training('batch') return True def _run_extensions(self, method_name, *args): with Timer(method_name, self.profile): for extension in self.extensions: with Timer(type(extension).__name__, self.profile): extension.dispatch(CallbackName(method_name), *args) def _check_finish_training(self, level): """Checks whether the current training should be terminated. Parameters ---------- level : {'epoch', 'batch'} The level at which this check was performed. In some cases, we only want to quit after completing the remained of the epoch. """ # In case when keyboard interrupt is handled right at the end of # the iteration the corresponding log record can be found only in # the previous row. if (self.log.current_row.get('training_finish_requested', False) or self.status.get('batch_interrupt_received', False)): raise TrainingFinish if (level == 'epoch' and self.status.get('epoch_interrupt_received', False)): raise TrainingFinish def _handle_epoch_interrupt(self, signal_number, frame): # Try to complete the current epoch if user presses CTRL + C logger.warning('Received epoch interrupt signal.' + epoch_interrupt_message) signal.signal(signal.SIGINT, self._handle_batch_interrupt) self.log.current_row['epoch_interrupt_received'] = True # Add a record to the status. Unlike the log record it will be # easy to access at later iterations. self.status['epoch_interrupt_received'] = True def _handle_batch_interrupt(self, signal_number, frame): # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch self._restore_signal_handlers() logger.warning('Received batch interrupt signal.' + batch_interrupt_message) self.log.current_row['batch_interrupt_received'] = True # Add a record to the status. Unlike the log record it will be # easy to access at later iterations. self.status['batch_interrupt_received'] = True def _restore_signal_handlers(self): signal.signal(signal.SIGINT, self.original_sigint_handler) signal.signal(signal.SIGTERM, self.original_sigterm_handler)
class MainLoopWithMultiCGnoBlocks(object): """ Standalone MainLoop in order to handle multi CG without blocks. -------------------------------------------------------- Main-loop represents the training loop and handles necessary actions before, during and after the training. Given a model, algorithm and data_stream, main loop feeds the algorithm with the batches from data_stream and updates the model. Main loop calls the defined extensions when they are needed. Also keeps a log which contains all the status. """ def __init__(self, models, algorithm, data_stream, num_encs=1, num_decs=1, log=None, extensions=None): """ models : dict, mapping cg_name to blocks.model algorithm : SGDMultiCG data_stream : data stream, either MultiSourceStream or MultiEncStream num_encs : int, number of encoders num_decs : int, number of decoders log : blocks.log, the logger object extensions : blocks.extensions, the main loop extensions """ self.models = models self.num_encs = num_encs self.num_decs = num_decs self.num_cgs = len(models) if log is None: log = TrainingLog() if extensions is None: extensions = [] self.data_stream = data_stream self.algorithm = algorithm self.log = log self.extensions = extensions self.profile = Profile() self.status['training_started'] = False self.status['epoch_started'] = False self.status['epoch_interrupt_received'] = False self.status['batch_interrupt_received'] = False @property def iteration_state(self): """Quick access to the (data stream, epoch iterator) pair.""" return (self.data_stream, self.epoch_iterator) @iteration_state.setter def iteration_state(self, value): (self.data_stream, self.epoch_iterator) = value @property def status(self): """A shortcut for `self.log.status`.""" return self.log.status def run(self): logging.basicConfig() with change_recursion_limit(cfg.recursion_limit): self.original_sigint_handler = signal.signal( signal.SIGINT, self._handle_epoch_interrupt) self.original_sigterm_handler = signal.signal( signal.SIGTERM, self._handle_batch_interrupt) try: logger.info("Entered the main loop") if not self.status['training_started']: for extension in self.extensions: extension.main_loop = self self._run_extensions('before_training') with Timer('initialization', self.profile): self.algorithm.initialize() self.status['training_started'] = True if self.log.status['iterations_done'] > 0: self._run_extensions('on_resumption') self.status['epoch_interrupt_received'] = False self.status['batch_interrupt_received'] = False with Timer('training', self.profile): while self._run_epoch(): pass except TrainingFinish: self.log.current_row['training_finished'] = True except Exception as e: self._restore_signal_handlers() self.log.current_row['got_exception'] = traceback.format_exc(e) logger.error("Error occured during training." + error_message) try: self._run_extensions('on_error') except Exception as inner_e: logger.error(traceback.format_exc(inner_e)) logger.error("Error occured when running extensions." + error_in_error_handling_message) reraise_as(e) finally: if self.log.current_row.get('training_finished', False): self._run_extensions('after_training') if cfg.profile: self.profile.report() self._restore_signal_handlers() def find_extension(self, name): """Find an extension with a given name.""" return unpack([extension for extension in self.extensions if extension.name == name], singleton=True) def _run_epoch(self): if not self.status.get('epoch_started', False): try: self.log.status['received_first_batch'] = False self.epoch_iterator = (self.data_stream. get_epoch_iterator(as_dict=True)) except StopIteration: return False self.status['epoch_started'] = True self._run_extensions('before_epoch') with Timer('epoch', self.profile): while self._run_iteration(): pass self.status['epoch_started'] = False self.status['epochs_done'] += 1 self.status['_epoch_ends'].append(self.status['iterations_done']) self._run_extensions('after_epoch') self._check_finish_training('epoch') return True def _run_iteration(self): try: with Timer('read_data', self.profile): batch = next(self.epoch_iterator) except StopIteration: if not self.log.status['received_first_batch']: reraise_as(ValueError("epoch iterator yielded zero batches")) return False self.log.status['received_first_batch'] = True self._run_extensions('before_batch', batch) with Timer('train', self.profile): self.algorithm.process_batch(batch) self.status['iterations_done'] += 1 self._run_extensions('after_batch', batch) self._check_finish_training('batch') return True def _run_extensions(self, method_name, *args): with Timer(method_name, self.profile): for extension in self.extensions: with Timer(type(extension).__name__, self.profile): extension.dispatch(CallbackName(method_name), *args) def _check_finish_training(self, level): """Checks whether the current training should be terminated.""" # In case when keyboard interrupt is handled right at the end of # the iteration the corresponding log record can be found only in # the previous row. if (self.log.current_row.get('training_finish_requested', False) or self.status.get('batch_interrupt_received', False)): raise TrainingFinish if (level == 'epoch' and self.status.get('epoch_interrupt_received', False)): raise TrainingFinish def _handle_epoch_interrupt(self, signal_number, frame): # Try to complete the current epoch if user presses CTRL + C logger.warning('Received epoch interrupt signal.' + epoch_interrupt_message) signal.signal(signal.SIGINT, self._handle_batch_interrupt) self.log.current_row['epoch_interrupt_received'] = True # Add a record to the status. Unlike the log record it will be # easy to access at later iterations. self.status['epoch_interrupt_received'] = True def _handle_batch_interrupt(self, signal_number, frame): # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch self._restore_signal_handlers() logger.warning('Received batch interrupt signal.' + batch_interrupt_message) self.log.current_row['batch_interrupt_received'] = True # Add a record to the status. Unlike the log record it will be # easy to access at later iterations. self.status['batch_interrupt_received'] = True def _restore_signal_handlers(self): signal.signal(signal.SIGINT, self.original_sigint_handler) signal.signal(signal.SIGTERM, self.original_sigterm_handler)
class MainLoop(object): """The standard main loop of Blocks. In the `MainLoop` a model is trained by a training algorithm using data extracted from a data stream. This process is scrupulously documented in a log object. The `MainLoop` itself does very little: only fetching the data from the data stream and feeding it to the algorithm. It expects the extensions to do most of the job. A respective callback of every extension is called at every stage of training. The extensions should communicate between themselves and with the main loop object by means of making records in the log. For instance in order to stop the training procedure an extension can make a record `training_finish_requested=True` in the log. The main loop checks for such a record after every batch and every epoch and terminates when finds it. The `MainLoop` also handles interruption signal SIGINT for you (e.g. the one program receives when you press Ctrl + C). It notes this event in the log and at the next iteration or epoch end the main loop will be gracefully finished, with calling all necessary extension callbacks and waiting until they finish. Parameters ---------- algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm` The training algorithm. data_stream : instance of :class:`.DataStream`. The data stream. Should support :class:`AbstractDataStream` interface from Fuel. model : instance of :class:`.ComputationGraph`, optional An annotated computation graph, typically represented by :class:`ComputationGraph` or :class:`Model` object. The main loop object uses the model only for optional sanity checks, it is here mainly for the main loop extensions. log : instance of :class:`.TrainingLog`, optional The log. When not given, a :class:`.TrainingLog` is created. log_backend : str The backend to use for the log. Currently `python` and `sqlite` are available. If not given, `config.log_backend` will be used. Ignored if `log` is passed. extensions : list of :class:`.TrainingExtension` instances The training extensions. Will be called in the same order as given here. """ def __init__(self, algorithm, data_stream, model=None, log=None, log_backend=None, extensions=None): if log is None: if log_backend is None: log_backend = config.log_backend log = BACKENDS[log_backend]() if extensions is None: extensions = [] self.data_stream = data_stream self.algorithm = algorithm self.log = log self.extensions = extensions self.profile = Profile() self._model = model self.status['training_started'] = False self.status['epoch_started'] = False self.status['epoch_interrupt_received'] = False self.status['batch_interrupt_received'] = False @property def model(self): if not self._model: raise AttributeError("no model in this main loop" + no_model_message) return self._model @property def iteration_state(self): """Quick access to the (data stream, epoch iterator) pair.""" return (self.data_stream, self.epoch_iterator) @iteration_state.setter def iteration_state(self, value): (self.data_stream, self.epoch_iterator) = value @property def status(self): """A shortcut for `self.log.status`.""" return self.log.status def run(self): """Starts the main loop. The main loop ends when a training extension makes a `training_finish_requested` record in the log. """ # This should do nothing if the user has already configured # logging, and will it least enable error messages otherwise. logging.basicConfig() # If this is resumption from a checkpoint, it is crucial to # reset `profile.current`. Otherwise, it simply does not hurt. self.profile.current = [] # Sanity check for the most common case if (self._model and isinstance(self._model, Model) and isinstance(self.algorithm, DifferentiableCostMinimizer)): if not (set(self._model.get_parameter_dict().values()) == set(self.algorithm.parameters)): logger.warning("different parameters for model and algorithm") with change_recursion_limit(config.recursion_limit): self.original_sigint_handler = signal.signal( signal.SIGINT, self._handle_epoch_interrupt) self.original_sigterm_handler = signal.signal( signal.SIGTERM, self._handle_batch_interrupt) try: logger.info("Entered the main loop") if not self.status['training_started']: for extension in self.extensions: extension.main_loop = self self._run_extensions('before_training') with Timer('initialization', self.profile): self.algorithm.initialize() self.status['training_started'] = True # We can not write "else:" here because extensions # called "before_training" could have changed the status # of the main loop. if self.log.status['iterations_done'] > 0: self.log.resume() self._run_extensions('on_resumption') self.status['epoch_interrupt_received'] = False self.status['batch_interrupt_received'] = False with Timer('training', self.profile): while self._run_epoch(): pass except TrainingFinish: self.log.current_row['training_finished'] = True except Exception as e: self._restore_signal_handlers() self.log.current_row['got_exception'] = traceback.format_exc() logger.error("Error occured during training." + error_message) try: self._run_extensions('on_error') except Exception: logger.error(traceback.format_exc()) logger.error("Error occured when running extensions." + error_in_error_handling_message) reraise_as(e) finally: self._restore_signal_handlers() if self.log.current_row.get('training_finished', False): self._run_extensions('after_training') if config.profile: self.profile.report() def find_extension(self, name): """Find an extension with a given name. Parameters ---------- name : str The name of the extension looked for. Notes ----- Will crash if there no or several extension found. """ return unpack([extension for extension in self.extensions if extension.name == name], singleton=True) def _run_epoch(self): if not self.status.get('epoch_started', False): try: self.log.status['received_first_batch'] = False self.epoch_iterator = (self.data_stream. get_epoch_iterator(as_dict=True)) except StopIteration: return False self.status['epoch_started'] = True self._run_extensions('before_epoch') with Timer('epoch', self.profile): while self._run_iteration(): pass self.status['epoch_started'] = False self.status['epochs_done'] += 1 # Log might not allow mutating objects, so use += instead of append self.status['_epoch_ends'] += [self.status['iterations_done']] self._run_extensions('after_epoch') self._check_finish_training('epoch') return True def _run_iteration(self): try: with Timer('read_data', self.profile): batch = next(self.epoch_iterator) except StopIteration: if not self.log.status['received_first_batch']: reraise_as(ValueError("epoch iterator yielded zero batches")) return False self.log.status['received_first_batch'] = True self._run_extensions('before_batch', batch) with Timer('train', self.profile): self.algorithm.process_batch(batch) self.status['iterations_done'] += 1 self._run_extensions('after_batch', batch) self._check_finish_training('batch') return True def _run_extensions(self, method_name, *args): with Timer(method_name, self.profile): for extension in self.extensions: with Timer(type(extension).__name__, self.profile): extension.dispatch(CallbackName(method_name), *args) def _check_finish_training(self, level): """Checks whether the current training should be terminated. Parameters ---------- level : {'epoch', 'batch'} The level at which this check was performed. In some cases, we only want to quit after completing the remained of the epoch. """ # In case when keyboard interrupt is handled right at the end of # the iteration the corresponding log record can be found only in # the previous row. if (self.log.current_row.get('training_finish_requested', False) or self.status.get('batch_interrupt_received', False)): raise TrainingFinish if (level == 'epoch' and self.status.get('epoch_interrupt_received', False)): raise TrainingFinish def _handle_epoch_interrupt(self, signal_number, frame): # Try to complete the current epoch if user presses CTRL + C logger.warning('Received epoch interrupt signal.' + epoch_interrupt_message) signal.signal(signal.SIGINT, self._handle_batch_interrupt) self.log.current_row['epoch_interrupt_received'] = True # Add a record to the status. Unlike the log record it will be # easy to access at later iterations. self.status['epoch_interrupt_received'] = True def _handle_batch_interrupt(self, signal_number, frame): # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch self._restore_signal_handlers() logger.warning('Received batch interrupt signal.' + batch_interrupt_message) self.log.current_row['batch_interrupt_received'] = True # Add a record to the status. Unlike the log record it will be # easy to access at later iterations. self.status['batch_interrupt_received'] = True def _restore_signal_handlers(self): signal.signal(signal.SIGINT, self.original_sigint_handler) signal.signal(signal.SIGTERM, self.original_sigterm_handler)