示例#1
0
    def __init__(self, log_dir, sim, objects):
        super().__init__()

        self.sim = sim

        # we do all the summary writing in eager mode, so that it will be executed
        # as the callback is called
        with context.eager_mode():
            self.writer = tf.summary.create_file_writer(log_dir)

        self.summaries = []
        for obj in objects:
            if isinstance(
                    obj,
                (nengo.Ensemble, nengo.ensemble.Neurons, nengo.Connection)):
                if isinstance(obj, nengo.Ensemble):
                    param = "encoders"
                    name = "Ensemble_%s" % obj.label
                elif isinstance(obj, nengo.ensemble.Neurons):
                    param = "bias"
                    name = "Ensemble.neurons_%s" % obj.ensemble.label
                elif isinstance(obj, nengo.Connection):
                    if not compat.conn_has_weights(obj):
                        raise ValidationError(
                            "Connection '%s' does not have any weights to log"
                            % obj,
                            "objects",
                        )
                    param = "weights"
                    name = "Connection_%s" % obj.label

                self.summaries.append(
                    (utils.sanitize_name("%s_%s" % (name, param)), obj, param))
            else:
                raise ValidationError(
                    "Unknown summary object %s; should be an Ensemble, Neurons, or "
                    "Connection" % obj,
                    "objects",
                )
        def mark_network(parent_configs, net):
            """Recursively marks the signals for objects within each subnetwork."""

            parent_configs = parent_configs + [net.config]

            for subnet in net.networks:
                mark_network(parent_configs, subnet)

            # encoders and biases are trainable
            for ens in net.ensembles:
                ens_trainable = get_trainable(parent_configs, ens)

                self.model.sig[ens]["encoders"].trainable = ens_trainable
                self.model.sig[ens]["encoders"].minibatched = False

                if not isinstance(ens.neuron_type, Direct):
                    neurons_trainable = get_trainable(parent_configs,
                                                      ens.neurons)
                    if neurons_trainable is 1:  # noqa: F632
                        neurons_trainable = ens_trainable

                    self.model.sig[
                        ens.neurons]["bias"].trainable = neurons_trainable
                    self.model.sig[ens.neurons]["bias"].minibatched = False

            # connection weights are trainable
            for conn in net.connections:
                # note: this doesn't include probe connections, since they
                # aren't added to the network
                if compat.conn_has_weights(conn):
                    self.model.sig[conn]["weights"].trainable = get_trainable(
                        parent_configs, conn)
                    self.model.sig[conn]["weights"].minibatched = False

            # parameters can't be modified by an online Nengo learning rule
            # and offline training at the same time. (it is possible in
            # theory, but it complicates things a lot and is probably not a
            # common use case). we also make those signals minibatched
            # (they wouldn't be normally), because we want to be able to
            # learn independently in each minibatch
            for conn in net.connections:
                rule = conn.learning_rule
                if rule is not None:
                    if isinstance(rule, dict):
                        rule = list(rule.values())
                    elif not isinstance(rule, list):
                        rule = [rule]

                    for r in rule:
                        if r.modifies in ("weights", "decoders"):
                            obj = conn
                            attr = "weights"
                        elif r.modifies == "encoders":
                            obj = conn.post_obj
                            attr = "encoders"
                        else:
                            raise NotImplementedError

                        if self.model.sig[obj][attr].trainable is True:
                            warnings.warn(
                                "%s has a learning rule and is also set "
                                "to be trainable; this is likely to "
                                "produce strange training behaviour." % obj)
                        else:
                            self.model.sig[obj][attr].trainable = False

                        self.model.sig[obj][attr].minibatched = True
    def mark_signals(self):
        """
        Mark all the signals in ``self.model`` according to whether they
        represent trainable parameters of the model (parameters that can be
        optimized by deep learning methods).

        Trainable parameters include connection weights, ensemble encoders, and
        neuron biases.  Unless one of those signals is targeted by a Nengo
        learning rule (otherwise the learning rule update conflicts with the
        deep learning optimization).

        Users can manually specify whether signals are trainable or not using
        the config system (e.g.,
        ``net.config[nengo.Ensemble].trainable = False``).

        The trainable attribute will be set to one of three values:

        - ``True``: Signal is trainable
        - ``False``: Signal could be trainable, but has been set to non-trainable
          (e.g., because the user manually configured that object not to be trainable).
        - ``None``: Signal is never trainable (e.g., simulator state)
        """
        def get_trainable(parent_configs, obj):
            """Looks up the current value of ``obj.trainable``."""

            if self.inference_only:
                return False

            # default to 1 (so that we can distinguish between an object being
            # set to trainable vs defaulting to trainable)
            trainable = 1

            # we go from top down (so lower level settings will override)
            for cfg in parent_configs:
                try:
                    cfg_trainable = getattr(cfg[obj], "trainable", None)
                except ConfigError:
                    # object not configured in this network config
                    cfg_trainable = None

                if cfg_trainable is not None:
                    trainable = cfg_trainable

            return trainable

        def mark_network(parent_configs, net):
            """Recursively marks the signals for objects within each subnetwork."""

            parent_configs = parent_configs + [net.config]

            for subnet in net.networks:
                mark_network(parent_configs, subnet)

            # encoders and biases are trainable
            for ens in net.ensembles:
                ens_trainable = get_trainable(parent_configs, ens)

                self.model.sig[ens]["encoders"].trainable = ens_trainable
                self.model.sig[ens]["encoders"].minibatched = False

                if not isinstance(ens.neuron_type, Direct):
                    neurons_trainable = get_trainable(parent_configs,
                                                      ens.neurons)
                    if neurons_trainable is 1:  # noqa: F632
                        neurons_trainable = ens_trainable

                    self.model.sig[
                        ens.neurons]["bias"].trainable = neurons_trainable
                    self.model.sig[ens.neurons]["bias"].minibatched = False

            # connection weights are trainable
            for conn in net.connections:
                # note: this doesn't include probe connections, since they
                # aren't added to the network
                if compat.conn_has_weights(conn):
                    self.model.sig[conn]["weights"].trainable = get_trainable(
                        parent_configs, conn)
                    self.model.sig[conn]["weights"].minibatched = False

            # parameters can't be modified by an online Nengo learning rule
            # and offline training at the same time. (it is possible in
            # theory, but it complicates things a lot and is probably not a
            # common use case). we also make those signals minibatched
            # (they wouldn't be normally), because we want to be able to
            # learn independently in each minibatch
            for conn in net.connections:
                rule = conn.learning_rule
                if rule is not None:
                    if isinstance(rule, dict):
                        rule = list(rule.values())
                    elif not isinstance(rule, list):
                        rule = [rule]

                    for r in rule:
                        if r.modifies in ("weights", "decoders"):
                            obj = conn
                            attr = "weights"
                        elif r.modifies == "encoders":
                            obj = conn.post_obj
                            attr = "encoders"
                        else:
                            raise NotImplementedError

                        if self.model.sig[obj][attr].trainable is True:
                            warnings.warn(
                                "%s has a learning rule and is also set "
                                "to be trainable; this is likely to "
                                "produce strange training behaviour." % obj)
                        else:
                            self.model.sig[obj][attr].trainable = False

                        self.model.sig[obj][attr].minibatched = True

        if self.model.toplevel is None:
            warnings.warn(
                "No top-level network in model; assuming no trainable parameters",
                UserWarning,
            )
        else:
            mark_network([], self.model.toplevel)

            # the connections to connection probes are not trainable, but
            # also not minibatched
            probe_seeds = [self.model.seeds[p] for p in self.model.probes]
            for obj, seed in self.model.seeds.items():
                if isinstance(obj, Connection) and seed in probe_seeds:
                    if compat.conn_has_weights(obj):
                        self.model.sig[obj]["weights"].trainable = None
                        self.model.sig[obj]["weights"].minibatched = False

        # time/step are not minibatched and not trainable
        self.model.step.trainable = None
        self.model.step.minibatched = False
        self.model.time.trainable = None
        self.model.time.minibatched = False

        # fill in defaults for all other signals
        # signals are not trainable by default, and views take on the
        # properties of their bases
        for op in self.model.operators:
            for sig in op.all_signals:
                if not hasattr(sig.base, "trainable"):
                    sig.base.trainable = None

                if not hasattr(sig.base, "minibatched"):
                    sig.base.minibatched = not sig.base.trainable

                if not hasattr(sig, "trainable"):
                    sig.trainable = sig.base.trainable

                if not hasattr(sig, "minibatched"):
                    sig.minibatched = sig.base.minibatched