def test_template(self, strategy_cls, file_format): num_workers = 2 num_epoch = 2 cluster_spec = test_base.create_cluster_spec(num_workers=num_workers, test_obj=self) self._barrier = dc._Barrier(2) def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside of a thread.""" with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): strategy = get_strategy_object(strategy_cls) batch_size = 64 steps = 2 train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) with strategy.scope(): model = multi_worker_testing_utils.get_mnist_model( (28, 28, 1)) custom_callable(model, self, train_ds, num_epoch, steps, strategy, saving_filepath=kwargs['saving_filepath'], barrier=kwargs['barrier'], threading_local=kwargs['threading_local']) # Pass saving_filepath from the parent thread to ensure every worker has the # same fileapth to save. saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint.' + file_format) barrier = dc._Barrier(2) threading_local = threading.local() threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, saving_filepath=saving_filepath, barrier=barrier, threading_local=threading_local) self.assertFalse(training_state.checkpoint_exists(saving_filepath)) threads_to_join = [] strategy = get_strategy_object(strategy_cls) if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join)
def run_independent_workers(self, worker_fn, strategy_cls, num_workers, num_ps=None, **kwargs): cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=num_workers, num_ps=num_ps) self._barrier = dc._Barrier(num_workers + (num_ps or 0)) # pylint: disable=protected-access def _worker_fn(**kwargs): """Runs the worker function in a thread.""" with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): strategy = get_strategy_object(strategy_cls) with strategy.scope(): return worker_fn(**kwargs) threads = self.run_multiple_tasks_in_threads(_worker_fn, cluster_spec, **kwargs) strategy = get_strategy_object(strategy_cls) if strategy.extended.experimental_between_graph: threads_to_join = threads.get('chief', []) + threads.get( 'worker', []) else: threads_to_join = [ threads['chief'][0] if 'chief' in threads else threads['worker'][0] ] self.join_independent_workers(threads_to_join)
def test_complete_flow_indepedent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): train_distribute = train_distribute_cls( num_gpus_per_worker=context.num_gpus()) if eval_distribute_cls: eval_distribute = eval_distribute_cls() else: eval_distribute = None cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=2, has_eval=True) # 3 workers, 2 ps and 1 evaluator. self._barrier = dc._Barrier(6) threads = self._run_multiple_tasks_in_threads(cluster_spec, train_distribute, eval_distribute) for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: t.join() estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator)
def setUp(self): self._result_correct = 0 self._lock = threading.Lock() self._worker_context = {} self._strategy_property = {} self._std_servers = {} self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
def setUp(self): self._result_correct = 0 self._lock = threading.Lock() self._worker_context = {} self._strategy_property = {} self._std_servers = {} self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
def run_independent_workers(self, worker_fn, strategy_cls, num_workers, num_ps=None, **kwargs): cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=num_workers, num_ps=num_ps) self._barrier = dc._Barrier(num_workers + (num_ps or 0)) # pylint: disable=protected-access def _worker_fn(**kwargs): """Runs the worker function in a thread.""" with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): strategy = get_strategy_object(strategy_cls) with strategy.scope(): return worker_fn(**kwargs) threads = self.run_multiple_tasks_in_threads(_worker_fn, cluster_spec, **kwargs) strategy = get_strategy_object(strategy_cls) if strategy.extended.experimental_between_graph: threads_to_join = [ ts for task_type, ts in threads.items() if task_type == 'ps' ] else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join)
def testSimpleModelIndependentWorkerAsync(self, strategy_cls): num_workers = 2 num_epoch = 2 cluster_spec = test_base.create_cluster_spec(num_workers=num_workers, num_ps=2, test_obj=self) self._barrier = dc._Barrier(4) # The verification callback will be shared by multiple threads. verification_callback = MultiWorkerVerificationCallback( num_epoch=num_epoch, num_worker=num_workers) def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside of a thread.""" # TODO(rchao/yuefengz): The following is run by both worker and ps # threads. The distribute coordinator should run std server immediately # without configuring the session (or building the graph) on PS. with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): batch_size = 64 steps = 2 strategy = strategy_cls() verification_callback.is_between_graph = \ strategy.extended.experimental_between_graph train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) val_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) with strategy.scope(): model = multi_worker_testing_utils.get_mnist_model( (28, 28, 1)) # TODO(b/123868066): Verify callback for model.evaluate(). callbacks_for_fit = nest.flatten( kwargs.get('verification_callback', [])) history = model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, validation_data=val_ds, validation_steps=steps, callbacks=callbacks_for_fit) self.assertIsInstance(history, keras.callbacks.History) threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, verification_callback=verification_callback) threads_to_join = [] for task_type, ts in threads.items(): # This test can finish once the worker threads complete, and thus # the ps threads don't need to be joined. if task_type == 'ps': continue threads_to_join.extend(ts) self.join_independent_workers(threads_to_join) verification_callback.verify(self)
def testSimpleModelIndependentWorkerAsync(self, strategy_cls): num_workers = 2 num_epoch = 2 cluster_spec = test_base.create_cluster_spec( num_workers=num_workers, num_ps=2) self._barrier = dc._Barrier(4) # The verification callback will be shared by multiple threads. verification_callback = MultiWorkerVerificationCallback( num_epoch=num_epoch, num_worker=num_workers) def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside of a thread.""" # TODO(rchao/yuefengz): The following is run by both worker and ps # threads. The distribute coordinator should run std server immediately # without configuring the session (or building the graph) on PS. with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): batch_size = 64 steps = 2 strategy = strategy_cls() verification_callback.is_between_graph = \ strategy.extended.experimental_between_graph train_ds, _ = _mnist_synthetic_dataset(batch_size, steps) val_ds, _ = _mnist_synthetic_dataset(batch_size, steps) with strategy.scope(): model = _get_model((28, 28, 1)) # TODO(b/123868066): Verify callback for model.evaluate(). callbacks_for_fit = nest.flatten( kwargs.get('verification_callback', [])) history = model.fit( x=train_ds, epochs=num_epoch, steps_per_epoch=steps, validation_data=val_ds, validation_steps=steps, callbacks=callbacks_for_fit) self.assertIsInstance(history, keras.callbacks.History) threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, verification_callback=verification_callback) threads_to_join = [] for task_type, ts in threads.items(): # This test can finish once the worker threads complete, and thus # the ps threads don't need to be joined. if task_type == 'ps': continue threads_to_join.extend(ts) self.join_independent_workers(threads_to_join) verification_callback.verify(self)
def test_complete_flow_independent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): if (context.num_gpus() < 2 and eval_distribute_cls == tf.distribute.experimental.MultiWorkerMirroredStrategy): self.skipTest( "`CollectiveAllReduceStrategy` needs at least two towers.") if (train_distribute_cls == tf.distribute.experimental.ParameterServerStrategy): cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=2, has_eval=True) # 3 workers, 2 ps and 1 evaluator. self._barrier = dc._Barrier(6) else: cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) train_distribute = self._get_strategy_object(train_distribute_cls, cluster_spec=cluster_spec) if eval_distribute_cls: eval_distribute = self._get_strategy_object(eval_distribute_cls, eval_strategy=True) else: eval_distribute = None threads = self.run_multiple_tasks_in_threads( self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) threads_to_join = [] for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: threads_to_join.append(t) self.join_independent_workers(threads_to_join) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator)
def test_complete_flow_indepedent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): train_distribute = train_distribute_cls( num_gpus_per_worker=context.num_gpus()) if (context.num_gpus() < 2 and eval_distribute_cls == collective_all_reduce_strategy.CollectiveAllReduceStrategy): self.skipTest( "`CollectiveAllReduceStrategy` needs at least two towers.") if eval_distribute_cls: eval_distribute = eval_distribute_cls( num_gpus_per_worker=context.num_gpus()) else: eval_distribute = None if (train_distribute_cls == parameter_server_strategy.ParameterServerStrategy): cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=2, has_eval=True) # 3 workers, 2 ps and 1 evaluator. self._barrier = dc._Barrier(6) else: cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) threads = self._run_multiple_tasks_in_threads(cluster_spec, train_distribute, eval_distribute) for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: t.join() estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator)
def test_complete_flow_independent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): if (context.num_gpus() < 2 and eval_distribute_cls == collective_all_reduce_strategy.CollectiveAllReduceStrategy): self.skipTest("`CollectiveAllReduceStrategy` needs at least two towers.") train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: eval_distribute = self._get_strategy_object( eval_distribute_cls, eval_strategy=True) else: eval_distribute = None if (train_distribute_cls == parameter_server_strategy .ParameterServerStrategy): cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=2, has_eval=True) # 3 workers, 2 ps and 1 evaluator. self._barrier = dc._Barrier(6) else: cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) threads_to_join = [] for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: threads_to_join.append(t) self.join_independent_workers(threads_to_join) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator)
def testSimpleModelIndependentWorkerSync(self, strategy_cls): num_workers = 2 num_epoch = 2 cluster_spec = tf.__internal__.distribute.multi_process_runner.create_cluster_spec( num_workers=num_workers) self._barrier = dc._Barrier(2) # The verification callback will be shared by multiple threads. verification_callback = MultiWorkerVerificationCallback( num_epoch=num_epoch, num_worker=num_workers) def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside of a thread.""" with tf.compat.v1.test.mock.patch.object( dc, '_run_std_server', self._make_mock_run_std_server()): strategy = strategy_cls() verification_callback.is_between_graph = \ strategy.extended.experimental_between_graph batch_size = 64 steps = 2 train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) with strategy.scope(): model = multi_worker_testing_utils.get_mnist_model( (28, 28, 1)) orig_loss, _ = model.evaluate(train_ds, steps=steps) callbacks_for_fit = tf.nest.flatten( kwargs.get('verification_callback', [])) history = model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=callbacks_for_fit) self.assertIsInstance(history, keras.callbacks.History) trained_loss, _ = model.evaluate(train_ds, steps=steps) self.assertLess(trained_loss, orig_loss) threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, verification_callback=verification_callback) threads_to_join = [] strategy = strategy_cls() if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join) verification_callback.verify(self)
def testSimpleModelIndependentWorkerSync(self, strategy_cls): num_workers = 2 num_epoch = 2 cluster_spec = test_base.create_cluster_spec(num_workers=num_workers) self._barrier = dc._Barrier(2) # The verification callback will be shared by multiple threads. verification_callback = MultiWorkerVerificationCallback( num_epoch=num_epoch, num_worker=num_workers) def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside of a thread.""" with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): strategy = strategy_cls() verification_callback.is_between_graph = \ strategy.extended.experimental_between_graph batch_size = 64 steps = 2 train_ds, _ = _mnist_synthetic_dataset(batch_size, steps) with strategy.scope(): model = _get_model((28, 28, 1)) orig_loss, _ = model.evaluate(train_ds, steps=steps) callbacks_for_fit = nest.flatten( kwargs.get('verification_callback', [])) history = model.fit( x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=callbacks_for_fit) self.assertIsInstance(history, keras.callbacks.History) trained_loss, _ = model.evaluate(train_ds, steps=steps) self.assertLess(trained_loss, orig_loss) threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, verification_callback=verification_callback) threads_to_join = [] strategy = strategy_cls() if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join) verification_callback.verify(self)
def test_template(self, strategy_cls): num_workers = 2 num_epoch = 2 cluster_spec = test_base.create_cluster_spec(num_workers=num_workers) self._barrier = dc._Barrier(2) def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside of a thread.""" with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): strategy = get_strategy_object(strategy_cls) batch_size = 64 steps = 2 train_ds, _ = _mnist_synthetic_dataset(batch_size, steps) with strategy.scope(): model = _get_model((28, 28, 1)) custom_callable( model, self, train_ds, num_epoch, steps, strategy, saving_filepath=kwargs['saving_filepath']) # Pass saving_filepath from the parent thread to ensure every worker has the # same fileapth to save. saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint.h5') threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, saving_filepath=saving_filepath) if os.path.exists(saving_filepath): os.remove(saving_filepath) threads_to_join = [] strategy = get_strategy_object(strategy_cls) if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join)
def test_complete_flow_independent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]]) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator)
def test_complete_flow_independent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]]) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator)
def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): train_distribute = train_distribute_cls( num_gpus_per_worker=context.num_gpus()) if eval_distribute_cls: eval_distribute = eval_distribute_cls() else: eval_distribute = None cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) threads = self._run_multiple_tasks_in_threads( cluster_spec, train_distribute, eval_distribute) threads[WORKER][0].join() threads[EVALUATOR][0].join() estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator)
def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): train_distribute = train_distribute_cls( num_gpus_per_worker=context.num_gpus()) if eval_distribute_cls: eval_distribute = eval_distribute_cls( num_gpus_per_worker=context.num_gpus()) else: eval_distribute = None cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) threads = self._run_multiple_tasks_in_threads( cluster_spec, train_distribute, eval_distribute) threads[WORKER][0].join() threads[EVALUATOR][0].join() estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator)
def testFaultToleranceInSyncStrategy(self, strategy_cls, file_format, preemption_callback): """Test fault-tolerance with multi-threading using sync dist-strat. This test simulates multi-worker training that is interrupted by a preemption, by having two threads, each of which represents a chief and a non-chief worker, where the non-chief raises an error in the middle of training loop. Upon excepting the error, a new thread with a new cluster spec is created to simulate the recovered non-chief worker. Meanwhile, the chief worker cannot proceed and hangs since the non-chief worker has crashed. To simulate a restart of the chief, a new thread has been prepared to run to take over chief with the help of a condition variable. It is expected that after the restart of both chief and non-chief workers, the training continues from the epoch they previously failed at. The test concludes by verifying the preemption-interrupted training can finish with the same loss and accuracy had the preemption not occurred. Arguments: strategy_cls: The strategy class to use. file_format: `h5` or `tf`. preemption_callback: The callback to simulate preemption. """ def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): # Condition variable that blocks the thread that represents the # restarted chief. cv = kwargs.get('cv', None) # `before_restart` is True for the threads that represent the original # chief and non-chief worker, and False for threads that represent the # restarted chief and non-chief workers. before_restart = kwargs['before_restart'] if kwargs['new_chief']: # `new_chief` is only True for the restarted chief thread. It waits # until non-chief is preempted and restarted to simulate the causality # where chief's restart results from non-chief's failure. cv.acquire() while not hasattr(cv, 'preempted'): cv.wait() cv.release() # Model building under strategy scope. Following is the code we expect # the user runs on every worker. strategy = get_strategy_object(strategy_cls) batch_size = 64 steps = 3 train_ds, _ = _mnist_synthetic_dataset(batch_size, steps) with strategy.scope(): model = _get_model((28, 28, 1)) # Function to start a new thread. This will be called twice in the # following code: one represents the restart of the non-chief, and one # represents the restart of the chief as a result of the restart of the # non-chief (so the training can continue in sync). def start_new_thread(new_chief=False): new_thread_tf_config = json.loads(os.environ['TF_CONFIG']) new_thread_tf_config['cluster']['worker'] = kwargs['reserved_ports'] return self._run_task_in_thread( task_fn=_independent_worker_fn, cluster_spec=None, task_type=None, task_id=None, tf_config=new_thread_tf_config, before_restart=False, cv=cv, new_chief=new_chief) if test_base.is_chief() and before_restart: # Chief to start a new thread (that will be blocked by a condition # variable until the non-chief's new thread is started). The thread # for (recovered) chief is started before entering `fit()` because # the original chief thread will eventually hang and be ignored. start_new_thread(new_chief=True) try: class CkptSavedEpochAssertingCallback(callbacks.Callback): def __init__(self, test_obj): super(CkptSavedEpochAssertingCallback, self).__init__() self.test_obj = test_obj def on_epoch_begin(self, epoch, logs=None): # `_ckpt_saved_epoch` attribute is set at the end of every epoch. self.test_obj.assertEqual(self.model._ckpt_saved_epoch is None, epoch == 0) callbacks_list = [ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=True, load_weights_on_restart=True), CkptSavedEpochAssertingCallback(self) ] if before_restart: callbacks_list.append(preemption_callback()) self.assertIsNone(model._ckpt_saved_epoch) history = model.fit( x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=callbacks_list) self.assertIsNone(model._ckpt_saved_epoch) # `history` of the training result is collected to be compared against # each other. It is expected that the training results (loss and # accuracy`) are the same with or without preemption. self._histories.append(history.history) except RuntimeError: # pylint: disable=g-assert-in-except self.assertTrue(before_restart) # Reset the barrier so the new threads simulating recovery can # continue. self._barrier._counter = 0 self._barrier._flag = False # Now that the non-chief has been preempted, it notifies the thread # that simulates the restarted chief to start so they can be back in # sync. cv.acquire() cv.preempted = True cv.notify() cv.release() # At this point we should discard the original non-chief thread, and # start the new thread that simulates the restarted non-chief, hence # joining the thread and return. self.join_independent_workers([start_new_thread()]) return # Successful end of a `fit()` call. self._successful_thread_ends += 1 self.assertFalse(before_restart) # Common parameters num_workers = 2 num_epoch = 3 # History list storing the results for preemption and no preemption cases. self._histories = [] # Pass `saving_filepath` from the parent thread to ensure every worker has # the same filepath to save. saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint.' + file_format) strategy = get_strategy_object(strategy_cls) # Case 1: Training for `num_epoch` without preemptions. cluster_spec = test_base.create_cluster_spec(num_workers=num_workers) self._barrier = dc._Barrier(2) self._successful_thread_ends = 0 threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, saving_filepath=saving_filepath, before_restart=False, new_chief=False) if os.path.exists(saving_filepath): os.remove(saving_filepath) threads_to_join = [] if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join) self.assertEqual(self._successful_thread_ends, 2) # Case 2: Training for `num_epoch` epoch with preemptions. # The preemption is simulated at both epoch boundary and batch boundary. cluster_spec = test_base.create_cluster_spec(num_workers=num_workers) cv = threading.Condition() self._barrier = dc._Barrier(2) # Ports reserved for new threads simulating recovery. reserved_ports = [ 'localhost:%s' % test_base.pick_unused_port() for _ in range(num_workers) ] self._successful_thread_ends = 0 threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, saving_filepath=saving_filepath, reserved_ports=reserved_ports, before_restart=True, cv=cv, new_chief=False) if os.path.exists(saving_filepath): os.remove(saving_filepath) threads_to_join = [] if strategy.extended.experimental_between_graph: # Only join the non-chief thread since the first thread for chief will # eventually hang and be ignored. threads_to_join = [threads['worker'][1]] else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join) self.assertEqual(self._successful_thread_ends, 2) def assert_all_elements_are_identical(list_to_check): first_item = list_to_check[0] for item in list_to_check[1:]: self.assertAllClose(first_item, item, rtol=1e-5, atol=1e-5) # Important: the results from preemption interrupted and non-interrupted # cases should give the same final results. assert_all_elements_are_identical( [history['acc'][-1] for history in self._histories]) assert_all_elements_are_identical( [history['loss'][-1] for history in self._histories]) # The length of `self._histories` would be num_workers * num_runs (3). self.assertLen(self._histories, 4)
def testFaultToleranceInSyncStrategy(self, strategy_cls, file_format, preemption_callback, save_weights_only, load_weights_on_restart): """Test fault-tolerance with multi-threading using sync dist-strat. This test simulates multi-worker training that is interrupted by a preemption, by having two threads, each of which represents a chief and a non-chief worker, where the non-chief raises an error in the middle of training loop. Upon excepting the error, a new thread with a new cluster spec is created to simulate the recovered non-chief worker. Meanwhile, the chief worker cannot proceed and hangs since the non-chief worker has crashed. To simulate a restart of the chief, a new thread has been prepared to run to take over chief with the help of a condition variable. It is expected that after the restart of both chief and non-chief workers, the training continues from the epoch they previously failed at. The test concludes by verifying the preemption-interrupted training can finish with the same loss and accuracy had the preemption not occurred. TODO(rchao): Add test to check preemption on chief (possibly using multi processes). TODO(rchao): Add test to check fault-tolerance with multiple `model.fit()`. Arguments: strategy_cls: The strategy class to use. file_format: `h5` or `tf`. preemption_callback: The callback to simulate preemption. save_weights_only: The argument for `model.fit()`'s `save_weights_only`. load_weights_on_restart: The argument for `model.fit()`'s `load_weights_on_restart`. """ def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): # `before_restart` is True for the threads that represent the original # chief and non-chief worker, and False for threads that represent the # restarted chief and non-chief workers. before_restart = kwargs['before_restart'] # Model building under strategy scope. Following is the code we expect # the user runs on every worker. strategy = get_strategy_object(strategy_cls) batch_size = 64 steps = 3 train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) with strategy.scope(): model = multi_worker_testing_utils.get_mnist_model( (28, 28, 1)) # Function to start a new thread. This will be called twice in the # following code: one represents the restart of the non-chief, and one # represents the restart of the chief as a result of the restart of the # non-chief (so the training can continue in sync). def start_new_thread(new_chief): new_thread_tf_config = json.loads(os.environ['TF_CONFIG']) # Update the ports in new chief and new worker threads. new_thread_tf_config['cluster']['worker'] = kwargs[ 'reserved_ports'] # Since both new chief and new worker threads are started from the # worker thread, we need to overwrite the tf config task index. new_thread_tf_config['task'][ 'index'] = 0 if new_chief else 1 return self._run_task_in_thread( task_fn=_independent_worker_fn, cluster_spec=None, task_type=None, task_id=None, tf_config=new_thread_tf_config, before_restart=False, new_chief=new_chief) try: class CkptSavedEpochAssertingCallback(callbacks.Callback): def __init__(self, test_obj): super(CkptSavedEpochAssertingCallback, self).__init__() self.test_obj = test_obj def on_epoch_begin(self, epoch, logs=None): # `_ckpt_saved_epoch` attribute is set at the end of every epoch. self.test_obj.assertEqual( K.eval(self.model._ckpt_saved_epoch) == training_state.CKPT_SAVED_EPOCH_UNUSED_VALUE, epoch == 0) callbacks_list = [ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=save_weights_only, load_weights_on_restart=load_weights_on_restart), CkptSavedEpochAssertingCallback(self) ] if before_restart: callbacks_list.append(preemption_callback()) self.assertFalse( hasattr(model, training_state.CKPT_SAVED_EPOCH)) history = model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=callbacks_list) self.assertFalse( hasattr(model, training_state.CKPT_SAVED_EPOCH)) # `history` of the training result is collected to be compared against # each other. It is expected that the training results (loss and # accuracy`) are the same with or without preemption. self._histories.append(history.history) except RuntimeError: # pylint: disable=g-assert-in-except self.assertTrue(before_restart) # Reset the barrier so the new threads simulating recovery can # continue. self._barrier._counter = 0 self._barrier._flag = False # At this point we block the original non-chief thread, and # start the new threads that simulate the restarted chief and # non-chief, joining the threads and return. new_chief_thread = start_new_thread(new_chief=True) new_worker_thread = start_new_thread(new_chief=False) self.join_independent_workers( [new_chief_thread, new_worker_thread]) return # Successful end of a `fit()` call. with self._lock: self._successful_thread_ends += 1 self.assertFalse(before_restart) # Common parameters num_workers = 2 num_epoch = 3 # History list storing the results for preemption and no preemption cases. self._histories = [] # Lock required to prevent race condition between two threads. self._lock = threading.Lock() strategy = get_strategy_object(strategy_cls) def handler(signum, frame): del signum, frame # `session.run()` within `model.fit()` can time out. Skipping it as it # doesn't represent the failure of this test. self.skipTest('Skipping test due to `session.run()` timeout.') signal.signal(signal.SIGALRM, handler) # Alarming within 5 min before the test timeouts and fails. signal.alarm(240) def get_saving_dir_and_filepath(): saving_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format) return saving_dir, saving_filepath # Case 1: Training for `num_epoch` without preemptions. cluster_spec = test_base.create_cluster_spec(num_workers=num_workers) self._barrier = dc._Barrier(2) self._successful_thread_ends = 0 # Get a new temporary filepath to save the checkpoint to. saving_dir, saving_filepath = get_saving_dir_and_filepath() threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, # Pass `saving_filepath` from the parent thread to ensure every worker # has the same filepath to save. saving_filepath=saving_filepath, before_restart=False, new_chief=False) threads_to_join = [] if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join) # `self.test_skipped_reason` could be set when a non-main thread attempts # to skip the test. # `multi_worker_test_base.skip_if_grpc_server_cant_be_started()` is an # example of where this can be set. Since raising `SkipTest` in a non-main # thread doesn't actually skip the test, we check if the test should be # skipped here once we have joined the threads. if getattr(self, 'test_skipped_reason', None) is not None: self.skipTest(self.test_skipped_reason) self.assertTrue( training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath)) self.assertEqual(self._successful_thread_ends, 2) # Case 2: Training for `num_epoch` epoch with preemptions. # The preemption is simulated at both epoch boundary and batch boundary. cluster_spec = test_base.create_cluster_spec(num_workers=num_workers) self._barrier = dc._Barrier(2) # Ports reserved for new threads simulating recovery. reserved_ports = [ 'localhost:%s' % test_base.pick_unused_port() for _ in range(num_workers) ] self._successful_thread_ends = 0 # Get a new temporary filepath to save the checkpoint to. saving_dir, saving_filepath = get_saving_dir_and_filepath() threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, # Pass `saving_filepath` from the parent thread to ensure every worker # has the same filepath to save. saving_filepath=saving_filepath, reserved_ports=reserved_ports, before_restart=True, new_chief=False) threads_to_join = [] if strategy.extended.experimental_between_graph: # Only join the non-chief thread since the first thread for chief will # eventually hang and be ignored. threads_to_join = [threads['worker'][1]] else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join) if getattr(self, 'test_skipped_reason', None) is not None: self.skipTest(self.test_skipped_reason) self.assertTrue( training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath)) self.assertEqual(self._successful_thread_ends, 2) def assert_all_elements_are_identical(list_to_check): first_item = list_to_check[0] for item in list_to_check[1:]: self.assertAllClose(first_item, item, rtol=2e-5, atol=1e-5) # Important: the results from preemption interrupted and non-interrupted # cases should give the same final results. assert_all_elements_are_identical( [history['acc'][-1] for history in self._histories]) assert_all_elements_are_identical( [history['loss'][-1] for history in self._histories]) # The length of `self._histories` would be num_workers * num_runs (3). self.assertLen(self._histories, 4) # Results from case 1 should have 3 full epochs. self.assertLen(self._histories[0]['acc'], 3) # Results from case 2 should only have 2 full epochs because it restarted at # epoch 1. self.assertLen(self._histories[-1]['acc'], 2)
def run_optimizer_comparison_with_simple_bias_model( self, strategy_cls, optimizer_class_1, optimizer_class_2): def get_input_datasets(): # Simple training input. train_input = [[1]] * 16 train_label = [[0]] * 16 ds = dataset_ops.Dataset.from_tensor_slices((train_input, train_label)) ds = maybe_shard_dataset(ds) # TODO(rchao): Investigate to figure out the reason for having 8 workers # instead of 2 as expected. return ds.batch(8, drop_remainder=True) def get_simple_bias_model(): class Bias(base_layer.Layer): def build(self, input_shape): self.bias = self.add_variable('bias', (1,), initializer='zeros') def call(self, inputs): return inputs + self.bias model = sequential.Sequential() model.add(Bias(input_shape=(1,))) return model self._lock = threading.Lock() cluster_spec = test_base.create_cluster_spec(num_workers=2) self._barrier = dc._Barrier(2) def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside a thread.""" # TODO(rchao): Refactor to abstract the common boilerplate out. with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): model = get_simple_bias_model() initial_weights = model.get_weights() def _get_model_results(optimizer, initial_weights): # Clear Keras session to reset device assignment keras.backend._SESSION.session = None strategy = strategy_cls() with strategy.scope(): train_ds = get_input_datasets() model = get_simple_bias_model() model.set_weights(initial_weights) model.compile(loss='mae', optimizer=optimizer, metrics=['mae']) return { 'trained_loss_and_accuracy': model.fit(x=train_ds, epochs=20).history, 'trained_weights': model.get_weights(), } results1 = _get_model_results(optimizer_class_1(0.01), initial_weights) results2 = _get_model_results(optimizer_class_2(0.01), initial_weights) for key in results1: self.assertAllClose( results1[key], results2[key], atol=1e-5, rtol=1e-5, msg='Fail to assert {}'.format(key)) threads = self.run_multiple_tasks_in_threads(_independent_worker_fn, cluster_spec) threads_to_join = [] strategy = strategy_cls() if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join)
def run_optimizer_comparison_with_simple_bias_model( self, strategy_cls, optimizer_class_1, optimizer_class_2): def get_input_datasets(): # Simple training input. train_input = [[1]] * 16 train_label = [[0]] * 16 ds = dataset_ops.Dataset.from_tensor_slices( (train_input, train_label)) # TODO(rchao): Investigate to figure out the reason for having 8 workers # instead of 2 as expected. return ds.batch(8, drop_remainder=True) def get_simple_bias_model(): class Bias(base_layer.Layer): def build(self, input_shape): self.bias = self.add_variable('bias', (1, ), initializer='zeros') def call(self, inputs): return inputs + self.bias model = sequential.Sequential() model.add(Bias(input_shape=(1, ))) return model self._lock = threading.Lock() cluster_spec = test_base.create_cluster_spec(num_workers=2) self._barrier = dc._Barrier(2) def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside a thread.""" # TODO(rchao): Refactor to abstract the common boilerplate out. with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): model = get_simple_bias_model() initial_weights = model.get_weights() def _get_model_results(optimizer, initial_weights): # Clear Keras session to reset device assignment keras.backend._SESSION.session = None strategy = strategy_cls() with strategy.scope(): train_ds = get_input_datasets() model = get_simple_bias_model() model.set_weights(initial_weights) model.compile(loss='mae', optimizer=optimizer, metrics=['mae']) return { 'trained_loss_and_accuracy': model.fit(x=train_ds, epochs=20).history, 'trained_weights': model.get_weights(), } results1 = _get_model_results(optimizer_class_1(0.01), initial_weights) results2 = _get_model_results(optimizer_class_2(0.01), initial_weights) for key in results1: self.assertAllClose(results1[key], results2[key], atol=1e-5, rtol=1e-5, msg='Fail to assert {}'.format(key)) threads = self.run_multiple_tasks_in_threads(_independent_worker_fn, cluster_spec) threads_to_join = [] strategy = strategy_cls() if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join)
def testFaultToleranceInSyncStrategy(self, strategy_cls, file_format, preemption_callback): """Test fault-tolerance with multi-threading using sync dist-strat. This test simulates multi-worker training that is interrupted by a preemption, by having two threads, each of which represents a chief and a non-chief worker, where the non-chief raises an error in the middle of training loop. Upon excepting the error, a new thread with a new cluster spec is created to simulate the recovered non-chief worker. Meanwhile, the chief worker cannot proceed and hangs since the non-chief worker has crashed. To simulate a restart of the chief, a new thread has been prepared to run to take over chief with the help of a condition variable. It is expected that after the restart of both chief and non-chief workers, the training continues from the epoch they previously failed at. The test concludes by verifying the preemption-interrupted training can finish with the same loss and accuracy had the preemption not occurred. Arguments: strategy_cls: The strategy class to use. file_format: `h5` or `tf`. preemption_callback: The callback to simulate preemption. """ def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument with test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): # Condition variable that blocks the thread that represents the # restarted chief. cv = kwargs.get('cv', None) # `before_restart` is True for the threads that represent the original # chief and non-chief worker, and False for threads that represent the # restarted chief and non-chief workers. before_restart = kwargs['before_restart'] if kwargs['new_chief']: # `new_chief` is only True for the restarted chief thread. It waits # until non-chief is preempted and restarted to simulate the causality # where chief's restart results from non-chief's failure. cv.acquire() while not hasattr(cv, 'preempted'): cv.wait() cv.release() # Model building under strategy scope. Following is the code we expect # the user runs on every worker. strategy = get_strategy_object(strategy_cls) batch_size = 64 steps = 3 train_ds, _ = _mnist_synthetic_dataset(batch_size, steps) with strategy.scope(): model = _get_model((28, 28, 1)) # Function to start a new thread. This will be called twice in the # following code: one represents the restart of the non-chief, and one # represents the restart of the chief as a result of the restart of the # non-chief (so the training can continue in sync). def start_new_thread(new_chief=False): new_thread_tf_config = json.loads(os.environ['TF_CONFIG']) new_thread_tf_config['cluster']['worker'] = kwargs[ 'reserved_ports'] return self._run_task_in_thread( task_fn=_independent_worker_fn, cluster_spec=None, task_type=None, task_id=None, tf_config=new_thread_tf_config, before_restart=False, cv=cv, new_chief=new_chief) if test_base.is_chief() and before_restart: # Chief to start a new thread (that will be blocked by a condition # variable until the non-chief's new thread is started). The thread # for (recovered) chief is started before entering `fit()` because # the original chief thread will eventually hang and be ignored. start_new_thread(new_chief=True) try: class CkptSavedEpochAssertingCallback(callbacks.Callback): def __init__(self, test_obj): super(CkptSavedEpochAssertingCallback, self).__init__() self.test_obj = test_obj def on_epoch_begin(self, epoch, logs=None): # `_ckpt_saved_epoch` attribute is set at the end of every epoch. self.test_obj.assertEqual( self.model._ckpt_saved_epoch is None, epoch == 0) callbacks_list = [ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=True, load_weights_on_restart=True), CkptSavedEpochAssertingCallback(self) ] if before_restart: callbacks_list.append(preemption_callback()) self.assertIsNone(model._ckpt_saved_epoch) history = model.fit(x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=callbacks_list) self.assertIsNone(model._ckpt_saved_epoch) # `history` of the training result is collected to be compared against # each other. It is expected that the training results (loss and # accuracy`) are the same with or without preemption. self._histories.append(history.history) except RuntimeError: # pylint: disable=g-assert-in-except self.assertTrue(before_restart) # Reset the barrier so the new threads simulating recovery can # continue. self._barrier._counter = 0 self._barrier._flag = False # Now that the non-chief has been preempted, it notifies the thread # that simulates the restarted chief to start so they can be back in # sync. cv.acquire() cv.preempted = True cv.notify() cv.release() # At this point we should discard the original non-chief thread, and # start the new thread that simulates the restarted non-chief, hence # joining the thread and return. self.join_independent_workers([start_new_thread()]) return # Successful end of a `fit()` call. self._successful_thread_ends += 1 self.assertFalse(before_restart) # Common parameters num_workers = 2 num_epoch = 3 # History list storing the results for preemption and no preemption cases. self._histories = [] # Pass `saving_filepath` from the parent thread to ensure every worker has # the same filepath to save. saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint.' + file_format) strategy = get_strategy_object(strategy_cls) # Case 1: Training for `num_epoch` without preemptions. cluster_spec = test_base.create_cluster_spec(num_workers=num_workers) self._barrier = dc._Barrier(2) self._successful_thread_ends = 0 threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, saving_filepath=saving_filepath, before_restart=False, new_chief=False) if os.path.exists(saving_filepath): os.remove(saving_filepath) threads_to_join = [] if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join) self.assertEqual(self._successful_thread_ends, 2) # Case 2: Training for `num_epoch` epoch with preemptions. # The preemption is simulated at both epoch boundary and batch boundary. cluster_spec = test_base.create_cluster_spec(num_workers=num_workers) cv = threading.Condition() self._barrier = dc._Barrier(2) # Ports reserved for new threads simulating recovery. reserved_ports = [ 'localhost:%s' % test_base.pick_unused_port() for _ in range(num_workers) ] self._successful_thread_ends = 0 threads = self.run_multiple_tasks_in_threads( _independent_worker_fn, cluster_spec, saving_filepath=saving_filepath, reserved_ports=reserved_ports, before_restart=True, cv=cv, new_chief=False) if os.path.exists(saving_filepath): os.remove(saving_filepath) threads_to_join = [] if strategy.extended.experimental_between_graph: # Only join the non-chief thread since the first thread for chief will # eventually hang and be ignored. threads_to_join = [threads['worker'][1]] else: threads_to_join = [threads['worker'][0]] self.join_independent_workers(threads_to_join) self.assertEqual(self._successful_thread_ends, 2) def assert_all_elements_are_identical(list_to_check): first_item = list_to_check[0] for item in list_to_check[1:]: self.assertAllClose(first_item, item, rtol=1e-5, atol=1e-5) # Important: the results from preemption interrupted and non-interrupted # cases should give the same final results. assert_all_elements_are_identical( [history['acc'][-1] for history in self._histories]) assert_all_elements_are_identical( [history['loss'][-1] for history in self._histories]) # The length of `self._histories` would be num_workers * num_runs (3). self.assertLen(self._histories, 4)