def build_step(self, signals): J = signals.gather(self.J_data) states = [signals.gather(x) for x in self.state_data] states_dtype = [x.dtype for x in self.state_data] if compat.eager_enabled(): # noop control_deps = contextlib.suppress() else: # we need to make sure that the previous call to this function # has completed before the next starts, since we don't know that the # functions are thread safe control_deps = tf.control_dependencies(self.prev_result) with control_deps: ret = tf.numpy_function( self.neuron_step, [signals.dt, J] + states, [self.output_data.dtype] + states_dtype, name=self.neuron_step.__name__, ) neuron_out, state_out = ret[0], ret[1:] self.prev_result = [neuron_out] neuron_out.set_shape((signals.minibatch_size, ) + self.output_data.shape) signals.scatter(self.output_data, neuron_out) for i, s in enumerate(self.state_data): state_out[i].set_shape((signals.minibatch_size, ) + s.shape) signals.scatter(s, state_out[i])
def test_uneven_validation_split(Simulator): net, _, _ = dummies.linear_net() with Simulator(net, minibatch_size=2) as sim: sim.compile(optimizer=tf.optimizers.SGD(0), loss=tf.losses.mse) sim.fit(np.zeros((10, 10, 1)), np.zeros((10, 10, 1)), validation_split=0.2) with pytest.raises(ValidationError, match="not evenly divisible"): sim.fit(np.zeros((10, 10, 1)), np.zeros((10, 10, 1)), validation_split=0.3) # regular keras error message when trying to use validation_split with # a generator with pytest.raises( ValueError, match="`validation_split` is only supported for Tensors" if compat.eager_enabled() else "cannot use `validation_split`", ): sim.fit( ((x, y) for x, y in zip(np.zeros((10, 10, 1)), np.zeros((10, 10, 1)))), validation_split=0.7, )
def test_sequential(seed): tf.random.set_seed(seed) model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(32, input_shape=(4,))) model.add(tf.keras.layers.Dense(32)) conv = converter.Converter(model, allow_fallback=False) assert conv.verify(training=False) # TODO: not sure why this is slightly less accurate in graph mode assert conv.verify(training=True, atol=1e-8 if compat.eager_enabled() else 1e-7)
def on_epoch_end(self, epoch, logs=None): """Log parameter values at the end of each epoch.""" summary_vals = self.sim.data.get_params( *[(obj, attr) for _, obj, attr in self.summaries] ) with ( contextlib.suppress() if compat.eager_enabled() else context.eager_mode() ), self.writer.as_default(): for (name, _, _), val in zip(self.summaries, summary_vals): tf.summary.histogram(name, val, step=epoch)
def test_multi_input_warning(Simulator): with nengo.Network() as net: inp0 = nengo.Node([0]) inp1 = nengo.Node([1]) nengo.Probe(inp0) nengo.Probe(inp1) with Simulator(net) as sim: with pytest.warns(UserWarning, match="does not match number of Nodes"): sim.predict(np.zeros((1, 10, 1))) with pytest.warns(UserWarning, match="does not match number of Nodes"): sim.predict([np.zeros((1, 10, 1))]) with pytest.warns(None) as recwarn: sim.predict([np.zeros((1, 10, 1))] * 2) assert not any("does not match number of Nodes" in str(w.message) for w in recwarn) with pytest.warns(None) as recwarn: sim.predict({inp1: np.zeros((1, 10, 1))}) assert not any("does not match number of Nodes" in str(w.message) for w in recwarn) sim.compile(optimizer=tf.optimizers.SGD(0), loss=tf.losses.mse) with pytest.warns(UserWarning, match="does not match number of Nodes"): sim.fit(np.zeros((1, 10, 1)), [np.zeros((1, 10, 1))] * 2) with pytest.warns(UserWarning, match="does not match number of Nodes"): sim.evaluate(np.zeros((1, 10, 1)), [np.zeros((1, 10, 1))] * 2) with pytest.warns(UserWarning, match="does not match number of Probes" ) if compat.eager_enabled() else pytest.raises( ValueError, match="No data provided for"): sim.evaluate([np.zeros((1, 10, 1))] * 2, np.zeros((1, 10, 1))) with pytest.warns(UserWarning, match="does not match number of Probes" ) if compat.eager_enabled() else pytest.raises( ValueError, match="No data provided for"): sim.evaluate([np.zeros((1, 10, 1))] * 2, np.zeros((1, 10, 1)))
def build_step(self, signals): time = [signals.gather(self.time_data)] input = [] if self.input_data is None else [ signals.gather(self.input_data) ] state = [signals.gather(s) for s in self.state_data] if compat.eager_enabled(): # noop control_deps = contextlib.suppress() else: # we need to make sure that the previous call to this function # has completed before the next starts, since we don't know that the # functions are thread safe control_deps = tf.control_dependencies(self.prev_result) with control_deps: result = tf.numpy_function( self.merged_func, time + input + state, [self.output_data.dtype] + [s.dtype for s in self.state_data], name=self.merged_func.__name__, ) # TensorFlow will automatically squeeze length-1 outputs (if there is # no state), which we don't want result = tf.nest.flatten(result) output = result[0] state = result[1:] self.prev_result = [output] output.set_shape(self.output_data.full_shape) signals.scatter(self.output_data, output, mode=self.mode) for i, s in enumerate(state): s.set_shape(self.state_data[i].full_shape) signals.scatter(self.state_data[i], s, mode="update")
def __init__(self, log_dir, sim, objects): super().__init__() self.sim = sim with contextlib.suppress() if compat.eager_enabled() else context.eager_mode(): self.writer = tf.summary.create_file_writer(log_dir) self.summaries = [] for obj in objects: if isinstance( obj, (nengo.Ensemble, nengo.ensemble.Neurons, nengo.Connection) ): if isinstance(obj, nengo.Ensemble): param = "encoders" name = "Ensemble_%s" % obj.label elif isinstance(obj, nengo.ensemble.Neurons): param = "bias" name = "Ensemble.neurons_%s" % obj.ensemble.label elif isinstance(obj, nengo.Connection): if not compat.conn_has_weights(obj): raise ValidationError( "Connection '%s' does not have any weights to log" % obj, "objects", ) param = "weights" name = "Connection_%s" % obj.label self.summaries.append( (utils.sanitize_name("%s_%s" % (name, param)), obj, param) ) else: raise ValidationError( "Unknown summary object %s; should be an Ensemble, Neurons, or " "Connection" % obj, "objects", )
def coerce(self, node, func): """ Performs validation on the function passed to TensorNode, and sets ``shape_out`` if necessary. Parameters ---------- node : `.TensorNode` The node whose ``tensor_func`` parameter is being set. func : callable The function being assigned to the TensorNode. Returns ------- output : callable The function after validation is applied. """ output = super().coerce(node, func) if not callable(func): raise ValidationError( "TensorNode output must be a function or Keras Layer", attr=self.name, obj=node, ) if node.shape_out is None: if isinstance(func, tf.keras.layers.Layer): # we can use Keras' static shape inference to get the # output shape, which avoids having to build/call the layer if node.pass_time: input_spec = [tf.TensorSpec(())] else: input_spec = [] if node.shape_in is not None: input_spec += [tf.TensorSpec((1,) + node.shape_in)] if len(input_spec) == 1: input_spec = input_spec[0] ctx = contextlib.suppress() if eager_enabled() else context.eager_mode() try: with ctx: result = func.compute_output_signature(input_spec) except Exception as e: raise ValidationError( "Attempting to automatically determine TensorNode output shape " "by calling Layer.compute_output_signature produced an error. " "If you would like to avoid this step, try manually setting " "`TensorNode(..., shape_out=x)`. The error is shown below:\n%s" % repr(e), attr=self.name, obj=node, ) else: if node.pass_time: args = (tf.constant(0.0),) else: args = () if node.shape_in is not None: args += (tf.zeros((1,) + node.shape_in),) try: result = func(*args) except Exception as e: raise ValidationError( "Attempting to automatically determine TensorNode output shape " "by calling TensorNode function produced an error. " "If you would like to avoid this step, try manually setting " "`TensorNode(..., shape_out=x)`. The error is shown below:\n%s" % e, attr=self.name, obj=node, ) validate_output(result) node.shape_out = result.shape[1:] return output
def test_predict(Simulator, seed): n_steps = 100 with nengo.Network(seed=seed) as net: a = nengo.Node([2], label="a") b = nengo.Ensemble(10, 1) nengo.Connection(a, b) p = nengo.Probe(b) with Simulator(net, minibatch_size=4) as sim: a_vals = np.ones((12, n_steps, 1)) n_batches = a_vals.shape[0] // sim.minibatch_size sim.run_steps(n_steps) data_noinput = sim.data[p] sim.reset(include_trainable=False, include_processes=False) sim.run_steps(n_steps, data={a: a_vals[:4]}) data_tile = np.tile(sim.data[p], (n_batches, 1, 1)) sim.reset(include_probes=False, include_trainable=False, include_processes=False) # no input (also checking batch_size is ignored) with pytest.warns(UserWarning, match="Batch size is determined statically"): output = sim.predict(n_steps=n_steps, batch_size=-1) assert np.allclose(output[p], data_noinput) # numpy input (single batch) output = sim.predict_on_batch(a_vals[:4]) assert np.allclose(output[p], sim.data[p]) # numpy input (multiple batches) output = sim.predict(a_vals) assert np.allclose(output[p], data_tile) # tf input if compat.eager_enabled(): output = sim.predict(tf.constant(a_vals)) assert np.allclose(output[p], data_tile) # dict input for key in [a, "a"]: output = sim.predict({key: a_vals}) assert np.allclose(output[p], data_tile) # generator input output = sim.predict( ([ a_vals[i * sim.minibatch_size:(i + 1) * sim.minibatch_size], np.ones((sim.minibatch_size, 1), dtype=np.int32) * n_steps, ] for i in range(n_batches)), steps=n_batches, ) assert np.allclose(output[p], data_tile) # dataset input dataset = tf.data.Dataset.from_tensor_slices({ "a": tf.constant(a_vals), "n_steps": tf.ones((12, 1), dtype=np.int32) * n_steps, }).batch(sim.minibatch_size) output = sim.predict(dataset) assert np.allclose(output[p], data_tile)
def test_fit(Simulator, seed): minibatch_size = 4 n_hidden = 20 with nengo.Network(seed=seed) as net: net.config[nengo.Ensemble].gain = nengo.dists.Choice([1]) net.config[nengo.Ensemble].bias = nengo.dists.Choice([0]) net.config[nengo.Connection].synapse = None # note: we have these weird input setup just so that we can test # training with two distinct inputs inp_a = nengo.Node([0]) inp_b = nengo.Node([0]) inp = nengo.Node(size_in=2) nengo.Connection(inp_a, inp[0], transform=1) nengo.Connection(inp_b, inp[1], transform=1) ens = nengo.Ensemble(n_hidden + 1, n_hidden, neuron_type=nengo.Sigmoid(tau_ref=1)) out = nengo.Ensemble(1, 1, neuron_type=nengo.Sigmoid(tau_ref=1)) nengo.Connection(inp, ens.neurons, transform=dists.Glorot()) nengo.Connection(ens.neurons, out.neurons, transform=dists.Glorot()) nengo.Probe(out.neurons) with Simulator(net, minibatch_size=minibatch_size, unroll_simulation=1, seed=seed) as sim: x = np.asarray([[[0.0, 0.0]], [[0.0, 1.0]], [[1.0, 0.0]], [[1.0, 1.0]]]) y = np.asarray([[[0.1]], [[0.9]], [[0.9]], [[0.1]]]) sim.compile(optimizer=tf.optimizers.Adam(0.01), loss=tf.losses.mse) # note: batch_size should be ignored with pytest.warns(UserWarning, match="Batch size is determined statically"): history = sim.fit( [x[..., [0]], x[..., [1]]], y, validation_data=([x[..., [0]], x[..., [1]]], y), epochs=200, verbose=0, batch_size=-1, ) assert history.history["loss"][-1] < 5e-4 assert history.history["val_loss"][-1] < 5e-4 # check that validation_sample_weights work correctly history = sim.fit( [x[..., [0]], x[..., [1]]], y, validation_data=([x[..., [0]], x[..., [1]]], y, np.zeros(y.shape[0])), epochs=1, verbose=0, ) assert np.allclose(history.history["val_loss"][-1], 0) if compat.eager_enabled(): sim.reset() history = sim.fit( [tf.constant(x[..., [0]]), tf.constant(x[..., [1]])], tf.constant(y), epochs=200, verbose=0, ) assert history.history["loss"][-1] < 5e-4 sim.reset() history = sim.fit( (((x[..., [0]], x[..., [1]], np.ones((4, 1), dtype=np.int32)), y) for _ in range(200)), epochs=20, steps_per_epoch=10, verbose=0, ) assert history.history["loss"][-1] < 5e-4 history = sim.fit( tf.data.Dataset.from_tensors( ((x[..., [0]], x[..., [1]], np.ones((4, 1), dtype=np.int32)), y)), validation_data=tf.data.Dataset.from_tensors( ((x[..., [0]], x[..., [1]], np.ones((4, 1), dtype=np.int32)), y)), epochs=200, verbose=0, ) assert history.history["loss"][-1] < 5e-4 assert history.history["val_loss"][-1] < 5e-4
def test_evaluate(Simulator): minibatch_size = 3 n_steps = 10 n_batches = 2 with nengo.Network() as net: inp0 = nengo.Node([0]) inp1 = nengo.Node([0]) p0 = nengo.Probe(inp0) p1 = nengo.Probe(inp1) with Simulator(net, minibatch_size=minibatch_size) as sim: # single probe sim.compile(loss={"probe": tf.losses.mse}) targets = np.ones((minibatch_size, n_steps, 1)) with pytest.warns(UserWarning, match="Batch size is determined statically"): loss = sim.evaluate(n_steps=n_steps, y=targets, batch_size=-1) assert np.allclose(loss["loss"], 1) assert np.allclose(loss["probe_loss"], 1) assert "probe_1_loss" not in loss # multiple probes sim.compile(loss=tf.losses.mse) loss = sim.evaluate(n_steps=n_steps, y={p0: targets, p1: targets}) assert np.allclose(loss["loss"], 2) assert np.allclose(loss["probe_loss"], 1) assert np.allclose(loss["probe_1_loss"], 1) # default inputs loss = sim.evaluate( y={ p0: np.zeros((minibatch_size, n_steps, 1)), p1: np.zeros((minibatch_size, n_steps, 1)), }, n_steps=n_steps, ) assert np.allclose(loss["loss"], 0) assert np.allclose(loss["probe_loss"], 0) assert np.allclose(loss["probe_1_loss"], 0) # list inputs inputs = np.ones((minibatch_size * n_batches, n_steps, 1)) targets = inputs.copy() loss = sim.evaluate(x=[inputs, inputs * 2], y={ p0: targets, p1: targets }) assert np.allclose(loss["loss"], 1) assert np.allclose(loss["probe_loss"], 0) assert np.allclose(loss["probe_1_loss"], 1) # tensor inputs if compat.eager_enabled(): loss = sim.evaluate( x=[tf.constant(inputs), tf.constant(inputs * 2)], y={ p0: tf.constant(targets), p1: tf.constant(targets) }, ) assert np.allclose(loss["loss"], 1) assert np.allclose(loss["probe_loss"], 0) assert np.allclose(loss["probe_1_loss"], 1) gen = (( { "node": np.ones((minibatch_size, n_steps, 1)), "node_1": np.ones((minibatch_size, n_steps, 1)) * 2, "n_steps": np.ones((minibatch_size, 1)) * n_steps, }, { "probe": np.ones((minibatch_size, n_steps, 1)), "probe_1": np.ones((minibatch_size, n_steps, 1)), }, ) for _ in range(n_batches)) loss = sim.evaluate(gen, steps=n_batches) assert np.allclose(loss["loss"], 1) assert np.allclose(loss["probe_loss"], 0) assert np.allclose(loss["probe_1_loss"], 1) # check custom objective def constant_error(y_true, y_pred): return tf.constant(3.0) sim.compile(loss={p0: constant_error}) assert np.allclose( sim.evaluate(y={p0: np.zeros((minibatch_size, n_steps, 1))}, n_steps=n_steps)["loss"], 3, ) # test metrics sim.compile( loss=tf.losses.mse, metrics={ p0: constant_error, p1: [constant_error, "mae"] }, ) output = sim.evaluate( y={ p0: np.ones((minibatch_size, n_steps, 1)), p1: np.ones((minibatch_size, n_steps, 1)) * 2, }, n_steps=n_steps, ) assert np.allclose(output["loss"], 5) assert np.allclose(output["probe_loss"], 1) assert np.allclose(output["probe_1_loss"], 4) assert np.allclose(output["probe_constant_error"], 3) assert np.allclose(output["probe_1_constant_error"], 3) assert "probe_mae" not in output assert np.allclose(output["probe_1_mae"], 2)
def on_train_end(self, logs=None): """Close summary writer at end of training.""" with contextlib.suppress() if compat.eager_enabled() else context.eager_mode(): self.writer.close()
def call(self, inputs, training=None, progress=None, stateful=False): """ Constructs the graph elements to simulate the model. Parameters ---------- inputs : list of ``tf.Tensor`` Input layers/tensors for the network (must match the structure defined in `.build_inputs`). training : bool Whether the network is being run in training or inference mode. If None, uses the symbolic Keras learning phase variable. progress : `.utils.ProgressBar` Progress bar for construction stage. stateful : bool Whether or not to build the model to support preserving the internal state between executions. Returns ------- probe_arrays : list of ``tf.Tensor`` Tensors representing the output of all the Probes in the network (order corresponding to ``self.model.probes``, which is the order the Probes were instantiated). """ override_training = config.get_setting(self.model, "learning_phase", None) training = training if override_training is None else override_training super().call(inputs, training=training) if training is True and self.inference_only: raise BuildError( "TensorGraph was created with inference_only=True; cannot be called " "with training=%s" % training ) tf.random.set_seed(self.seed) if progress is None: progress = utils.NullProgressBar() # reset signaldict self.signals.reset() # create these constants once here for reuse in different operators self.signals.dt = tf.constant(self.dt, self.dtype) self.signals.dt_val = self.dt # store the actual value as well self.signals.zero = tf.constant(0, self.dtype) self.signals.one = tf.constant(1, self.dtype) # set up invariant inputs with trackable.no_automatic_dependency_tracking_scope(self): self.node_inputs = {} for n, inp in zip(self.invariant_inputs, inputs): # specify shape of inputs (keras sometimes loses this shape information) inp.set_shape([self.minibatch_size, inp.shape[1], n.size_out]) self.node_inputs[n] = inp self.steps_to_run = inputs[-1][0, 0] # set up build config # TODO: it would be nicer if buildconfig was static (i.e. find a separate # way to pass around `training`) build_config = builder.BuildConfig( inference_only=self.inference_only, lif_smoothing=config.get_setting(self.model, "lif_smoothing"), cpu_only=self.device == "/cpu:0" or not utils.tf_gpu_installed, rng=np.random.RandomState(self.seed), training=( tf.keras.backend.learning_phase() if training is None else training ), ) # pre-build stage with progress.sub("pre-build stage", max_value=len(self.plan)) as sub: self.op_builder.build_pre(self.signals, build_config, sub) # build stage with progress.sub("build stage", max_value=len(self.plan) * self.unroll) as sub: steps_run, probe_arrays, final_internal_state, final_base_params = ( self._build_loop(sub) if self.use_loop else self._build_no_loop(sub) ) # store these so that they can be accessed after the initial build with trackable.no_automatic_dependency_tracking_scope(self): self.steps_run = steps_run self.probe_arrays = probe_arrays self.final_internal_state = final_internal_state self.final_base_params = final_base_params # logging logger.info( "Number of reads: %d", sum(x for x in self.signals.read_types.values()) ) for x in self.signals.read_types.items(): logger.info(" %s: %d", *x) logger.info( "Number of writes: %d", sum(x for x in self.signals.write_types.values()) ) for x in self.signals.write_types.items(): logger.info(" %s: %d", *x) # note: always return steps_run so that the simulation will run for the given # number of steps, even if there are no output probes outputs = list(probe_arrays.values()) + [steps_run] updates = [] if stateful: # update saved state for var, val in zip(self.saved_state.values(), final_internal_state): updates.append(var.assign(val)) # if any of the base params have changed (due to online learning rules) then we # also need to assign those back to the original variable (so that their # values will persist). any parameters targeted by online learning rules # will be minibatched, so we only need to update the minibatched params. for (key, var), val in zip(self.base_params.items(), final_base_params): try: minibatched = self.base_arrays_init["non_trainable"][key][-1] except KeyError: minibatched = self.base_arrays_init["trainable"][key][-1] if minibatched: updates.append(var.assign(val)) logger.info("Number of state updates: %d", len(updates)) if not compat.eager_enabled() and len(updates) > 0: with tf.control_dependencies(updates): outputs = [tf.identity(x) for x in outputs] return outputs
def build(self, input_shape=None): """ Create any Variables used in the model. Parameters ---------- input_shape : list of tuple of int Shapes of all the inputs to this layer. """ super().build(input_shape) tf.random.set_seed(self.seed) def get_initializer(init_vals): """Use more efficient initializers if possible to save memory.""" values, shapes, dtype, minibatched = init_vals # initial value of None means that the initial value isn't used, so we # can use anything for the initial value if all(v is None for v in values): initializer = None elif all(v is None or np.all(v == 0) for v in values): initializer = tf.initializers.zeros() elif all(v is None or np.all(v == 1) for v in values): initializer = tf.initializers.ones() else: val = tf.concat( [ tf.zeros(s, dtype) if v is None else tf.cast(tf.broadcast_to(v, s), dtype) for v, s in zip(values, shapes) ], axis=1 if minibatched else 0, ) initializer = lambda shape=None, dtype=None: val # figure out shape of full concatenated initial value shape = list(shapes[0]) shape[minibatched] = sum(x[minibatched] for x in shapes) return initializer, tuple(shape), dtype # save initializers so that we can reset the model later with trackable.no_automatic_dependency_tracking_scope(self): self.initial_values = {} # variables for model parameters with trackable.no_automatic_dependency_tracking_scope(self): self.base_params = OrderedDict() assert len(self.base_params) == 0 for sig_type in ("trainable", "non_trainable"): for k, v in self.base_arrays_init[sig_type].items(): initializer, shape, dtype = get_initializer(v) assert initializer is not None # params should never be set self.base_params[k] = self.add_weight( initializer=initializer, shape=shape, dtype=dtype, trainable=sig_type == "trainable", name="base_params/%s_%s_%s" % (sig_type, dtype, "_".join(str(x) for x in shape)), ) self.initial_values[k] = initializer logger.debug("created base param variables") logger.debug([str(x) for x in self.base_params.values()]) # variables to save the internal state of simulation between runs with trackable.no_automatic_dependency_tracking_scope(self): self.saved_state = OrderedDict() for k, v in self.base_arrays_init["state"].items(): initializer, shape, dtype = get_initializer(v) if initializer is not None: # don't need to save the state for signals where the initial value # doesn't matter self.saved_state[k] = tf.Variable( initial_value=lambda: initializer(shape=shape, dtype=dtype), shape=shape, dtype=dtype, trainable=False, name="saved_state/%s_%s" % (dtype, "_".join(str(x) for x in shape)), ) self.initial_values[k] = initializer logger.debug("created saved state variables") logger.debug([str(x) for x in self.saved_state.values()]) # call build on any TensorNode Layers def unbuild(layer): assert layer.built # clear any losses attached to layer (they will be recreated in the # build step, so we don't want to keep around any losses # associated with the previous build) # note: not clearing layer._losses, because those are manually added # by the user (not created during the build process) layer._eager_losses = [] layer._callable_losses = [] layer.built = False for sub in layer._layers: if isinstance(sub, tf.keras.layers.Layer): unbuild(sub) layer_ops = [ op for ops in self.plan if isinstance(ops[0], tensor_node.SimTensorNode) for op in ops if isinstance(op.func, tf.keras.layers.Layer) ] weight_gets = [] weight_sets = [] for op in layer_ops: if op.func in self._layers: # already built this layer continue if op.time is None: shape_in = [] else: shape_in = [()] if op.input is not None: shape_in += [(self.minibatch_size,) + op.shape_in] if len(shape_in) == 1: shape_in = shape_in[0] if op.func.built: # we rebuild the layer (even if it is already built), # because we need to build the weights within the TensorGraph # context # save the weight values so they can be restored # exactly inside the tensornode weights = op.func.weights weight_gets.extend(weights) # clear the results of previous build unbuild(op.func) else: weights = None with tf.name_scope(op.func.name): op.func.build(shape_in) if weights is not None: weight_sets.extend(op.func.weights) # add op func to _layers so that any weights are collected self._layers.append(op.func) if len(weight_gets) > 0: # do all the weight getting/setting in one go, for efficiency reasons # match the fetch context to the context in which the weights were created ctx = ( weight_gets[0].graph.as_default() if hasattr(weight_gets[0], "graph") else context.eager_mode() ) with ctx: weight_vals = tf.keras.backend.batch_get_value(weight_gets) tf.keras.backend.batch_set_value(zip(weight_sets, weight_vals)) if not compat.eager_enabled(): # initialize state variables (need to do this manually because we're not # adding them to self.weights) tf.keras.backend.batch_get_value( [var.initializer for var in self.saved_state.values()] )