def get_tpu_strategy(enable_packed_var=False): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) tpu_strategy_util.initialize_tpu_system(resolver) strategy = tpu_lib.TPUStrategyV2(resolver) strategy._enable_packed_variable_in_eager_mode = enable_packed_var return strategy
def test_model_parallelism_checkpointing(self): class PartitionedModel(module.Module): def __init__(self, v, w): super(PartitionedModel, self).__init__() assert distribution_strategy_context.has_strategy() strategy = distribution_strategy_context.get_strategy() with strategy.extended.experimental_logical_device(0): self.v = variables.Variable(v) with strategy.extended.experimental_logical_device(1): self.w = variables.Variable(w) def __call__(self, x): replica_ctx = distribution_strategy_context.get_replica_context() with replica_ctx.experimental_logical_device(0): y = self.v * x with replica_ctx.experimental_logical_device(1): z = self.w * y return z def change_weights_op(self, v_new, w_new): return control_flow_ops.group([self.v.assign(v_new), self.w.assign(w_new)]) resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = device_assignment_lib.DeviceAssignment( topology, core_assignment=[[[0, 0, 0, 0], [0, 0, 0, 1]]]) strategy = tpu_lib.TPUStrategyV2( resolver, experimental_device_assignment=device_assignment) with strategy.scope(): model = PartitionedModel(2., 3.) checkpoint_dir = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") checkpoint = util.Checkpoint(model=model) with self.cached_session() as sess: self.evaluate(variables.global_variables_initializer()) checkpoint.save(file_prefix=checkpoint_prefix) self.evaluate(model.change_weights_op(1., 4.)) result = strategy.run(def_function.function(model), args=(5.0,)) self.assertEqual(20., self.evaluate(result)) status = checkpoint.restore( checkpoint_management.latest_checkpoint(checkpoint_dir)) status.run_restore_ops(sess) # must run restore op in non-eager mode. status.assert_consumed() status.assert_existing_objects_matched() result = strategy.run(def_function.function(model), args=(5.0,)) self.assertEqual(30., self.evaluate(result))
def get_tpu_strategy(enable_spmd=False): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = device_assignment_lib.DeviceAssignment( topology, core_assignment=[[[0, 0, 0, 0], [0, 0, 0, 1]]]) return tpu_lib.TPUStrategyV2( resolver, experimental_device_assignment=device_assignment, experimental_spmd_xla_partitioning=enable_spmd)
def test_device_assignment_strategy_properties(self): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = device_assignment_lib.DeviceAssignment( topology, core_assignment=[[[0, 0, 0, 0]]]) strategy = tpu_lib.TPUStrategyV2( resolver, experimental_device_assignment=device_assignment) self.assertEqual(strategy.extended.num_hosts, 1) self.assertEqual(strategy.num_replicas_in_sync, 1) self.assertEqual(strategy.extended.num_replicas_per_host, 1) # pylint: disable=protected-access
def get_tpu_strategy(enable_spmd=False): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) num_replicas = resolver.get_tpu_system_metadata().num_cores // 2 device_assignment = device_assignment_lib.DeviceAssignment.build( topology, num_replicas=num_replicas, computation_shape=[1, 1, 1, 2]) return tpu_lib.TPUStrategyV2( resolver, experimental_device_assignment=device_assignment, experimental_spmd_xla_partitioning=enable_spmd)
def test_computation_on_subset_cores(self, enable_packed_var): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) all_core_strategy = tpu_lib.TPUStrategyV2(resolver) all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var with all_core_strategy.scope(): v = variables.Variable(0.0, aggregation=variables.VariableAggregation.MEAN) # Computation on the 1st core. device_assignment = device_assignment_lib.DeviceAssignment.build( topology, num_replicas=1) first_core_strategy = tpu_lib.TPUStrategyV2( resolver, experimental_device_assignment=device_assignment) first_core_strategy._enable_packed_variable_in_eager_mode = ( enable_packed_var) # Computation on the 2nd core. device_assignment2 = device_assignment_lib.DeviceAssignment( topology, [[[0, 0, 0, 1]]]) second_core_strategy = tpu_lib.TPUStrategyV2( resolver, experimental_device_assignment=device_assignment2) second_core_strategy._enable_packed_variable_in_eager_mode = ( enable_packed_var) @def_function.function def train_step(): def step_fn(): return v + 1.0 all_core_strategy.run(step_fn) r1 = first_core_strategy.run(step_fn) r2 = second_core_strategy.run(step_fn) return r1 + r2 train_step() self.assertAllEqual(2., train_step())
def test_variables_mismatched_device_assignment(self): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) strategy0 = tpu_lib.TPUStrategyV2(resolver) self.assertEqual( ("/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1"), strategy0.extended.worker_devices) with strategy0.scope(): v = variables.Variable(1.) v1_assign_op = strategy0.experimental_local_results(v)[1].assign(42.) with self.cached_session(): self.evaluate(variables.global_variables_initializer()) self.evaluate(v1_assign_op) self.assertAllEqual([1., 42.], self.evaluate( strategy0.experimental_local_results(v))) # Second strategy has devices reversed relative to the first. device_assignment = device_assignment_lib.DeviceAssignment( topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]]) strategy1 = tpu_lib.TPUStrategyV2( resolver, experimental_device_assignment=device_assignment) self.assertEqual( ("/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0"), strategy1.extended.worker_devices) v_read = strategy1.run(def_function.function(v.read_value)) with self.cached_session(): self.assertAllEqual([42., 1.], self.evaluate( strategy0.experimental_local_results(v_read)))
def _create_tpu_strategy(): FLAGS = flags.FLAGS # pylint: disable=invalid-name global _did_connect_to_cluster global _topology try: # Attempt to locally discover the TPU. This will fail for Cloud TPU, in # which case we fall back to the values passed as flags. resolver = tpu_cluster_resolver.TPUClusterResolver() did_automatically_resolve = True except ValueError: did_automatically_resolve = False # These flags will be defined by tpu_test_wrapper.py. resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "", zone=hasattr(FLAGS, "zone") and FLAGS.zone or None, project=hasattr(FLAGS, "project") and FLAGS.project or None, ) # Only connect once per process, rather than per test method. if not _did_connect_to_cluster: if getattr(FLAGS, "tpu", "") or did_automatically_resolve: remote.connect_to_cluster(resolver) _did_connect_to_cluster = True _topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = None if use_single_core: device_assignment = device_assignment_lib.DeviceAssignment( _topology, core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) # Steps per run is only supported in TF 1.x if tf2.enabled(): strategy = tpu_lib.TPUStrategyV2( resolver, device_assignment, experimental_spmd_xla_partitioning=enable_spmd_xla_paritioning, **kwargs) else: strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run, device_assignment, **kwargs) if enable_packed_variable and enable_spmd_xla_paritioning: raise ValueError( "Packed Variable is not compatiable with SPMD mode") strategy._enable_packed_variable_in_eager_mode = enable_packed_variable # pylint: disable=protected-access return strategy
def test_update_config_proto(self): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) tpu_strategy_util.initialize_tpu_system(resolver) strategy = tpu_lib.TPUStrategyV2(resolver) config_proto = config_pb2.ConfigProto() cluster_spec = server_lib.ClusterSpec({"worker": ["fake1", "fake2"]}) with test.mock.patch.object( resolver, "cluster_spec", return_value=cluster_spec): new_config = strategy.update_config_proto(config_proto) # Verify cluster_def. self.assertProtoEquals(cluster_spec.as_cluster_def(), new_config.cluster_def) # Verify isolate_session_state self.assertTrue(new_config.isolate_session_state)
def test_model_parallelism(self): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = device_assignment_lib.DeviceAssignment( topology, core_assignment=[[[0, 0, 0, 0], [0, 0, 0, 1]]]) strategy = tpu_lib.TPUStrategyV2( resolver, experimental_device_assignment=device_assignment) with strategy.scope(): v = variables.Variable(2.) with strategy.extended.experimental_logical_device(1): w = variables.Variable(3.) self.assertLen(strategy.experimental_local_results(v), 1) self.assertLen(strategy.experimental_local_results(w), 1) self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:0", strategy.experimental_local_results(v)[0].device) self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:1", strategy.experimental_local_results(w)[0].device) logical_devices = [] @def_function.function def f(x): replica_ctx = distribution_strategy_context.get_replica_context() with replica_ctx.experimental_logical_device(0): y = v * x with replica_ctx.experimental_logical_device(1): z = w * y logical_devices.append((y.device, z.device)) return z result = strategy.run(f, args=(5.,)) self.assertEqual( [("/device:TPU_REPLICATED_CORE:0", "/device:TPU_REPLICATED_CORE:1")], logical_devices) with self.cached_session(): self.evaluate(variables.global_variables_initializer()) self.assertEqual(30., self.evaluate(result))
def get_tpu_strategy(): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) tpu_strategy_util.initialize_tpu_system(resolver) return tpu_lib.TPUStrategyV2(resolver)