def testDistributeDatasetNotUsedWithClusterCoordinator(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) dataset = dataset_ops.DatasetV2.range(3) with self._assertRaisesUsageError(): def_function.function( lambda: strategy.experimental_distribute_dataset(dataset))()
def testInModelAndCapture(self, source): file_path = os.path.join(self.get_temp_dir(), "text_file_initializer") model = self.Model(source, file_path) func_captures = model.use_table.get_concrete_function( ).graph.external_captures self.assertLen(func_captures, 2) self.assertTrue( any(model.table.resource_handle is t for t in func_captures)) deferred_captures = model.use_table.get_concrete_function( ).graph.deferred_external_captures self.assertEmpty(deferred_captures) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): distributed_model = self.Model("value", file_path) func_captures = distributed_model.use_table.get_concrete_function( ).graph.external_captures # One less external_capture, since the table handle becomes a closure in the # deferred_external_capture self.assertLen(func_captures, 1) self.assertFalse( any(model.table.resource_handle is t for t in func_captures)) deferred_captures = distributed_model.use_table.get_concrete_function( ).graph.deferred_external_captures self.assertNotEmpty(deferred_captures) self.verifyWorkerLocalInstance(coordinator, distributed_model)
def testDistributeTableSaveAndServe(self, load, source): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) file_path = os.path.join(self.get_temp_dir(), "text_file_initializer") with strategy.scope(): model = self.Model(source, file_path) model_dir = self.get_temp_dir() tf_save.save(model, model_dir) if load == "tf_load": load_fn = tf_load.load else: load_fn = keras_save.load_model loaded_without_strategy = load_fn(model_dir) loaded_func_captures_without_strategy = ( loaded_without_strategy.use_table.get_concrete_function().graph .external_captures) loaded_func_deferred_captures_without_strategy = ( loaded_without_strategy.use_table.get_concrete_function().graph .deferred_external_captures) self.assertLen(loaded_func_captures_without_strategy, 2) self.assertEmpty(loaded_func_deferred_captures_without_strategy) self.assertAllEqual( loaded_without_strategy.use_table( constant_op.constant([0, 1, 3], dtype=dtypes.int64)), [0, 1, -2])
def _model_compile(self, steps_per_execution=1, run_eagerly=False): class ResultAssertingCallback(callbacks_lib.Callback): def __init__(self): self._prev_epoch = -1 def on_epoch_end(self, epoch, logs=None): logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs) if epoch <= self._prev_epoch: raise RuntimeError("Epoch is supposed to be larger than previous.") self._prev_epoch = epoch if (logs.get("loss", None) is None or not isinstance(logs["loss"], np.floating)): raise RuntimeError("loss is supposed to be in the logs and float.") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( make_cluster(3, 2)) with strategy.scope(): model = sequential.Sequential([core_layers.Dense(10)]) model.compile( gradient_descent.SGD(), loss="mse", steps_per_execution=steps_per_execution, run_eagerly=run_eagerly) return model, [ResultAssertingCallback()]
def setUp(self, num_workers, num_ps): super(BaseFaultToleranceTest, self).setUp() # Set the environment variable to prevent hanging upon job failure and # restart. Note that it defaults to 'use_caller' at Google, but defaults # to False in OSS. os.environ["GRPC_FAIL_FAST"] = "use_caller" self._cluster = multi_worker_test_base.create_multi_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") self._cluster_def = self._cluster.cluster_resolver.cluster_spec( ).as_dict() self._cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver(server_lib.ClusterSpec( self._cluster_def), rpc_layer="grpc") # The strategy's constructor would connect to the cluster. self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) self.cluster_coord = cluster_coordinator.ClusterCoordinator( self.strategy) self.thread_coord = thread_coordinator.Coordinator( clean_stop_exception_types=[]) self.num_workers = num_workers
def testBasicVariableWithAggregation(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) strategy.extended._allow_run_without_coordinator = True with strategy.scope(): v = variables.Variable( initial_value=[0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32, aggregation=variable_scope.VariableAggregation.SUM) if strategy.num_replicas_in_sync > 1: self.assertIsInstance(v, ps_values.AggregatingVariable) else: self.assertIsInstance(v, variables.Variable) def replica_fn(): replica_id = distribution_strategy_context.get_replica_context( ).replica_id_in_sync_group val = array_ops.reshape( math_ops.cast(replica_id + 10, dtype=v.dtype), [1]) v.assign( array_ops.concat( [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0)) strategy.run(replica_fn) expected_result = np.arange(8.) * strategy.num_replicas_in_sync for i in range(strategy.num_replicas_in_sync): expected_result[0] = expected_result[0] + i + 10 self.assertAllEqual(v, expected_result)
def testInteractionWithDeviceScope(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) # The strategy scope always wins. with strategy.scope(): with ops.device("/job:ps/replica:0/task:1"): v0 = variables.Variable(initial_value=0.0) self.assertEqual(v0.device, "/job:ps/replica:0/task:0/device:CPU:0") with ops.device("/job:ps/replica:0/task:0"): v1 = variables.Variable(initial_value=0.0) self.assertEqual(v1.device, "/job:ps/replica:0/task:1/device:CPU:0") with ops.device("/job:ps/replica:0/task:1"): with strategy.scope(): v2 = variables.Variable(initial_value=0.0) self.assertEqual(v2.device, "/job:ps/replica:0/task:2/device:CPU:0") v3 = variables.Variable(initial_value=0.0) self.assertEqual(v3.device, "/job:ps/replica:0/task:0/device:CPU:0")
def testDefaultNoPartition(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) with strategy.scope(): v = variables.Variable([0, 1, 2, 3]) self.assertIsInstance(v, variables.Variable)
def testBasic(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) v1 = variables.Variable( initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64), shape=(5, 2), dtype=dtypes.int64) init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5]) v2 = variables.Variable( initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64), shape=(6, 1), dtype=dtypes.int64) self.assertIsInstance(v1, sharded_variable.ShardedVariable) self.assertLen(v1.variables, 2) self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v1.variables[0], [[0, 1], [2, 3], [4, 5]]) self.assertAllEqual(v1.variables[1], [[6, 7], [8, 9]]) self.assertIsInstance(v2, sharded_variable.ShardedVariable) self.assertLen(v2.variables, 2) self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v2.variables[0], [[0], [1], [2]]) self.assertAllEqual(v2.variables[1], [[3], [4], [5]])
def testCreateInsideTFFunction(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) collection = [] @def_function.function def create_vars(): if not collection: identity = init_ops_v2.Identity() v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32) v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32)) v3 = variables.Variable( lambda: identity((2, 2), dtypes.float32), dtype=dtypes.float32, shape=(2, 2)) collection.extend([v1, v2, v3]) with strategy.scope(): create_vars() for v in collection: self.assertIsInstance(v, sharded_variable.ShardedVariable) self.assertLen(v.variables, 2) self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v.variables[0], [[1., 0.]]) self.assertAllEqual(v.variables[1], [[0., 1.]])
def testArbitraryCurrentTaskType(self): cluster_def = multi_worker_test_base.create_cluster_spec( num_workers=1, num_ps=1, has_chief=True) cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar") with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def fn(functions_scheduled_event, test_finished_event): # TODO(b/170664373): This is needed for TF2 parameter server training in # OSS. Remove this when resolved. os.environ["GRPC_FAIL_FAST"] = "use_caller" cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) ps_coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): v = variables.Variable(initial_value=0, dtype=dtypes.int32) @def_function.function def worker_fn(): # An ever-running function. for _ in math_ops.range(100000): v.assign_add(1) # Keep the two workers occupied. ps_coordinator.schedule(worker_fn) ps_coordinator.schedule(worker_fn) # Now the main process can terminate. functions_scheduled_event.set() # Verified that join and schedule indeed raise UnavailableError. try: if test_join: ps_coordinator.join() if test_schedule: while ps_coordinator.cluster._closure_queue._error is None: time.sleep(1) ps_coordinator.schedule(worker_fn) except errors.UnavailableError: # The following verifies that after PS fails, continue executing # functions on workers should fail and indicate it's PS failure. for worker_id in range(3): with ops.device( "/job:worker/replica:0/task:{}".format(worker_id)): try: # Executing a function after PS fails should result in a PS # failure. worker_fn() except Exception as e: # pylint: disable=broad-except if coordinator_lib._is_ps_failure(e): if worker_id < 2: continue logging.info( "_test_translate_ps_failure_error ends properly." ) # Now we can safely exit the test. test_finished_event.set() return raise RuntimeError( "Executing a function after PS fails, should " "result in a PS failure.") raise RuntimeError("UnavailableError supposed to be raised.")
def _model_compile(self, steps_per_execution=1, run_eagerly=False, with_normalization_layer=False): class ResultAssertingCallback(callbacks_lib.Callback): def __init__(self): self._prev_epoch = -1 def on_epoch_end(self, epoch, logs=None): logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs) if epoch <= self._prev_epoch: raise RuntimeError("Epoch is supposed to be larger than previous.") self._prev_epoch = epoch if (logs.get("loss", None) is None or not isinstance(logs["loss"], np.floating)): raise RuntimeError("loss is supposed to be in the logs and float.") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( make_cluster(3, 2), variable_partitioner=sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): model = sequential.Sequential([core_layers.Dense(10)]) if with_normalization_layer: norm = keras.layers.BatchNormalization( axis=-1, input_shape=(4, 4, 3), momentum=0.8) model.add(norm) model.compile( gradient_descent.SGD(), loss="mse", steps_per_execution=steps_per_execution, run_eagerly=run_eagerly) return model, [ResultAssertingCallback()]
def testPartitionWhenLackOfInfo(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2)) with strategy.scope(): initializer = init_ops_v2.Constant([0, 1, 2, 3]) # Shape is not explicitly specified. v1 = variables.Variable(initial_value=lambda: initializer( shape=(4, ), dtype=dtypes.int64), dtype=dtypes.int64) # Dtype is not explicitly specified. v2 = variables.Variable(initial_value=lambda: initializer( shape=(4, ), dtype=dtypes.int64), shape=(4, )) # Neither shape nor dtype is explicitly specified. v3 = variables.Variable(initial_value=lambda: initializer( shape=(4, ), dtype=dtypes.int64)) for v in [v1, v2, v3]: self.assertIsInstance(v, sharded_variable.ShardedVariable) self.assertLen(v.variables, 2) self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") self.assertAllEqual(v.variables[0], [0, 1]) self.assertAllEqual(v.variables[1], [2, 3])
def _model_compile(self, strategy, steps_per_execution=1, run_eagerly=False, with_normalization_layer=False): class ResultAssertingCallback(callbacks_lib.Callback): def __init__(self): self._prev_epoch = -1 self._loss_to_compare_against = 2 # Empirical initial value def on_epoch_end(self, epoch, logs=None): logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs) if epoch <= self._prev_epoch: raise RuntimeError("Epoch is supposed to be larger than previous.") self._prev_epoch = epoch is_loss_float = ( logs.get("loss", None) is not None and isinstance(logs["loss"], (float, np.floating))) if not is_loss_float: raise RuntimeError("loss is supposed to be in the logs and float.") if epoch == 0 or epoch == 9: # Making sure the loss of first epoch is below 1, and that of last # epoch is smaller than the first epoch. if logs["loss"] > self._loss_to_compare_against: raise RuntimeError( "loss at epoch {} is larger than previous.".format(epoch)) self._loss_to_compare_against = logs["loss"] def on_train_end(self, logs=None): if self._prev_epoch != 9: raise RuntimeError("Unexpected last epoch: {}".format( self._prev_epoch)) # TODO(b/182193218): Use ParameterServerStrategy as a proper strategy # combination. if strategy == "ParameterServerStrategy": gpu_devices = config.list_physical_devices("GPU") if len(gpu_devices) > 1: self.skipTest("b/178452835: Multi-GPUs not supported in " "ParameterServerStrategy.") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( multi_worker_testing_utils.make_parameter_server_cluster(3, 2), variable_partitioner=sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): model = sequential.Sequential([core_layers.Dense(10)]) if with_normalization_layer: norm = keras.layers.BatchNormalization( axis=-1, input_shape=(4, 4, 3), momentum=0.8) model.add(norm) model.compile( gradient_descent.SGD(), loss="mse", steps_per_execution=steps_per_execution, run_eagerly=run_eagerly) return model, [ResultAssertingCallback()]
def testSparselyReadForEmbeddingLookup(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) class FakeModel(module.Module): def __init__(self): self._var0 = variables.Variable([1.0, 2.0, 3.0, 4.0]) self._var1 = variables.Variable([5.0, 6.0, 7.0, 8.0]) @def_function.function(input_signature=[ tensor_spec.TensorSpec(shape=[2], dtype=dtypes.int32, name="inputs") ]) def func(self, x): return embedding_ops.embedding_lookup([self._var0, self._var1], x) with strategy.scope(): model = FakeModel() # Assert that ResourceGather op exists instead of Gather in training # function. found_resource_gather = False found_gather = False for n in model.func.get_concrete_function().graph.as_graph_def().node: if n.op == "ResourceGather": found_resource_gather = True elif n.op == "Gather": found_gather = True self.assertTrue(found_resource_gather) self.assertFalse(found_gather) # Assert that ResourceGather op exists instead of Gather in saved_model. found_resource_gather = False found_gather = False tmp_dir = self.get_temp_dir() save.save(model, tmp_dir, signatures=model.func) with gfile.Open("%s/saved_model.pb" % tmp_dir, "rb") as f: saved_model_proto = saved_model_pb2.SavedModel().FromString( f.read()) for function in saved_model_proto.meta_graphs[ 0].graph_def.library.function: for n in function.node_def: if n.op == "ResourceGather": found_resource_gather = True resource_gather_device = n.device elif n.op == "Gather": found_gather = True self.assertTrue(found_resource_gather) self.assertFalse(found_gather) # We also assert that the colocate_with in embedding_ops will not result in # a hard-coded device string. self.assertEmpty(resource_gather_device)
def testLessThanOneWorker(self): cluster_def = multi_worker_test_base.create_cluster_spec( num_workers=0, num_ps=1, has_chief=True) cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0) with self.assertRaisesRegexp(ValueError, "There must be at least one worker."): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def _create_strategy(self, num_shards): if num_shards > 1: strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, variable_partitioner=sharded_variable.FixedShardsPartitioner( num_shards)) else: strategy = ds_context._get_default_strategy() return strategy
def testArbitraryCurrentTaskType(self): cluster_def = multi_worker_test_base._create_cluster( num_workers=1, num_ps=1) cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar") with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def testArbitraryJobName(self): cluster_def = multi_worker_test_base.create_cluster_spec( num_workers=1, num_ps=1, has_chief=True) cluster_def["some_arbitrary_name"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc") with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def test_load_with_partitioner_raises_error(self): model = self.Model() model_dir = self.get_temp_dir() tf.saved_model.save(model, model_dir) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, tf1.fixed_size_partitioner(2)) with self.assertRaisesRegex(ValueError, "`variable_partitioner`"): with strategy.scope(): tf.saved_model.load(model_dir)
def test_load_with_partitioner_raises_error(self): model = self.Model() model_dir = self.get_temp_dir() tf.saved_model.save(model, model_dir) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, tf1.fixed_size_partitioner(2)) with self.assertRaises(errors_impl.InvalidArgumentError): with strategy.scope(): tf.saved_model.load(model_dir)
def testDistributeDatasetFromFunctionNotUsedWithClusterCoordinator(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) def dataset_fn(_): return dataset_ops.DatasetV2.range(3) with self._assertRaisesUsageError(): def_function.function( lambda: strategy.distribute_datasets_from_function(dataset_fn))()
def make_client(num_workers, num_ps): cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc") return client_lib.Client( parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver))
def test_sharded_variable(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, tf1.fixed_size_partitioner(2)) model_dir = self.get_temp_dir() with strategy.scope(): m = self.Model() self.assertIsInstance(m.v1, sharded_variable.ShardedVariable) m.train() tf.saved_model.save(m, model_dir) self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6])
def testLessThanOneWorker(self): cluster_def = multi_worker_test_base._create_cluster( num_workers=0, num_ps=1) cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0) with self.assertRaisesRegexp(ValueError, "There must be at least one worker."): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def make_coordinator(num_workers, num_ps): # TODO(rchao): Test the internal rpc_layer version. cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc') cluster_def['chief'] = [ 'localhost:%d' % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer='grpc') strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) return coordinator_lib.ClusterCoordinator(strategy)
def testDistributeDatasetUsedDirectly(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) dataset = dataset_ops.DatasetV2.range(3) distributed_dataset = strategy.experimental_distribute_dataset(dataset) with self.assertRaises(ValueError): iter(distributed_dataset) distributed_dataset = strategy.distribute_datasets_from_function( lambda: dataset) with self.assertRaises(ValueError): iter(distributed_dataset)
def testRunUsedWithTestOnlyMode(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) strategy.extended._allow_run_without_coordinator = True dataset = dataset_ops.DatasetV2.range(15) with strategy.scope(): v = variables.Variable(1, dtype=dtypes.int64) def step_fn(iterator): return next(iterator) + v strategy.run(step_fn, args=(iter(dataset),))
def testRunNotUsedWithClusterCoordinator(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver) dataset = dataset_ops.DatasetV2.range(8) with strategy.scope(): v = variables.Variable(1, dtype=dtypes.int64) def step_fn(iterator): return next(iterator) + v with self._assertRaisesUsageWarningWithSchedule(): strategy.run(step_fn, args=(iter(dataset),))