Esempio n. 1
0
    def testMultiStepRun(self):
        ray.init(num_cpus=4, num_gpus=2)
        kwargs = {
            "stopping_criterion": {
                "training_iteration": 5
            },
            "resources": Resources(cpu=1, gpu=1),
        }
        trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
        snapshot = TrialStatusSnapshot()
        runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
        for t in trials:
            runner.add_trial(t)

        while not runner.is_finished():
            runner.step()

        self.assertTrue(snapshot.all_trials_are_terminated())
Esempio n. 2
0
    def testExtraCustomResources(self):
        ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
        # Since each trial will occupy the full custom resources,
        # there are at most 1 trial running at any given moment.
        snapshot = TrialStatusSnapshot()
        runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
        kwargs = {
            "stopping_criterion": {"training_iteration": 1},
            "placement_group_factory": PlacementGroupFactory([{"CPU": 1}, {"a": 2}]),
        }
        trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
        for t in trials:
            runner.add_trial(t)

        while not runner.is_finished():
            runner.step()

        self.assertLess(snapshot.max_running_trials(), 2)
        self.assertTrue(snapshot.all_trials_are_terminated())
Esempio n. 3
0
    def testExtraResources(self):
        ray.init(num_cpus=4, num_gpus=2)
        snapshot = TrialStatusSnapshot()
        runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
        kwargs = {
            "stopping_criterion": {"training_iteration": 1},
            "placement_group_factory": PlacementGroupFactory(
                [{"CPU": 1}, {"CPU": 3, "GPU": 1}]
            ),
        }
        trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
        for t in trials:
            runner.add_trial(t)

        while not runner.is_finished():
            runner.step()

        self.assertLess(snapshot.max_running_trials(), 2)
        self.assertTrue(snapshot.all_trials_are_terminated())