def model_fn(): vs = [] vs.append(variable_scope.variable(1.0, name="foo/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) distribute_lib.get_tower_context().merge_call(lambda _: _) return vs
def model_fn(): v0 = variable_scope.get_variable("var-thread0", [1]) with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var-thread1", [1]) # This will pause the current thread, and execute the other thread. distribute_lib.get_tower_context().merge_call(lambda _: _) v2 = variable_scope.get_variable("var-thread2", [1]) return v0, v1, v2
def set_non_tensor_output(self, name, output): """Set `output` with `name` to be captured as a non tensor output.""" if distribute_lib.get_cross_tower_context(): self._non_tensor_outputs[name] = output else: def merge_fn(distribution, value): # NOTE(priyag): For non tensor outputs, we simply return all the values # in a list as aggregation doesn't make sense on non tensors. self._non_tensor_outputs[name] = distribution.unwrap(value) distribute_lib.get_tower_context().merge_call(merge_fn, output)
def model_fn(device_id): assert isinstance(device_id, int) def thread_creator_fn(next_creator, *args, **kwargs): return next_creator(*args, **kwargs) + ":thread_" + str(device_id) with variable_scope.variable_creator_scope(thread_creator_fn): # Create a variable in this scope. v = variable_scope.variable(1.0) # This will pause the current thread, and execute the other thread. distribute_lib.get_tower_context().merge_call(lambda _: _) return v
def model_fn(features): with variable_scope.variable_scope("common"): layer1 = core.Dense(1) layer1(features) layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. distribute_lib.get_tower_context().merge_call(lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)]
def merge_fn(dist, s): self.assertIs(distribute._default_distribution_strategy, dist) self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertFalse(distribute.has_distribution_strategy()) return "foo_" + s
def model_fn(): v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. distribute_lib.get_tower_context().merge_call(lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) v3 = variable_scope.get_variable( "var3", [1], synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation=variable_scope.VariableAggregation.MEAN) return v0, v1, v2, v3
def merge_fn(dist, s): self.assertIs(distribute._default_distribution_strategy, dist) self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertFalse(distribute.has_distribution_strategy()) return "foo_" + s
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribute_lib.get_cross_tower_context(): update_device = distribute_lib.get_update_device() # We are calling update on the mirrored variable in cross tower context. if update_device is not None: # We are calling an assign function on the mirrored variable in cross # tower context. v = self.get(device=update_device) return f(v, *args, **kwargs) return distribute_lib.get_distribution_strategy().update( self, f, *args, **kwargs) else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower # context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function on each of the mirrored variables with the reduced # value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "MirroredVariable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce( aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribute_lib.get_tower_context().merge_call(merge_fn, *args, **kwargs)
def model_fn(): v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. distribute_lib.get_tower_context().merge_call(lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) v3 = variable_scope.get_variable( "var3", [1], synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation=variable_scope.VariableAggregation.MEAN) return v0, v1, v2, v3
def _assert_in_default_state(t): t.assertIs(distribute._default_tower_context, distribute.get_tower_context()) t.assertIs(None, distribute.get_cross_tower_context()) t.assertIs(distribute._default_distribution_strategy, distribute.get_distribution_strategy()) t.assertFalse(distribute.has_distribution_strategy())
def _aggregate_across_towers(metrics_collections, metric_value_fn, *args): """Aggregate metric value across towers.""" def fn(distribution, *a): """Call `metric_value_fn` in the correct control flow context.""" if hasattr(distribution, '_outer_control_flow_context'): # If there was an outer context captured before this method was called, # then we enter that context to create the metric value op. If the # caputred context is `None`, ops.control_dependencies(None) gives the # desired behavior. Else we use `Enter` and `Exit` to enter and exit the # captured context. # This special handling is needed because sometimes the metric is created # inside a while_loop (and perhaps a TPU rewrite context). But we don't # want the value op to be evaluated every step or on the TPU. So we # create it outside so that it can be evaluated at the end on the host, # once the update ops have been evaluted. # pylint: disable=protected-access if distribution._outer_control_flow_context is None: with ops.control_dependencies(None): metric_value = metric_value_fn(distribution, *a) else: distribution._outer_control_flow_context.Enter() metric_value = metric_value_fn(distribution, *a) distribution._outer_control_flow_context.Exit() # pylint: enable=protected-access else: metric_value = metric_value_fn(distribution, *a) if metrics_collections: ops.add_to_collections(metrics_collections, metric_value) return metric_value return distribute_lib.get_tower_context().merge_call(fn, *args)
def _assert_in_default_state(t): t.assertIs(distribute._default_tower_context, distribute.get_tower_context()) t.assertIs(None, distribute.get_cross_tower_context()) t.assertIs(distribute._default_distribution_strategy, distribute.get_distribution_strategy()) t.assertFalse(distribute.has_distribution_strategy())
def model_fn(): tower_context = distribute_lib.get_tower_context() with tower_context.tower_local_var_scope( variable_scope.VariableAggregation.SUM): v_sum = variable_scope.variable(1.0) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) return v_sum
def skip_summary(): # If using multiple towers in distributed strategy, skip summaries on all # towers except the first one (tower_id=0). # TODO(priyag): Add a new optional argument that will provide multiple # alternatives to override default behavior. (e.g. run on last tower, # compute sum or mean across towers). tower_context = distribute.get_tower_context() return tower_context and tower_context.tower_id > 0
def run_fn(): tower_context = distribute.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
def skip_summary(): # If using multiple towers in distributed strategy, skip summaries on all # towers except the first one (tower_id=0). # TODO(priyag): Add a new optional argument that will provide multiple # alternatives to override default behavior. (e.g. run on last tower, # compute sum or mean across towers). tower_context = distribute.get_tower_context() return tower_context and tower_context.tower_id > 0
def run_fn(): tower_context = distribute.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("baz", variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("baz", variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def run_fn(): tower_context = distribute.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def set_last_step_output(self, name, output, aggregation=variables_lib.VariableAggregation.NONE): """Set `output` with `name` to be outputted from the last step. Args: name: String, name to identify the output. Doesn't need to match tensor name. output: The tensors that should be outputted with `name`. See below for actual types supported. aggregation: Aggregation method to use to aggregate outputs from multiple towers. Required if `set_last_step_output` is called in a tower context. Optional in cross_tower_context. When present, the outputs from all the towers are aggregated using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and aggregation is set, output must be a `PerDevice` value. The aggregation method is also recorded in a dictionary `_last_step_outputs_aggregations` for later interpreting of the outputs as already reduced or not. """ if distribute_lib.get_cross_tower_context(): self._last_step_outputs_aggregations[name] = aggregation if aggregation is variables_lib.VariableAggregation.NONE: self._last_step_outputs[name] = output else: distribution = distribute_lib.get_distribution_strategy() self._last_step_outputs[name] = distribution.reduce( aggregation, output, destinations="/device:CPU:0") else: assert aggregation is not variables_lib.VariableAggregation.NONE def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce( aggregation, value, destinations="/device:CPU:0") # Setting this inside the `merge_fn` because all towers share the same # context object, so it's more robust to set it only once (even if all # the towers are trying to set the same value). self._last_step_outputs_aggregations[name] = aggregation distribute_lib.get_tower_context().merge_call(merge_fn, output)
def run_fn(): tower_context = distribute.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def model_fn(device_id): tower_context = distribute_lib.get_tower_context() with tower_context.tower_local_var_scope("sum"): v_sum = variable_scope.variable(1.0) with tower_context.tower_local_var_scope("mean"): v_mean = variable_scope.variable(4.0) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) updates = [v_sum.assign_add(2.0 + device_id), v_mean.assign(6.0 * device_id)] all_v_sum[device_id] = v_sum all_v_mean[device_id] = v_mean return updates, v_sum, v_mean
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertTrue(distribute.has_distribution_strategy()) self.assertIs(dist, distribute.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def _decay_weights_op(self, var): def apply_decay_fn(v, decay): return state_ops.assign_sub(v, decay * v, self._use_locking) def merge_fn(strategy, v, value): value = strategy.reduce(variable_scope.VariableAggregation.MEAN, value, v) return strategy.update(v, apply_decay_fn, value) if not self._decay_var_list or var in self._decay_var_list: replica_context = distribute_lib.get_tower_context() return replica_context.merge_call(merge_fn, var, self._weight_decay) return control_flow_ops.no_op()
def testMergeCall(self): _assert_in_default_state(self) def merge_fn(dist, s): self.assertIs(distribute._default_distribution_strategy, dist) self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertFalse(distribute.has_distribution_strategy()) return "foo_" + s tower_ctx = distribute.get_tower_context() self.assertIs(distribute._default_tower_context, tower_ctx) self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar")) _assert_in_default_state(self)
def testMergeCall(self): _assert_in_default_state(self) def merge_fn(dist, s): self.assertIs(distribute._default_distribution_strategy, dist) self.assertIs(None, distribute.get_tower_context()) self.assertIs(dist, distribute.get_cross_tower_context()) self.assertIs(dist, distribute.get_distribution_strategy()) self.assertFalse(distribute.has_distribution_strategy()) return "foo_" + s tower_ctx = distribute.get_tower_context() self.assertIs(distribute._default_tower_context, tower_ctx) self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar")) _assert_in_default_state(self)
def get(self, device=None): """Returns the value for the current device or raises a ValueError.""" if device is None: tower_context = distribute_lib.get_tower_context() if tower_context: device = tower_context.device else: device = distribute_lib.get_update_device() if device is None: device = device_util.current() device = device_util.canonicalize(device) try: return self._index[device] except KeyError: raise ValueError("Device %s not found in %s (current device %s)" % (device, self._index.keys(), device_util.current()))
def decorated(*args): """Decorated function with merge_call.""" tower_context = distribute_lib.get_tower_context() if tower_context is None: # if in cross tower context already return result_fn() # TODO(psv): Test distribution of metrics using different distribution # strategies. # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn # with distribution object as the first parameter. We create a wrapper here # so that the result function need not have that parameter. def merge_fn_wrapper(distribution, merge_fn, *args): # We will get `PerDevice` merge function. Taking the first one as all are # identical copies of the function that we had passed below. return distribution.unwrap(merge_fn)[0](*args) # Wrapping result in merge_call. merge_call is used when we want to leave # tower mode and compute a value in cross tower mode. return tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
def decorated(*args): """Decorated function with merge_call.""" tower_context = distribute_lib.get_tower_context() if tower_context is None: # if in cross tower context already return result_fn() # TODO(psv): Test distribution of metrics using different distribution # strategies. # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn # with distribution object as the first parameter. We create a wrapper here # so that the result function need not have that parameter. def merge_fn_wrapper(distribution, merge_fn, *args): # We will get `PerDevice` merge function. Taking the first one as all are # identical copies of the function that we had passed below. return distribution.unwrap(merge_fn)[0](*args) # Wrapping result in merge_call. merge_call is used when we want to leave # tower mode and compute a value in cross tower mode. return tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
def model_fn(device_id): tower_context = distribute_lib.get_tower_context() with tower_context.tower_local_var_scope( variable_scope.VariableAggregation.SUM): v_sum = variable_scope.variable(1.0) with tower_context.tower_local_var_scope( variable_scope.VariableAggregation.MEAN): v_mean = variable_scope.variable(4.0) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) updates = [v_sum.assign_add(2.0 + device_id), v_mean.assign(6.0 * device_id)] all_v_sum[device_id] = v_sum all_v_mean[device_id] = v_mean c_sum = v_sum.get() c_mean = v_mean.get() components_sum[device_id] = c_sum components_mean[device_id] = c_mean self.assertIsNot(v_sum, c_sum) self.assertIsNot(v_mean, c_mean) return updates, v_sum, v_mean, c_sum, c_mean
def model_fn(name): v = variable_scope.variable(1.0, name=name) distribute_lib.get_tower_context().merge_call(lambda _: _) return v
def model_fn(device_id): v = variable_scope.variable(1.0, name="foo_" + str(device_id)) distribute_lib.get_tower_context().merge_call(lambda _: _) return v
def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): c = distribute_lib.get_tower_context().merge_call(in_cross_tower) return b, c
def model_fn(): # This variable should be created only once across the threads because of # special variable_creator functions used by `dist.call_for_each_tower`. v = variable_scope.variable(1.0, name="foo") distribute_lib.get_tower_context().merge_call(lambda _: _) return v
def model_fn(): value = math_ops.cast(distribute_lib.get_tower_context().tower_id, mirrored_var.dtype) return mirrored_var.assign_sub(value)
def model_fn(): tower_context = distribute_lib.get_tower_context() with tower_context.tower_local_var_scope("sum"): v_sum = variable_scope.variable(1.0) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) return v_sum
def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): c = distribute_lib.get_tower_context().merge_call(in_cross_tower) return b, c
def mark_devices_fn(): tower_id = distribute_lib.get_tower_context().tower_id self.assertLess(tower_id, len(d.worker_devices)) self.assertFalse(expected_devices[tower_id]) expected_devices[tower_id] = True
def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") distribute_lib.get_tower_context().merge_call(lambda _: _) b = constant_op.constant(2.0, name="b") return a, b
def init_from_checkpoint(ckpt_dir_or_file, assignment_map): """Initializes current variables with tensors loaded from given checkpoint. Note: This overrides default initialization ops of specified variables and redefines dtype. Assignment map supports following syntax: * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in current `scope_name` from `checkpoint_scope_name` with matching tensor names. * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - will initialize `scope_name/variable_name` variable from `checkpoint_scope_name/some_other_variable`. * `'scope_variable_name': variable` - will initialize given `tf.Variable` object with tensor 'scope_variable_name' from the checkpoint. * `'scope_variable_name': list(variable)` - will initialize list of partitioned variables with tensor 'scope_variable_name' from the checkpoint. * `'/': 'scope_name/'` - will load all variables in current `scope_name` from checkpoint's root (e.g. no scope). Supports loading into partitioned variables, which are represented as `'<variable>/part_<part #>'`. Example: ```python # Say, '/tmp/model.ckpt' has the following tensors: # -- name='old_scope_1/var1', shape=[20, 2] # -- name='old_scope_1/var2', shape=[50, 4] # -- name='old_scope_2/var3', shape=[100, 100] # Create new model's variables with tf.variable_scope('new_scope_1'): var1 = tf.get_variable('var1', shape=[20, 2], initializer=tf.zeros_initializer()) with tf.variable_scope('new_scope_2'): var2 = tf.get_variable('var2', shape=[50, 4], initializer=tf.zeros_initializer()) # Partition into 5 variables along the first axis. var3 = tf.get_variable(name='var3', shape=[100, 100], initializer=tf.zeros_initializer(), partitioner=lambda shape, dtype: [5, 1]) # Initialize all variables in `new_scope_1` from `old_scope_1`. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'}) # Use names to specify which variables to initialize from checkpoint. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': 'new_scope_1/var1', 'old_scope_1/var2': 'new_scope_2/var2'}) # Or use tf.Variable objects to identify what to initialize. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': var1, 'old_scope_1/var2': var2}) # Initialize partitioned variables using variable's name init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': 'new_scope_2/var3'}) # Or specify the list of tf.Variable objects. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': var3._get_variable_list()}) ``` Args: ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. assignment_map: Dict, where keys are names of the variables in the checkpoint and values are current variables or names of current variables (in default graph). Raises: tf.errors.OpError: If missing checkpoints or tensors in checkpoints. ValueError: If missing variables in current graph. """ if distribute_lib.get_cross_tower_context(): _init_from_checkpoint(None, ckpt_dir_or_file, assignment_map) else: distribute_lib.get_tower_context().merge_call(_init_from_checkpoint, ckpt_dir_or_file, assignment_map)
def model_fn(): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) distribute_lib.get_tower_context().merge_call(lambda _: _) return vs
def _merge_call_merge_raises_fn(): distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn)
def _add_tower_local_variable(self, *args, **kwargs): tower_context = distribute_lib.get_tower_context() with tower_context.tower_local_var_scope('mean'): return self.add_weight(*args, **kwargs)