예제 #1
0
def test_trial_attachments():

    exp_key = "A"
    with TempMongo() as tm:
        mj = tm.mongo_jobs("foo")
        trials = MongoTrials(tm.connection_string("foo"), exp_key=exp_key)

        space = hp.uniform("x", -10, 10)
        max_evals = 3
        fmin_thread = threading.Thread(target=fmin_thread_fn,
                                       args=(space, trials, max_evals))
        fmin_thread.start()

        mw = MongoWorker(mj=mj, logfilename=None, workdir="mongoexp_test_dir")
        n_jobs = max_evals
        while n_jobs:
            try:
                mw.run_one("hostname", 10.0, erase_created_workdir=True)
                print("worker: ran job")
            except Exception as exc:
                print(f"worker: encountered error : {str(exc)}")
                traceback.print_exc()
            n_jobs -= 1
        fmin_thread.join()
        all_trials = MongoTrials(tm.connection_string("foo"))

        assert len(all_trials) == max_evals
        assert trials.count_by_state_synced(JOB_STATE_DONE) == max_evals
        assert trials.count_by_state_unsynced(JOB_STATE_DONE) == max_evals
예제 #2
0
    def work(self):
        """
        Run a small experiment with several workers running in parallel
        using Python threads.
        """
        n_threads = self.n_threads
        jobs_per_thread = self.jobs_per_thread
        n_trials_per_exp = n_threads * jobs_per_thread
        n_trials_total = n_trials_per_exp * len(self.exp_keys)
        
        with TempMongo() as tm:
            mj = tm.mongo_jobs('foodb')
            def newth(ii):
                n_jobs = jobs_per_thread * len(self.exp_keys)
                return threading.Thread(
                        target=self.worker_thread_fn,
                        args=(('hostname', ii), n_jobs, 30.0))
            threads = map(newth, range(n_threads))
            [th.start() for th in threads]

            exp_list = []
            trials_list = []
            try:
                for key in self.exp_keys:
                    print 'running experiment'
                    trials = MongoTrials(tm.connection_string('foodb'), key)
                    assert len(trials) == 0
                    if hasattr(self, 'prep_trials'):
                        self.prep_trials(trials)
                    bandit = self.bandit
                    if self.use_stop:
                        bandit_algo = RandomStop(n_threads * jobs_per_thread,
                                                    self.bandit, cmd=self.cmd)
                        print bandit_algo
                        exp = Experiment(trials, bandit_algo, max_queue_len=1)
                        exp.run(sys.maxint, block_until_done=False)
                    else:
                        bandit_algo = Random(self.bandit, cmd=self.cmd)
                        exp = Experiment(trials, bandit_algo,
                                                       max_queue_len=10000)
                        exp.run(n_threads * jobs_per_thread,
                                 block_until_done=(len(self.exp_keys) == 1))
                    exp_list.append(exp)
                    trials_list.append(trials)
            finally:
                print 'joining worker thread...'
                [th.join() for th in threads]

            for exp in exp_list:
                exp.block_until_done()

            for trials in trials_list:
                assert trials.count_by_state_synced(JOB_STATE_DONE)\
                        == n_trials_per_exp, (trials.count_by_state_synced(JOB_STATE_DONE), n_trials_per_exp)
                assert trials.count_by_state_unsynced(JOB_STATE_DONE)\
                        == n_trials_per_exp
                assert len(trials) == n_trials_per_exp, (
                    'trials failure %d %d ' % (len(trials) , n_trials_per_exp))
                assert len(trials.results) == n_trials_per_exp, (
                    'results failure %d %d ' % (len(trials.results),
                        n_trials_per_exp))
            all_trials = MongoTrials(tm.connection_string('foodb'))
            assert len(all_trials) == n_trials_total