コード例 #1
0
        def wrapper(*args, **kwargs):
            # Get function information
            filename = inspect.getfile(func)
            function_name = func.__name__

            # Construct component run object
            store = Store(_db_uri)
            component_run = store.initialize_empty_component_run(
                component_name)
            component_run.set_start_timestamp()

            # Define trace helper
            frame = None
            trace = sys.gettrace()

            def trace_helper(_frame, event, arg):
                nonlocal frame
                if frame is None and event == "call":
                    frame = _frame
                    sys.settrace(trace)
                    return trace

            # Run function under the tracer
            sys.settrace(trace_helper)
            try:
                # merge with existing run
                value = func(*args, **kwargs)
            finally:
                sys.settrace(trace)

            component_run.set_end_timestamp()

            # Do logging here
            logging.info(f"Inspecting {frame.f_code.co_filename}")
            input_pointers = []
            output_pointers = []
            local_vars = frame.f_locals

            # Auto log inputs
            if auto_log:
                # Get IOPointers corresponding to args and f_locals
                all_input_args = {
                    k: v.default
                    for k, v in inspect.signature(func).parameters.items()
                    if v.default is not inspect.Parameter.empty
                }
                all_input_args = {
                    **all_input_args,
                    **dict(zip(inspect.getfullargspec(func).args, args)),
                }
                all_input_args = {**all_input_args, **kwargs}
                input_pointers += store.get_io_pointers_from_args(
                    **all_input_args)

            # Add input_vars and output_vars as pointers
            for var in input_vars:
                if var not in local_vars:
                    raise ValueError(
                        f"Variable {var} not in current stack frame.")
                val = local_vars[var]
                if val is None:
                    logging.debug(f"Variable {var} has value {val}.")
                    continue
                if isinstance(val, list):
                    input_pointers += store.get_io_pointers(val)
                else:
                    input_pointers.append(store.get_io_pointer(str(val)))
            for var in output_vars:
                if var not in local_vars:
                    raise ValueError(
                        f"Variable {var} not in current stack frame.")
                val = local_vars[var]
                if val is None:
                    logging.debug(f"Variable {var} has value {val}.")
                    continue
                if isinstance(val, list):
                    output_pointers += (store.get_io_pointers(
                        val, pointer_type=PointerTypeEnum.ENDPOINT) if endpoint
                                        else store.get_io_pointers(val))
                else:
                    output_pointers += ([
                        store.get_io_pointer(
                            str(val), pointer_type=PointerTypeEnum.ENDPOINT)
                    ] if endpoint else [store.get_io_pointer(str(val))])
            # Add input_kwargs and output_kwargs as pointers
            for key, val in input_kwargs.items():
                if key not in local_vars or val not in local_vars:
                    raise ValueError(
                        f"Variables ({key}, {val}) not in current stack frame."
                    )
                if local_vars[key] is None:
                    logging.debug(
                        f"Variable {key} has value {local_vars[key]}.")
                    continue
                if isinstance(local_vars[key], list):
                    if not isinstance(local_vars[val], list) or len(
                            local_vars[key]) != len(local_vars[val]):
                        raise ValueError(
                            f'Value "{val}" does not have the same length as' +
                            f' the key "{key}."')
                    input_pointers += store.get_io_pointers(
                        local_vars[key], values=local_vars[val])
                else:
                    input_pointers.append(
                        store.get_io_pointer(str(local_vars[key]),
                                             local_vars[val]))
            for key, val in output_kwargs.items():
                if key not in local_vars or val not in local_vars:
                    raise ValueError(
                        f"Variables ({key}, {val}) not in current stack frame."
                    )
                if local_vars[key] is None:
                    logging.debug(
                        f"Variable {key} has value {local_vars[key]}.")
                    continue
                if isinstance(local_vars[key], list):
                    if not isinstance(local_vars[val], list) or len(
                            local_vars[key]) != len(local_vars[val]):
                        raise ValueError(
                            f'Value "{val}" does not have the same length as' +
                            f' the key "{key}."')
                    output_pointers += (store.get_io_pointers(
                        local_vars[key],
                        local_vars[val],
                        pointer_type=PointerTypeEnum.ENDPOINT,
                    ) if endpoint else store.get_io_pointers(
                        local_vars[key], local_vars[val]))
                else:
                    output_pointers += ([
                        store.get_io_pointer(
                            str(local_vars[key]),
                            local_vars[val],
                            pointer_type=PointerTypeEnum.ENDPOINT,
                        )
                    ] if endpoint else [
                        store.get_io_pointer(str(local_vars[key]),
                                             local_vars[val])
                    ])

            # Directly specified I/O
            if not callable(inputs):
                input_pointers += [store.get_io_pointer(inp) for inp in inputs]
            input_pointers += [store.get_io_pointer(inp) for inp in inputs]
            output_pointers += ([
                store.get_io_pointer(out,
                                     pointer_type=PointerTypeEnum.ENDPOINT)
                for out in outputs
            ] if endpoint else [store.get_io_pointer(out) for out in outputs])

            # If there were calls to mltrace.load and mltrace.save, log them
            if "_mltrace_loaded_artifacts" in local_vars:
                input_pointers += [
                    store.get_io_pointer(name, val) for name, val in
                    local_vars["_mltrace_loaded_artifacts"].items()
                ]
            if "_mltrace_saved_artifacts" in local_vars:
                output_pointers += [
                    store.get_io_pointer(name, val) for name, val in
                    local_vars["_mltrace_saved_artifacts"].items()
                ]

            func_source_code = inspect.getsource(func)
            if auto_log:
                # Get IOPointers corresponding to args and f_locals
                all_output_args = {
                    k: v
                    for k, v in local_vars.items() if k not in all_input_args
                }
                output_pointers += store.get_io_pointers_from_args(
                    **all_output_args)

            component_run.add_inputs(input_pointers)
            component_run.add_outputs(output_pointers)

            # Add code versions
            try:
                repo = git.Repo(search_parent_directories=True)
                component_run.set_git_hash(str(repo.head.object.hexsha))
            except Exception as e:
                logging.info("No git repo found.")

            # Add git tags
            if get_git_tags() is not None:
                component_run.set_git_tags(get_git_tags())

            # Add source code if less than 2^16
            if len(func_source_code) < 2**16:
                component_run.set_code_snapshot(
                    bytes(func_source_code, "ascii"))

            # Create component if it does not exist
            create_component(component_run.component_name, "", "")

            store.set_dependencies_from_inputs(component_run)

            # Commit component run object to the DB
            store.commit_component_run(component_run,
                                       staleness_threshold=staleness_threshold)

            return value
コード例 #2
0
            def wrapper(*args, **kwargs):
                # Construct component run object
                store = Store(clientUtils.get_db_uri())
                component_run = store.initialize_empty_component_run(self.name)

                # Assert key names are not in args or kwargs
                if (set(key_names) & set(inspect.getfullargspec(func).args)
                    ) or (set(key_names) & set(kwargs.keys())):
                    raise ValueError(
                        "skip_before or skip_after cannot be in " +
                        f"the arguments of the function {func.__name__}")

                # Make Dictionary of test status
                status = {}

                # Run before tests
                if not user_kwargs.get("skip_before"):
                    all_args = dict(
                        zip(inspect.getfullargspec(func).args, args))
                    all_args = {
                        k if k not in inv_user_kwargs else inv_user_kwargs[k]:
                        v
                        for k, v in all_args.items()
                    }
                    all_args = {**all_args, **kwargs}
                    status.update(self.beforeRun(**all_args))

                # Create input and output pointers
                input_pointers = []
                output_pointers = []

                # Auto log inputs
                if auto_log:
                    # Get IOPointers corresponding to args and f_locals
                    all_input_args = {
                        k: v.default
                        for k, v in inspect.signature(func).parameters.items()
                        if v.default is not inspect.Parameter.empty
                    }
                    all_input_args = {
                        **all_input_args,
                        **dict(zip(inspect.getfullargspec(func).args, args)),
                    }
                    all_input_args = {**all_input_args, **kwargs}
                    input_pointers += store.get_io_pointers_from_args(
                        should_filter=True, **all_input_args)

                def mlflow_start_run_id():
                    nonlocal mlflow_run_id
                    res = mlflow_start_run_copy()
                    if mlflow.active_run():
                        mlflow_run_id = mlflow.active_run().info.run_id
                    return res

                # monkey patching mlflow.start_run method
                mlflow_run_id = None
                mlflow_start_run_copy = mlflow.start_run
                mlflow.start_run = mlflow_start_run_id

                component_run.set_start_timestamp()
                # Run function
                local_vars, value = utils.run_func_capture_locals(
                    func, *args, **kwargs)
                component_run.set_end_timestamp()

                if mlflow_run_id is not None:
                    try:
                        mlflow_run = mlflow.get_run(mlflow_run_id)
                        component_run.set_mlflow_run_id(mlflow_run_id)
                        metrics = mlflow_run.data.metrics
                        params = mlflow_run.data.params
                        component_run.set_mlflow_run_metrics(metrics)
                        component_run.set_mlflow_run_params(params)
                    except Exception as e:
                        logging.warning(
                            f"Mlflow.get_run {mlflow_run_id} failed.")
                mlflow.start_run = mlflow_start_run_copy

                if not callable(input_vars):
                    # Log input and output vars
                    duplicate = input_vars
                    if not isinstance(duplicate, dict):
                        duplicate = {vname: None for vname in input_vars}

                    for var, label_vars in duplicate.items():
                        if var not in local_vars:
                            raise ValueError(
                                f"Variable {var} not in current stack frame.")
                        val = local_vars[var]
                        labels = None
                        if label_vars is not None:
                            try:
                                labels = ([
                                    local_vars[lv] for lv in label_vars
                                ] if isinstance(label_vars, list) else
                                          local_vars[label_vars])
                                if isinstance(labels, str):
                                    labels = [labels]
                            except KeyError:
                                raise ValueError(
                                    f"Variable {label_vars} not " +
                                    f"in current stack frame.")
                        if val is None:
                            logging.debug(f"Variable {var} has value {val}.")
                            continue
                        input_pointers += store.get_io_pointers_from_args(
                            should_filter=False, labels=labels, **{var: val})

                    for var in output_vars:
                        if var not in local_vars:
                            raise ValueError(
                                f"Variable {var} not in current stack frame.")
                        val = local_vars[var]
                        if val is None:
                            logging.debug(f"Variable {var} has value {val}.")
                            continue
                        output_pointers += store.get_io_pointers_from_args(
                            should_filter=False, **{var: val})

                # If there were calls to mltrace.load and mltrace.save, log

                if "_mltrace_loaded_artifacts" in local_vars:
                    input_pointers += [
                        store.get_io_pointer(name, val) for name, val in
                        local_vars["_mltrace_loaded_artifacts"].items()
                    ]
                if "_mltrace_saved_artifacts" in local_vars:
                    output_pointers += [
                        store.get_io_pointer(name, val) for name, val in
                        local_vars["_mltrace_saved_artifacts"].items()
                    ]

                func_source_code = inspect.getsource(func)
                if auto_log:
                    # Get IOPointers corresponding to args and f_locals
                    all_output_args = {
                        k: v
                        for k, v in local_vars.items()
                        if k not in all_input_args
                    }
                    output_pointers += store.get_io_pointers_from_args(
                        should_filter=True, **all_output_args)

                # Check that none of the labels in the inputs are deleted
                store.assert_not_deleted_labels(
                    input_pointers, staleness_threshold=staleness_threshold)
                # Propagate labels
                store.propagate_labels(input_pointers, output_pointers)

                component_run.add_inputs(input_pointers)
                component_run.add_outputs(output_pointers)

                # Add code versions
                try:
                    repo = git.Repo(search_parent_directories=True)
                    component_run.set_git_hash(str(repo.head.object.hexsha))
                except Exception as e:
                    logging.info("No git repo found.")

                # Add git tags
                if client.get_git_tags() is not None:
                    component_run.set_git_tags(client.get_git_tags())

                # Add source code if less than 2^16
                if len(func_source_code) < 2**16:
                    component_run.set_code_snapshot(
                        bytes(func_source_code, "ascii"))

                # Create component if it does not exist
                client.create_component(self.name, self.description,
                                        self.owner, self.tags)

                # Set dependencies
                store.set_dependencies_from_inputs(component_run)

                # Perform after run tests
                if not user_kwargs.get("skip_after"):
                    after_run_args = {
                        k if k not in inv_user_kwargs else inv_user_kwargs[k]:
                        v
                        for k, v in local_vars.items()
                    }
                    status.update(self.afterRun(**after_run_args))

                # update the component's testStatus, convert status to a json
                component_run.set_test_result(status)

                # Commit component run object to the DB
                store.commit_component_run(
                    component_run, staleness_threshold=staleness_threshold)

                return value