Exemple #1
0
def test_ensemble(RGB_image, HSI_image):
    batch, height, width, channels = HSI_image[0].shape
    sensor_inputs, sensor_outputs, spatial, spectral = Hang.define_model(
        classes=2, height=height, width=width, channels=channels)
    model1 = tf.keras.Model(inputs=sensor_inputs, outputs=sensor_outputs)

    batch, height, width, channels = RGB_image[0].shape
    sensor_inputs, sensor_outputs, spatial, spectral = Hang.define_model(
        classes=2, height=height, width=width, channels=channels)
    model2 = tf.keras.Model(inputs=sensor_inputs, outputs=sensor_outputs)

    ensemble = Hang.ensemble(models=[model1, model2], classes=2)
    prediction = ensemble.predict([HSI_image[0], RGB_image[0]])
    assert prediction.shape == (1, 2)
Exemple #2
0
 def ensemble(self, experiment, class_weight=None, freeze = True, train=True):
     #Manually override batch size
     self.config["train"]["batch_size"] = self.config["train"]["ensemble"]["batch_size"]
     self.read_data(mode="ensemble")
     self.ensemble_model = Hang.ensemble([self.HSI_model, self.RGB_model], freeze=freeze, classes=self.classes)
     
     if train:
         self.ensemble_model.compile(
             loss="categorical_crossentropy",
             optimizer=tf.keras.optimizers.Adam(
             lr=float(self.config["train"]["learning_rate"])),
             metrics=[tf.keras.metrics.CategoricalAccuracy(
                                                          name='acc')])
         
         if self.val_split is None:
             print("Cannot run callbacks without validation data, skipping...")
             callback_list = None
             label_names = None
         elif experiment is None:
             print("Cannot run callbacks without comet experiment, skipping...")
             callback_list = None
             label_names = None
         else:            
             if self.classes_file is not None:
                 labeldf = pd.read_csv(self.classes_file)                
                 label_names = list(labeldf.taxonID.values)
             else:
                 label_names = None
                 
             callback_list = callbacks.create(log_dir=self.log_dir,
                                              experiment=experiment,
                                              validation_data=self.val_split,
                                              train_data=self.train_split,
                                              label_names=label_names,
                                              submodel="ensemble")
                 
         #Train ensemble layer
         self.ensemble_model.fit(
             self.train_split,
             epochs=self.config["train"]["epochs"],
             validation_data=self.val_split,
             callbacks=callback_list,
             class_weight=class_weight)