예제 #1
0
 def _get_base_part_mapping(cls) -> dict:
     """Base mapping for segment/phase schemas."""
     polar_coeff_schema = CommaSeparated(Float()) | Str()
     polar_schema = Map({
         "CL": polar_coeff_schema,
         "CD": polar_coeff_schema
     }) | Str()
     return {
         # TODO: this mapping covers all possible segments, but some options are relevant
         #  only for some segments. A better check could be done in second-pass validation.
         Optional("target", default=None):
         cls._get_target_schema(),
         Optional("engine_setting", default=None):
         cls._get_value_schema(Str(), False),
         Optional(POLAR_TAG, default=None):
         polar_schema,
         Optional("thrust_rate", default=None):
         cls._get_value_schema(has_unit=False),
         Optional("climb_thrust_rate", default=None):
         cls._get_value_schema(has_unit=False),
         Optional("time_step", default=None):
         cls._get_value_schema(),
         Optional("maximum_flight_level", default=None):
         cls._get_value_schema(has_unit=False),
         Optional("mass_ratio", default=None):
         cls._get_value_schema(has_unit=False),
         Optional("reserve_mass_ratio", default=None):
         cls._get_value_schema(has_unit=False),
         Optional("use_max_lift_drag_ratio", default=None):
         cls._get_value_schema(Bool(), False),
     }
예제 #2
0
 def schema(cls):
     return Map({
         **super().schema()._validator,
         Optional("default"): Int(),
         Optional("min"): Int(),
         Optional("max"): Int(),
     })
예제 #3
0
 def _get_mission_mapping(cls) -> dict:
     return {
         "name":
         Str(),
         STEPS_TAG:
         Seq(
             Map({
                 Optional(ROUTE_TAG, default=None): Str(),
                 Optional(PHASE_TAG, default=None): Str(),
             })),
     }
예제 #4
0
 def _get_route_schema(cls) -> Map:
     """Schema of the route section."""
     return Map({
         Optional("range", default=None):
         cls._get_value_schema(),
         Optional(CLIMB_PARTS_TAG, default=None):
         Seq(Map({PHASE_TAG: Str()})),
         CRUISE_PART_TAG:
         cls._get_segment_schema(),
         Optional(DESCENT_PARTS_TAG, default=None):
         Seq(Map({PHASE_TAG: Str()})),
     })
예제 #5
0
 def _get_target_schema(cls) -> Map:
     target_schema_map = {}
     for key in [f.name for f in fields(FlightPoint)]:
         target_schema_map[Optional(
             key,
             default=None)] = (Float()
                               | Str()
                               | Map({
                                   "value": Float() | Str(),
                                   Optional("unit", default=None): Str()
                               }))
     return Map(target_schema_map)
def get_type_schema_yaml_validator() -> Map:
    seq_validator = Seq(
        Map({
            "field": Enum([str(el) for el in Fields]),
            "condition": Str(),
            "value": Str() | Seq(Str()),
        }))
    return Map({
        Optional(str(RequirementTypes.INPUT_REQUIREMENTS)):
        seq_validator,
        Optional(str(RequirementTypes.OUTPUT_REQUIREMENTS)):
        seq_validator,
    })
예제 #7
0
    def schema(cls):
        basic_schema = {
            # `name` is not part of the sub schema, name is the title of the block.
            # `default` must be provided in the sub schemas
            "label": Str(),
            "type": Str(),
            Optional("variable"): Str(),
            Optional("helptext"): Str(),
            Optional("visibility"): Int(),
            Optional("required"): Bool(),
            Optional("readonly"): Bool(),
        }

        return Map(basic_schema)
예제 #8
0
파일: schema.py 프로젝트: akva2/badger
def FileMapping(glob_allowed: bool):
    """Validator that matches a file mapping: a string or a mapping with
    source and target.
    """
    if glob_allowed:
        return Str() | Map({
            'source': Str(),
            Optional('target'): Str(),
            Optional('mode'): Choice('simple', 'glob')
        })

    return Str() | Map({
        'source': Str(),
        Optional('target'): Str(),
    })
예제 #9
0
 def get_robot_part_schema():
     """
     Getter for robot schema
     :return: schema that is used to verify the robot yaml
     """
     return Map({
         'name': Str(),
         'type': Str(),
         'brick': Int(),
         'x_offset': Float(),
         'y_offset': Float(),
         Optional('port'): Regex('ev3-ports:(in[1-4]|out[A-D])'),
         Optional('side'): Regex('left|right|rear'),
         Optional('direction'): Regex('bottom|front'),
     })
예제 #10
0
 def __init__(self):
     super().__init__({
         Optional(conditions.ALL_OF): self,
         Optional(conditions.ENDS_EARLIER_THAN): Float(),
         Optional(conditions.ENDS_LATER_THAN): Float(),
         Optional(conditions.NONE_OF): self,
         Optional(conditions.ONE_OF): self,
         Optional(conditions.STARTS_EARLIER_THAN): Float(),
         Optional(conditions.STARTS_LATER_THAN): Float(),
         Optional(conditions.WEEKDAY): Bool(),
         Optional(conditions.WEEKEND): Bool(),
     })
예제 #11
0
 def _get_phase_mapping(cls) -> dict:
     phase_map = {
         Optional(STEPS_TAG, default=None):
         Seq(Map(cls._get_segment_mapping()))
     }
     phase_map.update(cls._get_base_step_mapping())
     return phase_map
예제 #12
0
 def _get_phase_schema(cls) -> Map:
     """Schema of the phase section."""
     phase_map = {
         Optional(PARTS_TAG, default=None): Seq(cls._get_segment_schema())
     }
     phase_map.update(cls._get_base_part_mapping())
     return Map(phase_map)
예제 #13
0
 def _get_target_schema(cls) -> Map:
     """Schema for segment target."""
     target_schema_map = {}
     for key in [f.name for f in fields(FlightPoint)]:
         target_schema_map[Optional(
             key, default=None)] = cls._get_value_schema()
     return Map(target_schema_map)
예제 #14
0
파일: schema.py 프로젝트: akva2/badger
def Regex():
    """Validator that matches a regex: a mapping with pattern and optional
    mode.
    """
    return Map({
        'pattern': Str(),
        Optional('mode'): Choice('first', 'last', 'all'),
    })
예제 #15
0
파일: schema.py 프로젝트: akva2/badger
def NumberCapture():
    """Validator that matches a predefined integer or float capture."""
    return Map({
        'type': Choice('integer', 'float'),
        'name': Str(),
        'prefix': Str(),
        Optional('mode'): Choice('first', 'last', 'all'),
    })
예제 #16
0
 def _get_mission_schema(cls) -> Map:
     """Schema of the mission section."""
     return Map({
         PARTS_TAG:
         Seq(
             Map({
                 Optional(ROUTE_TAG, default=None):
                 Str(),
                 Optional(PHASE_TAG, default=None):
                 Str(),
                 Optional(RESERVE_TAG, default=None):
                 Map({
                     "ref": Str(),
                     "multiplier": Float() | Str()
                 }),
             })),
     })
예제 #17
0
    def _get_value_schema(cls,
                          value_type: ScalarValidator = Float(),
                          has_unit=True) -> Validator:
        """Schema for parameter value."""
        map_dict = {"value": Float() | Str()}
        if has_unit:
            map_dict[Optional("unit", default=None)] = Str()

        return value_type | Str() | Map(map_dict)
예제 #18
0
    def __init__(self, filename):
        """Load config from YAML file."""
        filename = path.abspath(filename)

        if filename is None:
            self._config = []
        else:
            try:
                with open(filename, 'r') as handle:
                    self._yaml = handle.read()

                self._config = load(
                    self._yaml,
                    Seq(
                        Map({
                            Optional("name"):
                            "name",
                            "request":
                            Map({
                                Optional("path"):
                                Str(),
                                Optional("method"):
                                Enum([
                                    "get",
                                    "post",
                                    "put",
                                    "delete",
                                    "GET",
                                    "POST",
                                    "PUT",
                                    "DELETE",
                                ]),
                                Optional("headers"):
                                MapPattern(Str(), Str()),
                                Optional("data"):
                                Str(),
                            }),
                            "response":
                            Map({
                                "content": Str() | Map({"file": Str()}),
                                Optional("code"): Int(),
                                Optional("headers"): MapPattern(Str(), Str()),
                            }),
                        })))
            except Exception as e:
                sys.stderr.write(
                    "Error reading YAML config file: {0}\n".format(str(e)))
                sys.exit(1)

            # Read and store all references to external content files
            for pair in self._config:
                content = pair.get('response', {}).get('content')
                if type(content) != str and "file" in content:
                    with open(
                            path.join(path.dirname(filename), content['file']),
                            'r') as content_file_handle:
                        pair['response']['content'] = \
                            content_file_handle.read()
예제 #19
0
def build_schema_for_cubes():
    """
    Returns
    -------
    dict
        each element is str -> strictyaml.Map
        where key is name of cube,
        value is a schema used for validation and type-coercion
    """
    schemas = {}
    for class_of_object in SUPPORTED_CUBES:
        res = build_schema_from_signature(class_of_object)

        # "selection" isn't used in __init__, but we will need it later
        res["selection"] = Seq(Str())

        # shortcut for strategy intialization
        if is_key_in_schema("strategy", res):
            signature_validation = {}
            for strategy_class in SUPPORTED_STRATEGIES:
                local_signature_validation = build_schema_from_signature(
                    strategy_class)
                signature_validation.update(local_signature_validation)
            res[Optional("strategy_params")] = Map(signature_validation)

        # we will deal with "values" later, but we can check at least some simple things already
        if class_of_object.__name__ == "CubeCreator":
            element = Map({"name": Str(), "values": Seq(Any())})
            res["parameters"] = Seq(element)
        if class_of_object.__name__ == "RegularizersModifierCube":
            element = Map({
                Optional("name"): Str(),
                Optional("regularizer"): Any(),
                Optional("tau_grid"): Seq(Float())
            })
            res["regularizer_parameters"] = element | Seq(element)

        res = Map(res)

        specific_schema = Map({class_of_object.__name__: res})
        schemas[class_of_object.__name__] = specific_schema
    return schemas
예제 #20
0
def read_corpus_config(filename='corpus.yml'):
    schema = Map({
        'dataset_path': Str(),
        'batches_prefix': Str(),
        'word': Str(),
        'name': Str(),
        Optional("num_topics_interval"): Int(),
        Optional("nums_topics"): CommaSeparated(Int()),
        'min_num_topics': Int(),
        'max_num_topics': Int(),
        'num_fit_iterations': Int(),
        'num_restarts': Int(),
    })

    with open(filename, 'r') as f:
        string = f.read()

    data = strictyaml.load(string, schema=schema).data

    return data
예제 #21
0
def choose_key(param):
    """
    Parameters
    ----------
    param : inspect.Parameter

    Returns
    -------
    str or strictyaml.Optional
    """
    if param.default is not Parameter.empty:
        return Optional(param.name)
    return param.name
예제 #22
0
def get_schema(snippet):
    if snippet['type'] == "integer":
        return Int()
    elif snippet['type'] == "string":
        return Str()
    elif snippet['type'] == "array":
        return Seq(get_schema(snippet["items"]))
    elif snippet['type'] == "object":
        map_schema = {}
        for key, subschema in snippet['properties'].items():
            if key in snippet.get('required', []):
                map_schema[Optional(key)] = get_schema(subschema)
            else:
                map_schema[key] = get_schema(subschema)
        return Map(map_schema)
예제 #23
0
 def _get_base_step_mapping(cls) -> dict:
     polar_coeff_schema = CommaSeparated(Float()) | Str()
     polar_schema = Map({
         "CL": polar_coeff_schema,
         "CD": polar_coeff_schema
     }) | Str()
     return {
         Optional("target", default=None): cls._get_target_schema(),
         Optional("engine_setting", default=None): Str(),
         Optional(POLAR_TAG, default=None): polar_schema,
         Optional("thrust_rate", default=None): Float() | Str(),
         Optional("climb_thrust_rate", default=None): Float() | Str(),
         Optional("time_step", default=None): Float(),
         Optional("maximum_flight_level", default=None): Float() | Str(),
     }
예제 #24
0
def build_schema_for_regs():
    """
    Returns
    -------
    strictyaml.Map
        schema used for validation and type-coercion
    """
    schemas = {}
    for elem in artm.regularizers.__all__:
        if "Regularizer" in elem:
            class_of_object = getattr(artm.regularizers, elem)
            res = build_schema_from_signature(class_of_object)
            if elem in ["SmoothSparseThetaRegularizer", "SmoothSparsePhiRegularizer",
                        "DecorrelatorPhiRegularizer"]:
                res[Optional("relative", default=None)] = Bool()
            res = wrap_in_map(res)

            specific_schema = Map({class_of_object.__name__: res})
            schemas[class_of_object.__name__] = specific_schema

    return schemas
예제 #25
0
def is_pipelines_config_valid(strictyaml_pipelines: YAML) -> YAML:
    """
    TODO: Refactor to test and analyzer specific config validation.
    """
    pipelines_schema = Map({
        "pipelines":
        Seq(
            Map({
                "name":
                Str(),
                "type":
                Enum(["test", "analyzer"]),
                Optional("coverage"):
                Str(),
                Optional("commands"):
                Map({
                    "partial-scope": Str(),
                    "full-scope": Str()
                }),
                Optional("dirs"):
                Seq(
                    Map({
                        "path": Str(),
                        Optional("full-scope", default=False): Bool()
                    })),
                Optional("files"):
                Seq(
                    Map({
                        "path": Str(),
                        Optional("full-scope", default=False): Bool()
                    }))
            }))
    })
    try:
        strictyaml_pipelines.revalidate(pipelines_schema)
        return True
    except YAMLValidationError:
        return False
예제 #26
0
from strictyaml import Map, MapPattern, Optional
from strictyaml import Str, Int, Seq, Enum, Any, as_document

JSONSCHEMA_TYPE_SNIPPET = {
    "type": Enum(["object", "integer", "string", "array"]),
    Optional("required"): Seq(Str()),
    Optional("properties"): MapPattern(Str(), Any()),
    Optional("items"): Any(),
}

JSONSCHEMA_SCHEMA = Map(JSONSCHEMA_TYPE_SNIPPET)


def get_schema(snippet):
    if snippet['type'] == "integer":
        return Int()
    elif snippet['type'] == "string":
        return Str()
    elif snippet['type'] == "array":
        return Seq(get_schema(snippet["items"]))
    elif snippet['type'] == "object":
        map_schema = {}
        for key, subschema in snippet['properties'].items():
            if key in snippet.get('required', []):
                map_schema[Optional(key)] = get_schema(subschema)
            else:
                map_schema[key] = get_schema(subschema)
        return Map(map_schema)


def load_schema(json_schema):
예제 #27
0
class Engine(BaseEngine):
    """Python engine for running tests."""

    given_definition = GivenDefinition(
        scripts=GivenProperty(MapPattern(Str(), Str())),
        python_version=GivenProperty(Str()),
        pexpect_version=GivenProperty(Str()),
        icommandlib_version=GivenProperty(Str()),
        setup=GivenProperty(Str()),
        files=GivenProperty(MapPattern(Str(), Str())),
        code=GivenProperty(Str()),
    )

    info_definition = InfoDefinition(
        importance=InfoProperty(schema=Int()),
        docs=InfoProperty(schema=Str()),
        fails_on_python_2=InfoProperty(schema=Bool()),
    )

    def __init__(self, keypath, rewrite=False):
        self.path = keypath
        self._rewrite = rewrite
        self._cprofile = False

    def set_up(self):
        """Set up your applications and the test environment."""
        self.path.state = self.path.gen.joinpath("state")
        if self.path.state.exists():
            self.path.state.rmtree(ignore_errors=True)
        self.path.state.mkdir()

        for script in self.given.get("scripts", []):
            script_path = self.path.state.joinpath(script)

            if not script_path.dirname().exists():
                script_path.dirname().makedirs()

            script_path.write_text(self.given["scripts"][script])
            script_path.chmod("u+x")

        for filename, contents in self.given.get("files", {}).items():
            self.path.state.joinpath(filename).write_text(contents)

        self.python = hitchpylibrarytoolkit.project_build(
            "commandlib", self.path, self.given["python version"]
        ).bin.python

        self.example_py_code = (
            ExamplePythonCode(self.python, self.path.state)
            .with_code(self.given.get("code", ""))
            .with_setup_code(self.given.get("setup", ""))
        )

    def _story_friendly_output(self, text):
        return text.replace(self.path.state, "/path/to")

    @no_stacktrace_for(AssertionError)
    @no_stacktrace_for(HitchRunPyException)
    @validate(
        code=Str(),
        will_output=Str(),
        raises=Map(
            {
                Optional("type"): Map({"in python 2": Str(), "in python 3": Str()})
                | Str(),
                Optional("message"): Map({"in python 2": Str(), "in python 3": Str()})
                | Str(),
            }
        ),
    )
    def run(self, code, will_output=None, raises=None):
        to_run = self.example_py_code.with_code(code)

        if self._cprofile:
            to_run = to_run.with_cprofile(
                self.path.profile.joinpath("{0}.dat".format(self.story.slug))
            )

        result = (
            to_run.expect_exceptions().run() if raises is not None else to_run.run()
        )

        if will_output is not None:
            actual_output = "\n".join(
                [line.rstrip() for line in result.output.split("\n")]
            )
            try:
                Templex(will_output).assert_match(actual_output)
            except AssertionError:
                if self._rewrite:
                    self.current_step.update(**{"will output": actual_output})
                else:
                    raise

        if raises is not None:
            differential = False  # Difference between python 2 and python 3 output?
            exception_type = raises.get("type")
            message = raises.get("message")

            if exception_type is not None:
                if not isinstance(exception_type, str):
                    differential = True
                    exception_type = (
                        exception_type["in python 2"]
                        if self.given["python version"].startswith("2")
                        else exception_type["in python 3"]
                    )

            if message is not None:
                if not isinstance(message, str):
                    differential = True
                    message = (
                        message["in python 2"]
                        if self.given["python version"].startswith("2")
                        else message["in python 3"]
                    )

            try:
                result = self.example_py_code.expect_exceptions().run()
                result.exception_was_raised(exception_type)
                exception_message = self._story_friendly_output(
                    result.exception.message
                )
                Templex(exception_message).assert_match(message)
            except AssertionError:
                if self._rewrite and not differential:
                    new_raises = raises.copy()
                    new_raises["message"] = self._story_friendly_output(
                        result.exception.message
                    )
                    self.current_step.update(raises=new_raises)
                else:
                    raise

    def file_contents_will_be(self, filename, contents):
        file_contents = "\n".join(
            [
                line.rstrip()
                for line in self.path.state.joinpath(filename)
                .bytes()
                .decode("utf8")
                .strip()
                .split("\n")
            ]
        )
        try:
            # Templex(file_contents).assert_match(contents.strip())
            assert file_contents == contents.strip(), "{0} not {1}".format(
                file_contents, contents.strip()
            )
        except AssertionError:
            if self._rewrite:
                self.current_step.update(contents=file_contents)
            else:
                raise

    def pause(self, message="Pause"):
        import IPython

        IPython.embed()

    def on_success(self):
        if self._cprofile:
            self.python(
                self.path.key.joinpath("printstats.py"),
                self.path.profile.joinpath("{0}.dat".format(self.story.slug)),
            ).run()
예제 #28
0
파일: key.py 프로젝트: tobbez/strictyaml
class Engine(BaseEngine):
    """Python engine for running tests."""
    schema = StorySchema(
        preconditions=Map({
            "files": MapPattern(Str(), Str()),
            "variables": MapPattern(Str(), Str()),
            "python version": Str(),
            "ruamel version": Str(),
        }),
        params=Map({
            "python version": Str(),
            "ruamel version": Str(),
        }),
        about={
            "description": Str(),
            Optional("importance"): Int(),
        },
    )

    def __init__(self, keypath, settings):
        self.path = keypath
        self.settings = settings

    def set_up(self):
        """Set up your applications and the test environment."""
        self.doc = hitchdoc.Recorder(
            hitchdoc.HitchStory(self),
            self.path.gen.joinpath('storydb.sqlite'),
        )

        if self.path.gen.joinpath("state").exists():
            self.path.gen.joinpath("state").rmtree(ignore_errors=True)
        self.path.gen.joinpath("state").mkdir()
        self.path.state = self.path.gen.joinpath("state")

        for filename, text in self.preconditions.get("files", {}).items():
            filepath = self.path.state.joinpath(filename)
            if not filepath.dirname().exists():
                filepath.dirname().mkdir()
            filepath.write_text(text)

        self.python_package = hitchpython.PythonPackage(
            self.preconditions.get('python_version', '3.5.0'))
        self.python_package.build()

        self.pip = self.python_package.cmd.pip
        self.python = self.python_package.cmd.python

        # Install debugging packages
        with hitchtest.monitor(
            [self.path.key.joinpath("debugrequirements.txt")]) as changed:
            if changed:
                run(
                    self.pip("install", "-r",
                             "debugrequirements.txt").in_dir(self.path.key))

        # Uninstall and reinstall
        run(self.pip("uninstall", "strictyaml", "-y").ignore_errors())
        run(self.pip("install", ".").in_dir(self.path.project))
        run(
            self.pip(
                "install", "ruamel.yaml=={0}".format(
                    self.preconditions["ruamel version"])))

        self.services = hitchserve.ServiceBundle(str(self.path.project),
                                                 startup_timeout=8.0,
                                                 shutdown_timeout=1.0)

        self.services['IPython'] = hitchpython.IPythonKernelService(
            self.python_package)

        self.services.startup(interactive=False)
        self.ipython_kernel_filename = self.services[
            'IPython'].wait_and_get_ipykernel_filename()
        self.ipython_step_library = hitchpython.IPythonStepLibrary()
        self.ipython_step_library.startup_connection(
            self.ipython_kernel_filename)

        self.shutdown_connection = self.ipython_step_library.shutdown_connection
        self.ipython_step_library.run("import os")
        self.ipython_step_library.run("import sure")
        self.ipython_step_library.run("from path import Path")
        self.ipython_step_library.run("os.chdir('{}')".format(self.path.state))

        for filename, text in self.preconditions.get("files", {}).items():
            self.ipython_step_library.run(
                """{} = Path("{}").bytes().decode("utf8")""".format(
                    filename.replace(".yaml", ""), filename))

    def run_command(self, command):
        self.ipython_step_library.run(command)
        self.doc.step("code", command=command)

    def variable(self, name, value):
        self.path.state.joinpath("{}.yaml".format(name)).write_text(value)
        self.ipython_step_library.run(
            """{} = Path("{}").bytes().decode("utf8")""".format(
                name, "{}.yaml".format(name)))
        self.doc.step("variable", var_name=name, value=value)

    def code(self, command):
        self.ipython_step_library.run(command)
        self.doc.step("code", command=command)

    @validate(exception=Str())
    def raises_exception(self, command, exception, why=''):
        """
        Command raises exception.
        """
        import re
        self.error = self.ipython_step_library.run(
            command, swallow_exception=True).error
        if self.error is None:
            raise Exception("Expected exception, but got none")
        full_exception = re.compile("(?:\\x1bs?\[0m)+(?:\n+)+{0}".format(
            re.escape("\x1b[0;31m"))).split(self.error)[-1]
        exception_class_name, exception_text = full_exception.split(
            "\x1b[0m: ")
        if self.settings.get("overwrite"):
            self.current_step.update(exception=str(exception_text))
        else:
            assert exception.strip(
            ) in exception_text, "UNEXPECTED:\n{0}".format(exception_text)
        self.doc.step(
            "exception",
            command=command,
            exception_class_name=exception_class_name,
            exception=exception_text,
            why=why,
        )

    def returns_true(self, command, why=''):
        self.ipython_step_library.assert_true(command)
        self.doc.step("true", command=command, why=why)

    def should_be_equal(self, lhs='', rhs='', why=''):
        command = """({0}).should.be.equal({1})""".format(lhs, rhs)
        self.ipython_step_library.run(command)
        self.doc.step("true", command=command, why=why)

    def assert_true(self, command):
        self.ipython_step_library.assert_true(command)
        self.doc.step("true", command=command)

    def assert_exception(self, command, exception):
        error = self.ipython_step_library.run(command,
                                              swallow_exception=True).error
        assert exception.strip() in error
        self.doc.step("exception", command=command, exception=exception)

    def on_failure(self):
        if self.settings.get("pause_on_failure", True):
            if self.preconditions.get("launch_shell", False):
                self.services.log(message=self.stacktrace.to_template())
                self.shell()

    def shell(self):
        if hasattr(self, 'services'):
            self.services.start_interactive_mode()
            import sys
            import time
            time.sleep(0.5)
            if path.exists(
                    path.join(path.expanduser("~"),
                              ".ipython/profile_default/security/",
                              self.ipython_kernel_filename)):
                call([
                    sys.executable, "-m", "IPython", "console", "--existing",
                    "--no-confirm-exit",
                    path.join(path.expanduser("~"),
                              ".ipython/profile_default/security/",
                              self.ipython_kernel_filename)
                ])
            else:
                call([
                    sys.executable, "-m", "IPython", "console", "--existing",
                    self.ipython_kernel_filename
                ])
            self.services.stop_interactive_mode()

    def assert_file_contains(self, filename, contents):
        assert self.path.state.joinpath(filename).bytes().decode(
            'utf8').strip() == contents.strip()
        self.doc.step("filename contains",
                      filename=filename,
                      contents=contents)

    def pause(self, message="Pause"):
        if hasattr(self, 'services'):
            self.services.start_interactive_mode()
        import IPython
        IPython.embed()
        if hasattr(self, 'services'):
            self.services.stop_interactive_mode()

    def on_success(self):
        if self.settings.get("overwrite"):
            self.new_story.save()

    def tear_down(self):
        try:
            self.shutdown_connection()
        except:
            pass
        if hasattr(self, 'services'):
            self.services.shutdown()
예제 #29
0
    TRAINING_MODEL = "trainingModel"
    HYPERPARAMETERS = "hyperparameters"
    VALIDATION_SCHEMA = "typeSchema"
    # customPredictor section is not used by DRUM,
    # it is a place holder if user wants to add some fields and read them on his own
    CUSTOM_PREDICTOR = "customPredictor"


MODEL_CONFIG_SCHEMA = Map({
    ModelMetadataKeys.NAME:
    Str(),
    ModelMetadataKeys.TYPE:
    Str(),
    ModelMetadataKeys.TARGET_TYPE:
    Str(),
    Optional(ModelMetadataKeys.ENVIRONMENT_ID):
    Str(),
    Optional(ModelMetadataKeys.VALIDATION):
    Map({
        "input": Str(),
        Optional("targetName"): Str()
    }),
    Optional(ModelMetadataKeys.MODEL_ID):
    Str(),
    Optional(ModelMetadataKeys.DESCRIPTION):
    Str(),
    Optional(ModelMetadataKeys.MAJOR_VERSION):
    Bool(),
    Optional(ModelMetadataKeys.INFERENCE_MODEL):
    Map({
        Optional("targetName"): Str(),
예제 #30
0
from strictyaml import Bool, Int, load, Map, Optional, Str, YAMLError

from datarobot_drum.drum.common import RunMode
from datarobot_drum.drum.exceptions import DrumCommonException

CONFIG_FILENAME = "model-metadata.yaml"
DR_LINK_FORMAT = "{}/model-registry/custom-models/{}"
MODEL_LOGS_LINK_FORMAT = "{url}/projects/{project_id}/models/{model_id}/log"

schema = Map(
    {
        "name": Str(),
        "type": Str(),
        "environmentID": Str(),
        "targetType": Str(),
        "validation": Map({"input": Str(), Optional("targetName"): Str()}),
        Optional("modelID"): Str(),
        Optional("description"): Str(),
        Optional("majorVersion"): Bool(),
        Optional("inferenceModel"): Map(
            {
                "targetName": Str(),
                Optional("positiveClassLabel"): Str(),
                Optional("negativeClassLabel"): Str(),
                Optional("predictionThreshold"): Int(),
            }
        ),
        Optional("trainingModel"): Map({Optional("trainOnProject"): Str()}),
    }
)