예제 #1
0
    def __get_core_snapshot(self):
        """Returns the latest core snapshot object for id

        Returns
        -------
        datmo.core.entity.snapshot.Snapshot
            core snapshot object for the snapshot
        """
        snapshot_controller = SnapshotController()
        return snapshot_controller.get(self.id)
예제 #2
0
파일: run.py 프로젝트: stenpiren/datmo
    def __get_core_snapshot(self):
        """Returns the latest core snapshot object for id

        Returns
        -------
        datmo.core.entity.snapshot.Snapshot
            core snapshot object for the Snapshot
        """
        snapshot_controller = SnapshotController()
        snapshot_id = self.after_snapshot_id if self.after_snapshot_id else self.before_snapshot_id
        snapshot_obj = snapshot_controller.get(snapshot_id)
        return snapshot_obj
예제 #3
0
class SnapshotCommand(ProjectCommand):
    def __init__(self, cli_helper):
        super(SnapshotCommand, self).__init__(cli_helper)

    def usage(self):
        self.cli_helper.echo(__("argparser", "cli.snapshot.usage"))

    def snapshot(self):
        self.parse(["snapshot", "--help"])
        return True

    @Helper.notify_no_project_found
    def create(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        self.cli_helper.echo(__("info", "cli.snapshot.create"))
        run_id = kwargs.get("run_id", None)
        # creating snapshot with task id if it exists
        if run_id is not None:
            excluded_args = [
                "environment_id", "environment_paths", "paths",
                "config_filepath", "config_filename", "stats_filepath",
                "stats_filename"
            ]
            for arg in excluded_args:
                if arg in kwargs and kwargs[arg] is not None:
                    raise SnapshotCreateFromTaskArgs(
                        "error", "cli.snapshot.create.run.args", arg)

            message = kwargs.get("message", None)
            label = kwargs.get("label", None)
            # Create a new core snapshot object
            snapshot_task_obj = self.snapshot_controller.create_from_task(
                message, run_id, label=label)
            self.cli_helper.echo(
                "Created snapshot id: %s" % snapshot_task_obj.id)
            return snapshot_task_obj
        else:
            # creating snapshot without task id
            snapshot_dict = {"visible": True}

            # Environment
            if kwargs.get("environment_id", None) or kwargs.get(
                    "environment_paths", None):
                mutually_exclusive_args = [
                    "environment_id", "environment_paths"
                ]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)

            # File
            if kwargs.get("paths", None):
                snapshot_dict['paths'] = kwargs['paths']

            # Config
            if kwargs.get("config_filepath", None) or kwargs.get(
                    "config_filename", None) or kwargs.get("config", None):
                mutually_exclusive_args = [
                    "config_filepath", "config_filename", "config"
                ]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)
            # parsing config
            if "config" in snapshot_dict:
                config = {}
                config_list = snapshot_dict["config"]
                for item in config_list:
                    item_parsed_dict = parse_cli_key_value(item, 'config')
                    config.update(item_parsed_dict)
                snapshot_dict["config"] = config

            # Stats
            if kwargs.get("stats_filepath", None) or kwargs.get(
                    "stats_filename", None) or kwargs.get("config", None):
                mutually_exclusive_args = [
                    "stats_filepath", "stats_filename", "stats"
                ]
                mutually_exclusive(mutually_exclusive_args, kwargs,
                                   snapshot_dict)
            # parsing stats
            if "stats" in snapshot_dict:
                stats = {}
                stats_list = snapshot_dict["stats"]
                for item in stats_list:
                    item_parsed_dict = parse_cli_key_value(item, 'stats')
                    stats.update(item_parsed_dict)
                snapshot_dict["stats"] = stats

            optional_args = ["message", "label"]

            for arg in optional_args:
                if arg in kwargs and kwargs[arg] is not None:
                    snapshot_dict[arg] = kwargs[arg]

            snapshot_obj = self.snapshot_controller.create(snapshot_dict)
            # Because snapshots may be invisible to the user, this function ensures that by the end
            # the user can monitor the snapshot on the CLI, but making it visible
            snapshot_obj = self.snapshot_controller.update(
                snapshot_obj.id, visible=True)
            self.cli_helper.echo(
                __("info", "cli.snapshot.create.success", snapshot_obj.id))
            return snapshot_obj

    @Helper.notify_no_project_found
    def delete(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        self.cli_helper.echo(__("info", "cli.snapshot.delete"))
        snapshot_id = kwargs.get('id')
        result = self.snapshot_controller.delete(snapshot_id)
        self.cli_helper.echo(
            __("info", "cli.snapshot.delete.success", snapshot_id))
        return result

    @Helper.notify_no_project_found
    def update(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        self.cli_helper.echo(__("info", "cli.snapshot.update"))
        snapshot_id = kwargs.get('id')
        # getting previous saved config and stats
        snapshot_obj = self.snapshot_controller.get(snapshot_id)
        config = snapshot_obj.config
        stats = snapshot_obj.stats

        # extracting config
        update_config_list = kwargs.get('config', None)
        if update_config_list:
            update_config = {}
            for item in update_config_list:
                item_parsed_dict = parse_cli_key_value(item, 'config')
                update_config.update(item_parsed_dict)
            # updating config
            config.update(update_config)

        # extracting stats
        update_stats_list = kwargs.get('stats', None)
        if update_stats_list:
            update_stats = {}
            for item in update_stats_list:
                item_parsed_dict = parse_cli_key_value(item, 'stats')
                update_stats.update(item_parsed_dict)
            # updating stats
            stats.update(update_stats)

        # extracting message
        message = kwargs.get('message', None)
        # extracting label
        label = kwargs.get('label', None)

        result = self.snapshot_controller.update(
            snapshot_id,
            config=config,
            stats=stats,
            message=message,
            label=label)
        self.cli_helper.echo(
            __("info", "cli.snapshot.update.success", snapshot_id))
        return result

    @Helper.notify_no_project_found
    def ls(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        detailed_info = kwargs.get('details', None)
        show_all = kwargs.get('show_all', None)
        print_format = kwargs.get('format', "table")
        download = kwargs.get('download', None)
        download_path = kwargs.get('download_path', None)
        current_snapshot_obj = self.snapshot_controller.current_snapshot()
        current_snapshot_id = current_snapshot_obj.id if current_snapshot_obj else None
        if show_all:
            snapshot_objs = self.snapshot_controller.list(
                sort_key="created_at", sort_order="descending")
        else:
            snapshot_objs = self.snapshot_controller.list(
                visible=True, sort_key="created_at", sort_order="descending")
        item_dict_list = []
        if detailed_info:
            header_list = [
                "id", "created at", "config", "stats", "message", "label",
                "code id", "environment id", "file collection id"
            ]
            for snapshot_obj in snapshot_objs:
                snapshot_config_printable = printable_object(
                    snapshot_obj.config)
                snapshot_stats_printable = printable_object(snapshot_obj.stats)
                snapshot_message = printable_object(snapshot_obj.message)
                snapshot_label = printable_object(snapshot_obj.label)
                printable_snapshot_id = snapshot_obj.id if current_snapshot_id is not None and \
                                                           snapshot_obj.id != current_snapshot_id\
                    else "(current) " + snapshot_obj.id
                item_dict_list.append({
                    "id": printable_snapshot_id,
                    "created at": prettify_datetime(snapshot_obj.created_at),
                    "config": snapshot_config_printable,
                    "stats": snapshot_stats_printable,
                    "message": snapshot_message,
                    "label": snapshot_label,
                    "code id": snapshot_obj.code_id,
                    "environment id": snapshot_obj.environment_id,
                    "file collection id": snapshot_obj.file_collection_id
                })
        else:
            header_list = [
                "id", "created at", "config", "stats", "message", "label"
            ]
            for snapshot_obj in snapshot_objs:
                snapshot_config_printable = printable_object(
                    snapshot_obj.config)
                snapshot_stats_printable = printable_object(snapshot_obj.stats)
                snapshot_message = printable_object(snapshot_obj.message)
                snapshot_label = printable_object(snapshot_obj.label)
                printable_snapshot_id = snapshot_obj.id if current_snapshot_id is not None and \
                                                           snapshot_obj.id != current_snapshot_id \
                    else "(current) " + snapshot_obj.id
                item_dict_list.append({
                    "id": printable_snapshot_id,
                    "created at": prettify_datetime(snapshot_obj.created_at),
                    "config": snapshot_config_printable,
                    "stats": snapshot_stats_printable,
                    "message": snapshot_message,
                    "label": snapshot_label,
                })
        if download:
            if not download_path:
                # download to current working directory with timestamp
                current_time = datetime.utcnow()
                epoch_time = datetime.utcfromtimestamp(0)
                current_time_unix_time_ms = (
                    current_time - epoch_time).total_seconds() * 1000.0
                download_path = os.path.join(
                    self.snapshot_controller.home,
                    "snapshot_ls_" + str(current_time_unix_time_ms))
            self.cli_helper.print_items(
                header_list,
                item_dict_list,
                print_format=print_format,
                output_path=download_path)
            return snapshot_objs
        self.cli_helper.print_items(
            header_list, item_dict_list, print_format=print_format)
        return snapshot_objs

    @Helper.notify_no_project_found
    def checkout(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        snapshot_id = kwargs.get('id')
        checkout_success = self.snapshot_controller.checkout(snapshot_id)
        if checkout_success:
            self.cli_helper.echo(
                __("info", "cli.snapshot.checkout.success", snapshot_id))
        return self.snapshot_controller.checkout(snapshot_id)

    @Helper.notify_no_project_found
    def diff(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        snapshot_id_1 = kwargs.get("id_1", None)
        snapshot_id_2 = kwargs.get("id_2", None)
        snapshot_obj_1 = self.snapshot_controller.get(snapshot_id_1)
        snapshot_obj_2 = self.snapshot_controller.get(snapshot_id_2)
        comparison_attributes = [
            "id", "created_at", "message", "label", "code_id",
            "environment_id", "file_collection_id", "config", "stats"
        ]
        table_data = [["Attributes", "Snapshot 1", "", "Snapshot 2"],
                      ["", "", "", ""]]
        for attribute in comparison_attributes:
            value_1 = getattr(snapshot_obj_1, attribute) if getattr(
                snapshot_obj_1, attribute) else "N/A"
            value_2 = getattr(snapshot_obj_2, attribute) if getattr(
                snapshot_obj_2, attribute) else "N/A"
            if isinstance(value_1, datetime):
                value_1 = prettify_datetime(value_1)
            if isinstance(value_2, datetime):
                value_2 = prettify_datetime(value_2)
            if attribute in ["config", "stats"]:
                alldict = []
                if isinstance(value_1, dict): alldict.append(value_1)
                if isinstance(value_2, dict): alldict.append(value_2)
                allkey = set().union(*alldict)
                for key in allkey:
                    key_value_1 = "%s: %s" % (key, value_1[key]) if value_1 != "N/A" and value_1.get(key, None) \
                        else "N/A"
                    key_value_2 = "%s: %s" % (key, value_2[key]) if value_2 != "N/A" and value_2.get(key, None) \
                        else "N/A"
                    table_data.append(
                        [attribute, key_value_1, "->", key_value_2])
            else:
                table_data.append([attribute, value_1, "->", value_2])
        output = format_table(table_data)
        self.cli_helper.echo(output)
        return output

    @Helper.notify_no_project_found
    def inspect(self, **kwargs):
        self.snapshot_controller = SnapshotController()
        snapshot_id = kwargs.get("id", None)
        snapshot_obj = self.snapshot_controller.get(snapshot_id)
        output = str(snapshot_obj)
        self.cli_helper.echo(output)
        return output
예제 #4
0
class TaskController(BaseController):
    """TaskController inherits from BaseController and manages business logic associated with tasks
    within the project.

    Parameters
    ----------
    home : str
        home path of the project

    Attributes
    ----------
    environment : datmo.core.controller.environment.environment.EnvironmentController
        used to create environment if new definition file
    snapshot : datmo.core.controller.snapshot.SnapshotController
        used to create snapshots before and after tasks

    Methods
    -------
    create(dictionary)
        creates a Task object with the permanent parameters
    _run_helper(environment_id, log_filepath, options)
        helper for run to start environment and run with the appropriate parameters
    run(self, id, dictionary=None)
        runs the task and tracks the run, logs, inputs and outputs
    list(session_id=None)
        lists all tasks within the project given filters
    delete(id)
        deletes the specified task from the project
    """

    def __init__(self):
        super(TaskController, self).__init__()
        self.environment = EnvironmentController()
        self.snapshot = SnapshotController()
        self.spinner = Spinner()

        if not self.is_initialized:
            raise ProjectNotInitialized(
                __("error", "controller.task.__init__"))

    def create(self):
        """Create Task object

        Returns
        -------
        Task
            object entity for Task (datmo.core.entity.task.Task)
        """

        # Validate Inputs
        create_dict = {
            "model_id": self.model.id,
            "session_id": self.current_session.id
        }

        try:
            # Create Task
            self.spinner.start()
            task_obj = self.dal.task.create(Task(create_dict))
        finally:
            self.spinner.stop()
        return task_obj

    def _run_helper(self, environment_id, options, log_filepath):
        """Run environment with parameters

        Parameters
        ----------
        environment_id : str
            the environment id for definition
        options : dict
            can include the following values:

            command : list
            ports : list
                Here are some example ports used for common applications.
                   *  'jupyter notebook' - 8888
                   *  flask API - 5000
                   *  tensorboard - 6006
                An example input for the above would be ["8888:8888", "5000:5000", "6006:6006"]
                which maps the running host port (right) to that of the environment (left)
            name : str
            volumes : dict
            mem_limit : str
            workspace : str
            detach : bool
            stdin_open : bool
            tty : bool
        log_filepath : str
            absolute filepath to the log file

        Returns
        -------
        return_code : int
            system return code of the environment that was run
        run_id : str
            id of the environment run (different from environment id)
        logs : str
            output logs from the run
        """
        # Run container with options provided
        run_options = {
            "command": options.get('command', None),
            "ports": options.get('ports', None),
            "name": options.get('name', None),
            "volumes": options.get('volumes', None),
            "mem_limit": options.get('mem_limit', None),
            "gpu": options.get('gpu', False),
            "detach": options.get('detach', False),
            "stdin_open": options.get('stdin_open', False),
            "tty": options.get('tty', False),
            "api": False,
        }
        workspace = options.get('workspace', None)
        self.environment.build(environment_id, workspace)
        # Run container with environment
        return_code, run_id, logs = self.environment.run(
            environment_id, run_options, log_filepath)

        return return_code, run_id, logs

    def _parse_logs_for_results(self, logs):
        """Parse log string to extract results and return dictionary.

        The format of the log line must be "key:value", whitespace will not matter
        and if there are more than 2 items found when split on ":", it will not
        log this as a key/value result

        Note
        ----
        If the same key is found multiple times in the logs, the last occurring
        one will be the one that is saved.

        Parameters
        ----------
        logs : str
            raw string value of output logs

        Returns
        -------
        dict or None
            dictionary to represent results from task
        """
        results = {}
        for line in logs.split("\n"):
            split_line = line.split(":")
            if len(split_line) == 2:
                results[split_line[0].strip()] = split_line[1].strip()
        if results == {}:
            results = None
        return results

    def run(self, task_id, snapshot_dict=None, task_dict=None):
        """Run a task with parameters. If dictionary specified, create a new task with new run parameters.
        Snapshot objects are created before and after the task to keep track of the state. During the run,
        you can access task outputs using environment variable DATMO_TASK_DIR or `/task` which points to
        location for the task files. Create config.json, stats.json and any weights or any file such
        as graphs and visualizations within that directory for quick access

        Parameters
        ----------
        task_id : str
            id for the task you would like to run
        snapshot_dict : dict
            set of parameters to create a snapshot (see SnapshotController for details.
            default is None, which means dictionary with `visible` False will be added to
            hide auto-generated snapshot) NOTE: `visible` False will always be False regardless
            of whether the user provides another value for `visible`.
        task_dict : dict
            set of parameters to characterize the task run
            (default is None, which translate to {}, see datmo.core.entity.task.Task for more details on inputs)

        Returns
        -------
        Task
            the Task object which completed its run with updated parameters

        Raises
        ------
        TaskRunError
            If there is any error in creating files for the task or downstream errors
        """
        # Ensure visible=False is present in the snapshot dictionary
        if not snapshot_dict:
            snapshot_dict = {"visible": False}
        else:
            snapshot_dict['visible'] = False

        if not task_dict:
            task_dict = {}
        # Obtain Task to run
        task_obj = self.dal.task.get_by_id(task_id)

        # Ensure that at least 1 of command, command_list,  or interactive is present in task_dict
        important_task_args = ["command", "command_list", "interactive"]
        if not task_dict.get('command', task_obj.command) and \
            not task_dict.get('command_list', task_obj.command_list) and \
                not task_dict.get('interactive', task_obj.interactive):
            raise RequiredArgumentMissing(
                __("error", "controller.task.run.arg",
                   " or ".join(important_task_args)))

        if task_obj.status is None:
            task_obj.status = "RUNNING"
        else:
            raise TaskRunError(
                __("error", "cli.run.run.already_running", task_obj.id))
        # Create Task directory for user during run
        task_dirpath = os.path.join(".datmo", "tasks", task_obj.id)
        try:
            _ = self.file_driver.create(task_dirpath, directory=True)
        except Exception:
            raise TaskRunError(
                __("error", "controller.task.run", task_dirpath))
        # Create the before snapshot prior to execution
        before_snapshot_dict = snapshot_dict.copy()
        before_snapshot_dict[
            'message'] = "autogenerated snapshot created before task %s is run" % task_obj.id
        before_snapshot_obj = self.snapshot.create(before_snapshot_dict)
        # Update the task with pre-execution parameters, prefer list first then look for string command
        # List command will overwrite a string command if given
        if task_dict.get('command_list', task_obj.command_list):
            task_dict['command'] = " ".join(
                task_dict.get('command_list', task_obj.command_list))
        else:
            if task_dict.get('command', task_obj.command):
                task_dict['command_list'] = shlex.split(
                    task_dict.get('command', task_obj.command))
            elif not task_dict.get('interactive', task_obj.interactive):
                # If it's not interactive then there is not expected task
                raise TaskNoCommandGiven()

        validate("create_task", task_dict)
        task_obj = self.dal.task.update({
            "id":
                task_obj.id,
            "before_snapshot_id":
                task_dict.get('before_snapshot_id', before_snapshot_obj.id),
            "command":
                task_dict.get('command', task_obj.command),
            "command_list":
                task_dict.get('command_list', task_obj.command_list),
            "gpu":
                task_dict.get('gpu', False),
            "mem_limit":
                task_dict.get('mem_limit', None),
            "workspace":
                task_dict.get('workspace', None),
            "interactive":
                task_dict.get('interactive', task_obj.interactive),
            "detach":
                task_dict.get('detach', task_obj.detach),
            "ports":
                task_dict.get('ports', task_obj.ports),
            "task_dirpath":
                task_dict.get('task_dirpath', task_dirpath),
            "log_filepath":
                task_dict.get('log_filepath',
                              os.path.join(task_dirpath, "task.log")),
            "start_time":
                task_dict.get('start_time', datetime.utcnow()),
            "status":
                task_obj.status
        })

        # Copy over files from the before_snapshot file collection to task dir
        file_collection_obj =  \
            self.dal.file_collection.get_by_id(before_snapshot_obj.file_collection_id)
        self.file_driver.copytree(
            os.path.join(self.home, file_collection_obj.path),
            os.path.join(self.home, task_obj.task_dirpath))

        return_code, run_id, logs = 0, None, None

        try:
            # Set the parameters set in the task
            if task_obj.detach and task_obj.interactive:
                raise TaskInteractiveDetachError(
                    __("error", "controller.task.run.args.detach.interactive"))

            environment_run_options = {
                "command": task_obj.command_list,
                "ports": [] if task_obj.ports is None else task_obj.ports,
                "name": "datmo-task-" + self.model.id + "-" + task_obj.id,
                "volumes": {
                    os.path.join(self.home, task_obj.task_dirpath): {
                        'bind': '/task/',
                        'mode': 'rw'
                    },
                    self.home: {
                        'bind': '/home/',
                        'mode': 'rw'
                    }
                },
                "mem_limit": task_obj.mem_limit,
                "workspace": task_obj.workspace,
                "gpu": task_obj.gpu,
                "detach": task_obj.detach,
                "stdin_open": task_obj.interactive,
                "tty": task_obj.interactive,
                "api": False
            }
            # Run environment via the helper function
            return_code, run_id, logs =  \
                self._run_helper(before_snapshot_obj.environment_id,
                                 environment_run_options,
                                 os.path.join(self.home, task_obj.log_filepath))

        except Exception as e:
            return_code = 1
            logs += "Error running task: %" % e.message
        finally:
            # Create the after snapshot after execution is completed with new paths
            after_snapshot_dict = snapshot_dict.copy()
            after_snapshot_dict[
                'message'] = "autogenerated snapshot created after task %s is run" % task_obj.id

            # Add in absolute paths from running task directory
            absolute_task_dir_path = os.path.join(self.home,
                                                  task_obj.task_dirpath)
            absolute_paths = []
            for item in os.listdir(absolute_task_dir_path):
                path = os.path.join(absolute_task_dir_path, item)
                if os.path.isfile(path) or os.path.isdir(path):
                    absolute_paths.append(path)
            after_snapshot_dict.update({
                "paths": absolute_paths,
                "environment_id": before_snapshot_obj.environment_id,
            })
            after_snapshot_obj = self.snapshot.create(after_snapshot_dict)

            # (optional) Remove temporary task directory path
            # Update the task with post-execution parameters
            end_time = datetime.utcnow()
            duration = (end_time - task_obj.start_time).total_seconds()
            update_task_dict = {
                "id": task_obj.id,
                "after_snapshot_id": after_snapshot_obj.id,
                "logs": logs,
                "status": "SUCCESS" if return_code == 0 else "FAILED",
                # "results": task_obj.results, # TODO: update during run
                "end_time": end_time,
                "duration": duration
            }
            if logs is not None:
                update_task_dict["results"] = self._parse_logs_for_results(
                    logs)
            if run_id is not None:
                update_task_dict["run_id"] = run_id
            return self.dal.task.update(update_task_dict)

    def list(self, session_id=None, sort_key=None, sort_order=None):
        query = {}
        if session_id:
            try:
                self.dal.session.get_by_id(session_id)
            except EntityNotFound:
                raise SessionDoesNotExist(
                    __("error", "controller.task.list", session_id))
            query['session_id'] = session_id
        return self.dal.task.query(query, sort_key, sort_order)

    def get(self, task_id):
        """Get task object and return

        Parameters
        ----------
        task_id : str
            id for the task you would like to get

        Returns
        -------
        datmo.core.entity.task.Task
            core task object

        Raises
        ------
        DoesNotExist
            task does not exist
        """
        try:
            return self.dal.task.get_by_id(task_id)
        except EntityNotFound:
            raise DoesNotExist()

    def get_files(self, task_id, mode="r"):
        """Get list of file objects for task id. It will look in the following areas in the following order

        1) look in the after snapshot for file collection
        2) look in the running task file collection
        3) look in the before snapshot for file collection

        Parameters
        ----------
        task_id : str
            id for the task you would like to get file objects for
        mode : str
            file open mode
            (default is "r" to open file for read)

        Returns
        -------
        list
            list of python file objects

        Raises
        ------
        DoesNotExist
            task object does not exist
        PathDoesNotExist
            no file objects exist for the task
        """
        try:
            task_obj = self.dal.task.get_by_id(task_id)
        except EntityNotFound:
            raise DoesNotExist()
        if task_obj.after_snapshot_id:
            # perform number 1) and return file list
            return self.snapshot.get_files(
                task_obj.after_snapshot_id, mode=mode)
        elif task_obj.task_dirpath:
            # perform number 2) and return file list
            return self.file_driver.get(
                task_obj.task_dirpath, mode=mode, directory=True)
        elif task_obj.before_snapshot_id:
            # perform number 3) and return file list
            return self.snapshot.get_files(
                task_obj.before_snapshot_id, mode=mode)
        else:
            # Error because the task does not have any files associated with it
            raise PathDoesNotExist()

    def delete(self, task_id):
        if not task_id:
            raise RequiredArgumentMissing(
                __("error", "controller.task.delete.arg", "id"))
        stopped_success = self.stop(task_id)
        delete_task_success = self.dal.task.delete(task_id)
        return stopped_success and delete_task_success

    def stop(self, task_id=None, all=False, status="STOPPED"):
        """Stop and remove run for the task and update task object statuses

        Parameters
        ----------
        task_id : str, optional
            id for the task you would like to stop
        all : bool, optional
            if specified, will stop all tasks within project

        Returns
        -------
        return_code : bool
            system return code of the stop

        Raises
        ------
        RequiredArgumentMissing
        TooManyArgumentsFound
        """
        if task_id is None and all is False:
            raise RequiredArgumentMissing(
                __("error", "controller.task.stop.arg.missing", "id"))
        if task_id and all:
            raise TooManyArgumentsFound()
        if task_id:
            try:
                task_obj = self.get(task_id)
            except DoesNotExist:
                time.sleep(1)
                task_obj = self.get(task_id)
            task_match_string = "datmo-task-" + self.model.id + "-" + task_id
            # Get the environment id associated with the task
            kwargs = {'match_string': task_match_string}
            # Get the environment from the task
            before_snapshot_id = task_obj.before_snapshot_id
            after_snapshot_id = task_obj.after_snapshot_id
            if not before_snapshot_id and not after_snapshot_id:
                # TODO: remove...for now database may not be in sync. no task that has run can have NO before_snapshot_id
                time.sleep(1)
                task_obj = self.get(task_id)
            if after_snapshot_id:
                after_snapshot_obj = self.snapshot.get(after_snapshot_id)
                kwargs['environment_id'] = after_snapshot_obj.environment_id
            if not after_snapshot_id and before_snapshot_id:
                before_snapshot_obj = self.snapshot.get(before_snapshot_id)
                kwargs['environment_id'] = before_snapshot_obj.environment_id
            return_code = self.environment.stop(**kwargs)
        if all:
            return_code = self.environment.stop(all=True)
        # Set stopped task statuses to STOPPED if return success
        if return_code:
            if task_id:
                self.dal.task.update({"id": task_id, "status": status})
            if all:
                task_objs = self.dal.task.query({})
                for task_obj in task_objs:
                    self.dal.task.update({"id": task_obj.id, "status": status})

        return return_code
예제 #5
0
class TestSnapshotController():
    def setup_method(self):
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        Config().set_home(self.temp_dir)
        self.environment_ids = []

    def teardown_method(self):
        if not check_docker_inactive(test_datmo_dir,
                                     Config().datmo_directory_name):
            self.__setup()
            self.environment_controller = EnvironmentController()
            for env_id in list(set(self.environment_ids)):
                if not self.environment_controller.delete(env_id):
                    raise Exception

    def __setup(self):
        Config().set_home(self.temp_dir)
        self.project_controller = ProjectController()
        self.project_controller.init("test", "test description")
        self.task_controller = TaskController()
        self.snapshot_controller = SnapshotController()

    def test_init_fail_project_not_init(self):
        Config().set_home(self.temp_dir)
        failed = False
        try:
            SnapshotController()
        except ProjectNotInitialized:
            failed = True
        assert failed

    def test_init_fail_invalid_path(self):
        test_home = "some_random_dir"
        Config().set_home(test_home)
        failed = False
        try:
            SnapshotController()
        except InvalidProjectPath:
            failed = True
        assert failed

    def test_current_snapshot(self):
        self.__setup()
        # Test failure for unstaged changes
        failed = False
        try:
            self.snapshot_controller.current_snapshot()
        except UnstagedChanges:
            failed = True
        assert failed
        # Test success after snapshot created
        snapshot_obj = self.__default_create()
        current_snapshot_obj = self.snapshot_controller.current_snapshot()
        assert current_snapshot_obj == snapshot_obj

    def test_create_fail_no_message(self):
        self.__setup()
        # Test no message
        failed = False
        try:
            self.snapshot_controller.create({})
        except RequiredArgumentMissing:
            failed = True
        assert failed

    def test_create_success_no_code(self):
        self.__setup()
        # Test default values for snapshot, fail due to code
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})
        assert result

    def test_create_success_no_code_environment(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # test must pass when there is file present in root project folder
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert result

    def test_create_success_no_code_environment_files(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        test_file = os.path.join(
            self.project_controller.file_driver.files_directory, "test.txt")
        with open(test_file, "wb") as f:
            f.write(to_bytes(str("hello")))

        # test must pass when there is file present in root project folder
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert result

    def test_create_no_environment_detected_in_file(self):
        self.__setup()

        # Test default values for snapshot, fail due to no environment from file
        self.snapshot_controller.file_driver.create("filepath1")
        snapshot_obj_0 = self.snapshot_controller.create(
            {"message": "my test snapshot"})
        assert isinstance(snapshot_obj_0, Snapshot)
        assert snapshot_obj_0.code_id
        assert snapshot_obj_0.environment_id
        assert snapshot_obj_0.file_collection_id
        assert snapshot_obj_0.config == {}
        assert snapshot_obj_0.stats == {}

    def test_create_success_default_detected_in_file(self):
        self.__setup()
        # Test default values for snapshot when there is no environment
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import os\n"))
            f.write(to_bytes("import sys\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj_1 = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj_1, Snapshot)
        assert snapshot_obj_1.code_id
        assert snapshot_obj_1.environment_id
        assert snapshot_obj_1.file_collection_id
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_default_env_def(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_with_environment(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_env_paths(self):
        self.__setup()
        # Create environment definition
        random_dir = os.path.join(self.snapshot_controller.home, "random_dir")
        os.makedirs(random_dir)
        env_def_path = os.path.join(random_dir, "randomDockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))
        environment_paths = [env_def_path + ">Dockerfile"]

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create({
            "message":
            "my test snapshot",
            "environment_paths":
            environment_paths
        })

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_default_env_def_duplicate(self):
        self.__setup()
        # Test 2 snapshots with same parameters
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        snapshot_obj_1 = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Should return the same object back
        assert snapshot_obj_1.id == snapshot_obj.id
        assert snapshot_obj_1.code_id == snapshot_obj.code_id
        assert snapshot_obj_1.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_1.file_collection_id == \
               snapshot_obj.file_collection_id
        assert snapshot_obj_1.config == \
               snapshot_obj.config
        assert snapshot_obj_1.stats == \
               snapshot_obj.stats

    def test_create_success_given_files_env_def_config_file_stats_file(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        input_dict = {
            "message": "my test snapshot",
            "config_filepath": config_filepath,
            "stats_filepath": stats_filepath,
        }
        # Create snapshot in the project
        snapshot_obj_4 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_4 != snapshot_obj
        assert snapshot_obj_4.code_id != snapshot_obj.code_id
        assert snapshot_obj_4.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_4.file_collection_id != \
               snapshot_obj.file_collection_id
        assert snapshot_obj_4.config == {"foo": "bar"}
        assert snapshot_obj_4.stats == {"foo": "bar"}

    def test_create_success_given_files_env_def_different_config_stats(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Test different config and stats inputs
        input_dict = {
            "message": "my test snapshot",
            "config_filename": "different_name",
            "stats_filename": "different_name",
        }

        # Create snapshot in the project
        snapshot_obj_1 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_1 != snapshot_obj
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_given_files_env_def_direct_config_stats(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test different config and stats inputs
        input_dict = {
            "message": "my test snapshot",
            "config": {
                "foo": "bar"
            },
            "stats": {
                "foo": "bar"
            },
        }

        # Create snapshot in the project
        snapshot_obj_6 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_6.config == {"foo": "bar"}
        assert snapshot_obj_6.stats == {"foo": "bar"}

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_create_from_task(self):
        self.__setup()
        # 0) Test if fails with TaskNotComplete error
        # 1) Test if success with empty task files, results
        # 2) Test if success with task files, results, and message
        # 3) Test if success with message, label, config and stats
        # 4) Test if success with updated stats from after_snapshot_id and task_results

        # Create task in the project
        task_obj = self.task_controller.create()

        # 0) Test option 0
        failed = False
        try:
            _ = self.snapshot_controller.create_from_task(
                message="my test snapshot", task_id=task_obj.id)
        except TaskNotComplete:
            failed = True
        assert failed

        # 1) Test option 1

        # Create task_dict
        task_command = ["sh", "-c", "echo test"]
        task_dict = {"command_list": task_command}

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

        # Create new task and corresponding dict
        task_obj = self.task_controller.create()
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Test the default values
        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        # 2) Test option 2
        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

        # 3) Test option 3
        test_config = {"algo": "regression"}
        test_stats = {"accuracy": 0.9}
        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot",
            task_id=updated_task_obj.id,
            label="best",
            config=test_config,
            stats=test_stats)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.label == "best"
        assert snapshot_obj.config == test_config
        assert snapshot_obj.stats == test_stats
        assert snapshot_obj.visible == True

        # 4) Test option 4
        test_config = {"algo": "regression"}
        test_stats = {"new_key": 0.9}
        task_obj_2 = self.task_controller.create()
        updated_task_obj_2 = self.task_controller.run(task_obj_2.id,
                                                      task_dict=task_dict,
                                                      snapshot_dict={
                                                          "config":
                                                          test_config,
                                                          "stats": test_stats
                                                      })
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj_2.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot",
            task_id=updated_task_obj_2.id,
            label="best")
        updated_stats_dict = {}
        updated_stats_dict.update(test_stats)
        updated_stats_dict.update(updated_task_obj.results)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj_2.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.label == "best"
        assert snapshot_obj.stats == updated_stats_dict
        assert snapshot_obj.visible == True

    def __default_create(self):
        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))
        self.snapshot_controller.file_driver.create("filepath2")
        with open(os.path.join(self.snapshot_controller.home, "filepath2"),
                  "wb") as f:
            f.write(to_bytes(str("import sys\n")))
        # Create environment_driver definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        input_dict = {
            "message": "my test snapshot",
            "config_filename": config_filepath,
            "stats_filename": stats_filepath,
        }

        # Create snapshot in the project
        return self.snapshot_controller.create(input_dict)

    def test_check_unstaged_changes(self):
        self.__setup()
        # Check unstaged changes
        failed = False
        try:
            self.snapshot_controller.check_unstaged_changes()
        except UnstagedChanges:
            failed = True
        assert failed
        # Check no unstaged changes
        _ = self.__default_create()
        result = self.snapshot_controller.check_unstaged_changes()
        assert result == False

    def test_checkout(self):
        self.__setup()
        # Create snapshot
        snapshot_obj_1 = self.__default_create()

        # Create duplicate snapshot in project
        self.snapshot_controller.file_driver.create("test")
        snapshot_obj_2 = self.__default_create()

        assert snapshot_obj_2 != snapshot_obj_1

        # Checkout to snapshot 1 using snapshot id
        result = self.snapshot_controller.checkout(snapshot_obj_1.id)
        # TODO: Check for which snapshot we are on

        assert result == True

    def test_list(self):
        self.__setup()
        # Create file to add to snapshot
        test_filepath_1 = os.path.join(self.snapshot_controller.home,
                                       "test.txt")
        with open(test_filepath_1, "wb") as f:
            f.write(to_bytes(str("test")))

        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Create file to add to second snapshot
        test_filepath_2 = os.path.join(self.snapshot_controller.home,
                                       "test2.txt")
        with open(test_filepath_2, "wb") as f:
            f.write(to_bytes(str("test2")))

        # Create second snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # List all snapshots and ensure they exist
        result = self.snapshot_controller.list()

        assert len(result) == 2 and \
            snapshot_obj_1 in result and \
            snapshot_obj_2 in result

        # List all tasks regardless of filters in ascending
        result = self.snapshot_controller.list(sort_key='created_at',
                                               sort_order='ascending')

        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result
        assert result[0].created_at <= result[-1].created_at

        # List all tasks regardless of filters in descending
        result = self.snapshot_controller.list(sort_key='created_at',
                                               sort_order='descending')
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result
        assert result[0].created_at >= result[-1].created_at

        # Wrong order being passed in
        failed = False
        try:
            _ = self.snapshot_controller.list(sort_key='created_at',
                                              sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # Wrong key and order being passed in
        failed = False
        try:
            _ = self.snapshot_controller.list(sort_key='wrong_key',
                                              sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # wrong key and right order being passed in
        expected_result = self.snapshot_controller.list(sort_key='created_at',
                                                        sort_order='ascending')
        result = self.snapshot_controller.list(sort_key='wrong_key',
                                               sort_order='ascending')
        expected_ids = [item.id for item in expected_result]
        ids = [item.id for item in result]
        assert set(expected_ids) == set(ids)

        # List snapshots with visible filter
        result = self.snapshot_controller.list(visible=False)
        assert len(result) == 0

        result = self.snapshot_controller.list(visible=True)
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result

    def test_update(self):
        self.__setup()
        test_config = {"config_foo": "bar"}
        test_stats = {"stats_foo": "bar"}
        test_message = 'test_message'
        test_label = 'test_label'

        # Updating all config, stats, message and label
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        config=test_config,
                                        stats=test_stats,
                                        message=test_message,
                                        label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)
        assert updated_snapshot_obj.config == test_config
        assert updated_snapshot_obj.stats == test_stats
        assert updated_snapshot_obj.message == test_message
        assert updated_snapshot_obj.label == test_label

        # Updating config, stats
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        config=test_config,
                                        stats=test_stats)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)
        assert updated_snapshot_obj.config == test_config
        assert updated_snapshot_obj.stats == test_stats

        # Updating both message and label
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        message=test_message,
                                        label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)

        assert updated_snapshot_obj.message == test_message
        assert updated_snapshot_obj.label == test_label

        # Updating only message
        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj_1.id,
                                        message=test_message)

        # Get the updated snapshot obj
        updated_snapshot_obj_1 = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj_1.id)

        assert updated_snapshot_obj_1.message == test_message

        # Updating only label
        # Create snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj_2.id, label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj_2 = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj_2.id)

        assert updated_snapshot_obj_2.label == test_label

    def test_get(self):
        self.__setup()
        # Test failure for no snapshot
        failed = False
        try:
            self.snapshot_controller.get("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Test success for snapshot
        snapshot_obj = self.__default_create()
        snapshot_obj_returned = self.snapshot_controller.get(snapshot_obj.id)
        assert snapshot_obj == snapshot_obj_returned

    def test_get_files(self):
        self.__setup()
        # Test failure case
        failed = False
        try:
            self.snapshot_controller.get_files("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Test success case
        snapshot_obj = self.__default_create()
        result = self.snapshot_controller.get_files(snapshot_obj.id)
        file_collection_obj = self.task_controller.dal.file_collection.get_by_id(
            snapshot_obj.file_collection_id)

        file_names = [item.name for item in result]

        assert len(result) == 1
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "r"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

        result = self.snapshot_controller.get_files(snapshot_obj.id, mode="a")

        assert len(result) == 1
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "a"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

    def test_delete(self):
        self.__setup()
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Delete snapshot in the project
        result = self.snapshot_controller.delete(snapshot_obj.id)

        # Check if snapshot retrieval throws error
        thrown = False
        try:
            self.snapshot_controller.dal.snapshot.get_by_id(snapshot_obj.id)
        except EntityNotFound:
            thrown = True

        assert result == True and \
            thrown == True