def _assert_in_default_state(t): t.assertIs(ds_context._get_default_replica_context(), ds_context.get_replica_context()) t.assertIs(None, ds_context.get_cross_replica_context()) t.assertFalse(ds_context.in_cross_replica_context()) t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy()) t.assertFalse(ds_context.has_strategy())
def get_updates(self, loss, params): if distribution_strategy_context.has_strategy(): self.updates = [] if not params: # After the model vars have been created, the second call to get_updates # is called with params as an empty list. This ensures that we call # compute_gradients with params=None. grads = self.optimizer.compute_gradients(loss) else: grads = self.optimizer.compute_gradients(loss, params) global_step = training_util.get_global_step() opt_update = self.optimizer.apply_gradients(grads, global_step) else: if not params: self.updates = [state_ops.assign_add(self.iterations, 1)] return self.updates # Updates list starts out empty because the iterations variable is # incremented in optimizer.apply_gradients() self.updates = [] grads = self.optimizer.compute_gradients(loss, params) opt_update = self.optimizer.apply_gradients( grads, global_step=self.iterations) self.updates.append(opt_update) return self.updates
def merge_fn(dist, s): self.assertIs(ds_context._get_default_strategy(), dist) self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertIs(dist, ds_context.get_strategy()) self.assertFalse(ds_context.has_strategy()) return "foo_" + s
def _clip_gradients(self, grads): """Clip gradients according to the clipnorm and clipvalue attributes.""" if self.clipnorm is not None: if distribute_ctx.has_strategy(): raise ValueError( "Gradient clipping in the optimizer " "(by setting clipnorm or clipvalue) is currently " "unsupported when using a distribution strategy.") grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] if self.clipvalue is not None: if distribute_ctx.has_strategy(): raise ValueError( "Gradient clipping in the optimizer " "(by setting clipnorm or clipvalue) is currently " "unsupported when using a distribution strategy.") grads = [ clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) for g in grads ] return grads
def _moments(self, inputs, reduction_axes, keep_dims): mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. # TODO(b/130185866): Support zero batch input in graph mode. if (ops.executing_eagerly_outside_functions() and distribution_strategy_context.has_strategy()): inputs_size = array_ops.size(inputs) mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean)) variance = array_ops.where(inputs_size > 0, variance, K.zeros_like(variance)) return mean, variance
def decorated(metric_obj, *args): """Decorated function with merge_call.""" has_strategy = distribution_strategy_context.has_strategy() replica_context = distribution_strategy_context.get_replica_context() if not has_strategy or replica_context is None: raw_result = result_fn(*args) # Results need to be wrapped in a `tf.identity` op to ensure # correct execution order. if isinstance(raw_result, (ops.Tensor, variables_module.Variable, float, int)): result_t = array_ops.identity(raw_result) elif isinstance(raw_result, dict): result_t = { key: array_ops.identity(value) for key, value in raw_result.items() } else: try: result_t = array_ops.identity(raw_result) except (ValueError, TypeError): raise RuntimeError( 'The output of `metric.result()` can only be a single ' 'Tensor/Variable, or a dict of Tensors/Variables. ' 'For metric %s, got result %s.' % (metric_obj.name, raw_result)) else: # TODO(psv): Test distribution of metrics using different distribution # strategies. # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn # with distribution object as the first parameter. We create a wrapper # here so that the result function need not have that parameter. def merge_fn_wrapper(distribution, merge_fn, *args): # We will get `PerReplica` merge function. Taking the first one as all # are identical copies of the function that we had passed below. result = distribution.experimental_local_results(merge_fn)[0]( *args) # Wrapping result in identity so that control dependency between # update_op from `update_state` and result works in case result returns # a tensor. return array_ops.identity(result) # Wrapping result in merge_call. merge_call is used when we want to leave # replica mode and compute a value in cross replica mode. result_t = replica_context.merge_call(merge_fn_wrapper, args=(result_fn, ) + args) # We are saving the result op here to be used in train/test execution # functions. This basically gives the result op that was generated with a # control dep to the updates for these workflows. metric_obj._call_result = result_t return result_t
def _get_tensor(self, is_finite): tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN')) if not distribution_strategy_context.has_strategy(): return tensor def get(): rep_id = (distribution_strategy_context.get_replica_context() .replica_id_in_sync_group) return control_flow_ops.cond(math_ops.equal(rep_id, 0), lambda: tensor, lambda: 1.) distribution = distribution_strategy_context.get_strategy() return distribution.extended.call_for_each_replica(get)
def apply_gradients(self, grads_and_vars, global_step=None, name=None): assignments = [] for (grad, param) in grads_and_vars: if grad is None or param is None: continue param_name = _get_variable_name(param.name) m = tf.get_variable(name=param_name + "/momentum", shape=param.shape.as_list(), dtype=param.dtype, trainable=False, initializer=tf.zeros_initializer()) next_m = self.momentum * m + grad update = next_m # update is scaled by loss_scaling # so we need to restore it's scale update /= self.loss_scaling if _do_use_weight_decay(param_name, self.weight_decay_rate, self.exclude_from_weight_decay): update += self.weight_decay_rate * param update_with_lr = self.learning_rate * update next_param = param - update_with_lr if distribute_ctx.has_strategy(): # Handle DistributionStrategy case. if distribute_ctx.in_cross_replica_context(): raise RuntimeError( "Use `_distributed_apply()` instead of " "`apply_gradients()` in a cross-replica context.") assign_params = distribute_ctx.get_replica_context( ).merge_call(assign_vars, args=((param, m), (next_param, next_m))) else: assign_params = [param.assign(next_param), m.assign(next_m)] assignments.extend(assign_params) if _need_centering(param_name, self.darknet_gn, self.upsample_gn): with tf.control_dependencies(assign_params): param_identity = tf.identity(param) centering_op = _centering_weights(param, param_identity) assignments.append(centering_op) if self.use_moving_avg: assignments.extend( _create_moving_avg(grads_and_vars, self.moving_avg_decay)) return tf.group(*assignments, name=name)
def run_fn(): replica_context = ds_context.get_replica_context() self.assertTrue(replica_context is not None) self.assertIs(None, ds_context.get_cross_replica_context()) self.assertFalse(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def run_fn(): replica_context = ds_context.get_replica_context() self.assertTrue(replica_context is not None) self.assertIs(None, ds_context.get_cross_replica_context()) self.assertFalse(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def _moments(self, inputs, reduction_axes, keep_dims): mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. if distribution_strategy_context.has_strategy(): inputs_size = array_ops.size(inputs) mean = tf_utils.smart_cond(inputs_size > 0, lambda: mean, lambda: K.zeros_like(mean)) variance = tf_utils.smart_cond(inputs_size > 0, lambda: variance, lambda: K.zeros_like(variance)) return mean, variance
def merge_fn(dist, s): self.assertIs( distribution_strategy_context._get_default_strategy(), dist) self.assertIs(None, distribution_strategy_context.get_replica_context()) self.assertIs( dist, distribution_strategy_context.get_cross_replica_context()) self.assertTrue( distribution_strategy_context.in_cross_replica_context()) self.assertIs(dist, distribution_strategy_context.get_strategy()) self.assertFalse(distribution_strategy_context.has_strategy()) return "foo_" + s
def _moments(self, inputs, reduction_axes, keep_dims): mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. if distribution_strategy_context.has_strategy( ) and not inputs.shape.is_fully_defined(): inputs_size = array_ops.size(inputs) mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean)) variance = array_ops.where(inputs_size > 0, variance, K.zeros_like(variance)) return mean, variance
def __init__(self, copy_from=None, state=None, alg=None): """Creates a generator. The new generator will be initialized by one of the following ways, with decreasing precedence: (1) If `copy_from` is not None, the new generator is initialized by copying information from another generator. (2) If `state` and `alg` are not None (they must be set together), the new generator is initialized by a state. Args: copy_from: a generator to be copied from. state: a vector of dtype STATE_TYPE representing the initial state of the RNG, whose length and semantics are algorithm-specific. If it's a variable, the generator will reuse it instead of creating a new variable. alg: the RNG algorithm. Possible values are `tf.random.Algorithm.PHILOX` for the Philox algorithm and `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3' [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]). The string names `"philox"` and `"threefry"` can also be used. Note `PHILOX` guarantees the same numbers are produced (given the same random state) across all architectures (CPU, GPU, XLA etc). """ # TODO(b/175072242): Remove distribution-strategy dependencies in this file. if ds_context.has_strategy(): self._distribution_strategy = ds_context.get_strategy() else: self._distribution_strategy = None if copy_from is not None: # All other arguments should be None assert (alg or state) is None self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE, trainable=False) self._alg = copy_from.algorithm else: assert alg is not None and state is not None alg = stateless_random_ops.convert_alg_to_int(alg) if isinstance(state, variables.Variable): _check_state_shape(state.shape, alg) self._state_var = state else: state = _convert_to_state_tensor(state) _check_state_shape(state.shape, alg) self._state_var = self._create_variable(state, dtype=STATE_TYPE, trainable=False) self._alg = alg
def update(self, grads): """Updates loss scale based on if gradients are finite in current step.""" grads = nest.flatten(grads) if distribution_strategy_context.has_strategy(): distribution = distribution_strategy_context.get_cross_replica_context( ) def get_is_finite(grads): is_finite = _is_all_finite(grads) # We cast to float, because we cannot reduce booleans with # DistributionStrategy. return math_ops.cast(is_finite, dtypes.float32) is_finite_float = distribution.extended.call_for_each_replica( get_is_finite, args=(grads, )) reduced_is_finite_float = distribution.reduce( reduce_util.ReduceOp.SUM, is_finite_float, axis=None) is_finite = math_ops.equal(reduced_is_finite_float, distribution.num_replicas_in_sync) else: is_finite = _is_all_finite(grads) def update_if_finite_grads(): """Update assuming the gradients are finite.""" def incr_loss_scale(): new_loss_scale = math_ops.minimum( self._current_loss_scale * self._multiplier, 2**32) return control_flow_ops.group( _assign_if_finite(self._current_loss_scale, new_loss_scale), self._num_good_steps.assign(0)) return control_flow_ops.cond( self._num_good_steps + 1 >= self._increment_period, incr_loss_scale, lambda: _op_in_graph_mode(self._num_good_steps.assign_add(1))) def update_if_not_finite_grads(): """Update assuming the gradients are nonfinite.""" new_loss_scale = math_ops.maximum( self._current_loss_scale / self._multiplier, 1) return control_flow_ops.group( self._num_good_steps.assign(0), self._current_loss_scale.assign(new_loss_scale)) update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, update_if_not_finite_grads) should_apply_gradients = is_finite return update_op, should_apply_gradients
def model_fn(features, labels, mode): """model_fn for keras Estimator.""" model = _clone_and_build_model(mode, keras_model, custom_objects, features, labels) model_output_names = [] # We need to make sure that the output names of the last layer in the model # is the same for each of the cloned models. This is required for mirrored # strategy when we call regroup. if distribution_strategy_context.has_strategy(): for name in model.output_names: name = re.compile(r'_\d$').sub('', name) model_output_names.append(name) else: model_output_names = model.output_names # Get inputs to EstimatorSpec predictions = dict(zip(model_output_names, model.outputs)) loss = None train_op = None eval_metric_ops = None # Set loss and metric only during train and evaluate. if mode is not ModeKeys.PREDICT: if mode is ModeKeys.TRAIN: model._make_train_function() # pylint: disable=protected-access else: model._make_test_function() # pylint: disable=protected-access loss = model.total_loss eval_metric_ops = _convert_keras_metrics_to_estimator(model) # Set train_op only during train. if mode is ModeKeys.TRAIN: train_op = model.train_function.updates_op if not model._is_graph_network: # Reset model state to original state, # to avoid `model_fn` being destructive for the initial model argument. models.in_place_subclassed_model_state_restoration(keras_model) return model_fn_lib.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs={ _DEFAULT_SERVING_KEY: export_lib.PredictOutput(predictions) })
def _get_tensor(self, is_finite): tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN')) if not distribution_strategy_context.has_strategy(): return tensor def get(): rep_id = ( distribution_strategy_context.get_replica_context() .replica_id_in_sync_group) return control_flow_ops.cond( math_ops.equal(rep_id, 0), lambda: tensor, lambda: 1.) distribution = distribution_strategy_context.get_strategy() return distribution.extended.call_for_each_replica(get)
def enumerate_epochs(self): """Yields `(epoch, tf.data.Iterator)`.""" data_iterator = iter(self._dataset) for epoch in range(self._initial_epoch, self._epochs): if self._insufficient_data: # Set by `catch_stop_iteration`. break if self._adapter.should_recreate_iterator(): if ds_context.has_strategy(): # TODO(b/138326910): remove this when MultiDeviceIterator is a # CompositeTensor (unless this is more efficient) data_iterator._initializer # pylint: disable=pointless-statement, protected-access else: data_iterator = iter(self._dataset) yield epoch, data_iterator self._adapter.on_epoch_end()
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def default_model_compile(model, lr, loss='mean_absolute_error'): opt_kwargs = {} precision_policy = mixed_precision.global_policy() distributed = distribute_ctx.has_strategy() if precision_policy.loss_scale is None and not distributed: opt_kwargs['clipnorm'] = 1. if loss == 'compound_mssim': loss = compound_l1_mssim_loss elif loss == 'mssim': loss = partial(compound_l1_mssim_loss, alpha=0.9999) loss.__name__ = "mssim" model.compile( optimizer=tfa.optimizers.RectifiedAdam(lr=lr, **opt_kwargs), loss=loss, metrics=['mean_squared_error', keras_psnr, keras_ssim], )
def _assign_moving_average(self, variable, value, momentum, inputs_size): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: with ops.colocate_with(variable): decay = ops.convert_to_tensor(1.0 - momentum, name='decay') if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay # TODO(b/129279393): Support zero batch input in non # DistributionStrategy code as well. if distribution_strategy_context.has_strategy(): update_delta = tf_utils.smart_cond( inputs_size > 0, lambda: update_delta, lambda: K.zeros_like(update_delta)) return state_ops.assign_sub(variable, update_delta, name=scope)
def compute_weighted_loss(losses, sample_weight=None, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE): if distribution_strategy_context.has_strategy() and \ reduction in {tf.keras.losses.Reduction.AUTO, tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE}: raise ValueError( 'Please use `tf.keras.losses.Reduction.SUM` or `tf.keras.losses.Reduction.NONE` for loss reduction when ' 'losses are used with `tf.distribute.Strategy` outside of the built-in training loops. You can implement ' '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch size like:\n' '```\n' 'with strategy.scope():\n' ' loss_obj = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)\n' '....\n' ' loss = tf.reduce_sum(loss_obj(labels, predictions)) * (1. / global_batch_size)\n' '```\n' 'Please see https://www.tensorflow.org/tutorials/distribute/custom_training for more details.') return losses_utils.compute_weighted_loss(losses, sample_weight=sample_weight, reduction=reduction)
def update(self, grads): """Updates loss scale based on if gradients are finite in current step.""" if distribution_strategy_context.has_strategy(): distribution = distribution_strategy_context.get_cross_replica_context() def get_is_finite(grads): is_finite = _is_all_finite(grads) # We cast to float, because we cannot reduce booleans with # DistributionStrategy. return math_ops.cast(is_finite, dtypes.float32) is_finite_float = distribution.extended.call_for_each_replica( get_is_finite, args=(grads,)) reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, is_finite_float, axis=None) is_finite = math_ops.equal(reduced_is_finite_float, distribution.num_replicas_in_sync) else: is_finite = _is_all_finite(grads) def update_if_finite_grads(): """Update assuming the gradients are finite.""" def incr_loss_scale(): new_loss_scale = self._current_loss_scale * self._multiplier return control_flow_ops.group( _assign_if_finite(self._current_loss_scale, new_loss_scale), self._num_good_steps.assign(0)) return control_flow_ops.cond( self._num_good_steps + 1 >= self._increment_period, incr_loss_scale, lambda: _op_in_graph_mode( self._num_good_steps.assign_add(1))) def update_if_not_finite_grads(): """Update assuming the gradients are nonfinite.""" new_loss_scale = math_ops.maximum( self._current_loss_scale / self._multiplier, 1) return control_flow_ops.group( self._num_good_steps.assign(0), self._current_loss_scale.assign(new_loss_scale)) update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, update_if_not_finite_grads) should_apply_gradients = is_finite return update_op, should_apply_gradients
def _create_variable(*args, **kwargs): """Creates a variable, and check that it's not MirroredVariable. Args: *args: positional arguments passed along to `variables.Variable. **kwargs: keyword arguments passed along to `variables.Variable. Returns: The created variable. """ if ds_context.has_strategy(): raise ValueError( "Creating a generator within a strategy scope is disallowed, because " "there is ambiguity on how to replicate a generator (e.g. should it be " "copied so that each replica gets the same random numbers, or 'split' " "so that each replica gets different random numbers).") # TODO(wangpeng): Link to the RNG guide for solutions in such cases. var = variables.Variable(*args, **kwargs) return var
def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: # TODO(b/120571621): We want to avoid colocating the variables here # since TPUStrategy does not implement replica local variables. # Remove this hack once we support TPULocalVariables. is_tpu_strategy = False if distribution_strategy_context.has_strategy(): distribute = distribution_strategy_context.get_strategy() if distribute.__class__.__name__ == 'TPUStrategy': is_tpu_strategy = True with ops.colocate_with(variable): decay = ops.convert_to_tensor(1.0 - momentum, name='decay') if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay return state_ops.assign_sub(variable, update_delta, name=scope)
def _update_weights(self, fast_weights, slow_weights, alpha): def _update_slow_weight(slow_weight, fast_weight, a): slow_weight.assign_add(a * (fast_weight - slow_weight)) def _update_fast_weight(fast_weight, slow_weight): fast_weight.assign(slow_weight) if tf.equal(tf.cast(self._iterations, tf.float32) % self.k, 0): if distribution_strategy_context.has_strategy(): distribution = distribution_strategy_context.get_replica_context() for fast, slow in zip(fast_weights, slow_weights): distribution.extended.call_for_each_replica(_update_slow_weight, args=(slow, fast.value(), alpha)) distribution.extended.call_for_each_replica(_update_fast_weight, args=(fast, slow.value())) else: for fast, slow in zip(fast_weights, slow_weights): _update_slow_weight(slow, fast.value(), alpha) _update_fast_weight(fast, slow.value())
def _var_key(var): """Key for representing a primary variable, for looking up slots. In graph mode the name is derived from the var shared name. In eager mode the name is derived from the var unique id. If distribution strategy exists, get the primary variable first. Args: var: the variable. Returns: the unique name of the variable. """ # pylint: disable=protected-access if distribute_ctx.has_strategy() and hasattr(var, "_primary_var"): var = var._primary_var if hasattr(var, "op"): return var._shared_name return var._unique_id
def strategy_supports_loss_scaling(): """Returns True if the current Strategy supports loss scaling.""" if not distribution_strategy_context.has_strategy(): return True strategy = distribution_strategy_context.get_strategy() # Strategies are supported if either there is only one replica or if variables # are replicated per device. Otherwise, the current model.fit() implementation # and most custom training loops incorrectly unscale the gradients. Currently, # gradients are unscaled once per compute replica, but they should be unscaled # once per variable replica. When there is one variable replica for each # compute replica, this works fine, but otherwise issues will occur. # TODO(reedwm): Support all strategies. return isinstance(strategy, ( collective_all_reduce_strategy.CollectiveAllReduceStrategy, collective_all_reduce_strategy.CollectiveAllReduceStrategyV1, one_device_strategy.OneDeviceStrategy, one_device_strategy.OneDeviceStrategyV1, mirrored_strategy.MirroredStrategy, mirrored_strategy.MirroredStrategyV1, ))
def _var_key(var): """Key for representing a primary variable, for looking up slots. In graph mode the name is derived from the var shared name. In eager mode the name is derived from the var unique id. If distribution strategy exists, get the primary variable first. Args: var: the variable. Returns: the unique name of the variable. """ # pylint: disable=protected-access if distribute_ctx.has_strategy() and hasattr(var, "_primary_var"): var = var._primary_var if hasattr(var, "op"): return var._shared_name return var._unique_id
def _get_reduction(self): if distribution_strategy_context.has_strategy() and ( self.reduction == losses_utils.ReductionV2.AUTO or self.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE): raise ValueError( 'Please use `tf.keras.losses.Reduction.SUM` or ' '`tf.keras.losses.Reduction.NONE` for loss reduction when losses are ' 'used with `tf.distribute.Strategy` outside of the built-in training ' 'loops. You can implement ' '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch ' 'size like:\n```\nwith strategy.scope():\n' ' loss_obj = tf.keras.losses.CategoricalCrossentropy(' 'reduction=tf.keras.losses.reduction.NONE)\n....\n' ' loss = tf.reduce_sum(loss_obj(labels, predictions)) * ' '(1. / global_batch_size)\n```\nPlease see ' 'https://www.tensorflow.org/alpha/tutorials/distribute/training_loops' ' for more details.') if self.reduction == losses_utils.ReductionV2.AUTO: return losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE return self.reduction
def _get_distribution_strategy(model): """Get the model's distribution strategy.""" if model._distribution_strategy: return model._distribution_strategy else: # Use the default strategy if no strategy was present at compile. # Validate there is no actual strategy scope active at execution # time. strategy = distribution_strategy_context.get_strategy() if distribution_strategy_context.has_strategy(): raise ValueError( 'Model was compiled without any active distribution strategy, ' 'but there is an execution-time distribution ' 'strategy scope of (%s). ' 'Try to make sure your code looks similar to the following.\n' 'with strategy.scope():\n' ' model=_create_model()\n' ' model.compile(...)\n' ' model.fit(...)'% strategy) return strategy
def _get_reduction(self): """Handles `AUTO` reduction cases and returns the reduction value.""" if distribution_strategy_context.has_strategy() and ( self.reduction == losses_utils.ReductionV2.AUTO or self.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE): raise ValueError( 'Please use `tf.keras.losses.Reduction.SUM` or ' '`tf.keras.losses.Reduction.NONE` for loss reduction when losses are ' 'used with `tf.distribute.Strategy` outside of the built-in training ' 'loops. You can implement ' '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch ' 'size like:\n```\nwith strategy.scope():\n' ' loss_obj = tf.keras.losses.CategoricalCrossentropy(' 'reduction=tf.keras.losses.reduction.None)\n....\n' ' loss = tf.reduce_sum(loss_obj(labels, predictions)) * ' '(1. / global_batch_size)\n```\nPlease see ' 'https://www.tensorflow.org/alpha/tutorials/distribute/training_loops' ' for more details.') if self.reduction == losses_utils.ReductionV2.AUTO: return losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE return self.reduction
def _centering_weights(weight, weight_identity): # when using group norm # normalize weights variable may get better result centered_weight = weight_identity - tf.reduce_mean(weight_identity, axis=[0, 1, 2]) weight_norm = linalg_ops.norm(tf.cast(tf.reshape( centered_weight, [-1, centered_weight.shape[-1]]), dtype=tf.float32), ord=2, axis=-2) normed_weight = centered_weight / tf.cast(weight_norm, dtype=weight.dtype) if distribute_ctx.has_strategy(): # Handle DistributionStrategy case. if distribute_ctx.in_cross_replica_context(): raise RuntimeError( "Use `_distributed_apply()` instead of " "`apply_gradients()` in a cross-replica context.") assign_op = distribute_ctx.get_replica_context().merge_call( assign_vars, args=(weight, normed_weight))[0] else: assign_op = weight.assign(normed_weight) return [assign_op]
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This is the second part of `minimize()`. It returns an `Operation` that applies gradients. Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`. Raises: TypeError: If `grads_and_vars` is malformed. ValueError: If none of the variables have gradients. RuntimeError: If you should use `_distributed_apply()` instead. """ # This is a default implementation of apply_gradients() that can be shared # by most optimizers. It relies on the subclass implementing the following # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). # TODO(isaprykin): Get rid of `has_strategy()` check by # always calling _distributed_apply(), using the default distribution # as needed. if distribute_ctx.has_strategy(): # Handle DistributionStrategy case. if distribute_ctx.in_cross_replica_context(): raise RuntimeError("Use `_distributed_apply()` instead of " "`apply_gradients()` in a cross-replica context.") grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() return distribute_ctx.get_replica_context().merge_call( self._distributed_apply, args=(grads_and_vars, global_step, name)) # No DistributionStrategy case. grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. if not grads_and_vars: raise ValueError("No variables provided.") converted_grads_and_vars = [] for g, v in grads_and_vars: if g is not None: try: # Convert the grad to Tensor or IndexedSlices if necessary. g = ops.convert_to_tensor_or_indexed_slices(g) except TypeError: raise TypeError( "Gradient must be convertible to a Tensor" " or IndexedSlices, or None: %s" % g) if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): raise TypeError( "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) p = _get_processor(v) converted_grads_and_vars.append((g, v, p)) converted_grads_and_vars = tuple(converted_grads_and_vars) var_list = [v for g, v, _ in converted_grads_and_vars if g is not None] if not var_list: raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v, _ in converted_grads_and_vars],)) with ops.init_scope(): self._create_slots(var_list) update_ops = [] with ops.name_scope(name, self._name) as name: self._prepare() for grad, var, processor in converted_grads_and_vars: if grad is None: continue # We colocate all ops created in _apply_dense or _apply_sparse # on the same device as the variable. # TODO(apassos): figure out how to get the variable name here. if context.executing_eagerly() or isinstance( var, resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access scope_name = "" else: scope_name = var.op.name with ops.name_scope("update_" + scope_name), ops.colocate_with(var): update_ops.append(processor.update_op(self, grad)) if global_step is None: apply_updates = self._finish(update_ops, name) else: with ops.control_dependencies([self._finish(update_ops, "update")]): with ops.colocate_with(global_step): if isinstance(global_step, resource_variable_ops.ResourceVariable): # TODO(apassos): the implicit read in assign_add is slow; consider # making it less so. apply_updates = resource_variable_ops.assign_add_variable_op( global_step.handle, ops.convert_to_tensor(1, dtype=global_step.dtype), name=name) else: apply_updates = state_ops.assign_add(global_step, 1, name=name) if not context.executing_eagerly(): if isinstance(apply_updates, ops.Tensor): apply_updates = apply_updates.op train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) if apply_updates not in train_op: train_op.append(apply_updates) return apply_updates
def _compute_gradients_until_finite( distribution, loss_scale_gradient_tapes, loss_scale, target, sources, output_gradients, unconnected_gradients): """Compute gradients and update the loss scale until the gradients are finite. This must be called in a cross-replica context. This is a function instead of a method of LossScaleGradientTape, as the `self` parameter would be meaningless. There is one LossScaleGradientTape per replica, but this function is called once total (not per replica), so there cannot be a singular `self` parameter. Args: distribution: The distribution strategy in effect. loss_scale_gradient_tapes: A PerReplica value of LossScaleGradientTapes. Contains the LossScaleGradientTape of each replica. loss_scale: The loss scale to use to scale the loss and unscale the gradient. target: a list or nested structure of Tensors or Variables to be differentiated. sources: a list or nested structure of Tensors or Variables. `target` will be differentiated against elements in `sources`. output_gradients: Passed to GradientTape.gradient unconnected_gradients: Pass to GradientTape.gradient. Returns: The gradients of `target` with respect to `sources`. """ # Autograph cannot convert this function, so we must use an explicit # tf.while_loop. # TODO(b/143572314): Fix Autograph so that it can convert this function, then # replace the tf.while_loop with a Python while loop. # For convenience, we only deal with flattened sources flattened_sources = nest.flatten(sources) # Define the initial loop variables of the while loop. # Dummy value for initial_grads. The first iteration of the loop will # overwrite `grads` to the actual gradients. initial_grads = flattened_sources if distribution_strategy_context.has_strategy(): # A while_loop requires the initial values to have the same types as the # return values from the body. However, 'initial_grads' may have type # 'DistributionVariable', while body returns a 'PerReplica'. While both # types subclass 'DistributedValues', while_loop will still throw an error. # So we convert 'initial_grads' to be PerReplica values. # TODO(b/146084534): Once the bug is fixed, remove this special case. initial_grads = _convert_to_per_replicas(distribution, initial_grads) initial_ready_to_update = False initial_is_first_iteration = True def cond(grads, ready_to_update, is_first_iteration): """The condition of the while loop.""" del grads # Equivalent to: # `is_first_iteration or (not ready_to_update and loss_scale() > 1)` return math_ops.logical_or( is_first_iteration, math_ops.logical_and( math_ops.logical_not(ready_to_update), math_ops.greater(loss_scale(), 1))) # Boolean list specifying whether each gradient is None or not. Set by body(). is_nones = [] def body(grads, ready_to_update, is_first_iteration): """The body of the while loop.""" del grads, ready_to_update, is_first_iteration def replica_fn(gradient_tape, target, flattened_sources, output_gradients, initial_grads): """Scales the loss, computes the gradients, and unscales the gradients.""" loss_scale_val = loss_scale() with gradient_tape: # re-enter gradient tape so it sees the loss scaling scaled_target = nest.map_structure( lambda t: t * math_ops.cast(loss_scale_val, t.dtype), target) scaled_grads = super(LossScaleGradientTape, gradient_tape).gradient( scaled_target, flattened_sources, output_gradients, unconnected_gradients) is_nones[:] = [g is None for g in scaled_grads] inv_loss_scale = 1.0 / loss_scale_val grads = [] # The unscaled gradients for g, initial_grad in zip(scaled_grads, initial_grads): if g is not None: # We call ensure_shape as shape information can be lost for certain # ops, such as tf.transpose, if the op is called in a tf.function and # has inputs created outside the tf.function. # TODO(b/132092188): Remove ensure_shape call after this has been # fixed. g = array_ops.ensure_shape(g, initial_grad.shape) grads.append(g * math_ops.cast(inv_loss_scale, g.dtype)) else: # We cannot return None from a tf.while_loop, so we pass a dummy # tensor instead. We use initial_grad as a dummy tensor as it has the # correct shape and dtype. We replace it with None outside the while # loop. grads.append(initial_grad) return grads # Switch to a replica-context to compute gradients once per replica. grads = distribution.experimental_run_v2( replica_fn, args=(loss_scale_gradient_tapes, target, flattened_sources, output_gradients, initial_grads)) # Check for non-finite gradients possibly resulting from scaling. _, ready_to_update = loss_scale.update(grads) is_first_iteration = False return grads, ready_to_update, is_first_iteration grads, _, _ = control_flow_ops.while_loop( cond, body, [initial_grads, initial_ready_to_update, initial_is_first_iteration], ) grads = [None if is_none else g for g, is_none in zip(grads, is_nones)] grads = nest.pack_sequence_as(sources, grads) return grads
def fit( self, model, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False, **kwargs): batch_size = model._validate_or_infer_batch_size( batch_size, steps_per_epoch, x) strategy = _get_distribution_strategy(model) batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size( strategy, x, batch_size, steps_per_epoch, ModeKeys.TRAIN, validation_split=validation_split) dist_utils.validate_callbacks(input_callbacks=callbacks, optimizer=model.optimizer) # Enter tf.distribute.Strategy scope. with strategy.scope(): training_data_adapter, validation_adapter = _process_training_inputs( model, x, y, batch_size=batch_size, epochs=epochs, sample_weights=sample_weight, class_weights=class_weight, validation_split=validation_split, steps_per_epoch=steps_per_epoch, shuffle=shuffle, validation_data=validation_data, validation_steps=validation_steps, distribution_strategy=strategy, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing) total_samples = _get_total_number_of_samples(training_data_adapter) use_sample = total_samples is not None do_validation = (validation_adapter is not None) recreate_training_iterator = ( training_data_adapter.should_recreate_iterator(steps_per_epoch)) if not steps_per_epoch: # TODO(b/139762795): Add step inference for when steps is None to # prevent end of sequence warning message. steps_per_epoch = training_data_adapter.get_size() # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch)) training_context = TrainingContext() training_dataset = training_data_adapter.get_dataset() # Raise an error if steps_per_epoch isn't specified but the dataset # is infinite. # TODO(scottzhu): This check should probably happen in the adapter inferred_steps = training_utils.infer_steps_for_dataset( model, training_dataset, steps_per_epoch, steps_name='steps_per_epoch', epochs=0) steps_per_epoch = ( inferred_steps if steps_per_epoch is None else steps_per_epoch) training_dataset = strategy.experimental_distribute_dataset( training_dataset) training_function = training_v2_utils._get_or_make_execution_function( model, ModeKeys.TRAIN) training_data_iter = None if do_validation: validation_dataset = validation_adapter.get_dataset() if not validation_steps: # Raise an error if validation_steps isn't specified but the # validation dataset is infinite. validation_steps = ( validation_adapter.get_size() or training_utils.infer_steps_for_dataset( model, validation_dataset, validation_steps, steps_name='validation_steps')) eval_function = training_v2_utils._get_or_make_execution_function( model, ModeKeys.TEST) eval_data_iter = None validation_dataset = strategy.experimental_distribute_dataset( validation_dataset) val_total_samples = _get_total_number_of_samples(validation_adapter) else: val_total_samples = None if verbose and (total_samples or steps_per_epoch): _print_train_info(total_samples, steps_per_epoch, val_total_samples, validation_steps) training_callbacks = cbks.configure_callbacks( callbacks, model, do_validation=do_validation, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=total_samples or steps_per_epoch, count_mode='samples' if use_sample else 'steps', verbose=0, # Handle ProgBarLogger separately in this loop. mode=ModeKeys.TRAIN) with training_context.on_start(model, training_callbacks, use_sample, verbose, ModeKeys.TRAIN): initial_epoch = model._maybe_load_initial_epoch_from_ckpt( initial_epoch, ModeKeys.TRAIN) for epoch in range(initial_epoch, epochs): if training_context.callbacks.model.stop_training: break # Training with training_context.on_epoch(epoch, ModeKeys.TRAIN) as epoch_logs: model.reset_metrics() if training_data_iter is None or recreate_training_iterator: if (training_data_iter is not None and distribution_strategy_context.has_strategy()): # TODO(kaftan): remove this when MultiDeviceIterator is a ## compositetensor (unless this is more efficient) training_data_iter._initializer # pylint: disable=pointless-statement else: training_data_iter = iter(training_dataset) training_result = run_one_epoch( model, training_data_iter, training_function, dataset_size=training_data_adapter.get_size(), batch_size=training_data_adapter.batch_size(), strategy=strategy, steps_per_epoch=steps_per_epoch, num_samples=total_samples, mode=ModeKeys.TRAIN, training_context=training_context, total_epochs=epochs) cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN) # In the case of steps_per_epoch = None, the final cardinality will # be determined when the inputs are fully consumed (eg dataset or # generator). Update the steps_per_epoch to the new value. if (steps_per_epoch is None and training_context.progbar.progbar.target is not None): steps_per_epoch = training_context.progbar.progbar.target # Evaluation if (do_validation and training_utils.should_run_validation(validation_freq, epoch) and not training_callbacks.model.stop_training): if (eval_data_iter is not None and distribution_strategy_context.has_strategy()): # TODO(kaftan): remove this when MultiDeviceIterator is a ## compositetensor (unless this is more efficient) eval_data_iter._initializer # pylint: disable=pointless-statement else: eval_data_iter = iter(validation_dataset) validation_callbacks = cbks.configure_callbacks( training_callbacks, model, batch_size=batch_size, epochs=1, steps_per_epoch=validation_steps, samples=val_total_samples or validation_steps, count_mode='samples' if use_sample else 'steps', verbose=0, # Handle ProgBarLogger separately in this loop. mode=ModeKeys.TEST) eval_context = TrainingContext() with eval_context.on_start( model, validation_callbacks, use_sample, verbose=0, mode=ModeKeys.TEST): with eval_context.on_epoch(epoch, ModeKeys.TEST): model.reset_metrics() eval_result = run_one_epoch( model, eval_data_iter, eval_function, dataset_size=validation_adapter.get_size(), batch_size=validation_adapter.batch_size(), strategy=strategy, steps_per_epoch=validation_steps, num_samples=val_total_samples, mode=ModeKeys.TEST, training_context=eval_context, total_epochs=1) cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST, prefix='val_') return model.history
def is_default_strategy(strategy): with strategy.scope(): return not distribution_strategy_context.has_strategy()
def apply_gradients(self, grads_and_vars, name=None): """Apply gradients to variables. This is the second part of `minimize()`. It returns an `Operation` that applies gradients. Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`. Raises: TypeError: If `grads_and_vars` is malformed. ValueError: If none of the variables have gradients. """ grads_and_vars = _filter_grads(grads_and_vars) var_list = [v for (_, v) in grads_and_vars] if distribute_ctx.has_strategy(): reduced_grads = merge_grads(grads_and_vars) grads_and_vars = zip(reduced_grads, var_list) self._create_hypers() with ops.init_scope(): self._create_slots(var_list) update_ops = [] self._prepare(var_list) def update_grad_to_var(grad, var): """Apply gradient to variable.""" if isinstance(var, ops.Tensor): raise NotImplementedError("Trying to update a Tensor ", var) if isinstance(grad, ops.IndexedSlices): if var.constraint is not None: raise RuntimeError( "Cannot use a constraint function on a sparse variable.") return self._resource_apply_sparse_duplicate_indices( grad.values, var, grad.indices) update_op = self._resource_apply_dense(grad, var) if var.constraint is not None: with ops.control_dependencies([update_op]): return var.assign(var.constraint(var)) else: return update_op with ops.name_scope(name, self._name) as name: for grad, var in grads_and_vars: scope_name = ("" if ops.executing_eagerly_outside_functions() else "_" + var.op.name) with ops.name_scope("update" + scope_name): update_ops.append(update_grad_to_var(grad, var)) # control dependencies does not work in per replica mode, please change # this once b/118841692 is fixed. # with ops.control_dependencies(update_ops): # apply_updates = self._iterations.assign_add(1).op apply_updates = merge_update_step(update_ops, self.iterations) return apply_updates
def decorated(metric_obj, *args): """Decorated function with merge_call.""" has_strategy = distribution_strategy_context.has_strategy() replica_context = distribution_strategy_context.get_replica_context() # The purpose of using `merge_call` to call `result()` is to trigger cross # replica aggregation of metric state variables (SyncOnReadVariable). After # we introduced `variable_sync_on_read_context`, in principle there is no # need to use `merge_call` here. However the branch still exists because: # # 1. Keras V1 training code sometimes assumes `result_t` is the same tensor # across replicas (achieved by `merge_call`). With # `variable_sync_on_read_context` each replica gets their own tensors # residing on replica's device, thus breaking the assumption. # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that returns # the metric values of the first replica. With # `variable_sync_on_read_context` since each replica gets their own # tensors, the metric result tensors on the non-first replicas are not in # the return value of train_function, making TF graph optimizer prune the # branch that computes and aggregates those metric results. As a result, # if NCCL is used to do the aggregation, the program will hang because # NCCL ops are only launched on the non-pruned first replica. # # We condition on strategy.extended._use_merge_call() since we know if it is # false, the program uses `jit_compile` to compile replica fn, meaning it is # not V1 training (hence #1 is okay), and no pruning will happen as # compiled functions are not inlined (hence #2 is okay). if (not has_strategy or replica_context is None or not distribution_strategy_context.get_strategy( ).extended._use_merge_call()): with distribution_strategy_context.variable_sync_on_read_context(): raw_result = result_fn(*args) # Results need to be wrapped in a `tf.identity` op to ensure # correct execution order. if isinstance(raw_result, (ops.Tensor, variables_module.Variable, float, int)): result_t = array_ops.identity(raw_result) elif isinstance(raw_result, dict): result_t = { key: array_ops.identity(value) for key, value in raw_result.items() } else: try: result_t = array_ops.identity(raw_result) except (ValueError, TypeError): raise RuntimeError( 'The output of `metric.result()` can only be a single ' 'Tensor/Variable, or a dict of Tensors/Variables. ' 'For metric %s, got result %s.' % (metric_obj.name, raw_result)) else: # TODO(psv): Test distribution of metrics using different distribution # strategies. # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn # with distribution object as the first parameter. We create a wrapper # here so that the result function need not have that parameter. def merge_fn_wrapper(distribution, merge_fn, *args): # We will get `PerReplica` merge function. Taking the first one as all # are identical copies of the function that we had passed below. result = distribution.experimental_local_results(merge_fn)[0](*args) # Wrapping result in identity so that control dependency between # update_op from `update_state` and result works in case result returns # a tensor. return array_ops.identity(result) # Wrapping result in merge_call. merge_call is used when we want to leave # replica mode and compute a value in cross replica mode. result_t = replica_context.merge_call( merge_fn_wrapper, args=(result_fn,) + args) # We are saving the result op here to be used in train/test execution # functions. This basically gives the result op that was generated with a # control dep to the updates for these workflows. metric_obj._call_result = result_t return result_t
def _support_zero_size_input(self): return distribution_strategy_context.has_strategy() and getattr( distribution_strategy_context.get_strategy().extended, 'experimental_enable_get_next_as_optional', False)
def is_default_strategy(strategy): with strategy.scope(): return not distribution_strategy_context.has_strategy()