Beispiel #1
0
    def initialize(self):
        """Initialize parameters.

        Intialize parameters, such as weight matrices and biases.

        Notes
        -----
        If the brick has not allocated its parameters yet, this method will
        call the :meth:`allocate` method in order to do so.

        """
        if not self.allocated:
            self.allocate()
        if not self.initialization_config_pushed:
            self.push_initialization_config()
        for child in self.children:
            child.initialize()
        try:
            self._initialize()
        except Exception:
            if self.lazy:
                reraise_as("Lazy initialization is enabled, so please make "
                           "sure you have set all the required configuration "
                           "for this method call.")
            else:
                raise
        self.initialized = True
Beispiel #2
0
    def run(self):
        """Starts the main loop.

        The main loop ends when a training extension makes
        a `training_finish_requested` record in the log.

        """
        # This should do nothing if the user has already configured
        # logging, and will it least enable error messages otherwise.
        logging.basicConfig()

        # If this is resumption from a checkpoint, it is crucial to
        # reset `profile.current`. Otherwise, it simply does not hurt.
        self.profile.current = []

        with change_recursion_limit(config.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm_g.initialize()
                        self.algorithm_d.initialize()
                    self.status['training_started'] = True
                # We can not write "else:" here because extensions
                # called "before_training" could have changed the status
                # of the main loop.
                if self.log.status['iterations_done'] > 0:
                    self.log.resume()
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc()
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception:
                    logger.error(traceback.format_exc())
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                self._restore_signal_handlers()
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if config.profile:
                    self.profile.report()
Beispiel #3
0
    def run(self):
        """Starts the main loop.

        The main loop ends when a training extension makes
        a `training_finish_requested` record in the log.

        """
        # This should do nothing if the user has already configured
        # logging, and will it least enable error messages otherwise.
        logging.basicConfig()

        if self._model and isinstance(self.algorithm,
                                      DifferentiableCostMinimizer):
            # Sanity check: model and algorithm should be configured
            # similarly.
            if not self._model.get_objective() == self.algorithm.cost:
                logger.warning("different costs for model and algorithm")
            if not (set(self._model.get_params().values()) == set(
                    self.algorithm.params)):
                logger.warning("different params for model and algorithm")

        with change_recursion_limit(config.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status._training_started:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    self.algorithm.initialize()
                    self.status._training_started = True
                # We can not write "else:" here because extensions
                # called "before_training" could have changed the status
                # of the main loop.
                if self.log.status.iterations_done > 0:
                    self._run_extensions('on_resumption')
                    self.status._epoch_interrupt_received = False
                    self.status._batch_interrupt_received = False
                while self._run_epoch():
                    pass
            except TrainingFinish:
                self.log.current_row.training_finished = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row.got_exception = traceback.format_exc(e)
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception as inner_e:
                    logger.error(traceback.format_exc(inner_e))
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                if self.log.current_row.training_finished:
                    self._run_extensions('after_training')
                self._restore_signal_handlers()
Beispiel #4
0
    def initialize(self):
        """Initialize parameters.

        Intialize parameters, such as weight matrices and biases.

        Notes
        -----
        If the brick has not allocated its parameters yet, this method will
        call the :meth:`allocate` method in order to do so.

        """
        if not self.allocated:
            self.allocate()
        if not self.initialization_config_pushed:
            self.push_initialization_config()
        for child in self.children:
            child.initialize()
        try:
            self._initialize()
        except Exception:
            if self.lazy:
                reraise_as("Lazy initialization is enabled, so please make "
                           "sure you have set all the required configuration "
                           "for this method call.")
            else:
                raise
        self.initialized = True
Beispiel #5
0
    def run(self):
        """Starts the main loop.

        The main loop ends when a training extension makes
        a `training_finish_requested` record in the log.

        """
        # This should do nothing if the user has already configured
        # logging, and will it least enable error messages otherwise.
        logging.basicConfig()

        if self._model and isinstance(self.algorithm,
                                      DifferentiableCostMinimizer):
            # Sanity check: model and algorithm should be configured
            # similarly.
            if not self._model.get_objective() == self.algorithm.cost:
                logger.warning("different costs for model and algorithm")
            if not (set(self._model.get_params().values()) ==
                    set(self.algorithm.params)):
                logger.warning("different params for model and algorithm")

        with change_recursion_limit(config.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status._training_started:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    self.algorithm.initialize()
                    self.status._training_started = True
                # We can not write "else:" here because extensions
                # called "before_training" could have changed the status
                # of the main loop.
                if self.log.status.iterations_done > 0:
                    self._run_extensions('on_resumption')
                while self._run_epoch():
                    pass
            except TrainingFinish:
                self.log.current_row.training_finished = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row.got_exception = traceback.format_exc(e)
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception as inner_e:
                    logger.error(traceback.format_exc(inner_e))
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                if self.log.current_row.training_finished:
                    self._run_extensions('after_training')
                self._restore_signal_handlers()
Beispiel #6
0
 def process_batch(self, batch):
     try:
         batch = dict_subset(batch, self.buffer_.input_names)
     except KeyError:
         reraise_as("Not all data sources required for monitoring were"
                    " provided. The list of required data sources:"
                    " {}.".format(self.buffer_.input_names))
     if self._accumulate_fun is not None:
         self._accumulate_fun(**batch)
Beispiel #7
0
 def before_training(self):
     if not os.path.exists(self.embeddings):
         logger.info("No embeddings found")
         return
     logger.info("Loading embeddings into the main loop")
     try:
         self.load_to(self.main_loop)
     except Exception:
         reraise_as("Failed to load embeddings")
Beispiel #8
0
 def process_batch(self, batch):
     try:
         batch = dict_subset(batch, self.buffer_.input_names)
     except KeyError:
         reraise_as(
             "Not all data sources required for monitoring were"
             " provided. The list of required data sources:"
             " {}.".format(self.buffer_.input_names))
     if self._accumulate_fun is not None:
         self._accumulate_fun(**batch)
Beispiel #9
0
 def before_training(self):
     if not os.path.exists(self.path):
         logger.warning("No dump found")
         return
     logger.info("loading model from {}".format(self.path))
     try:
         self.load_to(self.main_loop)
         self.main_loop.log.current_row[LOADED_FROM] = self.path
     except Exception:
         reraise_as("Failed to load the state")
Beispiel #10
0
 def do(self, *args, **kwargs):
     if not os.path.exists(self.path):
         logger.warning("No dump found")
         return
     logger.info("loading model from {}".format(self.path))
     try:
         self.load_to(self.main_loop)
         self.main_loop.log.current_row[LOADED_FROM] = self.path
     except Exception:
         reraise_as("Failed to load the state")
Beispiel #11
0
 def before_training(self):
     if not os.path.exists(self.path):
         logger.warning("No log dump found")
         return
     logger.info("loading log from {}".format(self.path))
     try:
         self.load_to(self.main_loop)
         #self.main_loop.log.current_row[saveload.LOADED_FROM] = self.path
     except Exception:
         reraise_as("Failed to load the state")
Beispiel #12
0
 def before_training(self):
     if not os.path.exists(self.path_to_folder):
         logger.info("No dump found")
         return
     logger.info("Loading the state from {} into the main loop"
                 .format(self.path_to_folder))
     try:
         self.load_to(self.main_loop)
         self.main_loop.log.current_row[LOADED_FROM] = self.path_to_folder
     except Exception:
         reraise_as("Failed to load the state")
Beispiel #13
0
 def before_training(self):
     if not os.path.exists(self.path_to_folder):
         logger.info("No dump found")
         return
     logger.info("Loading the state from {} into the main loop".format(
         self.path_to_folder))
     try:
         self.load_to(self.main_loop)
         self.main_loop.log.current_row[LOADED_FROM] = self.path_to_folder
     except Exception:
         reraise_as("Failed to load the state")
Beispiel #14
0
 def process_batch(self, batch, accumulate_dict):
     try:
         input_names = [v.name for v in self.inputs]
         batch = dict_subset(batch, input_names)
     except KeyError:
         reraise_as("Not all data sources required for monitoring were"
                    " provided. The list of required data sources:"
                    " {}.".format(input_names))
     results_list = self._func(**batch)
     output_names = [v.name for v in self.outputs]
     for name, res in zip(output_names, results_list):
         accumulate_dict[name].append(res)
 def process_batch(self, batch):
     try:
         input_names = [v.name for v in self.unique_inputs]
         batch = dict_subset(batch, input_names)
     except KeyError:
         reraise_as("Not all data sources required for monitoring were"
                    " provided. The list of required data sources:"
                    " {}.".format(input_names))
     if self._aggregate_fun is not None:
         numerical_values = self._aggregate_fun(**batch)
         self.monitored_quantities_buffer.aggregate_quantities(
             numerical_values)
Beispiel #16
0
 def process_batch(self, batch):
     try:
         input_names = [v.name for v in self.unique_inputs]
         batch = dict_subset(batch, input_names)
     except KeyError:
         reraise_as(
             "Not all data sources required for monitoring were"
             " provided. The list of required data sources:"
             " {}.".format(input_names))
     if self._accumulate_fun is not None:
         numerical_values = self._accumulate_fun(**batch)
         for value, var in zip(numerical_values,self.theano_variables):
             self.data[var.name].append(value)
Beispiel #17
0
 def process_batch(self, batch):
     try:
         input_names = [v.name for v in self.unique_inputs]
         batch = dict_subset(batch, input_names)
     except KeyError:
         reraise_as(
             "Not all data sources required for monitoring were"
             " provided. The list of required data sources:"
             " {}.".format(input_names))
     if self._aggregate_fun is not None:
         numerical_values = self._aggregate_fun(**batch)
         self.monitored_quantities_buffer.aggregate_quantities(
             numerical_values)
Beispiel #18
0
 def do(self, which_callback, *args):
     print('doing LoadOnlyModel_later')
     #sys.exit(0)
     if not os.path.exists(self.path_to_folder):
         logger.info("No dump found")
         return
     logger.info("Loading the state from {} into the main loop".format(
         self.path_to_folder))
     try:
         self.load_to(self.main_loop)
         self.main_loop.log.current_row[LOADED_FROM] = self.path_to_folder
     except Exception:
         reraise_as("Failed to load the state")
Beispiel #19
0
 def process_batch(self, batch, accumulate_dict):
     try:
         input_names = [v.name for v in self.inputs]
         batch = dict_subset(batch, input_names)
     except KeyError:
         reraise_as(
             "Not all data sources required for monitoring were"
             " provided. The list of required data sources:"
             " {}.".format(input_names)
         )
     results_list = self._func(**batch)
     output_names = [v.name for v in self.outputs]
     for name, res in zip(output_names, results_list):
         accumulate_dict[name].append(res)
Beispiel #20
0
 def _run_iteration(self):
     try:
         batch = next(self.epoch_iterator)
     except StopIteration:
         if not self.log.status._received_first_batch:
             reraise_as(ValueError("epoch iterator yielded zero batches"))
         return False
     self.log.status._received_first_batch = True
     self._run_extensions('before_batch', batch)
     self.algorithm.process_batch(batch)
     self.status.iterations_done += 1
     self._run_extensions('after_batch', batch)
     self._check_finish_training('batch')
     return True
Beispiel #21
0
 def _run_iteration(self):
     try:
         batch = next(self.epoch_iterator)
     except StopIteration:
         if not self.log.status._received_first_batch:
             reraise_as(ValueError("epoch iterator yielded zero batches"))
         return False
     self.log.status._received_first_batch = True
     self._run_extensions('before_batch', batch)
     self.algorithm.process_batch(batch)
     self.status.iterations_done += 1
     self._run_extensions('after_batch', batch)
     self._check_finish_training('batch')
     return True
Beispiel #22
0
 def _run_iteration(self):
     try:
         with Timer('read_data', self.profile):
             batch = next(self.epoch_iterator)
     except StopIteration:
         if not self.log.status['received_first_batch']:
             reraise_as(ValueError("epoch iterator yielded zero batches"))
         return False
     self.log.status['received_first_batch'] = True
     self._run_extensions('before_batch', batch)
     with Timer('train', self.profile):
         self.algorithm.process_batch(batch)
     self.status['iterations_done'] += 1
     self._run_extensions('after_batch', batch)
     self._check_finish_training('batch')
     return True
Beispiel #23
0
 def _run_iteration(self):
     try:
         with Timer('read_data', self.profile):
             batch = next(self.epoch_iterator)
     except StopIteration:
         if not self.log.status['received_first_batch']:
             reraise_as(ValueError("epoch iterator yielded zero batches"))
         return False
     self.log.status['received_first_batch'] = True
     self._run_extensions('before_batch', batch)
     with Timer('train', self.profile):
         self.algorithm.process_batch(batch)
     self.status['iterations_done'] += 1
     self._run_extensions('after_batch', batch)
     self._check_finish_training('batch')
     return True
Beispiel #24
0
 def _run_iteration(self):
     ministeps_made = 0 # disabled D learning
     self.false_generated.set_value(0.)
     self.false_dataset.set_value(0.)
     while ministeps_made < self.k:
         try:
             with Timer('read_data', self.profile):
                 batch = next(self.epoch_iterator)
         except StopIteration:
             if not self.log.status['received_first_batch']:
                 reraise_as(ValueError("epoch iterator yielded zero batches"))
             return False
         self.log.status['received_first_batch'] = True
         self._run_extensions('before_batch', batch)
         batch = batch['features']
         noise = np.random.rand(self.minibatches, self.noise_per_sample).astype(np.float32)
         generated_batch = self._generator(noise)[0]
         bound_batch = np.zeros((batch.shape[0] * 2, batch.shape[1]), dtype=np.float32)
         bound_batch[:self.minibatches, :] = generated_batch
         bound_batch[self.minibatches:, :] = batch
         np.save('generated.npy', generated_batch)
         bound_batch = {'features': bound_batch}
         # if self.first_batch is None:
         #     self.first_batch = bound_batch
         with Timer('train', self.profile):
             # self.algorithm_d.process_batch(self.first_batch)
             self.algorithm_d.process_batch(bound_batch)
         ministeps_made += 1
     
     false_generated_perc = self.false_generated.get_value() / (self.k * self.minibatches)
     false_dataset_perc = self.false_dataset.get_value() / (self.k * self.minibatches)
     self.log[self.status['iterations_done'] + 1]['error_on_generated'] = false_generated_perc
     self.log[self.status['iterations_done'] + 1]['error_on_dataset'] = false_dataset_perc
     self.generator_errors.set_value(0.)
     noise = np.random.rand(self.minibatches, self.noise_per_sample).astype(np.float32)
     noise_batch = {'noise': noise}
     with Timer('train', self.profile):
         self.algorithm_g.process_batch(noise_batch)
     gen_errors = self.generator_errors.get_value() / self.minibatches
     self.log[self.status['iterations_done'] + 1]['generator_errors'] = gen_errors
     for o in self.observables:
         self.log[self.status['iterations_done'] + 1][o.name] = o.get_value()
     self.status['iterations_done'] += 1
     self._run_extensions('after_batch', bound_batch)
     self._check_finish_training('batch')
     return True
Beispiel #25
0
    def run(self):
        logging.basicConfig()

        with change_recursion_limit(cfg.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                if self.log.status['iterations_done'] > 0:
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc(e)
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception as inner_e:
                    logger.error(traceback.format_exc(inner_e))
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if cfg.profile:
                    self.profile.report()
                self._restore_signal_handlers()
Beispiel #26
0
    def run(self):
        logging.basicConfig()

        with change_recursion_limit(cfg.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                if self.log.status['iterations_done'] > 0:
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc(e)
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error')
                except Exception as inner_e:
                    logger.error(traceback.format_exc(inner_e))
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if cfg.profile:
                    self.profile.report()
                self._restore_signal_handlers()
Beispiel #27
0
def pickle_dump(*args, **kwargs):
    """A wrapper around pickle's dump that provides informative errors."""
    try:
        cPickle.dump(*args, **kwargs)
    except Exception as e:
        if six.PY3 and '<lambda>' in e.args[0]:
            reraise_as("Pickling failed to pickle a lambda function." +
                       LAMBDA_ERROR)
        if six.PY3 and '<function' in e.args[0] and '<locals>' in e.args[0]:
            reraise_as("Pickling failed to pickle a nested function." +
                       NESTED_FUNCTION_ERROR)
        if six.PY2 and 'function objects' in e.args[0]:
            reraise_as("Pickling failed to pickle a function." +
                       LAMBDA_ERROR + NESTED_FUNCTION_ERROR)
        if ((six.PY2 and 'isinstancemethod' in e.args[0]) or
                (six.PY3 and '<function' in e.args[0] and
                 'attribute lookup' in e.args[0])):
            reraise_as("Pickling failed to pickle a reference to a method." +
                       INSTANCEMETHOD_ERROR)
        reraise_as("Pickling failed." + PICKLING_ERROR)
Beispiel #28
0
def pickle_dump(*args, **kwargs):
    """A wrapper around pickle's dump that provides informative errors."""
    try:
        cPickle.dump(*args, **kwargs)
    except Exception as e:
        if six.PY3 and '<lambda>' in e.args[0]:
            reraise_as("Pickling failed to pickle a lambda function." +
                       LAMBDA_ERROR)
        if six.PY3 and '<function' in e.args[0] and '<locals>' in e.args[0]:
            reraise_as("Pickling failed to pickle a nested function." +
                       NESTED_FUNCTION_ERROR)
        if six.PY2 and 'function objects' in e.args[0]:
            reraise_as("Pickling failed to pickle a function." + LAMBDA_ERROR +
                       NESTED_FUNCTION_ERROR)
        if ((six.PY2 and 'isinstancemethod' in e.args[0])
                or (six.PY3 and '<function' in e.args[0]
                    and 'attribute lookup' in e.args[0])):
            reraise_as("Pickling failed to pickle a reference to a method." +
                       INSTANCEMETHOD_ERROR)
        reraise_as("Pickling failed." + PICKLING_ERROR)
Beispiel #29
0
    def allocate(self):
        """Allocate shared variables for parameters.

        Based on the current configuration of this :class:`Brick` create
        Theano shared variables to store the parameters.  After allocation,
        parameters are accessible through the :attr:`params` attribute.

        This method calls the :meth:`allocate` method of all children
        first, allowing the :meth:`_allocate` method to override the
        parameters of the children if needed.

        Raises
        ------
        ValueError
            If the configuration of this brick is insufficient to determine
            the number of parameters or their dimensionality to be
            initialized.

        Notes
        -----
        This method sets the :attr:`params` attribute to an empty list.
        This is in order to ensure that calls to this method completely
        reset the parameters.

        """
        if not self.allocation_config_pushed:
            self.push_allocation_config()
        for child in self.children:
            child.allocate()
        self.params = []
        try:
            self._allocate()
        except Exception:
            if self.lazy:
                reraise_as("Lazy initialization is enabled, so please make "
                           "sure you have set all the required configuration "
                           "for this method call.")
            else:
                raise
        self.allocated = True
Beispiel #30
0
    def allocate(self):
        """Allocate shared variables for parameters.

        Based on the current configuration of this :class:`Brick` create
        Theano shared variables to store the parameters.  After allocation,
        parameters are accessible through the :attr:`params` attribute.

        This method calls the :meth:`allocate` method of all children
        first, allowing the :meth:`_allocate` method to override the
        parameters of the children if needed.

        Raises
        ------
        ValueError
            If the configuration of this brick is insufficient to determine
            the number of parameters or their dimensionality to be
            initialized.

        Notes
        -----
        This method sets the :attr:`params` attribute to an empty list.
        This is in order to ensure that calls to this method completely
        reset the parameters.

        """
        if not self.allocation_config_pushed:
            self.push_allocation_config()
        for child in self.children:
            child.allocate()
        self.params = []
        try:
            self._allocate()
        except Exception:
            if self.lazy:
                reraise_as("Lazy initialization is enabled, so please make "
                           "sure you have set all the required configuration "
                           "for this method call.")
            else:
                raise
        self.allocated = True
Beispiel #31
0
    def _run_iteration(self):
        ministeps_made = 0
        self.false_generated = 0.
        self.false_dataset = 0.
        while ministeps_made < self.k:
            try:
                with Timer('read_data', self.profile):
                    batch = next(self.epoch_iterator)
            except StopIteration:
                if not self.log.status['received_first_batch']:
                    reraise_as(ValueError("epoch iterator yielded zero batches"))
                return False
            self.log.status['received_first_batch'] = True
            self._run_extensions('before_batch', batch)
            batch = batch['features']
            noise = np.random.rand(self.noise_per_sample, self.minibatches).astype(np.float32)
            generated_batch = self._generator(noise)[0]
            bound_batch = np.zeros((batch.shape[0] * 2, batch.shape[1]), dtype=np.float32)
            bound_batch[:self.minibatches, :] = generated_batch
            bound_batch[self.minibatches:, :] = batch
            bound_batch = {'features': bound_batch}
            with Timer('train', self.profile):
                self.algorithm_d.process_batch(bound_batch)
            ministeps_made += 1
            self.false_generated += self.d_out[:self.minibatches, 0].sum()
            self.false_dataset += self.d_out[self.minibatches:, 1].sum()
        self.false_generated /= self.k * self.minibatches
        self.false_dataset /= self.k * self.minibatches

        noise = np.random.rand(self.noise_per_sample, self.minibatches).astype(np.float32)
        generated_batch = self._generator(noise)[0]
        generated_batch = {'noise': noise}
        with Timer('train', self.profile):
            self.algorithm_g.process_batch(generated_batch)

        self.status['iterations_done'] += 1
        self._run_extensions('after_batch', batch)
        self._check_finish_training('batch')
        return True
Beispiel #32
0
    def run(self):
        """Starts the main loop.

        The main loop ends when a training extension makes
        a `training_finish_requested` record in the log.

        """
        # This should do nothing if the user has already configured
        # logging, and will it least enable error messages otherwise.
        logging.basicConfig()

        # If this is resumption from a checkpoint, it is crucial to
        # reset `profile.current`. Otherwise, it simply does not hurt.
        self.profile.current = []

        # Sanity check for the most common case
        if (self._model and isinstance(self._model, Model)
                and isinstance(self.algorithm, GradientDescent)):
            if not (set(self._model.get_parameter_dict().values()) == set(
                    self.algorithm.parameters)):
                logger.warning("different parameters for model and algorithm")

        with change_recursion_limit(config.recursion_limit):
            self.original_sigint_handler = signal.signal(
                signal.SIGINT, self._handle_epoch_interrupt)
            self.original_sigterm_handler = signal.signal(
                signal.SIGTERM, self._handle_batch_interrupt)
            try:
                logger.info("Entered the main loop")
                if not self.status['training_started']:
                    for extension in self.extensions:
                        extension.main_loop = self
                    self._run_extensions('before_training')
                    with Timer('initialization', self.profile):
                        self.algorithm.initialize()
                    self.status['training_started'] = True
                # We can not write "else:" here because extensions
                # called "before_training" could have changed the status
                # of the main loop.
                if self.log.status['iterations_done'] > 0:
                    self.log.resume()
                    self._run_extensions('on_resumption')
                    self.status['epoch_interrupt_received'] = False
                    self.status['batch_interrupt_received'] = False
                with Timer('training', self.profile):
                    while self._run_epoch():
                        pass
            except TrainingFinish:
                self.log.current_row['training_finished'] = True
            except Exception as e:
                self._restore_signal_handlers()
                self.log.current_row['got_exception'] = traceback.format_exc()
                logger.error("Error occured during training." + error_message)
                try:
                    self._run_extensions('on_error', e)
                except Exception:
                    logger.error(traceback.format_exc())
                    logger.error("Error occured when running extensions." +
                                 error_in_error_handling_message)
                reraise_as(e)
            finally:
                self._restore_signal_handlers()
                if self.log.current_row.get('training_finished', False):
                    self._run_extensions('after_training')
                if config.profile:
                    self.profile.report()
Beispiel #33
0
    def apply(self, bound_application, *args, **kwargs):
        as_dict = kwargs.pop('as_dict', False)
        as_list = kwargs.pop('as_list', False)
        call_id = kwargs.pop('call_id', None)
        if as_list and as_dict:
            raise ValueError

        brick = bound_application.brick

        # Find the names of the inputs to the application method
        args_names, varargs_name, _, _ = inspect.getargspec(
            self.application_function)
        args_names = args_names[1:]

        # Construct the ApplicationCall, used to store data in for this call
        call = ApplicationCall(bound_application)
        call.metadata['call_id'] = call_id
        args = list(args)
        if 'application' in args_names:
            args.insert(args_names.index('application'), bound_application)
        if 'application_call' in args_names:
            args.insert(args_names.index('application_call'), call)

        # Allocate before applying, and optionally initialize
        if not brick.allocated:
            brick.allocate()

        # Annotate all the input variables which are Theano variables

        for i, input_ in enumerate(args):
            if isinstance(input_, tensor.Variable):
                if i < len(args_names):
                    name = args_names[i]
                else:
                    name = "{}_{}".format(varargs_name, i - len(args_names))
                args[i] = copy_and_tag(input_, brick, call, INPUT,
                                       self.name, name)
        for name, input_ in kwargs.items():
            if isinstance(input_, tensor.Variable):
                kwargs[name] = copy_and_tag(input_, brick, call, INPUT,
                                            self.name, name)

        # Run the application method on the annotated variables
        last_brick = self.call_stack[-1] if self.call_stack else None
        if (last_brick and brick is not last_brick and
                brick not in last_brick.children):
            warnings.warn('Brick ' + str(self.call_stack[-1]) + ' tries '
                          'to call brick ' + str(self.brick) + ' which '
                          'is not in the list of its children. This could '
                          'be caused because an @application decorator is '
                          'missing.')
        self.call_stack.append(brick)
        try:
            outputs = self.application_function(brick, *args, **kwargs)
            outputs = pack(outputs)
        finally:
            self.call_stack.pop()

        # Rename and annotate output variables
        for i, output in enumerate(outputs):
            if isinstance(output, tensor.Variable):
                try:
                    name = bound_application.outputs[i]
                except AttributeError:
                    name = "output_{}".format(i)
                except IndexError:
                    reraise_as(ValueError("Unexpected outputs"))
                # TODO Tag with dimensions, axes, etc. for error-checking
                outputs[i] = copy_and_tag(outputs[i], brick, call,
                                          OUTPUT, self.name, name)

        # Return values
        if as_list:
            return outputs
        if as_dict:
            return OrderedDict(zip(bound_application.outputs, outputs))
        return unpack(outputs)
Beispiel #34
0

if __name__ == "__main__":
    # logging.basicConfig(level=logging.INFO)
    # parser = argparse.ArgumentParser(description="Train ALI on CIFAR10")
    # parser.add_argument("--save-path", type=str, default='ali_cifar10.tar',
    #                     help="main loop save path")
    # args = parser.parse_args()
    # create_main_loop(args.save_path).run()

    name = "ali_cifar10_%dlat" % NLAT
    logging.basicConfig(level=logging.INFO)
    sys.stdout = Logger(filename=os.path.join(
        "MIRIAM_%s_%s.log" %
        (name, datetime.datetime.now().strftime("%Y-%m-%d_%H%M"))))
    parser = argparse.ArgumentParser(description="Train ALI on CIFAR-10 %s" %
                                     name)
    parser.add_argument("--save-path",
                        type=str,
                        default='/var/scratch/aiir-mh/%s.tar' % name,
                        help="main loop save path")
    parser.add_argument("--backup-path",
                        type=str,
                        default='/var/scratch/aiir-mh/%s/' % name,
                        help="backup save path")
    args = parser.parse_args()
    try:
        create_main_loop(args.save_path, args.backup_path).run()
    except Exception as e:
        reraise_as(Exception("Extra information", *e.args))
Beispiel #35
0
    def apply(self, bound_application, *args, **kwargs):
        as_dict = kwargs.pop('as_dict', False)
        as_list = kwargs.pop('as_list', False)
        if as_list and as_dict:
            raise ValueError

        brick = bound_application.brick

        # Find the names of the inputs to the application method
        args_names, varargs_name, _, _ = inspect.getargspec(
            self.application_function)
        args_names = args_names[1:]

        # Construct the ApplicationCall, used to store data in for this call
        call = ApplicationCall(brick, bound_application)
        args = list(args)
        if 'application' in args_names:
            args.insert(args_names.index('application'), bound_application)
        if 'application_call' in args_names:
            args.insert(args_names.index('application_call'), call)

        # Allocate before applying, and optionally initialize
        if not brick.allocated:
            brick.allocate()
        if not brick.initialized and not brick.lazy:
            brick.initialize()

        # Annotate all the input variables which are Theano variables
        def copy_and_tag(variable, role, name):
            """Helper method to copy a variable and annotate it."""
            copy = variable.copy()
            # Theano name
            copy.name = _variable_name(brick.name, self.name, name)
            add_annotation(copy, brick)
            add_annotation(copy, call)
            # Blocks name
            copy.tag.name = name
            add_role(copy, role)
            return copy

        for i, input_ in enumerate(args):
            if isinstance(input_, tensor.Variable):
                if i < len(args_names):
                    name = args_names[i]
                else:
                    name = "{}_{}".format(varargs_name, i - len(args_names))
                args[i] = copy_and_tag(input_, INPUT, name)
        for name, input_ in kwargs.items():
            if isinstance(input_, tensor.Variable):
                kwargs[name] = copy_and_tag(input_, INPUT, name)

        # Run the application method on the annotated variables
        if self.call_stack and brick is not self.call_stack[-1] and \
                brick not in self.call_stack[-1].children:
            raise ValueError('Brick ' + str(self.call_stack[-1]) + ' tries '
                             'to call brick ' + str(self.brick) + ' which '
                             'is not in the list of its children.')
        self.call_stack.append(brick)
        try:
            outputs = self.application_function(brick, *args, **kwargs)
            outputs = pack(outputs)
        finally:
            self.call_stack.pop()

        # Rename and annotate output variables
        for i, output in enumerate(outputs):
            if isinstance(output, tensor.Variable):
                try:
                    name = bound_application.outputs[i]
                except AttributeError:
                    name = "output_{}".format(i)
                except IndexError:
                    reraise_as(ValueError("Unexpected outputs"))
                # TODO Tag with dimensions, axes, etc. for error-checking
                outputs[i] = copy_and_tag(outputs[i], OUTPUT, name)

        # Return values
        if as_list:
            return outputs
        if as_dict:
            return OrderedDict(zip(bound_application.outputs, outputs))
        return unpack(outputs)
Beispiel #36
0
    def apply(self, bound_application, *args, **kwargs):
        as_dict = kwargs.pop('as_dict', False)
        as_list = kwargs.pop('as_list', False)
        if as_list and as_dict:
            raise ValueError

        brick = bound_application.brick

        # Find the names of the inputs to the application method
        args_names, varargs_name, _, _ = inspect.getargspec(
            self.application_function)
        args_names = args_names[1:]

        # Construct the ApplicationCall, used to store data in for this call
        call = ApplicationCall(bound_application)
        args = list(args)
        if 'application' in args_names:
            args.insert(args_names.index('application'), bound_application)
        if 'application_call' in args_names:
            args.insert(args_names.index('application_call'), call)

        # Allocate before applying, and optionally initialize
        if not brick.allocated:
            brick.allocate()

        # Annotate all the input variables which are Theano variables
        def copy_and_tag(variable, role, name):
            """Helper method to copy a variable and annotate it."""
            copy = variable.copy()
            # Theano name
            copy.name = _variable_name(brick.name, self.name, name)
            add_annotation(copy, brick)
            add_annotation(copy, call)
            # Blocks name
            copy.tag.name = name
            add_role(copy, role)
            return copy

        for i, input_ in enumerate(args):
            if isinstance(input_, tensor.Variable):
                if i < len(args_names):
                    name = args_names[i]
                else:
                    name = "{}_{}".format(varargs_name, i - len(args_names))
                args[i] = copy_and_tag(input_, INPUT, name)
        for name, input_ in kwargs.items():
            if isinstance(input_, tensor.Variable):
                kwargs[name] = copy_and_tag(input_, INPUT, name)

        # Run the application method on the annotated variables
        last_brick = self.call_stack[-1] if self.call_stack else None
        if (last_brick and brick is not last_brick and
                brick not in last_brick.children):
            raise ValueError('Brick ' + str(self.call_stack[-1]) + ' tries '
                             'to call brick ' + str(self.brick) + ' which '
                             'is not in the list of its children.')
        self.call_stack.append(brick)
        try:
            outputs = self.application_function(brick, *args, **kwargs)
            outputs = pack(outputs)
        finally:
            self.call_stack.pop()

        # Rename and annotate output variables
        for i, output in enumerate(outputs):
            if isinstance(output, tensor.Variable):
                try:
                    name = bound_application.outputs[i]
                except AttributeError:
                    name = "output_{}".format(i)
                except IndexError:
                    reraise_as(ValueError("Unexpected outputs"))
                # TODO Tag with dimensions, axes, etc. for error-checking
                outputs[i] = copy_and_tag(outputs[i],
                                          OUTPUT, name)

        # Return values
        if as_list:
            return outputs
        if as_dict:
            return OrderedDict(zip(bound_application.outputs, outputs))
        return unpack(outputs)