Exemple #1
0
    def __init__(self, model, dt, unroll_simulation, dtype, minibatch_size,
                 device, progress):
        self.model = model
        self.dt = dt
        self.unroll = unroll_simulation
        self.dtype = dtype
        self.minibatch_size = minibatch_size
        self.device = device
        self.graph = tf.Graph()
        self.signals = signals.SignalDict(self.dtype, self.minibatch_size)
        self.inference_only = config.get_setting(model, "inference_only",
                                                 False)

        # find invariant inputs (nodes that don't receive any input other
        # than the simulation time). we'll compute these outside the simulation
        # and feed in the result.
        if self.model.toplevel is None:
            self.invariant_inputs = OrderedDict()
        else:
            self.invariant_inputs = OrderedDict(
                (n, n.output) for n in self.model.toplevel.all_nodes if
                n.size_in == 0 and not isinstance(n, tensor_node.TensorNode))

        # filter unused operators
        # remove TimeUpdate because it is executed as part of the simulation
        # loop, not part of the step plan. remove input nodes because they
        # are executed outside the simulation.
        node_processes = [
            n.output for n in self.invariant_inputs
            if isinstance(n.output, Process)
        ]
        operators = [
            op for op in self.model.operators
            if not (isinstance(op, TimeUpdate) or
                    (isinstance(op, SimPyFunc) and op.x is None) or
                    (isinstance(op, SimProcess) and op.input is None
                     and op.process in node_processes))
        ]

        # mark trainable signals
        self.mark_signals()

        logger.info("Initial plan length: %d", len(operators))

        # apply graph simplification functions
        simplifications = config.get_setting(model, "simplifications", [
            graph_optimizer.remove_constant_copies,
            graph_optimizer.remove_unmodified_resets,
            graph_optimizer.remove_zero_incs,
            graph_optimizer.remove_identity_muls,
        ])

        with progress.sub("operator simplificaton", max_value=None):
            old_operators = []
            while len(old_operators) != len(operators) or any(
                    x is not y for x, y in zip(operators, old_operators)):
                old_operators = operators
                for simp in simplifications:
                    operators = simp(operators)

        # group mergeable operators
        planner = config.get_setting(model, "planner",
                                     graph_optimizer.tree_planner)

        with progress.sub("merging operators", max_value=None):
            plan = planner(operators)

        # TODO: we could also merge operators sequentially (e.g., combine
        # a copy and dotinc into one op), as long as the intermediate signal
        # is only written to by one op and read by one op

        # order signals/operators to promote contiguous reads
        sorter = config.get_setting(model, "sorter",
                                    graph_optimizer.order_signals)

        with progress.sub("ordering signals", max_value=None):
            sigs, self.plan = sorter(plan, n_passes=10)

        # create base arrays and map Signals to TensorSignals (views on those
        # base arrays)
        with progress.sub("creating signals", max_value=None):
            self.create_signals(sigs)

        logger.info("Optimized plan length: %d", len(self.plan))
        logger.info("Number of base arrays: %d", len(self.base_arrays_init))

        # initialize op builder
        build_config = builder.BuildConfig(
            inference_only=self.inference_only,
            lif_smoothing=config.get_setting(self.model, "lif_smoothing"),
            cpu_only=self.device == "/cpu:0" or not utils.tf_gpu_installed,
        )
        self.op_builder = builder.Builder(self.plan, self.graph, self.signals,
                                          build_config)
    def call(self, inputs, training=None, progress=None, stateful=False):
        """
        Constructs the graph elements to simulate the model.

        Parameters
        ----------
        inputs : list of ``tf.Tensor``
            Input layers/tensors for the network (must match the structure defined in
            `.build_inputs`).
        training : bool
            Whether the network is being run in training or inference mode.  If None,
            uses the symbolic Keras learning phase variable.
        progress : `.utils.ProgressBar`
            Progress bar for construction stage.
        stateful : bool
            Whether or not to build the model to support preserving the internal state
            between executions.

        Returns
        -------
        probe_arrays : list of ``tf.Tensor``
            Tensors representing the output of all the Probes in the network (order
            corresponding to ``self.model.probes``, which is the order the Probes were
            instantiated).
        """

        super().call(inputs, training=training)

        if training == 1 and self.inference_only:
            raise BuildError(
                "TensorGraph was created with inference_only=True; cannot be built "
                "with training=%s" % training)

        tf.random.set_seed(self.seed)

        if progress is None:
            progress = utils.NullProgressBar()

        # reset signaldict
        self.signals.reset()

        # create these constants once here for reuse in different operators
        self.signals.dt = tf.constant(self.dt, self.dtype)
        self.signals.dt_val = self.dt  # store the actual value as well
        self.signals.zero = tf.constant(0, self.dtype)
        self.signals.one = tf.constant(1, self.dtype)

        # set up invariant inputs
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.node_inputs = {}
        for n, inp in zip(self.invariant_inputs, inputs):
            # specify shape of inputs (keras sometimes loses this shape information)
            inp.set_shape([self.minibatch_size, inp.shape[1], n.size_out])

            self.node_inputs[n] = inp

        self.steps_to_run = inputs[-1][0, 0]

        # initialize op builder
        build_config = builder.BuildConfig(
            inference_only=self.inference_only,
            lif_smoothing=config.get_setting(self.model, "lif_smoothing"),
            cpu_only=self.device == "/cpu:0" or not utils.tf_gpu_installed,
            rng=np.random.RandomState(self.seed),
            training=(tf.keras.backend.learning_phase()
                      if training is None else training),
        )
        self.op_builder = builder.Builder(self.plan, self.signals,
                                          build_config)

        # pre-build stage
        with progress.sub("pre-build stage", max_value=len(self.plan)) as sub:
            self.op_builder.build_pre(sub)

        # build stage
        with progress.sub("build stage",
                          max_value=len(self.plan) * self.unroll) as sub:
            steps_run, probe_arrays, final_internal_state, final_base_params = (
                self._build_loop(sub)
                if self.use_loop else self._build_no_loop(sub))

        # store these so that they can be accessed after the initial build
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.steps_run = steps_run
            self.probe_arrays = probe_arrays
            self.final_internal_state = final_internal_state
            self.final_base_params = final_base_params

        # logging
        logger.info("Number of reads: %d",
                    sum(x for x in self.signals.read_types.values()))
        for x in self.signals.read_types.items():
            logger.info("    %s: %d", *x)
        logger.info("Number of writes: %d",
                    sum(x for x in self.signals.write_types.values()))
        for x in self.signals.write_types.items():
            logger.info("    %s: %d", *x)

        # note: always return steps_run so that the simulation will run for the given
        # number of steps, even if there are no output probes
        outputs = list(probe_arrays.values()) + [steps_run]

        updates = []
        if stateful:
            # update saved state
            updates.extend(
                var.assign(val) for var, val in zip(self.saved_state.values(),
                                                    final_internal_state))

        # if any of the base params have changed (due to online learning rules) then we
        # also need to assign those back to the original variable (so that their
        # values will persist). any parameters targeted by online learning rules
        # will be minibatched, so we only need to update the minibatched params.
        for (key, var), val in zip(self.base_params.items(),
                                   final_base_params):
            try:
                minibatched = self.base_arrays_init["non_trainable"][key][-1]
            except KeyError:
                minibatched = self.base_arrays_init["trainable"][key][-1]

            if minibatched:
                updates.append(var.assign(val))

        logger.info("Number of variable updates: %d", len(updates))

        if len(updates) > 0:
            with tf.control_dependencies(updates):
                outputs = [tf.identity(x) for x in outputs]

        return outputs
    def __init__(
        self, model, dt, unroll_simulation, minibatch_size, device, progress, seed
    ):
        super().__init__(
            name="TensorGraph",
            dynamic=False,
            trainable=not config.get_setting(model, "inference_only", False),
            dtype=config.get_setting(model, "dtype", "float32"),
            batch_size=minibatch_size,
        )

        self.model = model
        self.dt = dt
        self.unroll = unroll_simulation
        self.use_loop = config.get_setting(model, "use_loop", True)
        self.minibatch_size = minibatch_size
        self.device = device
        self.seed = seed
        self.inference_only = not self.trainable
        self.signals = signals.SignalDict(self.dtype, self.minibatch_size)

        # find invariant inputs (nodes that don't receive any input other
        # than the simulation time). we'll compute these outside the simulation
        # and feed in the result.
        if self.model.toplevel is None:
            self.invariant_inputs = OrderedDict()
        else:
            self.invariant_inputs = OrderedDict(
                (n, n.output)
                for n in self.model.toplevel.all_nodes
                if n.size_in == 0 and not isinstance(n, tensor_node.TensorNode)
            )

        # remove input nodes because they are executed outside the simulation
        node_processes = [
            n.output for n in self.invariant_inputs if isinstance(n.output, Process)
        ]
        operators = [
            op
            for op in self.model.operators
            if not (
                (isinstance(op, SimPyFunc) and op.x is None)
                or (
                    isinstance(op, SimProcess)
                    and op.input is None
                    and op.process in node_processes
                )
            )
        ]

        # mark trainable signals
        self.mark_signals()

        logger.info("Initial plan length: %d", len(operators))

        # apply graph simplification functions
        simplifications = config.get_setting(
            model, "simplifications", graph_optimizer.default_simplifications,
        )

        with progress.sub("operator simplificaton", max_value=None):
            old_operators = []
            while len(old_operators) != len(operators) or any(
                x is not y for x, y in zip(operators, old_operators)
            ):
                old_operators = operators
                for simp in simplifications:
                    operators = simp(operators)

        # group mergeable operators
        planner = config.get_setting(model, "planner", graph_optimizer.tree_planner)

        with progress.sub("merging operators", max_value=None):
            plan = planner(operators)

        # TODO: we could also merge operators sequentially (e.g., combine
        # a copy and dotinc into one op), as long as the intermediate signal
        # is only written to by one op and read by one op

        # order signals/operators to promote contiguous reads
        sorter = config.get_setting(model, "sorter", graph_optimizer.order_signals)

        with progress.sub("ordering signals", max_value=None):
            sigs, self.plan = sorter(plan, n_passes=10)

        # create base arrays and map Signals to TensorSignals (views on those
        # base arrays)
        with progress.sub("creating signals", max_value=None):
            self.create_signals(sigs)

        # generate unique names for layer inputs/outputs
        # this follows the TensorFlow unique naming scheme, so if multiple objects are
        # created with the same name, they will be named like name, NAME_1, name_2
        # (note: case insensitive)
        self.io_names = {}
        name_count = defaultdict(int)
        for obj in list(self.invariant_inputs.keys()) + self.model.probes:
            name = (
                type(obj).__name__.lower()
                if obj.label is None
                else utils.sanitize_name(obj.label)
            )

            key = name.lower()

            if name_count[key] > 0:
                name += "_%d" % name_count[key]

            self.io_names[obj] = name
            name_count[key] += 1

        # set up op builder
        self.op_builder = builder.Builder(self.plan)

        # logging
        logger.info("Optimized plan length: %d", len(self.plan))
        logger.info(
            "Number of base arrays: (%s, %d), (%s, %d), (%s, %d)",
            *sum(((k, len(x)) for k, x in self.base_arrays_init.items()), ()),
        )