コード例 #1
0
    def build_step(self, signals):
        J = signals.gather(self.J_data)
        states = [signals.gather(x) for x in self.state_data]
        states_dtype = [x.dtype for x in self.state_data]

        if compat.eager_enabled():
            # noop
            control_deps = contextlib.suppress()
        else:
            # we need to make sure that the previous call to this function
            # has completed before the next starts, since we don't know that the
            # functions are thread safe
            control_deps = tf.control_dependencies(self.prev_result)

        with control_deps:
            ret = tf.numpy_function(
                self.neuron_step,
                [signals.dt, J] + states,
                [self.output_data.dtype] + states_dtype,
                name=self.neuron_step.__name__,
            )

        neuron_out, state_out = ret[0], ret[1:]

        self.prev_result = [neuron_out]

        neuron_out.set_shape((signals.minibatch_size, ) +
                             self.output_data.shape)
        signals.scatter(self.output_data, neuron_out)

        for i, s in enumerate(self.state_data):
            state_out[i].set_shape((signals.minibatch_size, ) + s.shape)
            signals.scatter(s, state_out[i])
コード例 #2
0
def test_uneven_validation_split(Simulator):
    net, _, _ = dummies.linear_net()

    with Simulator(net, minibatch_size=2) as sim:
        sim.compile(optimizer=tf.optimizers.SGD(0), loss=tf.losses.mse)

        sim.fit(np.zeros((10, 10, 1)),
                np.zeros((10, 10, 1)),
                validation_split=0.2)

        with pytest.raises(ValidationError, match="not evenly divisible"):
            sim.fit(np.zeros((10, 10, 1)),
                    np.zeros((10, 10, 1)),
                    validation_split=0.3)

        # regular keras error message when trying to use validation_split with
        # a generator
        with pytest.raises(
                ValueError,
                match="`validation_split` is only supported for Tensors"
                if compat.eager_enabled() else "cannot use `validation_split`",
        ):
            sim.fit(
                ((x, y) for x, y in zip(np.zeros((10, 10,
                                                  1)), np.zeros((10, 10, 1)))),
                validation_split=0.7,
            )
コード例 #3
0
def test_sequential(seed):
    tf.random.set_seed(seed)

    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(32, input_shape=(4,)))
    model.add(tf.keras.layers.Dense(32))

    conv = converter.Converter(model, allow_fallback=False)
    assert conv.verify(training=False)
    # TODO: not sure why this is slightly less accurate in graph mode
    assert conv.verify(training=True, atol=1e-8 if compat.eager_enabled() else 1e-7)
コード例 #4
0
ファイル: callbacks.py プロジェクト: Sreerag-ibtl/nengo-dl
    def on_epoch_end(self, epoch, logs=None):
        """Log parameter values at the end of each epoch."""

        summary_vals = self.sim.data.get_params(
            *[(obj, attr) for _, obj, attr in self.summaries]
        )

        with (
            contextlib.suppress() if compat.eager_enabled() else context.eager_mode()
        ), self.writer.as_default():
            for (name, _, _), val in zip(self.summaries, summary_vals):
                tf.summary.histogram(name, val, step=epoch)
コード例 #5
0
def test_multi_input_warning(Simulator):
    with nengo.Network() as net:
        inp0 = nengo.Node([0])
        inp1 = nengo.Node([1])
        nengo.Probe(inp0)
        nengo.Probe(inp1)

    with Simulator(net) as sim:
        with pytest.warns(UserWarning, match="does not match number of Nodes"):
            sim.predict(np.zeros((1, 10, 1)))

        with pytest.warns(UserWarning, match="does not match number of Nodes"):
            sim.predict([np.zeros((1, 10, 1))])

        with pytest.warns(None) as recwarn:
            sim.predict([np.zeros((1, 10, 1))] * 2)
        assert not any("does not match number of Nodes" in str(w.message)
                       for w in recwarn)

        with pytest.warns(None) as recwarn:
            sim.predict({inp1: np.zeros((1, 10, 1))})
        assert not any("does not match number of Nodes" in str(w.message)
                       for w in recwarn)

        sim.compile(optimizer=tf.optimizers.SGD(0), loss=tf.losses.mse)
        with pytest.warns(UserWarning, match="does not match number of Nodes"):
            sim.fit(np.zeros((1, 10, 1)), [np.zeros((1, 10, 1))] * 2)

        with pytest.warns(UserWarning, match="does not match number of Nodes"):
            sim.evaluate(np.zeros((1, 10, 1)), [np.zeros((1, 10, 1))] * 2)

        with pytest.warns(UserWarning, match="does not match number of Probes"
                          ) if compat.eager_enabled() else pytest.raises(
                              ValueError, match="No data provided for"):
            sim.evaluate([np.zeros((1, 10, 1))] * 2, np.zeros((1, 10, 1)))

        with pytest.warns(UserWarning, match="does not match number of Probes"
                          ) if compat.eager_enabled() else pytest.raises(
                              ValueError, match="No data provided for"):
            sim.evaluate([np.zeros((1, 10, 1))] * 2, np.zeros((1, 10, 1)))
コード例 #6
0
    def build_step(self, signals):
        time = [signals.gather(self.time_data)]
        input = [] if self.input_data is None else [
            signals.gather(self.input_data)
        ]
        state = [signals.gather(s) for s in self.state_data]

        if compat.eager_enabled():
            # noop
            control_deps = contextlib.suppress()
        else:
            # we need to make sure that the previous call to this function
            # has completed before the next starts, since we don't know that the
            # functions are thread safe
            control_deps = tf.control_dependencies(self.prev_result)

        with control_deps:
            result = tf.numpy_function(
                self.merged_func,
                time + input + state,
                [self.output_data.dtype] + [s.dtype for s in self.state_data],
                name=self.merged_func.__name__,
            )

        # TensorFlow will automatically squeeze length-1 outputs (if there is
        # no state), which we don't want
        result = tf.nest.flatten(result)

        output = result[0]
        state = result[1:]

        self.prev_result = [output]

        output.set_shape(self.output_data.full_shape)
        signals.scatter(self.output_data, output, mode=self.mode)
        for i, s in enumerate(state):
            s.set_shape(self.state_data[i].full_shape)
            signals.scatter(self.state_data[i], s, mode="update")
コード例 #7
0
ファイル: callbacks.py プロジェクト: Sreerag-ibtl/nengo-dl
    def __init__(self, log_dir, sim, objects):
        super().__init__()

        self.sim = sim

        with contextlib.suppress() if compat.eager_enabled() else context.eager_mode():
            self.writer = tf.summary.create_file_writer(log_dir)

        self.summaries = []
        for obj in objects:
            if isinstance(
                obj, (nengo.Ensemble, nengo.ensemble.Neurons, nengo.Connection)
            ):
                if isinstance(obj, nengo.Ensemble):
                    param = "encoders"
                    name = "Ensemble_%s" % obj.label
                elif isinstance(obj, nengo.ensemble.Neurons):
                    param = "bias"
                    name = "Ensemble.neurons_%s" % obj.ensemble.label
                elif isinstance(obj, nengo.Connection):
                    if not compat.conn_has_weights(obj):
                        raise ValidationError(
                            "Connection '%s' does not have any weights to log" % obj,
                            "objects",
                        )
                    param = "weights"
                    name = "Connection_%s" % obj.label

                self.summaries.append(
                    (utils.sanitize_name("%s_%s" % (name, param)), obj, param)
                )
            else:
                raise ValidationError(
                    "Unknown summary object %s; should be an Ensemble, Neurons, or "
                    "Connection" % obj,
                    "objects",
                )
コード例 #8
0
ファイル: tensor_node.py プロジェクト: Sreerag-ibtl/nengo-dl
    def coerce(self, node, func):
        """
        Performs validation on the function passed to TensorNode, and sets
        ``shape_out`` if necessary.

        Parameters
        ----------
        node : `.TensorNode`
            The node whose ``tensor_func`` parameter is being set.
        func : callable
            The function being assigned to the TensorNode.

        Returns
        -------
        output : callable
            The function after validation is applied.
        """

        output = super().coerce(node, func)

        if not callable(func):
            raise ValidationError(
                "TensorNode output must be a function or Keras Layer",
                attr=self.name,
                obj=node,
            )

        if node.shape_out is None:
            if isinstance(func, tf.keras.layers.Layer):
                # we can use Keras' static shape inference to get the
                # output shape, which avoids having to build/call the layer
                if node.pass_time:
                    input_spec = [tf.TensorSpec(())]
                else:
                    input_spec = []
                if node.shape_in is not None:
                    input_spec += [tf.TensorSpec((1,) + node.shape_in)]

                if len(input_spec) == 1:
                    input_spec = input_spec[0]

                ctx = contextlib.suppress() if eager_enabled() else context.eager_mode()

                try:
                    with ctx:
                        result = func.compute_output_signature(input_spec)
                except Exception as e:
                    raise ValidationError(
                        "Attempting to automatically determine TensorNode output shape "
                        "by calling Layer.compute_output_signature produced an error. "
                        "If you would like to avoid this step, try manually setting "
                        "`TensorNode(..., shape_out=x)`. The error is shown below:\n%s"
                        % repr(e),
                        attr=self.name,
                        obj=node,
                    )

            else:
                if node.pass_time:
                    args = (tf.constant(0.0),)
                else:
                    args = ()
                if node.shape_in is not None:
                    args += (tf.zeros((1,) + node.shape_in),)

                try:
                    result = func(*args)
                except Exception as e:
                    raise ValidationError(
                        "Attempting to automatically determine TensorNode output shape "
                        "by calling TensorNode function produced an error. "
                        "If you would like to avoid this step, try manually setting "
                        "`TensorNode(..., shape_out=x)`. The error is shown below:\n%s"
                        % e,
                        attr=self.name,
                        obj=node,
                    )

            validate_output(result)

            node.shape_out = result.shape[1:]

        return output
コード例 #9
0
def test_predict(Simulator, seed):
    n_steps = 100

    with nengo.Network(seed=seed) as net:
        a = nengo.Node([2], label="a")
        b = nengo.Ensemble(10, 1)
        nengo.Connection(a, b)
        p = nengo.Probe(b)

    with Simulator(net, minibatch_size=4) as sim:
        a_vals = np.ones((12, n_steps, 1))
        n_batches = a_vals.shape[0] // sim.minibatch_size

        sim.run_steps(n_steps)
        data_noinput = sim.data[p]
        sim.reset(include_trainable=False, include_processes=False)

        sim.run_steps(n_steps, data={a: a_vals[:4]})
        data_tile = np.tile(sim.data[p], (n_batches, 1, 1))
        sim.reset(include_probes=False,
                  include_trainable=False,
                  include_processes=False)

        # no input (also checking batch_size is ignored)
        with pytest.warns(UserWarning,
                          match="Batch size is determined statically"):
            output = sim.predict(n_steps=n_steps, batch_size=-1)
        assert np.allclose(output[p], data_noinput)

        # numpy input (single batch)
        output = sim.predict_on_batch(a_vals[:4])
        assert np.allclose(output[p], sim.data[p])

        # numpy input (multiple batches)
        output = sim.predict(a_vals)
        assert np.allclose(output[p], data_tile)

        # tf input
        if compat.eager_enabled():
            output = sim.predict(tf.constant(a_vals))
            assert np.allclose(output[p], data_tile)

        # dict input
        for key in [a, "a"]:
            output = sim.predict({key: a_vals})
            assert np.allclose(output[p], data_tile)

        # generator input
        output = sim.predict(
            ([
                a_vals[i * sim.minibatch_size:(i + 1) * sim.minibatch_size],
                np.ones((sim.minibatch_size, 1), dtype=np.int32) * n_steps,
            ] for i in range(n_batches)),
            steps=n_batches,
        )
        assert np.allclose(output[p], data_tile)

        # dataset input
        dataset = tf.data.Dataset.from_tensor_slices({
            "a":
            tf.constant(a_vals),
            "n_steps":
            tf.ones((12, 1), dtype=np.int32) * n_steps,
        }).batch(sim.minibatch_size)

        output = sim.predict(dataset)
        assert np.allclose(output[p], data_tile)
コード例 #10
0
def test_fit(Simulator, seed):
    minibatch_size = 4
    n_hidden = 20

    with nengo.Network(seed=seed) as net:
        net.config[nengo.Ensemble].gain = nengo.dists.Choice([1])
        net.config[nengo.Ensemble].bias = nengo.dists.Choice([0])
        net.config[nengo.Connection].synapse = None

        # note: we have these weird input setup just so that we can test
        # training with two distinct inputs
        inp_a = nengo.Node([0])
        inp_b = nengo.Node([0])
        inp = nengo.Node(size_in=2)
        nengo.Connection(inp_a, inp[0], transform=1)
        nengo.Connection(inp_b, inp[1], transform=1)

        ens = nengo.Ensemble(n_hidden + 1,
                             n_hidden,
                             neuron_type=nengo.Sigmoid(tau_ref=1))
        out = nengo.Ensemble(1, 1, neuron_type=nengo.Sigmoid(tau_ref=1))
        nengo.Connection(inp, ens.neurons, transform=dists.Glorot())
        nengo.Connection(ens.neurons, out.neurons, transform=dists.Glorot())

        nengo.Probe(out.neurons)

    with Simulator(net,
                   minibatch_size=minibatch_size,
                   unroll_simulation=1,
                   seed=seed) as sim:
        x = np.asarray([[[0.0, 0.0]], [[0.0, 1.0]], [[1.0, 0.0]], [[1.0,
                                                                    1.0]]])
        y = np.asarray([[[0.1]], [[0.9]], [[0.9]], [[0.1]]])

        sim.compile(optimizer=tf.optimizers.Adam(0.01), loss=tf.losses.mse)
        # note: batch_size should be ignored
        with pytest.warns(UserWarning,
                          match="Batch size is determined statically"):
            history = sim.fit(
                [x[..., [0]], x[..., [1]]],
                y,
                validation_data=([x[..., [0]], x[..., [1]]], y),
                epochs=200,
                verbose=0,
                batch_size=-1,
            )
        assert history.history["loss"][-1] < 5e-4
        assert history.history["val_loss"][-1] < 5e-4

        # check that validation_sample_weights work correctly
        history = sim.fit(
            [x[..., [0]], x[..., [1]]],
            y,
            validation_data=([x[..., [0]], x[...,
                                             [1]]], y, np.zeros(y.shape[0])),
            epochs=1,
            verbose=0,
        )
        assert np.allclose(history.history["val_loss"][-1], 0)

        if compat.eager_enabled():
            sim.reset()
            history = sim.fit(
                [tf.constant(x[..., [0]]),
                 tf.constant(x[..., [1]])],
                tf.constant(y),
                epochs=200,
                verbose=0,
            )
            assert history.history["loss"][-1] < 5e-4

        sim.reset()
        history = sim.fit(
            (((x[..., [0]], x[..., [1]], np.ones((4, 1), dtype=np.int32)), y)
             for _ in range(200)),
            epochs=20,
            steps_per_epoch=10,
            verbose=0,
        )
        assert history.history["loss"][-1] < 5e-4

        history = sim.fit(
            tf.data.Dataset.from_tensors(
                ((x[..., [0]], x[..., [1]], np.ones((4, 1),
                                                    dtype=np.int32)), y)),
            validation_data=tf.data.Dataset.from_tensors(
                ((x[..., [0]], x[..., [1]], np.ones((4, 1),
                                                    dtype=np.int32)), y)),
            epochs=200,
            verbose=0,
        )
        assert history.history["loss"][-1] < 5e-4
        assert history.history["val_loss"][-1] < 5e-4
コード例 #11
0
def test_evaluate(Simulator):
    minibatch_size = 3
    n_steps = 10
    n_batches = 2

    with nengo.Network() as net:
        inp0 = nengo.Node([0])
        inp1 = nengo.Node([0])
        p0 = nengo.Probe(inp0)
        p1 = nengo.Probe(inp1)

    with Simulator(net, minibatch_size=minibatch_size) as sim:
        # single probe
        sim.compile(loss={"probe": tf.losses.mse})
        targets = np.ones((minibatch_size, n_steps, 1))
        with pytest.warns(UserWarning,
                          match="Batch size is determined statically"):
            loss = sim.evaluate(n_steps=n_steps, y=targets, batch_size=-1)
        assert np.allclose(loss["loss"], 1)
        assert np.allclose(loss["probe_loss"], 1)
        assert "probe_1_loss" not in loss

        # multiple probes
        sim.compile(loss=tf.losses.mse)
        loss = sim.evaluate(n_steps=n_steps, y={p0: targets, p1: targets})
        assert np.allclose(loss["loss"], 2)
        assert np.allclose(loss["probe_loss"], 1)
        assert np.allclose(loss["probe_1_loss"], 1)

        # default inputs
        loss = sim.evaluate(
            y={
                p0: np.zeros((minibatch_size, n_steps, 1)),
                p1: np.zeros((minibatch_size, n_steps, 1)),
            },
            n_steps=n_steps,
        )
        assert np.allclose(loss["loss"], 0)
        assert np.allclose(loss["probe_loss"], 0)
        assert np.allclose(loss["probe_1_loss"], 0)

        # list inputs
        inputs = np.ones((minibatch_size * n_batches, n_steps, 1))
        targets = inputs.copy()
        loss = sim.evaluate(x=[inputs, inputs * 2],
                            y={
                                p0: targets,
                                p1: targets
                            })
        assert np.allclose(loss["loss"], 1)
        assert np.allclose(loss["probe_loss"], 0)
        assert np.allclose(loss["probe_1_loss"], 1)

        # tensor inputs
        if compat.eager_enabled():
            loss = sim.evaluate(
                x=[tf.constant(inputs),
                   tf.constant(inputs * 2)],
                y={
                    p0: tf.constant(targets),
                    p1: tf.constant(targets)
                },
            )
            assert np.allclose(loss["loss"], 1)
            assert np.allclose(loss["probe_loss"], 0)
            assert np.allclose(loss["probe_1_loss"], 1)

        gen = ((
            {
                "node": np.ones((minibatch_size, n_steps, 1)),
                "node_1": np.ones((minibatch_size, n_steps, 1)) * 2,
                "n_steps": np.ones((minibatch_size, 1)) * n_steps,
            },
            {
                "probe": np.ones((minibatch_size, n_steps, 1)),
                "probe_1": np.ones((minibatch_size, n_steps, 1)),
            },
        ) for _ in range(n_batches))

        loss = sim.evaluate(gen, steps=n_batches)
        assert np.allclose(loss["loss"], 1)
        assert np.allclose(loss["probe_loss"], 0)
        assert np.allclose(loss["probe_1_loss"], 1)

        # check custom objective
        def constant_error(y_true, y_pred):
            return tf.constant(3.0)

        sim.compile(loss={p0: constant_error})
        assert np.allclose(
            sim.evaluate(y={p0: np.zeros((minibatch_size, n_steps, 1))},
                         n_steps=n_steps)["loss"],
            3,
        )

        # test metrics
        sim.compile(
            loss=tf.losses.mse,
            metrics={
                p0: constant_error,
                p1: [constant_error, "mae"]
            },
        )
        output = sim.evaluate(
            y={
                p0: np.ones((minibatch_size, n_steps, 1)),
                p1: np.ones((minibatch_size, n_steps, 1)) * 2,
            },
            n_steps=n_steps,
        )
        assert np.allclose(output["loss"], 5)
        assert np.allclose(output["probe_loss"], 1)
        assert np.allclose(output["probe_1_loss"], 4)
        assert np.allclose(output["probe_constant_error"], 3)
        assert np.allclose(output["probe_1_constant_error"], 3)
        assert "probe_mae" not in output
        assert np.allclose(output["probe_1_mae"], 2)
コード例 #12
0
ファイル: callbacks.py プロジェクト: Sreerag-ibtl/nengo-dl
    def on_train_end(self, logs=None):
        """Close summary writer at end of training."""

        with contextlib.suppress() if compat.eager_enabled() else context.eager_mode():
            self.writer.close()
コード例 #13
0
ファイル: tensor_graph.py プロジェクト: Sreerag-ibtl/nengo-dl
    def call(self, inputs, training=None, progress=None, stateful=False):
        """
        Constructs the graph elements to simulate the model.

        Parameters
        ----------
        inputs : list of ``tf.Tensor``
            Input layers/tensors for the network (must match the structure defined in
            `.build_inputs`).
        training : bool
            Whether the network is being run in training or inference mode.  If None,
            uses the symbolic Keras learning phase variable.
        progress : `.utils.ProgressBar`
            Progress bar for construction stage.
        stateful : bool
            Whether or not to build the model to support preserving the internal state
            between executions.

        Returns
        -------
        probe_arrays : list of ``tf.Tensor``
            Tensors representing the output of all the Probes in the network (order
            corresponding to ``self.model.probes``, which is the order the Probes were
            instantiated).
        """

        override_training = config.get_setting(self.model, "learning_phase", None)
        training = training if override_training is None else override_training

        super().call(inputs, training=training)

        if training is True and self.inference_only:
            raise BuildError(
                "TensorGraph was created with inference_only=True; cannot be called "
                "with training=%s" % training
            )

        tf.random.set_seed(self.seed)

        if progress is None:
            progress = utils.NullProgressBar()

        # reset signaldict
        self.signals.reset()

        # create these constants once here for reuse in different operators
        self.signals.dt = tf.constant(self.dt, self.dtype)
        self.signals.dt_val = self.dt  # store the actual value as well
        self.signals.zero = tf.constant(0, self.dtype)
        self.signals.one = tf.constant(1, self.dtype)

        # set up invariant inputs
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.node_inputs = {}
        for n, inp in zip(self.invariant_inputs, inputs):
            # specify shape of inputs (keras sometimes loses this shape information)
            inp.set_shape([self.minibatch_size, inp.shape[1], n.size_out])

            self.node_inputs[n] = inp

        self.steps_to_run = inputs[-1][0, 0]

        # set up build config
        # TODO: it would be nicer if buildconfig was static (i.e. find a separate
        #  way to pass around `training`)
        build_config = builder.BuildConfig(
            inference_only=self.inference_only,
            lif_smoothing=config.get_setting(self.model, "lif_smoothing"),
            cpu_only=self.device == "/cpu:0" or not utils.tf_gpu_installed,
            rng=np.random.RandomState(self.seed),
            training=(
                tf.keras.backend.learning_phase() if training is None else training
            ),
        )

        # pre-build stage
        with progress.sub("pre-build stage", max_value=len(self.plan)) as sub:
            self.op_builder.build_pre(self.signals, build_config, sub)

        # build stage
        with progress.sub("build stage", max_value=len(self.plan) * self.unroll) as sub:
            steps_run, probe_arrays, final_internal_state, final_base_params = (
                self._build_loop(sub) if self.use_loop else self._build_no_loop(sub)
            )

        # store these so that they can be accessed after the initial build
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.steps_run = steps_run
            self.probe_arrays = probe_arrays
            self.final_internal_state = final_internal_state
            self.final_base_params = final_base_params

        # logging
        logger.info(
            "Number of reads: %d", sum(x for x in self.signals.read_types.values())
        )
        for x in self.signals.read_types.items():
            logger.info("    %s: %d", *x)
        logger.info(
            "Number of writes: %d", sum(x for x in self.signals.write_types.values())
        )
        for x in self.signals.write_types.items():
            logger.info("    %s: %d", *x)

        # note: always return steps_run so that the simulation will run for the given
        # number of steps, even if there are no output probes
        outputs = list(probe_arrays.values()) + [steps_run]

        updates = []
        if stateful:
            # update saved state
            for var, val in zip(self.saved_state.values(), final_internal_state):
                updates.append(var.assign(val))

        # if any of the base params have changed (due to online learning rules) then we
        # also need to assign those back to the original variable (so that their
        # values will persist). any parameters targeted by online learning rules
        # will be minibatched, so we only need to update the minibatched params.
        for (key, var), val in zip(self.base_params.items(), final_base_params):
            try:
                minibatched = self.base_arrays_init["non_trainable"][key][-1]
            except KeyError:
                minibatched = self.base_arrays_init["trainable"][key][-1]

            if minibatched:
                updates.append(var.assign(val))

        logger.info("Number of state updates: %d", len(updates))

        if not compat.eager_enabled() and len(updates) > 0:
            with tf.control_dependencies(updates):
                outputs = [tf.identity(x) for x in outputs]

        return outputs
コード例 #14
0
ファイル: tensor_graph.py プロジェクト: Sreerag-ibtl/nengo-dl
    def build(self, input_shape=None):
        """
        Create any Variables used in the model.

        Parameters
        ----------
        input_shape : list of tuple of int
            Shapes of all the inputs to this layer.
        """

        super().build(input_shape)

        tf.random.set_seed(self.seed)

        def get_initializer(init_vals):
            """Use more efficient initializers if possible to save memory."""

            values, shapes, dtype, minibatched = init_vals

            # initial value of None means that the initial value isn't used, so we
            # can use anything for the initial value
            if all(v is None for v in values):
                initializer = None
            elif all(v is None or np.all(v == 0) for v in values):
                initializer = tf.initializers.zeros()
            elif all(v is None or np.all(v == 1) for v in values):
                initializer = tf.initializers.ones()
            else:
                val = tf.concat(
                    [
                        tf.zeros(s, dtype)
                        if v is None
                        else tf.cast(tf.broadcast_to(v, s), dtype)
                        for v, s in zip(values, shapes)
                    ],
                    axis=1 if minibatched else 0,
                )
                initializer = lambda shape=None, dtype=None: val

            # figure out shape of full concatenated initial value
            shape = list(shapes[0])
            shape[minibatched] = sum(x[minibatched] for x in shapes)

            return initializer, tuple(shape), dtype

        # save initializers so that we can reset the model later
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.initial_values = {}

        # variables for model parameters
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.base_params = OrderedDict()
        assert len(self.base_params) == 0
        for sig_type in ("trainable", "non_trainable"):
            for k, v in self.base_arrays_init[sig_type].items():
                initializer, shape, dtype = get_initializer(v)
                assert initializer is not None  # params should never be set
                self.base_params[k] = self.add_weight(
                    initializer=initializer,
                    shape=shape,
                    dtype=dtype,
                    trainable=sig_type == "trainable",
                    name="base_params/%s_%s_%s"
                    % (sig_type, dtype, "_".join(str(x) for x in shape)),
                )

                self.initial_values[k] = initializer

        logger.debug("created base param variables")
        logger.debug([str(x) for x in self.base_params.values()])

        # variables to save the internal state of simulation between runs
        with trackable.no_automatic_dependency_tracking_scope(self):
            self.saved_state = OrderedDict()
        for k, v in self.base_arrays_init["state"].items():
            initializer, shape, dtype = get_initializer(v)
            if initializer is not None:
                # don't need to save the state for signals where the initial value
                # doesn't matter
                self.saved_state[k] = tf.Variable(
                    initial_value=lambda: initializer(shape=shape, dtype=dtype),
                    shape=shape,
                    dtype=dtype,
                    trainable=False,
                    name="saved_state/%s_%s" % (dtype, "_".join(str(x) for x in shape)),
                )

                self.initial_values[k] = initializer

        logger.debug("created saved state variables")
        logger.debug([str(x) for x in self.saved_state.values()])

        # call build on any TensorNode Layers

        def unbuild(layer):
            assert layer.built

            # clear any losses attached to layer (they will be recreated in the
            # build step, so we don't want to keep around any losses
            # associated with the previous build)
            # note: not clearing layer._losses, because those are manually added
            # by the user (not created during the build process)
            layer._eager_losses = []
            layer._callable_losses = []

            layer.built = False

            for sub in layer._layers:
                if isinstance(sub, tf.keras.layers.Layer):
                    unbuild(sub)

        layer_ops = [
            op
            for ops in self.plan
            if isinstance(ops[0], tensor_node.SimTensorNode)
            for op in ops
            if isinstance(op.func, tf.keras.layers.Layer)
        ]
        weight_gets = []
        weight_sets = []
        for op in layer_ops:
            if op.func in self._layers:
                # already built this layer
                continue

            if op.time is None:
                shape_in = []
            else:
                shape_in = [()]
            if op.input is not None:
                shape_in += [(self.minibatch_size,) + op.shape_in]
            if len(shape_in) == 1:
                shape_in = shape_in[0]

            if op.func.built:
                # we rebuild the layer (even if it is already built),
                # because we need to build the weights within the TensorGraph
                # context

                # save the weight values so they can be restored
                # exactly inside the tensornode
                weights = op.func.weights
                weight_gets.extend(weights)

                # clear the results of previous build
                unbuild(op.func)
            else:
                weights = None

            with tf.name_scope(op.func.name):
                op.func.build(shape_in)

            if weights is not None:
                weight_sets.extend(op.func.weights)

            # add op func to _layers so that any weights are collected
            self._layers.append(op.func)

        if len(weight_gets) > 0:
            # do all the weight getting/setting in one go, for efficiency reasons

            # match the fetch context to the context in which the weights were created
            ctx = (
                weight_gets[0].graph.as_default()
                if hasattr(weight_gets[0], "graph")
                else context.eager_mode()
            )
            with ctx:
                weight_vals = tf.keras.backend.batch_get_value(weight_gets)

            tf.keras.backend.batch_set_value(zip(weight_sets, weight_vals))

        if not compat.eager_enabled():
            # initialize state variables (need to do this manually because we're not
            # adding them to self.weights)
            tf.keras.backend.batch_get_value(
                [var.initializer for var in self.saved_state.values()]
            )