Exemplo n.º 1
0
 def test_hash_model_nochange(self):
     model1 = uniform_model(10, ActivationLayer)
     model2 = uniform_model(10, ActivationLayer)
     model_set = []
     genome_set = []
     EEGNAS.utilities.NAS_utils.hash_model(model1, model_set, genome_set)
     EEGNAS.utilities.NAS_utils.hash_model(model2, model_set, genome_set)
     assert (len(model_set)) == 1
     assert (len(genome_set)) == 1
Exemplo n.º 2
0
 def test_breed(self):
     model1 = uniform_model(10, BatchNormLayer)
     model2 = uniform_model(10, DropoutLayer)
     model3, _, _ = breed_layers(0, model1, model2, cut_point=4)
     for i in range(10):
         if i < 4:
             assert(type(model3[i]).__name__ == type(model1[i]).__name__)
         else:
             assert (type(model3[i]).__name__ == type(model2[i]).__name__)
     finalize_model(model3)
     pass
Exemplo n.º 3
0
 def test_hash_model(self):
     model1 = uniform_model(10, ActivationLayer)
     model2 = uniform_model(10, ActivationLayer)
     global_vars.set('mutation_rate', 1)
     model3, _ = breed_layers(1, model1, model2)
     model_set = []
     genome_set = []
     EEGNAS.utilities.NAS_utils.hash_model(model1, model_set, genome_set)
     EEGNAS.utilities.NAS_utils.hash_model(model2, model_set, genome_set)
     EEGNAS.utilities.NAS_utils.hash_model(model3, model_set, genome_set)
     assert (len(model_set)) == 1 or 2
     assert (len(genome_set)) == 1 or 2
Exemplo n.º 4
0
 def test_state_inheritance_breeding(self):
     global_vars.set('inherit_breeding_weights', True)
     global_vars.set('num_layers', 4)
     global_vars.set('mutation_rate', 0)
     model1 = uniform_model(4, ConvLayer)
     model1_state = finalize_model(model1).state_dict()
     model2 = uniform_model(4, ConvLayer)
     model2_state = finalize_model(model2).state_dict()
     model3, model3_state, _ = breed_layers(0, model1, model2, model1_state, model2_state, 2)
     for s1, s3 in zip(list(model1_state.values())[:4], list(model3_state.values())[:4]):
         assert((s1==s3).all())
     for s2, s3 in zip(list(model2_state.values())[6:8], list(model3_state.values())[6:8]):
         assert((s2==s3).all())
Exemplo n.º 5
0
 def test_fix_model(self):
     model1 = uniform_model(3, ConvLayer)
     try:
         models_generation.new_model_from_structure_pytorch(model1)
         assert False
     except Exception:
         assert True
     try:
         models_generation.new_model_from_structure_pytorch(model1, applyFix=True)
         assert True
     except Exception:
         assert False