def test_retry_run_on_yarn(nb_retries, nb_failures): cpt = 0 def fail(*args, **kwargs): if cpt < nb_failures: raise Exception("") else: pass with mock.patch('tf_yarn.client._setup_pyenvs'), \ mock.patch('tf_yarn.client._setup_skein_cluster') as mock_setup_skein_cluster, \ mock.patch('tf_yarn.client._run_on_cluster') as mock_run_on_cluster: mock_run_on_cluster.side_effect = fail gb = 2**10 try: run_on_yarn( "path/to/env", lambda: Experiment(None, None, None), task_specs={ "chief": TaskSpec(memory=16 * gb, vcores=16), "worker": TaskSpec(memory=16 * gb, vcores=16, instances=1), "ps": TaskSpec(memory=16 * gb, vcores=16, instances=1) }, nb_retries=nb_retries ) except Exception: pass nb_calls = min(nb_retries, nb_failures) + 1 assert mock_run_on_cluster.call_count == nb_calls assert mock_setup_skein_cluster.call_count == nb_calls
def run_on_yarn(experiment_fn: Union[ExperimentFn, KerasExperimentFn], task_specs: Dict[str, topologies.TaskSpec] = DEFAULT_TASK_SPEC, *args, **kwargs) -> Optional[Metrics]: def _new_experiment_fn(): return _add_monitor_to_experiment(experiment_fn()) return client.run_on_yarn(_new_experiment_fn, task_specs, *args, **kwargs)
def test_kill_skein_on_exception(): def cloudpickle_raise_exception(*args, **kwargs): raise Exception("Cannot serialize your method!") with mock.patch('tf_yarn.client._setup_skein_cluster') as mock_setup_skein_cluster: with mock.patch('tf_yarn.client._setup_pyenvs'): with mock.patch('tf_yarn.client.cloudpickle.dumps') as mock_cloudpickle: mock_cloudpickle.side_effect = cloudpickle_raise_exception mock_app = mock.MagicMock(skein.ApplicationClient) mock_setup_skein_cluster.return_value = SkeinCluster( client=None, app=mock_app, event_listener=None, events=None, tasks=[]) try: run_on_yarn(None, None, {}) except Exception: print(traceback.format_exc()) pass mock_app.shutdown.assert_called_once_with( skein.model.FinalStatus.FAILED)
def run_on_yarn( experiment_fn: ExperimentFn, task_specs: Dict[str, topologies.TaskSpec], **kwargs ) -> Optional[Metrics]: if "custom_task_module" not in kwargs: kwargs["custom_task_module"] = "tf_yarn.pytorch.tasks.worker" return client.run_on_yarn( experiment_fn, task_specs, **kwargs )