def test_basic_run(self, input_arg): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while (not maintenance_event.is_set()) and ( not training_finished.is_set()): time.sleep(1) # wait for all cluster to exit with a time out waiting_time = 0 exit_process_count = 0 # this addition to mitigate the fact that our step time is too short in test while exit_process_count != CLUSTER_SIZE and waiting_time < 15: exit_process_count = 0 for worker_id in range(CLUSTER_SIZE): if not mpr.process_exists('worker', worker_id): exit_process_count += 1 waiting_time += 1 time.sleep(1) if waiting_time >= 15: raise RuntimeError( 'Waited long but at least one worker still exist. ' 'Considering size of our model, this should not' ' happen.') if not training_finished.is_set(): logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) self.assertTrue(training_finished.is_set())
def test_grace_period_continue_training(self): grace_period = 5 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( time_till_termination=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, [training_started_event], None, training_restarted, training_finished, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) killed_worker = random.randrange(0, CLUSTER_SIZE) logging.info('sending SIGTERM') os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('SIGTERM sent') # wait for all cluster within the given grace period (plus a buffer since # our per-step time here is too small) waiting_time = 0 exit_process_count = 0 while exit_process_count != CLUSTER_SIZE and waiting_time < grace_period + 10: exit_process_count = 0 for worker_id in range(CLUSTER_SIZE): if not mpr.process_exists('worker', worker_id): exit_process_count += 1 waiting_time += 1 time.sleep(1) if waiting_time == grace_period + 10: raise RuntimeError('Waited exceeding grace period. ') logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250)
def test_basic_run(self): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, maintenance_event, training_finished), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while (not maintenance_event.is_set()) and (not training_finished.is_set()): time.sleep(1) time.sleep(5) if not training_finished.is_set(): logging.info('restarting workers') for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') stdout = mpr.join().stdout if maintenance_event.is_set(): all_start_point = [] for msg in stdout: matched_group = re.search(r'.*Start training at (\d+)', msg) if matched_group: all_start_point.append(int(matched_group.group(1))) # remove duplicate logs created due to presence of multiple workers start_points = all_start_point[::CLUSTER_SIZE] if len(start_points) > 1: # assert that after restarting, we don't repeat previous training steps self.assertNotEqual(start_points[-1], 0)
def test_creating_variable(self): # See PeerFailureTest.test_creating_variable def worker_fn(attempts): context.context().enable_coordination_service(COORDINATION_SERVICE) strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() task_id, attempt = get_attempt(strategy, attempts) with strategy.scope(): tf.Variable(1.) # worker-1 dies here. if attempt == 1 and task_id == 1: quick_exit(1) v = tf.Variable(tf.random.uniform(())) return v.read_value().numpy() cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=2) attempts = multi_process_runner.manager().dict() mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec, rpc_layer=RPC_PROTOCOL, args=(attempts, ), auto_restart=True) mpr.start() results = mpr.join(timeout=90).return_value self.assertEqual(results[0], results[1])
def test_reduce_small_tensor_broken(self): # This test simulates the case when a worker fails before or during reducing # a small tensors, e.g. reading a metric. # # Note that this is a rather corner case and only happens when all of the # following conditions are met: # - There're two workers. # - They're reducing a small tensor. The definition of small varies # per platform. # - They're reducing a single tensor. Batched all-reduce are not affected. # - It must be worker-1 that fails. def worker_fn(attempts): strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() task_id, attempt = get_attempt(strategy, attempts) value = tf.identity([1.]) strategy.reduce("sum", value, axis=None) # worker-1 dies here. if attempt == 1 and task_id == 1: quick_exit(1) strategy.reduce("sum", value, axis=None) cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=2) attempts = multi_process_runner.manager().dict() mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec, args=(attempts, ), auto_restart=True) mpr.start() # TODO(b/151232436): worker-0 should raises Unavailable instead of hanging. # Now after worker-1 fails, worker-0 waits on the second reduce; after # worker-1 recovers, worker-1 waits on the first reduce. with self.assertRaises(multi_process_runner.SubprocessTimeoutError): mpr.join(timeout=30)
def test_reduce_small_tensor(self): # See PeerFailureTest.test_reduce_small_tensor def worker_fn(attempts): context.context().enable_coordination_service(COORDINATION_SERVICE) strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() task_id, attempt = get_attempt(strategy, attempts) value = tf.identity([1.]) strategy.reduce("sum", value, axis=None) # worker-1 dies here. if attempt == 1 and task_id == 1: quick_exit(1) return strategy.reduce("sum", value, axis=None).numpy() cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=2) attempts = multi_process_runner.manager().dict() mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec, rpc_layer=RPC_PROTOCOL, args=(attempts, ), auto_restart=True) mpr.start() results = mpr.join(timeout=90).return_value self.assertAllEqual(results, [[2.], [2.]])
def test_creating_variable_broken(self): # This test simulates the case when a worker fails before or during creating # a variable. Creating variables involve broadcasting the initial value from # the first replica to all replicas. def worker_fn(attempts): strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() task_id, attempt = get_attempt(strategy, attempts) with strategy.scope(): tf.Variable(1.) # worker-1 dies here. if attempt == 1 and task_id == 1: quick_exit(1) v = tf.Variable(tf.random.uniform(())) return v.read_value().numpy() cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=2) attempts = multi_process_runner.manager().dict() mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec, args=(attempts, ), auto_restart=True) mpr.start() # TODO(b/151232436): worker-0 should raises Unavailable instead of hanging. # Now after worker-1 fails, worker-0 waits on the second variable creation; # after worker-1 recovers, worker-1 waits on the first variable creation. with self.assertRaises(multi_process_runner.SubprocessTimeoutError): mpr.join(timeout=30)
def test_grace_period_continue_training(self, input_arg): grace_period = 7 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while (not maintenance_event.is_set()) and ( not training_finished.is_set()): time.sleep(1) # this addition to mitigate the fact that our step time is too short in test time.sleep(grace_period + 10) if not training_finished.is_set(): logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) self.assertTrue(training_finished.is_set())
def test_grace_period_continue_training(self, input_arg): grace_period = 5 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) killed_worker = random.randrange(0, CLUSTER_SIZE) logging.info('sending SIGTERM') os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('SIGTERM sent') raise_if_not_all_exit(grace_period, mpr) logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250)
def test_preemption_checkpointing(self, input_arg): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, [training_started_event ], None, training_restarted, training_finished), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) logging.info('sending sigterm') killed_worker = random.randrange(0, CLUSTER_SIZE) os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('sigterm sent') time.sleep(5) logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=270)
def __init__(self, cluster_resolver, stream_output=False, collective_leader=None): self._cluster_resolver = cluster_resolver self._cluster_spec = cluster_resolver.cluster_spec().as_dict() self._rpc_layer = cluster_resolver.rpc_layer self._stream_output = stream_output self._start_events = {} self._finish_events = {} self._mpr_manager = multi_process_runner.manager() def task_function(start_events, finish_events): cluster_resolver = TFConfigClusterResolver() cluster_spec = cluster_resolver.cluster_spec() task_type = cluster_resolver.task_type task_id = cluster_resolver.task_id rpc_layer = cluster_resolver.rpc_layer # TODO(yuefengz): support GPU clusters. server_config = config_pb2.ConfigProto() server_config.device_count['GPU'] = 0 if collective_leader: server_config.experimental.collective_group_leader = collective_leader server_config.experimental.collective_nccl = False logging.info( 'Enabling collective ops with cluster_spec = %r, task_type = %r, ' 'task_id = %r, rpc_layer = %r, collective_leader = %s', cluster_spec, task_type, task_id, rpc_layer, collective_leader) else: logging.info( 'Starting server with cluster_spec = %r, task_type = %r, ' 'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id, rpc_layer) server_lib.Server(cluster_spec, job_name=task_type, protocol=rpc_layer, task_index=task_id, config=server_config, start=True) start_event = start_events[task_type][task_id] start_event.set() finish_event = finish_events[task_type][task_id] finish_event.wait() os._exit(0) # pylint: disable=protected-access self._task_function = task_function self._mpr = None
def test_basic_run(self, input_arg): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while (not maintenance_event.is_set()) and ( not training_finished.is_set()): time.sleep(1) raise_if_not_all_exit(0, mpr) if not training_finished.is_set(): logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) self.assertTrue(training_finished.is_set())
def test_multiple_workers_preempted_consecutively(self, grace_period, input_arg): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, True, training_restarted, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() raise_if_not_all_exit(grace_period, mpr) maintenance_event.set() logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) self.assertTrue(training_finished.is_set())
def test_preemption_checkpointing(self): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, [training_started_event]), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) logging.info('sending sigterm') killed_worker = random.randrange(0, CLUSTER_SIZE) os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('sigterm sent') time.sleep(5) logging.info('restarting workers') for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') stdout = mpr.join().stdout all_start_point = [] for msg in stdout: matched_group = re.search(r'.*Restored training at (\d+)', msg) if matched_group: all_start_point.append(int(matched_group.group(1))) # remove duplicate logs created due to presence of multiple workers start_points = all_start_point[::CLUSTER_SIZE] # assert that after restarting, we don't repeat previous training steps self.assertNotEqual(start_points[-1], 0)
def test_auto_restart(self): def proc_func(counter): counter.value += 1 if counter.value == 1: raise ValueError manager = multi_process_runner.manager() counter = manager.Value(int, 0) mpr = multi_process_runner.MultiProcessRunner( proc_func, multi_worker_test_base.create_cluster_spec(num_workers=1), args=(counter, ), auto_restart=True) mpr.start() mpr.join() self.assertEqual(counter.value, 2)
def __init__(self, cluster_resolver): self._cluster_resolver = cluster_resolver self._cluster_spec = cluster_resolver.cluster_spec().as_dict() self._rpc_layer = cluster_resolver.rpc_layer self._start_events = {} self._finish_events = {} self._mpr_manager = multi_process_runner.manager() def task_function(start_events, finish_events): cluster_resolver = TFConfigClusterResolver() cluster_spec = cluster_resolver.cluster_spec() task_type = cluster_resolver.task_type task_id = cluster_resolver.task_id rpc_layer = cluster_resolver.rpc_layer logging.info( 'Starting server with cluster_spec = %r, task_type = %r, ' 'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id, rpc_layer) # TODO(yuefengz): support GPU clusters. server_config = config_pb2.ConfigProto() server_config.device_count['GPU'] = 0 # Set the environment variable to prevent hanging upon job failure and # restart. Note that it defaults to 'use_caller' at Google, but defaults # to False in OSS. os.environ['GRPC_FAIL_FAST'] = 'use_caller' server_lib.Server( cluster_spec, job_name=task_type, protocol=rpc_layer, task_index=task_id, config=server_config, start=True) start_event = start_events[task_type][task_id] start_event.set() finish_event = finish_events[task_type][task_id] finish_event.wait() os._exit(0) # pylint: disable=protected-access self._task_function = task_function self._mpr = None
def test_auto_restart_and_chief(self): # If the chief has exited with zero exit code, auto restart should stop # restarting other tasks even if they fail. def proc_func(): time.sleep(1) if multi_worker_test_base.get_task_type() != 'chief': raise ValueError manager = multi_process_runner.manager() mpr = multi_process_runner.MultiProcessRunner( proc_func, multi_worker_test_base.create_cluster_spec(has_chief=True, num_workers=1), auto_restart=True) mpr.start() with self.assertRaises(ValueError): mpr.join(timeout=10)
def test_numpy_fetched_after_worker_failure(self): def fn(first_fetch_occurred_event, worker_terminated_event): os.environ["GRPC_FAIL_FAST"] = "use_caller" cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) ps_coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): v = variables.Variable(initial_value=0, dtype=dtypes.int32) @def_function.function def worker_fn(): return v + 1, v - 1 remote_value = ps_coordinator.schedule(worker_fn) logging.info("result (1st fetch): %r", remote_value.fetch()) first_fetch_occurred_event.set() worker_terminated_event.wait() logging.info("result (2nd fetch): %r", remote_value.fetch()) manager = multi_process_runner.manager() first_fetch_occurred_event = manager.Event() worker_terminated_event = manager.Event() mpr = multi_process_runner.MultiProcessRunner( fn, multi_worker_test_base.create_cluster_spec(has_chief=True, num_workers=1, num_ps=1, has_eval=False), args=(first_fetch_occurred_event, worker_terminated_event), rpc_layer="grpc", return_output=True, use_dill_for_args=False) mpr.start() first_fetch_occurred_event.wait() mpr.terminate("worker", 0) worker_terminated_event.set() self.assertTrue( any("result (2nd fetch)" in msg for msg in mpr.join().stdout))
def test_quick_recover(self): # This test simulates the case when a worker fails but recovers quickly # before the next collective. # # It's not guaranteed that the cluster only restarts once when one worker # fails. The external job management system is expected to keep restarting # failed workers. def worker_fn(attempts): # Set a long check alive interval to better simulate the case when a # worker fails and recovers during a check alive interval. mwms_lib.CollectiveAllReduceExtended._check_alive_interval = 30 mwms_lib.CollectiveAllReduceExtended._check_alive_initial_timeout = 30 context.context().configure_coordination_service( COORDINATION_SERVICE) strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() task_id, attempt = get_attempt(strategy, attempts) @tf.function def replica_fn(): ctx = tf.distribute.get_replica_context() # Use a large tensor because small tensor may hang regardless when the # worker recovers. value = tf.ones((64, 64)) ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value, value]) strategy.run(replica_fn) # worker-1 dies here. if attempt == 1 and task_id == 1: quick_exit(1) strategy.run(replica_fn) cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=2) attempts = multi_process_runner.manager().dict() mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec, rpc_layer=RPC_PROTOCOL, args=(attempts, ), auto_restart=True) mpr.start() mpr.join(timeout=90)
def test_quick_recover(self): # This test simulates the case when a worker fails but recovers quickly # before the next collective. # # It's not guaranteed that the cluster only restarts once when one worker # fails. The external job management system is expected to keep restarting # failed workers. def worker_fn(attempts): strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() task_id, attempt = get_attempt(strategy, attempts) if attempt == 2 and task_id == 1: multi_process_runner.barrier().wait() @tf.function def replica_fn(): ctx = tf.distribute.get_replica_context() # Use a large tensor because small tensor may hang regardless when the # worker recovers. value = tf.ones((64, 64)) ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value, value]) strategy.run(replica_fn) # worker-1 dies here. if attempt == 1 and task_id == 1: quick_exit(1) # Make worker-0 waits for worker-1 to restart before entering the next # collective to simulate a quick recovery of worker-1. if attempt == 1 and task_id == 0: multi_process_runner.barrier().wait() strategy.run(replica_fn) cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=2) attempts = multi_process_runner.manager().dict() mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec, args=(attempts, ), auto_restart=True) mpr.start() mpr.join(timeout=90)
def test_auto_restart_terminate(self): # Tasks terminated by the user should also be restarted. def proc_func(counter): counter.value += 1 if counter.value == 1: time.sleep(100) manager = multi_process_runner.manager() counter = manager.Value(int, 0) mpr = multi_process_runner.MultiProcessRunner( proc_func, multi_worker_test_base.create_cluster_spec(has_chief=False, num_workers=1), args=(counter, ), auto_restart=True) mpr.start() time.sleep(3) mpr.terminate('worker', 0) mpr.join(timeout=20) self.assertEqual(counter.value, 2)
def test_grace_period_continue_training(self, input_arg, mwms_mode): checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') grace_period = 7 if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) if mwms_mode == 'multi_worker': has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while (not maintenance_event.is_set()) and ( not training_finished.is_set()): time.sleep(1) raise_if_not_all_exit(grace_period, mpr) if not training_finished.is_set(): logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) self.assertTrue(training_finished.is_set()) else: maintenance_event = threading.Event() training_finished = threading.Event() training_restarted = threading.Event() cluster_spec = server_lib.ClusterSpec({}) caught_exit = False try: self.worker_fn(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted, termination_config) except SystemExit as exit_error: caught_exit = True # We cannot use assertRaise instead, since termination is not always # triggered. self.assertEqual(exit_error.code, 143) # pylint:disable=g-assert-in-except if maintenance_event.is_set() and not training_finished.is_set(): self.assertTrue(caught_exit) logging.info('restarting workers') training_restarted.set() self.worker_fn(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted, termination_config) self.assertTrue(training_finished.is_set())
def test_multiple_workers_preempted_consecutively(self, grace_period): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( time_till_termination=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, maintenance_event, training_finished, True, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() # wait for all cluster to exit with a time out waiting_time = 0 exit_process_count = 0 # this addition to mitigate the fact that our step time is too short in test while exit_process_count != CLUSTER_SIZE and waiting_time < max( grace_period + 15, 40): exit_process_count = 0 for worker_id in range(CLUSTER_SIZE): if not mpr.process_exists('worker', worker_id): exit_process_count += 1 waiting_time += 1 time.sleep(1) if waiting_time == max(grace_period + 5, 40): raise RuntimeError( 'Waited long but at least one worker still exist. ' 'Considering size of our model, this should not' ' happen.') maintenance_event.set() logging.info('restarting workers') for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') stdout = mpr.join(timeout=250).stdout found_message = 0 checkpoint_count = [] for msg in stdout: matched_group = re.search(r'.*has received termination notice*', msg) checkpoint_group = re.search(r'.*RUN_TO_CHECKPOINT set to (\d+)', msg) if matched_group: found_message += 1 if checkpoint_group: checkpoint_count.append(int(checkpoint_group.group(1))) self.assertGreaterEqual(found_message, 1) if grace_period > 0: self.assertLen(set(checkpoint_count), 2)
def _testStrategyRun(self, failure_task_type): def fn(functions_scheduled_event): # TODO(b/170664373): This is needed for TF2 parameter server training in # OSS. Remove this when resolved. os.environ["GRPC_FAIL_FAST"] = "use_caller" cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) ps_client = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): v = variables.Variable(initial_value=1) @def_function.function def worker_fn(input_tensor): def replica_fn(input_tensor): return input_tensor + v run_result = strategy.run(replica_fn, args=(input_tensor, )) check_ops.assert_equal_v2(run_result, 4) return run_result for i in range(5000): if i % 500 == 0: logging.info("Scheduling function-{}...".format(i)) result = ps_client.schedule(worker_fn, args=(constant_op.constant(3), )) functions_scheduled_event.set() logging.info("Joining...") ps_client.join() logging.info("Finished joining.") if result.fetch() != 4: raise AssertionError( "Unexpected RemoteValue result: {}".format(result.fetch())) logging.info("testStrategyRun succeeded") manager = multi_process_runner.manager() functions_scheduled_event = manager.Event() mpr = multi_process_runner.MultiProcessRunner( fn, multi_worker_test_base.create_cluster_spec(has_chief=True, num_workers=1, num_ps=1, has_eval=False), args=(functions_scheduled_event, ), rpc_layer="grpc", return_output=True) mpr.start() if failure_task_type is not None: functions_scheduled_event.wait() logging.info("Before interrupting {}-0.".format(failure_task_type)) mpr.terminate(failure_task_type, 0) if failure_task_type == "ps": with self.assertRaises(errors.UnavailableError): mpr.join() return time.sleep(10) logging.info("Before restarting {}-0.".format(failure_task_type)) mpr.start_single_process(task_type="worker", task_id=0) self.assertTrue( any([ "testStrategyRun succeeded" in msg for msg in mpr.join().stdout ]))
def test_grace_period_continue_training(self): grace_period = 5 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( time_till_termination=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, [training_started_event], None, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) killed_worker = random.randrange(0, CLUSTER_SIZE) logging.info('sending SIGTERM') os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('SIGTERM sent') # wait for all cluster within the given grace period (plus a buffer since # our per-step time here is too small) waiting_time = 0 exit_process_count = 0 while exit_process_count != CLUSTER_SIZE and waiting_time < grace_period + 10: exit_process_count = 0 for worker_id in range(CLUSTER_SIZE): if not mpr.process_exists('worker', worker_id): exit_process_count += 1 waiting_time += 1 time.sleep(1) if waiting_time == grace_period + 10: raise RuntimeError('Waited exceeding grace period. ') logging.info('restarting workers') for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') stdout = mpr.join(timeout=250).stdout all_start_point = [] checkpoint_count = [] for msg in stdout: # TODO(wxinyi): remove the string matching and assert checkpoint number. matched_group = re.search(r'.*Restored training at (\d+)', msg) checkpoint_group = re.search(r'.*RUN_TO_CHECKPOINT set to (\d+)', msg) if matched_group: all_start_point.append(int(matched_group.group(1))) if checkpoint_group: checkpoint_count.append(int(checkpoint_group.group(1))) # remove duplicate logs created due to presence of multiple workers start_points = all_start_point[::CLUSTER_SIZE] # assert that after restarting, we don't repeat previous training steps self.assertNotEqual(start_points[-1], 0) # One for timing, another for final call. self.assertLen(set(checkpoint_count), 2)
def test_grace_period_continue_training(self): grace_period = 7 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( time_till_termination=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, maintenance_event, training_finished, False, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while (not maintenance_event.is_set()) and ( not training_finished.is_set()): time.sleep(1) # this addition to mitigate the fact that our step time is too short in test time.sleep(grace_period + 10) if not training_finished.is_set(): logging.info('restarting workers') for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') if maintenance_event.is_set(): stdout = mpr.join(timeout=250).stdout all_start_point = [] checkpoint_count = [] for msg in stdout: matched_group = re.search(r'.*Start training at (\d+)', msg) checkpoint_group = re.search( r'.*RUN_TO_CHECKPOINT set to (\d+)', msg) if matched_group: all_start_point.append(int(matched_group.group(1))) if checkpoint_group: checkpoint_count.append(int(checkpoint_group.group(1))) # remove duplicate logs created due to presence of multiple workers start_points = all_start_point[::CLUSTER_SIZE] # if maintenance_event is set at the very end of training and training # completes, there won't be a restart. if len(start_points) > 1: # assert that after restarting, we don't repeat previous training steps self.assertNotEqual(start_points[-1], 0) # One for timing, another for final call. self.assertLen(set(checkpoint_count), 2)
def test_two_workers_preempted_consecutively(self): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, maintenance_event, training_finished, True), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() time.sleep(5) # wait for all cluster to exit with a time out waiting_time = 0 exit_process_count = 0 while exit_process_count != CLUSTER_SIZE and waiting_time < 40: exit_process_count = 0 for worker_id in range(CLUSTER_SIZE): if not mpr.process_exists('worker', worker_id): exit_process_count += 1 waiting_time += 1 time.sleep(1) if waiting_time == 100: raise RuntimeError('Waited long but at least one worker still exist. ' 'Considering size of our model, this should not' ' happen.') maintenance_event.set() logging.info('restarting workers') for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') stdout = mpr.join().stdout found_message = 0 for msg in stdout: matched_group = re.search(r'.*has received termination notice*', msg) if matched_group: found_message += 1 self.assertGreaterEqual(found_message, 1)
def test_grace_period_continue_training(self, input_arg, mwms_mode): if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if mwms_mode == 'multi_worker': grace_period = 5 termination_config = failure_handling.TerminationConfig( grace_period=grace_period) has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) killed_worker = random.randrange(0, CLUSTER_SIZE) logging.info('sending SIGTERM') os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('SIGTERM sent') raise_if_not_all_exit(grace_period, mpr) logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) else: # This is because single worker trains super fast with regards to the size # of "model" here. With a longer grace period, the training just finishes # within the grace period so we can't verify the exit behavior. grace_period = 1 termination_config = failure_handling.TerminationConfig( grace_period=grace_period) cluster_spec = server_lib.ClusterSpec({}) training_started_event = threading.Event() training_restarted = threading.Event() training_finished = threading.Event() def sending_sigterm(training_started_event): while not training_started_event.is_set(): time.sleep(1) logging.info('sending sigterm') training_started_event.set() os.kill(os.getpid(), signal.SIGTERM) preemption_sender_thread = threading.Thread( target=sending_sigterm, args=(training_started_event,)) preemption_sender_thread.start() caught_exit = False try: self.worker_fn(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished, termination_config) except SystemExit as exit_error: caught_exit = True # We cannot use assertRaise instead, since termination is not always # triggered. self.assertEqual(exit_error.code, 42) # pylint: disable=g-assert-in-except preemption_sender_thread.join(10) if not training_finished.is_set(): self.assertTrue(caught_exit) logging.info('restarting workers') training_restarted.set() self.worker_fn(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished, termination_config)
def _test_translate_ps_failure_error(self, test_schedule=False, test_join=False): def proc_func(functions_scheduled_event, test_finished_event): cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) ps_client = client_lib.Client(strategy) with strategy.scope(): v = variables.Variable(initial_value=0, dtype=dtypes.int32) @def_function.function def worker_fn(): # An ever-running function. for _ in math_ops.range(100000): v.assign_add(1) # Keep the two workers occupied. ps_client.schedule(worker_fn) ps_client.schedule(worker_fn) # Now the main process can terminate. functions_scheduled_event.set() # Verified that join and schedule indeed raise # ParameterServerFailureError. try: if test_join: ps_client.join() if test_schedule: while ps_client.cluster._closure_queue._error is None: time.sleep(1) ps_client.schedule(worker_fn) except client_lib.ParameterServerFailureError: # The following verifies that after PS fails, continue executing # functions on workers should fail and indicate it's PS failure. for worker_id in range(3): with ops.device("/job:worker/replica:0/task:{}".format(worker_id)): try: # Executing a function after PS fails should result in a PS # failure. worker_fn() except Exception as e: # pylint: disable=broad-except if client_lib._is_ps_failure(e): if worker_id < 2: continue logging.info("_test_translate_ps_failure_error ends properly.") # Now we can safely exit the test. test_finished_event.set() return raise RuntimeError("Executing a function after PS fails, should " "result in a PS failure.") raise RuntimeError("ParameterServerFailureError supposed to be raised.") manager = multi_process_runner.manager() functions_scheduled_event = manager.Event() test_finished_event = manager.Event() mpr = multi_process_runner.MultiProcessRunner( proc_func, multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=3, num_ps=1, has_eval=False), args=(functions_scheduled_event, test_finished_event), rpc_layer="grpc", list_stdout=True, use_dill_for_args=False) mpr.start() functions_scheduled_event.wait() mpr.terminate("ps", 0) while mpr.process_exists("ps", 0): time.sleep(0.01) test_finished_event.wait() self.assertTrue( any("_test_translate_ps_failure_error ends properly" in msg for msg in mpr.join().stdout))
def test_preemption_checkpointing(self, input_arg, mwms_mode): has_chief = False if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if mwms_mode == 'multi_worker': cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, [training_started_event ], None, training_restarted, training_finished), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) logging.info('sending sigterm') killed_worker = random.randrange(0, CLUSTER_SIZE) os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('sigterm sent') raise_if_not_all_exit(0, mpr) logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=270) else: cluster_spec = server_lib.ClusterSpec({}) training_started_event = threading.Event() training_restarted = threading.Event() training_finished = threading.Event() def sending_sigterm(training_started_event): while not training_started_event.is_set(): time.sleep(1) logging.info('sending sigterm') training_started_event.set() os.kill(os.getpid(), signal.SIGTERM) preemption_sender_thread = threading.Thread( target=sending_sigterm, args=(training_started_event,)) preemption_sender_thread.start() caught_exit = False try: self.worker_fn(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished) except SystemExit as exit_error: caught_exit = True # We cannot use assertRaise instead, since termination is not always # triggered. self.assertEqual(exit_error.code, 42) # pylint: disable=g-assert-in-except preemption_sender_thread.join(10) if not training_finished.is_set(): self.assertTrue(caught_exit) logging.info('restarting workers') training_restarted.set() self.worker_fn(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished)