Ejemplo n.º 1
0
def get_tpu_strategy(enable_packed_var=False):
    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    tpu_strategy_util.initialize_tpu_system(resolver)
    strategy = tpu_lib.TPUStrategyV2(resolver)
    strategy._enable_packed_variable_in_eager_mode = enable_packed_var
    return strategy
Ejemplo n.º 2
0
  def test_model_parallelism_checkpointing(self):

    class PartitionedModel(module.Module):

      def __init__(self, v, w):
        super(PartitionedModel, self).__init__()

        assert distribution_strategy_context.has_strategy()
        strategy = distribution_strategy_context.get_strategy()

        with strategy.extended.experimental_logical_device(0):
          self.v = variables.Variable(v)
        with strategy.extended.experimental_logical_device(1):
          self.w = variables.Variable(w)

      def __call__(self, x):
        replica_ctx = distribution_strategy_context.get_replica_context()
        with replica_ctx.experimental_logical_device(0):
          y = self.v * x
        with replica_ctx.experimental_logical_device(1):
          z = self.w * y
        return z

      def change_weights_op(self, v_new, w_new):
        return control_flow_ops.group([self.v.assign(v_new),
                                       self.w.assign(w_new)])

    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    topology = tpu_strategy_util.initialize_tpu_system(resolver)
    device_assignment = device_assignment_lib.DeviceAssignment(
        topology, core_assignment=[[[0, 0, 0, 0], [0, 0, 0, 1]]])
    strategy = tpu_lib.TPUStrategyV2(
        resolver,
        experimental_device_assignment=device_assignment)

    with strategy.scope():
      model = PartitionedModel(2., 3.)

    checkpoint_dir = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = util.Checkpoint(model=model)

    with self.cached_session() as sess:
      self.evaluate(variables.global_variables_initializer())
      checkpoint.save(file_prefix=checkpoint_prefix)

      self.evaluate(model.change_weights_op(1., 4.))
      result = strategy.run(def_function.function(model), args=(5.0,))
      self.assertEqual(20., self.evaluate(result))

      status = checkpoint.restore(
          checkpoint_management.latest_checkpoint(checkpoint_dir))
      status.run_restore_ops(sess)  # must run restore op in non-eager mode.
      status.assert_consumed()
      status.assert_existing_objects_matched()
      result = strategy.run(def_function.function(model), args=(5.0,))
      self.assertEqual(30., self.evaluate(result))
def get_tpu_strategy(enable_spmd=False):
    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    topology = tpu_strategy_util.initialize_tpu_system(resolver)
    device_assignment = device_assignment_lib.DeviceAssignment(
        topology, core_assignment=[[[0, 0, 0, 0], [0, 0, 0, 1]]])
    return tpu_lib.TPUStrategyV2(
        resolver,
        experimental_device_assignment=device_assignment,
        experimental_spmd_xla_partitioning=enable_spmd)
Ejemplo n.º 4
0
 def test_device_assignment_strategy_properties(self):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
     topology = tpu_strategy_util.initialize_tpu_system(resolver)
     device_assignment = device_assignment_lib.DeviceAssignment(
         topology, core_assignment=[[[0, 0, 0, 0]]])
     strategy = tpu_lib.TPUStrategyV2(
         resolver, experimental_device_assignment=device_assignment)
     self.assertEqual(strategy.extended.num_hosts, 1)
     self.assertEqual(strategy.num_replicas_in_sync, 1)
     self.assertEqual(strategy.extended.num_replicas_per_host, 1)  # pylint: disable=protected-access
def get_tpu_strategy(enable_spmd=False):
  resolver = get_tpu_cluster_resolver()
  remote.connect_to_cluster(resolver)
  topology = tpu_strategy_util.initialize_tpu_system(resolver)
  num_replicas = resolver.get_tpu_system_metadata().num_cores // 2
  device_assignment = device_assignment_lib.DeviceAssignment.build(
      topology, num_replicas=num_replicas, computation_shape=[1, 1, 1, 2])
  return tpu_lib.TPUStrategyV2(
      resolver,
      experimental_device_assignment=device_assignment,
      experimental_spmd_xla_partitioning=enable_spmd)
Ejemplo n.º 6
0
  def test_computation_on_subset_cores(self, enable_packed_var):
    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    topology = tpu_strategy_util.initialize_tpu_system(resolver)
    all_core_strategy = tpu_lib.TPUStrategyV2(resolver)
    all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var

    with all_core_strategy.scope():
      v = variables.Variable(0.0,
                             aggregation=variables.VariableAggregation.MEAN)

    # Computation on the 1st core.
    device_assignment = device_assignment_lib.DeviceAssignment.build(
        topology, num_replicas=1)
    first_core_strategy = tpu_lib.TPUStrategyV2(
        resolver, experimental_device_assignment=device_assignment)
    first_core_strategy._enable_packed_variable_in_eager_mode = (
        enable_packed_var)

    # Computation on the 2nd core.
    device_assignment2 = device_assignment_lib.DeviceAssignment(
        topology, [[[0, 0, 0, 1]]])
    second_core_strategy = tpu_lib.TPUStrategyV2(
        resolver, experimental_device_assignment=device_assignment2)
    second_core_strategy._enable_packed_variable_in_eager_mode = (
        enable_packed_var)

    @def_function.function
    def train_step():

      def step_fn():
        return v + 1.0

      all_core_strategy.run(step_fn)
      r1 = first_core_strategy.run(step_fn)
      r2 = second_core_strategy.run(step_fn)
      return r1 + r2

    train_step()
    self.assertAllEqual(2., train_step())
Ejemplo n.º 7
0
  def test_variables_mismatched_device_assignment(self):
    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    topology = tpu_strategy_util.initialize_tpu_system(resolver)

    strategy0 = tpu_lib.TPUStrategyV2(resolver)
    self.assertEqual(
        ("/job:localhost/replica:0/task:0/device:TPU:0",
         "/job:localhost/replica:0/task:0/device:TPU:1"),
        strategy0.extended.worker_devices)

    with strategy0.scope():
      v = variables.Variable(1.)

    v1_assign_op = strategy0.experimental_local_results(v)[1].assign(42.)

    with self.cached_session():
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(v1_assign_op)
      self.assertAllEqual([1., 42.],
                          self.evaluate(
                              strategy0.experimental_local_results(v)))

    # Second strategy has devices reversed relative to the first.
    device_assignment = device_assignment_lib.DeviceAssignment(
        topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]])
    strategy1 = tpu_lib.TPUStrategyV2(
        resolver,
        experimental_device_assignment=device_assignment)
    self.assertEqual(
        ("/job:localhost/replica:0/task:0/device:TPU:1",
         "/job:localhost/replica:0/task:0/device:TPU:0"),
        strategy1.extended.worker_devices)

    v_read = strategy1.run(def_function.function(v.read_value))

    with self.cached_session():
      self.assertAllEqual([42., 1.],
                          self.evaluate(
                              strategy0.experimental_local_results(v_read)))
Ejemplo n.º 8
0
    def _create_tpu_strategy():
        FLAGS = flags.FLAGS  # pylint: disable=invalid-name
        global _did_connect_to_cluster
        global _topology

        try:
            # Attempt to locally discover the TPU. This will fail for Cloud TPU, in
            # which case we fall back to the values passed as flags.
            resolver = tpu_cluster_resolver.TPUClusterResolver()
            did_automatically_resolve = True
        except ValueError:
            did_automatically_resolve = False

            # These flags will be defined by tpu_test_wrapper.py.
            resolver = tpu_cluster_resolver.TPUClusterResolver(
                tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "",
                zone=hasattr(FLAGS, "zone") and FLAGS.zone or None,
                project=hasattr(FLAGS, "project") and FLAGS.project or None,
            )

        # Only connect once per process, rather than per test method.
        if not _did_connect_to_cluster:
            if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
                remote.connect_to_cluster(resolver)
                _did_connect_to_cluster = True
            _topology = tpu_strategy_util.initialize_tpu_system(resolver)

        device_assignment = None
        if use_single_core:
            device_assignment = device_assignment_lib.DeviceAssignment(
                _topology,
                core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)

        # Steps per run is only supported in TF 1.x
        if tf2.enabled():
            strategy = tpu_lib.TPUStrategyV2(
                resolver,
                device_assignment,
                experimental_spmd_xla_partitioning=enable_spmd_xla_paritioning,
                **kwargs)
        else:
            strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run,
                                             device_assignment, **kwargs)
        if enable_packed_variable and enable_spmd_xla_paritioning:
            raise ValueError(
                "Packed Variable is not compatiable with SPMD mode")
        strategy._enable_packed_variable_in_eager_mode = enable_packed_variable  # pylint: disable=protected-access
        return strategy
Ejemplo n.º 9
0
  def test_update_config_proto(self):
    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    tpu_strategy_util.initialize_tpu_system(resolver)
    strategy = tpu_lib.TPUStrategyV2(resolver)

    config_proto = config_pb2.ConfigProto()
    cluster_spec = server_lib.ClusterSpec({"worker": ["fake1", "fake2"]})
    with test.mock.patch.object(
        resolver, "cluster_spec", return_value=cluster_spec):
      new_config = strategy.update_config_proto(config_proto)

    # Verify cluster_def.
    self.assertProtoEquals(cluster_spec.as_cluster_def(),
                           new_config.cluster_def)

    # Verify isolate_session_state
    self.assertTrue(new_config.isolate_session_state)
Ejemplo n.º 10
0
  def test_model_parallelism(self):
    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    topology = tpu_strategy_util.initialize_tpu_system(resolver)
    device_assignment = device_assignment_lib.DeviceAssignment(
        topology, core_assignment=[[[0, 0, 0, 0], [0, 0, 0, 1]]])
    strategy = tpu_lib.TPUStrategyV2(
        resolver,
        experimental_device_assignment=device_assignment)

    with strategy.scope():
      v = variables.Variable(2.)
      with strategy.extended.experimental_logical_device(1):
        w = variables.Variable(3.)

    self.assertLen(strategy.experimental_local_results(v), 1)
    self.assertLen(strategy.experimental_local_results(w), 1)
    self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:0",
                     strategy.experimental_local_results(v)[0].device)
    self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:1",
                     strategy.experimental_local_results(w)[0].device)

    logical_devices = []
    @def_function.function
    def f(x):
      replica_ctx = distribution_strategy_context.get_replica_context()
      with replica_ctx.experimental_logical_device(0):
        y = v * x
      with replica_ctx.experimental_logical_device(1):
        z = w * y
      logical_devices.append((y.device, z.device))
      return z

    result = strategy.run(f, args=(5.,))

    self.assertEqual(
        [("/device:TPU_REPLICATED_CORE:0", "/device:TPU_REPLICATED_CORE:1")],
        logical_devices)

    with self.cached_session():
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(30., self.evaluate(result))
Ejemplo n.º 11
0
def get_tpu_strategy():
    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    tpu_strategy_util.initialize_tpu_system(resolver)
    return tpu_lib.TPUStrategyV2(resolver)