def _create_non_slot_variable(self, initial_value, name, colocate_with): """Add an extra variable, not associated with a slot.""" # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. eager = context.executing_eagerly() graph = None if eager else colocate_with.graph key = (name, graph) v = self._non_slot_dict.get(key, None) if v is None: self._maybe_initialize_checkpointable() distribution_strategy = distribute_lib.get_distribution_strategy() with distribution_strategy.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( name=name, shape=None) if restored_initial_value is not None: initial_value = restored_initial_value v = variable_scope.variable(initial_value, name=name, trainable=False) # Restore this variable by name if necessary, but don't add a # Checkpointable dependency. Optimizers return the current graph's # non-slot variables from _checkpoint_dependencies explicitly rather # than unconditionally adding dependencies (since there may be multiple # non-slot variables with the same name in different graphs, trying to # save all of them would result in errors). self._handle_deferred_dependencies(name=name, checkpointable=v) self._non_slot_dict[key] = v return v
def _create_non_slot_variable(self, initial_value, name, colocate_with): """Add an extra variable, not associated with a slot.""" # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. eager = context.executing_eagerly() graph = None if eager else colocate_with.graph key = (name, graph) v = self._non_slot_dict.get(key, None) if v is None: self._maybe_initialize_checkpointable() distribution_strategy = distribute_lib.get_distribution_strategy() with distribution_strategy.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( name=name, shape=None) if restored_initial_value is not None: initial_value = restored_initial_value v = variable_scope.variable(initial_value, name=name, trainable=False) # Restore this variable by name if necessary, but don't add a # Checkpointable dependency. Optimizers return the current graph's # non-slot variables from _checkpoint_dependencies explicitly rather # than unconditionally adding dependencies (since there may be multiple # non-slot variables with the same name in different graphs, trying to # save all of them would result in errors). self._handle_deferred_dependencies(name=name, checkpointable=v) self._non_slot_dict[key] = v return v
def _assert_in_default_state(t): t.assertIs(distribute._default_tower_context, distribute.get_tower_context()) t.assertIs(None, distribute.get_cross_tower_context()) t.assertIs(distribute._default_distribution_strategy, distribute.get_distribution_strategy()) t.assertFalse(distribute.has_distribution_strategy())
def merge_fn(dist, s): self.assertIs(distribute._default_distribution_strategy, dist) self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertFalse(distribute.has_distribution_strategy()) return "foo_" + s
def _assert_in_default_state(t): t.assertIs(distribute._default_tower_context, distribute.get_tower_context()) t.assertIs(None, distribute.get_cross_tower_context()) t.assertIs(distribute._default_distribution_strategy, distribute.get_distribution_strategy()) t.assertFalse(distribute.has_distribution_strategy())
def create_slot(primary, val, name, colocate_with_primary=True): """Create a slot initialized to the given value. The type of the slot is determined by the given value. Args: primary: The primary `Variable` or `Tensor`. val: A `Tensor` specifying the initial value of the slot. name: Name to use for the slot variable. colocate_with_primary: Boolean. If True the slot is located on the same device as `primary`. Returns: A `Variable` object. """ # Scope the slot name in the namespace of the primary variable. # Set "primary.op.name + '/' + name" as default name, so the scope name of # optimizer can be shared when reuse is True. Meanwhile when reuse is False # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = val.get_shape().is_fully_defined() if context.executing_eagerly(): prefix = primary._shared_name # pylint: disable=protected-access else: prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: distribution_strategy = distribute_lib.get_distribution_strategy() with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) else: return _create_slot_var(primary, val, "", validate_shape, None, None)
def merge_fn(dist, s): self.assertIs(distribute._default_distribution_strategy, dist) self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertFalse(distribute.has_distribution_strategy()) return "foo_" + s
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribute_lib.get_cross_tower_context(): update_device = distribute_lib.get_update_device() # We are calling update on the mirrored variable in cross tower context. if update_device is not None: # We are calling an assign function on the mirrored variable in cross # tower context. v = self.get(device=update_device) return f(v, *args, **kwargs) return distribute_lib.get_distribution_strategy().update( self, f, *args, **kwargs) else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower # context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function on each of the mirrored variables with the reduced # value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "MirroredVariable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce( aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribute_lib.get_tower_context().merge_call(merge_fn, *args, **kwargs)
def run_fn(): tower_context = distribute.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
def run_fn(): tower_context = distribute.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("baz", variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("baz", variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def run_fn(): tower_context = distribute.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def run_fn(): tower_context = distribute.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def set_last_step_output(self, name, output, aggregation=variables_lib.VariableAggregation.NONE): """Set `output` with `name` to be outputted from the last step. Args: name: String, name to identify the output. Doesn't need to match tensor name. output: The tensors that should be outputted with `name`. See below for actual types supported. aggregation: Aggregation method to use to aggregate outputs from multiple towers. Required if `set_last_step_output` is called in a tower context. Optional in cross_tower_context. When present, the outputs from all the towers are aggregated using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and aggregation is set, output must be a `PerDevice` value. The aggregation method is also recorded in a dictionary `_last_step_outputs_aggregations` for later interpreting of the outputs as already reduced or not. """ if distribute_lib.get_cross_tower_context(): self._last_step_outputs_aggregations[name] = aggregation if aggregation is variables_lib.VariableAggregation.NONE: self._last_step_outputs[name] = output else: distribution = distribute_lib.get_distribution_strategy() self._last_step_outputs[name] = distribution.reduce( aggregation, output, destinations="/device:CPU:0") else: assert aggregation is not variables_lib.VariableAggregation.NONE def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce( aggregation, value, destinations="/device:CPU:0") # Setting this inside the `merge_fn` because all towers share the same # context object, so it's more robust to set it only once (even if all # the towers are trying to set the same value). self._last_step_outputs_aggregations[name] = aggregation distribute_lib.get_tower_context().merge_call(merge_fn, output)
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] # TODO(priyag, allenl): Use `SaveableObject.restore` instead here. if resource_variable_ops.is_resource_variable(variable): init_op = variable.assign(restore_op, read_value=False) # TODO(priyag): Remove this when using `SaveableObject.restore` instead. if hasattr(init_op, "_index"): init_op = distribute_lib.get_distribution_strategy().group(init_op) else: init_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access variable._initializer_op = init_op restore_op.set_shape(variable.shape) variable._initial_value = restore_op
def read_value(self): return distribute_lib.get_distribution_strategy().read_var(self)
def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndims = len(input_shape) # Convert axis to list and resolve negatives if isinstance(self.axis, int): self.axis = [self.axis] if not isinstance(self.axis, list): raise TypeError('axis must be int or list, type given: %s' % type(self.axis)) for idx, x in enumerate(self.axis): if x < 0: self.axis[idx] = ndims + x # Validate axes for x in self.axis: if x < 0 or x >= ndims: raise ValueError('Invalid axis: %d' % x) if len(self.axis) != len(set(self.axis)): raise ValueError('Duplicate axis: %s' % self.axis) if self.virtual_batch_size is not None: if self.virtual_batch_size <= 0: raise ValueError('virtual_batch_size must be a positive integer that ' 'divides the true batch size of the input Tensor') # If using virtual batches, the first dimension must be the batch # dimension and cannot be the batch norm axis if 0 in self.axis: raise ValueError('When using virtual_batch_size, the batch dimension ' 'must be 0 and thus axis cannot include 0') if self.adjustment is not None: raise ValueError('When using virtual_batch_size, adjustment cannot ' 'be specified') if self.fused: # Currently fused batch norm doesn't support renorm. It also only supports # an input tensor of rank 4 and a channel dimension on axis 1 or 3. # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the # output back to its original shape accordingly. self.fused = (not self.renorm and ndims == 4 and self.axis in [[1], [3]] and self.virtual_batch_size is None and self.adjustment is None) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is # particularly tricky. A compromise might be to just support the most # common use case (turning 5D w/ virtual batch to NCHW) if self.fused: if self.axis == [1]: self._data_format = 'NCHW' elif self.axis == [3]: self._data_format = 'NHWC' else: raise ValueError('Unsupported axis, fused batch norm only supports ' 'axis == [1] or axis == [3]') # Raise parameters of fp16 batch norm to fp32 if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: param_dtype = dtypes.float32 else: param_dtype = self.dtype or dtypes.float32 axis_to_dim = {x: input_shape[x].value for x in self.axis} for x in axis_to_dim: if axis_to_dim[x] is None: raise ValueError('Input has undefined `axis` dimension. Input shape: ', input_shape) self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) if len(axis_to_dim) == 1 and self.virtual_batch_size is None: # Single axis batch norm (most common/default use-case) param_shape = (list(axis_to_dim.values())[0],) else: # Parameter shape is the original shape but with 1 in all non-axis dims param_shape = [axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)] if self.virtual_batch_size is not None: # When using virtual batches, add an extra dim at index 1 param_shape.insert(1, 1) for idx, x in enumerate(self.axis): self.axis[idx] = x + 1 # Account for added dimension if self.scale: self.gamma = self.add_weight( name='gamma', shape=param_shape, dtype=param_dtype, initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, trainable=True) else: self.gamma = None if self.fused: self._gamma_const = array_ops.constant( 1.0, dtype=param_dtype, shape=param_shape) if self.center: self.beta = self.add_weight( name='beta', shape=param_shape, dtype=param_dtype, initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, trainable=True) else: self.beta = None if self.fused: self._beta_const = array_ops.constant( 0.0, dtype=param_dtype, shape=param_shape) try: # Disable variable partitioning when creating the moving mean and variance if hasattr(self, '_scope') and self._scope: partitioner = self._scope.partitioner self._scope.set_partitioner(None) else: partitioner = None self.moving_mean = self._add_tower_local_variable( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, trainable=False) self.moving_variance = self._add_tower_local_variable( name='moving_variance', shape=param_shape, dtype=param_dtype, initializer=self.moving_variance_initializer, trainable=False) if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving # averages above. The renorm variables are colocated with moving_mean # and moving_variance. # NOTE: below, the outer `with device` block causes the current device # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): var = self._add_tower_local_variable( name=name, shape=shape, dtype=param_dtype, initializer=init_ops.zeros_initializer(), trainable=False) return var with distribute_lib.get_distribution_strategy().colocate_vars_with( self.moving_mean): self.renorm_mean = _renorm_variable('renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. with distribute_lib.get_distribution_strategy().colocate_vars_with( self.moving_variance): self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight', ()) finally: if partitioner: self._scope.set_partitioner(partitioner) self.built = True
def tensor(): return distribute_lib.get_distribution_strategy().fetch( tower_local_variable)
def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): """Compute gradients of `loss` for the variables in `var_list`. This is the first part of `minimize()`. It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a `Tensor`, an `IndexedSlices`, or `None` if there is no gradient for the given variable. Args: loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable. var_list: Optional list or tuple of `tf.Variable` to update to minimize `loss`. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: If True, try colocating gradients with the corresponding op. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. Returns: A list of (gradient, variable) pairs. Variable is always present, but gradient can be `None`. Raises: TypeError: If `var_list` contains anything else than `Variable` objects. ValueError: If some arguments are invalid. RuntimeError: If called with eager execution enabled and `loss` is not callable. @compatibility(eager) When eager execution is enabled, `gate_gradients`, `aggregation_method`, and `colocate_gradients_with_ops` are ignored. @end_compatibility """ if callable(loss): with backprop.GradientTape() as tape: if var_list is not None: tape.watch(var_list) loss_value = loss() # Scale loss if using a "mean" loss reduction and multiple towers. # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. if distribute_lib.get_loss_reduction() == "mean": num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss_value *= (1. / num_towers) if var_list is None: var_list = tape.watched_variables() grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. if distribute_lib.get_loss_reduction() == "mean": num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss *= (1. / num_towers) if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH]: raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % gate_gradients) self._assert_valid_dtypes([loss]) if grad_loss is not None: self._assert_valid_dtypes([grad_loss]) if var_list is None: var_list = ( variables.trainable_variables() + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) else: var_list = nest.flatten(var_list) # pylint: disable=protected-access var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) # pylint: enable=protected-access processors = [_get_processor(v) for v in var_list] if not var_list: raise ValueError("No variables to optimize.") var_refs = [p.target() for p in processors] grads = gradients.gradients( loss, var_refs, grad_ys=grad_loss, gate_gradients=(gate_gradients == Optimizer.GATE_OP), aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops) if gate_gradients == Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) grads_and_vars = list(zip(grads, var_list)) self._assert_valid_dtypes( [v for g, v in grads_and_vars if g is not None and v.dtype != dtypes.resource]) return grads_and_vars
def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndims = len(input_shape) # Convert axis to list and resolve negatives if isinstance(self.axis, int): self.axis = [self.axis] if not isinstance(self.axis, list): raise TypeError('axis must be int or list, type given: %s' % type(self.axis)) for idx, x in enumerate(self.axis): if x < 0: self.axis[idx] = ndims + x # Validate axes for x in self.axis: if x < 0 or x >= ndims: raise ValueError('Invalid axis: %d' % x) if len(self.axis) != len(set(self.axis)): raise ValueError('Duplicate axis: %s' % self.axis) if self.virtual_batch_size is not None: if self.virtual_batch_size <= 0: raise ValueError( 'virtual_batch_size must be a positive integer that ' 'divides the true batch size of the input Tensor') # If using virtual batches, the first dimension must be the batch # dimension and cannot be the batch norm axis if 0 in self.axis: raise ValueError( 'When using virtual_batch_size, the batch dimension ' 'must be 0 and thus axis cannot include 0') if self.adjustment is not None: raise ValueError( 'When using virtual_batch_size, adjustment cannot ' 'be specified') if self.fused: # Currently fused batch norm doesn't support renorm. It also only supports # an input tensor of rank 4 and a channel dimension on axis 1 or 3. # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the # output back to its original shape accordingly. self.fused = (not self.renorm and ndims == 4 and self.axis in [[1], [3]] and self.virtual_batch_size is None and self.adjustment is None) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is # particularly tricky. A compromise might be to just support the most # common use case (turning 5D w/ virtual batch to NCHW) if self.fused: if self.axis == [1]: self._data_format = 'NCHW' elif self.axis == [3]: self._data_format = 'NHWC' else: raise ValueError( 'Unsupported axis, fused batch norm only supports ' 'axis == [1] or axis == [3]') # Raise parameters of fp16 batch norm to fp32 if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: param_dtype = dtypes.float32 else: param_dtype = self.dtype or dtypes.float32 axis_to_dim = {x: input_shape[x].value for x in self.axis} for x in axis_to_dim: if axis_to_dim[x] is None: raise ValueError( 'Input has undefined `axis` dimension. Input shape: ', input_shape) self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) if len(axis_to_dim) == 1 and self.virtual_batch_size is None: # Single axis batch norm (most common/default use-case) param_shape = (list(axis_to_dim.values())[0], ) else: # Parameter shape is the original shape but with 1 in all non-axis dims param_shape = [ axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims) ] if self.virtual_batch_size is not None: # When using virtual batches, add an extra dim at index 1 param_shape.insert(1, 1) for idx, x in enumerate(self.axis): self.axis[idx] = x + 1 # Account for added dimension if self.scale: self.gamma = self.add_weight(name='gamma', shape=param_shape, dtype=param_dtype, initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, trainable=True) else: self.gamma = None if self.fused: self._gamma_const = array_ops.constant(1.0, dtype=param_dtype, shape=param_shape) if self.center: self.beta = self.add_weight(name='beta', shape=param_shape, dtype=param_dtype, initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, trainable=True) else: self.beta = None if self.fused: self._beta_const = array_ops.constant(0.0, dtype=param_dtype, shape=param_shape) try: # Disable variable partitioning when creating the moving mean and variance if hasattr(self, '_scope') and self._scope: partitioner = self._scope.partitioner self._scope.set_partitioner(None) else: partitioner = None self.moving_mean = self._add_tower_local_variable( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, trainable=False) self.moving_variance = self._add_tower_local_variable( name='moving_variance', shape=param_shape, dtype=param_dtype, initializer=self.moving_variance_initializer, trainable=False) if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving # averages above. The renorm variables are colocated with moving_mean # and moving_variance. # NOTE: below, the outer `with device` block causes the current device # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): var = self._add_tower_local_variable( name=name, shape=shape, dtype=param_dtype, initializer=init_ops.zeros_initializer(), trainable=False) return var with distribute_lib.get_distribution_strategy( ).colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable( 'renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable( 'renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. with distribute_lib.get_distribution_strategy( ).colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable( 'renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable( 'renorm_stddev_weight', ()) finally: if partitioner: self._scope.set_partitioner(partitioner) self.built = True
def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): """Compute gradients of `loss` for the variables in `var_list`. This is the first part of `minimize()`. It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a `Tensor`, an `IndexedSlices`, or `None` if there is no gradient for the given variable. Args: loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable. var_list: Optional list or tuple of `tf.Variable` to update to minimize `loss`. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: If True, try colocating gradients with the corresponding op. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. Returns: A list of (gradient, variable) pairs. Variable is always present, but gradient can be `None`. Raises: TypeError: If `var_list` contains anything else than `Variable` objects. ValueError: If some arguments are invalid. RuntimeError: If called with eager execution enabled and `loss` is not callable. @compatibility(eager) When eager execution is enabled, `gate_gradients`, `aggregation_method`, and `colocate_gradients_with_ops` are ignored. @end_compatibility """ if callable(loss): with backprop.GradientTape() as tape: if var_list is not None: tape.watch(var_list) loss_value = loss() # Scale loss if using a "mean" loss reduction and multiple towers. # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss_value *= (1. / num_towers) if var_list is None: var_list = tape.watched_variables() grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss *= (1. / num_towers) if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH]: raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % gate_gradients) self._assert_valid_dtypes([loss]) if grad_loss is not None: self._assert_valid_dtypes([grad_loss]) if var_list is None: var_list = ( variables.trainable_variables() + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) else: var_list = nest.flatten(var_list) # pylint: disable=protected-access var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) # pylint: enable=protected-access processors = [_get_processor(v) for v in var_list] if not var_list: raise ValueError("No variables to optimize.") var_refs = [p.target() for p in processors] grads = gradients.gradients( loss, var_refs, grad_ys=grad_loss, gate_gradients=(gate_gradients == Optimizer.GATE_OP), aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops) if gate_gradients == Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) grads_and_vars = list(zip(grads, var_list)) self._assert_valid_dtypes( [v for g, v in grads_and_vars if g is not None and v.dtype != dtypes.resource]) return grads_and_vars