예제 #1
0
    def predict(actions, state):
        state = state.copy()
        # break down the inputs along the batch dimension to form equal sized
        # tensors in each replica.
        num_replicas = strategy.num_replicas_in_sync
        actions = tf.split(actions, num_replicas)
        state = {
            key: tf.split(value, num_replicas)
            for key, value in state.items()
        }
        devices = values.ReplicaDeviceMap(strategy.extended.worker_devices)
        dist_actions = values.PerReplica(devices, tuple(actions))
        dist_state = []
        for i in range(num_replicas):
            dist_state.append({key: value[i] for key, value in state.items()})
        dist_state = values.PerReplica(devices, tuple(dist_state))

        dist_predictions = strategy.experimental_run_v2(model.predict,
                                                        args=(dist_actions,
                                                              dist_state))
        dist_predictions = {
            key: strategy.experimental_local_results(value)
            for key, value in dist_predictions.items()
        }
        predictions = {
            key: tf.concat(value, axis=0)
            for key, value in dist_predictions.items()
        }
        return predictions
예제 #2
0
 def observe(images, actions, rewards, state):
     images = tf.to_float(images) / 255.0 - 0.5
     # break down the inputs along the batch dimension to form equal sized
     # tensors in each replica.
     num_replicas = strategy.num_replicas_in_sync
     images = tf.split(images, num_replicas)
     actions = tf.split(actions, num_replicas)
     state = {
         key: tf.split(value, num_replicas)
         for key, value in state.items()
     }
     devices = values.ReplicaDeviceMap(strategy.extended.worker_devices)
     dist_images = values.PerReplica(devices, tuple(images))
     dist_actions = values.PerReplica(devices, tuple(actions))
     dist_state = []
     for i in range(num_replicas):
         dist_state.append({key: value[i] for key, value in state.items()})
     dist_state = values.PerReplica(devices, tuple(dist_state))
     _, dist_posteriors = strategy.experimental_run_v2(model.observe,
                                                       args=(dist_actions,
                                                             dist_images,
                                                             dist_state))
     dist_posteriors = {
         key: strategy.experimental_local_results(value)
         for key, value in dist_posteriors.items()
     }
     posteriors = {
         key: tf.concat(value, axis=0)
         for key, value in dist_posteriors.items()
     }
     posteriors = {key: value[:, -1] for key, value in posteriors.items()}
     posteriors['rewards'] = rewards[:, -1]
     return posteriors
예제 #3
0
  def testCondWithValuesNotConvertibleToTensor(self):
    device_map = values.SingleDeviceMap("CPU")
    per_replica_1 = values.PerReplica(device_map, (set(["a"]),))
    per_replica_2 = values.PerReplica(device_map, (set(["b", "c"]),))
    condition = array_ops.placeholder(dtypes.bool, [])

    with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"):
      control_flow_ops.cond(
          condition, lambda: per_replica_1, lambda: per_replica_2)
    def testCondWithValuesNotConvertibleToTensor(self):
        per_replica_1 = values_lib.PerReplica(({"a"}, ))
        per_replica_2 = values_lib.PerReplica(({"b", "c"}, ))
        condition = array_ops.placeholder(dtypes.bool, [])

        with self.assertRaisesRegex(TypeError,
                                    "Could not build a TypeSpec for"):
            control_flow_ops.cond(condition, lambda: per_replica_1,
                                  lambda: per_replica_2)
예제 #5
0
  def testPassPerReplica(self, distribution):
    @function.defun
    def fn1(mock_model, factor):
      return mock_model(factor)

    device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
    factors = values.PerReplica(device_map, (5.0, 3.0))
    expected_result = values.PerReplica(device_map, (5.0 * 1.25, 3.0 * 1.25))
    self._call_and_check(distribution, fn1, [factors], expected_result, [fn1])
    def testCondWithValuesConvertibleToTensor(self):
        per_replica_1 = values_lib.PerReplica(("a", ))
        per_replica_2 = values_lib.PerReplica(("b", ))
        condition = array_ops.placeholder_with_default(True, [])

        result = control_flow_ops.cond(condition, lambda: per_replica_1,
                                       lambda: per_replica_2)

        self.assertLen(result.values, 1)
        self.assertAllEqual(result.values[0], "a")
예제 #7
0
def _gather(strategy, value):
  """Gathers a single value."""
  # pylint: disable=protected-access
  if not isinstance(value, values.DistributedValues):
    value = values.PerReplica([ops.convert_to_tensor(value)])
  if not isinstance(strategy.extended,
                    collective_all_reduce_strategy.CollectiveAllReduceExtended):
    return array_ops.stack(value._values)
  assert len(strategy.extended.worker_devices) == len(value._values)
  inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
  return strategy.gather(values.PerReplica(inputs), axis=0)
    def testCondWithTensorValues(self):
        per_replica_1 = values_lib.PerReplica((constant_op.constant("a"), ))
        per_replica_2 = values_lib.PerReplica((constant_op.constant(["b",
                                                                     "c"]), ))
        condition = array_ops.placeholder_with_default(True, [])

        result = control_flow_ops.cond(condition, lambda: per_replica_1,
                                       lambda: per_replica_2)

        self.assertLen(result.values, 1)
        self.assertAllEqual(result.values[0], "a")
예제 #9
0
  def testCondWithValuesConvertibleToTensor(self):
    device_map = values.SingleDeviceMap("CPU")
    per_replica_1 = values.PerReplica(device_map, ("a",))
    per_replica_2 = values.PerReplica(device_map, ("b",))
    condition = array_ops.placeholder_with_default(True, [])

    result = control_flow_ops.cond(
        condition, lambda: per_replica_1, lambda: per_replica_2)

    self.assertEqual(per_replica_1.device_map, result.device_map)
    self.assertEqual(per_replica_1.logical_device, result.logical_device)
    self.assertLen(result.values, 1)
    self.assertAllEqual(result.values[0], "a")
    def testTimeoutReduceSparse(self, communication, required_gpus):
        hints = collective_util.Hints(timeout_seconds=1)
        collective, devices, _ = self._get_test_objects(
            "worker",
            0,
            num_gpus=required_gpus,
            communication=communication,
            use_strategy_object=False)
        remote.connect_to_cluster(multi_worker_util.normalize_cluster_spec(
            self._cluster_spec),
                                  protocol="grpc")
        devices = [device_util.canonicalize(d) for d in devices]
        v = value_lib.PerReplica([
            _make_indexed_slices([[4., 6.], [5., 6.]], [1, 3], [5, 2],
                                 devices[0])
        ])

        @def_function.function
        def reduce_sparse():
            collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints)

        # The collective should time out because we only launch it on worker-0,
        # while there're three workers in total.
        with self.assertRaises(errors.DeadlineExceededError):
            reduce_sparse()

        # Reset since collective failures poison the context.
        context._reset_context()  # pylint: disable=protected-access
예제 #11
0
    def testTypeSpec(self):
        vals = (constant_op.constant(1.), )
        per_replica = values_lib.PerReplica(vals)

        spec = per_replica._type_spec
        self.assertEqual(spec._value_specs,
                         (tensor_spec.TensorSpec([], dtypes.float32), ))
예제 #12
0
 def testFunctionCanReturnPerReplica(self):
     f = def_function.function(lambda x: x)
     x = values_lib.PerReplica((constant_op.constant(1.), ))
     y = f(x)
     self.assertIsNot(x, y)
     nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
     self.assertEqual(x._type_spec, y._type_spec)
예제 #13
0
 def testShapeInvariantToComponentsExplicitShape(self):
   v1 = constant_op.constant([1., 1., 1.])
   v2 = constant_op.constant([2., 2., 2.])
   per_replica = values.PerReplica(values.SingleDeviceMap("CPU"), (v1, v2))
   shape = [None]
   self.assertEqual(per_replica._shape_invariant_to_components(shape=shape),
                    (shape, shape))
예제 #14
0
 def testShapeInvariantToComponents(self):
     v1 = constant_op.constant(1.)
     v2 = constant_op.constant(2.)
     per_replica = values.PerReplica(values.SingleDeviceMap("CPU"),
                                     (v1, v2))
     self.assertEqual(per_replica._shape_invariant_to_components(),
                      (v1.shape, v2.shape))
예제 #15
0
def nccl_gradient_sync(input, layer_id, construct_log):
    with tf.name_scope("gradient_sync_"+str(layer_id)):
        from tensorflow.python.distribute import values as value_lib

        nccl = tf.contrib.distribute.AllReduceCrossDeviceOps(all_reduce_alg='hierarchical_copy')

        tower_gradients = construct_log["tower_gradients"]
        destinations = construct_log["tower_devices"]

        grad_var_towers = list(zip(*tower_gradients))

        synchronized_grad_vars = []

        batch_reduce_vals = []
        valid_grad_var_towers = []
        for tgv in grad_var_towers:
            if tgv[0][0] is not None:
                per_replica = value_lib.PerReplica({ device: gv[0] for device, gv in zip(destinations, tgv)})
                batch_reduce_vals.append((per_replica, destinations))
                valid_grad_var_towers.append(tgv)
            else:
                for gv in tgv:
                    synchronized_grad_vars.append(gv)

        batch_mirrored = nccl.batch_reduce(tf.distribute.ReduceOp.MEAN, batch_reduce_vals)
        for tgv, mirrored in zip(valid_grad_var_towers, batch_mirrored):
                for device, gv in zip(destinations, tgv):
                    with tf.device(device):
                        synchronized_grad_vars.append((mirrored.get(device), gv[1]))


        construct_log["gradients"] = synchronized_grad_vars
        return input
예제 #16
0
  def testDoesNotTriggerFunctionTracing(self):
    traces = []

    @def_function.function
    def f(x):
      traces.append(None)  # Only happens on trace.
      return x

    per_replica = values.PerReplica(
        values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))

    # Trace once.
    f(per_replica)
    self.assertNotEmpty(traces)
    del traces[:]

    per_replica_spec = per_replica._type_spec
    for _ in range(5):
      vals = per_replica_spec._to_components(per_replica)
      vals = [v * 2 for v in vals]
      per_replica = per_replica_spec._from_components(vals)

      output = f(per_replica)
      self.assertIsInstance(output, values.PerReplica)
      self.assertAllEqual(output._values, per_replica._values)
      self.assertAllEqual(output._device_map, per_replica._device_map)
      self.assertAllEqual(output._logical_device, per_replica._logical_device)
      self.assertEmpty(traces)  # Make sure we're not re-tracing `f`.
    def testStrategyExtendedUpdate(self, distribution, synchronization,
                                   aggregation):
        if len(distribution.extended.parameter_devices) != 2:
            self.skipTest("n/a: needs exactly two parameter devices")
        if (synchronization == variables_lib.VariableSynchronization.ON_WRITE
                and aggregation != variables_lib.VariableAggregation.NONE):
            self.skipTest(
                "n/a: doesn't apply to ON_WRITE variable with aggregation")
        with distribution.scope():
            v = variables_lib.Variable(0.,
                                       synchronization=synchronization,
                                       aggregation=aggregation)
        value = values_lib.PerReplica([1., 2.])

        assign_fn = lambda var, value: var.assign(value)
        self.evaluate(
            distribution.extended.update(v, assign_fn, args=(value, )))
        self.assertAllEqual(self.evaluate(v.values), [1., 2.])

        assign_add_fn = lambda var, value: var.assign_add(value)
        self.evaluate(
            distribution.extended.update(v, assign_add_fn, args=(value, )))
        self.assertAllEqual(self.evaluate(v.values), [2., 4.])

        assign_sub_fn = lambda var, value: var.assign_sub(value)
        self.evaluate(
            distribution.extended.update(v, assign_sub_fn, args=(value, )))
        self.assertAllEqual(self.evaluate(v.values), [1., 2.])

        read_assign_fn = lambda var, value: var.assign_add(var.value() + var.
                                                           read_value())
        self.evaluate(
            distribution.extended.update(v, read_assign_fn, args=(value, )))
        self.assertAllEqual(self.evaluate(v.values), [3., 6.])
예제 #18
0
  def testMirroredStratParaAsync(self):
    """Tests RNG/MirrorStrategy interaction #3.

    The user can create n independent RNGs outside strategy.scope(), where n
    is the number of replicas, and give one to each replica. The replicas can
    thus get different random-number streams.
    """
    shape = [3, 4]
    dtype = dtypes.int32
    gens = random.get_global_generator().split(count=2)
    devices = ["/cpu:0", test_util.gpu_device_name()]
    strat = MirroredStrategy(devices=devices)
    # Use `PerReplica` to specify which `gen` is sent to which replica
    gens = dist_values.PerReplica(
        device_map=dist_values.ReplicaDeviceMap(devices),
        values=[[g] for g in gens])
    with strat.scope():
      def f(gen):
        t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
        t = array_ops.stack([t1, t2])
        return t
      results = strat.extended.call_for_each_replica(
          fn=f, args=gens)
      values = results.values
      self.assertAllEqual(2, len(values))
      self.assertAllDifferent(values)
예제 #19
0
  def _get_indexed_slices(self,
                          devices,
                          start_i,
                          variable_length,
                          as_per_replica=True):
    dense_shape = [10, 2]
    values = ([[1., 2.]], [[3., 4.]], [[2., 1.]], [[0., 0.]], [[3., 1.]],
              [[2., 1.]])
    indices = ([1], [2], [3], [4], [5], [6])

    # values and indices that have variable lengths.
    vl_values = ([[1., 2.], [3., 4.]], [[3., 4.]], [[2., 1.]], [[0., 0.]],
                 [[3., 1.], [2., 1.]], [[2., 1.]])
    vl_indices = ([1, 2], [2], [3], [4], [5, 6], [6])

    indexed_slices = []
    for i, d in enumerate(devices):
      idx = i + start_i
      indexed_slices.append(
          _make_indexed_slices(
              vl_values[idx] if variable_length else values[idx],
              vl_indices[idx] if variable_length else indices[idx], dense_shape,
              d))
    if as_per_replica:
      per_replica = value_lib.PerReplica(indexed_slices)
      return per_replica
    else:
      return indexed_slices
예제 #20
0
    def __init__(self,
                 container_strategy,
                 tpu_cluster_resolver,
                 steps_per_run,
                 num_cores=None):
        super(TPUExtended, self).__init__(container_strategy)
        self._tpu_cluster_resolver = tpu_cluster_resolver
        self._tpu_metadata = get_tpu_system_metadata(
            self._tpu_cluster_resolver)
        # TODO(sourabhbajaj): Change this from num_cores to metadata_override
        self._num_cores_override = num_cores

        # TODO(jhseu): Switch to DeviceAssignment to support pods and model
        # parallelism.
        device_map = {
            d.name: i
            for i, d in enumerate(self._tpu_metadata.devices)
            if "device:TPU:" in d.name
        }
        self._device_index = values.PerReplica(device_map)
        self._host_device = self.get_host_cpu_device(0)
        self._tpu_devices = sorted(device_map.keys())
        # Only create variables for the number of replicas we're running.
        self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]

        # TODO(sourabhbajaj): Remove this once performance of running one step
        # at a time is comparable to multiple steps.
        self.steps_per_run = steps_per_run

        self._require_static_shapes = True
예제 #21
0
def _gather(strategy, value):
    """Gathers a single value."""
    # pylint: disable=protected-access
    if not isinstance(value, values.DistributedValues):
        assert isinstance(value, core.Tensor)
        value = values.PerReplica([value])
    if not isinstance(
            strategy.extended,
            collective_all_reduce_strategy.CollectiveAllReduceExtended):
        return array_ops.stack(value._values)
    assert len(strategy.extended.worker_devices) == len(value._values)
    inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
    collective_keys = strategy.extended._collective_keys
    devices = strategy.extended.worker_devices
    group_size = strategy.num_replicas_in_sync

    @def_function.function
    def gather_fn():
        gathered = cross_device_utils.build_collective_gather(
            inputs, devices, group_size, collective_keys)
        return distribute_utils.update_regroup(strategy.extended,
                                               gathered,
                                               group=True)

    return gather_fn()
예제 #22
0
def make_per_replica_value(value_fn, devices):
  """Creates a `PerReplica` object whose values reside in `devices`.

  Args:
    value_fn: a callable that takes one argument (`device_idx`) and should
      return the value that is going to be created on devices[device_idx].
    devices: a list of device strings to create `PerReplica` values on.

  Returns:
    A `PerReplica` object.
  """
  values = []
  for device_idx, device in enumerate(devices):
    v = value_fn(device_idx)
    if isinstance(v, indexed_slices.IndexedSlicesValue):
      with ops.device(device):
        values.append(
            indexed_slices.IndexedSlices(
                values=array_ops.identity(v.values),
                indices=array_ops.identity(v.indices),
                dense_shape=array_ops.identity(v.dense_shape)))
    else:
      with ops.device(device):
        values.append(array_ops.identity(v))
  return value_lib.PerReplica(values)
예제 #23
0
    def _initialize_local(self, num_gpus, devices):
        """Initializes the object for local training."""
        self._cluster_spec = None
        # Convert `num_gpus` into `devices`, shouldn't specify both.
        if devices is None:
            if num_gpus is None:
                num_gpus = context.num_gpus()
            if num_gpus == 0:
                devices = ["/device:CPU:0"]
            else:
                devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
        elif num_gpus is not None:
            raise ValueError(
                "Must only specify one of `devices` and `num_gpus`.")
        self._num_gpus = num_gpus
        # TODO(yuefengz): consider setting the default device.

        assert devices, "Must specify at least one device."
        assert len(set(devices)) == len(devices), (
            "No duplicates allowed in `devices` argument.")
        # TODO(josh11b): Require at least 2 devices?
        self._devices = [device_util.resolve(d) for d in devices]
        self._canonical_device_set = set(self._devices)
        self._device_index = values.PerReplica(
            {d: i
             for i, d in enumerate(devices)})
예제 #24
0
    def _initialize_multi_worker(self, devices):
        """Initializes the object for multi-worker training."""
        self._local_mode = False

        assert devices, "Must specify at least one device."
        assert len(set(devices)) == len(devices), (
            "No duplicates allowed in `devices` argument.")
        # TODO(josh11b): Require at least 2 devices?
        self._devices = tuple(device_util.resolve(d) for d in devices)
        self._canonical_device_set = set(self._devices)
        self._device_index = values.PerReplica(
            {d: i
             for i, d in enumerate(devices)})

        device_dict = _group_device_list(devices)
        self._workers = []
        self._worker_devices = []
        for job in ["chief", "worker"]:
            for task in range(len(device_dict.get(job, []))):
                worker = "/job:%s/task:%d" % (job, task)
                self._workers.append(worker)
                self._worker_devices.append((worker, device_dict[job][task]))

        # Setting `_default_device` will add a device scope in the
        # distribution.scope. We set the default device to the first worker. When
        # users specify device under distribution.scope by
        #   with tf.device("/cpu:0"):
        #     ...
        # their ops will end up on the cpu device of its first worker, e.g.
        # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
        self._default_device = self._workers[0]

        self._inferred_cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce(
            self._workers, _infer_num_gpus_per_worker(self._devices))
 def testInconsistentShape(self):
   per_replica_values = [
       value_lib.PerReplica([
           array_ops.ones([10, 10], dtype=dtypes.float32),
           array_ops.ones([10, 10], dtype=dtypes.float32),
       ]),
       value_lib.PerReplica([
           array_ops.ones([10, 10], dtype=dtypes.float32),
           input_layer.Input(
               shape=(10), batch_size=None, dtype=dtypes.float32),
       ]),
   ]
   packs = cross_device_utils.pack_by_size(
       per_replica_values, bytes_per_pack=1)
   self.assertLen(packs, 1)
   self.assertEqual(packs[0], per_replica_values)
 def testContainsIndexedSlices_PerReplica(self):
   t0 = math_ops._as_indexed_slices(
       constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
   t1 = math_ops._as_indexed_slices(
       constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
   per_replica = value_lib.PerReplica((t0, t1))
   self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica))
예제 #27
0
 def testToComponents(self):
   device_map = values.SingleDeviceMap("CPU")
   vals = (constant_op.constant(1.),)
   per_replica = values.PerReplica(device_map, vals)
   logical_device = 0
   self.assertEqual(per_replica._to_components(), vals)
   self.assertEqual(per_replica._component_metadata(), (device_map,
                                                        logical_device))
예제 #28
0
    def testTypeSpecRoundTrip(self):
        vals = (constant_op.constant(1.), )
        per_replica = values_lib.PerReplica(vals)

        spec = per_replica._type_spec
        tensor_list = spec._to_components(per_replica)
        reconstructed = spec._from_components(tensor_list)

        self.assertAllEqual(per_replica.values, reconstructed.values)
예제 #29
0
 def testFunctionCanReturnPerReplica(self):
   f = def_function.function(lambda x: x)
   x = values.PerReplica(
       values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
   y = f(x)
   self.assertIsNot(x, y)
   for a, b in zip(x._to_components(), y._to_components()):
     self.assertAllEqual(a, b)
   self.assertEqual(x._component_metadata(), y._component_metadata())
예제 #30
0
  def testTypeSpec(self):
    device_map = values.SingleDeviceMap("CPU")
    vals = (constant_op.constant(1.),)
    per_replica = values.PerReplica(device_map, vals)

    spec = per_replica._type_spec
    self.assertEqual(spec._value_specs,
                     (tensor_spec.TensorSpec([], dtypes.float32),))
    self.assertEqual(spec._device_map, per_replica.device_map)
    self.assertEqual(spec._logical_device, per_replica.logical_device)