Esempio n. 1
0
    def load_state_dict(self, state):
        state = decompress_dict(state)

        super(RoBO, self).load_state_dict(state)

        self.count = state['count']
        state['count'] = self.count
        numpy.random.set_state(
            decode_rng_state(state['global_oh_my_god_numpy_rng_state']))
        self.robo.maximize_func.rng.set_state(
            decode_rng_state(state['maximizer_rng_state']))
        model = self.robo.model
        model.rng.set_state(decode_rng_state(state['model_rng_state']))
        model.prior.rng.set_state(decode_rng_state(state['prior_rng_state']))
        if self.model_type == 'gp_mcmc':
            if state.get('model_p0', None) is not None:
                model.p0 = numpy.array(state['model_p0'])
                model.burned = True
            elif hasattr(model, 'p0'):
                delattr(model, 'p0')
                model.burned = False
        else:
            model.kernel.set_parameter_vector(
                state['model_kernel_parameter_vector'])
            model.noise = state['noise']
Esempio n. 2
0
    def load_state_dict(self, state):
        state = decompress_dict(state)

        super(Hyperband, self).load_state_dict(state)
        self.offset = state['offset']
        self.brackets = [
            _Bracket().load_state_dict(b, self.trials)
            for b in state['brackets']
        ]
        return self
Esempio n. 3
0
    def load_state_dict(self, state):
        state = decompress_dict(state)

        self.space = Space.from_dict(state['space'])
        self.seed = state['seed']
        self.manual_samples = state['manual_samples']
        self.manual_fidelity = state['manual_fidelity']
        self.manual_insert = state['manual_insert']
        self.seed_time = state['seed_time']
        self.fidelity = Fidelity.from_dict(state['fidelity'])
        self.trials = OrderedDict(
            (k, Trial.from_dict(t)) for k, t in state['trials'])
        return self
Esempio n. 4
0
def test_register_hpo_is_actionable(client):
    """Test that the registered HPO have valid workitems and can be executed."""
    namespace = 'test-hpo'
    config = {
        'name': 'random_search',
        'seed': 1,
        'count': 1,
        'fidelity': Fidelity(1, 10, name='d').to_dict(),
        'space': {
            'a': 'uniform(-1, 1)',
            'b': 'uniform(-1, 1)',
            'c': 'uniform(-1, 1)',
            'd': 'uniform(-1, 1)'
        }
    }

    defaults = {}
    register_hpo(client, namespace, foo, config, defaults)
    worker = TrialWorker(URI, DATABASE, 0, namespace)
    worker.max_retry = 0
    worker.run()

    assert client.monitor().read_count(WORK_QUEUE, namespace,
                                       mtype=WORK_ITEM) == 1
    assert client.monitor().read_count(WORK_QUEUE, namespace,
                                       mtype=HPO_ITEM) == 2

    messages = client.monitor().unread_messages(RESULT_QUEUE,
                                                namespace,
                                                mtype=HPO_ITEM)

    compressed_state = messages[0].message.get('hpo_state')
    assert compressed_state is not None
    state = decompress_dict(compressed_state)

    assert len(state['trials']) == 1
    assert state['trials'][0][1]['objectives'] == [10.715799430116764]
Esempio n. 5
0
    def from_dict(state):
        state = decompress_dict(state)

        hpo = Hyperband(state['fidelity'], state['space'], state['seed'])
        hpo.load_state_dict(state)
        return hpo
Esempio n. 6
0
    def load_state_dict(self, state):
        state = decompress_dict(state)

        super(GridSearch, self).load_state_dict(state)
        self.count = state['count']