def _get(self): """Returns the value for the current device or raises a ValueError.""" replica_id = values_util.get_current_replica_id_as_int() if replica_id is None: return self._get_cross_replica() else: return self._values[replica_id]
def handle(self): replica_id = values_util.get_current_replica_id_as_int() if replica_id is None: raise ValueError("`handle` is not available outside the replica context" " or a `tf.distribute.Strategy.update()` call.") else: return self._values[replica_id].handle
def _get_on_device_or_primary(self): """Returns value in same replica or device if possible, else the _primary.""" replica_id = values_util.get_current_replica_id_as_int() if replica_id is None: # Try to find a value on the current device. current_device = device_util.canonicalize(device_util.current()) for value in self._values: if device_util.canonicalize(value.device) == current_device: return value return self._primary else: return self._values[replica_id]
def _replica_ctx_update(self, var, fn, args, kwargs): replica_context = distribution_strategy_context.get_replica_context() assert replica_context replica_id = values_util.get_current_replica_id_as_int() name = "update_%d" % replica_id if isinstance(var, values.DistributedVariable): var = var._get_replica(replica_id) # pylint: disable=protected-access with ops.device(var.device), ops.name_scope(name): result = fn(var, *args, **kwargs) return result
def call(self, inputs, training=None): model_obj = self.trainable_model if training else self.eval_model replica_context = None if tf.distribute.has_strategy(): replica_context = tf.distribute.get_replica_context() if replica_context is not None: # Map correspondent replica of MirroredVariable to replica concrete function replica_id = get_current_replica_id_as_int() new_variables = [] new_captured = [] for concrete_var_name, var, input_tensor in zip_longest( model_obj.sorted_concrete_vars_names, model_obj.mirrored_variables + self.op_vars, model_obj.fn_train.inputs[1:]): if concrete_var_name: # Check if some variables from other replicas are needed for # concrete function call name, idx = name_without_replica_idx(concrete_var_name) if name not in model_obj.bn_weights_names: idx = replica_id new_variables.append(var._get_replica(idx)) new_captured.append( (var._get_replica(idx).handle, input_tensor)) if not tf.distribute.has_strategy() or not replica_context: # If there is no distribute strategy or in compile time # don't change vars new_variables = model_obj.fn_train.graph.variables new_captured = model_obj.fn_train.graph.captures fn_train = make_new_func(model_obj.fn_train.graph.as_graph_def(), new_captured, new_variables, model_obj.fn_train.inputs, [model_obj.output_tensor]) return fn_train(inputs)