def test_greedy_oracle_state_hypermodel_is_graph(): oracle = greedy.GreedyOracle( hypermodel=utils.build_graph(), objective='val_loss', ) oracle.set_state(oracle.get_state()) assert isinstance(oracle.hypermodel, graph_module.Graph)
def test_greedy_oracle_populate_doesnt_crash_with_init_hps(get_best_trials): hp = keras_tuner.HyperParameters() tf.keras.backend.clear_session() input_node = ak.ImageInput(shape=(32, 32, 3)) input_node.batch_size = 32 input_node.num_samples = 1000 output_node = ak.ImageBlock()(input_node) head = ak.ClassificationHead(num_classes=10) head.shape = (10, ) output_node = head(output_node) graph = ak.graph.Graph(inputs=input_node, outputs=output_node) graph.build(hp) oracle = greedy.GreedyOracle( initial_hps=task_specific.IMAGE_CLASSIFIER, objective="val_loss", seed=utils.SEED, ) trial = mock.Mock() trial.hyperparameters = hp get_best_trials.return_value = [trial] for i in range(10): tf.keras.backend.clear_session() values = oracle.populate_space("a")["values"] hp = oracle.hyperparameters.copy() hp.values = values graph.build(hp) oracle.update_space(hp)
def test_greedy_oracle_get_state_update_space_can_run(): oracle = greedy.GreedyOracle( hypermodel=utils.build_graph(), objective='val_loss', ) oracle.set_state(oracle.get_state()) hp = kerastuner.HyperParameters() hp.Boolean('test') oracle.update_space(hp)
def test_random_oracle_state(): graph = utils.build_graph() oracle = greedy.GreedyOracle( hypermodel=graph, objective='val_loss', ) oracle.hypermodel = graph oracle.set_state(oracle.get_state()) assert oracle.hypermodel is graph
def test_overwrite_search(fit_fn, base_tuner_search, tmp_path): graph = utils.build_graph() tuner = tuner_module.AutoTuner( oracle=greedy.GreedyOracle(graph, objective='val_loss'), hypermodel=graph, directory=tmp_path) tuner.search(epochs=10) assert tuner._finished
def test_greedy_oracle_populate_space_with_no_hp(get_best_trials): hp = keras_tuner.HyperParameters() oracle = greedy.GreedyOracle(objective="val_loss", seed=utils.SEED) trial = mock.Mock() trial.hyperparameters = hp get_best_trials.return_value = [trial] oracle.update_space(hp) values_a = oracle.populate_space("a")["values"] assert len(values_a) == 0
def test_add_early_stopping(fit_fn, base_tuner_search, tmp_path): graph = utils.build_graph() tuner = tuner_module.AutoTuner( oracle=greedy.GreedyOracle(graph, objective='val_loss'), hypermodel=graph, directory=tmp_path) tuner.search(x=None, epochs=10) callbacks = base_tuner_search.call_args_list[0][1]['callbacks'] assert any([isinstance(callback, tf.keras.callbacks.EarlyStopping) for callback in callbacks])
def test_greedy_oracle_populate_different_values(get_best_trials): hp = keras_tuner.HyperParameters() utils.build_graph().build(hp) oracle = greedy.GreedyOracle(objective="val_loss", seed=utils.SEED) trial = mock.Mock() trial.hyperparameters = hp get_best_trials.return_value = [trial] oracle.update_space(hp) values_a = oracle.populate_space("a")["values"] values_b = oracle.populate_space("b")["values"] assert not all([values_a[key] == values_b[key] for key in values_a])
def test_no_epochs(best_epochs, fit_fn, base_tuner_search, tmp_path): best_epochs.return_value = 2 graph = utils.build_graph() tuner = tuner_module.AutoTuner( oracle=greedy.GreedyOracle(graph, objective='val_loss'), hypermodel=graph, directory=tmp_path) tuner.search(x=mock.Mock(), epochs=None, fit_on_val_data=True, validation_data=mock.Mock()) callbacks = fit_fn.call_args_list[0][1]['callbacks'] print(callbacks) assert not any([isinstance(callback, tf.keras.callbacks.EarlyStopping) for callback in callbacks])
def test_greedy_oracle_stop_reach_max_collision(get_best_trials, compute_values_hash): hp = keras_tuner.HyperParameters() utils.build_graph().build(hp) oracle = greedy.GreedyOracle(objective="val_loss", seed=utils.SEED) trial = mock.Mock() trial.hyperparameters = hp get_best_trials.return_value = [trial] compute_values_hash.return_value = 1 oracle.update_space(hp) oracle.populate_space("a")["values"] assert (oracle.populate_space("b")["status"] == keras_tuner.engine.trial.TrialStatus.STOPPED)
def test_greedy_oracle(fn): oracle = greedy.GreedyOracle( hypermodel=utils.build_graph(), objective='val_loss', ) trial = mock.Mock() hp = kerastuner.HyperParameters() trial.hyperparameters = hp fn.return_value = [trial] oracle.update_space(hp) for i in range(2000): oracle._populate_space(str(i)) assert 'optimizer' in oracle._hp_names[greedy.GreedyOracle.OPT] assert 'classification_head_1/dropout' in oracle._hp_names[ greedy.GreedyOracle.ARCH] assert 'image_block_1/block_type' in oracle._hp_names[ greedy.GreedyOracle.HYPER]
def test_image_classifier_oracle(): tf.keras.backend.clear_session() input_node = ak.ImageInput(shape=(32, 32, 3)) output_node = ak.ImageBlock()(input_node) output_node = ak.ClassificationHead(loss='categorical_crossentropy', output_shape=(10, ))(output_node) graph = graph_module.Graph(input_node, output_node) oracle = greedy.GreedyOracle(hypermodel=graph, initial_hps=task_specific.IMAGE_CLASSIFIER, objective='val_loss') oracle._populate_space('0') hp = oracle.get_space() hp.values = task_specific.IMAGE_CLASSIFIER[0] assert len( set(task_specific.IMAGE_CLASSIFIER[0].keys()) - set(oracle.get_space().values.keys())) == 0 oracle._populate_space('1') assert len( set(task_specific.IMAGE_CLASSIFIER[1].keys()) - set(oracle.get_space().values.keys())) == 0
def test_greedy_oracle_populate_doesnt_crash_with_init_hps(get_best_trials): hp = kerastuner.HyperParameters() graph = utils.build_graph() graph.build(hp) oracle = greedy.GreedyOracle( initial_hps=task_specific.IMAGE_CLASSIFIER, objective="val_loss", seed=utils.SEED, ) trial = mock.Mock() trial.hyperparameters = hp get_best_trials.return_value = [trial] for i in range(10): tf.keras.backend.clear_session() values = oracle._populate_space("a")["values"] hp = oracle.hyperparameters.copy() hp.values = values graph.build(hp) oracle.update_space(hp)
def test_greedy_oracle_get_state_update_space_can_run(): oracle = greedy.GreedyOracle(objective="val_loss") oracle.set_state(oracle.get_state()) hp = keras_tuner.HyperParameters() hp.Boolean("test") oracle.update_space(hp)