def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( task_type, task_id, num_gpus) if task_type: # Multi-worker assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) if CHIEF in d.extended._cluster_spec.as_dict(): num_workers += 1 else: # local num_workers = 1 with ops.Graph().as_default(), \ self.cached_session(target=master_target, config=sess_config) as sess, \ d.scope(): kernel = strategy_test_lib.create_variable_like_keras_layer( 'kernel', (1, 1), dtypes.float32, ) def loss_fn(x): y = array_ops.reshape(math_ops.matmul(x, kernel), []) - constant_op.constant(1.) return y * y # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for # multiple graphs (b/111216820). def grad_fn(x): loss = loss_fn(x) var_list = (variables.trainable_variables() + ops.get_collection( ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) grads = gradients.gradients(loss, var_list) ret = list(zip(grads, var_list)) return ret def update(v, g): return v.assign_sub(0.05 * g, use_locking=True) one = constant_op.constant([[1.]]) def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. g_v = d.extended.call_for_each_replica(grad_fn, args=(one, )) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.extended.reduce_to(reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.extended.update(v, update, args=(g, ), group=False)): after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() if (not task_type or multi_worker_util.is_chief( d.extended._cluster_spec, task_type, task_id)): self.evaluate(variables.global_variables_initializer()) # Workers waiting for chief worker's initializing variables. self._init_condition.acquire() self._init_reached += 1 while self._init_reached != num_workers: self._init_condition.wait() self._init_condition.notify_all() self._init_condition.release() for i in range(10): b, a = sess.run((before_out, after_out)) if i == 0: before, = b after, = a error_before = abs(before - 1) error_after = abs(after - 1) # Error should go down self.assertLess(error_after, error_before)
def testRunStepsWithOutputContext(self, distribution, optimizer_fn, is_tpu): with distribution.scope(): def dataset_fn(): dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() # TODO(priyag): batch with drop_remainder=True causes shapes to be # fully defined for TPU. Remove this when XLA supports dynamic shapes. return dataset.batch(batch_size=1, drop_remainder=True) optimizer = optimizer_fn() kernel = strategy_test_lib.create_variable_like_keras_layer( "kernel", (1, 1), dtypes.float32) bias = strategy_test_lib.create_variable_like_keras_layer( "bias", (1, ), dtypes.float32) # layer = core.Dense(1, use_bias=True) key1 = "foo" value1 = "bar" def model_fn(output_context, x): """A very simple model written by the user.""" def loss_fn(): y = array_ops.reshape( nn_ops.bias_add(math_ops.matmul(x, kernel), bias), []) - constant_op.constant(1.) return y * y if strategy_test_lib.is_optimizer_v2_instance(optimizer): train_op = optimizer.minimize(loss_fn, lambda: [kernel, bias]) else: train_op = optimizer.minimize(loss_fn) loss = loss_fn() output_context.set_last_step_output( name="replica_loss_reduced", output=loss, reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_non_tensor_output(key1, value1) return (train_op, loss) def step_fn(output_context, inputs): (train_op, loss) = distribution.extended.call_for_each_replica( model_fn, args=(output_context, inputs)) output_context.set_last_step_output( name="cross_replica_loss_reduced", output=loss, reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_last_step_output( name="cross_replica_loss_not_reduced", output=loss) return distribution.group(train_op) iterator = self._get_iterator(distribution, dataset_fn) def run_step(): initial_loss = lambda: constant_op.constant(1e7) # Initial values corresponding to reduced losses are just single # tensors. But for non reduced losses, we need to have initial # values that are of the same structure as non reduced losses. In # MirroredStrategy, this will be a list of losses, in TPUStrategy # it will be single tensor. Using `call_for_each_replica` followed # by `experimental_local_results` gives us the desired initial # value structure. not_reduced = distribution.experimental_local_results( distribution.extended.call_for_each_replica(initial_loss)) initial_loop_values = { "replica_loss_reduced": initial_loss(), "cross_replica_loss_reduced": initial_loss(), "cross_replica_loss_not_reduced": not_reduced, } ctx = distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=2, initial_loop_values=initial_loop_values) self.assertEqual({key1: (value1, )}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), loss_output=ctx.last_step_outputs["replica_loss_reduced"], reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), loss_output=ctx. last_step_outputs["cross_replica_loss_reduced"], reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), loss_output=ctx. last_step_outputs["cross_replica_loss_not_reduced"], reduced=False, distribution=distribution) return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) weights, biases, losses = [], [], [] for _ in range(5): _, loss = run_step() losses.append(loss) weights.append(self.evaluate(kernel)) biases.append(self.evaluate(bias)) loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:])) self.assertTrue(loss_is_not_increasing) error = abs( numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(error_is_not_increasing)