예제 #1
0
def test_underscore_execute_fall_back_remote_attributes(mock_wf_exec):
    mock_wf_exec.return_value = True
    mock_client = MagicMock()

    remote = FlyteRemote(config=Config.auto(),
                         default_project="p1",
                         default_domain="d1")
    remote._client = mock_client

    options = Options(
        raw_output_data_config=common_models.RawOutputDataConfig(
            output_location_prefix="raw_output"),
        security_context=security.SecurityContext(run_as=security.Identity(
            iam_role="iam:some:role")),
    )

    def local_assertions(*args, **kwargs):
        execution_spec = args[3]
        assert execution_spec.security_context.run_as.iam_role == "iam:some:role"
        assert execution_spec.raw_output_data_config.output_location_prefix == "raw_output"

    mock_client.create_execution.side_effect = local_assertions

    mock_entity = MagicMock()

    remote._execute(
        mock_entity,
        inputs={},
        project="proj",
        domain="dev",
        options=options,
    )
예제 #2
0
def test_underscore_execute_uses_launch_plan_attributes(mock_wf_exec):
    mock_wf_exec.return_value = True
    mock_client = MagicMock()

    remote = FlyteRemote(config=Config.auto(),
                         default_project="p1",
                         default_domain="d1")
    remote._client = mock_client

    def local_assertions(*args, **kwargs):
        execution_spec = args[3]
        assert execution_spec.security_context.run_as.k8s_service_account == "svc"
        assert execution_spec.labels == common_models.Labels(
            {"a": "my_label_value"})
        assert execution_spec.annotations == common_models.Annotations(
            {"b": "my_annotation_value"})

    mock_client.create_execution.side_effect = local_assertions

    mock_entity = MagicMock()
    options = Options(
        labels=common_models.Labels({"a": "my_label_value"}),
        annotations=common_models.Annotations({"b": "my_annotation_value"}),
        security_context=security.SecurityContext(run_as=security.Identity(
            k8s_service_account="svc")),
    )

    remote._execute(
        mock_entity,
        inputs={},
        project="proj",
        domain="dev",
        options=options,
    )
예제 #3
0
 def default_from(
         cls,
         k8s_service_account: typing.Optional[str] = None,
         raw_data_prefix: typing.Optional[str] = None) -> "Options":
     return cls(
         security_context=security.SecurityContext(run_as=security.Identity(
             k8s_service_account=k8s_service_account))
         if k8s_service_account else None,
         raw_output_data_config=common_models.RawOutputDataConfig(
             output_location_prefix=raw_data_prefix)
         if raw_data_prefix else None,
     )
예제 #4
0
    (literals.LiteralCollection(literals=[l, l, l]), [v, v, v])
    for l, v in LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE
]

LIST_OF_ALL_LITERALS_AND_VALUES = (
    LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE +
    LIST_OF_LITERAL_COLLECTIONS_AND_PYTHON_VALUE)

LIST_OF_SECRETS = [
    None,
    security.Secret(group="x", key="g"),
    security.Secret(group="x",
                    key="y",
                    mount_requirement=security.Secret.MountType.ANY),
    security.Secret(group="x",
                    key="y",
                    group_version="1",
                    mount_requirement=security.Secret.MountType.FILE),
]

LIST_RUN_AS = [
    None,
    security.Identity(iam_role="role"),
    security.Identity(k8s_service_account="service_account"),
]

LIST_OF_SECURITY_CONTEXT = [
    security.SecurityContext(run_as=r, secrets=s, tokens=None)
    for r in LIST_RUN_AS for s in LIST_OF_SECRETS
] + [None]
예제 #5
0
    def get_or_create(
        cls,
        workflow: _annotated_workflow.WorkflowBase,
        name: Optional[str] = None,
        default_inputs: Dict[str, Any] = None,
        fixed_inputs: Dict[str, Any] = None,
        schedule: _schedule_model.Schedule = None,
        notifications: List[_common_models.Notification] = None,
        labels: _common_models.Labels = None,
        annotations: _common_models.Annotations = None,
        raw_output_data_config: _common_models.RawOutputDataConfig = None,
        max_parallelism: int = None,
        security_context: typing.Optional[security.SecurityContext] = None,
        auth_role: _common_models.AuthRole = None,
    ) -> LaunchPlan:
        """
        This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not
        supplied, this assumes you are looking for the default launch plan for the workflow. If it is specified, it
        will be used. If creating the default launch plan, none of the other arguments may be specified.

        The resulting launch plan is also cached and if called again with the same name, the
        cached version is returned

        :param security_context: Security context for the execution
        :param workflow: The Workflow to create a launch plan for.
        :param name: If you supply a name, keep it mind it needs to be unique. That is, project, domain, version, and
          this name form a primary key. If you do not supply a name, this function will assume you want the default
          launch plan for the given workflow.
        :param default_inputs: Default inputs, expressed as Python values.
        :param fixed_inputs: Fixed inputs, expressed as Python values. At call time, these cannot be changed.
        :param schedule: Optional schedule to run on.
        :param notifications: Notifications to send.
        :param labels: Optional labels to attach to executions created by this launch plan.
        :param annotations: Optional annotations to attach to executions created by this launch plan.
        :param raw_output_data_config: Optional location of offloaded data for things like S3, etc.
        :param auth_role: Add an auth role if necessary.
        :param max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire
            workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and
            parallelism/concurrency of MapTasks is independent from this.
        """
        if name is None and (
            default_inputs is not None
            or fixed_inputs is not None
            or schedule is not None
            or notifications is not None
            or labels is not None
            or annotations is not None
            or raw_output_data_config is not None
            or auth_role is not None
            or max_parallelism is not None
            or security_context is not None
        ):
            raise ValueError(
                "Only named launchplans can be created that have other properties. Drop the name if you want to create a default launchplan. Default launchplans cannot have any other associations"
            )

        if name is not None and name in LaunchPlan.CACHE:
            cached_outputs = vars(LaunchPlan.CACHE[name])

            notifications = notifications or []
            default_inputs = default_inputs or {}
            fixed_inputs = fixed_inputs or {}
            default_inputs.update(fixed_inputs)

            if auth_role and not security_context:
                security_context = security.SecurityContext(
                    run_as=security.Identity(
                        iam_role=auth_role.assumable_iam_role,
                        k8s_service_account=auth_role.kubernetes_service_account,
                    ),
                )

            if (
                workflow != cached_outputs["_workflow"]
                or schedule != cached_outputs["_schedule"]
                or notifications != cached_outputs["_notifications"]
                or default_inputs != cached_outputs["_saved_inputs"]
                or labels != cached_outputs["_labels"]
                or annotations != cached_outputs["_annotations"]
                or raw_output_data_config != cached_outputs["_raw_output_data_config"]
                or max_parallelism != cached_outputs["_max_parallelism"]
                or security_context != cached_outputs["_security_context"]
            ):
                raise AssertionError("The cached values aren't the same as the current call arguments")

            return LaunchPlan.CACHE[name]
        elif name is None and workflow.name in LaunchPlan.CACHE:
            return LaunchPlan.CACHE[workflow.name]

        # Otherwise, handle the default launch plan case
        if name is None:
            ctx = FlyteContext.current_context()
            lp = cls.get_default_launch_plan(ctx, workflow)
        else:
            lp = cls.create(
                name,
                workflow,
                default_inputs,
                fixed_inputs,
                schedule,
                notifications,
                labels,
                annotations,
                raw_output_data_config,
                max_parallelism,
                auth_role=auth_role,
                security_context=security_context,
            )
        LaunchPlan.CACHE[name or workflow.name] = lp
        return lp
예제 #6
0
    def create(
        cls,
        name: str,
        workflow: _annotated_workflow.WorkflowBase,
        default_inputs: Dict[str, Any] = None,
        fixed_inputs: Dict[str, Any] = None,
        schedule: _schedule_model.Schedule = None,
        notifications: List[_common_models.Notification] = None,
        labels: _common_models.Labels = None,
        annotations: _common_models.Annotations = None,
        raw_output_data_config: _common_models.RawOutputDataConfig = None,
        max_parallelism: int = None,
        security_context: typing.Optional[security.SecurityContext] = None,
        auth_role: _common_models.AuthRole = None,
    ) -> LaunchPlan:
        ctx = FlyteContextManager.current_context()
        default_inputs = default_inputs or {}
        fixed_inputs = fixed_inputs or {}
        # Default inputs come from two places, the original signature of the workflow function, and the default_inputs
        # argument to this function. We'll take the latter as having higher precedence.
        wf_signature_parameters = transform_inputs_to_parameters(ctx, workflow.python_interface)

        # Construct a new Interface object with just the default inputs given to get Parameters, maybe there's an
        # easier way to do this, think about it later.
        temp_inputs = {}
        for k, v in default_inputs.items():
            temp_inputs[k] = (workflow.python_interface.inputs[k], v)
        temp_interface = Interface(inputs=temp_inputs, outputs={})
        temp_signature = transform_inputs_to_parameters(ctx, temp_interface)
        wf_signature_parameters._parameters.update(temp_signature.parameters)

        # These are fixed inputs that cannot change at launch time. If the same argument is also in default inputs,
        # it'll be taken out from defaults in the LaunchPlan constructor
        fixed_literals = translate_inputs_to_literals(
            ctx,
            incoming_values=fixed_inputs,
            flyte_interface_types=workflow.interface.inputs,
            native_types=workflow.python_interface.inputs,
        )
        fixed_lm = _literal_models.LiteralMap(literals=fixed_literals)

        if auth_role:
            if security_context:
                raise ValueError("Use of AuthRole is deprecated. You cannot specify both AuthRole and SecurityContext")

            security_context = security.SecurityContext(
                run_as=security.Identity(
                    iam_role=auth_role.assumable_iam_role,
                    k8s_service_account=auth_role.kubernetes_service_account,
                ),
            )

        lp = cls(
            name=name,
            workflow=workflow,
            parameters=wf_signature_parameters,
            fixed_inputs=fixed_lm,
            schedule=schedule,
            notifications=notifications,
            labels=labels,
            annotations=annotations,
            raw_output_data_config=raw_output_data_config,
            max_parallelism=max_parallelism,
            security_context=security_context,
        )

        # This is just a convenience - we'll need the fixed inputs LiteralMap for when serializing the Launch Plan out
        # to protobuf, but for local execution and such, why not save the original Python native values as well so
        # we don't have to reverse it back every time.
        default_inputs.update(fixed_inputs)
        lp._saved_inputs = default_inputs

        if name in cls.CACHE:
            raise AssertionError(f"Launch plan named {name} was already created! Make sure your names are unique.")
        cls.CACHE[name] = lp
        return lp