def __init__(self, log_dir, sim, objects): super().__init__() self.sim = sim # we do all the summary writing in eager mode, so that it will be executed # as the callback is called with context.eager_mode(): self.writer = tf.summary.create_file_writer(log_dir) self.summaries = [] for obj in objects: if isinstance( obj, (nengo.Ensemble, nengo.ensemble.Neurons, nengo.Connection)): if isinstance(obj, nengo.Ensemble): param = "encoders" name = "Ensemble_%s" % obj.label elif isinstance(obj, nengo.ensemble.Neurons): param = "bias" name = "Ensemble.neurons_%s" % obj.ensemble.label elif isinstance(obj, nengo.Connection): if not compat.conn_has_weights(obj): raise ValidationError( "Connection '%s' does not have any weights to log" % obj, "objects", ) param = "weights" name = "Connection_%s" % obj.label self.summaries.append( (utils.sanitize_name("%s_%s" % (name, param)), obj, param)) else: raise ValidationError( "Unknown summary object %s; should be an Ensemble, Neurons, or " "Connection" % obj, "objects", )
def mark_network(parent_configs, net): """Recursively marks the signals for objects within each subnetwork.""" parent_configs = parent_configs + [net.config] for subnet in net.networks: mark_network(parent_configs, subnet) # encoders and biases are trainable for ens in net.ensembles: ens_trainable = get_trainable(parent_configs, ens) self.model.sig[ens]["encoders"].trainable = ens_trainable self.model.sig[ens]["encoders"].minibatched = False if not isinstance(ens.neuron_type, Direct): neurons_trainable = get_trainable(parent_configs, ens.neurons) if neurons_trainable is 1: # noqa: F632 neurons_trainable = ens_trainable self.model.sig[ ens.neurons]["bias"].trainable = neurons_trainable self.model.sig[ens.neurons]["bias"].minibatched = False # connection weights are trainable for conn in net.connections: # note: this doesn't include probe connections, since they # aren't added to the network if compat.conn_has_weights(conn): self.model.sig[conn]["weights"].trainable = get_trainable( parent_configs, conn) self.model.sig[conn]["weights"].minibatched = False # parameters can't be modified by an online Nengo learning rule # and offline training at the same time. (it is possible in # theory, but it complicates things a lot and is probably not a # common use case). we also make those signals minibatched # (they wouldn't be normally), because we want to be able to # learn independently in each minibatch for conn in net.connections: rule = conn.learning_rule if rule is not None: if isinstance(rule, dict): rule = list(rule.values()) elif not isinstance(rule, list): rule = [rule] for r in rule: if r.modifies in ("weights", "decoders"): obj = conn attr = "weights" elif r.modifies == "encoders": obj = conn.post_obj attr = "encoders" else: raise NotImplementedError if self.model.sig[obj][attr].trainable is True: warnings.warn( "%s has a learning rule and is also set " "to be trainable; this is likely to " "produce strange training behaviour." % obj) else: self.model.sig[obj][attr].trainable = False self.model.sig[obj][attr].minibatched = True
def mark_signals(self): """ Mark all the signals in ``self.model`` according to whether they represent trainable parameters of the model (parameters that can be optimized by deep learning methods). Trainable parameters include connection weights, ensemble encoders, and neuron biases. Unless one of those signals is targeted by a Nengo learning rule (otherwise the learning rule update conflicts with the deep learning optimization). Users can manually specify whether signals are trainable or not using the config system (e.g., ``net.config[nengo.Ensemble].trainable = False``). The trainable attribute will be set to one of three values: - ``True``: Signal is trainable - ``False``: Signal could be trainable, but has been set to non-trainable (e.g., because the user manually configured that object not to be trainable). - ``None``: Signal is never trainable (e.g., simulator state) """ def get_trainable(parent_configs, obj): """Looks up the current value of ``obj.trainable``.""" if self.inference_only: return False # default to 1 (so that we can distinguish between an object being # set to trainable vs defaulting to trainable) trainable = 1 # we go from top down (so lower level settings will override) for cfg in parent_configs: try: cfg_trainable = getattr(cfg[obj], "trainable", None) except ConfigError: # object not configured in this network config cfg_trainable = None if cfg_trainable is not None: trainable = cfg_trainable return trainable def mark_network(parent_configs, net): """Recursively marks the signals for objects within each subnetwork.""" parent_configs = parent_configs + [net.config] for subnet in net.networks: mark_network(parent_configs, subnet) # encoders and biases are trainable for ens in net.ensembles: ens_trainable = get_trainable(parent_configs, ens) self.model.sig[ens]["encoders"].trainable = ens_trainable self.model.sig[ens]["encoders"].minibatched = False if not isinstance(ens.neuron_type, Direct): neurons_trainable = get_trainable(parent_configs, ens.neurons) if neurons_trainable is 1: # noqa: F632 neurons_trainable = ens_trainable self.model.sig[ ens.neurons]["bias"].trainable = neurons_trainable self.model.sig[ens.neurons]["bias"].minibatched = False # connection weights are trainable for conn in net.connections: # note: this doesn't include probe connections, since they # aren't added to the network if compat.conn_has_weights(conn): self.model.sig[conn]["weights"].trainable = get_trainable( parent_configs, conn) self.model.sig[conn]["weights"].minibatched = False # parameters can't be modified by an online Nengo learning rule # and offline training at the same time. (it is possible in # theory, but it complicates things a lot and is probably not a # common use case). we also make those signals minibatched # (they wouldn't be normally), because we want to be able to # learn independently in each minibatch for conn in net.connections: rule = conn.learning_rule if rule is not None: if isinstance(rule, dict): rule = list(rule.values()) elif not isinstance(rule, list): rule = [rule] for r in rule: if r.modifies in ("weights", "decoders"): obj = conn attr = "weights" elif r.modifies == "encoders": obj = conn.post_obj attr = "encoders" else: raise NotImplementedError if self.model.sig[obj][attr].trainable is True: warnings.warn( "%s has a learning rule and is also set " "to be trainable; this is likely to " "produce strange training behaviour." % obj) else: self.model.sig[obj][attr].trainable = False self.model.sig[obj][attr].minibatched = True if self.model.toplevel is None: warnings.warn( "No top-level network in model; assuming no trainable parameters", UserWarning, ) else: mark_network([], self.model.toplevel) # the connections to connection probes are not trainable, but # also not minibatched probe_seeds = [self.model.seeds[p] for p in self.model.probes] for obj, seed in self.model.seeds.items(): if isinstance(obj, Connection) and seed in probe_seeds: if compat.conn_has_weights(obj): self.model.sig[obj]["weights"].trainable = None self.model.sig[obj]["weights"].minibatched = False # time/step are not minibatched and not trainable self.model.step.trainable = None self.model.step.minibatched = False self.model.time.trainable = None self.model.time.minibatched = False # fill in defaults for all other signals # signals are not trainable by default, and views take on the # properties of their bases for op in self.model.operators: for sig in op.all_signals: if not hasattr(sig.base, "trainable"): sig.base.trainable = None if not hasattr(sig.base, "minibatched"): sig.base.minibatched = not sig.base.trainable if not hasattr(sig, "trainable"): sig.trainable = sig.base.trainable if not hasattr(sig, "minibatched"): sig.minibatched = sig.base.minibatched