def assert_models_same(self, model, model2): """Assert two keras models same.""" if model.optimizer: assert ( model.optimizer.get_config() == model2.optimizer.get_config()) assert len(model.layers) == len(model2.layers) # shallow comparison layers = list(iterlayers(model)) layers2 = list(iterlayers(model2)) assert len(layers) == len(layers2) # deep comparison
def test__freeze_sets_trainable_except(self, multinetwork): model = 'forecaster' all_models = multinetwork.model_names multinetwork._freeze_models_except(model) for m in all_models: if m == model: for layer in iterlayers(multinetwork.model[m]): assert layer.trainable is True else: for layer in iterlayers(multinetwork.model[m]): assert layer.trainable is False
def test_multimodel(self, multimodel, multimodel_num_layers): for name, model in multimodel.items(): layers = list(iterlayers(model)) expected_num_layers = multimodel_num_layers[name] assert len(layers) == expected_num_layers
def test_model(self, model, n_layers): layers = list(iterlayers(model)) assert len(layers) == n_layers
def test__freeze_sets_trainable_all(self, multinetwork): all_models = multinetwork.model_names multinetwork._freeze_models_except(all_models) for m in all_models: for layer in iterlayers(multinetwork.model[m]): assert layer.trainable is True
def test__freeze_sets_trainable_none(self, multinetwork, model): multinetwork._freeze_models_except(model) all_models = multinetwork.model_names for m in all_models: for layer in iterlayers(multinetwork.model[m]): assert layer.trainable is False