Exemplo n.º 1
0
def test_default_resources():
    with _configuration.TemporaryConfiguration(
            _os.path.join(
                _os.path.dirname(_os.path.realpath(__file__)),
                "../../configuration/configs/good.config",
            )):

        @inputs(in1=Types.Integer)
        @outputs(out1=Types.String)
        @python_task()
        def default_task2(wf_params, in1, out1):
            pass

        request_map = {
            r.name: r.value
            for r in default_task2.container.resources.requests
        }

        limit_map = {
            l.name: l.value
            for l in default_task2.container.resources.limits
        }

        assert request_map[_task_models.Resources.ResourceName.CPU] == "500m"
        assert request_map[
            _task_models.Resources.ResourceName.MEMORY] == "500Gi"
        assert request_map[_task_models.Resources.ResourceName.GPU] == "1"
        assert request_map[
            _task_models.Resources.ResourceName.STORAGE] == "500Gi"

        assert limit_map[_task_models.Resources.ResourceName.CPU] == "501m"
        assert limit_map[_task_models.Resources.ResourceName.MEMORY] == "501Gi"
        assert limit_map[_task_models.Resources.ResourceName.GPU] == "2"
        assert limit_map[
            _task_models.Resources.ResourceName.STORAGE] == "501Gi"
Exemplo n.º 2
0
def mock_clirunner(monkeypatch):
    def f(*args, **kwargs):
        runner = CliRunner()
        base_args = [
            "-p",
            "tests",
            "-d",
            "unit",
            "-v",
            "version",
            "--pkgs",
            "common.workflows",
        ]

        result = runner.invoke(_pyflyte.main, base_args + list(args), **kwargs)

        if result.exception:
            raise result.exception

        return result

    tests_dir_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../..")
    config_path = os.path.join(tests_dir_path, "common/configs/local.config")
    with _config.TemporaryConfiguration(config_path):
        monkeypatch.syspath_prepend(tests_dir_path)
        monkeypatch.setattr(_module_loader, "iterate_modules", _fake_module_load)
        yield f
Exemplo n.º 3
0
def test_serialize():
    workflow_to_test = _workflow.workflow(
        {},
        inputs={
            "required_input": _workflow.Input(_types.Types.Integer),
            "default_input": _workflow.Input(_types.Types.Integer, default=5),
        },
    )
    workflow_to_test.id = _identifier.Identifier(
        _identifier.ResourceType.WORKFLOW, "p", "d", "n", "v")
    lp = workflow_to_test.create_launch_plan(
        fixed_inputs={"required_input": 5},
        role="iam_role",
    )

    with _configuration.TemporaryConfiguration(
            _os.path.join(
                _os.path.dirname(_os.path.realpath(__file__)),
                "../../common/configs/local.config",
            ),
            internal_overrides={
                "image": "myflyteimage:v123",
                "project": "myflyteproject",
                "domain": "development"
            },
    ):
        s = lp.serialize()

    assert s.workflow_id == _identifier.Identifier(
        _identifier.ResourceType.WORKFLOW, "p", "d", "n", "v").to_flyte_idl()
    assert s.auth_role.assumable_iam_role == "iam_role"
    assert s.default_inputs.parameters[
        "default_input"].default.scalar.primitive.integer == 5
Exemplo n.º 4
0
def test_serialize():
    workflow_to_test = _workflow.workflow(
        {},
        inputs={
            'required_input': _workflow.Input(_types.Types.Integer),
            'default_input': _workflow.Input(_types.Types.Integer, default=5)
        })
    workflow_to_test._id = _identifier.Identifier(
        _identifier.ResourceType.WORKFLOW, "p", "d", "n", "v")
    lp = workflow_to_test.create_launch_plan(
        fixed_inputs={'required_input': 5},
        role='iam_role',
    )
    with _configuration.TemporaryConfiguration(_os.path.join(
            _os.path.dirname(_os.path.realpath(__file__)),
            '../../common/configs/local.config'),
                                               internal_overrides={
                                                   'image':
                                                   'myflyteimage:v123',
                                                   'project': 'myflyteproject',
                                                   'domain': 'development'
                                               }):
        s = lp.serialize()

    assert s.workflow_id == _identifier.Identifier(
        _identifier.ResourceType.WORKFLOW, "p", "d", "n", "v").to_flyte_idl()
    assert s.auth.assumable_iam_role == 'iam_role'
    assert s.default_inputs.parameters[
        'default_input'].default.scalar.primitive.integer == 5
Exemplo n.º 5
0
def test_default_deprecated_role():
    with _configuration.TemporaryConfiguration(
            _os.path.join(_os.path.dirname(_os.path.realpath(__file__)),
                          '../../common/configs/deprecated_local.config')):
        workflow_to_test = _workflow.workflow(
            {},
            inputs={
                'required_input': _workflow.Input(_types.Types.Integer),
                'default_input': _workflow.Input(_types.Types.Integer,
                                                 default=5)
            })
        lp = workflow_to_test.create_launch_plan()
        assert lp.auth.assumable_iam_role == 'arn:aws:iam::ABC123:role/my-flyte-role'
Exemplo n.º 6
0
def mock_ctx(request):
    with _config.TemporaryConfiguration(request.param):
        sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../.."))
        try:
            with _mock.patch("flytekit.tools.module_loader.iterate_modules") as mock_module_load:
                mock_module_load.side_effect = _fake_module_load
                ctx = _mock.MagicMock()
                ctx.obj = {
                    _constants.CTX_PACKAGES: ("common.workflows",),
                    _constants.CTX_PROJECT: "tests",
                    _constants.CTX_DOMAIN: "unit",
                    _constants.CTX_VERSION: "version",
                }
                yield ctx
        finally:
            sys.path.pop()
Exemplo n.º 7
0
def mock_ctx():
    with _config.TemporaryConfiguration(
            os.path.join(os.path.dirname(os.path.realpath(__file__)),
                         '../../../common/configs/local.config')):
        sys.path.append(
            os.path.join(os.path.dirname(os.path.realpath(__file__)),
                         '../../..'))
        try:
            with _mock.patch('flytekit.tools.module_loader.iterate_modules'
                             ) as mock_module_load:
                mock_module_load.side_effect = _fake_module_load
                ctx = _mock.MagicMock()
                ctx.obj = {
                    _constants.CTX_PACKAGES: 'common.workflows',
                    _constants.CTX_PROJECT: 'tests',
                    _constants.CTX_DOMAIN: 'unit',
                    _constants.CTX_VERSION: 'version'
                }
                yield ctx
        finally:
            sys.path.pop()
Exemplo n.º 8
0
def test_overriden_resources():
    with _configuration.TemporaryConfiguration(
            _os.path.join(_os.path.dirname(_os.path.realpath(__file__)),
                          '../../configuration/configs/good.config')):

        @inputs(in1=Types.Integer)
        @outputs(out1=Types.String)
        @python_task(memory_limit="100Gi",
                     memory_request="50Gi",
                     cpu_limit="1000m",
                     cpu_request="500m",
                     gpu_limit="1",
                     gpu_request="0",
                     storage_request="100Gi",
                     storage_limit="200Gi")
        def default_task2(wf_params, in1, out1):
            pass

        request_map = {
            r.name: r.value
            for r in default_task2.container.resources.requests
        }

        limit_map = {
            l.name: l.value
            for l in default_task2.container.resources.limits
        }

        assert request_map[_task_models.Resources.ResourceName.CPU] == "500m"
        assert request_map[
            _task_models.Resources.ResourceName.MEMORY] == "50Gi"
        assert request_map[_task_models.Resources.ResourceName.GPU] == "0"
        assert request_map[
            _task_models.Resources.ResourceName.STORAGE] == "100Gi"

        assert limit_map[_task_models.Resources.ResourceName.CPU] == "1000m"
        assert limit_map[_task_models.Resources.ResourceName.MEMORY] == "100Gi"
        assert limit_map[_task_models.Resources.ResourceName.GPU] == "1"
        assert limit_map[
            _task_models.Resources.ResourceName.STORAGE] == "200Gi"
Exemplo n.º 9
0
def mock_clirunner(monkeypatch):
    def f(*args, **kwargs):
        runner = CliRunner()
        base_args = [
            '-p', 'tests',
            '-d', 'unit',
            '-v', 'version',
            '--pkgs', 'common.workflows',
        ]

        result = runner.invoke(_pyflyte.main, base_args + list(args), **kwargs)

        if result.exception:
            raise result.exception

        return result

    tests_dir_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../..')
    config_path = os.path.join(tests_dir_path, 'common/configs/local.config')
    with _config.TemporaryConfiguration(config_path):
        monkeypatch.syspath_prepend(tests_dir_path)
        monkeypatch.setattr(_module_loader, 'iterate_modules', _fake_module_load)
        yield f
Exemplo n.º 10
0
            scaling_type=HyperparameterScalingType.LINEAR),
        max_depth=IntegerParameterRange(
            min_value=5,
            max_value=7,
            scaling_type=HyperparameterScalingType.LINEAR),
        gamma=ContinuousParameterRange(
            min_value=0.0,
            max_value=0.3,
            scaling_type=HyperparameterScalingType.LINEAR),
    )


sagemaker_hpo_lp = SageMakerHPO.create_launch_plan()

with _configuration.TemporaryConfiguration(
        _os.path.join(
            _os.path.dirname(_os.path.realpath(__file__)),
            "../../common/configs/local.config",
        ),
        internal_overrides={
            "image": "myflyteimage:v123",
            "project": "myflyteproject",
            "domain": "development"
        },
):
    print("Printing WF definition")
    print(SageMakerHPO)

    print("Printing LP definition")
    print(sagemaker_hpo_lp)