def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") distribution_strategy_context.get_tower_context().merge_call( lambda _: _) b = constant_op.constant(2.0, name="b") return a, b
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() if update_device is not None: # We are calling an assign function in an update context. return f(self._v, *args, **kwargs) # We are calling an assign function in cross tower context, wrap it in an # update call. return distribution_strategy_context.get_distribution_strategy( ).update(self, f, *args, **kwargs) else: assert distribution_strategy_context.get_tower_context() # We are calling an assign function in tower context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function with the reduced value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError( "You must specify an aggregation method to update a " "a variable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce(aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_tower_context( ).merge_call(merge_fn, *args, **kwargs)
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() if update_device is not None: # We are calling an assign function in an update context. return f(self._v, *args, **kwargs) # We are calling an assign function in cross tower context, wrap it in an # update call. return distribution_strategy_context.get_distribution_strategy().update( self, f, *args, **kwargs) else: assert distribution_strategy_context.get_tower_context() # We are calling an assign function in tower context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function with the reduced value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "a variable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce( aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_tower_context().merge_call( merge_fn, *args, **kwargs)
def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") distribution_strategy_context.get_tower_context().merge_call( lambda _: _) b = constant_op.constant(2.0, name="b") return a, b
def model_fn(): vs = [] vs.append(variable_scope.variable(1.0, name="foo/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs
def model_fn(): vs = [] vs.append(variable_scope.variable(1.0, name="foo/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs
def set_non_tensor_output(self, name, output): """Set `output` with `name` to be captured as a non tensor output.""" if distribution_strategy_context.get_cross_tower_context(): self._non_tensor_outputs[name] = output else: def merge_fn(distribution, value): # NOTE(priyag): For non tensor outputs, we simply return all the values # in a list as aggregation doesn't make sense on non tensors. self._non_tensor_outputs[name] = distribution.unwrap(value) distribution_strategy_context.get_tower_context().merge_call( merge_fn, output)
def set_non_tensor_output(self, name, output): """Set `output` with `name` to be captured as a non tensor output.""" if distribution_strategy_context.get_cross_tower_context(): self._non_tensor_outputs[name] = output else: def merge_fn(distribution, value): # NOTE(priyag): For non tensor outputs, we simply return all the values # in a list as aggregation doesn't make sense on non tensors. self._non_tensor_outputs[name] = distribution.unwrap(value) distribution_strategy_context.get_tower_context().merge_call( merge_fn, output)
def model_fn(device_id): assert isinstance(device_id, int) def thread_creator_fn(next_creator, *args, **kwargs): return next_creator(*args, **kwargs) + ":thread_" + str(device_id) with variable_scope.variable_creator_scope(thread_creator_fn): # Create a variable in this scope. v = variable_scope.variable(1.0) # This will pause the current thread, and execute the other thread. distribution_strategy_context.get_tower_context().merge_call( lambda _: _) return v
def model_fn(features): with variable_scope.variable_scope("common"): layer1 = core.Dense(1) layer1(features) layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. distribution_strategy_context.get_tower_context().merge_call( lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)]
def model_fn(features): with variable_scope.variable_scope("common"): layer1 = core.Dense(1) layer1(features) layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. distribution_strategy_context.get_tower_context().merge_call( lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)]
def _aggregate_across_towers(metrics_collections, metric_value_fn, *args): """Aggregate metric value across towers.""" def fn(distribution, *a): """Call `metric_value_fn` in the correct control flow context.""" if hasattr(distribution, '_outer_control_flow_context'): # If there was an outer context captured before this method was called, # then we enter that context to create the metric value op. If the # caputred context is `None`, ops.control_dependencies(None) gives the # desired behavior. Else we use `Enter` and `Exit` to enter and exit the # captured context. # This special handling is needed because sometimes the metric is created # inside a while_loop (and perhaps a TPU rewrite context). But we don't # want the value op to be evaluated every step or on the TPU. So we # create it outside so that it can be evaluated at the end on the host, # once the update ops have been evaluted. # pylint: disable=protected-access if distribution._outer_control_flow_context is None: with ops.control_dependencies(None): metric_value = metric_value_fn(distribution, *a) else: distribution._outer_control_flow_context.Enter() metric_value = metric_value_fn(distribution, *a) distribution._outer_control_flow_context.Exit() # pylint: enable=protected-access else: metric_value = metric_value_fn(distribution, *a) if metrics_collections: ops.add_to_collections(metrics_collections, metric_value) return metric_value return distribution_strategy_context.get_tower_context().merge_call( fn, *args)
def _assert_in_default_state(t): t.assertIs(distribution_strategy_context._get_default_tower_context(), distribution_strategy_context.get_tower_context()) t.assertIs(None, distribution_strategy_context.get_cross_tower_context()) t.assertIs(distribution_strategy_context._get_default_distribution_strategy(), distribution_strategy_context.get_distribution_strategy()) t.assertFalse(distribution_strategy_context.has_distribution_strategy())
def set_last_step_output( self, name, output, aggregation=variables_lib.VariableAggregation.NONE): """Set `output` with `name` to be outputted from the last step. Args: name: String, name to identify the output. Doesn't need to match tensor name. output: The tensors that should be outputted with `name`. See below for actual types supported. aggregation: Aggregation method to use to aggregate outputs from multiple towers. Required if `set_last_step_output` is called in a tower context. Optional in cross_tower_context. When present, the outputs from all the towers are aggregated using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and aggregation is set, output must be a `PerDevice` value. The aggregation method is also recorded in a dictionary `_last_step_outputs_aggregations` for later interpreting of the outputs as already reduced or not. """ if distribution_strategy_context.get_cross_tower_context(): self._last_step_outputs_aggregations[name] = aggregation if aggregation is variables_lib.VariableAggregation.NONE: self._last_step_outputs[name] = output else: distribution = distribution_strategy_context.get_distribution_strategy( ) self._last_step_outputs[name] = distribution.reduce( aggregation, output, destinations="/device:CPU:0") else: assert aggregation is not variables_lib.VariableAggregation.NONE def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce( aggregation, value, destinations="/device:CPU:0") # Setting this inside the `merge_fn` because all towers share the same # context object, so it's more robust to set it only once (even if all # the towers are trying to set the same value). self._last_step_outputs_aggregations[name] = aggregation distribution_strategy_context.get_tower_context().merge_call( merge_fn, output)
def skip_summary(): # If using multiple towers in distributed strategy, skip summaries on all # towers except the first one (tower_id=0). # TODO(priyag): Add a new optional argument that will provide multiple # alternatives to override default behavior. (e.g. run on last tower, # compute sum or mean across towers). tower_context = distribution_strategy_context.get_tower_context() return tower_context and tower_context.tower_id > 0
def model_fn(): v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. distribution_strategy_context.get_tower_context().merge_call( lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) v3 = variable_scope.get_variable( "var3", [1], synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation=variable_scope.VariableAggregation.MEAN) return v0, v1, v2, v3
def skip_summary(): # If using multiple towers in distributed strategy, skip summaries on all # towers except the first one (tower_id=0). # TODO(priyag): Add a new optional argument that will provide multiple # alternatives to override default behavior. (e.g. run on last tower, # compute sum or mean across towers). tower_context = distribution_strategy_context.get_tower_context() return tower_context and tower_context.tower_id > 0
def model_fn(): v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. distribution_strategy_context.get_tower_context().merge_call( lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) v3 = variable_scope.get_variable( "var3", [1], synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation=variable_scope.VariableAggregation.MEAN) return v0, v1, v2, v3
def vq_discrete_bottleneck(x, hparams): """Simple vector quantized discrete bottleneck.""" bottleneck_size = 2**hparams.bottleneck_bits x_shape = commons.shape_list(x) x = tf.reshape(x, [-1, hparams.hidden_size]) x_means_hot, e_loss = vq_nearest_neighbor(x, hparams) if hparams.bottleneck_kind == "mog": loss = hparams.beta * e_loss else: tf.logging.info("Using EMA with beta = {}".format(hparams.beta)) means, ema_means, ema_count = (hparams.means, hparams.ema_means, hparams.ema_count) # Update the ema variables updated_ema_count = commons.assign_moving_average(ema_count, tf.reduce_sum( x_means_hot, axis=0), hparams.decay, zero_debias=False) dw = tf.matmul(x_means_hot, x, transpose_a=True) updated_ema_means = commons.assign_moving_average(ema_means, dw, hparams.decay, zero_debias=False) n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True) updated_ema_count = ((updated_ema_count + hparams.epsilon) / (n + bottleneck_size * hparams.epsilon) * n) # pylint: disable=g-no-augmented-assignment updated_ema_means = updated_ema_means / tf.expand_dims( updated_ema_count, axis=-1) # pylint: enable=g-no-augmented-assignment with tf.control_dependencies([e_loss]): # distribution_strategy def update_fn(v, value): return tf.assign(v, value) tower_context = distribution_strategy_context.get_tower_context() if tower_context: def merge_fn(strategy, v, value): value = strategy.reduce(tf.VariableAggregation.MEAN, value, v) return strategy.update(v, update_fn, value) update_means = tower_context.merge_call( merge_fn, means, updated_ema_means) else: strategy = distribution_strategy_context.get_cross_tower_context( ) update_means = strategy.update(means, update_fn, updated_ema_means) with tf.control_dependencies([update_means]): loss = hparams.beta * e_loss discrete = tf.reshape(x_means_hot, x_shape[:-1] + [bottleneck_size]) return discrete, loss
def merge_grads(grads_and_vars): """Merge gradients from different replicas.""" def merge_grad_fn(strategy, grads_and_vars): reduced_grads = strategy.batch_reduce( variable_scope.VariableAggregation.MEAN, grads_and_vars) return reduced_grads return distribution_strategy_context.get_tower_context().merge_call( merge_grad_fn, grads_and_vars)
def increment_var(v, amount=1): """`v += amount`, distributed-aware version.""" def update(vu): return vu.assign_add(amount, read_value=False) def merge_fn(dist, vm): return dist.update(vm, update) tower_context = distribution_strategy_context.get_tower_context() return tower_context.merge_call(merge_fn, v)
def set_last_step_output(self, name, output, aggregation=variables_lib.VariableAggregation.NONE): """Set `output` with `name` to be outputted from the last step. Args: name: String, name to identify the output. Doesn't need to match tensor name. output: The tensors that should be outputted with `name`. See below for actual types supported. aggregation: Aggregation method to use to aggregate outputs from multiple towers. Required if `set_last_step_output` is called in a tower context. Optional in cross_tower_context. When present, the outputs from all the towers are aggregated using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and aggregation is set, output must be a `PerDevice` value. The aggregation method is also recorded in a dictionary `_last_step_outputs_aggregations` for later interpreting of the outputs as already reduced or not. """ if distribution_strategy_context.get_cross_tower_context(): self._last_step_outputs_aggregations[name] = aggregation if aggregation is variables_lib.VariableAggregation.NONE: self._last_step_outputs[name] = output else: distribution = distribution_strategy_context.get_distribution_strategy() self._last_step_outputs[name] = distribution.reduce( aggregation, output, destinations="/device:CPU:0") else: assert aggregation is not variables_lib.VariableAggregation.NONE def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce( aggregation, value, destinations="/device:CPU:0") # Setting this inside the `merge_fn` because all towers share the same # context object, so it's more robust to set it only once (even if all # the towers are trying to set the same value). self._last_step_outputs_aggregations[name] = aggregation distribution_strategy_context.get_tower_context().merge_call( merge_fn, output)
def merge_update_step(update_ops, local_step): """Merge local step counter update from different replicas.""" def merge_update_step_fn(strategy, update_ops, local_step): merged_ops = [] for update_op in update_ops: merged_ops.append(strategy.group(update_op)) with ops.control_dependencies(merged_ops): incre_op = local_step.assign_add(1).op return incre_op return distribution_strategy_context.get_tower_context().merge_call( merge_update_step_fn, update_ops, local_step)
def merge_fn(dist, s): self.assertIs( distribution_strategy_context._get_default_distribution_strategy(), dist) self.assertIs(None, distribution_strategy_context.get_tower_context()) self.assertIs(dist, distribution_strategy_context.get_cross_tower_context()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) self.assertFalse( distribution_strategy_context.has_distribution_strategy()) return "foo_" + s
def increment_var(v, amount=1): """`v += amount`, distributed-aware version.""" def update(vu): if isinstance(vu, resource_variable_ops.ResourceVariable): return vu.assign_add(amount, read_value=False) else: return state_ops.assign_add(vu, amount) def merge_fn(dist, vm): return dist.group(dist.update(vm, update)) tower_context = distribution_strategy_context.get_tower_context() return tower_context.merge_call(merge_fn, v)
def increment_var(v, amount=1): """`v += amount`, distributed-aware version.""" def update(vu): if isinstance(vu, resource_variable_ops.ResourceVariable): return vu.assign_add(amount, read_value=False) else: return state_ops.assign_add(vu, amount) def merge_fn(dist, vm): return dist.group(dist.update(vm, update)) tower_context = distribution_strategy_context.get_tower_context() return tower_context.merge_call(merge_fn, v)
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() if update_device is not None: # We are calling an assign function on the mirrored variable in an # update context. v = self.get(device=update_device) return f(v, *args, **kwargs) # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. strategy = distribution_strategy_context.get_distribution_strategy( ) updates = strategy.update(self, f, *args, **kwargs) grouped = strategy.group(updates) if isinstance(updates, DistributedValues) and updates.is_tensor_like: # Make sure we run all updates. Without this, something like # session.run(mirrored_var.assign*(...)) may only update one tower. index = {} for d in updates.devices: with ops.device(d), ops.control_dependencies([grouped]): index[d] = array_ops.identity(updates.get(d)) return Mirrored(index) else: return grouped else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower # context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function on each of the mirrored variables with the reduced # value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError( "You must specify an aggregation method to update a " "MirroredVariable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce(aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_tower_context( ).merge_call(merge_fn, *args, **kwargs)
def run_fn(): tower_context = distribution_strategy_context.get_tower_context() self.assertTrue(tower_context is not None) self.assertIs(None, distribution_strategy_context.get_cross_tower_context()) self.assertTrue(distribution_strategy_context.has_distribution_strategy()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribution_strategy_context.get_tower_context()) self.assertIs(dist, distribution_strategy_context.get_cross_tower_context()) self.assertTrue(distribution_strategy_context.has_distribution_strategy()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def get(self, device=None): """Returns the value for the current device or raises a ValueError.""" if device is None: tower_context = distribution_strategy_context.get_tower_context() if tower_context: device = tower_context.device else: device = distribute_lib.get_update_device() if device is None: return self._get_cross_tower() device = device_util.canonicalize(device) try: return self._index[device] except KeyError as e: six.raise_from( ValueError("Device %s not found in %s (current device %s)" % (device, self._index.keys(), device_util.current())), e)
def get(self, device=None): """Returns the value for the current device or raises a ValueError.""" if device is None: tower_context = distribution_strategy_context.get_tower_context() if tower_context: device = tower_context.device else: device = distribute_lib.get_update_device() if device is None: return self._get_cross_tower() device = device_util.canonicalize(device) try: return self._index[device] except KeyError as e: six.raise_from( ValueError("Device %s not found in %s (current device %s)" % (device, self._index.keys(), device_util.current())), e)
def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() if update_device is not None: # We are calling an assign function on the mirrored variable in an # update context. v = self.get(device=update_device) return f(v, *args, **kwargs) # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. strategy = distribution_strategy_context.get_distribution_strategy() updates = strategy.update(self, f, *args, **kwargs) grouped = strategy.group(updates) if isinstance(updates, DistributedValues) and updates.is_tensor_like: # Make sure we run all updates. Without this, something like # session.run(mirrored_var.assign*(...)) may only update one tower. index = {} for d in updates.devices: with ops.device(d), ops.control_dependencies([grouped]): index[d] = array_ops.identity(updates.get(d)) return Mirrored(index) else: return grouped else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower # context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function on each of the mirrored variables with the reduced # value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "MirroredVariable in Tower Context.") def merge_fn(strategy, value, *other_args, **other_kwargs): return strategy.update( self, f, strategy.reduce( aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) return distribution_strategy_context.get_tower_context().merge_call( merge_fn, *args, **kwargs)
def testMergeCall(self): _assert_in_default_state(self) def merge_fn(dist, s): self.assertIs( distribution_strategy_context._get_default_distribution_strategy(), dist) self.assertIs(None, distribution_strategy_context.get_tower_context()) self.assertIs(dist, distribution_strategy_context.get_cross_tower_context()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) self.assertFalse( distribution_strategy_context.has_distribution_strategy()) return "foo_" + s tower_ctx = distribution_strategy_context.get_tower_context() self.assertIs(distribution_strategy_context._get_default_tower_context(), tower_ctx) self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar")) _assert_in_default_state(self)
def decorated(metric_obj, *args): """Decorated function with merge_call.""" tower_context = distribution_strategy_context.get_tower_context() if tower_context is None: # if in cross tower context already result_t = result_fn(*args) else: # TODO(psv): Test distribution of metrics using different distribution # strategies. # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn # with distribution object as the first parameter. We create a wrapper # here so that the result function need not have that parameter. def merge_fn_wrapper(distribution, merge_fn, *args): # We will get `PerDevice` merge function. Taking the first one as all # are identical copies of the function that we had passed below. return distribution.unwrap(merge_fn)[0](*args) # Wrapping result in merge_call. merge_call is used when we want to leave # tower mode and compute a value in cross tower mode. result_t = tower_context.merge_call(merge_fn_wrapper, result_fn, *args) check_is_tensor_or_operation(result_t, 'Metric {0}\'s result'.format(metric_obj.name)) return result_t
def decorated(metric_obj, *args): """Decorated function with merge_call.""" tower_context = distribution_strategy_context.get_tower_context() if tower_context is None: # if in cross tower context already result_t = result_fn(*args) else: # TODO(psv): Test distribution of metrics using different distribution # strategies. # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn # with distribution object as the first parameter. We create a wrapper # here so that the result function need not have that parameter. def merge_fn_wrapper(distribution, merge_fn, *args): # We will get `PerDevice` merge function. Taking the first one as all # are identical copies of the function that we had passed below. return distribution.unwrap(merge_fn)[0](*args) # Wrapping result in merge_call. merge_call is used when we want to leave # tower mode and compute a value in cross tower mode. result_t = tower_context.merge_call(merge_fn_wrapper, result_fn, *args) check_is_tensor_or_operation(result_t, 'Metric {0}\'s result'.format(metric_obj.name)) return result_t
def model_fn(): if num_gpus == 0: last_part_device = 'device:CPU:0' else: last_part_device = ( 'device:GPU:%d' % distribution_strategy_context.get_tower_context().tower_id) a = constant_op.constant(1.0) b = constant_op.constant(2.0) c = a + b self.assertEqual(a.device, worker_device + '/' + last_part_device) self.assertEqual(b.device, worker_device + '/' + last_part_device) self.assertEqual(c.device, worker_device + '/' + last_part_device) # The device scope is ignored for variables but not for normal ops. with ops.device('/job:worker/task:0'): x = variable_scope.get_variable( 'x', initializer=10.0, aggregation=variable_scope.VariableAggregation.SUM) x_add = x.assign_add(c) e = a + c # The variable x is on the task 1 since the device_function has been # called once before the model_fn. self.assertEqual(x.device, '/job:ps/task:1') self.assertEqual(x_add.device, x.device) self.assertEqual(e.device, '/job:worker/replica:0/task:0/%s' % last_part_device) # The colocate_vars_with can override the distribution's device. with d.colocate_vars_with(x): y = variable_scope.get_variable( 'y', initializer=20.0, aggregation=variable_scope.VariableAggregation.SUM) # We add an identity here to avoid complaints about summing # non-distributed values. y_add = y.assign_add(array_ops.identity(x_add)) self.assertEqual(y.device, '/job:ps/task:1') self.assertEqual(y_add.device, y.device) self.assertEqual(y.device, x.device) z = variable_scope.get_variable( 'z', initializer=10.0, aggregation=variable_scope.VariableAggregation.SUM) self.assertEqual(z.device, '/job:ps/task:0') self.assertNotEqual(z.device, x.device) with ops.control_dependencies([y_add]): # We add an identity here to avoid complaints about summing # non-distributed values. z_add = z.assign_add(array_ops.identity(y)) with ops.control_dependencies([z_add]): f = z + c self.assertEqual(f.device, worker_device + '/' + last_part_device) # The device scope would merge with the default worker device. with ops.device('/CPU:1'): g = e + 1.0 self.assertEqual(g.device, worker_device + '/device:CPU:1') # Ths ops.colocate_with will be ignored when defining a variale but not # for a normal tensor. with ops.colocate_with(x): u = variable_scope.get_variable('u', initializer=30.0) v = variable_scope.get_variable('v', initializer=30.0) h = f + 1.0 self.assertIn('/job:ps/', u.device) self.assertIn('/job:ps/', v.device) # u and v are on different parameter servers. self.assertTrue(u.device != x.device or v.device != x.device) self.assertTrue(u.device == x.device or v.device == x.device) # Here h is not on one worker. Note h.device is canonical while x.device # is not but. self.assertIn('/job:ps/', h.device) return y_add, z_add, f
def assign_moving_average(variable, value, decay, zero_debias=True, name=None): """Compute the moving average of a variable. https://github.com/tensorflow/tensorflow/blob/c966b5eed60a570f2121cb84ddb4ece84c413719/tensorflow/python/training/moving_averages.py """ def _zero_debias(unbiased_var, value, decay): """Compute the delta required for a debiased Variable. """ with tf.variable_scope(unbiased_var.op.name, values=[unbiased_var, value, decay]) as scope: with tf.init_scope(): biased_initializer = tf.zeros_initializer( dtype=unbiased_var.dtype)(unbiased_var.get_shape()) local_step_initializer = tf.zeros_initializer() def _maybe_get_unique(name): """Get name for a unique variable, if not `reuse=True`.""" if tf.get_variable_scope().reuse: return name vs_vars = [ x.op.name for x in tf.get_variable_scope().global_variables() ] full_name = tf.get_variable_scope().name + "/" + name if full_name not in vs_vars: return name idx = 1 while full_name + ("_%d" % idx) in vs_vars: idx += 1 return name + ("_%d" % idx) biased_var = tf.get_variable(_maybe_get_unique("biased"), initializer=biased_initializer, trainable=False) local_step = tf.get_variable(_maybe_get_unique("local_step"), shape=[], dtype=unbiased_var.dtype, initializer=local_step_initializer, trainable=False) # Get an update ops for both shadow variables. update_biased = tf.assign_sub(biased_var, (biased_var - value) * decay, name=scope.name) update_local_step = local_step.assign_add(1) # Compute the value of the delta to update the unbiased EMA. Make sure to # use the new values of the biased variable and the local step. with tf.control_dependencies([update_biased, update_local_step]): # This function gets `1 - decay`, so use `1.0 - decay` in the exponent. unbiased_ema_delta = ( unbiased_var - biased_var.read_value() / (1 - tf.pow(1.0 - decay, local_step.read_value()))) return unbiased_ema_delta def update_fn(v, value, decay=decay): decay = tf.convert_to_tensor(1.0 - decay, name="decay") if decay.dtype != v.dtype.base_dtype: decay = tf.cast(decay, v.dtype.base_dtype) if zero_debias: update_delta = _zero_debias(v, value, decay) else: update_delta = (v - value) * decay return tf.assign_sub(v, update_delta, name=scope) with tf.name_scope(name, "AssignMovingAvg", [variable, value, decay]) as scope: tower_context = distribution_strategy_context.get_tower_context() if tower_context: # In a tower context, we update variable using the mean of value across # towers. def merge_fn(strategy, v, value): try: value = strategy.reduce(tf.VariableAggregation.MEAN, value, v) except: pass # Mirrored variables are loaded return strategy.update(v, update_fn, value) return tower_context.merge_call(merge_fn, variable, value) else: strategy = distribution_strategy_context.get_cross_tower_context() return strategy.update(variable, update_fn, value)
def f1_score(labels, predictions, weights=None, num_thresholds=200, metrics_collections=None, updates_collections=None, name=None): """Computes the approximately best F1-score across different thresholds. The f1_score function applies a range of thresholds to the predictions to convert them from [0, 1] to bool. Precision and recall are computed by comparing them to the labels. The F1-Score is then defined as 2 * precision * recall / (precision + recall). The best one across the thresholds is returned. Disclaimer: In practice it may be desirable to choose the best threshold on the validation set and evaluate the F1 score with this threshold on a separate code set. Or it may be desirable to use a fixed threshold (e.g. 0.5). This function internally creates four local variables, `true_positives`, `true_negatives`, `false_positives` and `false_negatives` that are used to compute the pairs of recall and precision values for a linearly spaced set of thresholds from which the best f1-score is derived. This value is ultimately returned as `f1-score`, an idempotent operation that computes the F1-score (computed using the aforementioned variables). The `num_thresholds` variable controls the degree of discretization with larger numbers of thresholds more closely approximating the true best F1-score. For estimation of the metric over a stream of data, the function creates an `update_op` operation that updates these variables and returns the F1-score. Example usage with a custom estimator: def model_fn(features, labels, mode): predictions = make_predictions(features) loss = make_loss(predictions, labels) train_op = tf.contrib.training.create_train_op( total_loss=loss, optimizer='Adam') eval_metric_ops = {'f1': f1_score(labels, predictions)} return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs) estimator = tf.estimator.Estimator(model_fn=model_fn) If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. Args: labels: A `Tensor` whose shape matches `predictions`. Will be cast to `bool`. predictions: A floating point `Tensor` of arbitrary shape and whose values are in the range `[0, 1]`. weights: Optional `Tensor` whose rank is either 0, or the same rank as `labels`, and must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). num_thresholds: The number of thresholds to use when discretizing the roc curve. metrics_collections: An optional list of collections that `f1_score` should be added to. updates_collections: An optional list of collections that `update_op` should be added to. name: An optional variable_scope name. Returns: f1_score: A scalar `Tensor` representing the current best f1-score across different thresholds. update_op: An operation that increments the `true_positives`, `true_negatives`, `false_positives` and `false_negatives` variables appropriately and whose value matches the `f1_score`. Raises: ValueError: If `predictions` and `labels` have mismatched shapes, or if `weights` is not `None` and its shape doesn't match `predictions`, or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ with variable_scope.variable_scope( name, 'f1', (labels, predictions, weights)): predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=predictions, labels=labels, weights=weights) # To account for floating point imprecisions / avoid division by zero. epsilon = 1e-7 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)] thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon] # Confusion matrix. values, update_ops = metrics_impl._confusion_matrix_at_thresholds( # pylint: disable=protected-access labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn')) # Compute precision and recall at various thresholds. def compute_best_f1_score(tp, fp, fn, name): precision_at_t = math_ops.div(tp, epsilon + tp + fp, name='precision_' + name) recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) # Compute F1 score. f1_at_thresholds = ( 2.0 * precision_at_t * recall_at_t / (precision_at_t + recall_at_t + epsilon)) return math_ops.reduce_max(f1_at_thresholds) def f1_across_towers(_, values): best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'], fn=values['fn'], name='value') if metrics_collections: ops.add_to_collections(metrics_collections, best_f1) return best_f1 best_f1 = distribution_strategy_context.get_tower_context().merge_call( f1_across_towers, values) update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'], fn=update_ops['fn'], name='update') if updates_collections: ops.add_to_collections(updates_collections, update_op) return best_f1, update_op
def model_fn(): if 'CPU' in compute_device: tower_compute_device = '/device:CPU:0' else: tower_compute_device = ( '/device:GPU:%d' % distribution_strategy_context.get_tower_context().tower_id) tower_compute_device = device_util.canonicalize(tower_compute_device) if 'CPU' in variable_device: tower_variable_device = '/device:CPU:0' else: tower_variable_device = ( '/device:GPU:%d' % distribution_strategy_context.get_tower_context().tower_id) tower_variable_device = device_util.canonicalize(tower_variable_device) a = constant_op.constant(1.0) b = constant_op.constant(2.0) c = a + b self.assertEqual(a.device, tower_compute_device) self.assertEqual(b.device, tower_compute_device) self.assertEqual(c.device, tower_compute_device) # The device scope is ignored for variables but not for normal ops. with ops.device('/device:GPU:2'): x = variable_scope.get_variable( 'x', initializer=10.0, aggregation=variable_scope.VariableAggregation.SUM) x_add = x.assign_add(c) e = a + c self.assertEqual( device_util.canonicalize(x.device), tower_variable_device) self.assertEqual(x_add.device, x.device) self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) # The colocate_vars_with can override the distribution's device. with d.colocate_vars_with(x): y = variable_scope.get_variable( 'y', initializer=20.0, aggregation=variable_scope.VariableAggregation.SUM) # We add an identity here to avoid complaints about summing # non-distributed values. y_add = y.assign_add(array_ops.identity(x_add)) self.assertEqual( device_util.canonicalize(y.device), tower_variable_device) self.assertEqual(y_add.device, y.device) self.assertEqual(y.device, x.device) z = variable_scope.get_variable( 'z', initializer=10.0, aggregation=variable_scope.VariableAggregation.SUM) self.assertEqual( device_util.canonicalize(z.device), tower_variable_device) with ops.control_dependencies([y_add]): # We add an identity here to avoid complaints about summing # non-distributed values. z_add = z.assign_add(array_ops.identity(y)) with ops.control_dependencies([z_add]): f = z + c self.assertEqual(f.device, tower_compute_device) # The device scope would merge with the default worker device. with ops.device('/CPU:1'): g = e + 1.0 self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1')) # Ths ops.colocate_with will be ignored when defining a variale but not # for a normal tensor. with ops.colocate_with(x): u = variable_scope.get_variable('u', initializer=30.0) h = f + 1.0 self.assertEqual( device_util.canonicalize(u.device), tower_variable_device) self.assertEqual(device_util.canonicalize(x.device), h.device) return y_add, z_add, f
def mark_devices_fn(): tower_id = distribution_strategy_context.get_tower_context().tower_id self.assertLess(tower_id, len(d.worker_devices)) self.assertFalse(expected_devices[tower_id]) expected_devices[tower_id] = True
def _merge_call_merge_raises_fn(): distribution_strategy_context.get_tower_context().merge_call( _call_merge_raises_fn)
def model_fn(name): v = variable_scope.variable(1.0, name=name) distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v
def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): c = distribution_strategy_context.get_tower_context().merge_call( in_cross_tower) return b, c
def _assert_tower_context(): if not distribution_strategy_context.get_tower_context(): raise RuntimeError( "Tower-local variables may only be assigned in a tower context.")
def apply_gradients(self, loss, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This is the second part of `minimize()`. It returns an `Operation` that applies gradients. Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`. Raises: TypeError: If `grads_and_vars` is malformed. ValueError: If none of the variables have gradients. RuntimeError: If you should use `_distributed_apply()` instead. """ # This is a default implementation of apply_gradients() that can be shared # by most optimizers. It relies on the subclass implementing the following # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). # Handle DistributionStrategy case. if distribution_strategy_context.get_cross_tower_context(): raise RuntimeError("Use `_distributed_apply()` instead of `apply_gradients()` in a cross-tower context.") # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by # always calling _distributed_apply(), using the default distribution # as needed. if distribution_strategy_context.has_distribution_strategy(): grads_and_vars = optimizer.get_filtered_grad_fn(lambda: grads_and_vars)() return distribution_strategy_context.get_tower_context().merge_call( self._distributed_apply, grads_and_vars, global_step, name ) # No DistributionStrategy case. grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. if not grads_and_vars: raise ValueError("No variables provided.") converted_grads_and_vars = [] for grad, var in grads_and_vars: if grad is not None: try: # Convert the grad to Tensor or IndexedSlices if necessary. grad = ops.convert_to_tensor_or_indexed_slices(grad) except TypeError: raise TypeError("Gradient must be convertible to a Tensor or IndexedSlices, or None: %s" % grad) if not isinstance(grad, (ops.Tensor, ops.IndexedSlices)): raise TypeError("Gradient must be a Tensor, IndexedSlices, or None: %s" % grad) processor = _get_processor(var) converted_grads_and_vars.append((grad, var, processor)) converted_grads_and_vars = tuple(converted_grads_and_vars) var_list = [var for grad, var, _ in converted_grads_and_vars if grad is not None] if not var_list: raise ValueError("No gradients provided for any variable: %s." % ([str(var) for _, var, _ in converted_grads_and_vars],)) with ops.init_scope(): self._create_slots(var_list) update_ops = [] with ops.name_scope(name, self._name) as name: self._prepare() for grad, var, processor in converted_grads_and_vars: if grad is None: continue # We colocate all ops created in _apply_dense or _apply_sparse # on the same device as the variable. # TODO(apassos): figure out how to get the variable name here. if context.executing_eagerly() or isinstance(var, resource_variable_ops.ResourceVariable) and not var._in_graph_mode: scope_name = "" else: scope_name = var.op.name with ops.name_scope("update_" + scope_name), ops.colocate_with(var): update_ops.append(processor.update_op(self, loss, grad, global_step)) if global_step is None: apply_updates = self._finish(update_ops, loss, name) else: with ops.control_dependencies([self._finish(update_ops, loss, "update")]): with ops.colocate_with(global_step): if isinstance(global_step, resource_variable_ops.ResourceVariable): # TODO(apassos): the implicit read in assign_add is slow; consider # making it less so. apply_updates = resource_variable_ops.assign_add_variable_op( resource=global_step.handle, value=ops.convert_to_tensor( value=1, dtype=global_step.dtype ), name=name ) else: apply_updates = state_ops.assign_add( ref=global_step, value=1, name=name ) if not context.executing_eagerly(): if isinstance(apply_updates, ops.Tensor): apply_updates = apply_updates.op train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) if apply_updates not in train_op: train_op.append(apply_updates) return apply_updates
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This is the second part of `minimize()`. It returns an `Operation` that applies gradients. Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`. Raises: TypeError: If `grads_and_vars` is malformed. ValueError: If none of the variables have gradients. RuntimeError: If you should use `_distributed_apply()` instead. """ # This is a default implementation of apply_gradients() that can be shared # by most optimizers. It relies on the subclass implementing the following # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). # Handle DistributionStrategy case. if distribution_strategy_context.get_cross_tower_context(): raise RuntimeError("Use `_distributed_apply()` instead of " "`apply_gradients()` in a cross-tower context.") # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by # always calling _distributed_apply(), using the default distribution # as needed. if distribution_strategy_context.has_distribution_strategy(): grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() return distribution_strategy_context.get_tower_context().merge_call( self._distributed_apply, grads_and_vars, global_step, name) # No DistributionStrategy case. grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. if not grads_and_vars: raise ValueError("No variables provided.") converted_grads_and_vars = [] for g, v in grads_and_vars: if g is not None: try: # Convert the grad to Tensor or IndexedSlices if necessary. g = ops.convert_to_tensor_or_indexed_slices(g) except TypeError: raise TypeError( "Gradient must be convertible to a Tensor" " or IndexedSlices, or None: %s" % g) if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): raise TypeError( "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) p = _get_processor(v) converted_grads_and_vars.append((g, v, p)) converted_grads_and_vars = tuple(converted_grads_and_vars) var_list = [v for g, v, _ in converted_grads_and_vars if g is not None] if not var_list: raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, _, v in converted_grads_and_vars],)) with ops.init_scope(): self._create_slots(var_list) update_ops = [] with ops.name_scope(name, self._name) as name: self._prepare() for grad, var, processor in converted_grads_and_vars: if grad is None: continue # We colocate all ops created in _apply_dense or _apply_sparse # on the same device as the variable. # TODO(apassos): figure out how to get the variable name here. if context.executing_eagerly() or isinstance( var, resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access scope_name = "" else: scope_name = var.op.name with ops.name_scope("update_" + scope_name), ops.colocate_with(var): update_ops.append(processor.update_op(self, grad)) if global_step is None: apply_updates = self._finish(update_ops, name) else: with ops.control_dependencies([self._finish(update_ops, "update")]): with ops.colocate_with(global_step): if isinstance(global_step, resource_variable_ops.ResourceVariable): # TODO(apassos): the implicit read in assign_add is slow; consider # making it less so. apply_updates = resource_variable_ops.assign_add_variable_op( global_step.handle, ops.convert_to_tensor(1, dtype=global_step.dtype), name=name) else: apply_updates = state_ops.assign_add(global_step, 1, name=name) if not context.executing_eagerly(): if isinstance(apply_updates, ops.Tensor): apply_updates = apply_updates.op train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) if apply_updates not in train_op: train_op.append(apply_updates) return apply_updates
def model_fn(device_id): v = variable_scope.variable(1.0, name="foo_" + str(device_id)) distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v
def model_fn(): # This variable should be created only once across the threads because of # special variable_creator functions used by `dist.call_for_each_tower`. v = variable_scope.variable(1.0, name="foo") distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v
def model_fn(): value = math_ops.cast( distribution_strategy_context.get_tower_context().tower_id, mirrored_var.dtype) return mirrored_var.assign_sub(value)
def model_fn(): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs