Exemplo n.º 1
0
 def setUp(self):
     super(TestProject, self).setUp()
     self.backend = MagicMock()
     self.project = Project(backend=self.backend,
                            internal_id=a_uuid_string(),
                            namespace=a_string(),
                            name=a_string())
     self.current_directory = os.getcwd()
Exemplo n.º 2
0
    def test_get_projects_with_given_namespace(self, _):
        # given
        session = Session(API_TOKEN)

        # and
        api_projects = [a_project(), a_project()]

        # and
        session._client.get_projects.return_value = api_projects

        # and
        custom_namespace = 'custom_namespace'

        # when
        projects = session.get_projects(custom_namespace)

        # then
        expected_projects = {
            custom_namespace + '/' + p.name: Project(session._client, p.id,
                                                     custom_namespace, p.name)
            for p in api_projects
        }
        self.assertEqual(expected_projects, projects)

        # and
        session._client.get_projects.assert_called_with(custom_namespace)
Exemplo n.º 3
0
    def test_get_projects_with_given_namespace(self, _):
        # given
        api_projects = [a_project(), a_project()]

        # and
        backend = MagicMock()
        leaderboard = MagicMock()
        backend.get_projects.return_value = api_projects
        backend.create_leaderboard_backend.return_value = leaderboard

        # and
        session = Session(backend=backend)

        # and
        custom_namespace = "custom_namespace"

        # when
        projects = session.get_projects(custom_namespace)

        # then
        expected_projects = OrderedDict((
            custom_namespace + "/" + p.name,
            Project(leaderboard, p.id, custom_namespace, p.name),
        ) for p in api_projects)
        self.assertEqual(expected_projects, projects)

        # and
        backend.get_projects.assert_called_with(custom_namespace)
Exemplo n.º 4
0
    def test_get_projects_with_given_namespace(self, _):
        # given
        api_projects = [a_project(), a_project()]

        # and
        backend = MagicMock()
        backend.get_projects.return_value = api_projects

        # and
        session = Session(backend=backend)

        # and
        custom_namespace = 'custom_namespace'

        # when
        projects = session.get_projects(custom_namespace)

        # then
        expected_projects = OrderedDict(
            (custom_namespace + '/' + p.name,
             Project(backend, p.id, custom_namespace, p.name))
            for p in api_projects)
        self.assertEqual(expected_projects, projects)

        # and
        backend.get_projects.assert_called_with(custom_namespace)
Exemplo n.º 5
0
def a_project():
    return Project(
        backend=mock.MagicMock(),
        internal_id=a_uuid_string(),
        name=a_string(),
        namespace=a_string(),
    )
Exemplo n.º 6
0
    def get_projects(self, namespace):
        """Get all projects that you have permissions to see in given workspace.

        | This method gets you all available projects names and their
          corresponding :class:`~neptune.projects.Project` objects.
        | Both private and public projects may be returned for the workspace.
          If you have role in private project, it is included.
        | You can retrieve all the public projects that belong to any user or workspace,
          as long as you know their username or workspace name.

        Args:
            namespace (:obj:`str`): It can either be name of the workspace or username.

        Returns:
            :obj:`OrderedDict`
                | **keys** are ``project_qualified_name`` that is: *'workspace/project_name'*
                | **values** are corresponding :class:`~neptune.projects.Project` objects.

        Raises:
            `WorkspaceNotFound`: When the given workspace does not exist.

        Examples:

            .. code:: python3

                # create Session
                from neptune.sessions import Session
                session = Session()

                # Now, you can list all the projects available for a selected namespace.
                # You can use `YOUR_NAMESPACE` which is your workspace or user name.
                # You can also list public projects created in other workspaces.
                # For example you can use the `neptune-ai` namespace.

                session.get_projects('neptune-ai')

                # Example output:
                # OrderedDict([('neptune-ai/credit-default-prediction',
                #               Project(neptune-ai/credit-default-prediction)),
                #              ('neptune-ai/GStore-Customer-Revenue-Prediction',
                #               Project(neptune-ai/GStore-Customer-Revenue-Prediction)),
                #              ('neptune-ai/human-protein-atlas',
                #               Project(neptune-ai/human-protein-atlas)),
                #              ('neptune-ai/Ships',
                #               Project(neptune-ai/Ships)),
                #              ('neptune-ai/Mapping-Challenge',
                #               Project(neptune-ai/Mapping-Challenge))
                #              ])
        """

        projects = [
            Project(
                self._backend.create_leaderboard_backend(p), p.id, namespace, p.name
            )
            for p in self._backend.get_projects(namespace)
        ]
        return OrderedDict((p.full_id, p) for p in projects)
    def get_project(self, project_qualified_name):
        try:
            project = self.backend_swagger_client.api.getProject(
                projectIdentifier=project_qualified_name).response().result

            return Project(backend=self,
                           internal_id=project.id,
                           namespace=project.organizationName,
                           name=project.name)
        except HTTPNotFound:
            raise ProjectNotFound(project_qualified_name)
Exemplo n.º 8
0
def fast_experiment(project: Project,nb_name:str,globs:dict,return_files: bool = True,
                    default:str = "main.py",**kwargs) -> Experiment:
    """Creates a Neptune ML experiment, wrapped with meta data.

    Args:
        project: Neptune Project
        nb_name: str name of the current notebook to be recorded
        globs: dict of the global variables. Simply set globs = globals() and then pass it.
        return_files: bool, True if we want to send files recorded in the parameters.
        default: str name of the default code
        kwargs: additional args passed to Neptune ML when the experiment is created

    Returns:
        exp: Neptune ML experiment
    """
    # First we get the code cells
    codes = get_codes(nb_name,default=default)

    # We write them in separate files
    for fn,code in codes.items():
        with open(fn,"w") as file:
            file.write(code)

    codes = list(codes.keys())

    # We get the properties
    properties,files = get_properties_from_cells(nb_name,globs=globs,return_files=return_files)
    metadata = get_metadata()
    properties.update(metadata)
    properties["nb_name"] = nb_name

    # We convert the dict keys to string
    for k,v in properties.items():
        properties[k] = str(v)

    exp = project.create_experiment(params=properties,upload_source_files=codes,**kwargs)

    # We create the requirements file and send it
    create_requirements(nb_name)
    exp.send_artifact("requirements.txt")

    for fn in files:
        exp.send_artifact(fn)

    yield exp

    exp.stop()

    # We remove the code files
    for fn in codes:
        os.remove(fn)

    os.remove("requirements.txt")
Exemplo n.º 9
0
 def get_project(self, project_qualified_name):
     """
     Raises:
         `ProjectNotFound`: When a project with given name does not exist.
     """
     if not project_qualified_name:
         raise ProjectNotFound(project_qualified_name)
     project = self._client.get_project(project_qualified_name)
     return Project(client=self._client,
                    internal_id=project.id,
                    namespace=project.organizationName,
                    name=project.name)
Exemplo n.º 10
0
def fast_experiment(project: Project,
                    nb_name: str,
                    globs: dict,
                    return_files: bool = True,
                    default: str = "main.py",
                    **kwargs) -> Experiment:
    # First we get the code cells
    codes = get_codes(nb_name, default=default)

    # We write them in separate files
    for fn, code in codes.items():
        with open(fn, "w") as file:
            file.write(code)

    codes = list(codes.keys())

    # We get the properties
    properties, files = get_properties_from_cells(nb_name,
                                                  globs=globs,
                                                  return_files=return_files)
    metadata = get_metadata()
    properties.update(metadata)
    properties["nb_name"] = nb_name

    # We convert the dict keys to string
    for k, v in properties.items():
        properties[k] = str(v)

    exp = project.create_experiment(params=properties,
                                    upload_source_files=codes,
                                    **kwargs)

    # We create the requirements file and send it
    create_requirements(nb_name)
    exp.send_artifact("requirements.txt")

    for fn in files:
        exp.send_artifact(fn)

    yield exp

    exp.stop()

    # We remove the code files
    for fn in codes:
        os.remove(fn)

    os.remove("requirements.txt")
    def get_project(self, project_qualified_name):
        try:
            response = self.backend_swagger_client.api.getProject(
                projectIdentifier=project_qualified_name).response()
            warning = response.metadata.headers.get('X-Server-Warning')
            if warning:
                click.echo('{warning}{content}{end}'.format(content=warning,
                                                            **STYLES))
            project = response.result

            return Project(backend=self,
                           internal_id=project.id,
                           namespace=project.organizationName,
                           name=project.name)
        except HTTPNotFound:
            raise ProjectNotFound(project_qualified_name)
Exemplo n.º 12
0
    def get_projects(self, namespace):
        """It gets all project and full project names for given namespace

        In order to access experiment data one needs to get a `Project` object first. This method helps you figure out
        what are the available projects and access the project of interest.
        You can list both your private and public projects.
        You can also access all the public projects that belong to any user or organization,
        as long as you know what is their namespace.

        Args:
            namespace(str): It can either be your organization or user name. You can list all the public projects
                for any organization or user you want as long as you know their namespace.

        Returns:
            dict: Dictionary of "NAMESPACE/PROJECT_NAME" and `neptune.project.Project` object pairs that contains
            all the projects that belong to the selected namespace.

        Raises:
            `NamespaceNotFound`: When the given namespace does not exist.

        Examples:
            First, you need to create a Session instance:

            >>> from neptune.sessions import Session
            >>> session = Session()

            Now, you can list all the projects available for a selected namespace. You can use `YOUR_NAMESPACE` which
            is your organization or user name. You can also list public projects created by other organizations.
            For example you can use the `neptune-ml` namespace.

            >>> session.get_projects('neptune-ml')
            {'neptune-ml/Sandbox': Project(neptune-ml/Sandbox),
            'neptune-ml/Home-Credit-Default-Risk': Project(neptune-ml/Home-Credit-Default-Risk),
            'neptune-ml/Mapping-Challenge': Project(neptune-ml/Mapping-Challenge),
            'neptune-ml/Ships': Project(neptune-ml/Ships),
            'neptune-ml/human-protein-atlas': Project(neptune-ml/human-protein-atlas),
            'neptune-ml/Salt-Detection': Project(neptune-ml/Salt-Detection),
            'neptune-ml/Data-Science-Bowl-2018': Project(neptune-ml/Data-Science-Bowl-2018)}
        """

        projects = [
            Project(self._client, p.id, namespace, p.name)
            for p in self._client.get_projects(namespace)
        ]
        return OrderedDict((p.full_id, p) for p in projects)
Exemplo n.º 13
0
class TestProject(unittest.TestCase):
    def setUp(self):
        super(TestProject, self).setUp()
        self.backend = MagicMock()
        self.project = Project(backend=self.backend,
                               internal_id=a_uuid_string(),
                               namespace=a_string(),
                               name=a_string())
        self.current_directory = os.getcwd()

    def tearDown(self):
        os.chdir(self.current_directory)

    def test_get_members(self):
        # given
        member_usernames = [a_string() for _ in range(0, 2)]
        members = [
            a_registered_project_member(username)
            for username in member_usernames
        ]

        # and
        self.backend.get_project_members.return_value = members + [
            an_invited_project_member()
        ]

        # when
        fetched_member_usernames = self.project.get_members()

        # then
        self.backend.get_project_members.assert_called_once_with(
            self.project.internal_id)

        # and
        self.assertEqual(member_usernames, fetched_member_usernames)

    def test_get_experiments_with_no_params(self):
        # given
        leaderboard_entries = [MagicMock() for _ in range(0, 2)]
        self.backend.get_leaderboard_entries.return_value = leaderboard_entries

        # when
        experiments = self.project.get_experiments()

        # then
        self.backend.get_leaderboard_entries.assert_called_once_with(
            project=self.project,
            ids=None,
            states=None,
            owners=None,
            tags=None,
            min_running_time=None)

        # and
        expected_experiments = [
            Experiment(self.backend, self.project, entry.id, entry.internal_id)
            for entry in leaderboard_entries
        ]
        self.assertEqual(expected_experiments, experiments)

    def test_get_experiments_with_scalar_params(self):
        # given
        leaderboard_entries = [MagicMock() for _ in range(0, 2)]
        self.backend.get_leaderboard_entries.return_value = leaderboard_entries

        # and
        params = dict(id=a_string(),
                      state='succeeded',
                      owner=a_string(),
                      tag=a_string(),
                      min_running_time=randint(1, 100))

        # when
        experiments = self.project.get_experiments(**params)

        # then
        expected_params = dict(project=self.project,
                               ids=[params['id']],
                               states=[params['state']],
                               owners=[params['owner']],
                               tags=[params['tag']],
                               min_running_time=params['min_running_time'])
        self.backend.get_leaderboard_entries.assert_called_once_with(
            **expected_params)

        # and
        expected_experiments = [
            Experiment(self.backend, self.project, entry.id, entry.internal_id)
            for entry in leaderboard_entries
        ]
        self.assertEqual(expected_experiments, experiments)

    def test_get_experiments_with_list_params(self):
        # given
        leaderboard_entries = [MagicMock() for _ in range(0, 2)]
        self.backend.get_leaderboard_entries.return_value = leaderboard_entries

        # and
        params = dict(id=a_string_list(),
                      state=['succeeded', 'failed'],
                      owner=a_string_list(),
                      tag=a_string_list(),
                      min_running_time=randint(1, 100))

        # when
        experiments = self.project.get_experiments(**params)

        # then
        expected_params = dict(project=self.project,
                               ids=params['id'],
                               states=params['state'],
                               owners=params['owner'],
                               tags=params['tag'],
                               min_running_time=params['min_running_time'])
        self.backend.get_leaderboard_entries.assert_called_once_with(
            **expected_params)

        # and
        expected_experiments = [
            Experiment(self.backend, self.project, entry.id, entry.internal_id)
            for entry in leaderboard_entries
        ]
        self.assertEqual(expected_experiments, experiments)

    def test_get_leaderboard(self):
        # given
        self.backend.get_leaderboard_entries.return_value = [
            LeaderboardEntry(some_exp_entry_dto)
        ]

        # when
        leaderboard = self.project.get_leaderboard()

        # then
        self.backend.get_leaderboard_entries.assert_called_once_with(
            project=self.project,
            ids=None,
            states=None,
            owners=None,
            tags=None,
            min_running_time=None)

        # and
        expected_data = {0: some_exp_entry_row}
        expected_leaderboard = pd.DataFrame.from_dict(data=expected_data,
                                                      orient='index')
        expected_leaderboard = expected_leaderboard.reindex(
            # pylint: disable=protected-access
            self.project._sort_leaderboard_columns(expected_leaderboard.columns
                                                   ),
            axis='columns')

        self.assertTrue(leaderboard.equals(expected_leaderboard))

    def test_sort_leaderboard_columns(self):
        # given
        columns_in_expected_order = [
            'id', 'name', 'created', 'finished', 'owner', 'notes', 'size',
            'tags', 'channel_abc', 'channel_def', 'parameter_abc',
            'parameter_def', 'property_abc', 'property_def'
        ]

        # when
        # pylint: disable=protected-access
        sorted_columns = self.project._sort_leaderboard_columns(
            reversed(columns_in_expected_order))

        # then
        self.assertEqual(columns_in_expected_order, sorted_columns)

    def test_full_id(self):
        # expect
        self.assertEqual(self.project.namespace + '/' + self.project.name,
                         self.project.full_id)

    def test_to_string(self):
        # expect
        self.assertEqual('Project({})'.format(self.project.full_id),
                         str(self.project))

    def test_repr(self):
        # expect
        self.assertEqual('Project({})'.format(self.project.full_id),
                         repr(self.project))

    # pylint: disable=protected-access
    def test_get_current_experiment_from_stack(self):
        # given
        experiment = Munch(internal_id=a_uuid_string())

        # when
        self.project._push_new_experiment(experiment)

        # then
        self.assertEqual(self.project._get_current_experiment(), experiment)

    # pylint: disable=protected-access
    def test_pop_experiment_from_stack(self):
        # given
        first_experiment = Munch(internal_id=a_uuid_string())
        second_experiment = Munch(internal_id=a_uuid_string())
        # and
        self.project._push_new_experiment(first_experiment)

        # when
        self.project._push_new_experiment(second_experiment)

        # then
        self.assertEqual(self.project._get_current_experiment(),
                         second_experiment)
        # and
        self.project._remove_stopped_experiment(second_experiment)
        # and
        self.assertEqual(self.project._get_current_experiment(),
                         first_experiment)

    # pylint: disable=protected-access
    def test_empty_stack(self):
        # expect
        with self.assertRaises(NoExperimentContext):
            self.project._get_current_experiment()

    def test_create_experiment_with_relative_upload_sources(self):
        # skip if
        if sys.version_info.major < 3 or (sys.version_info.major == 3
                                          and sys.version_info.minor < 5):
            self.skipTest("not supported in this Python version")

        # given
        os.chdir('tests/neptune')
        # and
        anExperiment = MagicMock()
        self.backend.create_experiment.return_value = anExperiment

        # when
        self.project.create_experiment(
            upload_source_files=["test_project.*", "../../*.md"])

        # then
        anExperiment._start.assert_called_once()
        self.assertTrue({
            entry.target_path
            for entry in anExperiment._start.call_args[1]
            ['upload_source_entries']
        } == {
            "CODE_OF_CONDUCT.md", "README.md", "tests/neptune/test_project.py"
        })

    def test_create_experiment_with_absolute_upload_sources(self):
        # skip if
        if sys.version_info.major < 3 or (sys.version_info.major == 3
                                          and sys.version_info.minor < 5):
            self.skipTest("not supported in this Python version")

        # given
        os.chdir('tests/neptune')
        # and
        anExperiment = MagicMock()
        self.backend.create_experiment.return_value = anExperiment

        # when
        self.project.create_experiment(upload_source_files=[
            os.path.abspath('test_project.py'), "../../*.md"
        ])

        # then
        anExperiment._start.assert_called_once()
        self.assertTrue({
            entry.target_path
            for entry in anExperiment._start.call_args[1]
            ['upload_source_entries']
        } == {
            "CODE_OF_CONDUCT.md", "README.md", "tests/neptune/test_project.py"
        })

    def test_create_experiment_with_upload_single_sources(self):
        # given
        os.chdir('tests/neptune')
        # and
        anExperiment = MagicMock()
        self.backend.create_experiment.return_value = anExperiment

        # when
        self.project.create_experiment(upload_source_files=['test_project.py'])

        # then
        anExperiment._start.assert_called_once()
        self.assertTrue({
            entry.target_path
            for entry in anExperiment._start.call_args[1]
            ['upload_source_entries']
        } == {"test_project.py"})

    def test_create_experiment_with_common_path_below_current_directory(self):
        # given
        anExperiment = MagicMock()
        self.backend.create_experiment.return_value = anExperiment

        # when
        self.project.create_experiment(
            upload_source_files=['tests/neptune/*.*'])

        # then
        anExperiment._start.assert_called_once()
        self.assertTrue(
            anExperiment._start.call_args[1]['upload_source_entries']
            [0].target_path.startswith('tests/neptune/'))

    @patch('neptune.projects.glob',
           new=lambda path: [path.replace('*', 'file.txt')])
    @patch('neptune.projects.os.path', new=ntpath)
    @patch('neptune.internal.storage.storage_utils.os.sep', new=ntpath.sep)
    def test_create_experiment_with_upload_sources_from_multiple_drives_on_windows(
            self):
        # given
        anExperiment = MagicMock()
        # and
        self.backend.create_experiment.return_value = anExperiment

        # when
        self.project.create_experiment(
            upload_source_files=['c:\\test1\\*', 'd:\\test2\\*'])

        # then
        anExperiment._start.assert_called_once()
        self.assertTrue({
            entry.target_path
            for entry in anExperiment._start.call_args[1]
            ['upload_source_entries']
        } == {'c:/test1/file.txt', 'd:/test2/file.txt'})
Exemplo n.º 14
0
 def setUp(self):
     super(TestProject, self).setUp()
     self.backend = MagicMock()
     self.project = Project(backend=self.backend, internal_id=a_uuid_string(), namespace=a_string(), name=a_string())
Exemplo n.º 15
0
class TestProject(unittest.TestCase):
    def setUp(self):
        super(TestProject, self).setUp()
        self.backend = MagicMock()
        self.project = Project(backend=self.backend, internal_id=a_uuid_string(), namespace=a_string(), name=a_string())

    def test_get_members(self):
        # given
        member_usernames = [a_string() for _ in range(0, 2)]
        members = [a_registered_project_member(username) for username in member_usernames]

        # and
        self.backend.get_project_members.return_value = members + [an_invited_project_member()]

        # when
        fetched_member_usernames = self.project.get_members()

        # then
        self.backend.get_project_members.assert_called_once_with(self.project.internal_id)

        # and
        self.assertEqual(member_usernames, fetched_member_usernames)

    def test_get_experiments_with_no_params(self):
        # given
        leaderboard_entries = [MagicMock() for _ in range(0, 2)]
        self.backend.get_leaderboard_entries.return_value = leaderboard_entries

        # when
        experiments = self.project.get_experiments()

        # then
        self.backend.get_leaderboard_entries.assert_called_once_with(
            project=self.project,
            ids=None,
            states=None, owners=None, tags=None,
            min_running_time=None)

        # and
        expected_experiments = [Experiment(self.backend, self.project, entry.id, entry.internal_id)
                                for entry in leaderboard_entries]
        self.assertEqual(expected_experiments, experiments)

    def test_get_experiments_with_scalar_params(self):
        # given
        leaderboard_entries = [MagicMock() for _ in range(0, 2)]
        self.backend.get_leaderboard_entries.return_value = leaderboard_entries

        # and
        params = dict(
            id=a_string(),
            state='succeeded', owner=a_string(), tag=a_string(),
            min_running_time=randint(1, 100))

        # when
        experiments = self.project.get_experiments(**params)

        # then
        expected_params = dict(
            project=self.project,
            ids=[params['id']],
            states=[params['state']], owners=[params['owner']], tags=[params['tag']],
            min_running_time=params['min_running_time']
        )
        self.backend.get_leaderboard_entries.assert_called_once_with(**expected_params)

        # and
        expected_experiments = [Experiment(self.backend, self.project, entry.id, entry.internal_id)
                                for entry in leaderboard_entries]
        self.assertEqual(expected_experiments, experiments)

    def test_get_experiments_with_list_params(self):
        # given
        leaderboard_entries = [MagicMock() for _ in range(0, 2)]
        self.backend.get_leaderboard_entries.return_value = leaderboard_entries

        # and
        params = dict(
            id=a_string_list(),
            state=['succeeded', 'failed'], owner=a_string_list(), tag=a_string_list(),
            min_running_time=randint(1, 100))

        # when
        experiments = self.project.get_experiments(**params)

        # then
        expected_params = dict(
            project=self.project,
            ids=params['id'],
            states=params['state'], owners=params['owner'], tags=params['tag'],
            min_running_time=params['min_running_time']
        )
        self.backend.get_leaderboard_entries.assert_called_once_with(**expected_params)

        # and
        expected_experiments = [Experiment(self.backend, self.project, entry.id, entry.internal_id)
                                for entry in leaderboard_entries]
        self.assertEqual(expected_experiments, experiments)

    def test_get_leaderboard(self):
        # given
        self.backend.get_leaderboard_entries.return_value = [LeaderboardEntry(some_exp_entry_dto)]

        # when
        leaderboard = self.project.get_leaderboard()

        # then
        self.backend.get_leaderboard_entries.assert_called_once_with(
            project=self.project,
            ids=None,
            states=None, owners=None, tags=None,
            min_running_time=None)

        # and
        expected_data = {0: some_exp_entry_row}
        expected_leaderboard = pd.DataFrame.from_dict(data=expected_data, orient='index')
        expected_leaderboard = expected_leaderboard.reindex(
            # pylint: disable=protected-access
            self.project._sort_leaderboard_columns(expected_leaderboard.columns), axis='columns')

        self.assertTrue(leaderboard.equals(expected_leaderboard))

    def test_sort_leaderboard_columns(self):
        # given
        columns_in_expected_order = [
            'id', 'name', 'created', 'finished', 'owner',
            'notes', 'size', 'tags',
            'channel_abc', 'channel_def',
            'parameter_abc', 'parameter_def',
            'property_abc', 'property_def'
        ]

        # when
        # pylint: disable=protected-access
        sorted_columns = self.project._sort_leaderboard_columns(reversed(columns_in_expected_order))

        # then
        self.assertEqual(columns_in_expected_order, sorted_columns)

    def test_full_id(self):
        # expect
        self.assertEqual(self.project.namespace + '/' + self.project.name, self.project.full_id)

    def test_to_string(self):
        # expect
        self.assertEqual('Project({})'.format(self.project.full_id), str(self.project))

    def test_repr(self):
        # expect
        self.assertEqual('Project({})'.format(self.project.full_id), repr(self.project))

    # pylint: disable=protected-access
    def test_get_current_experiment_from_stack(self):
        # given
        experiment = Munch(internal_id=a_uuid_string())

        # when
        self.project._push_new_experiment(experiment)

        # then
        self.assertEqual(self.project._get_current_experiment(), experiment)

    # pylint: disable=protected-access
    def test_pop_experiment_from_stack(self):
        # given
        first_experiment = Munch(internal_id=a_uuid_string())
        second_experiment = Munch(internal_id=a_uuid_string())
        # and
        self.project._push_new_experiment(first_experiment)

        # when
        self.project._push_new_experiment(second_experiment)

        # then
        self.assertEqual(self.project._get_current_experiment(), second_experiment)
        # and
        self.assertEqual(self.project._pop_stopped_experiment(), second_experiment)
        # and
        self.assertEqual(self.project._get_current_experiment(), first_experiment)

    # pylint: disable=protected-access
    def test_empty_stack(self):
        # when
        self.assertIsNone(self.project._pop_stopped_experiment())
        # and
        with self.assertRaises(NoExperimentContext):
            self.project._get_current_experiment()
Exemplo n.º 16
0
 def setUp(self):
     super(TestProject, self).setUp()
     self.client = MagicMock()
     self.project = Project(client=self.client, internal_id=a_uuid_string(), namespace=a_string(), name=a_string())