Ejemplo n.º 1
0
Archivo: base.py Proyecto: dmh43/datmo
class BaseCommand(object):
    def __init__(self, cli_helper):
        self.home = Config().home
        self.cli_helper = cli_helper
        self.logger = DatmoLogger.get_logger(__name__)
        self.parser = get_datmo_parser()

    def parse(self, args):
        try:
            self.display_usage_message(args)
            self.args = self.parser.parse_args(args)
        except SystemExit:
            self.args = True
            pass

    def display_usage_message(self, args):
        """ Checks to see if --help or -h is passed in, and if so it calls our usage()
        if it exists.

        Since argparser thinks it is clever and automatically
        handles [--help, -h] we need a hook to be able to display
        our own usage notes before argparse

        Parameters
        ----------
        args : array[string]
            command arguments
        """

        is_help = -1
        if "--help" in args:
            is_help = args.index("--help")
        if is_help == -1 and "-h" in args:
            is_help = args.index("-h")

        if is_help > -1 and hasattr(self, "usage"):
            self.usage()

    def execute(self):
        """
        Calls the method if it exists on this object, otherwise
        call a default method name (module name)

        Raises
        ------
        ClassMethodNotFound
            If the Class method is not found
        """
        # Sometimes eg(--help) the parser automagically handles the entire response
        # and calls exit.  If this happens, self.args is set to True
        # in base.parse.   Simply return True
        if self.args is True: return True

        if getattr(self.args, "command") is None:
            self.args.command = "datmo"

        command_args = vars(self.args).copy()
        # use command name if it exists,
        # otherwise use the module name
        function_name = None
        method = None

        try:
            if "subcommand" in command_args and command_args['subcommand'] is not None:
                function_name = getattr(self.args, "subcommand",
                                        self.args.command)
                method = getattr(self, function_name)
            else:
                function_name = getattr(self.args, "command",
                                        self.args.command)
                method = getattr(self, function_name)
        except AttributeError:
            raise ClassMethodNotFound(
                __("error", "cli.general.method.not_found",
                   (self.args.command, function_name)))

        # remove extraneous options that the method should need to care about
        if "command" in command_args:
            del command_args["command"]
        if "subcommand" in command_args:
            del command_args["subcommand"]

        if method is None:
            raise ClassMethodNotFound(
                __("error", "cli.general.method.not_found",
                   (self.args.command, method)))

        method_result = method(**command_args)
        return method_result

    def task_run_helper(self, task_dict, snapshot_dict, error_identifier):
        """
        Run task with given parameters and provide error identifier

        Parameters
        ----------
        task_dict : dict
            input task dictionary for task run controller
        snapshot_dict : dict
            input snapshot dictionary for task run controller
        error_identifier : str
            identifier to print error

        Returns
        -------
        Task or False
            the Task object which completed its run with updated parameters.
            returns False if an error occurs
        """
        self.task_controller = TaskController()
        task_obj = self.task_controller.create()

        updated_task_obj = task_obj
        # Pass in the task
        status = "NOT STARTED"
        try:
            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
            status = "SUCCESS"
        except Exception as e:
            status = "FAILED"
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo("%s" % e)
            self.cli_helper.echo(__("error", error_identifier, task_obj.id))
            return False
        finally:
            self.cli_helper.echo(__("info", "cli.run.run.stop"))
            self.task_controller.stop(
                task_id=updated_task_obj.id, status=status)
            self.cli_helper.echo(
                __("info", "cli.run.run.complete", updated_task_obj.id))

        return updated_task_obj
Ejemplo n.º 2
0
class TestSnapshotController():
    def setup_method(self):
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        Config().set_home(self.temp_dir)
        self.environment_ids = []

    def teardown_method(self):
        if not check_docker_inactive(test_datmo_dir):
            self.__setup()
            self.environment_controller = EnvironmentController()
            for env_id in list(set(self.environment_ids)):
                if not self.environment_controller.delete(env_id):
                    raise Exception

    def __setup(self):
        Config().set_home(self.temp_dir)
        self.project_controller = ProjectController()
        self.project_controller.init("test", "test description")
        self.task_controller = TaskController()
        self.snapshot_controller = SnapshotController()

    def test_init_fail_project_not_init(self):
        Config().set_home(self.temp_dir)
        failed = False
        try:
            SnapshotController()
        except ProjectNotInitialized:
            failed = True
        assert failed

    def test_init_fail_invalid_path(self):
        test_home = "some_random_dir"
        Config().set_home(test_home)
        failed = False
        try:
            SnapshotController()
        except InvalidProjectPath:
            failed = True
        assert failed

    def test_current_snapshot(self):
        self.__setup()
        # Test failure for unstaged changes
        failed = False
        try:
            self.snapshot_controller.current_snapshot()
        except UnstagedChanges:
            failed = True
        assert failed
        # Test success after snapshot created
        snapshot_obj = self.__default_create()
        current_snapshot_obj = self.snapshot_controller.current_snapshot()
        assert current_snapshot_obj == snapshot_obj

    def test_create_fail_no_message(self):
        self.__setup()
        # Test no message
        failed = False
        try:
            self.snapshot_controller.create({})
        except RequiredArgumentMissing:
            failed = True
        assert failed

    def test_create_success_no_code(self):
        self.__setup()
        # Test default values for snapshot, fail due to code
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})
        assert result

    def test_create_success_no_code_environment(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.file_driver.environment_directory,
            "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # test must pass when there is file present in root project folder
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert result

    def test_create_success_no_code_environment_files(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.file_driver.environment_directory,
            "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        test_file = os.path.join(
            self.project_controller.file_driver.files_directory, "test.txt")
        with open(test_file, "wb") as f:
            f.write(to_bytes(str("hello")))

        # test must pass when there is file present in root project folder
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert result

    def test_create_no_environment_detected_in_file(self):
        self.__setup()

        # Test default values for snapshot, fail due to no environment from file
        self.snapshot_controller.file_driver.create("filepath1")
        snapshot_obj_0 = self.snapshot_controller.create(
            {"message": "my test snapshot"})
        assert isinstance(snapshot_obj_0, Snapshot)
        assert snapshot_obj_0.code_id
        assert snapshot_obj_0.environment_id
        assert snapshot_obj_0.file_collection_id
        assert snapshot_obj_0.config == {}
        assert snapshot_obj_0.stats == {}

    def test_create_success_default_detected_in_file(self):
        self.__setup()
        # Test default values for snapshot when there is no environment
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import os\n"))
            f.write(to_bytes("import sys\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj_1 = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj_1, Snapshot)
        assert snapshot_obj_1.code_id
        assert snapshot_obj_1.environment_id
        assert snapshot_obj_1.file_collection_id
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_default_env_def(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.file_driver.environment_directory,
            "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_with_environment(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.file_driver.environment_directory,
            "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_env_paths(self):
        self.__setup()
        # Create environment definition
        random_dir = os.path.join(self.snapshot_controller.home, "random_dir")
        os.makedirs(random_dir)
        env_def_path = os.path.join(random_dir, "randomDockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))
        environment_paths = [env_def_path + ">Dockerfile"]

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create({
            "message":
            "my test snapshot",
            "environment_paths":
            environment_paths
        })

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_default_env_def_duplicate(self):
        self.__setup()
        # Test 2 snapshots with same parameters
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.file_driver.environment_directory,
            "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        snapshot_obj_1 = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Should return the same object back
        assert snapshot_obj_1.id == snapshot_obj.id
        assert snapshot_obj_1.code_id == snapshot_obj.code_id
        assert snapshot_obj_1.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_1.file_collection_id == \
               snapshot_obj.file_collection_id
        assert snapshot_obj_1.config == \
               snapshot_obj.config
        assert snapshot_obj_1.stats == \
               snapshot_obj.stats

    def test_create_success_given_files_env_def_config_file_stats_file(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.file_driver.environment_directory,
            "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Create files to add
        _, project_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        self.snapshot_controller.file_driver.create(os.path.join(
            project_directory_name, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            project_directory_name, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(project_directory_name, "filepath1"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        input_dict = {
            "message": "my test snapshot",
            "config_filepath": config_filepath,
            "stats_filepath": stats_filepath,
        }
        # Create snapshot in the project
        snapshot_obj_4 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_4 != snapshot_obj
        assert snapshot_obj_4.code_id != snapshot_obj.code_id
        assert snapshot_obj_4.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_4.file_collection_id != \
               snapshot_obj.file_collection_id
        assert snapshot_obj_4.config == {"foo": "bar"}
        assert snapshot_obj_4.stats == {"foo": "bar"}

    def test_create_success_given_files_env_def_different_config_stats(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.file_driver.environment_directory,
            "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Create files to add
        _, project_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        self.snapshot_controller.file_driver.create(os.path.join(
            project_directory_name, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            project_directory_name, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(project_directory_name, "filepath1"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Test different config and stats inputs
        input_dict = {
            "message": "my test snapshot",
            "config_filename": "different_name",
            "stats_filename": "different_name",
        }

        # Create snapshot in the project
        snapshot_obj_1 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_1 != snapshot_obj
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_given_files_env_def_direct_config_stats(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.file_driver.environment_directory,
            "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create files to add
        _, project_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        self.snapshot_controller.file_driver.create(os.path.join(
            project_directory_name, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            project_directory_name, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(project_directory_name, "filepath1"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test different config and stats inputs
        input_dict = {
            "message": "my test snapshot",
            "config": {
                "foo": "bar"
            },
            "stats": {
                "foo": "bar"
            },
        }

        # Create snapshot in the project
        snapshot_obj_6 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_6.config == {"foo": "bar"}
        assert snapshot_obj_6.stats == {"foo": "bar"}

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_create_from_task(self):
        self.__setup()
        # 0) Test if fails with TaskNotComplete error
        # 1) Test if success with empty task files, results
        # 2) Test if success with task files, results, and message
        # 3) Test if success with message, label, config and stats
        # 4) Test if success with updated stats from after_snapshot_id and task_results

        # Create task in the project
        task_obj = self.task_controller.create()

        # 0) Test option 0
        failed = False
        try:
            _ = self.snapshot_controller.create_from_task(
                message="my test snapshot", task_id=task_obj.id)
        except TaskNotComplete:
            failed = True
        assert failed

        # 1) Test option 1

        # Create task_dict
        task_command = ["sh", "-c", "echo test"]
        task_dict = {"command_list": task_command}

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

        # Create new task and corresponding dict
        task_obj = self.task_controller.create()
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Test the default values
        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        # 2) Test option 2
        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

        # 3) Test option 3
        test_config = {"algo": "regression"}
        test_stats = {"accuracy": 0.9}
        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot",
            task_id=updated_task_obj.id,
            label="best",
            config=test_config,
            stats=test_stats)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.label == "best"
        assert snapshot_obj.config == test_config
        assert snapshot_obj.stats == test_stats
        assert snapshot_obj.visible == True

        # 4) Test option 4
        test_config = {"algo": "regression"}
        test_stats = {"new_key": 0.9}
        task_obj_2 = self.task_controller.create()
        updated_task_obj_2 = self.task_controller.run(task_obj_2.id,
                                                      task_dict=task_dict,
                                                      snapshot_dict={
                                                          "config":
                                                          test_config,
                                                          "stats": test_stats
                                                      })
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj_2.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot",
            task_id=updated_task_obj_2.id,
            label="best")
        updated_stats_dict = {}
        updated_stats_dict.update(test_stats)
        updated_stats_dict.update(updated_task_obj.results)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj_2.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.label == "best"
        assert snapshot_obj.stats == updated_stats_dict
        assert snapshot_obj.visible == True

    def __default_create(self):
        # Create files to add
        _, project_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        self.snapshot_controller.file_driver.create(os.path.join(
            project_directory_name, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            project_directory_name, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(project_directory_name, "filepath1"))
        self.snapshot_controller.file_driver.create("filepath2")
        with open(os.path.join(self.snapshot_controller.home, "filepath2"),
                  "wb") as f:
            f.write(to_bytes(str("import sys\n")))
        # Create environment_driver definition
        env_def_path = os.path.join(
            self.project_controller.file_driver.environment_directory,
            "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        input_dict = {
            "message": "my test snapshot",
            "config_filename": config_filepath,
            "stats_filename": stats_filepath,
        }

        # Create snapshot in the project
        return self.snapshot_controller.create(input_dict)

    def test_check_unstaged_changes(self):
        self.__setup()
        # Check unstaged changes
        failed = False
        try:
            self.snapshot_controller.check_unstaged_changes()
        except UnstagedChanges:
            failed = True
        assert failed
        # Check no unstaged changes
        _ = self.__default_create()
        result = self.snapshot_controller.check_unstaged_changes()
        assert result == False

    def test_checkout(self):
        self.__setup()
        # Create snapshot
        snapshot_obj_1 = self.__default_create()

        # Create duplicate snapshot in project
        self.snapshot_controller.file_driver.create("test")
        snapshot_obj_2 = self.__default_create()

        assert snapshot_obj_2 != snapshot_obj_1

        # Checkout to snapshot 1 using snapshot id
        result = self.snapshot_controller.checkout(snapshot_obj_1.id)
        # TODO: Check for which snapshot we are on

        assert result == True

    def test_list(self):
        self.__setup()
        # Check for error if incorrect session given
        failed = False
        try:
            self.snapshot_controller.list(session_id="does_not_exist")
        except SessionDoesNotExist:
            failed = True
        assert failed

        # Create file to add to snapshot
        test_filepath_1 = os.path.join(self.snapshot_controller.home,
                                       "test.txt")
        with open(test_filepath_1, "wb") as f:
            f.write(to_bytes(str("test")))

        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Create file to add to second snapshot
        test_filepath_2 = os.path.join(self.snapshot_controller.home,
                                       "test2.txt")
        with open(test_filepath_2, "wb") as f:
            f.write(to_bytes(str("test2")))

        # Create second snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # List all snapshots and ensure they exist
        result = self.snapshot_controller.list()

        assert len(result) == 2 and \
            snapshot_obj_1 in result and \
            snapshot_obj_2 in result

        # List all tasks regardless of filters in ascending
        result = self.snapshot_controller.list(sort_key='created_at',
                                               sort_order='ascending')

        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result
        assert result[0].created_at <= result[-1].created_at

        # List all tasks regardless of filters in descending
        result = self.snapshot_controller.list(sort_key='created_at',
                                               sort_order='descending')
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result
        assert result[0].created_at >= result[-1].created_at

        # Wrong order being passed in
        failed = False
        try:
            _ = self.snapshot_controller.list(sort_key='created_at',
                                              sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # Wrong key and order being passed in
        failed = False
        try:
            _ = self.snapshot_controller.list(sort_key='wrong_key',
                                              sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # wrong key and right order being passed in
        expected_result = self.snapshot_controller.list(sort_key='created_at',
                                                        sort_order='ascending')
        result = self.snapshot_controller.list(sort_key='wrong_key',
                                               sort_order='ascending')
        expected_ids = [item.id for item in expected_result]
        ids = [item.id for item in result]
        assert set(expected_ids) == set(ids)

        # List all snapshots with session filter
        result = self.snapshot_controller.list(
            session_id=self.project_controller.current_session.id)

        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result

        # List snapshots with visible filter
        result = self.snapshot_controller.list(visible=False)
        assert len(result) == 0

        result = self.snapshot_controller.list(visible=True)
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result

    def test_update(self):
        self.__setup()
        test_config = {"config_foo": "bar"}
        test_stats = {"stats_foo": "bar"}
        test_message = 'test_message'
        test_label = 'test_label'

        # Updating all config, stats, message and label
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        config=test_config,
                                        stats=test_stats,
                                        message=test_message,
                                        label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)
        assert updated_snapshot_obj.config == test_config
        assert updated_snapshot_obj.stats == test_stats
        assert updated_snapshot_obj.message == test_message
        assert updated_snapshot_obj.label == test_label

        # Updating config, stats
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        config=test_config,
                                        stats=test_stats)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)
        assert updated_snapshot_obj.config == test_config
        assert updated_snapshot_obj.stats == test_stats

        # Updating both message and label
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        message=test_message,
                                        label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)

        assert updated_snapshot_obj.message == test_message
        assert updated_snapshot_obj.label == test_label

        # Updating only message
        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj_1.id,
                                        message=test_message)

        # Get the updated snapshot obj
        updated_snapshot_obj_1 = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj_1.id)

        assert updated_snapshot_obj_1.message == test_message

        # Updating only label
        # Create snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj_2.id, label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj_2 = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj_2.id)

        assert updated_snapshot_obj_2.label == test_label

    def test_get(self):
        self.__setup()
        # Test failure for no snapshot
        failed = False
        try:
            self.snapshot_controller.get("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Test success for snapshot
        snapshot_obj = self.__default_create()
        snapshot_obj_returned = self.snapshot_controller.get(snapshot_obj.id)
        assert snapshot_obj == snapshot_obj_returned

    def test_get_files(self):
        self.__setup()
        # Test failure case
        failed = False
        try:
            self.snapshot_controller.get_files("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Test success case
        snapshot_obj = self.__default_create()
        result = self.snapshot_controller.get_files(snapshot_obj.id)
        file_collection_obj = self.task_controller.dal.file_collection.get_by_id(
            snapshot_obj.file_collection_id)

        file_names = [item.name for item in result]

        assert len(result) == 1
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "r"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

        result = self.snapshot_controller.get_files(snapshot_obj.id, mode="a")

        assert len(result) == 1
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "a"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

    def test_delete(self):
        self.__setup()
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Delete snapshot in the project
        result = self.snapshot_controller.delete(snapshot_obj.id)

        # Check if snapshot retrieval throws error
        thrown = False
        try:
            self.snapshot_controller.dal.snapshot.get_by_id(snapshot_obj.id)
        except EntityNotFound:
            thrown = True

        assert result == True and \
            thrown == True
Ejemplo n.º 3
0
class TestTaskController():
    def setup_method(self):
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)

    def teardown_method(self):
        pass

    def __setup(self):
        Config().set_home(self.temp_dir)
        self.project_controller = ProjectController()
        self.project_controller.init("test", "test description")
        self.environment_controller = EnvironmentController()
        self.task_controller = TaskController()

    def test_init_fail_project_not_init(self):
        Config().set_home(self.temp_dir)
        failed = False
        try:
            TaskController()
        except ProjectNotInitialized:
            failed = True
        assert failed

    def test_init_fail_invalid_path(self):
        test_home = "some_random_dir"
        Config().set_home(test_home)
        failed = False
        try:
            TaskController()
        except InvalidProjectPath:
            failed = True
        assert failed

    def test_create(self):
        self.__setup()
        # Create task in the project
        task_obj = self.task_controller.create()

        assert isinstance(task_obj, Task)
        assert task_obj.created_at
        assert task_obj.updated_at

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_run_helper(self):
        self.__setup()
        # TODO: Try out more options (see below)
        # Create environment_driver id
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        paths = [env_def_path]
        environment_obj = self.environment_controller.create({"paths": paths})

        # Set log filepath
        log_filepath = os.path.join(self.task_controller.home, "test.log")

        # create volume to mount
        temp_test_dirpath = os.path.join(self.temp_dir, "temp")
        os.makedirs(temp_test_dirpath)

        # Test option set 1
        random_name = ''.join([
            random.choice(string.ascii_letters + string.digits)
            for _ in range(32)
        ])
        options_dict = {
            "command": ["sh", "-c", "echo accuracy:0.45"],
            "ports": ["8888:8888"],
            "name": random_name,
            "volumes": {
                temp_test_dirpath: {
                    'bind': '/task/',
                    'mode': 'rw'
                }
            },
            "mem_limit": "4g",
            "detach": False,
            "stdin_open": False,
            "tty": False,
            "api": False,
            "interactive": False
        }

        return_code, run_id, logs = \
            self.task_controller._run_helper(environment_obj.id,
                                  options_dict, log_filepath)
        assert return_code == 0
        assert run_id and \
               self.task_controller.environment_driver.get_container(run_id)
        assert logs and \
               os.path.exists(log_filepath)
        self.task_controller.environment_driver.stop_remove_containers_by_term(
            term=random_name)

        # Test option set 2
        random_name_2 = ''.join([
            random.choice(string.ascii_letters + string.digits)
            for _ in range(32)
        ])
        options_dict = {
            "command": ["sh", "-c", "echo accuracy:0.45"],
            "ports": ["8888:8888"],
            "name": random_name_2,
            "volumes": {
                temp_test_dirpath: {
                    'bind': '/task/',
                    'mode': 'rw'
                }
            },
            "mem_limit": "4g",
            "detach": True,
            "stdin_open": False,
            "tty": False,
            "api": True,
            "interactive": False
        }

        return_code, run_id, logs = \
            self.task_controller._run_helper(environment_obj.id,
                                  options_dict, log_filepath)
        assert return_code == 0
        assert run_id and \
               self.task_controller.environment_driver.get_container(run_id)
        assert logs and \
               os.path.exists(log_filepath)
        self.task_controller.environment_driver.stop_remove_containers_by_term(
            term=random_name_2)

    def test_parse_logs_for_results(self):
        self.__setup()
        test_logs = """
        this is a log
        accuracy is good
        accuracy : 0.94
        this did not work
        validation : 0.32
        model_type : logistic regression
        """
        result = self.task_controller._parse_logs_for_results(test_logs)

        assert isinstance(result, dict)
        assert result['accuracy'] == "0.94"
        assert result['validation'] == "0.32"
        assert result['model_type'] == "logistic regression"

        test_logs = """test"""
        result = self.task_controller._parse_logs_for_results(test_logs)
        assert result is None

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_run(self):
        self.__setup()
        # 0) Test failure case without command and without interactive
        # 1) Test success case with default values and env def file
        # 2) Test failure case if running same task (conflicting containers)
        # 3) Test failure case if running same task with snapshot_dict (conflicting containers)
        # 4) Test success case with snapshot_dict
        # 5) Test success case with saved file during task run

        # TODO: look into log filepath randomness, sometimes logs are not written

        # Create task in the project
        task_obj = self.task_controller.create()

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # 0) Test option 0
        failed = False
        try:
            self.task_controller.run(task_obj.id)
        except RequiredArgumentMissing:
            failed = True
        assert failed

        failed = False
        try:
            self.task_controller.run(task_obj.id,
                                     task_dict={
                                         "command": None,
                                         "interactive": False,
                                         "ports": None
                                     })
        except RequiredArgumentMissing:
            failed = True
        assert failed

        # Create task_dict
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        # 1) Test option 1
        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)

        assert isinstance(updated_task_obj, Task)
        assert task_obj.id == updated_task_obj.id

        assert updated_task_obj.before_snapshot_id
        assert updated_task_obj.ports == None
        assert updated_task_obj.interactive == False
        assert updated_task_obj.task_dirpath
        assert updated_task_obj.log_filepath
        assert updated_task_obj.start_time

        assert updated_task_obj.after_snapshot_id
        assert updated_task_obj.run_id
        assert updated_task_obj.logs
        assert "accuracy" in updated_task_obj.logs
        assert updated_task_obj.results
        assert updated_task_obj.results == {"accuracy": "0.45"}
        assert updated_task_obj.status == "SUCCESS"
        assert updated_task_obj.end_time
        assert updated_task_obj.duration

        self.task_controller.stop(task_obj.id)

        # 2) Test option 2
        failed = False
        try:
            self.task_controller.run(task_obj.id)
        except TaskRunError:
            failed = True
        assert failed

        # 3) Test option 3

        # Create files to add
        self.project_controller.file_driver.create("dirpath1", directory=True)
        self.project_controller.file_driver.create("dirpath2", directory=True)
        self.project_controller.file_driver.create("filepath1")

        # Snapshot dictionary
        snapshot_dict = {
            "paths": [
                os.path.join(self.project_controller.home, "dirpath1"),
                os.path.join(self.project_controller.home, "dirpath2"),
                os.path.join(self.project_controller.home, "filepath1")
            ],
        }

        # Run a basic task in the project
        failed = False
        try:
            self.task_controller.run(task_obj.id, snapshot_dict=snapshot_dict)
        except TaskRunError:
            failed = True
        assert failed

        # Test when the specific task id is already RUNNING
        # Create task in the project
        task_obj_1 = self.task_controller.create()
        self.task_controller.dal.task.update({
            "id": task_obj_1.id,
            "status": "RUNNING"
        })
        # Create environment_driver definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        failed = False
        try:
            self.task_controller.run(task_obj_1.id, task_dict=task_dict)
        except TaskRunError:
            failed = True
        assert failed

        # 4) Test option 4

        # Create a new task in the project
        task_obj_2 = self.task_controller.create()

        # Run another task in the project
        updated_task_obj_2 = self.task_controller.run(
            task_obj_2.id, task_dict=task_dict, snapshot_dict=snapshot_dict)

        assert isinstance(updated_task_obj_2, Task)
        assert task_obj_2.id == updated_task_obj_2.id

        assert updated_task_obj_2.before_snapshot_id
        assert updated_task_obj_2.ports == None
        assert updated_task_obj_2.interactive == False
        assert updated_task_obj_2.task_dirpath
        assert updated_task_obj_2.log_filepath
        assert updated_task_obj_2.start_time

        assert updated_task_obj_2.after_snapshot_id
        assert updated_task_obj_2.run_id
        assert updated_task_obj_2.logs
        assert "accuracy" in updated_task_obj_2.logs
        assert updated_task_obj_2.results
        assert updated_task_obj_2.results == {"accuracy": "0.45"}
        assert updated_task_obj_2.status == "SUCCESS"
        assert updated_task_obj_2.end_time
        assert updated_task_obj_2.duration

        self.task_controller.stop(task_obj_2.id)

        # 5) Test option 5

        # Create a basic script
        # (fails w/ no environment)
        test_filepath = os.path.join(self.temp_dir, "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import os\n"))
            f.write(to_bytes("import shutil\n"))
            f.write(to_bytes("print('hello')\n"))
            f.write(to_bytes("print(' accuracy: 0.56 ')\n"))
            f.write(
                to_bytes(
                    "with open(os.path.join('/task', 'new_file.txt'), 'a') as f:\n"
                ))
            f.write(to_bytes("    f.write('my test file')\n"))

        # Create task in the project
        task_obj_2 = self.task_controller.create()

        # Create task_dict
        task_command = ["python", "script.py"]
        task_dict = {"command_list": task_command}

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        updated_task_obj_2 = self.task_controller.run(task_obj_2.id,
                                                      task_dict=task_dict)

        assert isinstance(updated_task_obj_2, Task)
        assert updated_task_obj_2.before_snapshot_id
        assert updated_task_obj_2.ports == None
        assert updated_task_obj_2.interactive == False
        assert updated_task_obj_2.task_dirpath
        assert updated_task_obj_2.log_filepath
        assert updated_task_obj_2.start_time

        assert updated_task_obj_2.after_snapshot_id
        assert updated_task_obj_2.run_id
        assert updated_task_obj_2.logs
        assert "accuracy" in updated_task_obj_2.logs
        assert updated_task_obj_2.results
        assert updated_task_obj_2.results == {"accuracy": "0.56"}
        assert updated_task_obj_2.status == "SUCCESS"
        assert updated_task_obj_2.end_time
        assert updated_task_obj_2.duration

        self.task_controller.stop(task_obj_2.id)

        # test if after snapshot has the file written
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj_2.after_snapshot_id)
        file_collection_obj = self.task_controller.dal.file_collection.get_by_id(
            after_snapshot_obj.file_collection_id)
        files_absolute_path = os.path.join(self.task_controller.home,
                                           file_collection_obj.path)

        assert os.path.isfile(os.path.join(files_absolute_path, "task.log"))
        assert os.path.isfile(os.path.join(files_absolute_path,
                                           "new_file.txt"))

    def test_list(self):
        self.__setup()
        # Create tasks in the project
        task_obj_1 = self.task_controller.create()
        task_obj_2 = self.task_controller.create()

        # List all tasks regardless of filters
        result = self.task_controller.list()

        assert len(result) == 2 and \
               task_obj_1 in result and \
               task_obj_2 in result

        # List all tasks regardless of filters in ascending
        result = self.task_controller.list(sort_key='created_at',
                                           sort_order='ascending')

        assert len(result) == 2 and \
               task_obj_1 in result and \
               task_obj_2 in result
        assert result[0].created_at <= result[-1].created_at

        # List all tasks regardless of filters in descending
        result = self.task_controller.list(sort_key='created_at',
                                           sort_order='descending')
        assert len(result) == 2 and \
               task_obj_1 in result and \
               task_obj_2 in result
        assert result[0].created_at >= result[-1].created_at

        # Wrong order being passed in
        failed = False
        try:
            _ = self.task_controller.list(sort_key='created_at',
                                          sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # Wrong key and order being passed in
        failed = False
        try:
            _ = self.task_controller.list(sort_key='wrong_key',
                                          sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # wrong key and right order being passed in
        expected_result = self.task_controller.list(sort_key='created_at',
                                                    sort_order='ascending')
        result = self.task_controller.list(sort_key='wrong_key',
                                           sort_order='ascending')
        expected_ids = [item.id for item in expected_result]
        ids = [item.id for item in result]
        assert set(expected_ids) == set(ids)

        # List all tasks and filter by session
        result = self.task_controller.list(
            session_id=self.project_controller.current_session.id)

        assert len(result) == 2 and \
               task_obj_1 in result and \
               task_obj_2 in result

    def test_get(self):
        self.__setup()
        # Test failure for no task
        failed = False
        try:
            self.task_controller.get("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Test success for task
        task_obj = self.task_controller.create()
        task_obj_returned = self.task_controller.get(task_obj.id)
        assert task_obj == task_obj_returned

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_get_files(self):
        self.__setup()
        # Test failure case
        failed = False
        try:
            self.task_controller.get_files("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Create task in the project
        task_obj = self.task_controller.create()

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create file to add
        self.project_controller.file_driver.create("dirpath1", directory=True)
        self.project_controller.file_driver.create(
            os.path.join("dirpath1", "filepath1"))

        # Snapshot dictionary
        snapshot_dict = {
            "paths": [
                os.path.join(self.project_controller.home, "dirpath1",
                             "filepath1")
            ],
        }

        # Create task_dict
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        # Test the default values
        updated_task_obj = self.task_controller.run(
            task_obj.id, task_dict=task_dict, snapshot_dict=snapshot_dict)

        # TODO: Test case for during run and before_snapshot run
        # Get files for the task after run is complete (default)
        result = self.task_controller.get_files(updated_task_obj.id)

        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id)
        file_collection_obj = self.task_controller.dal.file_collection.get_by_id(
            after_snapshot_obj.file_collection_id)

        file_names = [item.name for item in result]

        assert len(result) == 2
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "r"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "task.log") in file_names
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

        # Get files for the task after run is complete for different mode
        result = self.task_controller.get_files(updated_task_obj.id, mode="a")

        assert len(result) == 2
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "a"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "task.log") in file_names
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

        self.task_controller.stop(task_obj.id)

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_delete(self):
        self.__setup()
        # Create tasks in the project
        task_obj = self.task_controller.create()

        # Delete task from the project
        result = self.task_controller.delete(task_obj.id)

        # Check if task retrieval throws error
        thrown = False
        try:
            self.task_controller.dal.snapshot.get_by_id(task_obj.id)
        except EntityNotFound:
            thrown = True

        assert result == True and \
               thrown == True

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_stop_failure(self):
        self.__setup()
        # 1) Test required arguments not provided
        # 2) Test too many arguments found
        # 3) Test incorrect task id given

        # 1) Test option 1
        failed = False
        try:
            self.task_controller.stop()
        except RequiredArgumentMissing:
            failed = True
        assert failed

        # 2) Test option 2
        failed = False
        try:
            self.task_controller.stop(task_id="test_task_id", all=True)
        except TooManyArgumentsFound:
            failed = True
        assert failed

        # 3) Test option 3
        thrown = False
        try:
            self.task_controller.stop(task_id="incorrect_task_id")
        except EntityNotFound:
            thrown = True
        assert thrown

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_stop_success(self):
        self.__setup()
        # 1) Test stop with task_id
        # 2) Test stop with all given

        # Create task in the project
        task_obj = self.task_controller.create()

        # Create environment driver definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create task_dict
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        # 1) Test option 1
        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)
        task_id = updated_task_obj.id
        result = self.task_controller.stop(task_id=task_id)
        after_task_obj = self.task_controller.dal.task.get_by_id(task_id)

        assert result
        assert after_task_obj.status == "STOPPED"

        # 2) Test option 2
        task_obj_2 = self.task_controller.create()
        _ = self.task_controller.run(task_obj_2.id, task_dict=task_dict)
        result = self.task_controller.stop(all=True)
        all_task_objs = self.task_controller.dal.task.query({})

        assert result
        for task_obj in all_task_objs:
            assert task_obj.status == "STOPPED"
Ejemplo n.º 4
0
class WorkspaceCommand(ProjectCommand):
    def __init__(self, cli_helper):
        super(WorkspaceCommand, self).__init__(cli_helper)

    @Helper.notify_environment_active(TaskController)
    @Helper.notify_no_project_found
    def notebook(self, **kwargs):
        self.task_controller = TaskController()
        self.cli_helper.echo(__("info", "cli.workspace.notebook"))
        # Creating input dictionaries
        snapshot_dict = {}

        # Environment
        if kwargs.get("environment_id", None) or kwargs.get(
                "environment_paths", None):
            mutually_exclusive_args = ["environment_id", "environment_paths"]
            mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)

        task_dict = {
            "ports": ["8888:8888"],
            "command_list": ["jupyter", "notebook"],
            "mem_limit": kwargs["mem_limit"]
        }

        # Create the task object
        task_obj = self.task_controller.create()

        updated_task_obj = task_obj
        # Pass in the task
        try:
            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
        except Exception as e:
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo(
                __("error", "cli.workspace.notebook", task_obj.id))
            return False
        finally:
            self.cli_helper.echo(__("info", "cli.task.run.stop"))
            self.task_controller.stop(updated_task_obj.id)
            self.cli_helper.echo(
                __("info", "cli.task.run.complete", updated_task_obj.id))

        return updated_task_obj

    @Helper.notify_environment_active(TaskController)
    @Helper.notify_no_project_found
    def rstudio(self, **kwargs):
        self.task_controller = TaskController()
        self.cli_helper.echo(__("info", "cli.workspace.rstudio"))
        # Creating input dictionaries
        snapshot_dict = {}

        # Environment
        if kwargs.get("environment_id", None) or kwargs.get(
                "environment_paths", None):
            mutually_exclusive_args = ["environment_id", "environment_paths"]
            mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)

        task_dict = {
            "ports": ["8787:8787"],
            "command_list": [
                "/usr/lib/rstudio-server/bin/rserver", "--server-daemonize=0",
                "--server-app-armor-enabled=0"
            ],
            "mem_limit":
            kwargs["mem_limit"]
        }

        # Create the task object
        task_obj = self.task_controller.create()

        updated_task_obj = task_obj
        # Pass in the task
        try:
            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
        except Exception as e:
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo(
                __("error", "cli.workspace.rstudio", task_obj.id))
            return False
        finally:
            self.cli_helper.echo(__("info", "cli.task.run.stop"))
            self.task_controller.stop(updated_task_obj.id)
            self.cli_helper.echo(
                __("info", "cli.task.run.complete", updated_task_obj.id))

        return updated_task_obj
Ejemplo n.º 5
0
class TaskCommand(ProjectCommand):
    def __init__(self, cli_helper):
        super(TaskCommand, self).__init__(cli_helper)

    def task(self):
        self.parse(["task", "--help"])
        return True

    @Helper.notify_environment_active(TaskController)
    @Helper.notify_no_project_found
    def run(self, **kwargs):
        self.task_controller = TaskController()
        self.cli_helper.echo(__("info", "cli.task.run"))
        # Create input dictionaries
        snapshot_dict = {}

        # Environment
        if kwargs.get("environment_id", None) or kwargs.get(
                "environment_paths", None):
            mutually_exclusive_args = ["environment_id", "environment_paths"]
            mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)
        task_dict = {
            "ports": kwargs['ports'],
            "interactive": kwargs['interactive'],
            "mem_limit": kwargs['mem_limit']
        }
        if not isinstance(kwargs['cmd'], list):
            if platform.system() == "Windows":
                task_dict['command'] = kwargs['cmd']
            elif isinstance(kwargs['cmd'], basestring):
                task_dict['command_list'] = shlex.split(kwargs['cmd'])
        else:
            task_dict['command_list'] = kwargs['cmd']

        # Create the task object
        task_obj = self.task_controller.create()

        updated_task_obj = task_obj
        try:
            # Pass in the task
            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
        except Exception as e:
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo("%s" % e)
            self.cli_helper.echo(__("error", "cli.task.run", task_obj.id))
            return False
        finally:
            self.cli_helper.echo(__("info", "cli.task.run.stop"))
            self.task_controller.stop(updated_task_obj.id)
            self.cli_helper.echo(
                __("info", "cli.task.run.complete", updated_task_obj.id))

        return updated_task_obj

    @Helper.notify_no_project_found
    def ls(self, **kwargs):
        self.task_controller = TaskController()
        session_id = kwargs.get('session_id',
                                self.task_controller.current_session.id)
        print_format = kwargs.get('format', "table")
        download = kwargs.get('download', None)
        download_path = kwargs.get('download_path', None)
        # Get all task meta information
        task_objs = self.task_controller.list(session_id,
                                              sort_key='created_at',
                                              sort_order='descending')
        header_list = [
            "id", "start time", "duration (s)", "command", "status", "results"
        ]
        item_dict_list = []
        for task_obj in task_objs:
            task_results_printable = printable_object(task_obj.results)
            item_dict_list.append({
                "id":
                task_obj.id,
                "command":
                printable_object(task_obj.command),
                "status":
                printable_object(task_obj.status),
                "results":
                task_results_printable,
                "start time":
                prettify_datetime(task_obj.start_time),
                "duration (s)":
                printable_object(task_obj.duration)
            })
        if download:
            if not download_path:
                # download to current working directory with timestamp
                current_time = datetime.utcnow()
                epoch_time = datetime.utcfromtimestamp(0)
                current_time_unix_time_ms = (
                    current_time - epoch_time).total_seconds() * 1000.0
                download_path = os.path.join(
                    self.task_controller.home,
                    "task_ls_" + str(current_time_unix_time_ms))
            self.cli_helper.print_items(header_list,
                                        item_dict_list,
                                        print_format=print_format,
                                        output_path=download_path)
            return task_objs
        self.cli_helper.print_items(header_list,
                                    item_dict_list,
                                    print_format=print_format)
        return task_objs

    @Helper.notify_environment_active(TaskController)
    @Helper.notify_no_project_found
    def stop(self, **kwargs):
        self.task_controller = TaskController()
        input_dict = {}
        mutually_exclusive(["id", "all"], kwargs, input_dict)
        if "id" in input_dict:
            self.cli_helper.echo(__("info", "cli.task.stop", input_dict['id']))
        elif "all" in input_dict:
            self.cli_helper.echo(__("info", "cli.task.stop.all"))
        else:
            raise RequiredArgumentMissing()
        try:
            if "id" in input_dict:
                result = self.task_controller.stop(task_id=input_dict['id'])
                if not result:
                    self.cli_helper.echo(
                        __("error", "cli.task.stop", input_dict['id']))
                else:
                    self.cli_helper.echo(
                        __("info", "cli.task.stop.success", input_dict['id']))
            if "all" in input_dict:
                result = self.task_controller.stop(all=input_dict['all'])
                if not result:
                    self.cli_helper.echo(__("error", "cli.task.stop.all"))
                else:
                    self.cli_helper.echo(
                        __("info", "cli.task.stop.all.success"))
            return result
        except Exception:
            if "id" in input_dict:
                self.cli_helper.echo(
                    __("error", "cli.task.stop", input_dict['id']))
            if "all" in input_dict:
                self.cli_helper.echo(__("error", "cli.task.stop.all"))
            return False
Ejemplo n.º 6
0
def run(command, env=None, gpu=False, mem_limit=None):
    """Run the code or script inside

    The project must be created before this is implemented. You can do that by using
    the following command::

        $ datmo init


    Parameters
    ----------
    command : str or list
        the command to be run in environment. this can be either a string or list
    env : str or list, optional
        the absolute file path for the environment definition path. this can be either a string or list
        (default is None, which will defer to the environment to find a default environment,
        or will fail if not found)
    gpu: boolean
        try to run task on GPU (if available)
    mem_limit : string, optional
        maximum amount of memory the environment can use
        (these options take a positive integer, followed by a suffix of b, k, m, g, to indicate bytes, kilobytes,
        megabytes, or gigabytes. memory limit is contrained by total memory of the VM in which docker runs)

    Returns
    -------
    Task
        returns a Task entity as defined above

    Examples
    --------
    You can use this function within a project repository to run tasks
    in the following way.

    >>> import datmo
    >>> datmo.task.run(command="python script.py")
    >>> datmo.task.run(command="python script.py", env='Dockerfile')
    """

    task_controller = TaskController()

    # Create input dictionaries
    snapshot_dict = {}
    task_dict = {}

    if isinstance(env, list):
        snapshot_dict["environment_paths"] = env
    elif env:
        snapshot_dict["environment_paths"] = [env]

    if isinstance(command, list):
        task_dict["command_list"] = command
    else:
        if platform.system() == "Windows":
            task_dict["command"] = command
        elif isinstance(command, basestring):
            task_dict["command_list"] = shlex.split(command)

    task_dict["gpu"] = gpu

    task_dict["mem_limit"] = mem_limit

    # Create the task object
    core_task_obj = task_controller.create()

    # Pass in the task
    updated_core_task_obj = task_controller.run(core_task_obj.id,
                                                snapshot_dict=snapshot_dict,
                                                task_dict=task_dict)

    # Create a new task object for the
    client_task_obj = Task(updated_core_task_obj)

    return client_task_obj
Ejemplo n.º 7
0
class TaskCommand(ProjectCommand):
    def __init__(self, home, cli_helper):
        super(TaskCommand, self).__init__(home, cli_helper)

        task_parser = self.subparsers.add_parser("task", help="Task module")
        subcommand_parsers = task_parser.add_subparsers(title="subcommands", dest="subcommand")

        # Task run arguments
        run = subcommand_parsers.add_parser("run", help="Run task")
        run.add_argument("--gpu", dest="gpu", action="store_true",
                         help="Boolean if you want to run using GPUs")
        run.add_argument("--ports", nargs="*", dest="ports", type=str, help="""
            Network port mapping during task (e.g. 8888:8888). Left is the host machine port and right
            is the environment port available during a run.
        """)
        # run.add_argument("--data", nargs="*", dest="data", type=str, help="Path for data to be used during the Task")
        run.add_argument("--env-def", dest="environment_definition_filepath", default="",
                         nargs="?", type=str,
                         help="Pass in the Dockerfile with which you want to build the environment")
        run.add_argument("--interactive", dest="interactive", action="store_true",
                         help="Run the environment in interactive mode (keeps STDIN open)")
        run.add_argument("cmd", nargs="?", default=None)

        # Task list arguments
        ls = subcommand_parsers.add_parser("ls", help="List tasks")
        ls.add_argument("--session-id", dest="session_id", default=None, nargs="?", type=str,
                         help="Pass in the session id to list the tasks in that session")

        # Task stop arguments
        stop = subcommand_parsers.add_parser("stop", help="Stop tasks")
        stop.add_argument("--id", dest="id", default=None, type=str, help="Task ID to stop")

        self.task_controller = TaskController(home=home)

    def run(self, **kwargs):
        self.cli_helper.echo(__("info", "cli.task.run"))

        # Create input dictionaries
        snapshot_dict = {
            "environment_definition_filepath":
                kwargs['environment_definition_filepath']
        }

        if not isinstance(kwargs['cmd'], list):
            if platform.system() == "Windows":
                kwargs['cmd'] = kwargs['cmd']
            else:
                kwargs['cmd'] = shlex.split(kwargs['cmd'])

        task_dict = {
            "gpu": kwargs['gpu'],
            "ports": kwargs['ports'],
            "interactive": kwargs['interactive'],
            "command": kwargs['cmd']
        }

        # Create the task object
        task_obj = self.task_controller.create(task_dict)

        # Pass in the task
        try:
            updated_task_obj = self.task_controller.run(task_obj.id, snapshot_dict=snapshot_dict)
        except:
            self.cli_helper.echo(__("error",
                                    "cli.task.run",
                                    task_obj.id))
            return False
        return updated_task_obj

    def ls(self, **kwargs):
        session_id = kwargs.get('session_id',
                                self.task_controller.current_session.id)
        # Get all snapshot meta information
        header_list = ["id", "command", "status", "gpu", "created at"]
        t = prettytable.PrettyTable(header_list)
        task_objs = self.task_controller.list(session_id)
        for task_obj in task_objs:
            t.add_row([task_obj.id, task_obj.command, task_obj.status, task_obj.gpu,
                       task_obj.created_at.strftime("%Y-%m-%d %H:%M:%S")])
        self.cli_helper.echo(t)

        return True

    def stop(self, **kwargs):
        task_id = kwargs.get('id', None)
        self.cli_helper.echo(__("info",
                                "cli.task.stop",
                                task_id))
        try:
            result = self.task_controller.stop(task_id)
            if not result:
                self.cli_helper.echo(__("error",
                                        "cli.task.stop",
                                        task_id))
            return result
        except:
            self.cli_helper.echo(__("error",
                                    "cli.task.stop",
                                    task_id))
            return False