def _check_train_consistency(self, flow): train_states = mdp.numx.array([node.is_training() for node in flow], dtype='int') # only sequences with trained to training is allowed and not the opposite ts_diff = train_states[1:] - train_states[:-1] if mdp.numx.any(ts_diff < 0): raise mdp.IsNotTrainableException("Inconsistent hierarchy! Found a trainable layer in between two " "trained layers.")
def train(self, x, msg=None): """Train and return None or more if the execution should continue. The possible return types are None, y, (y, msg), (y, msg, target). The last entry in a result tuple must not be None. y can be None if the result is a tuple. This template method normally calls the corresponding _train method or another method as specified in the message (using the magic 'method' key. Note that the remaining msg and taret values are only used if _train (or the requested method) returns something different from None (so an empty dict can be used to trigger continued execution). """ # perform checks, adapted from Node.train if not self.is_trainable(): raise mdp.IsNotTrainableException("This node is not trainable.") if not self.is_training(): err = "The training phase has already finished." raise mdp.TrainingFinishedException(err) if msg is None: if x is None: err = "Both x and msg are None." raise BiNodeException(err) # no fall-back on Node.train because we might have a return value self._check_input(x) try: self._check_train_args(x) except TypeError: err = ("%s training seems to require " % str(self) + "additional arguments, but none were given.") raise BiNodeException(err) self._train_phase_started = True x = self._refcast(x) return self._train_seq[self._train_phase][0](x) msg_id_keys = self._get_msg_id_keys(msg) target = self._extract_message_key("target", msg, msg_id_keys) method_name = self._extract_message_key("method", msg, msg_id_keys) default_method = self._train_seq[self._train_phase][0] method, target = self._get_method(method_name, default_method, target) msg, arg_dict = self._extract_method_args(method, msg, msg_id_keys) # perform specific checks if x is not None: if (not method_name) or (method_name == "train"): self._check_input(x) try: self._check_train_args(x, **arg_dict) except TypeError: err = ("The given additional arguments %s " % str(list(arg_dict.keys())) + "are not compatible with training %s." % str(self)) raise BiNodeException(err) self._train_phase_started = True x = self._refcast(x) elif method == self._inverse: self._pre_inversion_checks(x) result = method(x, **arg_dict) if result is None: return None result = self._combine_result(result, msg, target) if (isinstance(result, tuple) and len(result) == 2 and result[0] is None): # drop the remaining msg, so that no maual clearing is required return None return result