def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = trackable_utils.Checkpoint() root.var = trackable_utils.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): optimizer.minimize(root.var.read_value) else: train_op = optimizer.minimize(root.var) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate(trackable_utils.gather_initializers( trackable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = root.save(os.path.join(checkpoint_directory, "with_slots")) new_root = trackable_utils.Checkpoint() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = new_root.restore(slots_path) no_slot_status = new_root.restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = trackable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.AdamOptimizer(0.1) slot_status.assert_existing_objects_matched() with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.executing_eagerly(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.executing_eagerly(): new_root.optimizer.minimize(new_root.var.read_value) else: train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) slot_status.assert_consumed()
def testLoadFromNameBasedSaver(self): """Save a name-based checkpoint, load it using the object-based API.""" with test_util.device(use_gpu=True): save_path = self._write_name_based_checkpoint() root = self._initialized_model() self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) object_saver = util.TrackableSaver(graph_view.ObjectGraphView(root)) self._set_sentinels(root) status = object_saver.restore(save_path) if context.executing_eagerly(): self._check_sentinels(root) if context.executing_eagerly(): with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): status.assert_consumed() else: # When graph building, we haven't read any keys, so we don't know # whether the restore will be complete. with self.assertRaisesRegexp(AssertionError, "not restored"): status.assert_consumed() status.run_restore_ops() self._check_sentinels(root) self._set_sentinels(root) status = object_saver.restore(save_path) status.initialize_or_restore() self._check_sentinels(root)
def testDynamicShapeVariableWithCallableInit(self): var0 = variable_scope.get_variable("var0", initializer=constant_op.constant(1.), validate_shape=False) self.assertFalse(var0.shape.is_fully_defined()) grads0 = constant_op.constant(0.1, dtype=dtypes.float32) learning_rate = lambda: 3.0 ada_opt = adagrad.AdagradOptimizer( learning_rate, initial_accumulator_value=0.1, use_locking=True) if not context.executing_eagerly(): ada_update = ada_opt.apply_gradients( zip([grads0], [var0])) self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values v0_val = self.evaluate([var0]) self.assertAllClose([1.0], v0_val) # Run 3 steps of adagrad for _ in range(3): if not context.executing_eagerly(): self.evaluate(ada_update) else: ada_opt.apply_gradients(zip([grads0], [var0])) # Validate updated params v0_val = self.evaluate([var0]) self.assertAllCloseAccordingToType( np.array([-1.6026098728179932]), v0_val)
def testAddWeight(self): layer = base_layers.Layer(name='my_layer') # Test basic variable creation. variable = layer.add_variable( 'my_var', [2, 2], initializer=init_ops.zeros_initializer()) self.assertEqual(variable.name, 'my_layer/my_var:0') self.assertEqual(layer.variables, [variable]) self.assertEqual(layer.trainable_variables, [variable]) self.assertEqual(layer.non_trainable_variables, []) if not context.executing_eagerly(): self.assertEqual( layer.variables, ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) # Test non-trainable variable creation. # layer.add_variable should work even outside `build` and `call`. variable_2 = layer.add_variable( 'non_trainable_var', [2, 2], initializer=init_ops.zeros_initializer(), trainable=False) self.assertEqual(layer.variables, [variable, variable_2]) self.assertEqual(layer.trainable_variables, [variable]) self.assertEqual(layer.non_trainable_variables, [variable_2]) if not context.executing_eagerly(): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1) # regularizers only supported in GRAPH mode. regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 variable = layer.add_variable( 'reg_var', [2, 2], initializer=init_ops.zeros_initializer(), regularizer=regularizer) self.assertEqual(len(layer.losses), 1)
def _test_basic_sgd_with_learning_rate_decay(self, sgd, dtype): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) if not context.executing_eagerly(): sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) self.evaluate(variables.global_variables_initializer()) # Run 2 steps of sgd if not context.executing_eagerly(): self.evaluate(sgd_op) else: sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) # Validate updated params self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], self.evaluate(var0)) self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], self.evaluate(var1)) if not context.executing_eagerly(): self.evaluate(sgd_op) else: sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) # Validate updated params self.assertAllCloseAccordingToType( [1.0 - 3.0 * 0.1 - 2.0 * 0.1, 2.0 - 3.0 * 0.1 - 2.0 * 0.1], self.evaluate(var0)) self.assertAllCloseAccordingToType( [3.0 - 3.0 * 0.01 - 2.0 * 0.01, 4.0 - 3.0 * 0.01 - 2.0 * 0.01], self.evaluate(var1))
def test_apply_gradients(self): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) itr = dataset.make_one_shot_iterator() lr = 1 opt = gd.GradientDescentOptimizer(lr) lsm = lsm_lib.FixedLossScaleManager(1.e4) opt = lso.LossScaleOptimizer(opt, lsm) train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)]) if not context.executing_eagerly(): train_op = train_fn() expected_output = [1, 1, 1 - 0.1] actual_output = [] self.evaluate(variables.global_variables_initializer()) for _ in range(3): # nan or inf is not applied. if context.executing_eagerly(): train_fn() else: self.evaluate(train_op) actual_output.append(self.evaluate(x)) self.assertAllClose(expected_output, actual_output)
def _test_summary_for_replica_zero_only(self, d): logdir = tempfile.mkdtemp() def run_fn(): """Function executed for each replica.""" with summary_writer.as_default(): replica_id = ds_context.get_replica_context().replica_id_in_sync_group return summary_ops.write("a", replica_id) with self.cached_session() as sess, d.scope(), \ summary_ops.always_record_summaries(): # We need global_step because summary writing op *always* has global_step # as input, even when we always record summary or never record summary. global_step = training_util.get_or_create_global_step() if not context.executing_eagerly(): # When executing eagerly, variables are initialized immediately after # creation, and its initializer will be None. global_step.initializer.run() summary_ops.set_step(0) summary_writer = summary_ops.create_file_writer(logdir) output = d.extended.call_for_each_replica(run_fn) unwrapped = d.unwrap(output) if not context.executing_eagerly(): sess.run(summary_writer.init()) sess.run(unwrapped) sess.run(summary_writer.close()) events = _events_from_logdir(self, logdir) # There will be 2 entries: 1 summary file header entry, and 1 entry # written by replica 0. self.assertLen(events, 2) self.assertEqual(events[1].summary.value[0].tag, "a") self.assertEqual(events[1].summary.value[0].simple_value, 0.0)
def testBasicWithLearningRateDecay(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) learning_rate = 3.0 decay = 0.5 sgd = gradient_descent.SGD(learning_rate=learning_rate, decay=decay) if not context.executing_eagerly(): sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) self.evaluate(variables.global_variables_initializer()) # Run 2 steps of sgd if not context.executing_eagerly(): self.evaluate(sgd_op) else: sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) # Validate updated params self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], self.evaluate(var0)) self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], self.evaluate(var1)) if not context.executing_eagerly(): self.evaluate(sgd_op) else: sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) # Validate updated params self.assertAllCloseAccordingToType( [1.0 - 3.0 * 0.1 - 2.0 * 0.1, 2.0 - 3.0 * 0.1 - 2.0 * 0.1], self.evaluate(var0)) self.assertAllCloseAccordingToType( [3.0 - 3.0 * 0.01 - 2.0 * 0.01, 4.0 - 3.0 * 0.01 - 2.0 * 0.01], self.evaluate(var1))
def _test_helper(self, inputs, expected_outputs, initial_loss_scale=1., increment_period=2, multiplier=2): loss_scale = loss_scale_module.DynamicLossScale( initial_loss_scale=initial_loss_scale, increment_period=increment_period, multiplier=multiplier) itr = _get_example_iter(inputs) def update(): is_finite = itr.get_next() grad = self._get_tensor(is_finite) update_op, should_apply_gradients = loss_scale.update([grad]) assert_op = check_ops.assert_equal(should_apply_gradients, is_finite) if context.executing_eagerly(): return with ops.control_dependencies([assert_op]): return array_ops.identity(update_op) actual_outputs = [] if not context.executing_eagerly(): update_op = update() self.evaluate(variables.global_variables_initializer()) for _ in range(len(inputs)): if context.executing_eagerly(): update() else: self.evaluate(update_op) actual_outputs.append(self.evaluate(loss_scale())) self.assertEqual(actual_outputs, expected_outputs)
def testCriticalSectionInParallelDoesntDeadlockOnError(self): # No eager mode execution of this test because eager does not # run fn() in parallel, which is where the deadlock could # potentially occur (in graph mode). cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn(i): error = control_flow_ops.Assert((i % 2) == 1, ["Error"]) with ops.control_dependencies([error]): return v.read_value() num_concurrent = 2 @def_function.function(autograph=False) def run_concurrently(): return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)] if not context.executing_eagerly(): run_concurrently = run_concurrently() self.evaluate(v.initializer) for _ in range(100): with self.assertRaisesOpError("Error"): if context.executing_eagerly(): run_concurrently() else: self.evaluate(run_concurrently)
def testVariablesAcrossGraphs(self): optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5) with ops.Graph().as_default(): var0 = resource_variable_ops.ResourceVariable( [1.0, 2.0], dtype=dtypes.float32, name="var0") var1 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var1") if context.executing_eagerly(): loss = lambda: math_ops.reduce_sum(var0 + var1) else: loss = math_ops.reduce_sum(var0 + var1) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var0") self.assertStartsWith(optimizer_variables[1].name, "var1") self.assertEquals(2, len(optimizer_variables)) with ops.Graph().as_default(): var2 = resource_variable_ops.ResourceVariable( [1.0, 2.0], dtype=dtypes.float32, name="var2") var3 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var3") if context.executing_eagerly(): loss = lambda: math_ops.reduce_sum(var2 + var3) else: loss = math_ops.reduce_sum(var2 + var3) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var2") self.assertStartsWith(optimizer_variables[1].name, "var3") self.assertEquals(2, len(optimizer_variables))
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss=True): with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) ds = distribution.distribute_dataset(dataset_fn) if context.executing_eagerly(): iterator = ds.make_one_shot_iterator() else: iterator = ds.make_initializable_iterator() def run_step(): return control_flow_ops.group(distribution.unwrap( distribution.call_for_each_tower( model_fn, iterator.get_next(), run_concurrently=layer.built))) if not context.executing_eagerly(): with self.cached_session() as sess: sess.run(iterator.initializer) run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) weights, biases = [], [] for _ in range(10): run_step() weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing)
def _test_helper(self, inputs, expected_outputs, init_loss_scale=1, incr_every_n_step=2, decr_every_n_nan_or_inf=2): ratio = 2 lsm = lsm_lib.ExponentialUpdateLossScaleManager( init_loss_scale=init_loss_scale, incr_every_n_steps=incr_every_n_step, decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, incr_ratio=ratio, decr_ratio=1. / ratio) itr = _GetExampleIter(inputs) update_fn = lambda: lsm.update_loss_scale(itr.get_next()) self.evaluate(variables.global_variables_initializer()) actual_outputs = [] if not context.executing_eagerly(): update_op = update_fn() for _ in range(len(inputs)): if context.executing_eagerly(): update_fn() else: self.evaluate(update_op) actual_outputs.append(self.evaluate(lsm.get_loss_scale())) self.assertEqual(actual_outputs, expected_outputs)
def add_variable(self, name, shape=None, dtype=None, initializer=None): """***Only for use by descendants of Metric***.""" if self._built: raise RuntimeError("Can't call add_variable() except in build().") if context.executing_eagerly(): collections = None else: if self._use_global_variables: collections = [ops.GraphKeys.GLOBAL_VARIABLES] else: collections = [ops.GraphKeys.LOCAL_VARIABLES] collections += [ops.GraphKeys.METRIC_VARIABLES] # Variables are Checkpointable dependencies of Metrics regardless of the # global/local distinction. Users can avoid saving variables by not adding a # dependency on the Metric. v = self._add_variable_with_custom_getter( name=name, shape=shape, dtype=dtype, initializer=initializer, trainable=False, collections=collections, use_resource=True, getter=variable_scope.get_variable, # Raise duplicate variable exceptions from get_variable rather than # Checkpointable. overwrite=True) self._vars.append(v) if context.executing_eagerly(): self._initial_values[v] = v.value() return v
def testSaveRestoreMultipleIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") dataset = dataset_ops.Dataset.from_tensor_slices( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) iterator_1 = dataset.make_one_shot_iterator() get_next_1 = iterator_1.get_next if context.executing_eagerly( ) else functools.partial(self.evaluate, iterator_1.get_next()) iterator_2 = dataset.make_one_shot_iterator() get_next_2 = iterator_2.get_next if context.executing_eagerly( ) else functools.partial(self.evaluate, iterator_2.get_next()) dataset_2 = dataset_ops.Dataset.range(10) iterator_3 = dataset_2.make_one_shot_iterator() get_next_3 = iterator_3.get_next if context.executing_eagerly( ) else functools.partial(self.evaluate, iterator_3.get_next()) checkpoint = checkpointable_utils.Checkpoint( iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) self.assertAllEqual([1, 4], get_next_1()) self.assertAllEqual(0, get_next_3()) self.assertAllEqual(1, get_next_3()) self.assertAllEqual(2, get_next_3()) save_path = checkpoint.save(checkpoint_prefix) self.assertAllEqual([1, 4], get_next_2()) self.assertAllEqual([9, 16], get_next_2()) self.assertAllEqual(3, get_next_3()) checkpoint.restore(save_path).run_restore_ops() self.assertAllEqual([9, 16], get_next_1()) self.assertAllEqual([1, 4], get_next_2()) self.assertAllEqual(3, get_next_3())
def test_dropout_mask_reuse(self): # The layer is created with recurrent_initializer = zero, so that the # the recurrent state won't affect the output. By doing this, we can verify # the output and see if the same mask is applied to for each timestep. rnn = keras.layers.SimpleRNN(3, dropout=0.5, kernel_initializer='ones', recurrent_initializer='zeros', return_sequences=True, unroll=True) inputs = constant_op.constant(1.0, shape=(6, 2, 5)) out = rnn(inputs, training=True) if not context.executing_eagerly(): self.evaluate(variables_lib.global_variables_initializer()) batch_1 = self.evaluate(out) batch_1_t0, batch_1_t1 = batch_1[:, 0, :], batch_1[:, 1, :] self.assertAllClose(batch_1_t0, batch_1_t1) # This simulate the layer called with multiple batches in eager mode if context.executing_eagerly(): out2 = rnn(inputs, training=True) else: out2 = out batch_2 = self.evaluate(out2) batch_2_t0, batch_2_t1 = batch_2[:, 0, :], batch_2[:, 1, :] self.assertAllClose(batch_2_t0, batch_2_t1) # Also validate that different dropout is used by between batches. self.assertNotAllClose(batch_1_t0, batch_2_t0) self.assertNotAllClose(batch_1_t1, batch_2_t1)
def testRequestNotToCompile(self): with self.test_scope(): def f(x): with ops.device('device:CPU:0'): y = 2.0 * x return x, y wholly_compiled_f = def_function.function(f) op_by_op_f = function.defun_with_attributes( f, attributes={'_XlaCompile': False}) x = constant_op.constant([0.0, 2.0], name='data') # When function is wholly compiled, all outputs will be on the # device on which it is run. r_x, r_y = wholly_compiled_f(x) self.assertAllEqual([0.0, 2.0], r_x) self.assertAllEqual([0.0, 4.0], r_y) if context.executing_eagerly(): # backing_device is only available for eager tensors. self.assertRegexpMatches(r_x.backing_device, self.device) self.assertRegexpMatches(r_y.backing_device, self.device) # When function is executed op-by-op, requested devices will be # respected. r_x, r_y = op_by_op_f(x) self.assertAllEqual([0.0, 2.0], r_x) self.assertAllEqual([0.0, 4.0], r_y) if context.executing_eagerly(): # backing_device is only available for eager tensors. self.assertRegexpMatches(r_x.backing_device, self.device) self.assertRegexpMatches(r_y.backing_device, 'device:CPU:0')
def compress(self, inputs): """Compress inputs and store their binary representations into strings. Args: inputs: `Tensor` with values to be compressed. Returns: String `Tensor` vector containing the compressed representation of each batch element of `inputs`. """ with ops.name_scope(self._name_scope()): inputs = ops.convert_to_tensor(inputs) if not self.built: # Check input assumptions set before layer building, e.g. input rank. self._assert_input_compatibility(inputs) if self.dtype is None: self._dtype = inputs.dtype.base_dtype.name self.build(inputs.shape) # Check input assumptions set after layer building, e.g. input shape. if not context.executing_eagerly(): self._assert_input_compatibility(inputs) ndim = self.input_spec.ndim channel_axis = self._channel_axis(ndim) # Tuple of slices for expanding dimensions of tensors below. slices = ndim * [None] + [slice(None)] slices[channel_axis] = slice(None) slices = tuple(slices) # Expand dimensions of CDF to input dimensions, keeping the channels along # the right dimension. cdf = self._quantized_cdf[slices[1:]] num_levels = array_ops.shape(cdf)[-1] - 1 # Bring inputs to the right range by centering the range on the medians. half = constant_op.constant(.5, dtype=self.dtype) medians = array_ops.squeeze(self._medians, [1, 2]) offsets = (math_ops.cast(num_levels // 2, self.dtype) + half) - medians # Expand offsets to input dimensions and add to inputs. values = inputs + offsets[slices[:-1]] # Clip to range and cast to integers. Because we have added .5 above, and # all values are positive, the cast effectively implements rounding. values = math_ops.maximum(values, half) values = math_ops.minimum( values, math_ops.cast(num_levels, self.dtype) - half) values = math_ops.cast(values, dtypes.int16) def loop_body(tensor): return coder_ops.range_encode( tensor, cdf, precision=self.range_coder_precision) strings = functional_ops.map_fn( loop_body, values, dtype=dtypes.string, back_prop=False) if not context.executing_eagerly(): strings.set_shape(inputs.shape[:1]) return strings
def doTestBasic(self, use_resource=False): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): with self.test_session(graph=ops.Graph()): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) if use_resource: var0 = resource_variable_ops.ResourceVariable( var0_np, name="var0_%d" % i) var1 = resource_variable_ops.ResourceVariable( var1_np, name="var1_%d" % i) else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) grads0 = constant_op.constant(grads0_np) grads1 = constant_op.constant(grads1_np) opt = adamax.AdaMaxOptimizer() update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) opt_variables = opt.variables() beta1_power = opt._get_beta_accumulators() self.assertTrue(beta1_power is not None) self.assertIn(beta1_power, opt_variables) if not context.executing_eagerly(): with ops.Graph().as_default(): # Shouldn't return non-slot variables from other graphs. self.assertEqual(0, len(opt.variables())) self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power = opt._get_beta_accumulators() # Run 3 steps of AdaMax for t in range(1, 4): if not context.executing_eagerly(): self.evaluate(update) elif t > 1: opt.apply_gradients(zip([grads0, grads1], [var0, var1])) self.assertAllCloseAccordingToType(0.9**(t + 1), self.evaluate(beta1_power)) var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if use_resource: self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name)
def tearDown(self): # test for disable eager test ops.disable_eager_execution() self.assertFalse(context.executing_eagerly()) # Calling disable eager execution a second time should not cause an error. ops.disable_eager_execution() self.assertFalse(context.executing_eagerly())
def setUp(self): # test for enable eager test ops.enable_eager_execution() self.assertTrue(context.executing_eagerly()) # Calling enable eager execution a second time should not cause an error. ops.enable_eager_execution() self.assertTrue(context.executing_eagerly())
def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) tower_local = kwargs.pop("tower_local_reduce_method", None) if tower_local is not None: kwargs["trainable"] = False # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): index = {} for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] kwargs["name"] = "%s/replica_%d" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): initial_value = index[devices[0]].value() else: initial_value = index[devices[0]].initial_value kwargs["initial_value"] = array_ops.identity(initial_value) with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v if tower_local is None: result = values.MirroredVariable(index, index[devices[0]]) else: result = values.TowerLocalVariable( index, index[devices[0]], tower_local) if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): l.remove(v) g.add_to_collections(collections, result) return result
def _TestTrtGraphConverter(self, input_saved_model_dir=None, output_saved_model_dir=None, need_calibration=False, is_dynamic_op=False): """General method to test trt_convert.TrtGraphConverter().""" output_graph_def = self._ConvertGraph( input_saved_model_dir=input_saved_model_dir, output_saved_model_dir=output_saved_model_dir, need_calibration=need_calibration, is_dynamic_op=is_dynamic_op, use_function_backup=need_calibration) graph_defs_to_verify = [output_graph_def] if output_saved_model_dir: if context.executing_eagerly(): root = load.load(output_saved_model_dir) saved_model_graph_def = root.signatures[ signature_constants .DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def() else: saved_model_graph_def = saved_model_utils.get_meta_graph_def( output_saved_model_dir, tag_constants.SERVING).graph_def self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef)) graph_defs_to_verify.append(saved_model_graph_def) for graph_def in graph_defs_to_verify: node_name_to_op = {node.name: node.op for node in graph_def.node} if context.executing_eagerly(): # In V2 the actual graph could be inside a function. for func in graph_def.library.function: node_name_to_op.update({node.name: node.op for node in func.node_def}) self.assertIn("TRTEngineOp_0", node_name_to_op) self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"]) else: self.assertEqual({ "input": "Placeholder", "TRTEngineOp_0": "TRTEngineOp", "output": "Identity" }, node_name_to_op) if need_calibration: trt_engine_nodes = [ node for node in graph_def.node if node.op == "TRTEngineOp" ] self.assertNotEmpty(trt_engine_nodes) for node in trt_engine_nodes: self.assertTrue(len(node.attr["calibration_data"].s)) # Run the calibrated graph. # TODO(laigd): consider having some input where the answer is different. with ops.Graph().as_default(): importer.import_graph_def(graph_def, name="") with self.session(config=self._GetConfigProto()) as sess: for test_data in range(10): self.assertEqual((test_data + 1.0)**2, sess.run( "output:0", feed_dict={"input:0": [[[test_data]]]}))
def create_file_writer_v2(logdir, max_queue=None, flush_millis=None, filename_suffix=None, name=None): """Creates a summary file writer for the given log directory. Args: logdir: a string specifying the directory in which to write an event file. max_queue: the largest number of summaries to keep in a queue; will flush once the queue gets bigger than this. Defaults to 10. flush_millis: the largest interval between flushes. Defaults to 120,000. filename_suffix: optional suffix for the event file name. Defaults to `.v2`. name: a name for the op that creates the writer. Returns: A SummaryWriter object. """ if logdir is None: raise ValueError("logdir cannot be None") inside_function = ops.inside_function() with ops.name_scope(name, "create_file_writer") as scope, ops.device("cpu:0"): # Run init inside an init_scope() to hoist it out of tf.functions. with ops.init_scope(): if context.executing_eagerly(): _check_create_file_writer_args( inside_function, logdir=logdir, max_queue=max_queue, flush_millis=flush_millis, filename_suffix=filename_suffix) logdir = ops.convert_to_tensor(logdir, dtype=dtypes.string) if max_queue is None: max_queue = constant_op.constant(10) if flush_millis is None: flush_millis = constant_op.constant(2 * 60 * 1000) if filename_suffix is None: filename_suffix = constant_op.constant(".v2") # Prepend the PID and a process-local UID to the filename suffix to avoid # filename collisions within the machine (the filename already contains # the hostname to avoid cross-machine collisions). unique_prefix = constant_op.constant(".%s.%s" % (os.getpid(), ops.uid())) filename_suffix = unique_prefix + filename_suffix # Use a unique shared_name to prevent resource sharing. if context.executing_eagerly(): shared_name = context.shared_name() else: shared_name = ops.name_from_scope_name(scope) # pylint: disable=protected-access return ResourceSummaryWriter( shared_name=shared_name, init_op_fn=functools.partial( gen_summary_ops.create_summary_file_writer, logdir=logdir, max_queue=max_queue, flush_millis=flush_millis, filename_suffix=filename_suffix), name=name, v2=True)
def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root_checkpointable = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): optimizer.minimize( lambda: model(input_value)) else: train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. root_checkpointable.save_counter # pylint: disable=pointless-statement self.evaluate(checkpointable_utils.gather_initializers( root_checkpointable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.])) m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) save_path = root_checkpointable.save(file_prefix=prefix) self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.])) self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration status = root_checkpointable.restore(save_path=save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly on_create_model = MyModel() on_create_optimizer = adam.AdamOptimizer(0.001) on_create_root = checkpointable_utils.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) on_create_model(constant_op.constant([[3.]])) # create variables self.assertAllEqual(1, self.evaluate(on_create_root.save_counter)) self.assertAllEqual([42.], self.evaluate( on_create_model._named_dense.variables[1])) on_create_m_bias_slot = on_create_optimizer.get_slot( on_create_model._named_dense.variables[1], "m") # Optimizer slot variables are created when the original variable is # restored. self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) self.assertAllEqual(optimizer_variables[2:], self.evaluate(on_create_optimizer.variables())) on_create_optimizer._create_slots( [resource_variable_ops.ResourceVariable([1.])]) status.assert_consumed() beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
def init_fn(): self.assertTrue(context.executing_eagerly()) with ops.init_scope(): self.assertTrue(context.executing_eagerly()) context_switches = context.context().context_switches self.assertEqual(len(context_switches.stack), 1) self.assertFalse(context_switches.stack[0].is_building_function) self.assertEqual(context_switches.stack[0].enter_context_fn, context.eager_mode)
def testAddWeight(self): layer = base_layers.Layer(name='my_layer') # Test basic variable creation. variable = layer.add_variable( 'my_var', [2, 2], initializer=init_ops.zeros_initializer()) self.assertEqual(variable.name, 'my_layer/my_var:0') self.assertEqual(layer.variables, [variable]) self.assertEqual(layer.trainable_variables, [variable]) self.assertEqual(layer.non_trainable_variables, []) if not context.executing_eagerly(): self.assertEqual( layer.variables, ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) # Test non-trainable variable creation. # layer.add_variable should work even outside `build` and `call`. variable_2 = layer.add_variable( 'non_trainable_var', [2, 2], initializer=init_ops.zeros_initializer(), trainable=False) self.assertEqual(layer.variables, [variable, variable_2]) self.assertEqual(layer.trainable_variables, [variable]) self.assertEqual(layer.non_trainable_variables, [variable_2]) if not context.executing_eagerly(): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1) regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 _ = layer.add_variable( 'reg_var', [2, 2], initializer=init_ops.zeros_initializer(), regularizer=regularizer) self.assertEqual(len(layer.losses), 1) added_variable = [False] # Test that sync `ON_READ` variables are defaulted to be non-trainable. variable_3 = layer.add_variable( 'sync_on_read_var', [2, 2], initializer=init_ops.zeros_initializer(), synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) self.assertEqual(layer.non_trainable_variables, [variable_2, variable_3]) @def_function.function def function_adds_weight(): if not added_variable[0]: layer.add_variable( 'reg_var_from_function', [2, 2], initializer=init_ops.zeros_initializer(), regularizer=regularizer) added_variable[0] = True function_adds_weight() self.assertEqual(len(layer.losses), 2)
def test_apply_gradients_loss_scale_is_updated(self): class SimpleLossScaleManager(lsm_lib.LossScaleManager): """A simple loss scale manager for easier testing. It increments loss scale by 1 if grads are finite, and decreases loss scale by 1 if otherwise. """ def __init__(self, loss_scale): self._loss_scale = variable_scope.variable( name="loss_scale", initial_value=loss_scale, dtype=dtypes.float32, trainable=False) def get_loss_scale(self): return self._loss_scale def update_loss_scale(self, if_finite_grads): return control_flow_ops.cond( if_finite_grads, lambda: state_ops.assign_add(self._loss_scale, 1), lambda: state_ops.assign_sub(self._loss_scale, 1)) x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) itr = dataset.make_one_shot_iterator() lr = 1 init_loss_scale = 8 opt = gd.GradientDescentOptimizer(lr) lsm = SimpleLossScaleManager(init_loss_scale) opt = lso.LossScaleOptimizer(opt, lsm) train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)]) if not context.executing_eagerly(): train_op = train_fn() self.evaluate(variables.global_variables_initializer()) expected_loss_scale = [ init_loss_scale - 1, init_loss_scale - 2, init_loss_scale - 2 + 1 ] expected_output = [1, 1, 1 - 0.1] actual_output = [] for i in range(3): # nan or inf is not applied. if context.executing_eagerly(): train_fn() else: self.evaluate(train_op) actual_output.append(self.evaluate(x)) self.assertAllClose(expected_loss_scale[i], self.evaluate(lsm._loss_scale)) self.assertAllClose(expected_output, actual_output)
def skip_unsupported_test_configuration(self, distribution): if should_skip_tpu_with_eager(distribution): self.skipTest('TPUStrategy does not support eager mode now.') if context.executing_eagerly() and self.use_numpy: self.skipTest('Numpy as inputs is not supported with strategy in eager.') if context.executing_eagerly() and self.use_validation_data: self.skipTest('TODO(hongjunchoi): Add test logic for using validation ' 'data for eager execution.') return
def close(self): """Flushes and closes the summary writer.""" if self._v2 and context.executing_eagerly() and self._closed: return try: with ops.control_dependencies([self.flush()]): with ops.device("cpu:0"): return gen_summary_ops.close_summary_writer(self._resource) finally: if self._v2 and context.executing_eagerly(): self._closed = True
def _TestTrtGraphConverter(self, input_saved_model_dir=None, output_saved_model_dir=None, need_calibration=False, is_dynamic_op=False): """General method to test trt_convert.TrtGraphConverter().""" output_graph_def = self._ConvertGraph( input_saved_model_dir=input_saved_model_dir, output_saved_model_dir=output_saved_model_dir, need_calibration=need_calibration, is_dynamic_op=is_dynamic_op, use_function_backup=need_calibration) graph_defs_to_verify = [output_graph_def] if output_saved_model_dir: if context.executing_eagerly(): root = load.load(output_saved_model_dir) saved_model_graph_def = root.signatures[ signature_constants. DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def() else: saved_model_graph_def = saved_model_utils.get_meta_graph_def( output_saved_model_dir, tag_constants.SERVING).graph_def self.assertTrue( isinstance(saved_model_graph_def, graph_pb2.GraphDef)) graph_defs_to_verify.append(saved_model_graph_def) for graph_def in graph_defs_to_verify: node_name_to_op = {node.name: node.op for node in graph_def.node} if context.executing_eagerly(): # In V2 the actual graph could be inside a function. for func in graph_def.library.function: node_name_to_op.update( {node.name: node.op for node in func.node_def}) self.assertIn("TRTEngineOp_0", node_name_to_op) self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"]) else: self.assertEqual( { "input": "Placeholder", "TRTEngineOp_0": "TRTEngineOp", "output": "Identity" }, node_name_to_op) if need_calibration: trt_engine_nodes = [ node for node in graph_def.node if node.op == "TRTEngineOp" ] self.assertNotEmpty(trt_engine_nodes) for node in trt_engine_nodes: self.assertTrue(len(node.attr["calibration_data"].s)) # Run the calibrated graph. # TODO(laigd): consider having some input where the answer is different. with ops.Graph().as_default(): importer.import_graph_def(graph_def, name="") with self.session(config=self._GetConfigProto()) as sess: for test_data in range(10): self.assertEqual( (test_data + 1.0)**2, sess.run( "output:0", feed_dict={"input:0": [[[test_data]]]}))
def __init__(self, proc_func, cluster_spec, rpc_layer=None, max_run_time=None, capture_std_stream=False, grpc_fail_fast=False, args=None, kwargs=None): """Creates a multi-process runner. Args: proc_func: Function to be run on child processes. This will be run on processes for all task types. cluster_spec: Dict for cluster spec. The following is an example of cluster with three workers and two ps's. {"worker": ["worker0.example.com:2222", "worker1.example.com:2222", "worker2.example.com:2222"], "ps": ["ps0.example.com:2222", "ps1.example.com:2222"]} rpc_layer: RPC layer to use. Default value is 'grpc+loas'. max_run_time: If set, child processes is forced to exit at approximately this many seconds after `start` is called. We achieve this through `signal.alarm()` api. Note that this is best effort at Python level since Python signal handler does not get executed when it runs lower level C/C++ code. So it can be delayed for arbitrarily long time. capture_std_stream: Boolean, whether the messages streamed to stdout and stderr in subprocesses are captured. grpc_fail_fast: Whether GRPC connection between processes should fail without retrying. Defaults to False. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. Raises: RuntimeError: if `multi_process_runner.test_main()` is not called. ValueError: if there are more than one chief in the `cluster_spec`. """ assert cluster_spec is not None if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1: raise ValueError('If chief exists in the cluster, there must be at most ' 'one chief. Current `cluster_spec` has {} chiefs.' .format(len(cluster_spec['chief']))) assert callable(proc_func) if not multi_process_lib.using_context_manager(): raise RuntimeError('`multi_process_runner` is not initialized. ' 'Please call `multi_process_runner.test_main()` ' 'within `if __name__ == \'__main__\':` block ' 'in your python module to properly initialize ' '`multi_process_runner`.') self._proc_func = proc_func self._cluster_spec = cluster_spec self._rpc_layer = rpc_layer self._max_run_time = max_run_time self._capture_std_stream = capture_std_stream self._grpc_fail_fast = grpc_fail_fast self._args = args or () self._kwargs = kwargs or {} self._outstanding_subprocess_count = 0 # Child processes should have the same v2 and eager behavior. self._v2_enabled = tf2.enabled() self._executing_eagerly = context.executing_eagerly()
def __init__(self, initial_value=None, trainable=None, caching_device=None, name=None, dtype=None, constraint=None, add_initializers_to=None, lifted_initializer_graph=None, synchronization=None, aggregation=None, shape=None, **unused_kwargs): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, GradientTapes automatically watch uses of this Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. add_initializers_to: if not None and not in legacy graph mode, the initializer tensor will be added to this map in addition to adding the assignment to the function. lifted_initializer_graph: FuncGraph to try to lift initializers to. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. shape: (optional) The shape of this variable. If None, the shape of `initial_value` will be used. When setting this argument to `tf.TensorShape(None)` (representing an unspecified shape), the variable can be assigned with values of different shapes. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If called outside of a function definition. """ with ops.init_scope(): self._in_graph_mode = not context.executing_eagerly() if not ops.inside_function(): # If we've been init_scope()d out of the function definition nothing to do # here; we can't really do the capturing or conditional logic. resource_variable_ops.ResourceVariable.__init__( self, initial_value=initial_value, trainable=trainable, caching_device=caching_device, name=name, dtype=dtype, constraint=constraint) return if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") if isinstance(initial_value, trackable.CheckpointInitialValue): self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as scope_name: with ops.name_scope("Initializer"), ops.device(None): initial_value = ops.convert_to_tensor( initial_value() if init_from_fn else initial_value, name="initial_value", dtype=dtype) assert initial_value is not None # Don't use `shape or initial_value.shape` since TensorShape has # overridden `__bool__`. if shape is None: shape = initial_value.shape # Use the constructor for UninitializedVariable to start. Outside the name # scope so we don't double up the prefix. super(UnliftedInitializerVariable, self).__init__( trainable=trainable, caching_device=caching_device, name=name, shape=shape, dtype=initial_value.dtype, constraint=constraint, synchronization=synchronization, aggregation=aggregation, extra_handle_data=initial_value, **unused_kwargs) with ops.name_scope(scope_name): if self._in_graph_mode: with ops.init_scope(): outer_graph = ops.get_default_graph() func_graph = ops.get_default_graph() function_placeholders = ( func_graph.inputs + func_graph.internal_captures) placeholder_ops = set( [tensor.op for tensor in function_placeholders]) lifted_initializer = lift_to_graph.lift_to_graph( [initial_value], outer_graph, disallowed_placeholders=placeholder_ops)[initial_value] with ops.init_scope(): self._initial_value = lifted_initializer with ops.name_scope("IsInitialized"): self._is_initialized_op = ( resource_variable_ops.var_is_initialized_op(self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): self._initializer_op = resource_variable_ops.assign_variable_op( self._handle, lifted_initializer, name=n) elif context.executing_eagerly(): # In this case, both current scope and init scope are eager. # Assign_variable_op will be executed immediately. So we don't need to # add it to "add_initializers_to" to lift it out. with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): resource_variable_ops.assign_variable_op( self._handle, initial_value, name=n) else: # Init scope is eager but current scope is graph. We will lift out this # variable by addint it into "add_initializers_to". if add_initializers_to is not None: add_initializers_to.append((self, initial_value)) def assign_fn(): with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): resource_variable_ops.assign_variable_op( self._handle, initial_value, name=n) # Returning values to keep tf.cond happy. return ops.convert_to_tensor(1) def not_assign_fn(): return ops.convert_to_tensor(0) # Note: this cond is always guaranteed to run because we're inside a # defun which will insert automatic control dependencies. It will only # execute assign_fn if lifting failed. graph = ops.get_default_graph() # Capture the handle ahead of time in order to avoid querying the shape # of the handle which helps async execution performance graph.capture(self._handle, shape=()) control_flow_ops.cond( resource_variable_ops.var_is_initialized_op(self._handle), not_assign_fn, assign_fn)
def _add_variable_with_custom_getter(self, name, shape=None, dtype=dtypes.float32, initializer=None, getter=None, overwrite=False, **kwargs_for_getter): """Restore-on-create for a variable be saved with this `Checkpointable`. If the user has requested that this object or another `Checkpointable` which depends on this object be restored from a checkpoint (deferred loading before variable object creation), `initializer` may be ignored and the value from the checkpoint used instead. Args: name: A name for the variable. Must be unique within this object. shape: The shape of the variable. dtype: The data type of the variable. initializer: The initializer to use. Ignored if there is a deferred restoration left over from a call to `_restore_from_checkpoint_position`. getter: The getter to wrap which actually fetches the variable. overwrite: If True, disables unique name and type checks. **kwargs_for_getter: Passed to the getter. Returns: The new variable object. Raises: ValueError: If the variable name is not unique. """ self._maybe_initialize_checkpointable() with ops.init_scope(): if context.executing_eagerly(): # If this is a variable with a single Tensor stored in the checkpoint, # we can set that value as an initializer rather than initializing and # then assigning (when executing eagerly). This call returns None if # there is nothing to restore. checkpoint_initializer = self._preload_simple_restoration( name=name, shape=shape) else: checkpoint_initializer = None if (checkpoint_initializer is not None and not (isinstance(initializer, CheckpointInitialValue) and (initializer.restore_uid > checkpoint_initializer.restore_uid))): # If multiple Checkpointable objects are "creating" the same variable # via the magic of custom getters, the one with the highest restore UID # (the one called last) has to make the final initializer. If another # custom getter interrupts this process by overwriting the initializer, # then we'll catch that when we call _track_checkpointable. So this is # "best effort" to set the initializer with the highest restore UID. initializer = checkpoint_initializer shape = None new_variable = getter(name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs_for_getter) # If we set an initializer and the variable processed it, tracking will not # assign again. It will add this variable to our dependencies, and if there # is a non-trivial restoration queued, it will handle that. This also # handles slot variables. if not overwrite or isinstance(new_variable, CheckpointableBase): return self._track_checkpointable(new_variable, name=name, overwrite=overwrite) else: # TODO(allenl): Some variable types are not yet supported. Remove this # fallback once all get_variable() return types are Checkpointable. return new_variable
def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices. Args: cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.tpu.Topology object for the topology of the TPU cluster. Raises: RuntimeError: If no TPU devices found for eager execution. """ if cluster_resolver is None: cluster_resolver = TPUClusterResolver("") assert isinstance(cluster_resolver, TPUClusterResolver) tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access if tpu_name in _INITIALIZED_TPU_SYSTEMS: logging.warning("TPU system %s has already been initialized. " "Reinitializing the TPU can cause previously created " "variables on TPU to be lost.") logging.info("Initializing the TPU system.") if context.executing_eagerly(): # This function looks as it is for the following non-intuitive reasons. # tpu.initialize_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. @function.defun def _tpu_init_fn(): return tpu.initialize_system() tpu_devices = sorted( [x for x in context.list_devices() if "device:TPU:" in x]) if not tpu_devices: raise RuntimeError("Could not find any TPU devices") # Replace the remote TPU device with the remote TPU_SYSTEM system device. As # in the remote TPU device case, we will try to compile it instead of # running through optimization passes and TF Executor, but TPU_SYSTEM should # work. tpu_system_device = tpu_devices[0].replace("TPU", "TPU_SYSTEM") with ops.device(tpu_system_device): output = _tpu_init_fn() serialized_topology = output.numpy() else: master = cluster_resolver.master() session_config = config_pb2.ConfigProto(allow_soft_placement=True) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) logging.info("Finished initializing TPU system.") tpu_topology = topology.Topology(serialized=serialized_topology) _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology return tpu_topology
def doTestBasic(self, use_resource=False, use_callable_params=False): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): with self.session(graph=ops.Graph()): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) if use_resource: var0 = resource_variable_ops.ResourceVariable( var0_np, name="var0_%d" % i) var1 = resource_variable_ops.ResourceVariable( var1_np, name="var1_%d" % i) else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) grads0 = constant_op.constant(grads0_np) grads1 = constant_op.constant(grads1_np) learning_rate = lambda: 0.001 beta1 = lambda: 0.9 beta2 = lambda: 0.999 epsilon = lambda: 1e-8 if not use_callable_params: learning_rate = learning_rate() beta1 = beta1() beta2 = beta2() epsilon = epsilon() opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate) update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) opt_variables = opt.variables() beta1_power, beta2_power = opt._get_beta_accumulators() self.assertIsNotNone(beta1_power) self.assertIsNotNone(beta2_power is not None) self.assertIn(beta1_power, opt_variables) self.assertIn(beta2_power, opt_variables) if not context.executing_eagerly(): with ops.Graph().as_default(): # Shouldn't return non-slot variables from other graphs. self.assertEqual(0, len(opt.variables())) self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power, beta2_power = opt._get_beta_accumulators() # Run 3 steps of Adam for t in range(1, 4): if not context.executing_eagerly(): self.evaluate(update) elif t > 1: opt.apply_gradients(zip([grads0, grads1], [var0, var1])) self.assertAllCloseAccordingToType(0.9**(t + 1), self.evaluate(beta1_power)) self.assertAllCloseAccordingToType(0.999**(t + 1), self.evaluate(beta2_power)) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if use_resource: self.assertEqual("var0_%d/Adam:0" % (i,), opt.get_slot(var=var0, name="m").name)
def saveable_objects_for_op(op, name): """Create `SaveableObject`s from an operation. Args: op: A variable, operation, or SaveableObject to coerce into a SaveableObject. name: A string name for the SaveableObject. Yields: `SaveableObject`s which together save/restore `op`. Raises: TypeError: If `name` is not a string. ValueError: For operations with no known conversion to SaveableObject. """ if not isinstance(name, six.string_types): raise TypeError( "names_to_saveables must be a dict mapping string names to " "trackable operations. Name is not a string: %s" % name) if isinstance(op, saveable_object.SaveableObject): yield op elif isinstance(op, (list, tuple, variables.PartitionedVariable)): if isinstance(op, variables.PartitionedVariable): op = list(op) # A set of slices. slice_name = None # pylint: disable=protected-access for variable in op: if not isinstance(variable, variables.Variable): raise ValueError("Slices must all be Variables: %s" % variable) if not variable._save_slice_info: raise ValueError("Slices must all be slices: %s" % variable) if slice_name is None: slice_name = variable._save_slice_info.full_name elif slice_name != variable._save_slice_info.full_name: raise ValueError( "Slices must all be from the same tensor: %s != %s" % (slice_name, variable._save_slice_info.full_name)) if variable.op.type in [ "Variable", "VariableV2", "AutoReloadVariable" ]: yield ReferenceVariableSaveable(variable, variable._save_slice_info.spec, name) else: yield ResourceVariableSaveable(variable, variable._save_slice_info.spec, name) # pylint: enable=protected-access elif isinstance(op, trackable.Trackable) and not isinstance( op, variables.Variable): # pylint: disable=protected-access for attr, factory in op._gather_saveables_for_checkpoint().items(): if attr == trackable.VARIABLE_VALUE_KEY: # Keep original name for classes masquerading as variables. full_name = name else: full_name = name + "_" + attr op = (factory(full_name) if callable(factory) else factory) for op in saveable_objects_for_op(op, op.name): yield op # pylint: enable=protected-access else: # A variable or tensor. if isinstance(op, resource_variable_ops.ResourceVariable): # pylint: disable=protected-access if op._in_graph_mode: variable = op._graph_element else: variable = op # pylint: enable=protected-access yield ResourceVariableSaveable(variable, "", name) else: if context.executing_eagerly(): raise ValueError( "Can only save/restore ResourceVariables when " "executing eagerly, got type: %s." % type(op)) variable = ops.internal_convert_to_tensor(op, as_ref=True) if not _tensor_comes_from_variable(variable): raise TypeError( "names_to_saveables must be a dict mapping string " "names to Tensors/Variables. Not a variable: %s" % variable) if variable.op.type in [ "Variable", "VariableV2", "AutoReloadVariable" ]: yield ReferenceVariableSaveable(variable, "", name) else: yield ResourceVariableSaveable(variable, "", name)
def numpy(self): if context.executing_eagerly(): return self.read_value().numpy() else: raise NotImplementedError( "numpy() is only available when eager execution is enabled.")
def __init__(self, name, read_only_collections=True): """Construct a new FuncGraph. The graph will inherit its graph key, collections, seed, and distribution strategy stack from the current context or graph. Args: name: the name of the function. read_only_collections: whether to not write function graph collections back to default graph. Defaults to True. """ super(FuncGraph, self).__init__() self.name = name self.inputs = [] self.outputs = [] self.structured_outputs = None self._read_only_collections = read_only_collections self._weak_variables = [] self.outer_graph = ops.get_default_graph() self.captures = collections.OrderedDict() self._building_function = True # Map from resource tensor name to last op (in program order) which uses # this tensor. Used to enforce that execution order matches program order # for resource tensors. self._last_op_using_resource_tensor = {} graph = self.outer_graph # pylint: disable=protected-access # TODO(b/112906995, nareshmodi): distribution strategy depends on inheriting # this stack from the default graph even in eager mode. Maybe it should be # part of the eager context? This would also allow us to remove a # get_default_graph() call from the function cache lookup. self._distribution_strategy_stack = graph._distribution_strategy_stack # We ignore device placements from any outer scopes while tracing the # function when possible, to avoid hard-coding them in the function # graph. "Default" placements come from the PartitionedCallOp's placement, # so that the same trace of the Python function may be placed on several # different devices and saved functions may be placed on new devices when # restored. if context.executing_eagerly(): self.seed = context.global_seed() self._xla_compile = (context.context().device_spec.device_type == "TPU") if self._distribution_strategy_stack or self._xla_compile: self._add_device_to_stack(context.context().device_name) else: self.seed = graph.seed self._xla_compile = getattr(graph, "_xla_compile", False) # TODO(allenl): Figure out if we can remove colocation stack # specialization (currently used in cond_v2), here and in the cache key. self._colocation_stack = graph._colocation_stack.copy() if (self._distribution_strategy_stack or self._xla_compile or device_stack_has_callable(graph._device_function_stack)): # Hard-code devices from device functions in the function body self._device_function_stack = graph._device_function_stack.copy() if not self._read_only_collections: self._collections = graph._collections else: for collection_name in graph.get_all_collection_keys(): if collection_name not in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection( collection_name) for collection_name in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection_ref( collection_name) self._variable_creator_stack = graph._variable_creator_stack # Inherit the graph key, since this is used for matching variables in # optimizers. self._graph_key = graph._graph_key
def func_graph_from_py_func(name, python_func, args, kwargs, signature=None, func_graph=None, experimental_autograph=False, add_control_dependencies=True, arg_names=None, op_return_value=None): """Returns a `FuncGraph` generated from `python_func`. Args: name: an identifier for the function. python_func: the Python function to trace. args: the positional args with which the Python function should be called; ignored if a signature is provided. kwargs: the keyword args with which the Python function should be called; ignored if a signature is provided. signature: a possibly nested sequence of `TensorSpecs` specifying the shapes and dtypes of the arguments. When a signature is provided, `args` and `kwargs` are ignored, and `python_func` is traced with Tensors conforming to `signature`. If `None`, the shapes and dtypes are inferred from the inputs. func_graph: Optional. An instance of FuncGraph. If provided, we will use this graph else a new one is built and returned. experimental_autograph: whether to use autograph to compile `python_func`. See https://www.tensorflow.org/guide/autograph for more information. add_control_dependencies: If True, automatically adds control dependencies to ensure program order matches execution order and stateful ops always execute. arg_names: Optional list of argument names, used to give input placeholders recognizable names. op_return_value: Optional. A Tensor. If set and `python_func` returns Operations, those return values will be replaced with this value. If not set, returning an Operation triggers an error. Returns: A FuncGraph. Raises: TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. """ if op_return_value is not None: assert isinstance(op_return_value, ops.Tensor), op_return_value if func_graph is None: func_graph = FuncGraph(name) assert isinstance(func_graph, FuncGraph) if add_control_dependencies: control_manager = AutomaticControlDependencies else: control_manager = ops.NullContextmanager with func_graph.as_default(), control_manager() as a: current_scope = variable_scope.get_variable_scope() default_use_recource = current_scope.use_resource current_scope.set_use_resource(True) if signature is not None: args = signature kwargs = {} # Creates and names placeholders for all arguments. func_args = _get_defun_inputs_from_args(args, arg_names) func_kwargs = _get_defun_inputs_from_kwargs(kwargs) # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. # Variables to help check whether mutation happens in calling the function # Copy the recursive list, tuple and map structure, but not base objects func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args)) func_kwargs_before = nest.pack_sequence_as( func_kwargs, nest.flatten(func_kwargs)) def convert(x): """Converts a function output to a Tensor.""" if x is None: return None if op_return_value is not None and isinstance(x, ops.Operation): # TODO(b/79881896): we currently can't capture external control deps, so # this won't work if x needs to be captured (i.e. if python_func returns # captured Operations). with ops.control_dependencies([x]): x = array_ops.identity(op_return_value) else: try: x = ops.convert_to_tensor_or_indexed_slices(x) except (ValueError, TypeError): raise TypeError( "To be compatible with tf.contrib.eager.defun, Python functions " "must return zero or more Tensors; in compilation of %s, found " "return value of type %s, which is not a Tensor." % (str(python_func), type(x))) if add_control_dependencies: x = a.mark_as_return(x) return x this_tape = tape.push_new_tape() try: if experimental_autograph: from tensorflow.python import autograph # pylint: disable=g-import-not-at-top _, original_func = tf_decorator.unwrap(python_func) def wrapper(*args, **kwargs): return autograph.converted_call( original_func, None, autograph.ConversionOptions( verbose=True, recursive=True, strip_decorators=(function.defun, def_function.function), optional_features=(), ), *args, **kwargs) # Wrapping around a decorator allows checks like tf_inspect.getargspec # to be accurate. converted_func = tf_decorator.make_decorator(original_func, wrapper) tf_decorator.rewrap(python_func, original_func, converted_func) func_outputs = python_func(*func_args, **func_kwargs) # invariant: `func_outputs` contains only Tensors and `None`s. func_outputs = nest.map_structure(convert, func_outputs) check_mutation(func_args_before, func_args) check_mutation(func_kwargs_before, func_kwargs) finally: tape.pop_tape(this_tape) current_scope.set_use_resource(default_use_recource) # Variables in `func_args`, `func_kwargs` should be explicit inputs # to the function, not captured inputs. tape_variables = this_tape.watched_variables() arg_variables = set() inputs = [] for arg in nest.flatten(func_args) + nest.flatten(func_kwargs): if isinstance(arg, resource_variable_ops.ResourceVariable): # Even if an argument variable was not used in the function, we've # already manually captured the resource Tensor when creating argument # placeholders. resource_placeholder = func_graph.captures.pop(arg.handle) arg_variables.add(arg) inputs.append(resource_placeholder) elif isinstance(arg, ops.Tensor): inputs.append(arg) variables = [v for v in tape_variables if v not in arg_variables] func_graph.inputs = inputs + list(func_graph.captures.values()) func_graph.structured_outputs = func_outputs # Returning a closed-over tensor does not trigger convert_to_tensor. func_graph.outputs.extend( func_graph.capture(x) for x in flatten(func_graph.structured_outputs) if x is not None) func_graph.variables = variables # Register any other functions defined in the graph. with ops.init_scope(): if context.executing_eagerly(): for f in func_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? context.add_function(f._c_func.func) # pylint: disable=protected-access return func_graph
def should_skip_tpu_with_eager(distribution): return (context.executing_eagerly() and isinstance( distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)))
def call(self, inputs): W_shape = self.kernel.shape.as_list() W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1])) new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape) inputs_shape = array_ops.shape(inputs) batch_size = inputs_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 height, width = inputs_shape[h_axis], inputs_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides if self.output_padding is None: out_pad_h = out_pad_w = None else: out_pad_h, out_pad_w = self.output_padding out_height = conv_utils.deconv_output_length(height, kernel_h, padding=self.padding, output_padding=out_pad_h, stride=stride_h, dilation=self.dilation_rate[0]) out_width = conv_utils.deconv_output_length(width, kernel_w, padding=self.padding, output_padding=out_pad_w, stride=stride_w, dilation=self.dilation_rate[1]) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) else: output_shape = (batch_size, out_height, out_width, self.filters) output_shape_tensor = array_ops.stack(output_shape) outputs = K.conv2d_transpose( inputs, new_kernel, output_shape_tensor, strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if not context.executing_eagerly(): out_shape = self.compute_output_shape(inputs.shape) outputs.set_shape(out_shape) if self.use_bias: outputs = tf.nn.bias_add( outputs, self.bias, data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) if self.activation is not None: return self.activation(outputs) return outputs
def all_gather(self, input_tensor, axis, communication_hint='AUTO', timeout=0): """All-gather a dense tensor. This method must be called inside a tf.function. Args: input_tensor: a dense tensor. It must have the same rank on all replicas, and dimensions other than `axis` need to be the same as well. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). communication_hint: string providing hint to runtime for choosing collective implementation. Available options are `AUTO`, `NCCL`, and `RING`. timeout: a float. The timeout in seconds. Returns: The gathered Tensor. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError('all_gather in eager mode is not supported') with ops.device(self._device), \ ops.control_dependencies([array_ops.identity(input_tensor)]): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to # place it back. perm_pre = array_ops.concat( ([axis], math_ops.range(axis), math_ops.range(axis + 1, array_ops.rank(input_tensor))), axis=0) input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) # 2. Pad gathered_shape = self._all_gather(array_ops.expand_dims_v2( array_ops.shape_v2(input_tensor_t), axis=0), communication_hint, timeout=timeout) first_dims = gathered_shape[:, 0] full_axis_dim = math_ops.reduce_max(first_dims) padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) # 3. Gather gather_padded_out_tensor = self._all_gather(padded_input_tensor, communication_hint, timeout=timeout) # 4. Unpad split_tensors = [] for i in range(self._group_size): start_pos = i * full_axis_dim split_tensors.append( gather_padded_out_tensor[start_pos:start_pos + first_dims[i]]) out_tensor_t = array_ops.concat(split_tensors, 0) # 5. Transpose back perm_after = array_ops.concat( (math_ops.range(1, axis + 1), [0], math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), axis=0) return array_ops.transpose(out_tensor_t, perm=perm_after)
def _allow_variable_partition(self): return not context.executing_eagerly()
def _maybe_run_in_function(fn, run_in_function=False): if not run_in_function or not context.executing_eagerly(): return fn else: return def_function.function()(fn)
def testSaveNormalRestoreMirrored(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") save_path = self._save_normal() self._restore_mirrored(save_path)
def _tpu_function_creator(self, fn): if fn in self._tpu_function_cache: return self._tpu_function_cache[fn] strategy = self._container_strategy() def tpu_function(args, kwargs): """TF Function used to replicate the user computation.""" if kwargs is None: kwargs = {} # Remove None at the end of args as they are not replicatable # If there are None in the middle we can't do anything about it # so let those cases fail. # For example when Keras model predict is used they pass the targets as # None. We want to handle it here so all client libraries don't have to # do this as other strategies can handle None values better. while args and args[-1] is None: args = args[:-1] # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: if tensor_util.is_tensor(input_tensor): maximum_shape = input_tensor.get_shape() else: maximum_shape = tensor_shape.TensorShape(np.shape(input_tensor)) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. if result[0] is None: replicate_outputs = [None] * len(replicate_outputs) else: replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs) if context.executing_eagerly(): tpu_function = def_function.function(tpu_function) self._tpu_function_cache[fn] = tpu_function return tpu_function
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): """`apply_gradients` using a `DistributionStrategy`.""" reduced_grads = distribution.extended.batch_reduce_to( ds_reduce_util.ReduceOp.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) def apply_grad_to_update_var(var, grad): """Apply gradient to variable.""" if isinstance(var, ops.Tensor): raise NotImplementedError("Trying to update a Tensor ", var) apply_kwargs = {} if isinstance(grad, ops.IndexedSlices): if var.constraint is not None: raise RuntimeError( "Cannot use a constraint function on a sparse variable." ) if "apply_state" in self._sparse_apply_args: apply_kwargs["apply_state"] = apply_state return self._resource_apply_sparse_duplicate_indices( grad.values, var, grad.indices, **apply_kwargs) if "apply_state" in self._dense_apply_args: apply_kwargs["apply_state"] = apply_state update_op = self._resource_apply_dense(grad, var, **apply_kwargs) if var.constraint is not None: with ops.control_dependencies([update_op]): return var.assign(var.constraint(var)) else: return update_op update_ops = [] with backend.name_scope(name or self._name): for grad, var in grads_and_vars: scope_name = ("update" if ops.executing_eagerly_outside_functions() else "update_" + var.op.name) # Colocate the update with variables to avoid unnecessary communication # delays. See b/136304694. with backend.name_scope( scope_name), distribution.extended.colocate_vars_with( var): update_ops.extend( distribution.extended.update(var, apply_grad_to_update_var, args=(grad, ), group=False)) any_symbolic = any( isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i) for i in update_ops) if not context.executing_eagerly() or any_symbolic: # If the current context is graph mode or any of the update ops are # symbolic then the step update should be carried out under a graph # context. (eager updates execute immediately) with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access with ops.control_dependencies(update_ops): return self._iterations.assign_add(1).op return self._iterations.assign_add(1)
def get_layer_class(): if context.executing_eagerly(): return integer_lookup.IntegerLookup else: return integer_lookup_v1.IntegerLookup
def restore_ops(self): """Create or fetch restore ops for this object's attributes. Requires that the `Checkpointable` Python object has been bound to an object ID in the checkpoint. Returns: A list of operations when graph building, or an empty list when executing eagerly. """ saveables = self.checkpointable._gather_saveables_for_checkpoint() # pylint: disable=protected-access # Name saveables based on the name this object had when it was checkpointed. named_saveables = {} restore_ops = [] building_graph = not context.executing_eagerly() for serialized_tensor in self.object_proto.attributes: saveable_factory = saveables.get(serialized_tensor.name, None) if saveable_factory is None: # Purposefully does not throw an exception if attributes have been added # or deleted. Stores unused attributes so an exception can be raised if # the user decides to check that everything in the checkpoint was # loaded. self._checkpoint.unused_attributes.setdefault( self.checkpointable, []).append(serialized_tensor.name) continue if building_graph: existing_ops = self._checkpoint.restore_ops_by_name.get( serialized_tensor.name, None) else: existing_ops = None if existing_ops is None: if callable(saveable_factory): saveable = saveable_factory( name=serialized_tensor.checkpoint_key) else: saveable = saveable_factory named_saveables[serialized_tensor.checkpoint_key] = saveable if named_saveables: validated_saveables = ( self._checkpoint.builder._ValidateAndSliceInputs( named_saveables)) # pylint: disable=protected-access validated_names = set(saveable.name for saveable in validated_saveables) if set(named_saveables.keys()) != validated_names: raise AssertionError( ("Saveable keys changed when validating. Got back %s, was " "expecting %s") % (named_saveables.keys(), validated_names)) all_tensors = self._checkpoint.builder.bulk_restore( filename_tensor=self._checkpoint.save_path, saveables=validated_saveables, preferred_shard=-1, restore_sequentially=False) saveable_index = 0 for saveable in validated_saveables: num_specs = len(saveable.specs) saveable_tensors = all_tensors[saveable_index:saveable_index + num_specs] saveable_index += num_specs restore_op = saveable.restore(saveable_tensors, restored_shapes=None) if building_graph: assert saveable.name not in self._checkpoint.restore_ops_by_name self._checkpoint.restore_ops_by_name[ saveable.name] = restore_op restore_ops.append(restore_op) return restore_ops
def testMirroredContainer(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") v, devices, mirrored = _make_mirrored() result = values.regroup(dict(zip(devices, v))) self.assertIs(mirrored, result)
def bar(): self.assertTrue(context.executing_eagerly())
def testSaveNormalRestoreTowerLocalSum(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") save_path = self._save_normal() self._restore_tower_local_sum(save_path)
def foo(): self.assertFalse(context.executing_eagerly())
def model_iteration(model, data, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, validation_freq=1, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=False, initial_epoch=0, mode=ModeKeys.TRAIN, batch_size=None, steps_name='steps', **kwargs): """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. Arguments: model: Keras Model instance. data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or `(x, y, sample_weights)`) or a generator or `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. epochs: Number of times to iterate over the data. verbose: 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. Note that the progress bar is not particularly useful when logged to a file, so verbose=2 is recommended when not running interactively (eg, in a production environment). callbacks: List of callbacks to be called during training. validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or `(x, y, sample_weights)`) or a generator or `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. validation_steps: Total number of steps (batches of samples) before declaring validation finished. validation_freq: Only relevant if validation data is provided. Integer or `collections.abc.Container` instance (e.g. list, tuple, etc.). If an integer, specifies how many training epochs to run before a new validation run is performed, e.g. `validation_freq=2` runs validation every 2 epochs. If a Container, specifies the epochs on which to run validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the end of the 1st, 2nd, and 10th epochs. class_weight: Dictionary mapping class indices to a weight for the class. max_queue_size: Integer. Maximum size for the generator queue. If unspecified, `max_queue_size` will default to 10. workers: Integer. Maximum number of processes to spin up when using process-based threading. If unspecified, `workers` will default to 1. If 0, will execute the generator on the main thread. use_multiprocessing: Boolean. If `True`, use process-based threading. If unspecified, `use_multiprocessing` will default to `False`. Note that because this implementation relies on multiprocessing, you should not pass non-picklable arguments to the generator as they can't be passed easily to children processes. shuffle: Boolean. Whether to shuffle the order of the batches at the beginning of each epoch. Only used with instances of `Sequence` (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not `None`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run). mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. batch_size: Integer batch size or None if unknown. Will only be used if `data` is in NumPy/Tensor format. steps_name: The string name of the steps argument, either `steps`, `validation_steps`, or `steps_per_epoch`. Only used for error message formatting. **kwargs: Additional arguments for backwards compatibility. `steps` is accepted as an alias for `steps_per_epoch`. Returns: - In TRAIN mode: `History` object. - In TEST mode: Evaluation metrics. - In PREDICT mode: Outputs of the Model called on inputs. Raises: ValueError: in case of invalid arguments. """ if 'steps' in kwargs: steps_per_epoch = kwargs['steps'] # Determine the number of steps per epoch and whether we should reset the # dataset at the end of each epoch. reset_dataset_after_each_epoch = False original_dataset = None is_dataset = isinstance(data, (dataset_ops.DatasetV2, dataset_ops.DatasetV1)) if is_dataset: original_dataset = data if steps_per_epoch is None: reset_dataset_after_each_epoch = True steps_per_epoch = training_utils_v1.infer_steps_for_dataset( model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name) # Convert to a format that supports `next(generator)`. generator, steps_per_epoch = convert_to_generator_like( data, steps_per_epoch=steps_per_epoch, batch_size=batch_size, epochs=epochs - initial_epoch, shuffle=shuffle) do_validation = validation_data is not None is_sequence = isinstance(generator, data_utils.Sequence) _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, steps_per_epoch, validation_data, validation_steps, mode, kwargs) batch_function = _make_execution_function(model, mode, class_weight=class_weight) # Create the queue for the generator. enqueuer = None if not is_dataset: generator, enqueuer = _make_enqueued_generator( generator, workers=workers, use_multiprocessing=use_multiprocessing, max_queue_size=max_queue_size, shuffle=shuffle) num_samples_or_steps, use_steps = _get_num_samples_or_steps( data, steps_per_epoch) count_mode = 'steps' if use_steps else 'samples' callbacks = cbks.configure_callbacks(callbacks, model, do_validation=do_validation, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, samples=num_samples_or_steps, count_mode=count_mode, verbose=verbose, mode=mode) if mode == ModeKeys.PREDICT: aggregator = training_utils_v1.OutputsAggregator(True, steps=steps_per_epoch) else: aggregator = training_utils_v1.MetricsAggregator(True, steps=steps_per_epoch) should_set_learning_phase = context.executing_eagerly( ) and model.run_eagerly if should_set_learning_phase: learning_phase_scope = backend.eager_learning_phase_scope( 1 if mode == ModeKeys.TRAIN else 0) learning_phase_scope.__enter__() callbacks.model.stop_training = False callbacks._call_begin_hook(mode) initial_epoch = model._maybe_load_initial_epoch_from_ckpt( initial_epoch, mode) for epoch in range(initial_epoch, epochs): if callbacks.model.stop_training: break # Setup work for each epoch. model.reset_metrics() epoch_logs = {} if mode == ModeKeys.TRAIN: callbacks.on_epoch_begin(epoch, epoch_logs) if steps_per_epoch is None: # Loop over dataset until `OutOfRangeError` is raised. target_steps = np.inf else: # Loop over dataset for the specified number of steps. target_steps = steps_per_epoch step = 0 while step < target_steps: batch_data = _get_next_batch(generator) if batch_data is None: if is_dataset: # The dataset passed by the user ran out of batches. # Now we know the cardinality of the dataset. # If steps_per_epoch was specified, then running out of data is # unexpected, so we stop training and inform the user. if steps_per_epoch: callbacks.model.stop_training = True logging.warning( 'Your dataset ran out of data; interrupting training. ' 'Make sure that your dataset can generate at least ' '`%s * epochs` batches (in this case, %d batches). ' 'You may need to use the repeat() function when ' 'building your dataset.' % (steps_name, steps_per_epoch * epochs)) elif step > 0: steps_per_epoch = step aggregator.steps = steps_per_epoch else: # We ran out of batches while the user passed an iterator (legacy). callbacks.model.stop_training = True logging.warning( 'Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your iterator ' 'can generate at least `%s * epochs` ' 'batches (in this case, %d batches). You may need to' 'use the repeat() function when building your ' 'dataset.' % (steps_name, steps_per_epoch * epochs)) break # `batch_size` used for validation data if validation # data is NumPy/EagerTensors. batch_size = int(nest.flatten(batch_data)[0].shape[0]) # Callbacks batch begin. batch_logs = {'batch': step, 'size': batch_size} callbacks._call_batch_hook(mode, 'begin', step, batch_logs) is_deferred = not model._is_compiled batch_outs = batch_function(*batch_data) if not isinstance(batch_outs, list): batch_outs = [batch_outs] if step == 0: aggregator.create(batch_outs) if is_deferred: # Set callbacks params. We do this here when model is compiled only # in the first iteration of this loop (deferred build scenario). cbks.set_callback_parameters( callbacks, model, do_validation=do_validation, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=num_samples_or_steps, verbose=verbose, mode=mode) # Aggregate results. aggregator.aggregate(batch_outs) # Callbacks batch end. batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) callbacks._call_batch_hook(mode, 'end', step, batch_logs) step += 1 if callbacks.model.stop_training: break aggregator.finalize() results = aggregator.results epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) if len(results) == 1: results = results[0] # Run the test loop every epoch during training. if (do_validation and training_utils_v1.should_run_validation( validation_freq, epoch) and not callbacks.model.stop_training): val_results = model_iteration( model, validation_data, steps_per_epoch=validation_steps, batch_size=batch_size, class_weight=class_weight, workers=workers, use_multiprocessing=use_multiprocessing, max_queue_size=max_queue_size, callbacks=callbacks, verbose=verbose, mode=ModeKeys.TEST, steps_name='validation_steps') if not isinstance(val_results, list): val_results = [val_results] epoch_logs = cbks.make_logs(model, epoch_logs, val_results, mode, prefix='val_') if mode == ModeKeys.TRAIN: # Epochs only apply to `fit`. callbacks.on_epoch_end(epoch, epoch_logs) # Recreate dataset iterator for the next epoch. if reset_dataset_after_each_epoch and epoch < epochs - 1: generator = dataset_ops.make_one_shot_iterator(original_dataset) model._successful_loop_finish = True callbacks._call_end_hook(mode) if enqueuer is not None: enqueuer.stop() if should_set_learning_phase: learning_phase_scope.__exit__(None, None, None) if mode == ModeKeys.TRAIN: return model.history return results
def has_symbolic_tensors(ls): if context.executing_eagerly(): return False if isinstance(ls, (list, tuple)): return any(tensor_util.is_tensor(v) for v in ls) return tensor_util.is_tensor(ls)
def _initialize_multi_worker(self, cluster_resolver): """Initializes the object for multi-worker training.""" cluster_spec = multi_worker_util.normalize_cluster_spec( cluster_resolver.cluster_spec()) task_type = cluster_resolver.task_type task_id = cluster_resolver.task_id if task_type is None or task_id is None: raise ValueError("When `cluster_spec` is given, you must also specify " "`task_type` and `task_id`.") self._cluster_spec = cluster_spec self._task_type = task_type self._task_id = task_id self._id_in_cluster = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) if not self._num_workers: raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found " "in `cluster_spec`.") self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, task_id) self._worker_device = "/job:%s/task:%d" % (task_type, task_id) self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) if (ops.executing_eagerly_outside_functions() and not getattr(self, "_local_or_standalone_client_mode", False)): context.context().configure_collective_ops( collective_leader=multi_worker_util.collective_leader( cluster_spec, task_type, task_id), scoped_allocator_enabled_ops=("CollectiveReduce",), device_filters=("/job:%s/task:%d" % (task_type, task_id),)) self._collective_ops_configured = True if context.context().coordination_service is None: coordination_service = remote_utils.coordination_service_type( cluster_resolver.rpc_layer) coordinated_jobs = ["chief", "worker"] if coordination_service and task_type in coordinated_jobs: context.context().configure_coordination_service( service_type=coordination_service, service_leader=multi_worker_util.coordination_leader( cluster_spec), coordinated_jobs=coordinated_jobs) # Starting a std server in eager mode and in independent worker mode. if (context.executing_eagerly() and not getattr(self, "_std_server_started", False) and not getattr(self, "_local_or_standalone_client_mode", False)): # Checking _local_or_standalone_client_mode as well because we should not # create the std server in standalone client mode. config_proto = copy.deepcopy(context.context().config) config_proto = self._update_config_proto(config_proto) # If coordination service is enabled, use its internal heartbeat to detect # peer failures instead of the Python-level health check. if config_proto.experimental.coordination_config.service_type: self._enable_check_health = False if hasattr(cluster_resolver, "port"): port = cluster_resolver.port else: port = 0 server_def = tensorflow_server_pb2.ServerDef( cluster=cluster_spec.as_cluster_def(), default_session_config=config_proto, job_name=task_type, task_index=task_id, protocol=cluster_resolver.rpc_layer or "grpc", port=port) context.context().enable_collective_ops(server_def) self._std_server_started = True # The `ensure_initialized` is needed before calling # `context.context().devices()`. context.context().ensure_initialized() logging.info( "Enabled multi-worker collective ops with available devices: %r", context.context().devices()) # TODO(yuefengz): The `num_gpus` is only for this particular task. It # assumes all workers have the same number of GPUs. We should remove this # assumption by querying all tasks for their numbers of GPUs. # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in # some cases. local_devices, local_device_type = self._initialize_local_devices( cluster_resolver, self._worker_device) if local_device_type == "TPU": tpu_strategy_util.initialize_tpu_system() self._collective_keys = cross_device_utils.CollectiveKeys( group_key_start=1 + self._collective_key_base) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=local_devices, group_size=len(local_devices) * self._num_workers, options=self._communication_options, collective_keys=self._collective_keys) # CrossDeviceOps for per host tensors. self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=[self._worker_device], group_size=self._num_workers, options=self._communication_options, collective_keys=self._collective_keys) super(CollectiveAllReduceExtended, self)._initialize_single_worker( local_devices) # Add a default device so that ops without specified devices will not end up # on other workers. self._default_device = "/job:%s/task:%d" % (task_type, task_id) # Save the num_devices_per_worker and rpc_layer for configure method. self._num_devices_per_worker = len(local_devices) self._local_device_type = local_device_type self._rpc_layer = cluster_resolver.rpc_layer self._warn_nccl_no_gpu() if self._enable_check_health and context.executing_eagerly(): self._start_check_health_thread() else: logging.info("Check health not enabled.") logging.info( "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " "task_id = %r, num_workers = %r, local_devices = %r, " "communication = %s", cluster_spec.as_dict(), task_type, task_id, self._num_workers, local_devices, self._communication_options.implementation)
def _testMinimizeLoss(self, distribution): if context.executing_eagerly(): self._test_minimize_loss_eager(distribution) else: self._test_minimize_loss_graph(distribution, learning_rate=0.05)
def py_func(func, inp, Tout, stateful=True, name=None): """Wraps a python function and uses it as a TensorFlow op. Given a python function `func`, which takes numpy arrays as its arguments and returns numpy arrays as its outputs, wrap this function as an operation in a TensorFlow graph. The following snippet constructs a simple TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation in the graph: ```python def my_func(x): # x will be a numpy array with the contents of the placeholder below return np.sinh(x) input = tf.placeholder(tf.float32) y = tf.py_func(my_func, [input], tf.float32) ``` **N.B.** The `tf.py_func()` operation has the following known limitations: * The body of the function (i.e. `func`) will not be serialized in a `GraphDef`. Therefore, you should not use this function if you need to serialize your model and restore it in a different environment. * The operation must run in the same address space as the Python program that calls `tf.py_func()`. If you are using distributed TensorFlow, you must run a `tf.train.Server` in the same process as the program that calls `tf.py_func()` and you must pin the created operation to a device in that server (e.g. using `with tf.device():`). Args: func: A Python function, which accepts `ndarray` objects as arguments and returns a list of `ndarray` objects (or a single `ndarray`). This function must accept as many arguments as there are tensors in `inp`, and these argument types will match the corresponding `tf.Tensor` objects in `inp`. The returns `ndarray`s must match the number and types defined `Tout`. Important Note: Input and output numpy `ndarray`s of `func` are not guaranteed to be copies. In some cases their underlying memory will be shared with the corresponding TensorFlow tensors. In-place modification or storing `func` input or return values in python datastructures without explicit (np.)copy can have non-deterministic consequences. inp: A list of `Tensor` objects. Tout: A list or tuple of tensorflow data types or a single tensorflow data type if there is only one, indicating what `func` returns. stateful: (Boolean.) If True, the function should be considered stateful. If a function is stateless, when given the same input it will return the same output and have no observable side effects. Optimizations such as common subexpression elimination are only performed on stateless operations. name: A name for the operation (optional). Returns: A list of `Tensor` or a single `Tensor` which `func` computes. """ if context.executing_eagerly(): result = func(*[x.numpy() for x in inp]) result = nest.flatten(result) result = [x if x is None else ops.convert_to_tensor(x) for x in result] if len(result) == 1: # Mimic the automatic unwrapping in graph-mode py_func result, = result return result return _internal_py_func(func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
def all_reduce_indexed_slices(self, input_slices, communication_hint='AUTO', timeout=0): """All-reduce an IndexedSlices. This method must be called inside a tf.function. Args: input_slices: an IndexedSlices. communication_hint: string providing hint to runtime for choosing collective implementation. timeout: a float. The timeout in seconds. Returns: The reduced IndexedSlices. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError( 'all_reduce_indexed_slices in eager mode is not supported') # Current CollectiveAllGather implementations require input IndexedSlices to # have consistent length across the board, we handle the reduction of # IndexedSlices as follows: # 1. Gather the lengths of IndexedSlices from all participants. # 2. If they have consistent length, apply all_gather. # 3. Otherwise convert IndexedSlices to dense tensors and apply # all_reduce. with ops.device(self._device): def all_gather(): """Use all_gather to aggregate `IndexedSlices`.""" all_values = self._all_gather(input_slices.values, communication_hint, timeout=timeout) # Add control dependency to order the all-gather. control = [all_values] if communication_hint == 'NCCL' else [] with ops.control_dependencies(control): all_indices = self._all_gather(input_slices.indices, communication_hint, timeout=timeout) return ops.IndexedSlices(values=all_values, indices=all_indices, dense_shape=input_slices.dense_shape) def densify_and_all_reduce(): """Use all_reduce to aggregate `IndexedSlices`.""" densified = ops.convert_to_tensor(input_slices) reduced = self.all_reduce( densified, communication_hint=communication_hint, timeout=timeout) # We have to convert dense grad to IndexedSlice because all_reduce() # and all_gather() must have the same return type as required by # control_flow_ops.cond. return ops.IndexedSlices(values=reduced, indices=math_ops.range( array_ops.shape(reduced)[0]), dense_shape=input_slices.dense_shape) length = array_ops.shape(input_slices.indices) all_lengths = self._all_gather(length, communication_hint, timeout=timeout) return control_flow_ops.cond( math_ops.equal(math_ops.reduce_max(all_lengths), math_ops.reduce_min(all_lengths)), all_gather, densify_and_all_reduce)