Пример #1
0
    def testExperimentTagTruncation(self):
        ray.init(num_cpus=2)

        def train(config, reporter):
            reporter(timesteps_total=1)

        trial_executor = RayTrialExecutor()
        register_trainable("f1", train)

        experiments = {
            "foo": {
                "run": "f1",
                "config": {
                    "a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
                    "b" * 50: tune.sample_from(lambda spec: "long" * 40),
                },
            }
        }

        for name, spec in experiments.items():
            trial_generator = BasicVariantGenerator()
            trial_generator.add_configurations({name: spec})
            while not trial_generator.is_finished():
                trial = trial_generator.next_trial()
                if not trial:
                    break
                trial_executor.start_trial(trial)
                self.assertLessEqual(len(os.path.basename(trial.logdir)), 200)
                trial_executor.stop_trial(trial)
Пример #2
0
    def testExperimentTagTruncation(self):
        ray.init(num_cpus=2)
        trainable_cls = AdaptDLTrainableCreator(_train_simple, num_workers=1)
        trial_executor = RayTrialExecutor()
        experiments = {
            "foo": {
                "run": trainable_cls.__name__,
                "config": {
                    "a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
                    "b" * 50: tune.sample_from(lambda spec: "long" * 40)
                },
            }
        }

        for name, spec in experiments.items():
            trial_generator = BasicVariantGenerator()
            trial_generator.add_configurations({name: spec})
            while not trial_generator.is_finished():
                trial = trial_generator.next_trial()
                if not trial:
                    break
                trial_executor.start_trial(trial)
                assert trial.status == Trial.RUNNING
                assert len(os.path.basename(trial.logdir)) <= 200
                trial_executor.stop_trial(trial)
                assert trial.status == Trial.TERMINATED
Пример #3
0
 def generate_trials(spec, name):
     suggester = BasicVariantGenerator()
     suggester.add_configurations({name: spec})
     trials = []
     while not suggester.is_finished():
         trial = suggester.next_trial()
         if trial:
             trials.append(trial)
         else:
             break
     return trials
Пример #4
0
 def _add_trials(self, name, spec):
     """Add trial by invoking TrialRunner."""
     resource = {}
     resource["trials"] = []
     trial_generator = BasicVariantGenerator()
     trial_generator.add_configurations({name: spec})
     while not trial_generator.is_finished():
         trial = trial_generator.next_trial()
         if not trial:
             break
         runner.add_trial(trial)
         resource["trials"].append(self._trial_info(trial))
     return resource
Пример #5
0
    def testBasicVariantLimiter(self):
        search_alg = BasicVariantGenerator(max_concurrent=2)

        experiment_spec = {
            "run": "__fake",
            "num_samples": 5,
            "stop": {
                "training_iteration": 1
            }
        }
        search_alg.add_configurations({"test": experiment_spec})

        trial1 = search_alg.next_trial()
        self.assertTrue(trial1)

        trial2 = search_alg.next_trial()
        self.assertTrue(trial2)

        # Returns None because of limiting
        trial3 = search_alg.next_trial()
        self.assertFalse(trial3)

        # Finish trial, now trial 3 should be created
        search_alg.on_trial_complete(trial1.trial_id, None, False)
        trial3 = search_alg.next_trial()
        self.assertTrue(trial3)

        trial4 = search_alg.next_trial()
        self.assertFalse(trial4)

        search_alg.on_trial_complete(trial2.trial_id, None, False)
        search_alg.on_trial_complete(trial3.trial_id, None, False)

        trial4 = search_alg.next_trial()
        self.assertTrue(trial4)

        trial5 = search_alg.next_trial()
        self.assertTrue(trial5)

        search_alg.on_trial_complete(trial4.trial_id, None, False)

        # Should also be None because search is finished
        trial6 = search_alg.next_trial()
        self.assertFalse(trial6)