def test_hash_model_nochange(self): model1 = uniform_model(10, ActivationLayer) model2 = uniform_model(10, ActivationLayer) model_set = [] genome_set = [] EEGNAS.utilities.NAS_utils.hash_model(model1, model_set, genome_set) EEGNAS.utilities.NAS_utils.hash_model(model2, model_set, genome_set) assert (len(model_set)) == 1 assert (len(genome_set)) == 1
def test_breed(self): model1 = uniform_model(10, BatchNormLayer) model2 = uniform_model(10, DropoutLayer) model3, _, _ = breed_layers(0, model1, model2, cut_point=4) for i in range(10): if i < 4: assert(type(model3[i]).__name__ == type(model1[i]).__name__) else: assert (type(model3[i]).__name__ == type(model2[i]).__name__) finalize_model(model3) pass
def test_hash_model(self): model1 = uniform_model(10, ActivationLayer) model2 = uniform_model(10, ActivationLayer) global_vars.set('mutation_rate', 1) model3, _ = breed_layers(1, model1, model2) model_set = [] genome_set = [] EEGNAS.utilities.NAS_utils.hash_model(model1, model_set, genome_set) EEGNAS.utilities.NAS_utils.hash_model(model2, model_set, genome_set) EEGNAS.utilities.NAS_utils.hash_model(model3, model_set, genome_set) assert (len(model_set)) == 1 or 2 assert (len(genome_set)) == 1 or 2
def test_state_inheritance_breeding(self): global_vars.set('inherit_breeding_weights', True) global_vars.set('num_layers', 4) global_vars.set('mutation_rate', 0) model1 = uniform_model(4, ConvLayer) model1_state = finalize_model(model1).state_dict() model2 = uniform_model(4, ConvLayer) model2_state = finalize_model(model2).state_dict() model3, model3_state, _ = breed_layers(0, model1, model2, model1_state, model2_state, 2) for s1, s3 in zip(list(model1_state.values())[:4], list(model3_state.values())[:4]): assert((s1==s3).all()) for s2, s3 in zip(list(model2_state.values())[6:8], list(model3_state.values())[6:8]): assert((s2==s3).all())
def test_fix_model(self): model1 = uniform_model(3, ConvLayer) try: models_generation.new_model_from_structure_pytorch(model1) assert False except Exception: assert True try: models_generation.new_model_from_structure_pytorch(model1, applyFix=True) assert True except Exception: assert False