示例#1
0
 def _check_train_consistency(self, flow):
     train_states = mdp.numx.array([node.is_training() for node in flow], dtype='int')
     # only sequences with trained to training is allowed and not the opposite
     ts_diff = train_states[1:] - train_states[:-1]
     if mdp.numx.any(ts_diff < 0):
         raise mdp.IsNotTrainableException("Inconsistent hierarchy! Found a trainable layer in between two "
                                           "trained layers.")
示例#2
0
    def train(self, x, msg=None):
        """Train and return None or more if the execution should continue.

        The possible return types are None, y, (y, msg), (y, msg, target).
        The last entry in a result tuple must not be None.
        y can be None if the result is a tuple.

        This template method normally calls the corresponding _train method
        or another method as specified in the message (using the magic 'method'
        key.

        Note that the remaining msg and taret values are only used if _train
        (or the requested method) returns something different from None
        (so an empty dict can be used to trigger continued execution).
        """
        # perform checks, adapted from Node.train
        if not self.is_trainable():
            raise mdp.IsNotTrainableException("This node is not trainable.")
        if not self.is_training():
            err = "The training phase has already finished."
            raise mdp.TrainingFinishedException(err)
        if msg is None:
            if x is None:
                err = "Both x and msg are None."
                raise BiNodeException(err)
            # no fall-back on Node.train because we might have a return value
            self._check_input(x)
            try:
                self._check_train_args(x)
            except TypeError:
                err = ("%s training seems to require " % str(self) +
                       "additional arguments, but none were given.")
                raise BiNodeException(err)
            self._train_phase_started = True
            x = self._refcast(x)
            return self._train_seq[self._train_phase][0](x)
        msg_id_keys = self._get_msg_id_keys(msg)
        target = self._extract_message_key("target", msg, msg_id_keys)
        method_name = self._extract_message_key("method", msg, msg_id_keys)
        default_method = self._train_seq[self._train_phase][0]
        method, target = self._get_method(method_name, default_method, target)
        msg, arg_dict = self._extract_method_args(method, msg, msg_id_keys)
        # perform specific checks
        if x is not None:
            if (not method_name) or (method_name == "train"):
                self._check_input(x)
                try:
                    self._check_train_args(x, **arg_dict)
                except TypeError:
                    err = ("The given additional arguments %s " %
                           str(list(arg_dict.keys())) +
                           "are not compatible with training %s." % str(self))
                    raise BiNodeException(err)
                self._train_phase_started = True
                x = self._refcast(x)
            elif method == self._inverse:
                self._pre_inversion_checks(x)
        result = method(x, **arg_dict)
        if result is None:
            return None
        result = self._combine_result(result, msg, target)
        if (isinstance(result, tuple) and len(result) == 2
                and result[0] is None):
            # drop the remaining msg, so that no maual clearing is required
            return None
        return result