Ejemplo n.º 1
0
 def stop(self, **kwargs):
     self.task_controller = TaskController()
     input_dict = {}
     mutually_exclusive(["id", "all"], kwargs, input_dict)
     if "id" in input_dict:
         self.cli_helper.echo(__("info", "cli.run.stop", input_dict['id']))
     elif "all" in input_dict:
         self.cli_helper.echo(__("info", "cli.run.stop.all"))
     else:
         raise RequiredArgumentMissing()
     try:
         if "id" in input_dict:
             result = self.task_controller.stop(task_id=input_dict['id'])
             if not result:
                 self.cli_helper.echo(
                     __("error", "cli.run.stop", input_dict['id']))
             else:
                 self.cli_helper.echo(
                     __("info", "cli.run.stop.success", input_dict['id']))
         if "all" in input_dict:
             result = self.task_controller.stop(all=input_dict['all'])
             if not result:
                 self.cli_helper.echo(__("error", "cli.run.stop.all"))
             else:
                 self.cli_helper.echo(__("info",
                                         "cli.run.stop.all.success"))
         return result
     except Exception:
         if "id" in input_dict:
             self.cli_helper.echo(
                 __("error", "cli.run.stop", input_dict['id']))
         if "all" in input_dict:
             self.cli_helper.echo(__("error", "cli.run.stop.all"))
         return False
Ejemplo n.º 2
0
    def __init__(self, home, cli_helper):
        super(TaskCommand, self).__init__(home, cli_helper)

        task_parser = self.subparsers.add_parser("task", help="Task module")
        subcommand_parsers = task_parser.add_subparsers(title="subcommands", dest="subcommand")

        # Task run arguments
        run = subcommand_parsers.add_parser("run", help="Run task")
        run.add_argument("--gpu", dest="gpu", action="store_true",
                         help="Boolean if you want to run using GPUs")
        run.add_argument("--ports", nargs="*", dest="ports", type=str, help="""
            Network port mapping during task (e.g. 8888:8888). Left is the host machine port and right
            is the environment port available during a run.
        """)
        # run.add_argument("--data", nargs="*", dest="data", type=str, help="Path for data to be used during the Task")
        run.add_argument("--env-def", dest="environment_definition_filepath", default="",
                         nargs="?", type=str,
                         help="Pass in the Dockerfile with which you want to build the environment")
        run.add_argument("--interactive", dest="interactive", action="store_true",
                         help="Run the environment in interactive mode (keeps STDIN open)")
        run.add_argument("cmd", nargs="?", default=None)

        # Task list arguments
        ls = subcommand_parsers.add_parser("ls", help="List tasks")
        ls.add_argument("--session-id", dest="session_id", default=None, nargs="?", type=str,
                         help="Pass in the session id to list the tasks in that session")

        # Task stop arguments
        stop = subcommand_parsers.add_parser("stop", help="Stop tasks")
        stop.add_argument("--id", dest="id", default=None, type=str, help="Task ID to stop")

        self.task_controller = TaskController(home=home)
Ejemplo n.º 3
0
def ls(session_id=None, filter=None):
    """List tasks within a project

    The project must be created before this is implemented. You can do that by using
    the following command::

        $ datmo init


    Parameters
    ----------
    session_id : str, optional
        session to filter output tasks
        (default is None, which means no session filter is given)
    filter : str, optional
        a string to use to filter from message and label
        (default is to give all snapshots, unless provided a specific string. eg: best)

    Returns
    -------
    list
        returns a list of Task entities (as defined above)

    Examples
    --------
    You can use this function within a project repository to list tasks.

    >>> import datmo
    >>> tasks = datmo.task.ls()
    """
    task_controller = TaskController()

    # add arguments if they are not None
    if not session_id:
        session_id = task_controller.current_session.id

    core_task_objs = task_controller.list(session_id,
                                          sort_key='created_at',
                                          sort_order='descending')

    # Filtering Tasks
    # TODO: move to list function in TaskController
    # Add in preliminary tasks if no filter
    filtered_core_task_objs = [
        core_task_obj for core_task_obj in core_task_objs if not filter
    ]
    # If filter is present then use it and only add those that pass filter
    for core_task_obj in core_task_objs:
        if filter and \
            (filter in core_task_obj.command):
            filtered_core_task_objs.append(core_task_obj)

    # Return Task entities
    return [
        Task(filtered_core_task_obj)
        for filtered_core_task_obj in filtered_core_task_objs
    ]
Ejemplo n.º 4
0
 def ls(self, **kwargs):
     # Create controllers
     self.task_controller = TaskController()
     self.snapshot_controller = SnapshotController()
     session_id = kwargs.get('session_id',
                             self.task_controller.current_session.id)
     print_format = kwargs.get('format', "table")
     download = kwargs.get('download', None)
     download_path = kwargs.get('download_path', None)
     # Get all task meta information
     task_objs = self.task_controller.list(session_id,
                                           sort_key="created_at",
                                           sort_order="descending")
     header_list = [
         "id", "command", "status", "config", "results", "created at"
     ]
     item_dict_list = []
     run_obj_list = []
     for task_obj in task_objs:
         # Create a new Run Object from Task Object
         run_obj = RunObject(task_obj)
         task_results_printable = printable_object(str(run_obj.results))
         snapshot_config_printable = printable_object(str(run_obj.config))
         item_dict_list.append({
             "id":
             run_obj.id,
             "command":
             run_obj.command,
             "status":
             run_obj.status,
             "config":
             snapshot_config_printable,
             "results":
             task_results_printable,
             "created at":
             prettify_datetime(run_obj.created_at)
         })
         run_obj_list.append(run_obj)
     if download:
         if not download_path:
             # download to current working directory with timestamp
             current_time = datetime.utcnow()
             epoch_time = datetime.utcfromtimestamp(0)
             current_time_unix_time_ms = (
                 current_time - epoch_time).total_seconds() * 1000.0
             download_path = os.path.join(
                 os.getcwd(), "run_ls_" + str(current_time_unix_time_ms))
         self.cli_helper.print_items(header_list,
                                     item_dict_list,
                                     print_format=print_format,
                                     output_path=download_path)
         return task_objs
     self.cli_helper.print_items(header_list,
                                 item_dict_list,
                                 print_format=print_format)
     return run_obj_list
Ejemplo n.º 5
0
    def __get_core_task(self):
        """Returns the latest core task object for id

        Returns
        -------
        datmo.core.entity.task.Task
            core task object fo the task
        """
        task_controller = TaskController()
        return task_controller.get(self.id)
Ejemplo n.º 6
0
 def setup_method(self):
     # provide mountable tmp directory for docker
     tempfile.tempdir = "/tmp" if not platform.system() == "Windows" else None
     test_datmo_dir = os.environ.get('TEST_DATMO_DIR',
                                     tempfile.gettempdir())
     self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
     self.project = ProjectController(self.temp_dir)
     self.project.init("test", "test description")
     self.environment = EnvironmentController(self.temp_dir)
     self.task = TaskController(self.temp_dir)
Ejemplo n.º 7
0
    def test_status_basic(self):
        self.project_controller.init("test3", "test description")
        status_dict, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test3"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert not latest_snapshot_user_generated
        assert not latest_snapshot_auto_generated
        assert unstaged_code  # no files, but unstaged because blank commit id has not yet been created (no initial snapshot)
        assert not unstaged_environment
        assert not unstaged_files

        self.task_controller = TaskController()

        # Create and run a task and test if unstaged task is shown
        first_task = self.task_controller.create()

        # Create task_dict
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        # Create a file so it can create a snapshot for the task
        env_def_path = os.path.join(self.task_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        updated_first_task = self.task_controller.run(first_task.id,
                                                      task_dict=task_dict)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_first_task.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        status_dict, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test3"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert not latest_snapshot_user_generated
        assert latest_snapshot_auto_generated
        # after task has been completed, all states are saved to ensure no lost work
        assert not unstaged_code
        assert not unstaged_environment
        assert not unstaged_files
Ejemplo n.º 8
0
 def ls(self, **kwargs):
     self.task_controller = TaskController()
     session_id = kwargs.get('session_id',
                             self.task_controller.current_session.id)
     print_format = kwargs.get('format', "table")
     download = kwargs.get('download', None)
     download_path = kwargs.get('download_path', None)
     # Get all task meta information
     task_objs = self.task_controller.list(session_id,
                                           sort_key='created_at',
                                           sort_order='descending')
     header_list = [
         "id", "start time", "duration (s)", "command", "status", "results"
     ]
     item_dict_list = []
     for task_obj in task_objs:
         task_results_printable = printable_object(task_obj.results)
         item_dict_list.append({
             "id":
             task_obj.id,
             "command":
             printable_object(task_obj.command),
             "status":
             printable_object(task_obj.status),
             "results":
             task_results_printable,
             "start time":
             prettify_datetime(task_obj.start_time),
             "duration (s)":
             printable_object(task_obj.duration)
         })
     if download:
         if not download_path:
             # download to current working directory with timestamp
             current_time = datetime.utcnow()
             epoch_time = datetime.utcfromtimestamp(0)
             current_time_unix_time_ms = (
                 current_time - epoch_time).total_seconds() * 1000.0
             download_path = os.path.join(
                 self.task_controller.home,
                 "task_ls_" + str(current_time_unix_time_ms))
         self.cli_helper.print_items(header_list,
                                     item_dict_list,
                                     print_format=print_format,
                                     output_path=download_path)
         return task_objs
     self.cli_helper.print_items(header_list,
                                 item_dict_list,
                                 print_format=print_format)
     return task_objs
Ejemplo n.º 9
0
    def rerun(self, **kwargs):
        self.task_controller = TaskController()
        # Get task id
        task_id = kwargs.get("id", None)
        self.cli_helper.echo(__("info", "cli.run.rerun", task_id))
        # Create the task_obj
        task_obj = self.task_controller.get(task_id)
        # Create the run obj
        run_obj = Run(task_obj)
        # Select the initial snapshot if it's a script else the final snapshot
        initial = True if run_obj.type == 'script' else False
        environment_id = run_obj.environment_id
        command = task_obj.command_list
        snapshot_id = run_obj.core_snapshot_id if not initial else run_obj.before_snapshot_id

        # Checkout to the core snapshot id before rerunning the task
        self.snapshot_controller = SnapshotController()
        try:
            checkout_success = self.snapshot_controller.checkout(snapshot_id)
        except Exception:
            self.cli_helper.echo(__("error", "cli.snapshot.checkout.failure"))
            sys.exit(1)

        if checkout_success:
            self.cli_helper.echo(
                __("info", "cli.snapshot.checkout.success", snapshot_id))

        # Rerunning the task
        # Create input dictionary for the new task
        snapshot_dict = {}
        snapshot_dict["environment_id"] = environment_id
        task_dict = {
            "ports": task_obj.ports,
            "interactive": task_obj.interactive,
            "mem_limit": task_obj.mem_limit,
            "command_list": command,
            "data_file_path_map": task_obj.data_file_path_map,
            "data_directory_path_map": task_obj.data_directory_path_map,
            "workspace": task_obj.workspace
        }
        # Run task and return Task object result
        new_task_obj = self.task_run_helper(task_dict, snapshot_dict,
                                            "cli.run.run")
        if not new_task_obj:
            return False
        # Creating the run object
        new_run_obj = Run(new_task_obj)
        return new_run_obj
Ejemplo n.º 10
0
    def files(self, mode="r"):
        """Returns a list of file objects for the task

        Parameters
        ----------
        mode : str
            file object mode
            (default is "r" which signifies read mode)

        Returns
        -------
        list
            list of file objects associated with the task
        """
        task_controller = TaskController(home=self._home)
        return task_controller.get_files(self.id, mode=mode)
Ejemplo n.º 11
0
 def test_init_fail_project_not_init(self):
     failed = False
     try:
         TaskController(self.temp_dir)
     except ProjectNotInitializedException:
         failed = True
     assert failed
Ejemplo n.º 12
0
    def test_create_from_task(self):
        # 1) Test if success with task files, results, and message
        # 2) Test if success with user given config and stats
        # TODO: test for failure case where tasks is not complete

        # Setup task

        # Create environment definition
        env_def_path = os.path.join(self.temp_dir, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        task_controller = TaskController()
        task_obj = task_controller.create()
        task_obj = task_controller.run(
            task_obj.id, task_dict={"command": "sh -c echo accuracy:0.45"})

        # 1) Test option 1
        snapshot_obj = create(
            message="my test snapshot",
            run_id=task_obj.id,
            label="best",
            config={"foo": "bar"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.label == "best"
        assert len(snapshot_obj.files) == 1
        assert "task.log" in snapshot_obj.files[0].name
        assert snapshot_obj.config == {"foo": "bar"}
        assert snapshot_obj.stats == task_obj.results

        # Test option 2
        snapshot_obj_2 = create(
            message="my test snapshot",
            run_id=task_obj.id,
            label="best",
            config={"foo": "bar"},
            stats={"foo": "bar"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj_2.message == "my test snapshot"
        assert snapshot_obj_2.label == "best"
        assert len(snapshot_obj.files) == 1
        assert "task.log" in snapshot_obj.files[0].name
        assert snapshot_obj_2.config == {"foo": "bar"}
        assert snapshot_obj_2.stats == {"foo": "bar"}
Ejemplo n.º 13
0
 def test_init_fail_project_not_init(self):
     Config().set_home(self.temp_dir)
     failed = False
     try:
         TaskController()
     except ProjectNotInitialized:
         failed = True
     assert failed
Ejemplo n.º 14
0
 def test_init_fail_invalid_path(self):
     test_home = "some_random_dir"
     failed = False
     try:
         TaskController(test_home)
     except InvalidProjectPathException:
         failed = True
     assert failed
Ejemplo n.º 15
0
 def delete(self, **kwargs):
     self.task_controller = TaskController()
     task_id = kwargs.get("id", None)
     if task_id:
         self.cli_helper.echo(__("info", "cli.run.delete", task_id))
     else:
         raise RequiredArgumentMissing()
     try:
         # Delete the task for the run
         result = self.task_controller.delete(task_id)
         if result:
             self.cli_helper.echo(
                 __("info", "cli.run.delete.success", task_id))
         return result
     except Exception:
         self.cli_helper.echo(__("error", "cli.run.delete", task_id))
         return False
Ejemplo n.º 16
0
 def test_init_fail_invalid_path(self):
     test_home = "some_random_dir"
     Config().set_home(test_home)
     failed = False
     try:
         TaskController()
     except InvalidProjectPath:
         failed = True
     assert failed
Ejemplo n.º 17
0
    def run(self, **kwargs):
        self.task_controller = TaskController()
        self.cli_helper.echo(__("info", "cli.task.run"))
        # Create input dictionaries
        snapshot_dict = {}

        # Environment
        if kwargs.get("environment_id", None) or kwargs.get(
                "environment_paths", None):
            mutually_exclusive_args = ["environment_id", "environment_paths"]
            mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)
        task_dict = {
            "ports": kwargs['ports'],
            "interactive": kwargs['interactive'],
            "mem_limit": kwargs['mem_limit']
        }
        if not isinstance(kwargs['cmd'], list):
            if platform.system() == "Windows":
                task_dict['command'] = kwargs['cmd']
            elif isinstance(kwargs['cmd'], basestring):
                task_dict['command_list'] = shlex.split(kwargs['cmd'])
        else:
            task_dict['command_list'] = kwargs['cmd']

        # Create the task object
        task_obj = self.task_controller.create()

        updated_task_obj = task_obj
        try:
            # Pass in the task
            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
        except Exception as e:
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo("%s" % e)
            self.cli_helper.echo(__("error", "cli.task.run", task_obj.id))
            return False
        finally:
            self.cli_helper.echo(__("info", "cli.task.run.stop"))
            self.task_controller.stop(updated_task_obj.id)
            self.cli_helper.echo(
                __("info", "cli.task.run.complete", updated_task_obj.id))

        return updated_task_obj
Ejemplo n.º 18
0
Archivo: base.py Proyecto: dmh43/datmo
    def task_run_helper(self, task_dict, snapshot_dict, error_identifier):
        """
        Run task with given parameters and provide error identifier

        Parameters
        ----------
        task_dict : dict
            input task dictionary for task run controller
        snapshot_dict : dict
            input snapshot dictionary for task run controller
        error_identifier : str
            identifier to print error

        Returns
        -------
        Task or False
            the Task object which completed its run with updated parameters.
            returns False if an error occurs
        """
        self.task_controller = TaskController()
        task_obj = self.task_controller.create()

        updated_task_obj = task_obj
        # Pass in the task
        status = "NOT STARTED"
        try:
            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
            status = "SUCCESS"
        except Exception as e:
            status = "FAILED"
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo("%s" % e)
            self.cli_helper.echo(__("error", error_identifier, task_obj.id))
            return False
        finally:
            self.cli_helper.echo(__("info", "cli.run.run.stop"))
            self.task_controller.stop(
                task_id=updated_task_obj.id, status=status)
            self.cli_helper.echo(
                __("info", "cli.run.run.complete", updated_task_obj.id))

        return updated_task_obj
Ejemplo n.º 19
0
    def rstudio(self, **kwargs):
        self.task_controller = TaskController()
        self.cli_helper.echo(__("info", "cli.workspace.rstudio"))
        # Creating input dictionaries
        snapshot_dict = {}

        # Environment
        if kwargs.get("environment_id", None) or kwargs.get(
                "environment_paths", None):
            mutually_exclusive_args = ["environment_id", "environment_paths"]
            mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)

        task_dict = {
            "ports": ["8787:8787"],
            "command_list": [
                "/usr/lib/rstudio-server/bin/rserver", "--server-daemonize=0",
                "--server-app-armor-enabled=0"
            ],
            "mem_limit":
            kwargs["mem_limit"]
        }

        # Create the task object
        task_obj = self.task_controller.create()

        updated_task_obj = task_obj
        # Pass in the task
        try:
            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
        except Exception as e:
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo(
                __("error", "cli.workspace.rstudio", task_obj.id))
            return False
        finally:
            self.cli_helper.echo(__("info", "cli.task.run.stop"))
            self.task_controller.stop(updated_task_obj.id)
            self.cli_helper.echo(
                __("info", "cli.task.run.complete", updated_task_obj.id))

        return updated_task_obj
Ejemplo n.º 20
0
    def test_create_from_task_fail_user_inputs(self):
        # Setup task

        # Create environment definition
        env_def_path = os.path.join(self.temp_dir, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        task_controller = TaskController()
        task_obj = task_controller.create()
        task_obj = task_controller.run(
            task_obj.id, task_dict={"command": "sh -c echo accuracy:0.45"})

        # Test if failure if user gives environment_id with task_id
        failed = False
        try:
            _ = create(
                message="my test snapshot",
                run_id=task_obj.id,
                label="best",
                config={"foo": "bar"},
                stats={"foo": "bar"},
                environment_id="test_id")
        except SnapshotCreateFromTaskArgs:
            failed = True
        assert failed
        # Test if failure if user gives filepaths with task_id
        failed = False
        try:
            _ = create(
                message="my test snapshot",
                run_id=task_obj.id,
                label="best",
                config={"foo": "bar"},
                stats={"foo": "bar"},
                paths=["mypath"])
        except SnapshotCreateFromTaskArgs:
            failed = True
        assert failed
Ejemplo n.º 21
0
    def notebook(self, **kwargs):
        self.task_controller = TaskController()
        self.cli_helper.echo(__("info", "cli.workspace.notebook"))
        # Creating input dictionaries
        snapshot_dict = {}

        # Environment
        if kwargs.get("environment_id", None) or kwargs.get(
                "environment_paths", None):
            mutually_exclusive_args = ["environment_id", "environment_paths"]
            mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)

        task_dict = {
            "ports": ["8888:8888"],
            "command_list": ["jupyter", "notebook"],
            "mem_limit": kwargs["mem_limit"]
        }

        # Create the task object
        task_obj = self.task_controller.create()

        updated_task_obj = task_obj
        # Pass in the task
        try:
            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
        except Exception as e:
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo(
                __("error", "cli.workspace.notebook", task_obj.id))
            return False
        finally:
            self.cli_helper.echo(__("info", "cli.task.run.stop"))
            self.task_controller.stop(updated_task_obj.id)
            self.cli_helper.echo(
                __("info", "cli.task.run.complete", updated_task_obj.id))

        return updated_task_obj
Ejemplo n.º 22
0
class RunCommand(ProjectCommand):
    def __init__(self, cli_helper):
        super(RunCommand, self).__init__(cli_helper)

    @Helper.notify_environment_active(TaskController)
    @Helper.notify_no_project_found
    def run(self, **kwargs):
        self.cli_helper.echo(__("info", "cli.task.run"))
        # Create controllers
        self.task_controller = TaskController()
        self.snapshot_controller = SnapshotController()
        # Create input dictionaries
        snapshot_dict = {}

        # Environment
        if kwargs.get("environment_id", None) or kwargs.get(
                "environment_paths", None):
            mutually_exclusive_args = ["environment_id", "environment_paths"]
            mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)
        task_dict = {
            "ports": kwargs['ports'],
            "interactive": kwargs['interactive'],
            "mem_limit": kwargs['mem_limit']
        }
        if not isinstance(kwargs['cmd'], list):
            if platform.system() == "Windows":
                task_dict['command'] = kwargs['cmd']
            elif isinstance(kwargs['cmd'], basestring):
                task_dict['command_list'] = shlex.split(kwargs['cmd'])
        else:
            task_dict['command_list'] = kwargs['cmd']

        # Create the task object
        task_obj = self.task_controller.create()
        try:
            # Pass in the task to run
            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
        except Exception as e:
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo("%s" % e)
            self.cli_helper.echo(__("error", "cli.task.run", task_obj.id))
            return False

        self.cli_helper.echo(
            __("info", "cli.task.run.complete", updated_task_obj.id))
        return updated_task_obj

    @Helper.notify_no_project_found
    def ls(self, **kwargs):
        # Create controllers
        self.task_controller = TaskController()
        self.snapshot_controller = SnapshotController()
        session_id = kwargs.get('session_id',
                                self.task_controller.current_session.id)
        print_format = kwargs.get('format', "table")
        download = kwargs.get('download', None)
        download_path = kwargs.get('download_path', None)
        # Get all task meta information
        task_objs = self.task_controller.list(session_id,
                                              sort_key="created_at",
                                              sort_order="descending")
        header_list = [
            "id", "command", "status", "config", "results", "created at"
        ]
        item_dict_list = []
        run_obj_list = []
        for task_obj in task_objs:
            # Create a new Run Object from Task Object
            run_obj = RunObject(task_obj)
            task_results_printable = printable_object(str(run_obj.results))
            snapshot_config_printable = printable_object(str(run_obj.config))
            item_dict_list.append({
                "id":
                run_obj.id,
                "command":
                run_obj.command,
                "status":
                run_obj.status,
                "config":
                snapshot_config_printable,
                "results":
                task_results_printable,
                "created at":
                prettify_datetime(run_obj.created_at)
            })
            run_obj_list.append(run_obj)
        if download:
            if not download_path:
                # download to current working directory with timestamp
                current_time = datetime.utcnow()
                epoch_time = datetime.utcfromtimestamp(0)
                current_time_unix_time_ms = (
                    current_time - epoch_time).total_seconds() * 1000.0
                download_path = os.path.join(
                    os.getcwd(), "run_ls_" + str(current_time_unix_time_ms))
            self.cli_helper.print_items(header_list,
                                        item_dict_list,
                                        print_format=print_format,
                                        output_path=download_path)
            return task_objs
        self.cli_helper.print_items(header_list,
                                    item_dict_list,
                                    print_format=print_format)
        return run_obj_list
Ejemplo n.º 23
0
class TestProjectController():
    def setup_method(self):
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        Config().set_home(self.temp_dir)
        self.project_controller = ProjectController()
        self.environment_ids = []

    def teardown_method(self):
        if not check_docker_inactive(test_datmo_dir):
            self.project_controller = ProjectController()
            if self.project_controller.is_initialized:
                self.environment_controller = EnvironmentController()
                for env_id in list(set(self.environment_ids)):
                    if not self.environment_controller.delete(env_id):
                        raise Exception

    def test_init_failure_none(self):
        # Test failed case
        failed = False
        try:
            self.project_controller.init(None, None)
        except ValidationFailed:
            failed = True
        assert failed

    def test_init_failure_empty_str(self):
        # Test failed case
        failed = False
        try:
            self.project_controller.init("", "")
        except ValidationFailed:
            failed = True
        assert failed
        assert not self.project_controller.code_driver.is_initialized
        assert not self.project_controller.file_driver.is_initialized

    def test_init_failure_git_code_driver(self):
        # Create a HEAD.lock file in .git to make GitCodeDriver.init() fail
        if self.project_controller.code_driver.type == "git":
            git_dir = os.path.join(
                self.project_controller.code_driver.filepath, ".git")
            os.makedirs(git_dir)
            with open(os.path.join(git_dir, "HEAD.lock"), "a+") as f:
                f.write(to_bytes("test"))
            failed = False
            try:
                self.project_controller.init("test1", "test description")
            except Exception:
                failed = True
            assert failed
            assert not self.project_controller.code_driver.is_initialized
            assert not self.project_controller.file_driver.is_initialized

    def test_init_success(self):
        result = self.project_controller.init("test1", "test description")

        # Tested with is_initialized
        assert self.project_controller.model.name == "test1"
        assert self.project_controller.model.description == "test description"
        assert result and self.project_controller.is_initialized

        # Changeable by user, not tested in is_initialized
        assert self.project_controller.current_session.name == "default"

    # TODO: Test lower level functions (DAL, JSONStore, etc for interruptions)
    # def test_init_with_interruption(self):
    #     # Reinitializing after timed interruption during init
    #     @timeout_decorator.timeout(0.001, use_signals=False)
    #     def timed_init_with_interruption():
    #         result = self.project_controller.init("test1", "test description")
    #         return result
    #
    #     failed = False
    #     try:
    #         timed_init_with_interruption()
    #     except timeout_decorator.timeout_decorator.TimeoutError:
    #         failed = True
    #     # Tested with is_initialized
    #     assert failed
    #
    #     # Reperforming init after a wait of 2 seconds
    #     time.sleep(2)
    #     result = self.project_controller.init("test2", "test description")
    #     # Tested with is_initialized
    #     assert self.project_controller.model.name == "test2"
    #     assert self.project_controller.model.description == "test description"
    #     assert result and self.project_controller.is_initialized
    #
    #     # Changeable by user, not tested in is_initialized
    #     assert self.project_controller.current_session.name == "default"

    def test_init_reinit_failure_empty_str(self):
        _ = self.project_controller.init("test1", "test description")
        failed = True
        try:
            self.project_controller.init("", "")
        except Exception:
            failed = True
        assert failed
        assert self.project_controller.model.name == "test1"
        assert self.project_controller.model.description == "test description"
        assert self.project_controller.code_driver.is_initialized
        assert self.project_controller.file_driver.is_initialized

    def test_init_reinit_success(self):
        _ = self.project_controller.init("test1", "test description")
        # Test out functionality for re-initialize project
        result = self.project_controller.init("anything", "else")

        assert self.project_controller.model.name == "anything"
        assert self.project_controller.model.description == "else"
        assert result == True

    def test_cleanup_no_environment(self):
        self.project_controller.init("test2", "test description")
        result = self.project_controller.cleanup()

        assert not self.project_controller.code_driver.is_initialized
        assert not self.project_controller.file_driver.is_initialized
        # Ensure that containers built with this image do not exist
        # assert not self.project_controller.environment_driver.list_containers(filters={
        #     "ancestor": image_id
        # })
        assert result == True

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_cleanup_with_environment(self):
        self.project_controller.init("test2", "test description")
        result = self.project_controller.cleanup()

        assert not self.project_controller.code_driver.is_initialized
        assert not self.project_controller.file_driver.is_initialized
        assert not self.project_controller.environment_driver.list_images(
            "datmo-test2")
        # Ensure that containers built with this image do not exist
        # assert not self.project_controller.environment_driver.list_containers(filters={
        #     "ancestor": image_id
        # })
        assert result == True

    def test_status_basic(self):
        self.project_controller.init("test3", "test description")
        status_dict, current_snapshot, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test3"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert not current_snapshot
        assert not latest_snapshot_user_generated
        assert not latest_snapshot_auto_generated
        assert unstaged_code  # no files, but unstaged because blank commit id has not yet been created (no initial snapshot)
        assert not unstaged_environment
        assert not unstaged_files

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_status_snapshot_task(self):
        self.project_controller.init("test4", "test description")
        self.snapshot_controller = SnapshotController()
        self.task_controller = TaskController()

        # Create files to add
        self.snapshot_controller.file_driver.create("dirpath1", directory=True)
        self.snapshot_controller.file_driver.create("dirpath2", directory=True)
        self.snapshot_controller.file_driver.create("filepath1")

        # Create environment definition
        env_def_path = os.path.join(self.snapshot_controller.home,
                                    "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        environment_paths = [env_def_path]

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        input_dict = {
            "message":
                "my test snapshot",
            "paths": [
                os.path.join(self.snapshot_controller.home, "dirpath1"),
                os.path.join(self.snapshot_controller.home, "dirpath2"),
                os.path.join(self.snapshot_controller.home, "filepath1")
            ],
            "environment_paths":
                environment_paths,
            "config_filename":
                config_filepath,
            "stats_filename":
                stats_filepath,
        }

        # Create snapshot in the project, then wait, and try status
        first_snapshot = self.snapshot_controller.create(input_dict)

        status_dict, current_snapshot, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test4"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert not current_snapshot  # snapshot was created from other environments and files (so user is not on any current snapshot)
        assert isinstance(latest_snapshot_user_generated, Snapshot)
        assert latest_snapshot_user_generated == first_snapshot
        assert not latest_snapshot_auto_generated
        assert not unstaged_code
        assert not unstaged_environment
        assert not unstaged_files

        # Create and run a task and test if task is shown
        first_task = self.task_controller.create()

        # Create task_dict
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        updated_first_task = self.task_controller.run(
            first_task.id, task_dict=task_dict)
        before_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_first_task.before_snapshot_id)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_first_task.after_snapshot_id)
        before_environment_obj = self.task_controller.dal.environment.get_by_id(
            before_snapshot_obj.environment_id)
        after_environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        assert before_environment_obj == after_environment_obj
        self.environment_ids.append(after_environment_obj.id)

        status_dict, current_snapshot, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test4"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert isinstance(current_snapshot, Snapshot)
        assert isinstance(latest_snapshot_user_generated, Snapshot)
        assert latest_snapshot_user_generated == first_snapshot
        assert isinstance(latest_snapshot_auto_generated, Snapshot)
        # current snapshot is the before snapshot for the run
        assert current_snapshot == before_snapshot_obj
        assert current_snapshot != latest_snapshot_auto_generated
        assert current_snapshot != latest_snapshot_user_generated
        # latest autogenerated snapshot is the after snapshot id
        assert latest_snapshot_auto_generated == after_snapshot_obj
        assert latest_snapshot_auto_generated != latest_snapshot_user_generated
        # user generated snapshot is not associated with any before or after snapshot
        assert latest_snapshot_user_generated != before_snapshot_obj
        assert latest_snapshot_user_generated != after_snapshot_obj
        assert not unstaged_code
        assert not unstaged_environment
        assert not unstaged_files
Ejemplo n.º 24
0
    def test_status_snapshot_task(self):
        self.project_controller.init("test4", "test description")
        self.snapshot_controller = SnapshotController()
        self.task_controller = TaskController()

        # Create files to add
        self.snapshot_controller.file_driver.create("dirpath1", directory=True)
        self.snapshot_controller.file_driver.create("dirpath2", directory=True)
        self.snapshot_controller.file_driver.create("filepath1")

        # Create environment definition
        env_def_path = os.path.join(self.snapshot_controller.home,
                                    "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        environment_paths = [env_def_path]

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        input_dict = {
            "message":
                "my test snapshot",
            "paths": [
                os.path.join(self.snapshot_controller.home, "dirpath1"),
                os.path.join(self.snapshot_controller.home, "dirpath2"),
                os.path.join(self.snapshot_controller.home, "filepath1")
            ],
            "environment_paths":
                environment_paths,
            "config_filename":
                config_filepath,
            "stats_filename":
                stats_filepath,
        }

        # Create snapshot in the project, then wait, and try status
        first_snapshot = self.snapshot_controller.create(input_dict)

        status_dict, current_snapshot, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test4"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert not current_snapshot  # snapshot was created from other environments and files (so user is not on any current snapshot)
        assert isinstance(latest_snapshot_user_generated, Snapshot)
        assert latest_snapshot_user_generated == first_snapshot
        assert not latest_snapshot_auto_generated
        assert not unstaged_code
        assert not unstaged_environment
        assert not unstaged_files

        # Create and run a task and test if task is shown
        first_task = self.task_controller.create()

        # Create task_dict
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        updated_first_task = self.task_controller.run(
            first_task.id, task_dict=task_dict)
        before_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_first_task.before_snapshot_id)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_first_task.after_snapshot_id)
        before_environment_obj = self.task_controller.dal.environment.get_by_id(
            before_snapshot_obj.environment_id)
        after_environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        assert before_environment_obj == after_environment_obj
        self.environment_ids.append(after_environment_obj.id)

        status_dict, current_snapshot, latest_snapshot_user_generated, latest_snapshot_auto_generated, unstaged_code, unstaged_environment, unstaged_files = \
            self.project_controller.status()

        assert status_dict
        assert isinstance(status_dict, dict)
        assert status_dict['name'] == "test4"
        assert status_dict['description'] == "test description"
        assert isinstance(status_dict['config'], dict)
        assert isinstance(current_snapshot, Snapshot)
        assert isinstance(latest_snapshot_user_generated, Snapshot)
        assert latest_snapshot_user_generated == first_snapshot
        assert isinstance(latest_snapshot_auto_generated, Snapshot)
        # current snapshot is the before snapshot for the run
        assert current_snapshot == before_snapshot_obj
        assert current_snapshot != latest_snapshot_auto_generated
        assert current_snapshot != latest_snapshot_user_generated
        # latest autogenerated snapshot is the after snapshot id
        assert latest_snapshot_auto_generated == after_snapshot_obj
        assert latest_snapshot_auto_generated != latest_snapshot_user_generated
        # user generated snapshot is not associated with any before or after snapshot
        assert latest_snapshot_user_generated != before_snapshot_obj
        assert latest_snapshot_user_generated != after_snapshot_obj
        assert not unstaged_code
        assert not unstaged_environment
        assert not unstaged_files
Ejemplo n.º 25
0
 def __setup(self):
     Config().set_home(self.temp_dir)
     self.project_controller = ProjectController()
     self.project_controller.init("test", "test description")
     self.task_controller = TaskController()
     self.snapshot_controller = SnapshotController()
Ejemplo n.º 26
0
class TestSnapshotController():
    def setup_method(self):
        # provide mountable tmp directory for docker
        tempfile.tempdir = "/tmp" if not platform.system(
        ) == "Windows" else None
        test_datmo_dir = os.environ.get('TEST_DATMO_DIR',
                                        tempfile.gettempdir())
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        self.project = ProjectController(self.temp_dir)
        self.project.init("test", "test description")
        self.task = TaskController(self.temp_dir)
        self.snapshot = SnapshotController(self.temp_dir)

    def teardown_method(self):
        pass

    def test_create_fail_no_message(self):
        # Test no message
        failed = False
        try:
            self.snapshot.create({})
        except RequiredArgumentMissing:
            failed = True
        assert failed

    def test_create_fail_no_code(self):
        # Test default values for snapshot, fail due to code
        failed = False
        try:
            self.snapshot.create({"message": "my test snapshot"})
        except GitCommitDoesNotExist:
            failed = True
        assert failed

    def test_create_fail_no_environment_with_language(self):
        # Test default values for snapshot, fail due to environment with other than default
        self.snapshot.file_driver.create("filepath1")
        failed = False
        try:
            self.snapshot.create({
                "message": "my test snapshot",
                "language": "java"
            })
        except EnvironmentDoesNotExist:
            failed = True
        assert failed

    def test_create_fail_no_environment_detected_in_file(self):
        # Test default values for snapshot, fail due to no environment from file
        self.snapshot.file_driver.create("filepath1")
        failed = False
        try:
            self.snapshot.create({
                "message": "my test snapshot",
            })
        except EnvironmentDoesNotExist:
            failed = True
        assert failed

    def test_create_success_default_detected_in_file(self):
        # Test default values for snapshot when there is no environment
        test_filepath = os.path.join(self.snapshot.home, "script.py")
        with open(test_filepath, "w") as f:
            f.write(to_unicode("import numpy\n"))
            f.write(to_unicode("import sklearn\n"))
            f.write(to_unicode("print('hello')\n"))

        snapshot_obj_1 = self.snapshot.create({"message": "my test snapshot"})

        assert snapshot_obj_1
        assert snapshot_obj_1.code_id
        assert snapshot_obj_1.environment_id
        assert snapshot_obj_1.file_collection_id
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_default_env_def(self):
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot.create({"message": "my test snapshot"})

        assert snapshot_obj
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_default_env_def_duplicate(self):
        # Test 2 snapshots with same parameters
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        snapshot_obj = self.snapshot.create({"message": "my test snapshot"})

        snapshot_obj_1 = self.snapshot.create({"message": "my test snapshot"})

        # Should return the same object back
        assert snapshot_obj_1 == snapshot_obj
        assert snapshot_obj_1.code_id == snapshot_obj.code_id
        assert snapshot_obj_1.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_1.file_collection_id == \
               snapshot_obj.file_collection_id
        assert snapshot_obj_1.config == \
               snapshot_obj.config
        assert snapshot_obj_1.stats == \
               snapshot_obj.stats

    def test_create_success_given_files_env_def_config_file_stats_file(self):
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        snapshot_obj = self.snapshot.create({"message": "my test snapshot"})

        # Create files to add
        self.snapshot.file_driver.create("dirpath1", directory=True)
        self.snapshot.file_driver.create("dirpath2", directory=True)
        self.snapshot.file_driver.create("filepath1")

        # Create config
        config_filepath = os.path.join(self.snapshot.home, "config.json")
        with open(config_filepath, "w") as f:
            f.write(to_unicode(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot.home, "stats.json")
        with open(stats_filepath, "w") as f:
            f.write(to_unicode(str('{"foo":"bar"}')))

        input_dict = {
            "message":
            "my test snapshot",
            "filepaths": [
                os.path.join(self.snapshot.home, "dirpath1"),
                os.path.join(self.snapshot.home, "dirpath2"),
                os.path.join(self.snapshot.home, "filepath1")
            ],
            "environment_definition_filepath":
            env_def_path,
            "config_filepath":
            config_filepath,
            "stats_filepath":
            stats_filepath,
        }
        # Create snapshot in the project
        snapshot_obj_4 = self.snapshot.create(input_dict)

        assert snapshot_obj_4 != snapshot_obj
        assert snapshot_obj_4.code_id != snapshot_obj.code_id
        assert snapshot_obj_4.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_4.file_collection_id != \
               snapshot_obj.file_collection_id
        assert snapshot_obj_4.config == {"foo": "bar"}
        assert snapshot_obj_4.stats == {"foo": "bar"}

    def test_create_success_given_files_env_def_different_config_stats(self):
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        snapshot_obj = self.snapshot.create({"message": "my test snapshot"})

        # Create files to add
        self.snapshot.file_driver.create("dirpath1", directory=True)
        self.snapshot.file_driver.create("dirpath2", directory=True)
        self.snapshot.file_driver.create("filepath1")

        # Create config
        config_filepath = os.path.join(self.snapshot.home, "config.json")
        with open(config_filepath, "w") as f:
            f.write(to_unicode(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot.home, "stats.json")
        with open(stats_filepath, "w") as f:
            f.write(to_unicode(str('{"foo":"bar"}')))

        # Test different config and stats inputs
        input_dict = {
            "message":
            "my test snapshot",
            "filepaths": [
                os.path.join(self.snapshot.home, "dirpath1"),
                os.path.join(self.snapshot.home, "dirpath2"),
                os.path.join(self.snapshot.home, "filepath1")
            ],
            "environment_definition_filepath":
            env_def_path,
            "config_filename":
            "different_name",
            "stats_filename":
            "different_name",
        }

        # Create snapshot in the project
        snapshot_obj_1 = self.snapshot.create(input_dict)

        assert snapshot_obj_1 != snapshot_obj
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_given_files_env_def_direct_config_stats(self):
        # Create environment definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Create files to add
        self.snapshot.file_driver.create("dirpath1", directory=True)
        self.snapshot.file_driver.create("dirpath2", directory=True)
        self.snapshot.file_driver.create("filepath1")

        # Test different config and stats inputs
        input_dict = {
            "message":
            "my test snapshot",
            "filepaths": [
                os.path.join(self.snapshot.home, "dirpath1"),
                os.path.join(self.snapshot.home, "dirpath2"),
                os.path.join(self.snapshot.home, "filepath1")
            ],
            "environment_definition_filepath":
            env_def_path,
            "config": {
                "foo": "bar"
            },
            "stats": {
                "foo": "bar"
            },
        }

        # Create snapshot in the project
        snapshot_obj_6 = self.snapshot.create(input_dict)

        assert snapshot_obj_6.config == {"foo": "bar"}
        assert snapshot_obj_6.stats == {"foo": "bar"}

    def test_create_from_task(self):
        # 1) Test if fails with TaskNotComplete error
        # 2) Test if success with task files, results, and message

        # Setup task
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        input_dict = {"command": task_command}

        # Create task in the project
        task_obj = self.task.create(input_dict)

        # 1) Test option 1
        failed = False
        try:
            _ = self.snapshot.create_from_task(message="my test snapshot",
                                               task_id=task_obj.id)
        except TaskNotComplete:
            failed = True
        assert failed

        # Create environment definition
        env_def_path = os.path.join(self.project.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Test the default values
        updated_task_obj = self.task.run(task_obj.id)

        # 2) Test option 2
        snapshot_obj = self.snapshot.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

    def __default_create(self):
        # Create files to add
        self.snapshot.file_driver.create("dirpath1", directory=True)
        self.snapshot.file_driver.create("dirpath2", directory=True)
        self.snapshot.file_driver.create("filepath1")

        # Create environment_driver definition
        env_def_path = os.path.join(self.snapshot.home, "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Create config
        config_filepath = os.path.join(self.snapshot.home, "config.json")
        with open(config_filepath, "w") as f:
            f.write(to_unicode(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot.home, "stats.json")
        with open(stats_filepath, "w") as f:
            f.write(to_unicode(str("{}")))

        input_dict = {
            "message":
            "my test snapshot",
            "filepaths": [
                os.path.join(self.snapshot.home, "dirpath1"),
                os.path.join(self.snapshot.home, "dirpath2"),
                os.path.join(self.snapshot.home, "filepath1")
            ],
            "environment_definition_filepath":
            env_def_path,
            "config_filename":
            config_filepath,
            "stats_filename":
            stats_filepath,
        }

        # Create snapshot in the project
        return self.snapshot.create(input_dict)

    def test_checkout(self):
        # Create snapshot
        snapshot_obj_1 = self.__default_create()

        code_obj_1 = self.snapshot.dal.code.get_by_id(snapshot_obj_1.code_id)

        # Create duplicate snapshot in project
        self.snapshot.file_driver.create("test")
        snapshot_obj_2 = self.__default_create()

        assert snapshot_obj_2 != snapshot_obj_1

        # Checkout to snapshot 1 using snapshot id
        result = self.snapshot.checkout(snapshot_obj_1.id)

        # Snapshot directory in user directory
        snapshot_obj_1_path = os.path.join(self.snapshot.home,
                                           "datmo_snapshots",
                                           snapshot_obj_1.id)

        assert result == True and \
               self.snapshot.code_driver.latest_commit() == code_obj_1.commit_id and \
               os.path.isdir(snapshot_obj_1_path)

    def test_list(self):
        # Check for error if incorrect session given
        failed = False
        try:
            self.snapshot.list(session_id="does_not_exist")
        except SessionDoesNotExistException:
            failed = True
        assert failed

        # Create file to add to snapshot
        test_filepath_1 = os.path.join(self.snapshot.home, "test.txt")
        with open(test_filepath_1, "w") as f:
            f.write(to_unicode(str("test")))

        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Create file to add to second snapshot
        test_filepath_2 = os.path.join(self.snapshot.home, "test2.txt")
        with open(test_filepath_2, "w") as f:
            f.write(to_unicode(str("test2")))

        # Create second snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # List all snapshots and ensure they exist
        result = self.snapshot.list()

        assert len(result) == 2 and \
            snapshot_obj_1 in result and \
            snapshot_obj_2 in result

        # List all snapshots with session filter
        result = self.snapshot.list(session_id=self.project.current_session.id)

        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result

        # List snapshots with visible filter
        result = self.snapshot.list(visible=False)
        assert len(result) == 0

        result = self.snapshot.list(visible=True)
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result

    def test_delete(self):
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Delete snapshot in the project
        result = self.snapshot.delete(snapshot_obj.id)

        # Check if snapshot retrieval throws error
        thrown = False
        try:
            self.snapshot.dal.snapshot.get_by_id(snapshot_obj.id)
        except EntityNotFound:
            thrown = True

        assert result == True and \
            thrown == True
Ejemplo n.º 27
0
class RunCommand(ProjectCommand):
    def __init__(self, cli_helper):
        super(RunCommand, self).__init__(cli_helper)

    @Helper.notify_environment_active(TaskController)
    @Helper.notify_no_project_found
    def run(self, **kwargs):
        self.cli_helper.echo(__("info", "cli.run.run"))
        # Create input dictionaries
        snapshot_dict = {}
        # Environment
        if kwargs.get("environment_id", None) or kwargs.get(
                "environment_paths", None):
            mutually_exclusive_args = ["environment_id", "environment_paths"]
            mutually_exclusive(mutually_exclusive_args, kwargs, snapshot_dict)
        task_dict = {
            "ports": kwargs['ports'],
            "interactive": kwargs['interactive'],
            "mem_limit": kwargs['mem_limit']
        }
        if not isinstance(kwargs['cmd'], list):
            if platform.system() == "Windows":
                task_dict['command'] = kwargs['cmd']
            elif isinstance(kwargs['cmd'], basestring):
                task_dict['command_list'] = shlex.split(kwargs['cmd'])
        else:
            task_dict['command_list'] = kwargs['cmd']

        data_paths = kwargs['data']
        # Run task and return Task object result
        task_obj = self.task_run_helper(task_dict,
                                        snapshot_dict,
                                        "cli.run.run",
                                        data_paths=data_paths)
        if not task_obj:
            return False
        # Creating the run object
        run_obj = Run(task_obj)
        return run_obj

    @Helper.notify_no_project_found
    def ls(self, **kwargs):
        print_format = kwargs.get('format', "table")
        download = kwargs.get('download', None)
        download_path = kwargs.get('download_path', None)
        # Get all task meta information
        self.task_controller = TaskController()
        task_objs = self.task_controller.list(sort_key="created_at",
                                              sort_order="descending")
        header_list = [
            "id", "command", "type", "status", "config", "results",
            "created at"
        ]
        item_dict_list = []
        run_obj_list = []
        for task_obj in task_objs:
            # Create a new Run Object from Task Object
            run_obj = Run(task_obj)
            task_results_printable = printable_object(run_obj.results)
            snapshot_config_printable = printable_object(run_obj.config)
            item_dict_list.append({
                "id":
                run_obj.id,
                "command":
                run_obj.command,
                "type":
                run_obj.type,
                "status":
                run_obj.status,
                "config":
                snapshot_config_printable,
                "results":
                task_results_printable,
                "created at":
                prettify_datetime(run_obj.created_at)
            })
            run_obj_list.append(run_obj)
        if download:
            if not download_path:
                # download to current working directory with timestamp
                current_time = datetime.utcnow()
                epoch_time = datetime.utcfromtimestamp(0)
                current_time_unix_time_ms = (
                    current_time - epoch_time).total_seconds() * 1000.0
                download_path = os.path.join(
                    os.getcwd(), "run_ls_" + str(current_time_unix_time_ms))
            self.cli_helper.print_items(header_list,
                                        item_dict_list,
                                        print_format=print_format,
                                        output_path=download_path)
            return task_objs
        self.cli_helper.print_items(header_list,
                                    item_dict_list,
                                    print_format=print_format)
        return run_obj_list

    @Helper.notify_environment_active(TaskController)
    @Helper.notify_no_project_found
    def rerun(self, **kwargs):
        self.task_controller = TaskController()
        # Get task id
        task_id = kwargs.get("id", None)
        self.cli_helper.echo(__("info", "cli.run.rerun", task_id))
        # Create the task_obj
        task_obj = self.task_controller.get(task_id)
        # Create the run obj
        run_obj = Run(task_obj)
        # Select the initial snapshot if it's a script else the final snapshot
        initial = True if run_obj.type == 'script' else False
        environment_id = run_obj.environment_id
        command = task_obj.command_list
        snapshot_id = run_obj.core_snapshot_id if not initial else run_obj.before_snapshot_id

        # Checkout to the core snapshot id before rerunning the task
        self.snapshot_controller = SnapshotController()
        try:
            checkout_success = self.snapshot_controller.checkout(snapshot_id)
        except Exception:
            self.cli_helper.echo(__("error", "cli.snapshot.checkout.failure"))
            sys.exit(1)

        if checkout_success:
            self.cli_helper.echo(
                __("info", "cli.snapshot.checkout.success", snapshot_id))

        # Rerunning the task
        # Create input dictionary for the new task
        snapshot_dict = {}
        snapshot_dict["environment_id"] = environment_id
        task_dict = {
            "ports": task_obj.ports,
            "interactive": task_obj.interactive,
            "mem_limit": task_obj.mem_limit,
            "command_list": command,
            "data_file_path_map": task_obj.data_file_path_map,
            "data_directory_path_map": task_obj.data_directory_path_map,
            "workspace": task_obj.workspace
        }
        # Run task and return Task object result
        new_task_obj = self.task_run_helper(task_dict, snapshot_dict,
                                            "cli.run.run")
        if not new_task_obj:
            return False
        # Creating the run object
        new_run_obj = Run(new_task_obj)
        return new_run_obj

    @Helper.notify_environment_active(TaskController)
    @Helper.notify_no_project_found
    def stop(self, **kwargs):
        self.task_controller = TaskController()
        input_dict = {}
        mutually_exclusive(["id", "all"], kwargs, input_dict)
        if "id" in input_dict:
            self.cli_helper.echo(__("info", "cli.run.stop", input_dict['id']))
        elif "all" in input_dict:
            self.cli_helper.echo(__("info", "cli.run.stop.all"))
        else:
            raise RequiredArgumentMissing()
        try:
            if "id" in input_dict:
                result = self.task_controller.stop(task_id=input_dict['id'])
                if not result:
                    self.cli_helper.echo(
                        __("error", "cli.run.stop", input_dict['id']))
                else:
                    self.cli_helper.echo(
                        __("info", "cli.run.stop.success", input_dict['id']))
            if "all" in input_dict:
                result = self.task_controller.stop(all=input_dict['all'])
                if not result:
                    self.cli_helper.echo(__("error", "cli.run.stop.all"))
                else:
                    self.cli_helper.echo(__("info",
                                            "cli.run.stop.all.success"))
            return result
        except Exception:
            if "id" in input_dict:
                self.cli_helper.echo(
                    __("error", "cli.run.stop", input_dict['id']))
            if "all" in input_dict:
                self.cli_helper.echo(__("error", "cli.run.stop.all"))
            return False

    @Helper.notify_environment_active(TaskController)
    @Helper.notify_no_project_found
    def delete(self, **kwargs):
        self.task_controller = TaskController()
        task_id = kwargs.get("id", None)
        if task_id:
            self.cli_helper.echo(__("info", "cli.run.delete", task_id))
        else:
            raise RequiredArgumentMissing()
        try:
            # Delete the task for the run
            result = self.task_controller.delete(task_id)
            if result:
                self.cli_helper.echo(
                    __("info", "cli.run.delete.success", task_id))
            return result
        except Exception:
            self.cli_helper.echo(__("error", "cli.run.delete", task_id))
            return False
Ejemplo n.º 28
0
class TestTaskController():
    def setup_method(self):
        # provide mountable tmp directory for docker
        tempfile.tempdir = "/tmp" if not platform.system() == "Windows" else None
        test_datmo_dir = os.environ.get('TEST_DATMO_DIR',
                                        tempfile.gettempdir())
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        self.project = ProjectController(self.temp_dir)
        self.project.init("test", "test description")
        self.environment = EnvironmentController(self.temp_dir)
        self.task = TaskController(self.temp_dir)

    def teardown_method(self):
        pass

    def test_create(self):
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_gpu = False
        input_dict = {
            "command": task_command,
            "gpu": task_gpu
        }

        # Create task in the project
        task_obj = self.task.create(input_dict)

        assert task_obj
        assert task_obj.command == task_command
        assert task_obj.gpu == task_gpu

    def test_run_helper(self):
        # TODO: Try out more options (see below)
        # Create environment_driver id
        env_def_path = os.path.join(self.project.home,
                                    "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        environment_obj = self.environment.create({
            "definition_filepath": env_def_path
        })

        # Set log filepath
        log_filepath = os.path.join(self.task.home,
                                    "test.log")

        # create volume to mount
        temp_test_dirpath = os.path.join(self.temp_dir, "temp")
        os.makedirs(temp_test_dirpath)

        # Test option set 1
        random_name = ''.join([random.choice(string.ascii_letters + string.digits)
                               for _ in range(32)])
        options_dict = {
            "command": ["sh", "-c", "echo accuracy:0.45"],
            "ports": ["8888:8888"],
            "gpu": False,
            "name": random_name,
            "volumes": {
                temp_test_dirpath: {
                    'bind': '/task/',
                    'mode': 'rw'
                }
            },
            "detach": False,
            "stdin_open": True,
            "tty": False,
            "api": False
        }

        return_code, run_id, logs = \
            self.task._run_helper(environment_obj.id,
                                  options_dict, log_filepath)
        assert return_code == 0
        assert run_id and \
               self.task.environment_driver.get_container(run_id)
        assert logs and \
               os.path.exists(log_filepath)
        self.task.environment_driver.stop_remove_containers_by_term(term=random_name)

        # Test option set 2
        random_name_2 = ''.join([random.choice(string.ascii_letters + string.digits)
                               for _ in range(32)])
        options_dict = {
            "command": ["sh", "-c", "echo accuracy:0.45"],
            "ports": ["8888:8888"],
            "gpu": False,
            "name": random_name_2 ,
            "volumes": {
                temp_test_dirpath: {
                    'bind': '/task/',
                    'mode': 'rw'
                }
            },
            "detach": False,
            "stdin_open": True,
            "tty": False,
            "api": True
        }

        return_code, run_id, logs = \
            self.task._run_helper(environment_obj.id,
                                  options_dict, log_filepath)
        assert return_code == 0
        assert run_id and \
               self.task.environment_driver.get_container(run_id)
        assert logs and \
               os.path.exists(log_filepath)
        self.task.environment_driver.stop_remove_containers_by_term(term=random_name_2)

    def test_parse_logs_for_results(self):
        test_logs = """
        this is a log
        accuracy is good
        accuracy : 0.94
        this did not work
        validation : 0.32
        model_type : logistic regression
        """
        result = self.task._parse_logs_for_results(test_logs)

        assert isinstance(result, dict)
        assert result['accuracy'] == "0.94"
        assert result['validation'] == "0.32"
        assert result['model_type'] == "logistic regression"

    def test_run(self):
        # 1) Test success case with default values and env def file
        # 2) Test failure case if running same task (conflicting containers)
        # 3) Test failure case if running same task with snapshot_dict (conflicting containers)
        # 4) Test success case with snapshot_dict
        # 5) Test success case with saved file during task run

        # TODO: look into log filepath randomness, sometimes logs are not written
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        input_dict = {
            "command": task_command
        }

        # Create task in the project
        task_obj = self.task.create(input_dict)

        # Create environment definition
        env_def_path = os.path.join(self.project.home,
                                    "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # 1) Test option 1
        updated_task_obj = self.task.run(task_obj.id)

        assert task_obj.id == updated_task_obj.id

        assert updated_task_obj.before_snapshot_id
        assert updated_task_obj.ports == None
        assert updated_task_obj.gpu == False
        assert updated_task_obj.interactive == False
        assert updated_task_obj.task_dirpath
        assert updated_task_obj.log_filepath
        assert updated_task_obj.start_time

        assert updated_task_obj.after_snapshot_id
        assert updated_task_obj.run_id
        assert updated_task_obj.logs
        assert "accuracy" in updated_task_obj.logs
        assert updated_task_obj.results
        assert updated_task_obj.results == {"accuracy": "0.45"}
        assert updated_task_obj.status == "SUCCESS"
        assert updated_task_obj.end_time
        assert updated_task_obj.duration

        # 2) Test option 2
        failed = False
        try:
             self.task.run(task_obj.id)
        except TaskRunException:
            failed = True
        assert failed

        # 3) Test option 3

        # Create files to add
        self.project.file_driver.create("dirpath1", directory=True)
        self.project.file_driver.create("dirpath2", directory=True)
        self.project.file_driver.create("filepath1")

        # Snapshot dictionary
        snapshot_dict = {
            "filepaths": [os.path.join(self.project.home, "dirpath1"),
                          os.path.join(self.project.home, "dirpath2"),
                          os.path.join(self.project.home, "filepath1")],
        }

        # Run a basic task in the project
        failed = False
        try:
            self.task.run(task_obj.id,
                          snapshot_dict=snapshot_dict)
        except TaskRunException:
            failed = True
        assert failed

        # Test when the specific task id is already RUNNING
        # Create task in the project
        task_obj_1 = self.task.create(input_dict)
        self.task.dal.task.update({"id": task_obj_1.id, "status": "RUNNING"})
        # Create environment_driver definition
        env_def_path = os.path.join(self.project.home,
                                    "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        failed = False
        try:
            self.task.run(task_obj_1.id)
        except TaskRunException:
            failed = True
        assert failed

        # 4) Test option 4

        # Create a new task in the project
        task_obj_2 = self.task.create(input_dict)

        # Run another task in the project
        updated_task_obj_2 = self.task.run(task_obj_2.id,
                                           snapshot_dict=snapshot_dict)

        assert task_obj_2.id == updated_task_obj_2.id

        assert updated_task_obj_2.before_snapshot_id
        assert updated_task_obj_2.ports == None
        assert updated_task_obj_2.gpu == False
        assert updated_task_obj_2.interactive == False
        assert updated_task_obj_2.task_dirpath
        assert updated_task_obj_2.log_filepath
        assert updated_task_obj_2.start_time

        assert updated_task_obj_2.after_snapshot_id
        assert updated_task_obj_2.run_id
        assert updated_task_obj_2.logs
        assert "accuracy" in updated_task_obj_2.logs
        assert updated_task_obj_2.results
        assert updated_task_obj_2.results == {"accuracy": "0.45"}
        assert updated_task_obj_2.status == "SUCCESS"
        assert updated_task_obj_2.end_time
        assert updated_task_obj_2.duration

        # 5) Test option 5

        # Create a basic script
        # (fails w/ no environment)
        test_filepath = os.path.join(self.temp_dir, "script.py")
        with open(test_filepath, "w") as f:
            f.write(to_unicode("import os\n"))
            f.write(to_unicode("import numpy\n"))
            f.write(to_unicode("import sklearn\n"))
            f.write(to_unicode("print('hello')\n"))
            f.write(to_unicode("print(' accuracy: 0.56 ')\n"))
            f.write(to_unicode("with open(os.path.join('/task', 'new_file.txt'), 'a') as f:\n"))
            f.write(to_unicode("    f.write('my test file')\n"))

        task_command = ["python", "script.py"]
        input_dict = {
            "command": task_command
        }

        # Create task in the project
        task_obj_2 = self.task.create(input_dict)

        # Create environment definition
        env_def_path = os.path.join(self.project.home,
                                    "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        updated_task_obj_2 = self.task.run(task_obj_2.id)

        assert updated_task_obj_2.before_snapshot_id
        assert updated_task_obj_2.ports == None
        assert updated_task_obj_2.gpu == False
        assert updated_task_obj_2.interactive == False
        assert updated_task_obj_2.task_dirpath
        assert updated_task_obj_2.log_filepath
        assert updated_task_obj_2.start_time

        assert updated_task_obj_2.after_snapshot_id
        assert updated_task_obj_2.run_id
        assert updated_task_obj_2.logs
        assert "accuracy" in updated_task_obj_2.logs
        assert updated_task_obj_2.results
        assert updated_task_obj_2.results == {"accuracy": "0.56"}
        assert updated_task_obj_2.status == "SUCCESS"
        assert updated_task_obj_2.end_time
        assert updated_task_obj_2.duration

        # test if after snapshot has the file written
        after_snapshot_obj = self.task.dal.snapshot.get_by_id(
            updated_task_obj_2.after_snapshot_id
        )
        file_collection_obj = self.task.dal.file_collection.get_by_id(
            after_snapshot_obj.file_collection_id
        )
        files_absolute_path = os.path.join(self.task.home, file_collection_obj.path)

        assert os.path.isfile(os.path.join(files_absolute_path, "task.log"))
        assert os.path.isfile(os.path.join(files_absolute_path, "new_file.txt"))

    def test_list(self):
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        input_dict = {
            "command": task_command
        }

        # Create tasks in the project
        task_obj_1 = self.task.create(input_dict)
        task_obj_2 = self.task.create(input_dict)

        # List all tasks regardless of filters
        result = self.task.list()

        assert len(result) == 2 and \
               task_obj_1 in result and \
               task_obj_2 in result

        # List all tasks and filter by session
        result = self.task.list(session_id=
                                self.project.current_session.id)

        assert len(result) == 2 and \
               task_obj_1 in result and \
               task_obj_2 in result

    def test_get_files(self):
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        input_dict = {
            "command": task_command
        }

        # Create task in the project
        task_obj = self.task.create(input_dict)

        # Create environment definition
        env_def_path = os.path.join(self.project.home,
                                    "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Create file to add
        self.project.file_driver.create("dirpath1", directory=True)
        self.project.file_driver.create(os.path.join("dirpath1", "filepath1"))

        # Snapshot dictionary
        snapshot_dict = {
            "filepaths": [os.path.join(self.project.home, "dirpath1", "filepath1")],
        }

        # Test the default values
        updated_task_obj = self.task.run(task_obj.id,
                                         snapshot_dict=snapshot_dict)

        # TODO: Test case for during run and before_snapshot run
        # Get files for the task after run is complete (default)
        result = self.task.get_files(updated_task_obj.id)

        after_snapshot_obj = self.task.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id
        )
        file_collection_obj = self.task.dal.file_collection.get_by_id(
            after_snapshot_obj.file_collection_id
        )

        assert len(result) == 2
        assert isinstance(result[0], TextIOWrapper)
        assert result[0].name == os.path.join(self.task.home, ".datmo",
                                              "collections",
                                              file_collection_obj.filehash,
                                              "task.log")
        assert result[0].mode == "r"
        assert isinstance(result[1], TextIOWrapper)
        assert result[1].name == os.path.join(self.task.home, ".datmo",
                                              "collections",
                                              file_collection_obj.filehash,
                                              "filepath1")
        assert result[1].mode == "r"

        # Get files for the task after run is complete for different mode
        result = self.task.get_files(updated_task_obj.id, mode="a")

        assert len(result) == 2
        assert isinstance(result[0], TextIOWrapper)
        assert result[0].name == os.path.join(self.task.home, ".datmo",
                                              "collections",
                                              file_collection_obj.filehash,
                                              "task.log")
        assert result[0].mode == "a"
        assert isinstance(result[1], TextIOWrapper)
        assert result[1].name == os.path.join(self.task.home, ".datmo",
                                              "collections",
                                              file_collection_obj.filehash,
                                              "filepath1")
        assert result[1].mode == "a"

    def test_delete(self):
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        input_dict = {
            "command": task_command
        }

        # Create tasks in the project
        task_obj = self.task.create(input_dict)

        # Delete task from the project
        result = self.task.delete(task_obj.id)

        # Check if task retrieval throws error
        thrown = False
        try:
            self.task.dal.snapshot.get_by_id(task_obj.id)
        except EntityNotFound:
            thrown = True

        assert result == True and \
               thrown == True

    def test_stop(self):
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        input_dict = {
            "command": task_command
        }

        # Create task in the project
        task_obj = self.task.create(input_dict)

        # Create environment driver definition
        env_def_path = os.path.join(self.project.home,
                                    "Dockerfile")
        with open(env_def_path, "w") as f:
            f.write(to_unicode(str("FROM datmo/xgboost:cpu")))

        # Test the default values
        updated_task_obj = self.task.run(task_obj.id)

        # Stop the task
        task_id = updated_task_obj.id
        result = self.task.stop(task_id)

        # Check if task stop throws error when wrong task id is given
        thrown = False
        try:
            self.task.dal.snapshot.get_by_id(task_obj.id)
        except EntityNotFound:
            thrown = True

        assert result == True and \
               thrown == True
Ejemplo n.º 29
0
Archivo: base.py Proyecto: yyht/datmo
class BaseCommand(object):
    def __init__(self, cli_helper):
        self.home = Config().home
        self.cli_helper = cli_helper
        self.logger = DatmoLogger.get_logger(__name__)
        self.parser = get_datmo_parser()

    def parse(self, args):
        try:
            self.display_usage_message(args)
            self.args = self.parser.parse_args(args)
        except SystemExit:
            self.args = True
            pass

    def display_usage_message(self, args):
        """ Checks to see if --help or -h is passed in, and if so it calls our usage()
        if it exists.

        Since argparser thinks it is clever and automatically
        handles [--help, -h] we need a hook to be able to display
        our own usage notes before argparse

        Parameters
        ----------
        args : array[string]
            command arguments
        """

        is_help = -1
        if "--help" in args:
            is_help = args.index("--help")
        if is_help == -1 and "-h" in args:
            is_help = args.index("-h")

        if is_help > -1 and hasattr(self, "usage"):
            self.usage()

    def execute(self):
        """
        Calls the method if it exists on this object, otherwise
        call a default method name (module name)

        Raises
        ------
        ClassMethodNotFound
            If the Class method is not found
        """
        # Sometimes eg(--help) the parser automagically handles the entire response
        # and calls exit.  If this happens, self.args is set to True
        # in base.parse.   Simply return True
        if self.args is True: return True

        if getattr(self.args, "command") is None:
            self.args.command = "datmo"

        command_args = vars(self.args).copy()
        # use command name if it exists,
        # otherwise use the module name
        function_name = None
        method = None

        try:
            if "subcommand" in command_args and command_args[
                    'subcommand'] is not None:
                function_name = getattr(self.args, "subcommand",
                                        self.args.command)
                method = getattr(self, function_name)
            else:
                function_name = getattr(self.args, "command",
                                        self.args.command)
                method = getattr(self, function_name)
        except AttributeError:
            raise ClassMethodNotFound(
                __("error", "cli.general.method.not_found",
                   (self.args.command, function_name)))

        # remove extraneous options that the method should need to care about
        if "command" in command_args:
            del command_args["command"]
        if "subcommand" in command_args:
            del command_args["subcommand"]

        if method is None:
            raise ClassMethodNotFound(
                __("error", "cli.general.method.not_found",
                   (self.args.command, method)))

        method_result = method(**command_args)
        return method_result

    def task_run_helper(self,
                        task_dict,
                        snapshot_dict,
                        error_identifier,
                        data_paths=None):
        """
        Run task with given parameters and provide error identifier

        Parameters
        ----------
        task_dict : dict
            input task dictionary for task run controller
        snapshot_dict : dict
            input snapshot dictionary for task run controller
        error_identifier : str
            identifier to print error
        data_paths : list
            list of data paths being passed for task run

        Returns
        -------
        Task or False
            the Task object which completed its run with updated parameters.
            returns False if an error occurs
        """
        self.task_controller = TaskController()
        task_obj = self.task_controller.create()

        updated_task_obj = task_obj
        # Pass in the task
        status = "NOT STARTED"
        try:
            if data_paths:
                try:
                    _, _, task_dict['data_file_path_map'], task_dict['data_directory_path_map'] = \
                        parse_paths(self.task_controller.home, data_paths, '/data')
                except PathDoesNotExist as e:
                    status = "NOT STARTED"
                    workspace = task_dict.get('workspace', None)
                    command = task_dict.get('command', None)
                    command_list = task_dict.get('command_list', None)
                    interactive = task_dict.get('interactive', False)
                    self.task_controller.update(task_obj.id,
                                                workspace=workspace,
                                                command=command,
                                                command_list=command_list,
                                                interactive=interactive)
                    self.cli_helper.echo(
                        __("error", "cli.run.parse.paths", str(e)))
                    return False

            updated_task_obj = self.task_controller.run(
                task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
            status = "SUCCESS"
            self.cli_helper.echo(__("info", "cli.run.run.stop"))
        except Exception as e:
            status = "FAILED"
            self.logger.error("%s %s" % (e, task_dict))
            self.cli_helper.echo("%s" % e)
            self.cli_helper.echo(__("error", error_identifier, task_obj.id))
            return False
        finally:
            self.task_controller.stop(task_id=updated_task_obj.id,
                                      status=status)
        self.cli_helper.echo(
            __("info", "cli.run.run.complete", updated_task_obj.id))

        return updated_task_obj
Ejemplo n.º 30
0
class TestSnapshotController():
    def setup_method(self):
        self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
        Config().set_home(self.temp_dir)
        self.environment_ids = []

    def teardown_method(self):
        if not check_docker_inactive(test_datmo_dir,
                                     Config().datmo_directory_name):
            self.__setup()
            self.environment_controller = EnvironmentController()
            for env_id in list(set(self.environment_ids)):
                if not self.environment_controller.delete(env_id):
                    raise Exception

    def __setup(self):
        Config().set_home(self.temp_dir)
        self.project_controller = ProjectController()
        self.project_controller.init("test", "test description")
        self.task_controller = TaskController()
        self.snapshot_controller = SnapshotController()

    def test_init_fail_project_not_init(self):
        Config().set_home(self.temp_dir)
        failed = False
        try:
            SnapshotController()
        except ProjectNotInitialized:
            failed = True
        assert failed

    def test_init_fail_invalid_path(self):
        test_home = "some_random_dir"
        Config().set_home(test_home)
        failed = False
        try:
            SnapshotController()
        except InvalidProjectPath:
            failed = True
        assert failed

    def test_current_snapshot(self):
        self.__setup()
        # Test failure for unstaged changes
        failed = False
        try:
            self.snapshot_controller.current_snapshot()
        except UnstagedChanges:
            failed = True
        assert failed
        # Test success after snapshot created
        snapshot_obj = self.__default_create()
        current_snapshot_obj = self.snapshot_controller.current_snapshot()
        assert current_snapshot_obj == snapshot_obj

    def test_create_fail_no_message(self):
        self.__setup()
        # Test no message
        failed = False
        try:
            self.snapshot_controller.create({})
        except RequiredArgumentMissing:
            failed = True
        assert failed

    def test_create_success_no_code(self):
        self.__setup()
        # Test default values for snapshot, fail due to code
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})
        assert result

    def test_create_success_no_code_environment(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # test must pass when there is file present in root project folder
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert result

    def test_create_success_no_code_environment_files(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        test_file = os.path.join(
            self.project_controller.file_driver.files_directory, "test.txt")
        with open(test_file, "wb") as f:
            f.write(to_bytes(str("hello")))

        # test must pass when there is file present in root project folder
        result = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert result

    def test_create_no_environment_detected_in_file(self):
        self.__setup()

        # Test default values for snapshot, fail due to no environment from file
        self.snapshot_controller.file_driver.create("filepath1")
        snapshot_obj_0 = self.snapshot_controller.create(
            {"message": "my test snapshot"})
        assert isinstance(snapshot_obj_0, Snapshot)
        assert snapshot_obj_0.code_id
        assert snapshot_obj_0.environment_id
        assert snapshot_obj_0.file_collection_id
        assert snapshot_obj_0.config == {}
        assert snapshot_obj_0.stats == {}

    def test_create_success_default_detected_in_file(self):
        self.__setup()
        # Test default values for snapshot when there is no environment
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import os\n"))
            f.write(to_bytes("import sys\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj_1 = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj_1, Snapshot)
        assert snapshot_obj_1.code_id
        assert snapshot_obj_1.environment_id
        assert snapshot_obj_1.file_collection_id
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_default_env_def(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_with_environment(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_env_paths(self):
        self.__setup()
        # Create environment definition
        random_dir = os.path.join(self.snapshot_controller.home, "random_dir")
        os.makedirs(random_dir)
        env_def_path = os.path.join(random_dir, "randomDockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))
        environment_paths = [env_def_path + ">Dockerfile"]

        # Test default values for snapshot, success
        snapshot_obj = self.snapshot_controller.create({
            "message":
            "my test snapshot",
            "environment_paths":
            environment_paths
        })

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.code_id
        assert snapshot_obj.environment_id
        assert snapshot_obj.file_collection_id
        assert snapshot_obj.config == {}
        assert snapshot_obj.stats == {}

    def test_create_success_default_env_def_duplicate(self):
        self.__setup()
        # Test 2 snapshots with same parameters
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        snapshot_obj_1 = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Should return the same object back
        assert snapshot_obj_1.id == snapshot_obj.id
        assert snapshot_obj_1.code_id == snapshot_obj.code_id
        assert snapshot_obj_1.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_1.file_collection_id == \
               snapshot_obj.file_collection_id
        assert snapshot_obj_1.config == \
               snapshot_obj.config
        assert snapshot_obj_1.stats == \
               snapshot_obj.stats

    def test_create_success_given_files_env_def_config_file_stats_file(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        input_dict = {
            "message": "my test snapshot",
            "config_filepath": config_filepath,
            "stats_filepath": stats_filepath,
        }
        # Create snapshot in the project
        snapshot_obj_4 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_4 != snapshot_obj
        assert snapshot_obj_4.code_id != snapshot_obj.code_id
        assert snapshot_obj_4.environment_id == \
               snapshot_obj.environment_id
        assert snapshot_obj_4.file_collection_id != \
               snapshot_obj.file_collection_id
        assert snapshot_obj_4.config == {"foo": "bar"}
        assert snapshot_obj_4.stats == {"foo": "bar"}

    def test_create_success_given_files_env_def_different_config_stats(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        snapshot_obj = self.snapshot_controller.create(
            {"message": "my test snapshot"})

        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str('{"foo":"bar"}')))

        # Test different config and stats inputs
        input_dict = {
            "message": "my test snapshot",
            "config_filename": "different_name",
            "stats_filename": "different_name",
        }

        # Create snapshot in the project
        snapshot_obj_1 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_1 != snapshot_obj
        assert snapshot_obj_1.config == {}
        assert snapshot_obj_1.stats == {}

    def test_create_success_given_files_env_def_direct_config_stats(self):
        self.__setup()
        # Create environment definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))

        # Creating a file in project folder
        test_filepath = os.path.join(self.snapshot_controller.home,
                                     "script.py")
        with open(test_filepath, "wb") as f:
            f.write(to_bytes("import numpy\n"))
            f.write(to_bytes("import sklearn\n"))
            f.write(to_bytes("print('hello')\n"))

        # Test different config and stats inputs
        input_dict = {
            "message": "my test snapshot",
            "config": {
                "foo": "bar"
            },
            "stats": {
                "foo": "bar"
            },
        }

        # Create snapshot in the project
        snapshot_obj_6 = self.snapshot_controller.create(input_dict)

        assert snapshot_obj_6.config == {"foo": "bar"}
        assert snapshot_obj_6.stats == {"foo": "bar"}

    @pytest_docker_environment_failed_instantiation(test_datmo_dir)
    def test_create_from_task(self):
        self.__setup()
        # 0) Test if fails with TaskNotComplete error
        # 1) Test if success with empty task files, results
        # 2) Test if success with task files, results, and message
        # 3) Test if success with message, label, config and stats
        # 4) Test if success with updated stats from after_snapshot_id and task_results

        # Create task in the project
        task_obj = self.task_controller.create()

        # 0) Test option 0
        failed = False
        try:
            _ = self.snapshot_controller.create_from_task(
                message="my test snapshot", task_id=task_obj.id)
        except TaskNotComplete:
            failed = True
        assert failed

        # 1) Test option 1

        # Create task_dict
        task_command = ["sh", "-c", "echo test"]
        task_dict = {"command_list": task_command}

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

        # Create new task and corresponding dict
        task_obj = self.task_controller.create()
        task_command = ["sh", "-c", "echo accuracy:0.45"]
        task_dict = {"command_list": task_command}

        # Create environment definition
        env_def_path = os.path.join(self.project_controller.home, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Test the default values
        updated_task_obj = self.task_controller.run(task_obj.id,
                                                    task_dict=task_dict)
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        # 2) Test option 2
        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot", task_id=updated_task_obj.id)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.stats == updated_task_obj.results
        assert snapshot_obj.visible == True

        # 3) Test option 3
        test_config = {"algo": "regression"}
        test_stats = {"accuracy": 0.9}
        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot",
            task_id=updated_task_obj.id,
            label="best",
            config=test_config,
            stats=test_stats)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.label == "best"
        assert snapshot_obj.config == test_config
        assert snapshot_obj.stats == test_stats
        assert snapshot_obj.visible == True

        # 4) Test option 4
        test_config = {"algo": "regression"}
        test_stats = {"new_key": 0.9}
        task_obj_2 = self.task_controller.create()
        updated_task_obj_2 = self.task_controller.run(task_obj_2.id,
                                                      task_dict=task_dict,
                                                      snapshot_dict={
                                                          "config":
                                                          test_config,
                                                          "stats": test_stats
                                                      })
        after_snapshot_obj = self.task_controller.dal.snapshot.get_by_id(
            updated_task_obj_2.after_snapshot_id)
        environment_obj = self.task_controller.dal.environment.get_by_id(
            after_snapshot_obj.environment_id)
        self.environment_ids.append(environment_obj.id)

        snapshot_obj = self.snapshot_controller.create_from_task(
            message="my test snapshot",
            task_id=updated_task_obj_2.id,
            label="best")
        updated_stats_dict = {}
        updated_stats_dict.update(test_stats)
        updated_stats_dict.update(updated_task_obj.results)

        assert isinstance(snapshot_obj, Snapshot)
        assert snapshot_obj.id == updated_task_obj_2.after_snapshot_id
        assert snapshot_obj.message == "my test snapshot"
        assert snapshot_obj.label == "best"
        assert snapshot_obj.stats == updated_stats_dict
        assert snapshot_obj.visible == True

    def __default_create(self):
        # Create files to add
        _, files_directory_name = os.path.split(
            self.project_controller.file_driver.files_directory)
        files_directory_relative_path = os.path.join(
            self.project_controller.file_driver.datmo_directory_name,
            files_directory_name)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath1"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(os.path.join(
            files_directory_relative_path, "dirpath2"),
                                                    directory=True)
        self.snapshot_controller.file_driver.create(
            os.path.join(files_directory_relative_path, "filepath1"))
        self.snapshot_controller.file_driver.create("filepath2")
        with open(os.path.join(self.snapshot_controller.home, "filepath2"),
                  "wb") as f:
            f.write(to_bytes(str("import sys\n")))
        # Create environment_driver definition
        env_def_path = os.path.join(
            self.project_controller.environment_driver.
            environment_directory_path, "Dockerfile")
        with open(env_def_path, "wb") as f:
            f.write(to_bytes("FROM python:3.5-alpine"))

        # Create config
        config_filepath = os.path.join(self.snapshot_controller.home,
                                       "config.json")
        with open(config_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        # Create stats
        stats_filepath = os.path.join(self.snapshot_controller.home,
                                      "stats.json")
        with open(stats_filepath, "wb") as f:
            f.write(to_bytes(str("{}")))

        input_dict = {
            "message": "my test snapshot",
            "config_filename": config_filepath,
            "stats_filename": stats_filepath,
        }

        # Create snapshot in the project
        return self.snapshot_controller.create(input_dict)

    def test_check_unstaged_changes(self):
        self.__setup()
        # Check unstaged changes
        failed = False
        try:
            self.snapshot_controller.check_unstaged_changes()
        except UnstagedChanges:
            failed = True
        assert failed
        # Check no unstaged changes
        _ = self.__default_create()
        result = self.snapshot_controller.check_unstaged_changes()
        assert result == False

    def test_checkout(self):
        self.__setup()
        # Create snapshot
        snapshot_obj_1 = self.__default_create()

        # Create duplicate snapshot in project
        self.snapshot_controller.file_driver.create("test")
        snapshot_obj_2 = self.__default_create()

        assert snapshot_obj_2 != snapshot_obj_1

        # Checkout to snapshot 1 using snapshot id
        result = self.snapshot_controller.checkout(snapshot_obj_1.id)
        # TODO: Check for which snapshot we are on

        assert result == True

    def test_list(self):
        self.__setup()
        # Create file to add to snapshot
        test_filepath_1 = os.path.join(self.snapshot_controller.home,
                                       "test.txt")
        with open(test_filepath_1, "wb") as f:
            f.write(to_bytes(str("test")))

        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Create file to add to second snapshot
        test_filepath_2 = os.path.join(self.snapshot_controller.home,
                                       "test2.txt")
        with open(test_filepath_2, "wb") as f:
            f.write(to_bytes(str("test2")))

        # Create second snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # List all snapshots and ensure they exist
        result = self.snapshot_controller.list()

        assert len(result) == 2 and \
            snapshot_obj_1 in result and \
            snapshot_obj_2 in result

        # List all tasks regardless of filters in ascending
        result = self.snapshot_controller.list(sort_key='created_at',
                                               sort_order='ascending')

        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result
        assert result[0].created_at <= result[-1].created_at

        # List all tasks regardless of filters in descending
        result = self.snapshot_controller.list(sort_key='created_at',
                                               sort_order='descending')
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result
        assert result[0].created_at >= result[-1].created_at

        # Wrong order being passed in
        failed = False
        try:
            _ = self.snapshot_controller.list(sort_key='created_at',
                                              sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # Wrong key and order being passed in
        failed = False
        try:
            _ = self.snapshot_controller.list(sort_key='wrong_key',
                                              sort_order='wrong_order')
        except InvalidArgumentType:
            failed = True
        assert failed

        # wrong key and right order being passed in
        expected_result = self.snapshot_controller.list(sort_key='created_at',
                                                        sort_order='ascending')
        result = self.snapshot_controller.list(sort_key='wrong_key',
                                               sort_order='ascending')
        expected_ids = [item.id for item in expected_result]
        ids = [item.id for item in result]
        assert set(expected_ids) == set(ids)

        # List snapshots with visible filter
        result = self.snapshot_controller.list(visible=False)
        assert len(result) == 0

        result = self.snapshot_controller.list(visible=True)
        assert len(result) == 2 and \
               snapshot_obj_1 in result and \
               snapshot_obj_2 in result

    def test_update(self):
        self.__setup()
        test_config = {"config_foo": "bar"}
        test_stats = {"stats_foo": "bar"}
        test_message = 'test_message'
        test_label = 'test_label'

        # Updating all config, stats, message and label
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        config=test_config,
                                        stats=test_stats,
                                        message=test_message,
                                        label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)
        assert updated_snapshot_obj.config == test_config
        assert updated_snapshot_obj.stats == test_stats
        assert updated_snapshot_obj.message == test_message
        assert updated_snapshot_obj.label == test_label

        # Updating config, stats
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        config=test_config,
                                        stats=test_stats)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)
        assert updated_snapshot_obj.config == test_config
        assert updated_snapshot_obj.stats == test_stats

        # Updating both message and label
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj.id,
                                        message=test_message,
                                        label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj.id)

        assert updated_snapshot_obj.message == test_message
        assert updated_snapshot_obj.label == test_label

        # Updating only message
        # Create snapshot in the project
        snapshot_obj_1 = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj_1.id,
                                        message=test_message)

        # Get the updated snapshot obj
        updated_snapshot_obj_1 = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj_1.id)

        assert updated_snapshot_obj_1.message == test_message

        # Updating only label
        # Create snapshot in the project
        snapshot_obj_2 = self.__default_create()

        # Update snapshot in the project
        self.snapshot_controller.update(snapshot_obj_2.id, label=test_label)

        # Get the updated snapshot obj
        updated_snapshot_obj_2 = self.snapshot_controller.dal.snapshot.get_by_id(
            snapshot_obj_2.id)

        assert updated_snapshot_obj_2.label == test_label

    def test_get(self):
        self.__setup()
        # Test failure for no snapshot
        failed = False
        try:
            self.snapshot_controller.get("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Test success for snapshot
        snapshot_obj = self.__default_create()
        snapshot_obj_returned = self.snapshot_controller.get(snapshot_obj.id)
        assert snapshot_obj == snapshot_obj_returned

    def test_get_files(self):
        self.__setup()
        # Test failure case
        failed = False
        try:
            self.snapshot_controller.get_files("random")
        except DoesNotExist:
            failed = True
        assert failed

        # Test success case
        snapshot_obj = self.__default_create()
        result = self.snapshot_controller.get_files(snapshot_obj.id)
        file_collection_obj = self.task_controller.dal.file_collection.get_by_id(
            snapshot_obj.file_collection_id)

        file_names = [item.name for item in result]

        assert len(result) == 1
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "r"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

        result = self.snapshot_controller.get_files(snapshot_obj.id, mode="a")

        assert len(result) == 1
        for item in result:
            assert isinstance(item, TextIOWrapper)
            assert item.mode == "a"
        assert os.path.join(self.task_controller.home, ".datmo", "collections",
                            file_collection_obj.filehash,
                            "filepath1") in file_names

    def test_delete(self):
        self.__setup()
        # Create snapshot in the project
        snapshot_obj = self.__default_create()

        # Delete snapshot in the project
        result = self.snapshot_controller.delete(snapshot_obj.id)

        # Check if snapshot retrieval throws error
        thrown = False
        try:
            self.snapshot_controller.dal.snapshot.get_by_id(snapshot_obj.id)
        except EntityNotFound:
            thrown = True

        assert result == True and \
            thrown == True