def _assert_in_default_state(t): t.assertIs(distribution_strategy_context._get_default_tower_context(), distribution_strategy_context.get_tower_context()) t.assertIs(None, distribution_strategy_context.get_cross_tower_context()) t.assertIs(distribution_strategy_context._get_default_distribution_strategy(), distribution_strategy_context.get_distribution_strategy()) t.assertFalse(distribution_strategy_context.has_distribution_strategy())
def create_slot(primary, val, name, colocate_with_primary=True): """Create a slot initialized to the given value. The type of the slot is determined by the given value. Args: primary: The primary `Variable` or `Tensor`. val: A `Tensor` specifying the initial value of the slot. name: Name to use for the slot variable. colocate_with_primary: Boolean. If True the slot is located on the same device as `primary`. Returns: A `Variable` object. """ # Scope the slot name in the namespace of the primary variable. # Set "primary.op.name + '/' + name" as default name, so the scope name of # optimizer can be shared when reuse is True. Meanwhile when reuse is False # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = val.get_shape().is_fully_defined() if context.executing_eagerly(): prefix = primary._shared_name # pylint: disable=protected-access else: prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: distribution_strategy = ( distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) else: return _create_slot_var(primary, val, "", validate_shape, None, None)
def _create_non_slot_variable(self, initial_value, name, colocate_with): """Add an extra variable, not associated with a slot.""" # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. eager = context.executing_eagerly() graph = None if eager else colocate_with.graph key = (name, graph) v = self._non_slot_dict.get(key, None) if v is None: self._maybe_initialize_checkpointable() distribution_strategy = distribute_ctx.get_distribution_strategy() with distribution_strategy.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( name=name, shape=None) if restored_initial_value is not None: initial_value = restored_initial_value v = variable_scope.variable(initial_value, name=name, trainable=False) # Restore this variable by name if necessary, but don't add a # Checkpointable dependency. Optimizers return the current graph's # non-slot variables from _checkpoint_dependencies explicitly rather # than unconditionally adding dependencies (since there may be multiple # non-slot variables with the same name in different graphs, trying to # save all of them would result in errors). self._handle_deferred_dependencies(name=name, checkpointable=v) self._non_slot_dict[key] = v return v
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() if update_device is not None: # We are calling an assign function in an update context. return f(self._v, *args, **kwargs) # We are calling an assign function in cross tower context, wrap it in an # update call. return distribution_strategy_context.get_distribution_strategy().update( self, f, *args, **kwargs) else: assert distribution_strategy_context.get_tower_context() # We are calling an assign function in tower context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function with the reduced value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "a variable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce( aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_tower_context().merge_call( merge_fn, *args, **kwargs)
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() if update_device is not None: # We are calling an assign function in an update context. return f(self._v, *args, **kwargs) # We are calling an assign function in cross tower context, wrap it in an # update call. return distribution_strategy_context.get_distribution_strategy( ).update(self, f, *args, **kwargs) else: assert distribution_strategy_context.get_tower_context() # We are calling an assign function in tower context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function with the reduced value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError( "You must specify an aggregation method to update a " "a variable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce(aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_tower_context( ).merge_call(merge_fn, *args, **kwargs)
def create_slot(primary, val, name, colocate_with_primary=True): """Create a slot initialized to the given value. The type of the slot is determined by the given value. Args: primary: The primary `Variable` or `Tensor`. val: A `Tensor` specifying the initial value of the slot. name: Name to use for the slot variable. colocate_with_primary: Boolean. If True the slot is located on the same device as `primary`. Returns: A `Variable` object. """ # Scope the slot name in the namespace of the primary variable. # Set "primary.op.name + '/' + name" as default name, so the scope name of # optimizer can be shared when reuse is True. Meanwhile when reuse is False # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = val.get_shape().is_fully_defined() if context.executing_eagerly(): prefix = primary._shared_name # pylint: disable=protected-access else: prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: distribution_strategy = ( distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) else: return _create_slot_var(primary, val, "", validate_shape, None, None)
def _scale_loss(loss_value): if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) return loss_value
def _scale_loss(loss_value): if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) return loss_value
def merge_fn(dist, s): self.assertIs( distribution_strategy_context._get_default_distribution_strategy(), dist) self.assertIs(None, distribution_strategy_context.get_tower_context()) self.assertIs(dist, distribution_strategy_context.get_cross_tower_context()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) self.assertFalse( distribution_strategy_context.has_distribution_strategy()) return "foo_" + s
def run_fn(): tower_context = distribution_strategy_context.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribution_strategy_context.get_cross_tower_context()) self.assertTrue(distribution_strategy_context.has_distribution_strategy()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() if update_device is not None: # We are calling an assign function on the mirrored variable in an # update context. v = self.get(device=update_device) return f(v, *args, **kwargs) # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. strategy = distribution_strategy_context.get_distribution_strategy( ) updates = strategy.update(self, f, *args, **kwargs) grouped = strategy.group(updates) if isinstance(updates, DistributedValues) and updates.is_tensor_like: # Make sure we run all updates. Without this, something like # session.run(mirrored_var.assign*(...)) may only update one tower. index = {} for d in updates.devices: with ops.device(d), ops.control_dependencies([grouped]): index[d] = array_ops.identity(updates.get(d)) return Mirrored(index) else: return grouped else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower # context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function on each of the mirrored variables with the reduced # value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError( "You must specify an aggregation method to update a " "MirroredVariable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce(aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_tower_context( ).merge_call(merge_fn, *args, **kwargs)
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribution_strategy_context.get_tower_context()) self.assertIs(dist, distribution_strategy_context.get_cross_tower_context()) self.assertTrue(distribution_strategy_context.has_distribution_strategy()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def set_last_step_output( self, name, output, aggregation=variables_lib.VariableAggregation.NONE): """Set `output` with `name` to be outputted from the last step. Args: name: String, name to identify the output. Doesn't need to match tensor name. output: The tensors that should be outputted with `name`. See below for actual types supported. aggregation: Aggregation method to use to aggregate outputs from multiple towers. Required if `set_last_step_output` is called in a tower context. Optional in cross_tower_context. When present, the outputs from all the towers are aggregated using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and aggregation is set, output must be a `PerDevice` value. The aggregation method is also recorded in a dictionary `_last_step_outputs_aggregations` for later interpreting of the outputs as already reduced or not. """ if distribution_strategy_context.get_cross_tower_context(): self._last_step_outputs_aggregations[name] = aggregation if aggregation is variables_lib.VariableAggregation.NONE: self._last_step_outputs[name] = output else: distribution = distribution_strategy_context.get_distribution_strategy( ) self._last_step_outputs[name] = distribution.reduce( aggregation, output, destinations="/device:CPU:0") else: assert aggregation is not variables_lib.VariableAggregation.NONE def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce( aggregation, value, destinations="/device:CPU:0") # Setting this inside the `merge_fn` because all towers share the same # context object, so it's more robust to set it only once (even if all # the towers are trying to set the same value). self._last_step_outputs_aggregations[name] = aggregation distribution_strategy_context.get_tower_context().merge_call( merge_fn, output)
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() if update_device is not None: # We are calling an assign function on the mirrored variable in an # update context. v = self.get(device=update_device) return f(v, *args, **kwargs) # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. strategy = distribution_strategy_context.get_distribution_strategy() updates = strategy.update(self, f, *args, **kwargs) grouped = strategy.group(updates) if isinstance(updates, DistributedValues) and updates.is_tensor_like: # Make sure we run all updates. Without this, something like # session.run(mirrored_var.assign*(...)) may only update one tower. index = {} for d in updates.devices: with ops.device(d), ops.control_dependencies([grouped]): index[d] = array_ops.identity(updates.get(d)) return Mirrored(index) else: return grouped else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower # context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function on each of the mirrored variables with the reduced # value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "MirroredVariable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce( aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_tower_context().merge_call( merge_fn, *args, **kwargs)
def set_last_step_output(self, name, output, aggregation=variables_lib.VariableAggregation.NONE): """Set `output` with `name` to be outputted from the last step. Args: name: String, name to identify the output. Doesn't need to match tensor name. output: The tensors that should be outputted with `name`. See below for actual types supported. aggregation: Aggregation method to use to aggregate outputs from multiple towers. Required if `set_last_step_output` is called in a tower context. Optional in cross_tower_context. When present, the outputs from all the towers are aggregated using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and aggregation is set, output must be a `PerDevice` value. The aggregation method is also recorded in a dictionary `_last_step_outputs_aggregations` for later interpreting of the outputs as already reduced or not. """ if distribution_strategy_context.get_cross_tower_context(): self._last_step_outputs_aggregations[name] = aggregation if aggregation is variables_lib.VariableAggregation.NONE: self._last_step_outputs[name] = output else: distribution = distribution_strategy_context.get_distribution_strategy() self._last_step_outputs[name] = distribution.reduce( aggregation, output, destinations="/device:CPU:0") else: assert aggregation is not variables_lib.VariableAggregation.NONE def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce( aggregation, value, destinations="/device:CPU:0") # Setting this inside the `merge_fn` because all towers share the same # context object, so it's more robust to set it only once (even if all # the towers are trying to set the same value). self._last_step_outputs_aggregations[name] = aggregation distribution_strategy_context.get_tower_context().merge_call( merge_fn, output)
def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndims = len(input_shape) # Convert axis to list and resolve negatives if isinstance(self.axis, int): self.axis = [self.axis] if not isinstance(self.axis, list): raise TypeError('axis must be int or list, type given: %s' % type(self.axis)) for idx, x in enumerate(self.axis): if x < 0: self.axis[idx] = ndims + x # Validate axes for x in self.axis: if x < 0 or x >= ndims: raise ValueError('Invalid axis: %d' % x) if len(self.axis) != len(set(self.axis)): raise ValueError('Duplicate axis: %s' % self.axis) if self.virtual_batch_size is not None: if self.virtual_batch_size <= 0: raise ValueError('virtual_batch_size must be a positive integer that ' 'divides the true batch size of the input Tensor') # If using virtual batches, the first dimension must be the batch # dimension and cannot be the batch norm axis if 0 in self.axis: raise ValueError('When using virtual_batch_size, the batch dimension ' 'must be 0 and thus axis cannot include 0') if self.adjustment is not None: raise ValueError('When using virtual_batch_size, adjustment cannot ' 'be specified') if self.fused: # Currently fused batch norm doesn't support renorm. It also only supports # an input tensor of rank 4 and a channel dimension on axis 1 or 3. # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the # output back to its original shape accordingly. self.fused = (not self.renorm and ndims == 4 and self.axis in [[1], [3]] and self.virtual_batch_size is None and self.adjustment is None) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is # particularly tricky. A compromise might be to just support the most # common use case (turning 5D w/ virtual batch to NCHW) if self.fused: if self.axis == [1]: self._data_format = 'NCHW' elif self.axis == [3]: self._data_format = 'NHWC' else: raise ValueError('Unsupported axis, fused batch norm only supports ' 'axis == [1] or axis == [3]') # Raise parameters of fp16 batch norm to fp32 if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: param_dtype = dtypes.float32 else: param_dtype = self.dtype or dtypes.float32 axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} for x in axis_to_dim: if axis_to_dim[x] is None: raise ValueError('Input has undefined `axis` dimension. Input shape: ', input_shape) self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) if len(axis_to_dim) == 1 and self.virtual_batch_size is None: # Single axis batch norm (most common/default use-case) param_shape = (list(axis_to_dim.values())[0],) else: # Parameter shape is the original shape but with 1 in all non-axis dims param_shape = [axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)] if self.virtual_batch_size is not None: # When using virtual batches, add an extra dim at index 1 param_shape.insert(1, 1) for idx, x in enumerate(self.axis): self.axis[idx] = x + 1 # Account for added dimension if self.scale: self.gamma = self.add_weight( name='gamma', shape=param_shape, dtype=param_dtype, initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, trainable=True) else: self.gamma = None if self.fused: self._gamma_const = array_ops.constant( 1.0, dtype=param_dtype, shape=param_shape) if self.center: self.beta = self.add_weight( name='beta', shape=param_shape, dtype=param_dtype, initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, trainable=True) else: self.beta = None if self.fused: self._beta_const = array_ops.constant( 0.0, dtype=param_dtype, shape=param_shape) try: # Disable variable partitioning when creating the moving mean and variance if hasattr(self, '_scope') and self._scope: partitioner = self._scope.partitioner self._scope.set_partitioner(None) else: partitioner = None self.moving_mean = self.add_weight( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) self.moving_variance = self.add_weight( name='moving_variance', shape=param_shape, dtype=param_dtype, initializer=self.moving_variance_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving # averages above. The renorm variables are colocated with moving_mean # and moving_variance. # NOTE: below, the outer `with device` block causes the current device # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): var = self.add_weight( name=name, shape=shape, dtype=param_dtype, initializer=init_ops.zeros_initializer(), synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) return var with distribution_strategy_context.get_distribution_strategy( ).colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable('renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. with distribution_strategy_context.get_distribution_strategy( ).colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight', ()) finally: if partitioner: self._scope.set_partitioner(partitioner) self.built = True
def tensor(): return distribution_strategy_context.get_distribution_strategy( ).read_var(tower_local_variable)
def read_value(self): return distribution_strategy_context.get_distribution_strategy( ).read_var(self)
def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndims = len(input_shape) # Convert axis to list and resolve negatives if isinstance(self.axis, int): self.axis = [self.axis] if not isinstance(self.axis, list): raise TypeError('axis must be int or list, type given: %s' % type(self.axis)) for idx, x in enumerate(self.axis): if x < 0: self.axis[idx] = ndims + x # Validate axes for x in self.axis: if x < 0 or x >= ndims: raise ValueError('Invalid axis: %d' % x) if len(self.axis) != len(set(self.axis)): raise ValueError('Duplicate axis: %s' % self.axis) if self.virtual_batch_size is not None: if self.virtual_batch_size <= 0: raise ValueError('virtual_batch_size must be a positive integer that ' 'divides the true batch size of the input Tensor') # If using virtual batches, the first dimension must be the batch # dimension and cannot be the batch norm axis if 0 in self.axis: raise ValueError('When using virtual_batch_size, the batch dimension ' 'must be 0 and thus axis cannot include 0') if self.adjustment is not None: raise ValueError('When using virtual_batch_size, adjustment cannot ' 'be specified') if self.fused: # Currently fused batch norm doesn't support renorm. It also only supports # an input tensor of rank 4 and a channel dimension on axis 1 or 3. # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the # output back to its original shape accordingly. self.fused = (not self.renorm and ndims == 4 and self.axis in [[1], [3]] and self.virtual_batch_size is None and self.adjustment is None) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is # particularly tricky. A compromise might be to just support the most # common use case (turning 5D w/ virtual batch to NCHW) if self.fused: if self.axis == [1]: self._data_format = 'NCHW' elif self.axis == [3]: self._data_format = 'NHWC' else: raise ValueError('Unsupported axis, fused batch norm only supports ' 'axis == [1] or axis == [3]') # Raise parameters of fp16 batch norm to fp32 if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: param_dtype = dtypes.float32 else: param_dtype = self.dtype or dtypes.float32 axis_to_dim = {x: input_shape[x].value for x in self.axis} for x in axis_to_dim: if axis_to_dim[x] is None: raise ValueError('Input has undefined `axis` dimension. Input shape: ', input_shape) self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) if len(axis_to_dim) == 1 and self.virtual_batch_size is None: # Single axis batch norm (most common/default use-case) param_shape = (list(axis_to_dim.values())[0],) else: # Parameter shape is the original shape but with 1 in all non-axis dims param_shape = [axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)] if self.virtual_batch_size is not None: # When using virtual batches, add an extra dim at index 1 param_shape.insert(1, 1) for idx, x in enumerate(self.axis): self.axis[idx] = x + 1 # Account for added dimension if self.scale: self.gamma = self.add_weight( name='gamma', shape=param_shape, dtype=param_dtype, initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, trainable=True) else: self.gamma = None if self.fused: self._gamma_const = array_ops.constant( 1.0, dtype=param_dtype, shape=param_shape) if self.center: self.beta = self.add_weight( name='beta', shape=param_shape, dtype=param_dtype, initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, trainable=True) else: self.beta = None if self.fused: self._beta_const = array_ops.constant( 0.0, dtype=param_dtype, shape=param_shape) try: # Disable variable partitioning when creating the moving mean and variance if hasattr(self, '_scope') and self._scope: partitioner = self._scope.partitioner self._scope.set_partitioner(None) else: partitioner = None self.moving_mean = self.add_weight( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) self.moving_variance = self.add_weight( name='moving_variance', shape=param_shape, dtype=param_dtype, initializer=self.moving_variance_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving # averages above. The renorm variables are colocated with moving_mean # and moving_variance. # NOTE: below, the outer `with device` block causes the current device # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): var = self.add_weight( name=name, shape=shape, dtype=param_dtype, initializer=init_ops.zeros_initializer(), synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) return var with distribution_strategy_context.get_distribution_strategy( ).colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable('renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. with distribution_strategy_context.get_distribution_strategy( ).colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight', ()) finally: if partitioner: self._scope.set_partitioner(partitioner) self.built = True
def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): """Compute gradients of `loss` for the variables in `var_list`. This is the first part of `minimize()`. It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a `Tensor`, an `IndexedSlices`, or `None` if there is no gradient for the given variable. Args: loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable. var_list: Optional list or tuple of `tf.Variable` to update to minimize `loss`. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: If True, try colocating gradients with the corresponding op. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. Returns: A list of (gradient, variable) pairs. Variable is always present, but gradient can be `None`. Raises: TypeError: If `var_list` contains anything else than `Variable` objects. ValueError: If some arguments are invalid. RuntimeError: If called with eager execution enabled and `loss` is not callable. @compatibility(eager) When eager execution is enabled, `gate_gradients`, `aggregation_method`, and `colocate_gradients_with_ops` are ignored. @end_compatibility """ if callable(loss): with backprop.GradientTape() as tape: if var_list is not None: tape.watch(var_list) loss_value = loss() # Scale loss if using a "mean" loss reduction and multiple towers. # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribution_strategy_context.get_distribution_strategy( ).num_towers if num_towers > 1: loss_value *= (1. / num_towers) if var_list is None: var_list = tape.watched_variables() grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribution_strategy_context.get_distribution_strategy( ).num_towers if num_towers > 1: loss *= (1. / num_towers) if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH]: raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % gate_gradients) self._assert_valid_dtypes([loss]) if grad_loss is not None: self._assert_valid_dtypes([grad_loss]) if var_list is None: var_list = ( variables.trainable_variables() + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) else: var_list = nest.flatten(var_list) # pylint: disable=protected-access var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) # pylint: enable=protected-access processors = [_get_processor(v) for v in var_list] if not var_list: raise ValueError("No variables to optimize.") var_refs = [p.target() for p in processors] grads = gradients.gradients( loss, var_refs, grad_ys=grad_loss, gate_gradients=(gate_gradients == Optimizer.GATE_OP), aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops) if gate_gradients == Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) grads_and_vars = list(zip(grads, var_list)) self._assert_valid_dtypes( [v for g, v in grads_and_vars if g is not None and v.dtype != dtypes.resource]) return grads_and_vars
def read_value(self): return distribution_strategy_context.get_distribution_strategy().read_var( self)
def _restore_checkpoint(self, master, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None): """Creates a `Session`, and tries to restore a checkpoint. Args: master: `String` representation of the TensorFlow master to use. saver: A `Saver` object used to restore a model. checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the dir will be used to restore. checkpoint_filename_with_path: Full file name path to the checkpoint file. wait_for_checkpoint: Whether to wait for checkpoint to become available. max_wait_secs: Maximum time to wait for checkpoints to become available. config: Optional `ConfigProto` proto used to configure the session. Returns: A pair (sess, is_restored) where 'is_restored' is `True` if the session could be restored, `False` otherwise. Raises: ValueError: If both checkpoint_dir and checkpoint_filename_with_path are set. """ self._target = master # This is required to so that we initialize the TPU device before # restoring from checkpoint since we'll be placing variables on the device # and TPUInitialize wipes out the memory of the device. strategy = distribution_strategy_context.get_distribution_strategy() if strategy and hasattr(strategy.extended, "_experimental_initialize_system"): strategy.extended._experimental_initialize_system() # pylint: disable=protected-access sess = session.Session(self._target, graph=self._graph, config=config) if checkpoint_dir and checkpoint_filename_with_path: raise ValueError("Can not provide both checkpoint_dir and " "checkpoint_filename_with_path.") # If either saver or checkpoint_* is not specified, cannot restore. Just # return. if not saver or not (checkpoint_dir or checkpoint_filename_with_path): return sess, False if checkpoint_filename_with_path: saver.restore(sess, checkpoint_filename_with_path) return sess, True # Waits up until max_wait_secs for checkpoint to become available. wait_time = 0 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) while not ckpt or not ckpt.model_checkpoint_path: if wait_for_checkpoint and wait_time < max_wait_secs: logging.info("Waiting for checkpoint to be available.") time.sleep(self._recovery_wait_secs) wait_time += self._recovery_wait_secs ckpt = checkpoint_management.get_checkpoint_state( checkpoint_dir) else: return sess, False # Loads the checkpoint. saver.restore(sess, ckpt.model_checkpoint_path) saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) return sess, True
def tensor(): return distribution_strategy_context.get_distribution_strategy().read_var( tower_local_variable)
def build(self, input_shape): input_shape = tf.TensorShape(input_shape) if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndims = len(input_shape) # Convert axis to list and resolve negatives if isinstance(self.axis, int): self.axis = [self.axis] if not isinstance(self.axis, list): raise TypeError( 'axis must be int or list, type given: %s' % type(self.axis)) for idx, x in enumerate(self.axis): if x < 0: self.axis[idx] = ndims + x # Validate axes for x in self.axis: if x < 0 or x >= ndims: raise ValueError('Invalid axis: %d' % x) if len(self.axis) != len(set(self.axis)): raise ValueError('Duplicate axis: %s' % self.axis) # Raise parameters of fp16 batch norm to fp32 if self.dtype == tf.float16 or self.dtype == tf.bfloat16: param_dtype = tf.float32 else: param_dtype = self.dtype or tf.float32 axis_to_dim = {x: input_shape[x] for x in self.axis} for x in axis_to_dim: if axis_to_dim[x] is None: raise ValueError( 'Input has undefined `axis` dimension. Input shape: ', input_shape) self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) if self.binary_dense: param_shape = [1] elif len(axis_to_dim) == 1: # Single axis batch norm (most common/default use-case) param_shape = (list(axis_to_dim.values())[0], ) else: # Parameter shape is the original shape but with 1 in all non-axis dims param_shape = [ axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims) ] if self.scale: self.gamma = self.add_weight( name='gamma', shape=param_shape, dtype=param_dtype, initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, trainable=True) else: self.gamma = None if self.center: self.beta = self.add_weight( name='beta', shape=param_shape, dtype=param_dtype, initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, trainable=True) else: self.beta = None try: # Disable variable partitioning when creating the moving mean and variance if hasattr(self, '_scope') and self._scope: partitioner = self._scope.partitioner self._scope.set_partitioner(None) else: partitioner = None self.moving_mean = self.add_weight( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, synchronization=tf.VariableSynchronization.ON_READ, trainable=False, aggregation=tf.VariableAggregation.MEAN) self.moving_variance = self.add_weight( name='moving_variance', shape=param_shape, dtype=param_dtype, initializer=self.moving_variance_initializer, synchronization=tf.VariableSynchronization.ON_READ, trainable=False, aggregation=tf.VariableAggregation.MEAN) if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving # averages above. The renorm variables are colocated with moving_mean # and moving_variance. # NOTE: below, the outer `with device` block causes the current device # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): var = self.add_weight( name=name, shape=shape, dtype=param_dtype, initializer=tf.zeros_initializer(), synchronization=tf.VariableSynchronization.ON_READ, trainable=False, aggregation=tf.VariableAggregation.MEAN) return var with distribution_strategy_context.get_distribution_strategy( ).colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable( 'renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable( 'renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. with distribution_strategy_context.get_distribution_strategy( ).colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable( 'renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable( 'renorm_stddev_weight', ()) finally: if partitioner: self._scope.set_partitioner(partitioner) self.built = True
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This contains most of the synchronization implementation and also wraps the apply_gradients() from the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: train_op: The op to dequeue a token so the replicas can exit this batch and start the next one. This is executed by each replica. Raises: ValueError: If the grads_and_vars is empty. ValueError: If global step is not provided, the staleness cannot be checked. """ if not grads_and_vars: raise ValueError("Must supply at least one variable") if global_step is None: raise ValueError("Global step is required to check staleness") self._global_step = global_step train_ops = [] aggregated_grad = [] var_list = [] # local_anchor op will be placed on this worker task by default. local_anchor = control_flow_ops.no_op() # Colocating local_step variable prevents it being placed on the PS. distribution_strategy = ( distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(local_anchor): self._local_step = variable_scope.variable( initial_value=0, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], dtype=global_step.dtype.base_dtype, name="sync_rep_local_step") self.local_step_init_op = state_ops.assign(self._local_step, global_step) chief_init_ops = [self.local_step_init_op] self.ready_for_local_init_op = variables.report_uninitialized_variables( variables.global_variables()) with ops.name_scope(None, self._name): for grad, var in grads_and_vars: var_list.append(var) with ops.device(var.device): # Dense gradients. if grad is None: aggregated_grad.append(None) # pass-through. continue elif isinstance(grad, ops.Tensor): grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=var.get_shape(), shared_name=var.name + "/grad_accum") train_ops.append(grad_accum.apply_grad( grad, local_step=self._local_step)) aggregated_grad.append(grad_accum.take_grad( self._replicas_to_aggregate)) else: if not isinstance(grad, ops.IndexedSlices): raise ValueError("Unknown grad type!") grad_accum = data_flow_ops.SparseConditionalAccumulator( grad.dtype, shape=(), shared_name=var.name + "/grad_accum") train_ops.append(grad_accum.apply_indexed_slices_grad( grad, local_step=self._local_step)) aggregated_grad.append(grad_accum.take_indexed_slices_grad( self._replicas_to_aggregate)) self._accumulator_list.append((grad_accum, var.device)) aggregated_grads_and_vars = zip(aggregated_grad, var_list) # sync_op will be assigned to the same device as the global step. with ops.device(global_step.device), ops.name_scope(""): update_op = self._opt.apply_gradients(aggregated_grads_and_vars, global_step) # Create token queue. with ops.device(global_step.device), ops.name_scope(""): sync_token_queue = ( data_flow_ops.FIFOQueue(-1, global_step.dtype.base_dtype, shapes=(), name="sync_token_q", shared_name="sync_token_q")) self._sync_token_queue = sync_token_queue # dummy_queue is passed to the queue runner. Don't use the real queues # because the queue runner doesn't automatically reopen it once it # closed queues in PS devices. dummy_queue = ( data_flow_ops.FIFOQueue(1, types_pb2.DT_INT32, shapes=(), name="dummy_queue", shared_name="dummy_queue")) with ops.device(global_step.device), ops.name_scope(""): # Replicas have to wait until they can get a token from the token queue. with ops.control_dependencies(train_ops): token = sync_token_queue.dequeue() train_op = state_ops.assign(self._local_step, token) with ops.control_dependencies([update_op]): # Sync_op needs to insert tokens to the token queue at the end of the # step so the replicas can fetch them to start the next step. tokens = array_ops.fill([self._tokens_per_step], global_step) sync_op = sync_token_queue.enqueue_many((tokens,)) if self._variable_averages is not None: with ops.control_dependencies([sync_op]), ops.name_scope(""): sync_op = self._variable_averages.apply( self._variables_to_average) self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue, [sync_op]) for accum, dev in self._accumulator_list: with ops.device(dev): chief_init_ops.append( accum.set_global_step( global_step, name="SetGlobalStep")) self.chief_init_op = control_flow_ops.group(*(chief_init_ops)) self._gradients_applied = True return train_op
def _restore_checkpoint(self, master, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None): """Creates a `Session`, and tries to restore a checkpoint. Args: master: `String` representation of the TensorFlow master to use. saver: A `Saver` object used to restore a model. checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the dir will be used to restore. checkpoint_filename_with_path: Full file name path to the checkpoint file. wait_for_checkpoint: Whether to wait for checkpoint to become available. max_wait_secs: Maximum time to wait for checkpoints to become available. config: Optional `ConfigProto` proto used to configure the session. Returns: A pair (sess, is_restored) where 'is_restored' is `True` if the session could be restored, `False` otherwise. Raises: ValueError: If both checkpoint_dir and checkpoint_filename_with_path are set. """ self._target = master sess = session.Session(self._target, graph=self._graph, config=config) # TODO(jhseu): Delete once tpu.initialize_system() goes away. initialize_ops = ( distribution_strategy_context.get_distribution_strategy().initialize() ) if initialize_ops: sess.run(initialize_ops) if checkpoint_dir and checkpoint_filename_with_path: raise ValueError("Can not provide both checkpoint_dir and " "checkpoint_filename_with_path.") # If either saver or checkpoint_* is not specified, cannot restore. Just # return. if not saver or not (checkpoint_dir or checkpoint_filename_with_path): return sess, False if checkpoint_filename_with_path: saver.restore(sess, checkpoint_filename_with_path) return sess, True # Waits up until max_wait_secs for checkpoint to become available. wait_time = 0 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) while not ckpt or not ckpt.model_checkpoint_path: if wait_for_checkpoint and wait_time < max_wait_secs: logging.info("Waiting for checkpoint to be available.") time.sleep(self._recovery_wait_secs) wait_time += self._recovery_wait_secs ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) else: return sess, False # Loads the checkpoint. saver.restore(sess, ckpt.model_checkpoint_path) saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) return sess, True
def compute_gradients(optimizer, loss, var_list=None, gate_gradients=Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): if callable(loss): from tensorflow.python.eager import backprop with backprop.GradientTape() as tape: if var_list is not None: tape.watch(var_list) loss_value = loss() # Scale loss if using a "mean" loss reduction and multiple towers. # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribution_strategy_context.get_distribution_strategy( ).num_towers if num_towers > 1: loss_value *= (1. / num_towers) if var_list is None: var_list = tape.watched_variables() # TODO(jhseu): Figure out why GradientTape's gradients don't require loss # to be executed. with ops.control_dependencies([loss_value]): grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribution_strategy_context.get_distribution_strategy( ).num_towers if num_towers > 1: loss *= (1. / num_towers) if gate_gradients not in [ Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH ]: raise ValueError( "gate_gradients must be one of: Optimizer.GATE_NONE, " "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % gate_gradients) optimizer._assert_valid_dtypes([loss]) if grad_loss is not None: optimizer._assert_valid_dtypes([grad_loss]) if var_list is None: var_list = (variables.trainable_variables() + ops.get_collection( ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) else: var_list = nest.flatten(var_list) # pylint: disable=protected-access var_list += ops.get_collection( ops.GraphKeys._STREAMING_MODEL_PORTS) # pylint: enable=protected-access from tensorflow.python.training.optimizer import _get_processor processors = [_get_processor(v) for v in var_list] if not var_list: raise ValueError("No variables to optimize.") var_refs = [p.target() for p in processors] # original gradients computation # grads = tf.gradients( # loss, var_refs, grad_ys=grad_loss, # gate_gradients=(gate_gradients == Optimizer.GATE_OP), # aggregation_method=aggregation_method, # colocate_gradients_with_ops=colocate_gradients_with_ops) # using gradient check-pointing from memory_saving_gradients import gradients # setting outputs of different networks tensors_to_checkpoint = self.get_tensors_to_checkpoint() # just specifying memory as parameter fails grads = gradients( loss, var_refs, grad_ys=grad_loss, gate_gradients=(gate_gradients == Optimizer.GATE_OP), aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops, checkpoints='speed') if gate_gradients == Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) grads_and_vars = list(zip(grads, var_list)) optimizer._assert_valid_dtypes([ v for g, v in grads_and_vars if g is not None and v.dtype != dtypes.resource ]) return grads_and_vars