示例#1
0
 def _get(self):
     """Returns the value for the current device or raises a ValueError."""
     replica_id = values_util.get_current_replica_id_as_int()
     if replica_id is None:
         return self._get_cross_replica()
     else:
         return self._values[replica_id]
示例#2
0
 def handle(self):
   replica_id = values_util.get_current_replica_id_as_int()
   if replica_id is None:
     raise ValueError("`handle` is not available outside the replica context"
                      " or a `tf.distribute.Strategy.update()` call.")
   else:
     return self._values[replica_id].handle
示例#3
0
 def _get_on_device_or_primary(self):
     """Returns value in same replica or device if possible, else the _primary."""
     replica_id = values_util.get_current_replica_id_as_int()
     if replica_id is None:
         # Try to find a value on the current device.
         current_device = device_util.canonicalize(device_util.current())
         for value in self._values:
             if device_util.canonicalize(value.device) == current_device:
                 return value
         return self._primary
     else:
         return self._values[replica_id]
示例#4
0
  def _replica_ctx_update(self, var, fn, args, kwargs):
    replica_context = distribution_strategy_context.get_replica_context()
    assert replica_context
    replica_id = values_util.get_current_replica_id_as_int()
    name = "update_%d" % replica_id

    if isinstance(var, values.DistributedVariable):
      var = var._get_replica(replica_id)  # pylint: disable=protected-access

    with ops.device(var.device), ops.name_scope(name):
      result = fn(var, *args, **kwargs)
    return result
示例#5
0
    def call(self, inputs, training=None):
        model_obj = self.trainable_model if training else self.eval_model
        replica_context = None
        if tf.distribute.has_strategy():
            replica_context = tf.distribute.get_replica_context()
            if replica_context is not None:
                # Map correspondent replica of MirroredVariable to replica concrete function
                replica_id = get_current_replica_id_as_int()
                new_variables = []
                new_captured = []
                for concrete_var_name, var, input_tensor in zip_longest(
                        model_obj.sorted_concrete_vars_names,
                        model_obj.mirrored_variables + self.op_vars,
                        model_obj.fn_train.inputs[1:]):
                    if concrete_var_name:
                        # Check if some variables from other replicas are needed for
                        # concrete function call
                        name, idx = name_without_replica_idx(concrete_var_name)
                        if name not in model_obj.bn_weights_names:
                            idx = replica_id

                    new_variables.append(var._get_replica(idx))
                    new_captured.append(
                        (var._get_replica(idx).handle, input_tensor))

        if not tf.distribute.has_strategy() or not replica_context:
            # If there is no distribute strategy or in compile time
            # don't change vars
            new_variables = model_obj.fn_train.graph.variables
            new_captured = model_obj.fn_train.graph.captures

        fn_train = make_new_func(model_obj.fn_train.graph.as_graph_def(),
                                 new_captured, new_variables,
                                 model_obj.fn_train.inputs,
                                 [model_obj.output_tensor])

        return fn_train(inputs)