Exemple #1
0
    def test_stopping_group_stops_iteration(self):
        # Fake reschedule
        with patch('hpsearch.tasks.hyperband.hp_hyperband_start.apply_async'
                   ) as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=
                experiment_group_spec_content_hyperband_trigger_reschedule)
        assert mock_fct.call_count == 1
        ExperimentGroupIteration.objects.create(
            experiment_group=experiment_group,
            data={
                'iteration': 0,
                'bracket_iteration': 21
            })
        # Mark experiment as done
        with patch(
                'scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            for xp in experiment_group.experiments.all():
                ExperimentStatusFactory(experiment=xp,
                                        status=ExperimentLifeCycle.SUCCEEDED)
        # Mark group as stopped
        ExperimentGroupStatusFactory(experiment_group=experiment_group,
                                     status=ExperimentGroupLifeCycle.STOPPED)
        with patch('hpsearch.tasks.hyperband.hp_hyperband_create.apply_async'
                   ) as mock_fct1:
            hp_hyperband_start(experiment_group.id)

        assert mock_fct1.call_count == 0
Exemple #2
0
    def test_stop_all_experiments(self):
        with patch('hpsearch.tasks.random.hp_random_search_start.apply_async'
                   ) as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=experiment_group_spec_content_early_stopping)

        assert mock_fct.call_count == 1

        # Add a running experiment
        experiment = ExperimentFactory(experiment_group=experiment_group)
        ExperimentStatusFactory(experiment=experiment,
                                status=ExperimentLifeCycle.RUNNING)
        assert experiment_group.pending_experiments.count() == 2
        assert experiment_group.running_experiments.count() == 1
        assert experiment_group.experiments.count() == 3
        assert experiment_group.stopped_experiments.count() == 0

        with patch('scheduler.experiment_scheduler.stop_experiment'
                   ) as spawner_mock_fct:
            experiments_group_stop_experiments(
                experiment_group_id=experiment_group.id, pending=False)

        assert experiment_group.pending_experiments.count() == 0
        assert experiment_group.running_experiments.count() == 0
        assert spawner_mock_fct.call_count == 1  # Should be stopped with this function
        assert experiment_group.stopped_experiments.count() == 3
Exemple #3
0
    def test_bo_rescheduling(self):
        with patch('hpsearch.tasks.bo.hp_bo_start.apply_async') as mock_fct:
            ExperimentGroupFactory(content=experiment_group_spec_content_bo)

        assert mock_fct.call_count == 2

        with patch.object(GroupChecks, 'is_checked') as mock_is_check:
            with patch('hpsearch.tasks.bo.hp_bo_iterate.apply_async') as mock_fct1:
                with patch('scheduler.tasks.experiments.'
                           'experiments_build.apply_async') as mock_fct2:
                    mock_is_check.return_value = False
                    ExperimentGroupFactory(
                        content=experiment_group_spec_content_bo)

        assert mock_fct1.call_count == 2
        # 2 experiments, but since we are mocking the scheduling function, it's 4 calls,
        # every call to start tries to schedule again, but in reality it's just 2 calls
        assert mock_fct2.call_count == 4

        # Fake
        with patch('hpsearch.tasks.bo.hp_bo_start.apply_async') as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=experiment_group_spec_content_bo)
        assert mock_fct.call_count == 2
        assert experiment_group.non_done_experiments.count() == 2

        # Mark experiment as done
        with patch('scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            with patch('hpsearch.tasks.bo.hp_bo_start.apply_async') as xp_trigger_start:
                for xp in experiment_group.experiments.all():
                    ExperimentStatusFactory(experiment=xp, status=ExperimentLifeCycle.SUCCEEDED)

        assert xp_trigger_start.call_count == experiment_group.experiments.count()
        with patch('hpsearch.tasks.bo.hp_bo_iterate.apply_async') as mock_fct1:
            hp_bo_start(experiment_group.id)
        assert mock_fct1.call_count == 1

        # Mark experiment as done
        with patch('scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            with patch('hpsearch.tasks.bo.hp_bo_start.apply_async') as xp_trigger_start:
                for xp in experiment_group.experiments.all():
                    ExperimentStatusFactory(experiment=xp, status=ExperimentLifeCycle.SUCCEEDED)
        assert xp_trigger_start.call_count == experiment_group.experiments.count()
        GroupChecks(group=experiment_group.id).clear()
        with patch('hpsearch.tasks.bo.hp_bo_create.apply_async') as mock_fct1:
            hp_bo_start(experiment_group.id)
        assert mock_fct1.call_count == 1
    def test_heartbeat_experiments(self):
        experiment1 = ExperimentFactory()
        ExperimentStatusFactory(experiment=experiment1, status=ExperimentLifeCycle.SCHEDULED)
        experiment2 = ExperimentFactory()
        ExperimentStatusFactory(experiment=experiment2, status=ExperimentLifeCycle.CREATED)
        experiment3 = ExperimentFactory()
        ExperimentStatusFactory(experiment=experiment3, status=ExperimentLifeCycle.FAILED)
        experiment4 = ExperimentFactory()
        ExperimentStatusFactory(experiment=experiment4, status=ExperimentLifeCycle.STARTING)
        experiment5 = ExperimentFactory()
        ExperimentStatusFactory(experiment=experiment5, status=ExperimentLifeCycle.RUNNING)

        with patch('scheduler.tasks.experiments'
                   '.experiments_check_heartbeat.apply_async') as mock_fct:
            heartbeat_experiments()

        assert mock_fct.call_count == 1
Exemple #5
0
    def test_experiments_sync_jobs_statuses(self):
        with patch('scheduler.tasks.experiments.experiments_build.apply_async'
                   ) as _:  # noqa
            with patch.object(Experiment, 'set_status') as _:  # noqa
                experiments = [ExperimentFactory() for _ in range(3)]

        done_xp, no_jobs_xp, xp_with_jobs = experiments

        # Set done status
        with patch(
                'scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            ExperimentStatusFactory(experiment=done_xp,
                                    status=JobLifeCycle.FAILED)

        # Create jobs for xp_with_jobs and update status, and do not update the xp status
        with patch.object(Experiment, 'set_status') as _:  # noqa
            job = ExperimentJobFactory(experiment=xp_with_jobs)
            ExperimentJobStatusFactory(job=job, status=JobLifeCycle.RUNNING)

        xp_with_jobs.refresh_from_db()
        assert xp_with_jobs.last_status is None

        # Mock sync experiments and jobs constants
        with patch(
                'scheduler.tasks.experiments.'
                'experiments_check_status.apply_async') as check_status_mock:
            experiments_sync_jobs_statuses()

        assert check_status_mock.call_count == 1

        # Call sync experiments and jobs constants
        with patch('scheduler.tasks.experiments.experiments_build.apply_async'
                   ) as build_mock:
            ExperimentStatusFactory(experiment=xp_with_jobs,
                                    status=JobLifeCycle.CREATED)
        assert build_mock.call_count == 1
        experiments_sync_jobs_statuses()
        done_xp.refresh_from_db()
        no_jobs_xp.refresh_from_db()
        xp_with_jobs.refresh_from_db()
        assert done_xp.last_status == ExperimentLifeCycle.FAILED
        assert no_jobs_xp.last_status is None
        assert xp_with_jobs.last_status == ExperimentLifeCycle.RUNNING
Exemple #6
0
    def test_bo_rescheduling(self):
        with patch('hpsearch.tasks.bo.hp_bo_start.apply_async') as mock_fct:
            ExperimentGroupFactory(content=experiment_group_spec_content_bo)

        assert mock_fct.call_count == 1

        with patch('hpsearch.tasks.bo.hp_bo_iterate.apply_async') as mock_fct1:
            with patch(
                    'scheduler.tasks.experiments.experiments_build.apply_async'
            ) as mock_fct2:
                ExperimentGroupFactory(
                    content=experiment_group_spec_content_bo)

        assert mock_fct1.call_count == 1
        assert mock_fct2.call_count == 2

        # Fake
        with patch('hpsearch.tasks.bo.hp_bo_start.apply_async') as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=experiment_group_spec_content_bo)
        assert mock_fct.call_count == 1
        assert experiment_group.non_done_experiments.count() == 2

        # Mark experiment as done
        with patch(
                'scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            for xp in experiment_group.experiments.all():
                ExperimentStatusFactory(experiment=xp,
                                        status=ExperimentLifeCycle.SUCCEEDED)
        with patch('hpsearch.tasks.bo.hp_bo_iterate.apply_async') as mock_fct1:
            hp_bo_start(experiment_group.id)
        assert mock_fct1.call_count == 1

        # Mark experiment as done
        with patch(
                'scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            for xp in experiment_group.experiments.all():
                ExperimentStatusFactory(experiment=xp,
                                        status=ExperimentLifeCycle.SUCCEEDED)
        with patch('hpsearch.tasks.bo.hp_bo_create.apply_async') as mock_fct1:
            hp_bo_start(experiment_group.id)
        assert mock_fct1.call_count == 1
Exemple #7
0
    def test_spec_creation_triggers_experiments_creations_and_scheduling(self):
        with patch('runner.hp_search.grid.hp_grid_search_start.apply_async') as mock_fct:
            experiment_group = ExperimentGroupFactory()

        assert Experiment.objects.filter(experiment_group=experiment_group).count() == 2
        assert mock_fct.call_count == 1
        assert experiment_group.pending_experiments.count() == 2
        assert experiment_group.running_experiments.count() == 0
        experiment = Experiment.objects.filter(experiment_group=experiment_group).first()
        ExperimentStatusFactory(experiment=experiment, status=ExperimentLifeCycle.RUNNING)
        assert experiment_group.pending_experiments.count() == 1
        assert experiment_group.running_experiments.count() == 1
        with patch('runner.schedulers.experiment_scheduler.stop_experiment') as _:  # noqa
            ExperimentStatusFactory(experiment=experiment, status=ExperimentLifeCycle.SUCCEEDED)
        assert experiment_group.pending_experiments.count() == 1
        assert experiment_group.running_experiments.count() == 0
        assert experiment_group.succeeded_experiments.count() == 1
        experiment.resume()
        assert experiment_group.pending_experiments.count() == 2
        assert experiment_group.running_experiments.count() == 0
        assert experiment_group.succeeded_experiments.count() == 0
Exemple #8
0
    def test_spec_creation_triggers_experiments_creations_and_scheduling(self):
        with patch('experiment_groups.tasks.start_group_experiments.apply_async') as mock_fct:
            experiment_group = ExperimentGroupFactory()

        assert Experiment.objects.filter(experiment_group=experiment_group).count() == 2
        assert mock_fct.call_count == 1
        assert experiment_group.pending_experiments.count() == 2
        assert experiment_group.running_experiments.count() == 0
        experiment = Experiment.objects.filter(experiment_group=experiment_group).first()
        ExperimentStatusFactory(experiment=experiment, status=ExperimentLifeCycle.RUNNING)
        assert experiment_group.pending_experiments.count() == 1
        assert experiment_group.running_experiments.count() == 1
Exemple #9
0
 def setUp(self):
     super().setUp()
     project = ProjectFactory(user=self.auth_client.user)
     with patch('projects.tasks.start_group_experiments.apply_async') as _:
         self.object = self.factory_class(project=project)
     # Add a running experiment
     experiment = ExperimentFactory(experiment_group=self.object)
     ExperimentStatusFactory(experiment=experiment,
                             status=ExperimentLifeCycle.RUNNING)
     self.url = '/{}/{}/{}/groups/{}/stop'.format(API_V1,
                                                  project.user.username,
                                                  project.name,
                                                  self.object.sequence)
Exemple #10
0
    def setUp(self):
        super().setUp()
        project = ProjectFactory(user=self.auth_client.user)
        with patch('hpsearch.tasks.grid.hp_grid_search_start.apply_async') as mock_fct:
            self.object = self.factory_class(project=project)

        assert mock_fct.call_count == 2
        # Add a running experiment
        experiment = ExperimentFactory(experiment_group=self.object)
        ExperimentStatusFactory(experiment=experiment, status=ExperimentLifeCycle.RUNNING)
        self.url = '/{}/{}/{}/groups/{}/stop'.format(
            API_V1,
            project.user.username,
            project.name,
            self.object.id)
Exemple #11
0
    def test_spec_creation_triggers_experiments_creations_and_scheduling(
            self, create_build_job):
        build = BuildJobFactory()
        BuildJobStatus.objects.create(status=JobLifeCycle.SUCCEEDED, job=build)
        create_build_job.return_value = build, True, True
        with patch('hpsearch.tasks.grid.hp_grid_search_start.apply_async'
                   ) as mock_fct:
            experiment_group = ExperimentGroupFactory()

        assert Experiment.objects.filter(
            experiment_group=experiment_group).count() == 2
        assert mock_fct.call_count == 2
        assert experiment_group.iteration_config.num_suggestions == 2
        assert experiment_group.pending_experiments.count() == 2
        assert experiment_group.running_experiments.count() == 0
        experiment = Experiment.objects.filter(
            experiment_group=experiment_group).first()
        ExperimentStatusFactory(experiment=experiment,
                                status=ExperimentLifeCycle.RUNNING)
        assert experiment_group.pending_experiments.count() == 1
        assert experiment_group.running_experiments.count() == 1
        with patch(
                'scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            ExperimentStatusFactory(experiment=experiment,
                                    status=ExperimentLifeCycle.SUCCEEDED)
        assert experiment_group.pending_experiments.count() == 1
        assert experiment_group.running_experiments.count() == 0
        assert experiment_group.succeeded_experiments.count() == 1
        with patch('scheduler.tasks.experiments.experiments_build.apply_async'
                   ) as start_build:
            experiment.resume()

        assert start_build.call_count == 1
        assert experiment_group.pending_experiments.count() == 2
        assert experiment_group.running_experiments.count() == 0
        assert experiment_group.succeeded_experiments.count() == 1
Exemple #12
0
    def test_stop_pending_experiments(self):
        with patch('hpsearch.tasks.random.hp_random_search_start.apply_async') as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=experiment_group_spec_content_early_stopping)
        experiment = ExperimentFactory(experiment_group=experiment_group)
        ExperimentStatusFactory(experiment=experiment, status=ExperimentLifeCycle.RUNNING)

        assert mock_fct.call_count == 1
        assert experiment_group.pending_experiments.count() == 2
        assert experiment_group.running_experiments.count() == 1

        experiments_group_stop_experiments(experiment_group_id=experiment_group.id, pending=True)

        assert experiment_group.pending_experiments.count() == 0
        assert experiment_group.running_experiments.count() == 1
Exemple #13
0
    def test_stop_pending_experiments(self, create_build_job):
        build = BuildJobFactory()
        BuildJobStatus.objects.create(status=JobLifeCycle.SUCCEEDED, job=build)
        create_build_job.return_value = build, True, True

        with patch('hpsearch.tasks.random.hp_random_search_start.apply_async') as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=experiment_group_spec_content_early_stopping)
        experiment = ExperimentFactory(experiment_group=experiment_group)
        ExperimentStatusFactory(experiment=experiment, status=ExperimentLifeCycle.RUNNING)

        assert mock_fct.call_count == 1
        assert experiment_group.pending_experiments.count() == 2
        assert experiment_group.running_experiments.count() == 1

        experiments_group_stop_experiments(experiment_group_id=experiment_group.id, pending=True)

        assert experiment_group.pending_experiments.count() == 0
        assert experiment_group.running_experiments.count() == 1
Exemple #14
0
    def setUp(self):
        super().setUp()
        project = ProjectFactory(user=self.auth_client.user)
        with patch('hpsearch.tasks.grid.hp_grid_search_start.apply_async') as mock_fct:
            with patch('scheduler.dockerizer_scheduler.create_build_job') as mock_start:
                build = BuildJobFactory()
                BuildJobStatus.objects.create(status=JobLifeCycle.SUCCEEDED, job=build)
                mock_start.return_value = build, True, True
                self.object = self.factory_class(project=project)

        assert mock_fct.call_count == 1
        # Add a running experiment
        experiment = ExperimentFactory(experiment_group=self.object)
        ExperimentStatusFactory(experiment=experiment, status=ExperimentLifeCycle.RUNNING)
        self.url = '/{}/{}/{}/groups/{}/stop'.format(
            API_V1,
            project.user.username,
            project.name,
            self.object.id)
Exemple #15
0
    def test_sync_experiments_and_jobs_statuses(self):
        with patch('runner.tasks.experiments.start_experiment.delay'
                   ) as _:  # noqa
            with patch.object(Experiment, 'set_status') as _:  # noqa
                experiments = [ExperimentFactory() for _ in range(3)]

        done_xp, no_jobs_xp, xp_with_jobs = experiments

        # Set done status
        with patch('runner.schedulers.experiment_scheduler.stop_experiment'
                   ) as _:  # noqa
            ExperimentStatusFactory(experiment=done_xp,
                                    status=JobLifeCycle.FAILED)

        # Create jobs for xp_with_jobs and update status, and do not update the xp status
        with patch.object(Experiment, 'set_status') as _:  # noqa
            job = ExperimentJobFactory(experiment=xp_with_jobs)
            ExperimentJobStatusFactory(job=job, status=JobLifeCycle.RUNNING)

        xp_with_jobs.refresh_from_db()
        assert xp_with_jobs.last_status is None

        # Mock sync experiments and jobs statuses
        with patch('experiments.tasks.check_experiment_status.delay'
                   ) as check_status_mock:
            sync_experiments_and_jobs_statuses()

        assert check_status_mock.call_count == 1

        # Call sync experiments and jobs statuses
        sync_experiments_and_jobs_statuses()
        done_xp.refresh_from_db()
        no_jobs_xp.refresh_from_db()
        xp_with_jobs.refresh_from_db()
        assert done_xp.last_status == ExperimentLifeCycle.FAILED
        assert no_jobs_xp.last_status is None
        assert xp_with_jobs.last_status == ExperimentLifeCycle.RUNNING
Exemple #16
0
    def test_resume(self):
        experiment = ExperimentFactory()
        count_experiment = Experiment.objects.count()
        ExperimentStatus.objects.create(experiment=experiment, status=ExperimentLifeCycle.STOPPED)
        assert experiment.last_status == ExperimentLifeCycle.STOPPED

        config = experiment.config
        declarations = experiment.declarations

        # Resume with same config
        experiment.resume()
        experiment.refresh_from_db()
        assert experiment.last_status == ExperimentLifeCycle.STOPPED
        last_resumed_experiment = experiment.clones.filter(
            cloning_strategy=CloningStrategy.RESUME).last()
        assert last_resumed_experiment.config == config
        assert last_resumed_experiment.declarations == declarations
        assert Experiment.objects.count() == count_experiment + 1
        assert experiment.clones.count() == 1

        # Resume with different config
        new_declarations = {
            'lr': 0.1,
            'dropout': 0.5
        }
        new_experiment = experiment.resume(declarations=new_declarations)
        experiment.refresh_from_db()
        assert experiment.last_status == ExperimentLifeCycle.STOPPED
        last_resumed_experiment = experiment.clones.filter(
            cloning_strategy=CloningStrategy.RESUME).last()
        assert last_resumed_experiment.config == config
        assert last_resumed_experiment.declarations != declarations
        assert last_resumed_experiment.declarations == new_declarations
        assert Experiment.objects.count() == count_experiment + 2
        assert experiment.clones.count() == 2

        # Resuming a resumed experiment
        with patch('scheduler.tasks.experiments.experiments_build.apply_async') as _:  # noqa
            resumed = new_experiment.resume()
            ExperimentStatusFactory(experiment=resumed, status=ExperimentLifeCycle.CREATED)
        experiment.refresh_from_db()
        assert experiment.last_status == ExperimentLifeCycle.STOPPED
        last_resumed_experiment_new = experiment.clones.filter(
            cloning_strategy=CloningStrategy.RESUME).last()
        assert last_resumed_experiment_new.original_experiment.pk != last_resumed_experiment.pk
        assert (last_resumed_experiment_new.original_experiment.pk ==
                last_resumed_experiment.original_experiment.pk)
        assert last_resumed_experiment.config == config
        assert last_resumed_experiment.declarations != declarations
        assert last_resumed_experiment.declarations == new_declarations
        assert Experiment.objects.count() == count_experiment + 3
        assert experiment.clones.count() == 3

        # Deleting a resumed experiment does not delete other experiments
        last_resumed_experiment_new.set_status(ExperimentLifeCycle.SCHEDULED)
        ExperimentJobFactory(experiment=last_resumed_experiment_new)
        with patch('scheduler.experiment_scheduler.stop_experiment') as mock_stop:
            last_resumed_experiment_new.delete()
        assert experiment.clones.count() == 2
        assert mock_stop.call_count == 1

        # Deleting original experiment deletes all
        with patch('scheduler.experiment_scheduler.stop_experiment') as mock_stop:
            experiment.delete()
        assert Experiment.objects.count() == 0
        assert mock_stop.call_count == 0  # No running experiment
Exemple #17
0
    def test_hyperband_rescheduling(self, create_build_job):
        build = BuildJobFactory()
        BuildJobStatus.objects.create(status=JobLifeCycle.SUCCEEDED, job=build)
        create_build_job.return_value = build, True, True

        with patch('hpsearch.tasks.hyperband.hp_hyperband_start.apply_async'
                   ) as mock_fct:
            ExperimentGroupFactory(
                content=experiment_group_spec_content_hyperband)

        assert mock_fct.call_count == 2

        with patch.object(GroupChecks, 'is_checked') as mock_is_check:
            with patch(
                    'hpsearch.tasks.hyperband.hp_hyperband_iterate.apply_async'
            ) as mock_fct1:
                with patch('scheduler.tasks.experiments.'
                           'experiments_build.apply_async') as mock_fct2:
                    mock_is_check.return_value = False
                    experiment_group = ExperimentGroupFactory(
                        content=
                        experiment_group_spec_content_hyperband_trigger_reschedule
                    )

        assert experiment_group.iteration_config.num_suggestions == 9
        assert mock_fct1.call_count == 2
        # 9 experiments, but since we are mocking the scheduling function, it's ~ 3 x calls,
        # every call to start tries to schedule again, but in reality it's just 9 calls
        assert mock_fct2.call_count >= 9 * 2

        # Fake reschedule
        with patch('hpsearch.tasks.hyperband.hp_hyperband_start.apply_async'
                   ) as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=
                experiment_group_spec_content_hyperband_trigger_reschedule)
        self.assertEqual(
            mock_fct.call_count,
            math.ceil(experiment_group.experiments.count() /
                      conf.get('GROUP_CHUNKS')) + 1)
        ExperimentGroupIteration.objects.create(
            experiment_group=experiment_group,
            data={
                'iteration': 0,
                'bracket_iteration': 21,
                'num_suggestions': 9
            })

        experiment_group.iteration.experiments.set(
            experiment_group.experiments.values_list('id', flat=True))

        # Mark experiments as done
        with patch(
                'scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            with patch('hpsearch.tasks.hyperband.'
                       'hp_hyperband_start.apply_async') as xp_trigger_start:
                for xp in experiment_group.experiments.all():
                    ExperimentStatusFactory(
                        experiment=xp, status=ExperimentLifeCycle.SUCCEEDED)

        assert xp_trigger_start.call_count == experiment_group.experiments.count(
        )
        with patch('hpsearch.tasks.hyperband.hp_hyperband_create.apply_async'
                   ) as mock_fct1:
            hp_hyperband_start(experiment_group.id)

        assert mock_fct1.call_count == 1

        # Fake reduce
        with patch('hpsearch.tasks.hyperband.hp_hyperband_start.apply_async'
                   ) as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=
                experiment_group_spec_content_hyperband_trigger_reschedule)
        self.assertEqual(
            mock_fct.call_count,
            math.ceil(experiment_group.experiments.count() /
                      conf.get('GROUP_CHUNKS')) + 1)
        assert experiment_group.non_done_experiments.count() == 9

        # Mark experiment as done
        with patch(
                'scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            with patch('hpsearch.tasks.hyperband.'
                       'hp_hyperband_start.apply_async') as xp_trigger_start:
                for xp in experiment_group.experiments.all():
                    ExperimentStatusFactory(
                        experiment=xp, status=ExperimentLifeCycle.SUCCEEDED)

        assert xp_trigger_start.call_count == experiment_group.experiments.count(
        )
        with patch('hpsearch.tasks.hyperband.hp_hyperband_start.apply_async'
                   ) as mock_fct2:
            with patch.object(HyperbandIterationManager,
                              'reduce_configs') as mock_fct3:
                hp_hyperband_start(experiment_group.id)
        assert mock_fct2.call_count == 1
        assert mock_fct3.call_count == 1
Exemple #18
0
    def test_hyperband_rescheduling(self):
        with patch('hpsearch.tasks.hyperband.hp_hyperband_start.apply_async'
                   ) as mock_fct:
            ExperimentGroupFactory(
                content=experiment_group_spec_content_hyperband)

        assert mock_fct.call_count == 1

        with patch('hpsearch.tasks.hyperband.hp_hyperband_iterate.apply_async'
                   ) as mock_fct1:
            with patch(
                    'scheduler.tasks.experiments.experiments_build.apply_async'
            ) as mock_fct2:
                ExperimentGroupFactory(
                    content=
                    experiment_group_spec_content_hyperband_trigger_reschedule)

        assert mock_fct1.call_count == 1
        assert mock_fct2.call_count == 9

        # Fake reschedule
        with patch('hpsearch.tasks.hyperband.hp_hyperband_start.apply_async'
                   ) as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=
                experiment_group_spec_content_hyperband_trigger_reschedule)
        assert mock_fct.call_count == 1
        ExperimentGroupIteration.objects.create(
            experiment_group=experiment_group,
            data={
                'iteration': 0,
                'bracket_iteration': 21
            })

        # Mark experiment as done
        with patch(
                'scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            for xp in experiment_group.experiments.all():
                ExperimentStatusFactory(experiment=xp,
                                        status=ExperimentLifeCycle.SUCCEEDED)
        with patch('hpsearch.tasks.hyperband.hp_hyperband_create.apply_async'
                   ) as mock_fct1:
            hp_hyperband_start(experiment_group.id)

        assert mock_fct1.call_count == 1

        # Fake reduce
        with patch('hpsearch.tasks.hyperband.hp_hyperband_start.apply_async'
                   ) as mock_fct:
            experiment_group = ExperimentGroupFactory(
                content=
                experiment_group_spec_content_hyperband_trigger_reschedule)
        assert mock_fct.call_count == 1
        assert experiment_group.non_done_experiments.count() == 9

        # Mark experiment as done
        with patch(
                'scheduler.experiment_scheduler.stop_experiment') as _:  # noqa
            for xp in experiment_group.experiments.all():
                ExperimentStatusFactory(experiment=xp,
                                        status=ExperimentLifeCycle.SUCCEEDED)
        with patch('hpsearch.tasks.hyperband.hp_hyperband_start.apply_async'
                   ) as mock_fct2:
            with patch.object(HyperbandIterationManager,
                              'reduce_configs') as mock_fct3:
                hp_hyperband_start(experiment_group.id)
        assert mock_fct2.call_count == 1
        assert mock_fct3.call_count == 1