Ejemplo n.º 1
0
def test_retry_run_on_yarn(nb_retries, nb_failures):
    cpt = 0

    def fail(*args, **kwargs):
        if cpt < nb_failures:
            raise Exception("")
        else:
            pass

    with mock.patch('tf_yarn.client._setup_pyenvs'), \
            mock.patch('tf_yarn.client._setup_skein_cluster') as mock_setup_skein_cluster, \
            mock.patch('tf_yarn.client._run_on_cluster') as mock_run_on_cluster:
        mock_run_on_cluster.side_effect = fail

        gb = 2**10

        try:
            run_on_yarn(
                "path/to/env", lambda: Experiment(None, None, None),
                task_specs={
                    "chief": TaskSpec(memory=16 * gb, vcores=16),
                    "worker": TaskSpec(memory=16 * gb, vcores=16, instances=1),
                    "ps": TaskSpec(memory=16 * gb, vcores=16, instances=1)
                },
                nb_retries=nb_retries
            )
        except Exception:
            pass

        nb_calls = min(nb_retries, nb_failures) + 1
        assert mock_run_on_cluster.call_count == nb_calls
        assert mock_setup_skein_cluster.call_count == nb_calls
Ejemplo n.º 2
0
def run_on_yarn(experiment_fn: Union[ExperimentFn, KerasExperimentFn],
                task_specs: Dict[str, topologies.TaskSpec] = DEFAULT_TASK_SPEC,
                *args,
                **kwargs) -> Optional[Metrics]:
    def _new_experiment_fn():
        return _add_monitor_to_experiment(experiment_fn())

    return client.run_on_yarn(_new_experiment_fn, task_specs, *args, **kwargs)
Ejemplo n.º 3
0
def test_kill_skein_on_exception():
    def cloudpickle_raise_exception(*args, **kwargs):
        raise Exception("Cannot serialize your method!")

    with mock.patch('tf_yarn.client._setup_skein_cluster') as mock_setup_skein_cluster:
        with mock.patch('tf_yarn.client._setup_pyenvs'):
            with mock.patch('tf_yarn.client.cloudpickle.dumps') as mock_cloudpickle:
                mock_cloudpickle.side_effect = cloudpickle_raise_exception
                mock_app = mock.MagicMock(skein.ApplicationClient)
                mock_setup_skein_cluster.return_value = SkeinCluster(
                    client=None, app=mock_app,
                    event_listener=None, events=None,
                    tasks=[])
                try:
                    run_on_yarn(None, None, {})
                except Exception:
                    print(traceback.format_exc())
                    pass
                mock_app.shutdown.assert_called_once_with(
                    skein.model.FinalStatus.FAILED)
Ejemplo n.º 4
0
def run_on_yarn(
    experiment_fn: ExperimentFn,
    task_specs: Dict[str, topologies.TaskSpec],
    **kwargs
) -> Optional[Metrics]:
    if "custom_task_module" not in kwargs:
        kwargs["custom_task_module"] = "tf_yarn.pytorch.tasks.worker"
    return client.run_on_yarn(
        experiment_fn,
        task_specs,
        **kwargs
    )