Exemple #1
0
class TestProjectController():
    def setup_method(self):
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        Config().set_home(self.temp_dir)
        self.project_controller = ProjectController()
        self.environment_ids = []

    def teardown_method(self):
        if not check_docker_inactive(test_datmo_dir):
            self.project_controller = ProjectController()
            if self.project_controller.is_initialized:
                self.environment_controller = EnvironmentController()
                for env_id in list(set(self.environment_ids)):
                    if not self.environment_controller.delete(env_id):
                        raise Exception

    def test_init_failure_none(self):
        # Test failed case
        failed = False
        try:
            self.project_controller.init(None, None)
        except ValidationFailed:
            failed = True
        assert failed

    def test_init_failure_empty_str(self):
        # Test failed case
        failed = False
        try:
            self.project_controller.init("", "")
        except ValidationFailed:
            failed = True
        assert failed
        assert not self.project_controller.code_driver.is_initialized
        assert not self.project_controller.file_driver.is_initialized

    def test_init_failure_git_code_driver(self):
        # Create a HEAD.lock file in .git to make GitCodeDriver.init() fail
        if self.project_controller.code_driver.type == "git":
            git_dir = os.path.join(
                self.project_controller.code_driver.filepath, ".git")
            os.makedirs(git_dir)
            with open(os.path.join(git_dir, "HEAD.lock"), "a+") as f:
                f.write(to_bytes("test"))
            failed = False
            try:
                self.project_controller.init("test1", "test description")
            except Exception:
                failed = True
            assert failed
            assert not self.project_controller.code_driver.is_initialized
            assert not self.project_controller.file_driver.is_initialized

    def test_init_success(self):
        result = self.project_controller.init("test1", "test description")

        # Tested with is_initialized
        assert self.project_controller.model.name == "test1"
        assert self.project_controller.model.description == "test description"
        assert result and self.project_controller.is_initialized

        # Changeable by user, not tested in is_initialized
        assert self.project_controller.current_session.name == "default"

    # TODO: Test lower level functions (DAL, JSONStore, etc for interruptions)
    # def test_init_with_interruption(self):
    #     # Reinitializing after timed interruption during init
    #     @timeout_decorator.timeout(0.001, use_signals=False)
    #     def timed_init_with_interruption():
    #         result = self.project_controller.init("test1", "test description")
    #         return result
    #
    #     failed = False
    #     try:
    #         timed_init_with_interruption()
    #     except timeout_decorator.timeout_decorator.TimeoutError:
    #         failed = True
    #     # Tested with is_initialized
    #     assert failed
    #
    #     # Reperforming init after a wait of 2 seconds
    #     time.sleep(2)
    #     result = self.project_controller.init("test2", "test description")
    #     # Tested with is_initialized
    #     assert self.project_controller.model.name == "test2"
    #     assert self.project_controller.model.description == "test description"
    #     assert result and self.project_controller.is_initialized
    #
    #     # Changeable by user, not tested in is_initialized
    #     assert self.project_controller.current_session.name == "default"

    def test_init_reinit_failure_empty_str(self):
        _ = self.project_controller.init("test1", "test description")
        failed = True
        try:
            self.project_controller.init("", "")
        except Exception:
            failed = True
        assert failed
        assert self.project_controller.model.name == "test1"
        assert self.project_controller.model.description == "test description"
        assert self.project_controller.code_driver.is_initialized
        assert self.project_controller.file_driver.is_initialized

    def test_init_reinit_success(self):
        _ = self.project_controller.init("test1", "test description")
        # Test out functionality for re-initialize project
        result = self.project_controller.init("anything", "else")

        assert self.project_controller.model.name == "anything"
        assert self.project_controller.model.description == "else"
        assert result == True

    def test_cleanup_no_environment(self):
        self.project_controller.init("test2", "test description")
        result = self.project_controller.cleanup()

        assert not self.project_controller.code_driver.is_initialized
        assert not self.project_controller.file_driver.is_initialized
        # Ensure that containers built with this image do not exist
        # assert not self.project_controller.environment_driver.list_containers(filters={
        #     "ancestor": image_id
        # })
        assert result == True

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_cleanup_with_environment(self):
        self.project_controller.init("test2", "test description")
        result = self.project_controller.cleanup()

        assert not self.project_controller.code_driver.is_initialized
        assert not self.project_controller.file_driver.is_initialized
        assert not self.project_controller.environment_driver.list_images(
            "datmo-test2")
        # Ensure that containers built with this image do not exist
        # assert not self.project_controller.environment_driver.list_containers(filters={
        #     "ancestor": image_id
        # })
        assert result == True

    def test_status_basic(self):
        self.project_controller.init("test3", "test description")
        status_dict, current_snapshot, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test3"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert not current_snapshot
        assert not latest_snapshot_user_generated
        assert not latest_snapshot_auto_generated
        assert unstaged_code  # no files, but unstaged because blank commit id has not yet been created (no initial snapshot)
        assert not unstaged_environment
        assert not unstaged_files

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_status_snapshot_task(self):
        self.project_controller.init("test4", "test description")
        self.snapshot_controller = SnapshotController()
        self.task_controller = TaskController()

        # Create files to add
        self.snapshot_controller.file_driver.create("dirpath1", directory=True)
        self.snapshot_controller.file_driver.create("dirpath2", directory=True)
        self.snapshot_controller.file_driver.create("filepath1")

        # Create environment definition
        env_def_path = os.path.join(self.snapshot_controller.home,
                                    "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        environment_paths = [env_def_path]

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        input_dict = {
            "message":
                "my test snapshot",
            "paths": [
                os.path.join(self.snapshot_controller.home, "dirpath1"),
                os.path.join(self.snapshot_controller.home, "dirpath2"),
                os.path.join(self.snapshot_controller.home, "filepath1")
            ],
            "environment_paths":
                environment_paths,
            "config_filename":
                config_filepath,
            "stats_filename":
                stats_filepath,
        }

        # Create snapshot in the project, then wait, and try status
        first_snapshot = self.snapshot_controller.create(input_dict)

        status_dict, current_snapshot, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test4"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert not current_snapshot  # snapshot was created from other environments and files (so user is not on any current snapshot)
        assert isinstance(latest_snapshot_user_generated, Snapshot)
        assert latest_snapshot_user_generated == first_snapshot
        assert not latest_snapshot_auto_generated
        assert not unstaged_code
        assert not unstaged_environment
        assert not unstaged_files

        # Create and run a task and test if task is shown
        first_task = self.task_controller.create()

        # Create task_dict
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        updated_first_task = self.task_controller.run(
            first_task.id, task_dict=task_dict)
        before_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_first_task.before_snapshot_id)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_first_task.after_snapshot_id)
        before_environment_obj = self.task_controller.dal.environment.get_by_id(
            before_snapshot_obj.environment_id)
        after_environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        assert before_environment_obj == after_environment_obj
        self.environment_ids.append(after_environment_obj.id)

        status_dict, current_snapshot, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test4"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert isinstance(current_snapshot, Snapshot)
        assert isinstance(latest_snapshot_user_generated, Snapshot)
        assert latest_snapshot_user_generated == first_snapshot
        assert isinstance(latest_snapshot_auto_generated, Snapshot)
        # current snapshot is the before snapshot for the run
        assert current_snapshot == before_snapshot_obj
        assert current_snapshot != latest_snapshot_auto_generated
        assert current_snapshot != latest_snapshot_user_generated
        # latest autogenerated snapshot is the after snapshot id
        assert latest_snapshot_auto_generated == after_snapshot_obj
        assert latest_snapshot_auto_generated != latest_snapshot_user_generated
        # user generated snapshot is not associated with any before or after snapshot
        assert latest_snapshot_user_generated != before_snapshot_obj
        assert latest_snapshot_user_generated != after_snapshot_obj
        assert not unstaged_code
        assert not unstaged_environment
        assert not unstaged_files
Exemple #2
0
class TaskController(BaseController):
    """TaskController inherits from BaseController and manages business logic associated with tasks
    within the project.

    Parameters
    ----------
    home : str
        home path of the project

    Attributes
    ----------
    environment : datmo.core.controller.environment.environment.EnvironmentController
        used to create environment if new definition file
    snapshot : datmo.core.controller.snapshot.SnapshotController
        used to create snapshots before and after tasks

    Methods
    -------
    create(dictionary)
        creates a Task object with the permanent parameters
    _run_helper(environment_id, log_filepath, options)
        helper for run to start environment and run with the appropriate parameters
    run(self, id, dictionary=None)
        runs the task and tracks the run, logs, inputs and outputs
    list(session_id=None)
        lists all tasks within the project given filters
    delete(id)
        deletes the specified task from the project
    """

    def __init__(self):
        super(TaskController, self).__init__()
        self.environment = EnvironmentController()
        self.snapshot = SnapshotController()
        self.spinner = Spinner()

        if not self.is_initialized:
            raise ProjectNotInitialized(
                __("error", "controller.task.__init__"))

    def create(self):
        """Create Task object

        Returns
        -------
        Task
            object entity for Task (datmo.core.entity.task.Task)
        """

        # Validate Inputs
        create_dict = {
            "model_id": self.model.id,
            "session_id": self.current_session.id
        }

        try:
            # Create Task
            self.spinner.start()
            task_obj = self.dal.task.create(Task(create_dict))
        finally:
            self.spinner.stop()
        return task_obj

    def _run_helper(self, environment_id, options, log_filepath):
        """Run environment with parameters

        Parameters
        ----------
        environment_id : str
            the environment id for definition
        options : dict
            can include the following values:

            command : list
            ports : list
                Here are some example ports used for common applications.
                   *  'jupyter notebook' - 8888
                   *  flask API - 5000
                   *  tensorboard - 6006
                An example input for the above would be ["8888:8888", "5000:5000", "6006:6006"]
                which maps the running host port (right) to that of the environment (left)
            name : str
            volumes : dict
            mem_limit : str
            workspace : str
            detach : bool
            stdin_open : bool
            tty : bool
        log_filepath : str
            absolute filepath to the log file

        Returns
        -------
        return_code : int
            system return code of the environment that was run
        run_id : str
            id of the environment run (different from environment id)
        logs : str
            output logs from the run
        """
        # Run container with options provided
        run_options = {
            "command": options.get('command', None),
            "ports": options.get('ports', None),
            "name": options.get('name', None),
            "volumes": options.get('volumes', None),
            "mem_limit": options.get('mem_limit', None),
            "gpu": options.get('gpu', False),
            "detach": options.get('detach', False),
            "stdin_open": options.get('stdin_open', False),
            "tty": options.get('tty', False),
            "api": False,
        }
        workspace = options.get('workspace', None)
        self.environment.build(environment_id, workspace)
        # Run container with environment
        return_code, run_id, logs = self.environment.run(
            environment_id, run_options, log_filepath)

        return return_code, run_id, logs

    def _parse_logs_for_results(self, logs):
        """Parse log string to extract results and return dictionary.

        The format of the log line must be "key:value", whitespace will not matter
        and if there are more than 2 items found when split on ":", it will not
        log this as a key/value result

        Note
        ----
        If the same key is found multiple times in the logs, the last occurring
        one will be the one that is saved.

        Parameters
        ----------
        logs : str
            raw string value of output logs

        Returns
        -------
        dict or None
            dictionary to represent results from task
        """
        results = {}
        for line in logs.split("\n"):
            split_line = line.split(":")
            if len(split_line) == 2:
                results[split_line[0].strip()] = split_line[1].strip()
        if results == {}:
            results = None
        return results

    def run(self, task_id, snapshot_dict=None, task_dict=None):
        """Run a task with parameters. If dictionary specified, create a new task with new run parameters.
        Snapshot objects are created before and after the task to keep track of the state. During the run,
        you can access task outputs using environment variable DATMO_TASK_DIR or `/task` which points to
        location for the task files. Create config.json, stats.json and any weights or any file such
        as graphs and visualizations within that directory for quick access

        Parameters
        ----------
        task_id : str
            id for the task you would like to run
        snapshot_dict : dict
            set of parameters to create a snapshot (see SnapshotController for details.
            default is None, which means dictionary with `visible` False will be added to
            hide auto-generated snapshot) NOTE: `visible` False will always be False regardless
            of whether the user provides another value for `visible`.
        task_dict : dict
            set of parameters to characterize the task run
            (default is None, which translate to {}, see datmo.core.entity.task.Task for more details on inputs)

        Returns
        -------
        Task
            the Task object which completed its run with updated parameters

        Raises
        ------
        TaskRunError
            If there is any error in creating files for the task or downstream errors
        """
        # Ensure visible=False is present in the snapshot dictionary
        if not snapshot_dict:
            snapshot_dict = {"visible": False}
        else:
            snapshot_dict['visible'] = False

        if not task_dict:
            task_dict = {}
        # Obtain Task to run
        task_obj = self.dal.task.get_by_id(task_id)

        # Ensure that at least 1 of command, command_list,  or interactive is present in task_dict
        important_task_args = ["command", "command_list", "interactive"]
        if not task_dict.get('command', task_obj.command) and \
            not task_dict.get('command_list', task_obj.command_list) and \
                not task_dict.get('interactive', task_obj.interactive):
            raise RequiredArgumentMissing(
                __("error", "controller.task.run.arg",
                   " or ".join(important_task_args)))

        if task_obj.status is None:
            task_obj.status = "RUNNING"
        else:
            raise TaskRunError(
                __("error", "cli.run.run.already_running", task_obj.id))
        # Create Task directory for user during run
        task_dirpath = os.path.join(".datmo", "tasks", task_obj.id)
        try:
            _ = self.file_driver.create(task_dirpath, directory=True)
        except Exception:
            raise TaskRunError(
                __("error", "controller.task.run", task_dirpath))
        # Create the before snapshot prior to execution
        before_snapshot_dict = snapshot_dict.copy()
        before_snapshot_dict[
            'message'] = "autogenerated snapshot created before task %s is run" % task_obj.id
        before_snapshot_obj = self.snapshot.create(before_snapshot_dict)
        # Update the task with pre-execution parameters, prefer list first then look for string command
        # List command will overwrite a string command if given
        if task_dict.get('command_list', task_obj.command_list):
            task_dict['command'] = " ".join(
                task_dict.get('command_list', task_obj.command_list))
        else:
            if task_dict.get('command', task_obj.command):
                task_dict['command_list'] = shlex.split(
                    task_dict.get('command', task_obj.command))
            elif not task_dict.get('interactive', task_obj.interactive):
                # If it's not interactive then there is not expected task
                raise TaskNoCommandGiven()

        validate("create_task", task_dict)
        task_obj = self.dal.task.update({
            "id":
                task_obj.id,
            "before_snapshot_id":
                task_dict.get('before_snapshot_id', before_snapshot_obj.id),
            "command":
                task_dict.get('command', task_obj.command),
            "command_list":
                task_dict.get('command_list', task_obj.command_list),
            "gpu":
                task_dict.get('gpu', False),
            "mem_limit":
                task_dict.get('mem_limit', None),
            "workspace":
                task_dict.get('workspace', None),
            "interactive":
                task_dict.get('interactive', task_obj.interactive),
            "detach":
                task_dict.get('detach', task_obj.detach),
            "ports":
                task_dict.get('ports', task_obj.ports),
            "task_dirpath":
                task_dict.get('task_dirpath', task_dirpath),
            "log_filepath":
                task_dict.get('log_filepath',
                              os.path.join(task_dirpath, "task.log")),
            "start_time":
                task_dict.get('start_time', datetime.utcnow()),
            "status":
                task_obj.status
        })

        # Copy over files from the before_snapshot file collection to task dir
        file_collection_obj =  \
            self.dal.file_collection.get_by_id(before_snapshot_obj.file_collection_id)
        self.file_driver.copytree(
            os.path.join(self.home, file_collection_obj.path),
            os.path.join(self.home, task_obj.task_dirpath))

        return_code, run_id, logs = 0, None, None

        try:
            # Set the parameters set in the task
            if task_obj.detach and task_obj.interactive:
                raise TaskInteractiveDetachError(
                    __("error", "controller.task.run.args.detach.interactive"))

            environment_run_options = {
                "command": task_obj.command_list,
                "ports": [] if task_obj.ports is None else task_obj.ports,
                "name": "datmo-task-" + self.model.id + "-" + task_obj.id,
                "volumes": {
                    os.path.join(self.home, task_obj.task_dirpath): {
                        'bind': '/task/',
                        'mode': 'rw'
                    },
                    self.home: {
                        'bind': '/home/',
                        'mode': 'rw'
                    }
                },
                "mem_limit": task_obj.mem_limit,
                "workspace": task_obj.workspace,
                "gpu": task_obj.gpu,
                "detach": task_obj.detach,
                "stdin_open": task_obj.interactive,
                "tty": task_obj.interactive,
                "api": False
            }
            # Run environment via the helper function
            return_code, run_id, logs =  \
                self._run_helper(before_snapshot_obj.environment_id,
                                 environment_run_options,
                                 os.path.join(self.home, task_obj.log_filepath))

        except Exception as e:
            return_code = 1
            logs += "Error running task: %" % e.message
        finally:
            # Create the after snapshot after execution is completed with new paths
            after_snapshot_dict = snapshot_dict.copy()
            after_snapshot_dict[
                'message'] = "autogenerated snapshot created after task %s is run" % task_obj.id

            # Add in absolute paths from running task directory
            absolute_task_dir_path = os.path.join(self.home,
                                                  task_obj.task_dirpath)
            absolute_paths = []
            for item in os.listdir(absolute_task_dir_path):
                path = os.path.join(absolute_task_dir_path, item)
                if os.path.isfile(path) or os.path.isdir(path):
                    absolute_paths.append(path)
            after_snapshot_dict.update({
                "paths": absolute_paths,
                "environment_id": before_snapshot_obj.environment_id,
            })
            after_snapshot_obj = self.snapshot.create(after_snapshot_dict)

            # (optional) Remove temporary task directory path
            # Update the task with post-execution parameters
            end_time = datetime.utcnow()
            duration = (end_time - task_obj.start_time).total_seconds()
            update_task_dict = {
                "id": task_obj.id,
                "after_snapshot_id": after_snapshot_obj.id,
                "logs": logs,
                "status": "SUCCESS" if return_code == 0 else "FAILED",
                # "results": task_obj.results, # TODO: update during run
                "end_time": end_time,
                "duration": duration
            }
            if logs is not None:
                update_task_dict["results"] = self._parse_logs_for_results(
                    logs)
            if run_id is not None:
                update_task_dict["run_id"] = run_id
            return self.dal.task.update(update_task_dict)

    def list(self, session_id=None, sort_key=None, sort_order=None):
        query = {}
        if session_id:
            try:
                self.dal.session.get_by_id(session_id)
            except EntityNotFound:
                raise SessionDoesNotExist(
                    __("error", "controller.task.list", session_id))
            query['session_id'] = session_id
        return self.dal.task.query(query, sort_key, sort_order)

    def get(self, task_id):
        """Get task object and return

        Parameters
        ----------
        task_id : str
            id for the task you would like to get

        Returns
        -------
        datmo.core.entity.task.Task
            core task object

        Raises
        ------
        DoesNotExist
            task does not exist
        """
        try:
            return self.dal.task.get_by_id(task_id)
        except EntityNotFound:
            raise DoesNotExist()

    def get_files(self, task_id, mode="r"):
        """Get list of file objects for task id. It will look in the following areas in the following order

        1) look in the after snapshot for file collection
        2) look in the running task file collection
        3) look in the before snapshot for file collection

        Parameters
        ----------
        task_id : str
            id for the task you would like to get file objects for
        mode : str
            file open mode
            (default is "r" to open file for read)

        Returns
        -------
        list
            list of python file objects

        Raises
        ------
        DoesNotExist
            task object does not exist
        PathDoesNotExist
            no file objects exist for the task
        """
        try:
            task_obj = self.dal.task.get_by_id(task_id)
        except EntityNotFound:
            raise DoesNotExist()
        if task_obj.after_snapshot_id:
            # perform number 1) and return file list
            return self.snapshot.get_files(
                task_obj.after_snapshot_id, mode=mode)
        elif task_obj.task_dirpath:
            # perform number 2) and return file list
            return self.file_driver.get(
                task_obj.task_dirpath, mode=mode, directory=True)
        elif task_obj.before_snapshot_id:
            # perform number 3) and return file list
            return self.snapshot.get_files(
                task_obj.before_snapshot_id, mode=mode)
        else:
            # Error because the task does not have any files associated with it
            raise PathDoesNotExist()

    def delete(self, task_id):
        if not task_id:
            raise RequiredArgumentMissing(
                __("error", "controller.task.delete.arg", "id"))
        stopped_success = self.stop(task_id)
        delete_task_success = self.dal.task.delete(task_id)
        return stopped_success and delete_task_success

    def stop(self, task_id=None, all=False, status="STOPPED"):
        """Stop and remove run for the task and update task object statuses

        Parameters
        ----------
        task_id : str, optional
            id for the task you would like to stop
        all : bool, optional
            if specified, will stop all tasks within project

        Returns
        -------
        return_code : bool
            system return code of the stop

        Raises
        ------
        RequiredArgumentMissing
        TooManyArgumentsFound
        """
        if task_id is None and all is False:
            raise RequiredArgumentMissing(
                __("error", "controller.task.stop.arg.missing", "id"))
        if task_id and all:
            raise TooManyArgumentsFound()
        if task_id:
            try:
                task_obj = self.get(task_id)
            except DoesNotExist:
                time.sleep(1)
                task_obj = self.get(task_id)
            task_match_string = "datmo-task-" + self.model.id + "-" + task_id
            # Get the environment id associated with the task
            kwargs = {'match_string': task_match_string}
            # Get the environment from the task
            before_snapshot_id = task_obj.before_snapshot_id
            after_snapshot_id = task_obj.after_snapshot_id
            if not before_snapshot_id and not after_snapshot_id:
                # TODO: remove...for now database may not be in sync. no task that has run can have NO before_snapshot_id
                time.sleep(1)
                task_obj = self.get(task_id)
            if after_snapshot_id:
                after_snapshot_obj = self.snapshot.get(after_snapshot_id)
                kwargs['environment_id'] = after_snapshot_obj.environment_id
            if not after_snapshot_id and before_snapshot_id:
                before_snapshot_obj = self.snapshot.get(before_snapshot_id)
                kwargs['environment_id'] = before_snapshot_obj.environment_id
            return_code = self.environment.stop(**kwargs)
        if all:
            return_code = self.environment.stop(all=True)
        # Set stopped task statuses to STOPPED if return success
        if return_code:
            if task_id:
                self.dal.task.update({"id": task_id, "status": status})
            if all:
                task_objs = self.dal.task.query({})
                for task_obj in task_objs:
                    self.dal.task.update({"id": task_obj.id, "status": status})

        return return_code
Exemple #3
0
class SnapshotCommand(ProjectCommand):
    def __init__(self, cli_helper):
        super(SnapshotCommand, self).__init__(cli_helper)

    def usage(self):
        self.cli_helper.echo(__("argparser", "cli.snapshot.usage"))

    def snapshot(self):
        self.parse(["snapshot", "--help"])
        return True

    @Helper.notify_no_project_found
    def create(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        self.cli_helper.echo(__("info", "cli.snapshot.create"))
        run_id = kwargs.get("run_id", None)
        # creating snapshot with task id if it exists
        if run_id is not None:
            excluded_args = [
                "environment_id", "environment_paths", "paths",
                "config_filepath", "config_filename", "stats_filepath",
                "stats_filename"
            ]
            for arg in excluded_args:
                if arg in kwargs and kwargs[arg] is not None:
                    raise SnapshotCreateFromTaskArgs(
                        "error", "cli.snapshot.create.run.args", arg)

            message = kwargs.get("message", None)
            label = kwargs.get("label", None)
            # Create a new core snapshot object
            snapshot_task_obj = self.snapshot_controller.create_from_task(
                message, run_id, label=label)
            self.cli_helper.echo(
                "Created snapshot id: %s" % snapshot_task_obj.id)
            return snapshot_task_obj
        else:
            # creating snapshot without task id
            snapshot_dict = {"visible": True}

            # Environment
            if kwargs.get("environment_id", None) or kwargs.get(
                    "environment_paths", None):
                mutually_exclusive_args = [
                    "environment_id", "environment_paths"
                ]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)

            # File
            if kwargs.get("paths", None):
                snapshot_dict['paths'] = kwargs['paths']

            # Config
            if kwargs.get("config_filepath", None) or kwargs.get(
                    "config_filename", None) or kwargs.get("config", None):
                mutually_exclusive_args = [
                    "config_filepath", "config_filename", "config"
                ]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)
            # parsing config
            if "config" in snapshot_dict:
                config = {}
                config_list = snapshot_dict["config"]
                for item in config_list:
                    item_parsed_dict = parse_cli_key_value(item, 'config')
                    config.update(item_parsed_dict)
                snapshot_dict["config"] = config

            # Stats
            if kwargs.get("stats_filepath", None) or kwargs.get(
                    "stats_filename", None) or kwargs.get("config", None):
                mutually_exclusive_args = [
                    "stats_filepath", "stats_filename", "stats"
                ]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)
            # parsing stats
            if "stats" in snapshot_dict:
                stats = {}
                stats_list = snapshot_dict["stats"]
                for item in stats_list:
                    item_parsed_dict = parse_cli_key_value(item, 'stats')
                    stats.update(item_parsed_dict)
                snapshot_dict["stats"] = stats

            optional_args = ["message", "label"]

            for arg in optional_args:
                if arg in kwargs and kwargs[arg] is not None:
                    snapshot_dict[arg] = kwargs[arg]

            snapshot_obj = self.snapshot_controller.create(snapshot_dict)
            # Because snapshots may be invisible to the user, this function ensures that by the end
            # the user can monitor the snapshot on the CLI, but making it visible
            snapshot_obj = self.snapshot_controller.update(
                snapshot_obj.id, visible=True)
            self.cli_helper.echo(
                __("info", "cli.snapshot.create.success", snapshot_obj.id))
            return snapshot_obj

    @Helper.notify_no_project_found
    def delete(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        self.cli_helper.echo(__("info", "cli.snapshot.delete"))
        snapshot_id = kwargs.get('id')
        result = self.snapshot_controller.delete(snapshot_id)
        self.cli_helper.echo(
            __("info", "cli.snapshot.delete.success", snapshot_id))
        return result

    @Helper.notify_no_project_found
    def update(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        self.cli_helper.echo(__("info", "cli.snapshot.update"))
        snapshot_id = kwargs.get('id')
        # getting previous saved config and stats
        snapshot_obj = self.snapshot_controller.get(snapshot_id)
        config = snapshot_obj.config
        stats = snapshot_obj.stats

        # extracting config
        update_config_list = kwargs.get('config', None)
        if update_config_list:
            update_config = {}
            for item in update_config_list:
                item_parsed_dict = parse_cli_key_value(item, 'config')
                update_config.update(item_parsed_dict)
            # updating config
            config.update(update_config)

        # extracting stats
        update_stats_list = kwargs.get('stats', None)
        if update_stats_list:
            update_stats = {}
            for item in update_stats_list:
                item_parsed_dict = parse_cli_key_value(item, 'stats')
                update_stats.update(item_parsed_dict)
            # updating stats
            stats.update(update_stats)

        # extracting message
        message = kwargs.get('message', None)
        # extracting label
        label = kwargs.get('label', None)

        result = self.snapshot_controller.update(
            snapshot_id,
            config=config,
            stats=stats,
            message=message,
            label=label)
        self.cli_helper.echo(
            __("info", "cli.snapshot.update.success", snapshot_id))
        return result

    @Helper.notify_no_project_found
    def ls(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        detailed_info = kwargs.get('details', None)
        show_all = kwargs.get('show_all', None)
        print_format = kwargs.get('format', "table")
        download = kwargs.get('download', None)
        download_path = kwargs.get('download_path', None)
        current_snapshot_obj = self.snapshot_controller.current_snapshot()
        current_snapshot_id = current_snapshot_obj.id if current_snapshot_obj else None
        if show_all:
            snapshot_objs = self.snapshot_controller.list(
                sort_key="created_at", sort_order="descending")
        else:
            snapshot_objs = self.snapshot_controller.list(
                visible=True, sort_key="created_at", sort_order="descending")
        item_dict_list = []
        if detailed_info:
            header_list = [
                "id", "created at", "config", "stats", "message", "label",
                "code id", "environment id", "file collection id"
            ]
            for snapshot_obj in snapshot_objs:
                snapshot_config_printable = printable_object(
                    snapshot_obj.config)
                snapshot_stats_printable = printable_object(snapshot_obj.stats)
                snapshot_message = printable_object(snapshot_obj.message)
                snapshot_label = printable_object(snapshot_obj.label)
                printable_snapshot_id = snapshot_obj.id if current_snapshot_id is not None and \
                                                           snapshot_obj.id != current_snapshot_id\
                    else "(current) " + snapshot_obj.id
                item_dict_list.append({
                    "id": printable_snapshot_id,
                    "created at": prettify_datetime(snapshot_obj.created_at),
                    "config": snapshot_config_printable,
                    "stats": snapshot_stats_printable,
                    "message": snapshot_message,
                    "label": snapshot_label,
                    "code id": snapshot_obj.code_id,
                    "environment id": snapshot_obj.environment_id,
                    "file collection id": snapshot_obj.file_collection_id
                })
        else:
            header_list = [
                "id", "created at", "config", "stats", "message", "label"
            ]
            for snapshot_obj in snapshot_objs:
                snapshot_config_printable = printable_object(
                    snapshot_obj.config)
                snapshot_stats_printable = printable_object(snapshot_obj.stats)
                snapshot_message = printable_object(snapshot_obj.message)
                snapshot_label = printable_object(snapshot_obj.label)
                printable_snapshot_id = snapshot_obj.id if current_snapshot_id is not None and \
                                                           snapshot_obj.id != current_snapshot_id \
                    else "(current) " + snapshot_obj.id
                item_dict_list.append({
                    "id": printable_snapshot_id,
                    "created at": prettify_datetime(snapshot_obj.created_at),
                    "config": snapshot_config_printable,
                    "stats": snapshot_stats_printable,
                    "message": snapshot_message,
                    "label": snapshot_label,
                })
        if download:
            if not download_path:
                # download to current working directory with timestamp
                current_time = datetime.utcnow()
                epoch_time = datetime.utcfromtimestamp(0)
                current_time_unix_time_ms = (
                    current_time - epoch_time).total_seconds() * 1000.0
                download_path = os.path.join(
                    self.snapshot_controller.home,
                    "snapshot_ls_" + str(current_time_unix_time_ms))
            self.cli_helper.print_items(
                header_list,
                item_dict_list,
                print_format=print_format,
                output_path=download_path)
            return snapshot_objs
        self.cli_helper.print_items(
            header_list, item_dict_list, print_format=print_format)
        return snapshot_objs

    @Helper.notify_no_project_found
    def checkout(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        snapshot_id = kwargs.get('id')
        checkout_success = self.snapshot_controller.checkout(snapshot_id)
        if checkout_success:
            self.cli_helper.echo(
                __("info", "cli.snapshot.checkout.success", snapshot_id))
        return self.snapshot_controller.checkout(snapshot_id)

    @Helper.notify_no_project_found
    def diff(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        snapshot_id_1 = kwargs.get("id_1", None)
        snapshot_id_2 = kwargs.get("id_2", None)
        snapshot_obj_1 = self.snapshot_controller.get(snapshot_id_1)
        snapshot_obj_2 = self.snapshot_controller.get(snapshot_id_2)
        comparison_attributes = [
            "id", "created_at", "message", "label", "code_id",
            "environment_id", "file_collection_id", "config", "stats"
        ]
        table_data = [["Attributes", "Snapshot 1", "", "Snapshot 2"],
                      ["", "", "", ""]]
        for attribute in comparison_attributes:
            value_1 = getattr(snapshot_obj_1, attribute) if getattr(
                snapshot_obj_1, attribute) else "N/A"
            value_2 = getattr(snapshot_obj_2, attribute) if getattr(
                snapshot_obj_2, attribute) else "N/A"
            if isinstance(value_1, datetime):
                value_1 = prettify_datetime(value_1)
            if isinstance(value_2, datetime):
                value_2 = prettify_datetime(value_2)
            if attribute in ["config", "stats"]:
                alldict = []
                if isinstance(value_1, dict): alldict.append(value_1)
                if isinstance(value_2, dict): alldict.append(value_2)
                allkey = set().union(*alldict)
                for key in allkey:
                    key_value_1 = "%s: %s" % (key, value_1[key]) if value_1 != "N/A" and value_1.get(key, None) \
                        else "N/A"
                    key_value_2 = "%s: %s" % (key, value_2[key]) if value_2 != "N/A" and value_2.get(key, None) \
                        else "N/A"
                    table_data.append(
                        [attribute, key_value_1, "->", key_value_2])
            else:
                table_data.append([attribute, value_1, "->", value_2])
        output = format_table(table_data)
        self.cli_helper.echo(output)
        return output

    @Helper.notify_no_project_found
    def inspect(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        snapshot_id = kwargs.get("id", None)
        snapshot_obj = self.snapshot_controller.get(snapshot_id)
        output = str(snapshot_obj)
        self.cli_helper.echo(output)
        return output
Exemple #4
0
class TestSnapshotController():
    def setup_method(self):
        # provide mountable tmp directory for docker
        tempfile.tempdir = "/tmp" if not platform.system(
        ) == "Windows" else None
        test_datmo_dir = os.environ.get('TEST_DATMO_DIR',
                                        tempfile.gettempdir())
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        self.project = ProjectController(self.temp_dir)
        self.project.init("test", "test description")
        self.task = TaskController(self.temp_dir)
        self.snapshot = SnapshotController(self.temp_dir)

    def teardown_method(self):
        pass

    def test_create_fail_no_message(self):
        # Test no message
        failed = False
        try:
            self.snapshot.create({})
        except RequiredArgumentMissing:
            failed = True
        assert failed

    def test_create_fail_no_code(self):
        # Test default values for snapshot, fail due to code
        failed = False
        try:
            self.snapshot.create({"message": "my test snapshot"})
        except GitCommitDoesNotExist:
            failed = True
        assert failed

    def test_create_fail_no_environment_with_language(self):
        # Test default values for snapshot, fail due to environment with other than default
        self.snapshot.file_driver.create("filepath1")
        failed = False
        try:
            self.snapshot.create({
                "message": "my test snapshot",
                "language": "java"
            })
        except EnvironmentDoesNotExist:
            failed = True
        assert failed

    def test_create_fail_no_environment_detected_in_file(self):
        # Test default values for snapshot, fail due to no environment from file
        self.snapshot.file_driver.create("filepath1")
        failed = False
        try:
            self.snapshot.create({
                "message": "my test snapshot",
            })
        except EnvironmentDoesNotExist:
            failed = True
        assert failed

    def test_create_success_default_detected_in_file(self):
        # Test default values for snapshot when there is no environment
        test_filepath = os.path.join(self.snapshot.home, "script.py")
        with open(test_filepath, "w") as f:
            f.write(to_unicode("import numpy\n"))
            f.write(to_unicode("import sklearn\n"))
            f.write(to_unicode("print('hello')\n"))

        snapshot_obj_1 = self.snapshot.create({"message": "my test snapshot"})

        assert snapshot_obj_1
        assert snapshot_obj_1.code_id
        assert snapshot_obj_1.environment_id
        assert snapshot_obj_1.file_collection_id
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_default_env_def(self):
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot.create({"message": "my test snapshot"})

        assert snapshot_obj
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_default_env_def_duplicate(self):
        # Test 2 snapshots with same parameters
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        snapshot_obj = self.snapshot.create({"message": "my test snapshot"})

        snapshot_obj_1 = self.snapshot.create({"message": "my test snapshot"})

        # Should return the same object back
        assert snapshot_obj_1 == snapshot_obj
        assert snapshot_obj_1.code_id == snapshot_obj.code_id
        assert snapshot_obj_1.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_1.file_collection_id == \
               snapshot_obj.file_collection_id
        assert snapshot_obj_1.config == \
               snapshot_obj.config
        assert snapshot_obj_1.stats == \
               snapshot_obj.stats

    def test_create_success_given_files_env_def_config_file_stats_file(self):
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        snapshot_obj = self.snapshot.create({"message": "my test snapshot"})

        # Create files to add
        self.snapshot.file_driver.create("dirpath1", directory=True)
        self.snapshot.file_driver.create("dirpath2", directory=True)
        self.snapshot.file_driver.create("filepath1")

        # Create config
        config_filepath = os.path.join(self.snapshot.home, "config.json")
        with open(config_filepath, "w") as f:
            f.write(to_unicode(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot.home, "stats.json")
        with open(stats_filepath, "w") as f:
            f.write(to_unicode(str('{"foo":"bar"}')))

        input_dict = {
            "message":
            "my test snapshot",
            "filepaths": [
                os.path.join(self.snapshot.home, "dirpath1"),
                os.path.join(self.snapshot.home, "dirpath2"),
                os.path.join(self.snapshot.home, "filepath1")
            ],
            "environment_definition_filepath":
            env_def_path,
            "config_filepath":
            config_filepath,
            "stats_filepath":
            stats_filepath,
        }
        # Create snapshot in the project
        snapshot_obj_4 = self.snapshot.create(input_dict)

        assert snapshot_obj_4 != snapshot_obj
        assert snapshot_obj_4.code_id != snapshot_obj.code_id
        assert snapshot_obj_4.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_4.file_collection_id != \
               snapshot_obj.file_collection_id
        assert snapshot_obj_4.config == {"foo": "bar"}
        assert snapshot_obj_4.stats == {"foo": "bar"}

    def test_create_success_given_files_env_def_different_config_stats(self):
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        snapshot_obj = self.snapshot.create({"message": "my test snapshot"})

        # Create files to add
        self.snapshot.file_driver.create("dirpath1", directory=True)
        self.snapshot.file_driver.create("dirpath2", directory=True)
        self.snapshot.file_driver.create("filepath1")

        # Create config
        config_filepath = os.path.join(self.snapshot.home, "config.json")
        with open(config_filepath, "w") as f:
            f.write(to_unicode(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot.home, "stats.json")
        with open(stats_filepath, "w") as f:
            f.write(to_unicode(str('{"foo":"bar"}')))

        # Test different config and stats inputs
        input_dict = {
            "message":
            "my test snapshot",
            "filepaths": [
                os.path.join(self.snapshot.home, "dirpath1"),
                os.path.join(self.snapshot.home, "dirpath2"),
                os.path.join(self.snapshot.home, "filepath1")
            ],
            "environment_definition_filepath":
            env_def_path,
            "config_filename":
            "different_name",
            "stats_filename":
            "different_name",
        }

        # Create snapshot in the project
        snapshot_obj_1 = self.snapshot.create(input_dict)

        assert snapshot_obj_1 != snapshot_obj
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_given_files_env_def_direct_config_stats(self):
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Create files to add
        self.snapshot.file_driver.create("dirpath1", directory=True)
        self.snapshot.file_driver.create("dirpath2", directory=True)
        self.snapshot.file_driver.create("filepath1")

        # Test different config and stats inputs
        input_dict = {
            "message":
            "my test snapshot",
            "filepaths": [
                os.path.join(self.snapshot.home, "dirpath1"),
                os.path.join(self.snapshot.home, "dirpath2"),
                os.path.join(self.snapshot.home, "filepath1")
            ],
            "environment_definition_filepath":
            env_def_path,
            "config": {
                "foo": "bar"
            },
            "stats": {
                "foo": "bar"
            },
        }

        # Create snapshot in the project
        snapshot_obj_6 = self.snapshot.create(input_dict)

        assert snapshot_obj_6.config == {"foo": "bar"}
        assert snapshot_obj_6.stats == {"foo": "bar"}

    def test_create_from_task(self):
        # 1) Test if fails with TaskNotComplete error
        # 2) Test if success with task files, results, and message

        # Setup task
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        input_dict = {"command": task_command}

        # Create task in the project
        task_obj = self.task.create(input_dict)

        # 1) Test option 1
        failed = False
        try:
            _ = self.snapshot.create_from_task(message="my test snapshot",
                                               task_id=task_obj.id)
        except TaskNotComplete:
            failed = True
        assert failed

        # Create environment definition
        env_def_path = os.path.join(self.project.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Test the default values
        updated_task_obj = self.task.run(task_obj.id)

        # 2) Test option 2
        snapshot_obj = self.snapshot.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

    def __default_create(self):
        # Create files to add
        self.snapshot.file_driver.create("dirpath1", directory=True)
        self.snapshot.file_driver.create("dirpath2", directory=True)
        self.snapshot.file_driver.create("filepath1")

        # Create environment_driver definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Create config
        config_filepath = os.path.join(self.snapshot.home, "config.json")
        with open(config_filepath, "w") as f:
            f.write(to_unicode(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot.home, "stats.json")
        with open(stats_filepath, "w") as f:
            f.write(to_unicode(str("{}")))

        input_dict = {
            "message":
            "my test snapshot",
            "filepaths": [
                os.path.join(self.snapshot.home, "dirpath1"),
                os.path.join(self.snapshot.home, "dirpath2"),
                os.path.join(self.snapshot.home, "filepath1")
            ],
            "environment_definition_filepath":
            env_def_path,
            "config_filename":
            config_filepath,
            "stats_filename":
            stats_filepath,
        }

        # Create snapshot in the project
        return self.snapshot.create(input_dict)

    def test_checkout(self):
        # Create snapshot
        snapshot_obj_1 = self.__default_create()

        code_obj_1 = self.snapshot.dal.code.get_by_id(snapshot_obj_1.code_id)

        # Create duplicate snapshot in project
        self.snapshot.file_driver.create("test")
        snapshot_obj_2 = self.__default_create()

        assert snapshot_obj_2 != snapshot_obj_1

        # Checkout to snapshot 1 using snapshot id
        result = self.snapshot.checkout(snapshot_obj_1.id)

        # Snapshot directory in user directory
        snapshot_obj_1_path = os.path.join(self.snapshot.home,
                                           "datmo_snapshots",
                                           snapshot_obj_1.id)

        assert result == True and \
               self.snapshot.code_driver.latest_commit() == code_obj_1.commit_id and \
               os.path.isdir(snapshot_obj_1_path)

    def test_list(self):
        # Check for error if incorrect session given
        failed = False
        try:
            self.snapshot.list(session_id="does_not_exist")
        except SessionDoesNotExistException:
            failed = True
        assert failed

        # Create file to add to snapshot
        test_filepath_1 = os.path.join(self.snapshot.home, "test.txt")
        with open(test_filepath_1, "w") as f:
            f.write(to_unicode(str("test")))

        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Create file to add to second snapshot
        test_filepath_2 = os.path.join(self.snapshot.home, "test2.txt")
        with open(test_filepath_2, "w") as f:
            f.write(to_unicode(str("test2")))

        # Create second snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # List all snapshots and ensure they exist
        result = self.snapshot.list()

        assert len(result) == 2 and \
            snapshot_obj_1 in result and \
            snapshot_obj_2 in result

        # List all snapshots with session filter
        result = self.snapshot.list(session_id=self.project.current_session.id)

        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result

        # List snapshots with visible filter
        result = self.snapshot.list(visible=False)
        assert len(result) == 0

        result = self.snapshot.list(visible=True)
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result

    def test_delete(self):
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Delete snapshot in the project
        result = self.snapshot.delete(snapshot_obj.id)

        # Check if snapshot retrieval throws error
        thrown = False
        try:
            self.snapshot.dal.snapshot.get_by_id(snapshot_obj.id)
        except EntityNotFound:
            thrown = True

        assert result == True and \
            thrown == True
Exemple #5
0
class TestSnapshotController():
    def setup_method(self):
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        Config().set_home(self.temp_dir)
        self.environment_ids = []

    def teardown_method(self):
        if not check_docker_inactive(test_datmo_dir,
                                     Config().datmo_directory_name):
            self.__setup()
            self.environment_controller = EnvironmentController()
            for env_id in list(set(self.environment_ids)):
                if not self.environment_controller.delete(env_id):
                    raise Exception

    def __setup(self):
        Config().set_home(self.temp_dir)
        self.project_controller = ProjectController()
        self.project_controller.init("test", "test description")
        self.task_controller = TaskController()
        self.snapshot_controller = SnapshotController()

    def test_init_fail_project_not_init(self):
        Config().set_home(self.temp_dir)
        failed = False
        try:
            SnapshotController()
        except ProjectNotInitialized:
            failed = True
        assert failed

    def test_init_fail_invalid_path(self):
        test_home = "some_random_dir"
        Config().set_home(test_home)
        failed = False
        try:
            SnapshotController()
        except InvalidProjectPath:
            failed = True
        assert failed

    def test_current_snapshot(self):
        self.__setup()
        # Test failure for unstaged changes
        failed = False
        try:
            self.snapshot_controller.current_snapshot()
        except UnstagedChanges:
            failed = True
        assert failed
        # Test success after snapshot created
        snapshot_obj = self.__default_create()
        current_snapshot_obj = self.snapshot_controller.current_snapshot()
        assert current_snapshot_obj == snapshot_obj

    def test_create_fail_no_message(self):
        self.__setup()
        # Test no message
        failed = False
        try:
            self.snapshot_controller.create({})
        except RequiredArgumentMissing:
            failed = True
        assert failed

    def test_create_success_no_code(self):
        self.__setup()
        # Test default values for snapshot, fail due to code
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})
        assert result

    def test_create_success_no_code_environment(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # test must pass when there is file present in root project folder
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert result

    def test_create_success_no_code_environment_files(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        test_file = os.path.join(
            self.project_controller.file_driver.files_directory, "test.txt")
        with open(test_file, "wb") as f:
            f.write(to_bytes(str("hello")))

        # test must pass when there is file present in root project folder
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert result

    def test_create_no_environment_detected_in_file(self):
        self.__setup()

        # Test default values for snapshot, fail due to no environment from file
        self.snapshot_controller.file_driver.create("filepath1")
        snapshot_obj_0 = self.snapshot_controller.create(
            {"message": "my test snapshot"})
        assert isinstance(snapshot_obj_0, Snapshot)
        assert snapshot_obj_0.code_id
        assert snapshot_obj_0.environment_id
        assert snapshot_obj_0.file_collection_id
        assert snapshot_obj_0.config == {}
        assert snapshot_obj_0.stats == {}

    def test_create_success_default_detected_in_file(self):
        self.__setup()
        # Test default values for snapshot when there is no environment
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import os\n"))
            f.write(to_bytes("import sys\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj_1 = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj_1, Snapshot)
        assert snapshot_obj_1.code_id
        assert snapshot_obj_1.environment_id
        assert snapshot_obj_1.file_collection_id
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_default_env_def(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_with_environment(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_env_paths(self):
        self.__setup()
        # Create environment definition
        random_dir = os.path.join(self.snapshot_controller.home, "random_dir")
        os.makedirs(random_dir)
        env_def_path = os.path.join(random_dir, "randomDockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))
        environment_paths = [env_def_path + ">Dockerfile"]

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create({
            "message":
            "my test snapshot",
            "environment_paths":
            environment_paths
        })

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_default_env_def_duplicate(self):
        self.__setup()
        # Test 2 snapshots with same parameters
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        snapshot_obj_1 = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Should return the same object back
        assert snapshot_obj_1.id == snapshot_obj.id
        assert snapshot_obj_1.code_id == snapshot_obj.code_id
        assert snapshot_obj_1.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_1.file_collection_id == \
               snapshot_obj.file_collection_id
        assert snapshot_obj_1.config == \
               snapshot_obj.config
        assert snapshot_obj_1.stats == \
               snapshot_obj.stats

    def test_create_success_given_files_env_def_config_file_stats_file(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        input_dict = {
            "message": "my test snapshot",
            "config_filepath": config_filepath,
            "stats_filepath": stats_filepath,
        }
        # Create snapshot in the project
        snapshot_obj_4 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_4 != snapshot_obj
        assert snapshot_obj_4.code_id != snapshot_obj.code_id
        assert snapshot_obj_4.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_4.file_collection_id != \
               snapshot_obj.file_collection_id
        assert snapshot_obj_4.config == {"foo": "bar"}
        assert snapshot_obj_4.stats == {"foo": "bar"}

    def test_create_success_given_files_env_def_different_config_stats(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Test different config and stats inputs
        input_dict = {
            "message": "my test snapshot",
            "config_filename": "different_name",
            "stats_filename": "different_name",
        }

        # Create snapshot in the project
        snapshot_obj_1 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_1 != snapshot_obj
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_given_files_env_def_direct_config_stats(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test different config and stats inputs
        input_dict = {
            "message": "my test snapshot",
            "config": {
                "foo": "bar"
            },
            "stats": {
                "foo": "bar"
            },
        }

        # Create snapshot in the project
        snapshot_obj_6 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_6.config == {"foo": "bar"}
        assert snapshot_obj_6.stats == {"foo": "bar"}

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_create_from_task(self):
        self.__setup()
        # 0) Test if fails with TaskNotComplete error
        # 1) Test if success with empty task files, results
        # 2) Test if success with task files, results, and message
        # 3) Test if success with message, label, config and stats
        # 4) Test if success with updated stats from after_snapshot_id and task_results

        # Create task in the project
        task_obj = self.task_controller.create()

        # 0) Test option 0
        failed = False
        try:
            _ = self.snapshot_controller.create_from_task(
                message="my test snapshot", task_id=task_obj.id)
        except TaskNotComplete:
            failed = True
        assert failed

        # 1) Test option 1

        # Create task_dict
        task_command = ["sh", "-c", "echo test"]
        task_dict = {"command_list": task_command}

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

        # Create new task and corresponding dict
        task_obj = self.task_controller.create()
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Test the default values
        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        # 2) Test option 2
        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

        # 3) Test option 3
        test_config = {"algo": "regression"}
        test_stats = {"accuracy": 0.9}
        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot",
            task_id=updated_task_obj.id,
            label="best",
            config=test_config,
            stats=test_stats)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.label == "best"
        assert snapshot_obj.config == test_config
        assert snapshot_obj.stats == test_stats
        assert snapshot_obj.visible == True

        # 4) Test option 4
        test_config = {"algo": "regression"}
        test_stats = {"new_key": 0.9}
        task_obj_2 = self.task_controller.create()
        updated_task_obj_2 = self.task_controller.run(task_obj_2.id,
                                                      task_dict=task_dict,
                                                      snapshot_dict={
                                                          "config":
                                                          test_config,
                                                          "stats": test_stats
                                                      })
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj_2.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot",
            task_id=updated_task_obj_2.id,
            label="best")
        updated_stats_dict = {}
        updated_stats_dict.update(test_stats)
        updated_stats_dict.update(updated_task_obj.results)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj_2.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.label == "best"
        assert snapshot_obj.stats == updated_stats_dict
        assert snapshot_obj.visible == True

    def __default_create(self):
        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))
        self.snapshot_controller.file_driver.create("filepath2")
        with open(os.path.join(self.snapshot_controller.home, "filepath2"),
                  "wb") as f:
            f.write(to_bytes(str("import sys\n")))
        # Create environment_driver definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        input_dict = {
            "message": "my test snapshot",
            "config_filename": config_filepath,
            "stats_filename": stats_filepath,
        }

        # Create snapshot in the project
        return self.snapshot_controller.create(input_dict)

    def test_check_unstaged_changes(self):
        self.__setup()
        # Check unstaged changes
        failed = False
        try:
            self.snapshot_controller.check_unstaged_changes()
        except UnstagedChanges:
            failed = True
        assert failed
        # Check no unstaged changes
        _ = self.__default_create()
        result = self.snapshot_controller.check_unstaged_changes()
        assert result == False

    def test_checkout(self):
        self.__setup()
        # Create snapshot
        snapshot_obj_1 = self.__default_create()

        # Create duplicate snapshot in project
        self.snapshot_controller.file_driver.create("test")
        snapshot_obj_2 = self.__default_create()

        assert snapshot_obj_2 != snapshot_obj_1

        # Checkout to snapshot 1 using snapshot id
        result = self.snapshot_controller.checkout(snapshot_obj_1.id)
        # TODO: Check for which snapshot we are on

        assert result == True

    def test_list(self):
        self.__setup()
        # Create file to add to snapshot
        test_filepath_1 = os.path.join(self.snapshot_controller.home,
                                       "test.txt")
        with open(test_filepath_1, "wb") as f:
            f.write(to_bytes(str("test")))

        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Create file to add to second snapshot
        test_filepath_2 = os.path.join(self.snapshot_controller.home,
                                       "test2.txt")
        with open(test_filepath_2, "wb") as f:
            f.write(to_bytes(str("test2")))

        # Create second snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # List all snapshots and ensure they exist
        result = self.snapshot_controller.list()

        assert len(result) == 2 and \
            snapshot_obj_1 in result and \
            snapshot_obj_2 in result

        # List all tasks regardless of filters in ascending
        result = self.snapshot_controller.list(sort_key='created_at',
                                               sort_order='ascending')

        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result
        assert result[0].created_at <= result[-1].created_at

        # List all tasks regardless of filters in descending
        result = self.snapshot_controller.list(sort_key='created_at',
                                               sort_order='descending')
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result
        assert result[0].created_at >= result[-1].created_at

        # Wrong order being passed in
        failed = False
        try:
            _ = self.snapshot_controller.list(sort_key='created_at',
                                              sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # Wrong key and order being passed in
        failed = False
        try:
            _ = self.snapshot_controller.list(sort_key='wrong_key',
                                              sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # wrong key and right order being passed in
        expected_result = self.snapshot_controller.list(sort_key='created_at',
                                                        sort_order='ascending')
        result = self.snapshot_controller.list(sort_key='wrong_key',
                                               sort_order='ascending')
        expected_ids = [item.id for item in expected_result]
        ids = [item.id for item in result]
        assert set(expected_ids) == set(ids)

        # List snapshots with visible filter
        result = self.snapshot_controller.list(visible=False)
        assert len(result) == 0

        result = self.snapshot_controller.list(visible=True)
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result

    def test_update(self):
        self.__setup()
        test_config = {"config_foo": "bar"}
        test_stats = {"stats_foo": "bar"}
        test_message = 'test_message'
        test_label = 'test_label'

        # Updating all config, stats, message and label
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        config=test_config,
                                        stats=test_stats,
                                        message=test_message,
                                        label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)
        assert updated_snapshot_obj.config == test_config
        assert updated_snapshot_obj.stats == test_stats
        assert updated_snapshot_obj.message == test_message
        assert updated_snapshot_obj.label == test_label

        # Updating config, stats
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        config=test_config,
                                        stats=test_stats)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)
        assert updated_snapshot_obj.config == test_config
        assert updated_snapshot_obj.stats == test_stats

        # Updating both message and label
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        message=test_message,
                                        label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)

        assert updated_snapshot_obj.message == test_message
        assert updated_snapshot_obj.label == test_label

        # Updating only message
        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj_1.id,
                                        message=test_message)

        # Get the updated snapshot obj
        updated_snapshot_obj_1 = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj_1.id)

        assert updated_snapshot_obj_1.message == test_message

        # Updating only label
        # Create snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj_2.id, label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj_2 = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj_2.id)

        assert updated_snapshot_obj_2.label == test_label

    def test_get(self):
        self.__setup()
        # Test failure for no snapshot
        failed = False
        try:
            self.snapshot_controller.get("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Test success for snapshot
        snapshot_obj = self.__default_create()
        snapshot_obj_returned = self.snapshot_controller.get(snapshot_obj.id)
        assert snapshot_obj == snapshot_obj_returned

    def test_get_files(self):
        self.__setup()
        # Test failure case
        failed = False
        try:
            self.snapshot_controller.get_files("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Test success case
        snapshot_obj = self.__default_create()
        result = self.snapshot_controller.get_files(snapshot_obj.id)
        file_collection_obj = self.task_controller.dal.file_collection.get_by_id(
            snapshot_obj.file_collection_id)

        file_names = [item.name for item in result]

        assert len(result) == 1
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "r"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

        result = self.snapshot_controller.get_files(snapshot_obj.id, mode="a")

        assert len(result) == 1
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "a"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

    def test_delete(self):
        self.__setup()
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Delete snapshot in the project
        result = self.snapshot_controller.delete(snapshot_obj.id)

        # Check if snapshot retrieval throws error
        thrown = False
        try:
            self.snapshot_controller.dal.snapshot.get_by_id(snapshot_obj.id)
        except EntityNotFound:
            thrown = True

        assert result == True and \
            thrown == True
Exemple #6
0
def create(message,
           label=None,
           task_id=None,
           environment_id=None,
           env=None,
           paths=None,
           config=None,
           stats=None):
    """Create a snapshot within a project

    The project must be created before this is implemented. You can do that by using
    the following command::

        $ datmo init


    Parameters
    ----------

    message : str
        a description of the snapshot for later reference
    label : str, optional
        a short description of the snapshot for later reference
        (default is None, which means a blank label is stored)
    task_id : str, optional
        task object id to use to create snapshot
        if task id is passed then subsequent parameters would be ignored.
        when using task id, it will overwrite the following inputs

        *environment_id*: used to run the task,

        *paths*: this is the set of all files saved during the task

        *config*: nothing is passed into this variable. the user may add
        something to the config by passing in a dict for the config

        *stats*:  the task.results are added into the stats variable of the
        snapshot.

    environment_id : str, optional
        provide the environment object id to use with this snapshot
        (default is None, which means it creates a default environment)
    env : str or list, optional
        the absolute file path for the environment definition path. env is not used if environment_id is also passed.
        this can be either a string or list
        (default is None, environment_id is also not passed, which will defer to the environment to find a
        default environment or will fail if not found)
    paths : list, optional
        list of absolute or relative filepaths and/or dirpaths to collect with destination names
        (e.g. "/path/to/file>hello", "/path/to/file2", "/path/to/dir>newdir")
    config : dict, optional
        provide the dictionary of configurations
        (default is None, which means it is empty)
    stats : dict, optional
        provide the dictionary of relevant statistics or metrics
        (default is None, which means it is empty)

    Returns
    -------
    Snapshot
        returns a Snapshot entity as defined above

    Examples
    --------
    You can use this function within a project repository to save snapshots
    for later use. Once you have created this, you will be able to view the
    snapshot with the `datmo snapshot ls` cli command

    >>> import datmo
    >>> datmo.snapshot.create(message="my first snapshot", paths=["/path/to/a/large/file"], config={"test": 0.4, "test2": "string"}, stats={"accuracy": 0.94})

    You can also use the result of a task run in order to create a snapshot

    >>> datmo.snapshot.create(message="my first snapshot from task", task_id="1jfkshg049")
    """

    snapshot_controller = SnapshotController()

    if task_id is not None:
        excluded_args = ["environment_id", "paths"]
        for arg in excluded_args:
            if eval(arg) is not None:
                raise SnapshotCreateFromTaskArgs(
                    "error", "sdk.snapshot.create.task.args", arg)

        # Create a new core snapshot object
        core_snapshot_obj = snapshot_controller.create_from_task(message,
                                                                 task_id,
                                                                 label=label,
                                                                 config=config,
                                                                 stats=stats)

        # Create a new snapshot object
        client_snapshot_obj = Snapshot(core_snapshot_obj)

        return client_snapshot_obj
    else:
        snapshot_create_dict = {"message": message}

        # add arguments if they are not None
        if label:
            snapshot_create_dict['label'] = label
        if environment_id:
            snapshot_create_dict['environment_id'] = environment_id
        elif isinstance(env, list):
            snapshot_create_dict['environment_paths'] = env
        elif env:
            snapshot_create_dict['environment_paths'] = [env]
        if paths:
            snapshot_create_dict['paths'] = paths
        if config:
            snapshot_create_dict['config'] = config
        if stats:
            snapshot_create_dict['stats'] = stats
        if label:
            snapshot_create_dict['label'] = label

        # Create a new core snapshot object
        core_snapshot_obj = snapshot_controller.create(snapshot_create_dict)
        core_snapshot_obj = snapshot_controller.update(core_snapshot_obj.id,
                                                       visible=True)

        # Create a new snapshot object
        client_snapshot_obj = Snapshot(core_snapshot_obj)

        return client_snapshot_obj
Exemple #7
0
def create(message,
           label=None,
           home=None,
           task_id=None,
           commit_id=None,
           environment_id=None,
           filepaths=None,
           config=None,
           stats=None):
    """Create a snapshot within a project

    The project must be created before this is implemented. You can do that by using
    the following command::

        $ datmo init


    Parameters
    ----------

    message : str
        a description of the snapshot for later reference
    label : str, optional
        a short description of the snapshot for later reference
        (default is None, which means a blank label is stored)
    home : str, optional
        absolute home path of the project
        (default is None, which will use the CWD as the project path)
    task_id : str, optional
        task object id to use to create snapshot
        if task id is passed then subsequent parameters would be ignored.
        when using task id, it will overwrite the following inputs

        *commit_id*:  taken form the source code after the task is run

        *environment_id*: used to run the task,

        *filepaths*: this is the set of all files saved during the task

        *config*: nothing is passed into this variable. the user may add
        something to the config by passing in a dict for the config

        *stats*:  the task.results are added into the stats variable of the
        snapshot.

    commit_id : str, optional
        provide the exact commit hash associated with the snapshot
        (default is None, which means it automatically creates a commit)
    environment_id : str, optional
        provide the environment object id to use with this snapshot
        (default is None, which means it creates a default environment)
    filepaths : list, optional
        provides a list of absolute filepaths to files or directories
        that are relevant (default is None, which means we have an empty
    config : dict, optional
        provide the dictionary of configurations
        (default is None, which means it is empty)
    stats : dict, optional
        provide the dictionary of relevant statistics or metrics
        (default is None, which means it is empty)

    Returns
    -------
    Snapshot
        returns a Snapshot entity as defined above

    Examples
    --------
    You can use this function within a project repository to save snapshots
    for later use. Once you have created this, you will be able to view the
    snapshot with the `datmo snapshot ls` cli command

    >>> import datmo
    >>> datmo.snapshot.create(message="my first snapshot", filepaths=["/path/to/a/large/file"], config={"test": 0.4, "test2": "string"}, stats={"accuracy": 0.94})

    You can also use the result of a task run in order to create a snapshot

    >>> datmo.snapshot.create(message="my first snapshot from task", task_id="1jfkshg049")
    """
    if not home:
        home = os.getcwd()
    snapshot_controller = SnapshotController(home=home)

    if task_id is not None:
        excluded_args = ["commit_id", "environment_id", "filepaths"]
        for arg in excluded_args:
            if eval(arg) is not None:
                raise SnapshotCreateFromTaskArgs(
                    "error", "sdk.snapshot.create.task.args", arg)

        # Create a new core snapshot object
        core_snapshot_obj = snapshot_controller.create_from_task(
            message, task_id, label=label, config=config, stats=stats)

        # Create a new snapshot object
        client_snapshot_obj = Snapshot(core_snapshot_obj, home=home)

        return client_snapshot_obj
    else:
        snapshot_create_dict = {"message": message}

        # add arguments if they are not None
        if label:
            snapshot_create_dict['label'] = label
        if commit_id:
            snapshot_create_dict['commit_id'] = commit_id
        if environment_id:
            snapshot_create_dict['environment_id'] = environment_id
        if filepaths:
            snapshot_create_dict['filepaths'] = filepaths
        if config:
            snapshot_create_dict['config'] = config
        if stats:
            snapshot_create_dict['stats'] = stats
        if label:
            snapshot_create_dict['label'] = label

        # Create a new core snapshot object
        core_snapshot_obj = snapshot_controller.create(snapshot_create_dict)

        # Create a new snapshot object
        client_snapshot_obj = Snapshot(core_snapshot_obj, home=home)

        return client_snapshot_obj
Exemple #8
0
class TaskController(BaseController):
    """TaskController inherits from BaseController and manages business logic associated with tasks
    within the project.

    Parameters
    ----------
    home : str
        home path of the project

    Attributes
    ----------
    environment : EnvironmentController
        used to create environment if new definition file
    snapshot : SnapshotController
        used to create snapshots before and after tasks

    Methods
    -------
    create(dictionary)
        creates a Task object with the permanent parameters
    _run_helper(environment_id, log_filepath, options)
        helper for run to start environment and run with the appropriate parameters
    run(self, id, dictionary=None)
        runs the task and tracks the run, logs, inputs and outputs
    list(session_id=None)
        lists all tasks within the project given filters
    delete(id)
        deletes the specified task from the project

    """
    def __init__(self, home):
        super(TaskController, self).__init__(home)
        self.environment = EnvironmentController(home)
        self.snapshot = SnapshotController(home)
        if not self.is_initialized:
            raise ProjectNotInitializedException(
                __("error", "controller.task.__init__"))

    def create(self, dictionary):
        """Create Task object

        Parameters
        ----------
        dictionary : dict
            command : str
                full command used

        Returns
        -------
        Task
            object entity for Task (datmo.core.entity.task.Task)
        """

        # Validate Inputs

        create_dict = {
            "model_id": self.model.id,
            "session_id": self.current_session.id
        }

        ## Required args
        required_args = ["command"]
        for required_arg in required_args:
            # Add in any values that are
            if required_arg in dictionary and dictionary[
                    required_arg] is not None:
                create_dict[required_arg] = dictionary[required_arg]
            else:
                raise RequiredArgumentMissing(
                    __("error", "controller.task.create.arg", required_arg))

        # Create Task
        return self.dal.task.create(Task(create_dict))

    def _run_helper(self, environment_id, options, log_filepath):
        """Run environment with parameters

        Parameters
        ----------
        environment_id : str
            the environment id for definition
        options : dict
            can include the following values:

            command : list
            ports : list
                Here are some example ports used for common applications.
                   *  'jupyter notebook' - 8888
                   *  flask API - 5000
                   *  tensorboard - 6006
                An example input for the above would be ["8888:8888", "5000:5000", "6006:6006"]
                which maps the running host port (right) to that of the environment (left)
            name : str
            volumes : dict
            detach : bool
            stdin_open : bool
            tty : bool
            gpu : bool
        log_filepath : str
            absolute filepath to the log file

        Returns
        -------
        return_code : int
            system return code of the environment that was run
        run_id : str
            id of the environment run (different from environment id)
        logs : str
            output logs from the run
        """
        # Run container with options provided
        run_options = {
            "command": options.get('command', None),
            "ports": options.get('ports', None),
            "name": options.get('name', None),
            "volumes": options.get('volumes', None),
            "detach": options.get('detach', False),
            "stdin_open": options.get('stdin_open', False),
            "tty": options.get('tty', False),
            "gpu": options.get('gpu', False),
            "api": False
        }

        # Build image for environment
        self.environment.build(environment_id)

        # Run container with environment
        return_code, run_id, logs = \
            self.environment.run(environment_id, run_options, log_filepath)

        return return_code, run_id, logs

    def _parse_logs_for_results(self, logs):
        """Parse log string to extract results and return dictionary.

        Note
        ----
        If the same key is found multiple times in the logs, the last occurring
        one will be the one that is saved.

        Parameters
        ----------
        logs : str
            raw string value of output logs

        Returns
        -------
        dict
            dictionary to represent results from task
        """
        results = {}
        for line in logs.split("\n"):
            split_line = line.split(":")
            if len(split_line) == 2:
                results[split_line[0].strip()] = split_line[1].strip()
        return results

    def run(self, task_id, snapshot_dict=None, task_dict=None):
        """Run a task with parameters. If dictionary specified, create a new task with new run parameters.
        Snapshot objects are created before and after the task to keep track of the state. During the run,
        you can access task outputs using environment variable DATMO_TASK_DIR or `/task` which points to
        location of datmo_tasks/[task-id]. Create config.json, stats.json and any weights or any file such
        as graphs and visualizations within that directory for quick access

        Parameters
        ----------
        task_id : str
            id for the task you would like to run
        snapshot_dict : dict
            set of parameters to create a snapshot (see SnapshotController for details.
            default is None, which means dictionary with `visible` False will be added to
            hide auto-generated snapshot) NOTE: `visible` False will always be False regardless
            of whether the user provides another value for `visible`.
        task_dict : dict
            set of parameters to characterize the task run
            (default is None, which translate to {}, see datmo.core.entity.task.Task for more details on inputs)

        Returns
        -------
        Task
            the Task object which completed its run with updated parameters

        Raises
        ------
        TaskRunException
            If there is any error in creating files for the task or downstream errors
        """
        # Ensure visible=False is present in the snapshot dictionary
        if not snapshot_dict:
            snapshot_dict = {"visible": False}
        else:
            snapshot_dict['visible'] = False

        if not task_dict:
            task_dict = {}

        # Obtain Task to run
        task_obj = self.dal.task.get_by_id(task_id)

        if task_obj.status == None:
            task_obj.status = 'RUNNING'
        else:
            raise TaskRunException(
                __("error", "cli.task.run.already_running", task_obj.id))

        # Create Task directory for user during run
        task_dirpath = os.path.join("datmo_tasks", task_obj.id)
        try:
            _ = self.file_driver.create(os.path.join("datmo_tasks",
                                                     task_obj.id),
                                        directory=True)
        except:
            raise TaskRunException(
                __("error", "controller.task.run", task_dirpath))

        # Create the before snapshot prior to execution
        before_snapshot_dict = snapshot_dict.copy()
        before_snapshot_dict[
            'message'] = "autogenerated snapshot created before task %s is run" % task_obj.id
        before_snapshot_obj = self.snapshot.create(before_snapshot_dict)

        # Update the task with pre-execution parameters
        task_obj = self.dal.task.update({
            "id":
            task_obj.id,
            "before_snapshot_id":
            task_dict.get('before_snapshot_id', before_snapshot_obj.id),
            "ports":
            task_dict.get('ports', task_obj.ports),
            "gpu":
            task_dict.get('gpu', task_obj.gpu),
            "interactive":
            task_dict.get('interactive', task_obj.interactive),
            "task_dirpath":
            task_dict.get('task_dirpath', task_dirpath),
            "log_filepath":
            task_dict.get('log_filepath', os.path.join(task_dirpath,
                                                       "task.log")),
            "start_time":
            task_dict.get('start_time', datetime.utcnow())
        })

        # Copy over files from the before_snapshot file collection to task dir
        file_collection_obj =  \
            self.dal.file_collection.get_by_id(before_snapshot_obj.file_collection_id)
        self.file_driver.copytree(
            os.path.join(self.home, file_collection_obj.path),
            os.path.join(self.home, task_obj.task_dirpath))

        # Set the parameters set in the task
        environment_run_options = {
            "command": task_obj.command,
            "ports": [] if task_obj.ports is None else task_obj.ports,
            "gpu": task_obj.gpu,
            "name": "datmo-task-" + task_obj.id,
            "volumes": {
                os.path.join(self.home, task_obj.task_dirpath): {
                    'bind': '/task/',
                    'mode': 'rw'
                },
                self.home: {
                    'bind': '/home/',
                    'mode': 'rw'
                }
            },
            "detach": task_obj.interactive,
            "stdin_open": task_obj.interactive,
            "tty": False,
            "api": not task_obj.interactive
        }

        # Run environment via the helper function
        return_code, run_id, logs =  \
            self._run_helper(before_snapshot_obj.environment_id,
                             environment_run_options,
                             os.path.join(self.home, task_obj.log_filepath))

        # Create the after snapshot after execution is completed with new filepaths
        after_snapshot_dict = snapshot_dict.copy()
        after_snapshot_dict[
            'message'] = "autogenerated snapshot created after task %s is run" % task_obj.id

        # Add in absolute filepaths from running task directory
        absolute_task_dir_path = os.path.join(self.home, task_obj.task_dirpath)
        absolute_filepaths = []
        for item in os.listdir(absolute_task_dir_path):
            path = os.path.join(absolute_task_dir_path, item)
            if os.path.isfile(path) or os.path.isdir(path):
                absolute_filepaths.append(path)
        after_snapshot_dict.update({
            "filepaths":
            absolute_filepaths,
            "environment_id":
            before_snapshot_obj.environment_id,
        })
        after_snapshot_obj = self.snapshot.create(after_snapshot_dict)

        # (optional) Remove temporary task directory path
        # Update the task with post-execution parameters
        end_time = datetime.utcnow()
        duration = (end_time - task_obj.start_time).total_seconds()
        return self.dal.task.update({
            "id":
            task_obj.id,
            "after_snapshot_id":
            after_snapshot_obj.id,
            "run_id":
            run_id,
            "logs":
            logs,
            "results":
            self._parse_logs_for_results(logs),
            # "results": task_obj.results, # TODO: update during run
            "status":
            "SUCCESS" if return_code == 0 else "FAILED",
            "end_time":
            end_time,
            "duration":
            duration
        })

    def list(self, session_id=None):
        query = {}
        if session_id:
            query['session_id'] = session_id
        return self.dal.task.query(query)

    def get_files(self, task_id, mode="r"):
        """Get list of file objects for task id. It will look in the following areas in the following order

        1) look in the after snapshot for file collection
        2) look in the running task file collection
        3) look in the before snapshot for file collection

        Parameters
        ----------
        task_id : str
            id for the task you would like to get file objects for
        mode : str
            file open mode
            (default is "r" to open file for read)

        Returns
        -------
        list
            list of python file objects

        Raises
        ------
        PathDoesNotExist
            no file objects exist for the task
        """
        task_obj = self.dal.task.get_by_id(task_id)
        if task_obj.after_snapshot_id:
            # perform number 1) and return file list
            after_snapshot_obj = \
                self.dal.snapshot.get_by_id(task_obj.after_snapshot_id)
            file_collection_obj = \
                self.dal.file_collection.get_by_id(after_snapshot_obj.file_collection_id)
            return self.file_driver.\
                get_collection_files(file_collection_obj.filehash, mode=mode)
        elif task_obj.task_dirpath:
            # perform number 2) and return file list
            return self.file_driver.get(task_obj.task_dirpath,
                                        mode=mode,
                                        directory=True)
        elif task_obj.before_snapshot_id:
            # perform number 3) and return file list
            before_snapshot_obj = \
                self.dal.snapshot.get_by_id(task_obj.before_snapshot_id)
            file_collection_obj = \
                self.dal.file_collection.get_by_id(before_snapshot_obj.file_collection_id)
            return self.file_driver. \
                get_collection_files(file_collection_obj.filehash, mode=mode)
        else:
            # Error because the task does not have any files associated with it
            raise PathDoesNotExist()

    def delete(self, task_id):
        if not task_id:
            raise RequiredArgumentMissing(
                __("error", "controller.task.delete.arg", "id"))
        return self.dal.task.delete(task_id)

    def stop(self, task_id):
        """Stop and remove run for the task

        Parameters
        ----------
        task_id : str
            id for the task you would like to stop

        Returns
        -------
        return_code : bool
            system return code of the stop
        """
        if not task_id:
            raise RequiredArgumentMissing(
                __("error", "controller.task.stop.arg", "id"))
        task_obj = self.dal.task.get_by_id(task_id)
        run_id = task_obj.run_id
        return_code = self.environment.stop(run_id)
        return return_code
Exemple #9
0
class SnapshotCommand(ProjectCommand):
    def __init__(self, home, cli_helper):
        super(SnapshotCommand, self).__init__(home, cli_helper)
        # dest="subcommand" argument will populate a "subcommand" property with the subparsers name
        # example  "subcommand"="create"  or "subcommand"="ls"
        self.snapshot_controller = SnapshotController(home=home)

    def usage(self):
        self.cli_helper.echo(__("argparser", "cli.snapshot.usage"))

    def snapshot(self):
        self.parse(["snapshot", "--help"])
        return True

    def create(self, **kwargs):
        self.cli_helper.echo(__("info", "cli.snapshot.create"))
        task_id = kwargs.get("task_id", None)
        # creating snapshot with task id if it exists
        if task_id is not None:
            excluded_args = [
                "code_id", "commit_id", "environment_id",
                "environment_definition_filepath", "file_collection_id",
                "filepaths", "config_filepath", "config_filename",
                "stats_filepath", "stats_filename"
            ]
            for arg in excluded_args:
                if arg in kwargs and kwargs[arg] is not None:
                    raise SnapshotCreateFromTaskArgs(
                        "error", "cli.snapshot.create.task.args", arg)

            message = kwargs.get("message", None)
            label = kwargs.get("label", None)
            # Create a new core snapshot object
            snapshot_task_obj = self.snapshot_controller.create_from_task(
                message, task_id, label=label)
            self.cli_helper.echo("Created snapshot id: %s" %
                                 snapshot_task_obj.id)
            return snapshot_task_obj.id
        else:
            # creating snapshot without task id
            snapshot_dict = {"visible": True}

            # Code
            if kwargs.get("code_id", None) or kwargs.get("commit_id", None):
                mutually_exclusive_args = ["code_id", "commit_id"]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)

            # Environment
            if kwargs.get("environment_id", None) or kwargs.get(
                    "environment_definition_filepath", None):
                mutually_exclusive_args = [
                    "environment_id", "environment_definition_filepath"
                ]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)

            # File
            if kwargs.get("file_collection_id", None) or kwargs.get(
                    "filepaths", None):
                mutually_exclusive_args = ["file_collection_id", "filepaths"]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)

            # Config
            if kwargs.get("config_filepath", None) or kwargs.get(
                    "config_filename", None):
                mutually_exclusive_args = [
                    "config_filepath", "config_filename"
                ]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)

            # Stats
            if kwargs.get("stats_filepath", None) or kwargs.get(
                    "stats_filename", None):
                mutually_exclusive_args = ["stats_filepath", "stats_filename"]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)

            optional_args = ["session_id", "message", "label"]

            for arg in optional_args:
                if arg in kwargs and kwargs[arg] is not None:
                    snapshot_dict[arg] = kwargs[arg]

            snapshot_obj = self.snapshot_controller.create(snapshot_dict)
            self.cli_helper.echo(
                __("info", "cli.snapshot.create.success", snapshot_obj.id))
            return snapshot_obj.id

    def delete(self, **kwargs):
        self.cli_helper.echo(__("info", "cli.snapshot.delete"))
        snapshot_id = kwargs.get("id", None)
        self.cli_helper.echo(
            __("info", "cli.snapshot.delete.success", snapshot_id))
        return self.snapshot_controller.delete(snapshot_id)

    def ls(self, **kwargs):
        session_id = kwargs.get('session_id',
                                self.snapshot_controller.current_session.id)
        # Get all snapshot meta information
        detailed_info = kwargs.get('details', None)
        # List of ids shown
        listed_snapshot_ids = []
        snapshot_objs = self.snapshot_controller.list(session_id=session_id,
                                                      visible=True,
                                                      sort_key='created_at',
                                                      sort_order='descending')
        if detailed_info:
            header_list = [
                "id", "created at", "config", "stats", "message", "label",
                "code id", "environment id", "file collection id"
            ]
            t = prettytable.PrettyTable(header_list)
            for snapshot_obj in snapshot_objs:
                t.add_row([
                    snapshot_obj.id,
                    snapshot_obj.created_at.strftime("%Y-%m-%d %H:%M:%S"),
                    snapshot_obj.config, snapshot_obj.stats,
                    snapshot_obj.message, snapshot_obj.label,
                    snapshot_obj.code_id, snapshot_obj.environment_id,
                    snapshot_obj.file_collection_id
                ])
                listed_snapshot_ids.append(snapshot_obj.id)
        else:
            header_list = [
                "id", "created at", "config", "stats", "message", "label"
            ]
            t = prettytable.PrettyTable(header_list)
            for snapshot_obj in snapshot_objs:
                t.add_row([
                    snapshot_obj.id,
                    snapshot_obj.created_at.strftime("%Y-%m-%d %H:%M:%S"),
                    snapshot_obj.config, snapshot_obj.stats,
                    snapshot_obj.message, snapshot_obj.label
                ])
                listed_snapshot_ids.append(snapshot_obj.id)

        self.cli_helper.echo(t)
        return listed_snapshot_ids

    def checkout(self, **kwargs):
        snapshot_id = kwargs.get("id", None)
        checkout_success = self.snapshot_controller.checkout(snapshot_id)
        if checkout_success:
            self.cli_helper.echo(
                __("info", "cli.snapshot.checkout.success", snapshot_id))
        return self.snapshot_controller.checkout(snapshot_id)
Exemple #10
0
class TestProjectController():
    def setup_method(self):
        # provide mountable tmp directory for docker
        tempfile.tempdir = "/tmp" if not platform.system(
        ) == "Windows" else None
        test_datmo_dir = os.environ.get('TEST_DATMO_DIR',
                                        tempfile.gettempdir())
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        self.project = ProjectController(self.temp_dir)

    def teardown_method(self):
        pass

    def test_init_none(self):
        # Test failed case
        failed = False
        try:
            self.project.init(None, None)
        except ValidationFailed:
            failed = True
        assert failed

    def test_init_empty_str(self):
        # Test failed case
        failed = False
        try:
            self.project.init("", "")
        except ValidationFailed:
            failed = True
        assert failed

    def test_init(self):

        result = self.project.init("test1", "test description")

        # Tested with is_initialized
        assert self.project.model.name == "test1"
        assert self.project.model.description == "test description"
        assert self.project.code_driver.is_initialized
        assert self.project.file_driver.is_initialized
        assert self.project.environment_driver.is_initialized
        assert result and self.project.is_initialized

        # Changeable by user, not tested in is_initialized
        assert self.project.current_session.name == "default"

        # Check Project template if user specified template
        # TODO: Add in Project template if user specifies

        # Test out functionality for re-initialize project
        result = self.project.init("anything", "else")

        assert self.project.model.name == "anything"
        assert self.project.model.description == "else"
        assert result == True

    def test_cleanup(self):
        self.project.init("test2", "test description")
        result = self.project.cleanup()

        assert not self.project.code_driver.exists_code_refs_dir()
        assert not self.project.file_driver.exists_datmo_file_structure()
        assert not self.project.environment_driver.list_images("datmo-test2")
        # Ensure that containers built with this image do not exist
        # assert not self.project.environment_driver.list_containers(filters={
        #     "ancestor": image_id
        # })
        assert result == True

    def test_status_basic(self):
        self.project.init("test3", "test description")
        status_dict, latest_snapshot, ascending_unstaged_task_list = \
            self.project.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test3"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert not latest_snapshot
        assert not ascending_unstaged_task_list

    def test_status_snapshot_task(self):
        self.project.init("test4", "test description")
        self.snapshot = SnapshotController(self.temp_dir)
        self.task = TaskController(self.temp_dir)

        # Create files to add
        self.snapshot.file_driver.create("dirpath1", directory=True)
        self.snapshot.file_driver.create("dirpath2", directory=True)
        self.snapshot.file_driver.create("filepath1")

        # Create environment_driver definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Create config
        config_filepath = os.path.join(self.snapshot.home, "config.json")
        with open(config_filepath, "w") as f:
            f.write(to_unicode(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot.home, "stats.json")
        with open(stats_filepath, "w") as f:
            f.write(to_unicode(str("{}")))

        input_dict = {
            "message":
            "my test snapshot",
            "filepaths": [
                os.path.join(self.snapshot.home, "dirpath1"),
                os.path.join(self.snapshot.home, "dirpath2"),
                os.path.join(self.snapshot.home, "filepath1")
            ],
            "environment_definition_filepath":
            env_def_path,
            "config_filename":
            config_filepath,
            "stats_filename":
            stats_filepath,
        }

        # Create snapshot in the project
        first_snapshot = self.snapshot.create(input_dict)

        status_dict, latest_snapshot, ascending_unstaged_task_list = \
            self.project.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test4"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert isinstance(latest_snapshot, Snapshot)
        assert latest_snapshot.id == first_snapshot.id
        assert not ascending_unstaged_task_list

        # Create and run a task and test if unstaged task is shown
        first_task = self.task.create()

        # Create task_dict
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command": task_command}

        updated_first_task = self.task.run(first_task.id, task_dict=task_dict)

        status_dict, latest_snapshot, ascending_unstaged_task_list = \
            self.project.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test4"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert isinstance(latest_snapshot, Snapshot)
        assert latest_snapshot.id == first_snapshot.id
        assert isinstance(ascending_unstaged_task_list[0], Task)
        assert ascending_unstaged_task_list[0].id == updated_first_task.id
Exemple #11
0
def create(message,
           label=None,
           commit_id=None,
           environment_id=None,
           filepaths=None,
           config=None,
           stats=None,
           home=None):
    """Create a snapshot within a project

    The project must be created before this is implemented. You can do that by using
    the following command::

        $ datmo init


    Parameters
    ----------
    message : str
        a description of the snapshot for later reference
    label : str, optional
        a short description of the snapshot for later reference
        (default is None, which means a blank label is stored)
    commit_id : str, optional
        provide the exact commit hash associated with the snapshot
        (default is None, which means it automatically creates a commit)
    environment_id : str, optional
        provide the environment object id to use with this snapshot
        (default is None, which means it creates a default environment)
    filepaths : list, optional
        provides a list of absolute filepaths to files or directories
        that are relevant (default is None, which means we have an empty
    config : dict, optional
        provide the dictionary of configurations
        (default is None, which means it is empty)
    stats : dict, optional
        provide the dictionary of relevant statistics or metrics
        (default is None, which means it is empty)
    home : str, optional
        absolute home path of the project
        (default is None, which will use the CWD as the project path)

    Returns
    -------
    Snapshot
        returns a Snapshot entity as defined above

    Examples
    --------
    You can use this function within a project repository to save snapshots
    for later use. Once you have created this, you will be able to view the
    snapshot with the `datmo snapshot ls` cli command

    >>> import datmo
    >>> datmo.snapshot.create(message="my first snapshot", filepaths=["/path/to/a/large/file"], config={"test": 0.4, "test2": "string"}, stats={"accuracy": 0.94})
    """
    if not home:
        home = os.getcwd()
    snapshot_controller = SnapshotController(home=home)

    snapshot_create_dict = {"message": message}

    # add arguments if they are not None
    if label:
        snapshot_create_dict['label'] = label
    if commit_id:
        snapshot_create_dict['commit_id'] = commit_id
    if environment_id:
        snapshot_create_dict['environment_id'] = environment_id
    if filepaths:
        snapshot_create_dict['filepaths'] = filepaths
    if config:
        snapshot_create_dict['config'] = config
    if stats:
        snapshot_create_dict['stats'] = stats
    if label:
        snapshot_create_dict['label'] = label

    # Create a new core snapshot object
    core_snapshot_obj = snapshot_controller.create(snapshot_create_dict)

    # Create a new snapshot object
    client_snapshot_obj = Snapshot(core_snapshot_obj, home=home)

    return client_snapshot_obj
Exemple #12
0
class SnapshotCommand(ProjectCommand):
    def __init__(self, home, cli_helper):
        super(SnapshotCommand, self).__init__(home, cli_helper)
        # dest="subcommand" argument will populate a "subcommand" property with the subparsers name
        # example  "subcommand"="create"  or "subcommand"="ls"
        snapshot_parser = self.subparsers.add_parser("snapshot",
                                                     help="Snapshot module")
        subcommand_parsers = snapshot_parser.add_subparsers(
            title="subcommands", dest="subcommand")

        create = subcommand_parsers.add_parser("create",
                                               help="create snapshot")
        create.add_argument("--message",
                            "-m",
                            dest="message",
                            default=None,
                            help="message to describe snapshot")
        create.add_argument("--label",
                            "-l",
                            dest="label",
                            default=None,
                            help="Label snapshots with a category (e.g. best)")
        create.add_argument("--session-id",
                            dest="session_id",
                            default=None,
                            help="user given session id")

        create.add_argument("--task-id",
                            dest="task_id",
                            default=None,
                            help="Specify task id to pull information from")

        create.add_argument("--code-id",
                            dest="code_id",
                            default=None,
                            help="code id from code object")
        create.add_argument("--commit-id",
                            dest="commit_id",
                            default=None,
                            help="commit id from source control")

        create.add_argument("--environment-id",
                            dest="environment_id",
                            default=None,
                            help="environment id from environment object")
        create.add_argument(
            "--environment-def-path",
            dest="environment_def_path",
            default=None,
            help=
            "absolute filepath to environment definition file (e.g. /path/to/Dockerfile)"
        )

        create.add_argument(
            "--file-collection-id",
            dest="file_collection_id",
            default=None,
            help="file collection id for file collection object")
        create.add_argument(
            "--filepaths",
            dest="filepaths",
            default=None,
            nargs="*",
            help=
            "absolute paths to files or folders to include within the files of the snapshot"
        )

        create.add_argument(
            "--config-filename",
            dest="config_filename",
            default=None,
            help="filename to use to search for configuration JSON")
        create.add_argument(
            "--config-filepath",
            dest="config_filepath",
            default=None,
            help="absolute filepath to use to search for configuration JSON")

        create.add_argument("--stats-filename",
                            dest="stats_filename",
                            default=None,
                            help="filename to use to search for metrics JSON")
        create.add_argument(
            "--stats-filepath",
            dest="stats_filepath",
            default=None,
            help="absolute filepath to use to search for metrics JSON")

        delete = subcommand_parsers.add_parser("delete",
                                               help="Delete a snapshot by id")
        delete.add_argument("--id", dest="id", help="snapshot id to delete")

        ls = subcommand_parsers.add_parser("ls", help="List snapshots")
        ls.add_argument("--session-id",
                        dest="session_id",
                        default=None,
                        help="Session ID to filter")
        ls.add_argument("--all",
                        "-a",
                        dest="details",
                        action="store_true",
                        help="Show detailed snapshot information")

        checkout = subcommand_parsers.add_parser(
            "checkout", help="Checkout a snapshot by id")
        checkout.add_argument("--id",
                              dest="id",
                              default=None,
                              help="Snapshot ID")

        self.snapshot_controller = SnapshotController(home=home)

    def create(self, **kwargs):
        self.cli_helper.echo(__("info", "cli.snapshot.create"))

        snapshot_dict = {}

        # Code
        mutually_exclusive_args = ["code_id", "commit_id"]
        mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)

        # Environment
        mutually_exclusive_args = ["environment_id", "environment_def_path"]
        mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)

        # File
        mutually_exclusive_args = ["file_collection_id", "filepaths"]
        mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)

        # Config
        mutually_exclusive_args = ["config_filepath", "config_filename"]
        mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)

        # Stats
        mutually_exclusive_args = ["stats_filepath", "stats_filename"]
        mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)

        optional_args = ["session_id", "task_id", "message", "label"]

        for arg in optional_args:
            if arg in kwargs and kwargs[arg] is not None:
                snapshot_dict[arg] = kwargs[arg]

        snapshot_obj = self.snapshot_controller.create(snapshot_dict)

        return snapshot_obj.id

    def delete(self, **kwargs):
        self.cli_helper.echo(__("info", "cli.snapshot.delete"))
        snapshot_id = kwargs.get("id", None)
        return self.snapshot_controller.delete(snapshot_id)

    def ls(self, **kwargs):
        session_id = kwargs.get('session_id',
                                self.snapshot_controller.current_session.id)
        # Get all snapshot meta information
        detailed_info = kwargs.get('details', None)
        if detailed_info:
            header_list = [
                "id", "created at", "config", "stats", "message", "label",
                "code id", "environment id", "file collection id"
            ]
            t = prettytable.PrettyTable(header_list)
            snapshot_objs = self.snapshot_controller.list(
                session_id=session_id, visible=True)
            for snapshot_obj in snapshot_objs:
                t.add_row([
                    snapshot_obj.id,
                    snapshot_obj.created_at.strftime("%Y-%m-%d %H:%M:%S"),
                    snapshot_obj.config, snapshot_obj.stats,
                    snapshot_obj.message, snapshot_obj.label,
                    snapshot_obj.code_id, snapshot_obj.environment_id,
                    snapshot_obj.file_collection_id
                ])
        else:
            header_list = [
                "id", "created at", "config", "stats", "message", "label"
            ]
            t = prettytable.PrettyTable(header_list)
            snapshot_objs = self.snapshot_controller.list(
                session_id=session_id, visible=True)
            for snapshot_obj in snapshot_objs:
                t.add_row([
                    snapshot_obj.id,
                    snapshot_obj.created_at.strftime("%Y-%m-%d %H:%M:%S"),
                    snapshot_obj.config, snapshot_obj.stats,
                    snapshot_obj.message, snapshot_obj.label
                ])

        self.cli_helper.echo(t)
        return True

    def checkout(self, **kwargs):
        snapshot_id = kwargs.get("id", None)
        return self.snapshot_controller.checkout(snapshot_id)