def testTrainNetworkByCallForEachReplica(self, distribution, optimizer_fn, use_callable_loss): with distribution.scope(): optimizer = optimizer_fn() model_fn, dataset_fn, layer = minimize_loss_example( optimizer, use_bias=True, use_callable_loss=use_callable_loss) iterator = self._get_iterator(distribution, dataset_fn) def run_step(): return distribution.group( distribution.extended.call_for_each_replica( model_fn, args=(iterator.get_next(), ))) if not tf.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(tf.compat.v1.global_variables_initializer()) weights, biases = [], [] for _ in range(10): run_step() weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) error = abs( numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing)
def testOptimizerInsideModelFn(self, distribution, optimizer_fn): if (not tf.executing_eagerly() and tf.compat.v1.control_flow_v2_enabled()): self.skipTest("b/138751864") created_variables = [] trainable_variables = [] def appending_creator(next_creator, **kwargs): v = next_creator(**kwargs) created_variables.append(v.name) if "trainable" in kwargs and kwargs["trainable"]: trainable_variables.append(v.name) return v # Creator scope needs to be set before it's used inside # `distribution.scope`. with tf.variable_creator_scope( appending_creator), distribution.scope(): optimizer = optimizer_fn() model_fn, dataset_fn, _ = minimize_loss_example( optimizer, use_bias=True, use_callable_loss=True) def step_fn(ctx, inputs): del ctx # Unused return distribution.group( distribution.extended.call_for_each_replica( model_fn, args=(inputs, ))) iterator = self._get_iterator(distribution, dataset_fn) def run_step(): return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op if not tf.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(tf.compat.v1.global_variables_initializer()) run_step() def get_expected_variables(num_parameter_devices): name = optimizer._name if isinstance(optimizer, optimizer_v2.OptimizerV2): variables = VAR_MAP_V2[name] else: variables = VAR_MAP_V1[name] extended_variables = [ v + "/replica_{}".format(replica) for v in variables for replica in range(1, num_parameter_devices) ] variables = list(variables) + extended_variables return set(v + ":0" for v in variables) self.assertEqual( get_expected_variables( len(distribution.extended.parameter_devices)), set(created_variables))