def test_multiple_workers_preempted_consecutively(self, grace_period, input_arg): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, True, training_restarted, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() # wait for all cluster to exit with a time out waiting_time = 0 exit_process_count = 0 # this addition to mitigate the fact that our step time is too short in test while exit_process_count != CLUSTER_SIZE and waiting_time < max( grace_period + 15, 40): exit_process_count = 0 for worker_id in range(CLUSTER_SIZE): if not mpr.process_exists('worker', worker_id): exit_process_count += 1 waiting_time += 1 time.sleep(1) if waiting_time == max(grace_period + 5, 40): raise RuntimeError( 'Waited long but at least one worker still exist. ' 'Considering size of our model, this should not' ' happen.') maintenance_event.set() logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) self.assertTrue(training_finished.is_set())
def test_grace_period_continue_training(self, input_arg): grace_period = 5 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) killed_worker = random.randrange(0, CLUSTER_SIZE) logging.info('sending SIGTERM') os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('SIGTERM sent') # wait for all cluster within the given grace period (plus a buffer since # our per-step time here is too small) waiting_time = 0 exit_process_count = 0 while exit_process_count != CLUSTER_SIZE and waiting_time < grace_period + 10: exit_process_count = 0 for worker_id in range(CLUSTER_SIZE): if not mpr.process_exists('worker', worker_id): exit_process_count += 1 waiting_time += 1 time.sleep(1) if waiting_time == grace_period + 10: raise RuntimeError('Waited exceeding grace period. ') logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250)
def test_grace_period_continue_training(self, input_arg): grace_period = 7 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while (not maintenance_event.is_set()) and ( not training_finished.is_set()): time.sleep(1) # this addition to mitigate the fact that our step time is too short in test time.sleep(grace_period + 10) if not training_finished.is_set(): logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) self.assertTrue(training_finished.is_set())
def test_grace_period_continue_training(self, input_arg): grace_period = 5 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) killed_worker = random.randrange(0, CLUSTER_SIZE) logging.info('sending SIGTERM') os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('SIGTERM sent') raise_if_not_all_exit(grace_period, mpr) logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250)
def test_multiple_workers_preempted_consecutively(self, grace_period, input_arg): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, True, training_restarted, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() raise_if_not_all_exit(grace_period, mpr) maintenance_event.set() logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) self.assertTrue(training_finished.is_set())
def worker_fn( self, checkpoint_dir, cluster_spec, input_arg, maintenance_event=None, training_finished=None, frequent_send=False, training_restarted=None, termination_config=failure_handling.TerminationConfig(grace_period=0)): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() def mock_termination_watcher_function_gce(*args, **kwargs): del args, kwargs if not frequent_send: time.sleep(1) if (not maintenance_event.is_set()) and (random.randrange( 0, 7) == 5): maintenance_event.set() logging.info('Termination notice available.') return True elif frequent_send and not maintenance_event.is_set(): logging.info('Termination notice available.') return True return False with mock.patch.object( gce_util, 'termination_watcher_function_gce', mock_termination_watcher_function_gce), mock.patch.object( gce_util, 'detect_platform', lambda: gce_util.PlatformDevice.GCE_GPU): class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization. ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with strategy.scope(): model = Model() fh_ckpt = tracking_util.Checkpoint(model=model) if input_arg == 'checkpoint': checkpoint_or_manager = fh_ckpt else: checkpoint_or_manager = _make_checkpoint_manager( fh_ckpt, checkpoint_dir, strategy.cluster_resolver) preemption_handler = ( failure_handling.PreemptionCheckpointHandler( strategy.cluster_resolver, checkpoint_or_manager, checkpoint_dir, termination_config)) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Start training at %d', preemption_handler.total_run_calls) # If the training process has been restarted, verify that the expected # number of checkpoints have been written. # We also want to check training_finished, because there's a corner case # where the signal is sent quite late and training finishes before the # grace period ends. if training_restarted.is_set() and not training_finished.is_set(): match_group = [ re.search(r'.*ckpt-(\d+).index', a_file) for a_file in gfile.ListDirectory(checkpoint_dir) ] checkpoint_index = [ a_match.group(1) for a_match in match_group if a_match ] if termination_config.grace_period > 0: # Two checkpoints were saved for the extended grace period. self.assertEqual( max([ int(ckpt_index) for ckpt_index in checkpoint_index ]), 2) else: self.assertEqual( max([ int(ckpt_index) for ckpt_index in checkpoint_index ]), 1) for epoch in range( preemption_handler.total_run_calls // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range( preemption_handler.total_run_calls % STEPS_PER_EPOCH, STEPS_PER_EPOCH): preemption_handler.run(distributed_train_step, epoch, step) logging.info('Training finished.') training_finished.set() self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH) running_threads = test_util.get_running_threads() if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads) and test_util.has_thread( _LOCAL_WATCHER_THREAD_PREFIX, running_threads): try: # Explicitly call __del__ since making it None and gc.collect does # not invoke __del__ here. preemption_handler.__del__() time.sleep(2) running_threads = test_util.get_running_threads() self.assertFalse( test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX, running_threads)) self.assertFalse( test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads)) except urllib.error.URLError as e: if 'Temporary failure in name resolution' in e.message: # This is caused by a weird flakiness that mock.patch does not # correctly patch gce_util.request_compute_metadata, a real request # is attempted, and an error is hit in # gce_util.request_compute_metadata logging.warning('Hit a mock issue.') return
def worker_fn(self, checkpoint_dir, cluster_spec, input_arg='checkpoint', training_started_event=None, raise_app_error_on_worker=None, training_restarted=None, training_finished=None, termination_config=failure_handling.TerminationConfig()): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization. ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with mock.patch.object(gce_util, 'on_gcp', lambda: False): with strategy.scope(): model = Model() # Named it fh_ckpt because it'd be better that the user have their # regular checkpoint separate from the checkpoint for # PreemptionCheckpointHandler, since we will create CheckpointManager # to manage the checkpoint and only one CheckpointManager should be # active in a particular directory at a time. fh_ckpt = tracking_util.Checkpoint(model=model) if input_arg == 'checkpoint': checkpoint_or_manager = fh_ckpt else: checkpoint_or_manager = _make_checkpoint_manager( fh_ckpt, checkpoint_dir, strategy.cluster_resolver) preemption_handler = ( failure_handling.PreemptionCheckpointHandler( strategy.cluster_resolver, checkpoint_or_manager, checkpoint_dir, termination_config)) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): if distribution_strategy_context.get_distribution_strategy( ).cluster_resolver.task_id == raise_app_error_on_worker: raise errors_impl.ResourceExhaustedError( node_def=None, op=None, message='Running out of resources') model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Start training at %d', preemption_handler.total_run_calls) # If the training process has been restarted, verify that the expected # number of checkpoints have been written. # we also want to check training_finished, because there's a corner case # where the signal is sent quite late and training finishes before the # grace period ends. if training_restarted and training_restarted.is_set( ) and not training_finished.is_set(): logging.info('training restarted') match_group = [ re.search(r'.*ckpt-(\d+).index', a_file) for a_file in gfile.ListDirectory(checkpoint_dir) ] checkpoint_index = [ a_match.group(1) for a_match in match_group if a_match ] if getattr(termination_config, 'grace_period', 0): # Two checkpoints were saved for the extended grace period. self.assertEqual(int(checkpoint_index[0]), 2) else: self.assertEqual(int(checkpoint_index[0]), 1) for epoch in range( preemption_handler.total_run_calls // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range( preemption_handler.total_run_calls % STEPS_PER_EPOCH, STEPS_PER_EPOCH): preemption_handler.run(distributed_train_step, epoch, step) # Add some randomness to when preemption actually happens. We should # trigger it for sure if the training is coming to an end and it hasn't # been triggered yet. if epoch >= EPOCHS_TO_RUN - 2: trigger_it = True else: trigger_it = False self._maybe_trigger_a_preemption(training_started_event, trigger_it) training_finished.set() logging.info('Training finished.') self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)
def test_grace_period_continue_training(self, input_arg, mwms_mode): if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if mwms_mode == 'multi_worker': grace_period = 5 termination_config = failure_handling.TerminationConfig( grace_period=grace_period) has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) killed_worker = random.randrange(0, CLUSTER_SIZE) logging.info('sending SIGTERM') os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('SIGTERM sent') raise_if_not_all_exit(grace_period, mpr) logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) else: # This is because single worker trains super fast with regards to the size # of "model" here. With a longer grace period, the training just finishes # within the grace period so we can't verify the exit behavior. grace_period = 1 termination_config = failure_handling.TerminationConfig( grace_period=grace_period) cluster_spec = server_lib.ClusterSpec({}) training_started_event = threading.Event() training_restarted = threading.Event() training_finished = threading.Event() def sending_sigterm(training_started_event): while not training_started_event.is_set(): time.sleep(1) logging.info('sending sigterm') training_started_event.set() os.kill(os.getpid(), signal.SIGTERM) preemption_sender_thread = threading.Thread( target=sending_sigterm, args=(training_started_event,)) preemption_sender_thread.start() caught_exit = False try: self.worker_fn(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished, termination_config) except SystemExit as exit_error: caught_exit = True # We cannot use assertRaise instead, since termination is not always # triggered. self.assertEqual(exit_error.code, 42) # pylint: disable=g-assert-in-except preemption_sender_thread.join(10) if not training_finished.is_set(): self.assertTrue(caught_exit) logging.info('restarting workers') training_restarted.set() self.worker_fn(checkpoint_dir, cluster_spec, input_arg, [training_started_event], None, training_restarted, training_finished, termination_config)
def worker_fn(self, checkpoint_dir, cluster_spec, training_started_event=None, raise_app_error_on_worker=None, termination_config=failure_handling.TerminationConfig()): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization. ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with mock.patch.object(gce_util, 'on_gcp', lambda: False): with strategy.scope(): model = Model() # Named it fh_ckpt because it'd be better that the user have their # regular checkpoint separate from the checkpoint for # WorkerPreemptionHandler, since we will create CheckpointManager # to manage the checkpoint and only one CheckpointManager should be # active in a particular directory at a time. fh_ckpt = tracking_util.Checkpoint(model=model) worker_preemption_watcher = failure_handling.WorkerPreemptionHandler( strategy.cluster_resolver, fh_ckpt, checkpoint_dir, termination_config) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): if distribution_strategy_context.get_distribution_strategy( ).cluster_resolver.task_id == raise_app_error_on_worker: raise errors_impl.ResourceExhaustedError( node_def=None, op=None, message='Running out of resources') model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Restored training at %d', worker_preemption_watcher.total_runs) for epoch in range( worker_preemption_watcher.total_runs // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range( worker_preemption_watcher.total_runs % STEPS_PER_EPOCH, STEPS_PER_EPOCH): worker_preemption_watcher.run(distributed_train_step, epoch, step) # Add some randomness to when preemption actually happens. We should # trigger it for sure if the training is coming to an end and it hasn't # been triggered yet. if epoch >= EPOCHS_TO_RUN - 2: trigger_it = True else: trigger_it = False self._maybe_trigger_a_preemption(training_started_event, trigger_it) logging.info('Training finished.') self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)
def test_grace_period_continue_training(self): grace_period = 5 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) training_started_event = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( time_till_termination=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, [training_started_event], None, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while not training_started_event.is_set(): time.sleep(1) killed_worker = random.randrange(0, CLUSTER_SIZE) logging.info('sending SIGTERM') os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM) logging.info('SIGTERM sent') # wait for all cluster within the given grace period (plus a buffer since # our per-step time here is too small) waiting_time = 0 exit_process_count = 0 while exit_process_count != CLUSTER_SIZE and waiting_time < grace_period + 10: exit_process_count = 0 for worker_id in range(CLUSTER_SIZE): if not mpr.process_exists('worker', worker_id): exit_process_count += 1 waiting_time += 1 time.sleep(1) if waiting_time == grace_period + 10: raise RuntimeError('Waited exceeding grace period. ') logging.info('restarting workers') for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') stdout = mpr.join(timeout=250).stdout all_start_point = [] checkpoint_count = [] for msg in stdout: # TODO(wxinyi): remove the string matching and assert checkpoint number. matched_group = re.search(r'.*Restored training at (\d+)', msg) checkpoint_group = re.search(r'.*RUN_TO_CHECKPOINT set to (\d+)', msg) if matched_group: all_start_point.append(int(matched_group.group(1))) if checkpoint_group: checkpoint_count.append(int(checkpoint_group.group(1))) # remove duplicate logs created due to presence of multiple workers start_points = all_start_point[::CLUSTER_SIZE] # assert that after restarting, we don't repeat previous training steps self.assertNotEqual(start_points[-1], 0) # One for timing, another for final call. self.assertLen(set(checkpoint_count), 2)
def worker_fn(self, checkpoint_dir, cluster_spec, maintenance_event=None, training_finished=None, frequent_send=False, termination_config=failure_handling.TerminationConfig()): _enable_coordination_service(cluster_spec) strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() def mock_termination_watcher_function_gce(*args, **kwargs): del args, kwargs if not frequent_send: time.sleep(1) if (not maintenance_event.is_set()) and (random.randrange( 0, 20) > 18): maintenance_event.set() logging.info('Termination notice available.') return True elif frequent_send and not maintenance_event.is_set(): logging.info('Termination notice available.') return True return False with mock.patch.object( gce_util, 'termination_watcher_function_gce', mock_termination_watcher_function_gce), mock.patch.object( gce_util, 'detect_platform', lambda: gce_util.PlatformDevice.GCE_GPU): class Model(module.Module): def __init__(self): self.v = variables_lib.Variable( 0., synchronization=variables_lib.VariableSynchronization. ON_WRITE, aggregation=variables_lib.VariableAggregation.SUM) @def_function.function(input_signature=[]) def __call__(self): return self.v.read_value() with strategy.scope(): model = Model() fh_ckpt = tracking_util.Checkpoint(model=model) worker_preemption_watcher = failure_handling.WorkerPreemptionHandler( strategy.cluster_resolver, fh_ckpt, checkpoint_dir, termination_config) def distributed_train_step(current_epoch, current_step): @def_function.function def train_step(): model.v.assign_add(constant_op.constant(1.)) strategy.run(train_step) if current_step == STEPS_PER_EPOCH - 1: logging.info('epoch %d finished', current_epoch) logging.info('Start training at %d', worker_preemption_watcher.total_runs) for epoch in range( worker_preemption_watcher.total_runs // STEPS_PER_EPOCH, EPOCHS_TO_RUN): for step in range( worker_preemption_watcher.total_runs % STEPS_PER_EPOCH, STEPS_PER_EPOCH): worker_preemption_watcher.run(distributed_train_step, epoch, step) logging.info('Training finished.') training_finished.set() self.assertEqual( model.v.numpy(), strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH) running_threads = test_util.get_running_threads() if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads) and test_util.has_thread( _LOCAL_WATCHER_THREAD_PREFIX, running_threads): try: # Explicitly call __del__ since making it None and gc.collect does # not invoke __del__ here. worker_preemption_watcher.__del__() time.sleep(2) running_threads = test_util.get_running_threads() self.assertFalse( test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX, running_threads)) self.assertFalse( test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX, running_threads)) except urllib.error.URLError as e: if 'Temporary failure in name resolution' in e.message: # This is caused by a weird flakiness that mock.patch does not # correctly patch gce_util.request_compute_metadata, a real request # is attempted, and an error is hit in # gce_util.request_compute_metadata logging.warning('Hit a mock issue.') return
def test_grace_period_continue_training(self): grace_period = 7 has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( time_till_termination=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, maintenance_event, training_finished, False, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while (not maintenance_event.is_set()) and ( not training_finished.is_set()): time.sleep(1) # this addition to mitigate the fact that our step time is too short in test time.sleep(grace_period + 10) if not training_finished.is_set(): logging.info('restarting workers') for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') if maintenance_event.is_set(): stdout = mpr.join(timeout=250).stdout all_start_point = [] checkpoint_count = [] for msg in stdout: matched_group = re.search(r'.*Start training at (\d+)', msg) checkpoint_group = re.search( r'.*RUN_TO_CHECKPOINT set to (\d+)', msg) if matched_group: all_start_point.append(int(matched_group.group(1))) if checkpoint_group: checkpoint_count.append(int(checkpoint_group.group(1))) # remove duplicate logs created due to presence of multiple workers start_points = all_start_point[::CLUSTER_SIZE] # if maintenance_event is set at the very end of training and training # completes, there won't be a restart. if len(start_points) > 1: # assert that after restarting, we don't repeat previous training steps self.assertNotEqual(start_points[-1], 0) # One for timing, another for final call. self.assertLen(set(checkpoint_count), 2)
def test_multiple_workers_preempted_consecutively(self, grace_period): has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt') if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( time_till_termination=grace_period) mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, maintenance_event, training_finished, True, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() # wait for all cluster to exit with a time out waiting_time = 0 exit_process_count = 0 # this addition to mitigate the fact that our step time is too short in test while exit_process_count != CLUSTER_SIZE and waiting_time < max( grace_period + 15, 40): exit_process_count = 0 for worker_id in range(CLUSTER_SIZE): if not mpr.process_exists('worker', worker_id): exit_process_count += 1 waiting_time += 1 time.sleep(1) if waiting_time == max(grace_period + 5, 40): raise RuntimeError( 'Waited long but at least one worker still exist. ' 'Considering size of our model, this should not' ' happen.') maintenance_event.set() logging.info('restarting workers') for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') stdout = mpr.join(timeout=250).stdout found_message = 0 checkpoint_count = [] for msg in stdout: matched_group = re.search(r'.*has received termination notice*', msg) checkpoint_group = re.search(r'.*RUN_TO_CHECKPOINT set to (\d+)', msg) if matched_group: found_message += 1 if checkpoint_group: checkpoint_count.append(int(checkpoint_group.group(1))) self.assertGreaterEqual(found_message, 1) if grace_period > 0: self.assertLen(set(checkpoint_count), 2)
def test_grace_period_continue_training(self, input_arg, mwms_mode): checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') grace_period = 7 if _is_oss(): rpc_layer = 'grpc' else: rpc_layer = 'grpc+loas' termination_config = failure_handling.TerminationConfig( grace_period=grace_period) if mwms_mode == 'multi_worker': has_chief = False cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=CLUSTER_SIZE) checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/') maintenance_event = multi_process_runner.manager().Event() training_finished = multi_process_runner.manager().Event() training_restarted = multi_process_runner.manager().Event() mpr = multi_process_runner.MultiProcessRunner( self.worker_fn, cluster_spec, args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted, termination_config), rpc_layer=rpc_layer, return_output=True, dependence_on_chief=has_chief) logging.info('Cluster starting.') mpr.start() while (not maintenance_event.is_set()) and ( not training_finished.is_set()): time.sleep(1) raise_if_not_all_exit(grace_period, mpr) if not training_finished.is_set(): logging.info('restarting workers') training_restarted.set() for worker_id in range(CLUSTER_SIZE): mpr.start_single_process('worker', worker_id, cluster_spec) logging.info('workers restarted') mpr.join(timeout=250) self.assertTrue(training_finished.is_set()) else: maintenance_event = threading.Event() training_finished = threading.Event() training_restarted = threading.Event() cluster_spec = server_lib.ClusterSpec({}) caught_exit = False try: self.worker_fn(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted, termination_config) except SystemExit as exit_error: caught_exit = True # We cannot use assertRaise instead, since termination is not always # triggered. self.assertEqual(exit_error.code, 143) # pylint:disable=g-assert-in-except if maintenance_event.is_set() and not training_finished.is_set(): self.assertTrue(caught_exit) logging.info('restarting workers') training_restarted.set() self.worker_fn(checkpoint_dir, cluster_spec, input_arg, maintenance_event, training_finished, False, training_restarted, termination_config) self.assertTrue(training_finished.is_set())