def _assert_in_default_state(t): t.assertIs(ds_context._get_default_replica_context(), ds_context.get_replica_context()) t.assertIs(None, ds_context.get_cross_replica_context()) t.assertFalse(ds_context.in_cross_replica_context()) t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy()) t.assertFalse(ds_context.has_strategy())
def _assert_in_default_state(t): t.assertIs(ds_context._get_default_replica_context(), ds_context.get_replica_context()) t.assertIs(None, ds_context.get_cross_replica_context()) t.assertFalse(ds_context.in_cross_replica_context()) t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy()) t.assertFalse(ds_context.has_strategy())
def merge_fn(dist, s): self.assertIs(ds_context._get_default_strategy(), dist) self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertIs(dist, ds_context.get_strategy()) self.assertFalse(ds_context.has_strategy()) return "foo_" + s
def merge_fn(dist, s): self.assertIs(ds_context._get_default_strategy(), dist) self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertIs(dist, ds_context.get_strategy()) self.assertFalse(ds_context.has_strategy()) return "foo_" + s
def init_mht_saveable_from_checkpoint(ckpt_dir_or_file, assignment_map): init_from_checkpoint_fn = lambda _: _init_mht_saveable_from_checkpoint( ckpt_dir_or_file, assignment_map) if distribution_strategy_context.get_cross_replica_context(): init_from_checkpoint_fn(None) else: distribution_strategy_context.get_replica_context().merge_call( init_from_checkpoint_fn)
def update(self, grads): """Updates loss scale based on if gradients are finite in current step.""" counter = self._num_good_steps if IS_PREV_TF_2_4_0 else self.counter growth_steps = self._increment_period if IS_PREV_TF_2_4_0 else self.growth_steps current_loss_scale = self._current_loss_scale if IS_PREV_TF_2_4_0 else self.current_loss_scale multiplier = self._multiplier if IS_PREV_TF_2_4_0 else self.multiplier grads = tf.nest.flatten(grads) if distribution_strategy_context.has_strategy(): if IS_PREV_TF_2_4_0: distribution = distribution_strategy_context.get_cross_replica_context() else: distribution = distribution_strategy_context.get_strategy() def get_is_finite(grads): is_finite = _refactor_is_all_finite(grads) # !!!!!!!!! # We cast to float, because we cannot reduce booleans with # DistributionStrategy. return tf.cast(is_finite, tf.float32) is_finite_float = distribution.extended.call_for_each_replica( get_is_finite, args=(grads,)) reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, is_finite_float, axis=None) is_finite = tf.equal(reduced_is_finite_float, distribution.num_replicas_in_sync) else: is_finite = _refactor_is_all_finite(grads) def update_if_finite_grads(): """Update assuming the gradients are finite.""" def incr_loss_scale(): new_loss_scale = current_loss_scale * multiplier return control_flow_ops.group( _assign_if_finite(current_loss_scale, new_loss_scale), counter.assign(0)) return control_flow_ops.cond( counter + 1 >= growth_steps, incr_loss_scale, lambda: _op_in_graph_mode(counter.assign_add(1))) def update_if_not_finite_grads(): """Update assuming the gradients are nonfinite.""" new_loss_scale = tf.math.maximum(current_loss_scale / multiplier, 1) return control_flow_ops.group( counter.assign(0), current_loss_scale.assign(new_loss_scale)) update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, update_if_not_finite_grads) should_apply_gradients = is_finite return update_op, should_apply_gradients
def run_fn(): replica_context = ds_context.get_replica_context() self.assertTrue(replica_context is not None) self.assertIs(None, ds_context.get_cross_replica_context()) self.assertFalse(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def run_fn(): replica_context = ds_context.get_replica_context() self.assertTrue(replica_context is not None) self.assertIs(None, ds_context.get_cross_replica_context()) self.assertFalse(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def update(self, grads): """Updates loss scale based on if gradients are finite in current step.""" grads = nest.flatten(grads) if distribution_strategy_context.has_strategy(): distribution = distribution_strategy_context.get_cross_replica_context( ) def get_is_finite(grads): is_finite = _is_all_finite(grads) # We cast to float, because we cannot reduce booleans with # DistributionStrategy. return math_ops.cast(is_finite, dtypes.float32) is_finite_float = distribution.extended.call_for_each_replica( get_is_finite, args=(grads, )) reduced_is_finite_float = distribution.reduce( reduce_util.ReduceOp.SUM, is_finite_float, axis=None) is_finite = math_ops.equal(reduced_is_finite_float, distribution.num_replicas_in_sync) else: is_finite = _is_all_finite(grads) def update_if_finite_grads(): """Update assuming the gradients are finite.""" def incr_loss_scale(): new_loss_scale = math_ops.minimum( self._current_loss_scale * self._multiplier, 2**32) return control_flow_ops.group( _assign_if_finite(self._current_loss_scale, new_loss_scale), self._num_good_steps.assign(0)) return control_flow_ops.cond( self._num_good_steps + 1 >= self._increment_period, incr_loss_scale, lambda: _op_in_graph_mode(self._num_good_steps.assign_add(1))) def update_if_not_finite_grads(): """Update assuming the gradients are nonfinite.""" new_loss_scale = math_ops.maximum( self._current_loss_scale / self._multiplier, 1) return control_flow_ops.group( self._num_good_steps.assign(0), self._current_loss_scale.assign(new_loss_scale)) update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, update_if_not_finite_grads) should_apply_gradients = is_finite return update_op, should_apply_gradients
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def update(self, grads): """Updates loss scale based on if gradients are finite in current step.""" if distribution_strategy_context.has_strategy(): distribution = distribution_strategy_context.get_cross_replica_context() def get_is_finite(grads): is_finite = _is_all_finite(grads) # We cast to float, because we cannot reduce booleans with # DistributionStrategy. return math_ops.cast(is_finite, dtypes.float32) is_finite_float = distribution.extended.call_for_each_replica( get_is_finite, args=(grads,)) reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, is_finite_float, axis=None) is_finite = math_ops.equal(reduced_is_finite_float, distribution.num_replicas_in_sync) else: is_finite = _is_all_finite(grads) def update_if_finite_grads(): """Update assuming the gradients are finite.""" def incr_loss_scale(): new_loss_scale = self._current_loss_scale * self._multiplier return control_flow_ops.group( _assign_if_finite(self._current_loss_scale, new_loss_scale), self._num_good_steps.assign(0)) return control_flow_ops.cond( self._num_good_steps + 1 >= self._increment_period, incr_loss_scale, lambda: _op_in_graph_mode( self._num_good_steps.assign_add(1))) def update_if_not_finite_grads(): """Update assuming the gradients are nonfinite.""" new_loss_scale = math_ops.maximum( self._current_loss_scale / self._multiplier, 1) return control_flow_ops.group( self._num_good_steps.assign(0), self._current_loss_scale.assign(new_loss_scale)) update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, update_if_not_finite_grads) should_apply_gradients = is_finite return update_op, should_apply_gradients
def init_from_checkpoint(ckpt_dir_or_file, assignment_map): """Replaces `tf.Variable` initializers so they load from a checkpoint file. Values are not loaded immediately, but when the initializer is run (typically by running a `tf.compat.v1.global_variables_initializer` op). Note: This overrides default initialization ops of specified variables and redefines dtype. Assignment map supports following syntax: * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in current `scope_name` from `checkpoint_scope_name` with matching tensor names. * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - will initialize `scope_name/variable_name` variable from `checkpoint_scope_name/some_other_variable`. * `'scope_variable_name': variable` - will initialize given `tf.Variable` object with tensor 'scope_variable_name' from the checkpoint. * `'scope_variable_name': list(variable)` - will initialize list of partitioned variables with tensor 'scope_variable_name' from the checkpoint. * `'/': 'scope_name/'` - will load all variables in current `scope_name` from checkpoint's root (e.g. no scope). Supports loading into partitioned variables, which are represented as `'<variable>/part_<part #>'`. Example: ```python # Say, '/tmp/model.ckpt' has the following tensors: # -- name='old_scope_1/var1', shape=[20, 2] # -- name='old_scope_1/var2', shape=[50, 4] # -- name='old_scope_2/var3', shape=[100, 100] # Create new model's variables with tf.compat.v1.variable_scope('new_scope_1'): var1 = tf.compat.v1.get_variable('var1', shape=[20, 2], initializer=tf.compat.v1.zeros_initializer()) with tf.compat.v1.variable_scope('new_scope_2'): var2 = tf.compat.v1.get_variable('var2', shape=[50, 4], initializer=tf.compat.v1.zeros_initializer()) # Partition into 5 variables along the first axis. var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100], initializer=tf.compat.v1.zeros_initializer(), partitioner=lambda shape, dtype: [5, 1]) # Initialize all variables in `new_scope_1` from `old_scope_1`. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'}) # Use names to specify which variables to initialize from checkpoint. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': 'new_scope_1/var1', 'old_scope_1/var2': 'new_scope_2/var2'}) # Or use tf.Variable objects to identify what to initialize. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': var1, 'old_scope_1/var2': var2}) # Initialize partitioned variables using variable's name init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': 'new_scope_2/var3'}) # Or specify the list of tf.Variable objects. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': var3._get_variable_list()}) ``` Args: ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. assignment_map: Dict, where keys are names of the variables in the checkpoint and values are current variables or names of current variables (in default graph). Raises: ValueError: If missing variables in current graph, or if missing checkpoints or tensors in checkpoints. """ init_from_checkpoint_fn = lambda _: _init_from_checkpoint( ckpt_dir_or_file, assignment_map) if distribution_strategy_context.get_cross_replica_context(): init_from_checkpoint_fn(None) else: distribution_strategy_context.get_replica_context().merge_call( init_from_checkpoint_fn)
def assign_moving_average(variable, value, decay, zero_debias=True, name=None): """Compute the moving average of a variable. The moving average of 'variable' updated with 'value' is: variable * decay + value * (1 - decay) The returned Operation sets 'variable' to the newly computed moving average, by performing this subtraction: variable -= (1 - decay) * (variable - value) Since variables that are initialized to a `0` value will be `0` biased, `zero_debias` optionally enables scaling by the mathematically correct debiasing factor of 1 - decay ** num_updates See Section 3 of (Kingma et al., 2015) for more details. The names of the debias shadow variables, by default, include both the scope they were created in and the scope of the variables they debias. They are also given a uniquifying-suffix. E.g.: ``` with tf.compat.v1.variable_scope('scope1'): with tf.compat.v1.variable_scope('scope2'): var = tf.compat.v1.get_variable('foo') update_1 = tf.assign_moving_average(var, 0.0, 1.0) update_2 = tf.assign_moving_average(var, 0.0, 0.9) # var.name: 'scope1/scope2/foo' # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' # 'scope1/scope2/scope1/scope2/foo/biased_1' ``` Args: variable: A Variable. value: A tensor with the same shape as 'variable'. decay: A float Tensor or float value. The moving average decay. zero_debias: A python bool. If true, assume the variable is 0-initialized and unbias it, as in (Kingma et al., 2015). See docstring in `_zero_debias` for more details. name: Optional name of the returned operation. Returns: A tensor which if evaluated will compute and return the new moving average. References: Adam - A Method for Stochastic Optimization: [Kingma et al., 2015](https://arxiv.org/abs/1412.6980) ([pdf](https://arxiv.org/pdf/1412.6980.pdf)) """ with ops.name_scope(name, "AssignMovingAvg", [variable, value, decay]) as scope: decay = ops.convert_to_tensor(1.0 - decay, name="decay") if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) def update_fn(v, value): return state_ops.assign_sub(v, (v - value) * decay, name=scope) def update(strategy, v, value): if zero_debias: return _zero_debias(strategy, v, value, decay) else: return _update(strategy, v, update_fn, args=(value,)) replica_context = distribution_strategy_context.get_replica_context() if replica_context: # In a replica context, we update variable using the mean of value across # replicas. def merge_fn(strategy, v, value): value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value, v) return update(strategy, v, value) return replica_context.merge_call(merge_fn, args=(variable, value)) else: strategy = distribution_strategy_context.get_cross_replica_context() return update(strategy, variable, value)
def init_from_checkpoint(ckpt_dir_or_file, assignment_map): """Replaces `tf.Variable` initializers so they load from a checkpoint file. Values are not loaded immediately, but when the initializer is run (typically by running a `tf.compat.v1.global_variables_initializer` op). Note: This overrides default initialization ops of specified variables and redefines dtype. Assignment map supports following syntax: * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in current `scope_name` from `checkpoint_scope_name` with matching tensor names. * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - will initialize `scope_name/variable_name` variable from `checkpoint_scope_name/some_other_variable`. * `'scope_variable_name': variable` - will initialize given `tf.Variable` object with tensor 'scope_variable_name' from the checkpoint. * `'scope_variable_name': list(variable)` - will initialize list of partitioned variables with tensor 'scope_variable_name' from the checkpoint. * `'/': 'scope_name/'` - will load all variables in current `scope_name` from checkpoint's root (e.g. no scope). Supports loading into partitioned variables, which are represented as `'<variable>/part_<part #>'`. Example: ```python # Say, '/tmp/model.ckpt' has the following tensors: # -- name='old_scope_1/var1', shape=[20, 2] # -- name='old_scope_1/var2', shape=[50, 4] # -- name='old_scope_2/var3', shape=[100, 100] # Create new model's variables with tf.compat.v1.variable_scope('new_scope_1'): var1 = tf.compat.v1.get_variable('var1', shape=[20, 2], initializer=tf.compat.v1.zeros_initializer()) with tf.compat.v1.variable_scope('new_scope_2'): var2 = tf.compat.v1.get_variable('var2', shape=[50, 4], initializer=tf.compat.v1.zeros_initializer()) # Partition into 5 variables along the first axis. var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100], initializer=tf.compat.v1.zeros_initializer(), partitioner=lambda shape, dtype: [5, 1]) # Initialize all variables in `new_scope_1` from `old_scope_1`. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'}) # Use names to specify which variables to initialize from checkpoint. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': 'new_scope_1/var1', 'old_scope_1/var2': 'new_scope_2/var2'}) # Or use tf.Variable objects to identify what to initialize. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': var1, 'old_scope_1/var2': var2}) # Initialize partitioned variables using variable's name init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': 'new_scope_2/var3'}) # Or specify the list of tf.Variable objects. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': var3._get_variable_list()}) ``` Args: ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. assignment_map: Dict, where keys are names of the variables in the checkpoint and values are current variables or names of current variables (in default graph). Raises: ValueError: If missing variables in current graph, or if missing checkpoints or tensors in checkpoints. """ init_from_checkpoint_fn = lambda _: _init_from_checkpoint( ckpt_dir_or_file, assignment_map) if distribution_strategy_context.get_cross_replica_context(): init_from_checkpoint_fn(None) else: distribution_strategy_context.get_replica_context().merge_call( init_from_checkpoint_fn)
def assign_moving_average(variable, value, decay, zero_debias=True, name=None): """Compute the moving average of a variable. The moving average of 'variable' updated with 'value' is: variable * decay + value * (1 - decay) The returned Operation sets 'variable' to the newly computed moving average, by performing this subtraction: variable -= (1 - decay) * (variable - value) Since variables that are initialized to a `0` value will be `0` biased, `zero_debias` optionally enables scaling by the mathematically correct debiasing factor of 1 - decay ** num_updates See `ADAM: A Method for Stochastic Optimization` Section 3 for more details (https://arxiv.org/abs/1412.6980). The names of the debias shadow variables, by default, include both the scope they were created in and the scope of the variables they debias. They are also given a uniquifying-suffix. E.g.: ``` with tf.compat.v1.variable_scope('scope1'): with tf.compat.v1.variable_scope('scope2'): var = tf.compat.v1.get_variable('foo') update_1 = tf.assign_moving_average(var, 0.0, 1.0) update_2 = tf.assign_moving_average(var, 0.0, 0.9) # var.name: 'scope1/scope2/foo' # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' # 'scope1/scope2/scope1/scope2/foo/biased_1' ``` Args: variable: A Variable. value: A tensor with the same shape as 'variable'. decay: A float Tensor or float value. The moving average decay. zero_debias: A python bool. If true, assume the variable is 0-initialized and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in `_zero_debias` for more details. name: Optional name of the returned operation. Returns: A tensor which if evaluated will compute and return the new moving average. """ with ops.name_scope(name, "AssignMovingAvg", [variable, value, decay]) as scope: decay = ops.convert_to_tensor(1.0 - decay, name="decay") if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) def update_fn(v, value): return state_ops.assign_sub(v, (v - value) * decay, name=scope) def update(strategy, v, value): if zero_debias: return _zero_debias(strategy, v, value, decay) else: return strategy.extended.update(v, update_fn, args=(value,)) replica_context = distribution_strategy_context.get_replica_context() if replica_context: # In a replica context, we update variable using the mean of value across # replicas. def merge_fn(strategy, v, value): value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value, v) return update(strategy, v, value) return replica_context.merge_call(merge_fn, args=(variable, value)) else: strategy = distribution_strategy_context.get_cross_replica_context() return update(strategy, variable, value)
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This is the second part of `minimize()`. It returns an `Operation` that applies gradients. Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`. Raises: TypeError: If `grads_and_vars` is malformed. ValueError: If none of the variables have gradients. RuntimeError: If you should use `_distributed_apply()` instead. """ # This is a default implementation of apply_gradients() that can be shared # by most optimizers. It relies on the subclass implementing the following # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). # Handle DistributionStrategy case. if distribute_ctx.get_cross_replica_context(): raise RuntimeError("Use `_distributed_apply()` instead of " "`apply_gradients()` in a cross-replica context.") # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by # always calling _distributed_apply(), using the default distribution # as needed. if distribute_ctx.has_distribution_strategy(): grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() return distribute_ctx.get_replica_context().merge_call( self._distributed_apply, args=(grads_and_vars, global_step, name)) # No DistributionStrategy case. grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. if not grads_and_vars: raise ValueError("No variables provided.") converted_grads_and_vars = [] for g, v in grads_and_vars: if g is not None: try: # Convert the grad to Tensor or IndexedSlices if necessary. g = ops.convert_to_tensor_or_indexed_slices(g) except TypeError: raise TypeError( "Gradient must be convertible to a Tensor" " or IndexedSlices, or None: %s" % g) if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): raise TypeError( "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) p = _get_processor(v) converted_grads_and_vars.append((g, v, p)) converted_grads_and_vars = tuple(converted_grads_and_vars) var_list = [v for g, v, _ in converted_grads_and_vars if g is not None] if not var_list: raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v, _ in converted_grads_and_vars],)) with ops.init_scope(): self._create_slots(var_list) update_ops = [] with ops.name_scope(name, self._name) as name: self._prepare() for grad, var, processor in converted_grads_and_vars: if grad is None: continue # We colocate all ops created in _apply_dense or _apply_sparse # on the same device as the variable. # TODO(apassos): figure out how to get the variable name here. if context.executing_eagerly() or isinstance( var, resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access scope_name = "" else: scope_name = var.op.name with ops.name_scope("update_" + scope_name), ops.colocate_with(var): update_ops.append(processor.update_op(self, grad)) if global_step is None: apply_updates = self._finish(update_ops, name) else: with ops.control_dependencies([self._finish(update_ops, "update")]): with ops.colocate_with(global_step): if isinstance(global_step, resource_variable_ops.ResourceVariable): # TODO(apassos): the implicit read in assign_add is slow; consider # making it less so. apply_updates = resource_variable_ops.assign_add_variable_op( global_step.handle, ops.convert_to_tensor(1, dtype=global_step.dtype), name=name) else: apply_updates = state_ops.assign_add(global_step, 1, name=name) if not context.executing_eagerly(): if isinstance(apply_updates, ops.Tensor): apply_updates = apply_updates.op train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) if apply_updates not in train_op: train_op.append(apply_updates) return apply_updates
def init_from_checkpoint(ckpt_dir_or_file, assignment_map): """Replaces `tf.Variable` initializers so they load from a checkpoint file. @compatibility(TF2) `tf.compat.v1.train.init_from_checkpoint` is not recommended for restoring variable values in TF2. To restore checkpoints in TF2, please use `tf.keras.Model.load_weights` or `tf.train.Checkpoint.restore`. These APIs use use an [object-based method of checkpointing] (https://www.tensorflow.org/guide/checkpoint#loading_mechanics), while `tf.compat.v1.init_from_checkpoint` relies on a more-fragile variable-name based method of checkpointing. There is no object-based equivalent of `init_from_checkpoint` in TF2. Please re-write your checkpoints immediately using the object-based APIs, see [migration guide] (https://www.tensorflow.org/guide/migrate#checkpoint_compatibility) for more details. You can load a name-based checkpoint written by `tf.compat.v1.train.Saver` using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However, you may have to change the names of the variables in your model to match the variable names in the name-based checkpoint, which can be viewed with `tf.train.list_variables(path)`. Another option is to create an `assignment_map` that maps the name of the variables in the name-based checkpoint to the variables in your model, eg: ``` { 'sequential/dense/bias': model.variables[0], 'sequential/dense/kernel': model.variables[1] } ``` and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to restore the name-based checkpoint. After restoring, re-encode your checkpoint using `tf.train.Checkpoint.save` or `tf.keras.Model.save_weights`. @end_compatibility Values are not loaded immediately, but when the initializer is run (typically by running a `tf.compat.v1.global_variables_initializer` op). Note: This overrides default initialization ops of specified variables and redefines dtype. Assignment map supports following syntax: * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in current `scope_name` from `checkpoint_scope_name` with matching tensor names. * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - will initialize `scope_name/variable_name` variable from `checkpoint_scope_name/some_other_variable`. * `'scope_variable_name': variable` - will initialize given `tf.Variable` object with tensor 'scope_variable_name' from the checkpoint. * `'scope_variable_name': list(variable)` - will initialize list of partitioned variables with tensor 'scope_variable_name' from the checkpoint. * `'/': 'scope_name/'` - will load all variables in current `scope_name` from checkpoint's root (e.g. no scope). Supports loading into partitioned variables, which are represented as `'<variable>/part_<part #>'`. Assignment map can be a dict, or a list of pairs. The latter is necessary to initialize multiple variables in the current graph from the same variable in the checkpoint. Example: ```python # Say, '/tmp/model.ckpt' has the following tensors: # -- name='old_scope_1/var1', shape=[20, 2] # -- name='old_scope_1/var2', shape=[50, 4] # -- name='old_scope_2/var3', shape=[100, 100] # Create new model's variables with tf.compat.v1.variable_scope('new_scope_1'): var1 = tf.compat.v1.get_variable('var1', shape=[20, 2], initializer=tf.compat.v1.zeros_initializer()) with tf.compat.v1.variable_scope('new_scope_2'): var2 = tf.compat.v1.get_variable('var2', shape=[50, 4], initializer=tf.compat.v1.zeros_initializer()) # Partition into 5 variables along the first axis. var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100], initializer=tf.compat.v1.zeros_initializer(), partitioner=lambda shape, dtype: [5, 1]) # Initialize all variables in `new_scope_1` from `old_scope_1`. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1/'}) # Use names to specify which variables to initialize from checkpoint. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': 'new_scope_1/var1', 'old_scope_1/var2': 'new_scope_2/var2'}) # Or use tf.Variable objects to identify what to initialize. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': var1, 'old_scope_1/var2': var2}) # Initialize partitioned variables using variable's name init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': 'new_scope_2/var3'}) # Or specify the list of tf.Variable objects. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': var3._get_variable_list()}) ``` Args: ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. assignment_map: Dict, or a list of key-value pairs, where keys are names of the variables in the checkpoint and values are current variables or names of current variables (in default graph). Raises: ValueError: If missing variables in current graph, or if missing checkpoints or tensors in checkpoints. """ init_from_checkpoint_fn = lambda _: _init_from_checkpoint( ckpt_dir_or_file, assignment_map) if distribution_strategy_context.get_cross_replica_context(): init_from_checkpoint_fn(None) else: distribution_strategy_context.get_replica_context().merge_call( init_from_checkpoint_fn)