예제 #1
0
    def train(self, class_weight=None, submodel=None):
        """Train a model"""

        callback_list = callbacks.create(self.log_dir)

        if submodel == "spatial":
            #The spatial model is very shallow compared to spectral, train for longer
            self.spatial_model.fit(self.train_split,
                                   epochs=int(self.config["train"]["epochs"] /
                                              2),
                                   validation_data=self.val_split,
                                   callbacks=callback_list,
                                   class_weight=class_weight)

        elif submodel == "spectral":
            #one for each loss layer
            self.spectral_model.fit(self.train_split,
                                    epochs=int(self.config["train"]["epochs"] /
                                               2),
                                    validation_data=self.val_split,
                                    callbacks=callback_list,
                                    class_weight=class_weight)
        else:
            self.model.fit(self.train_split,
                           epochs=self.config["train"]["epochs"],
                           validation_data=self.val_split,
                           callbacks=callback_list,
                           class_weight=class_weight)
예제 #2
0
    def ensemble(self, experiment, class_weight=None):

        if self.val_split is None:
            print("Cannot run callbacks without validation data, skipping...")
            callback_list = None
            label_names = None
        elif experiment is None:
            print("Cannot run callbacks without comet experiment, skipping...")
            callback_list = None
            label_names = None
        else:
            if self.classes_file is not None:
                labeldf = pd.read_csv(self.classes_file)
                label_names = list(labeldf.taxonID.values)
            else:
                label_names = None

            callback_list = callbacks.create(log_dir=self.log_dir,
                                             experiment=experiment,
                                             validation_data=self.val_split,
                                             train_data=self.train_split,
                                             label_names=label_names,
                                             train_shp=self.train_shp,
                                             submodel="ensemble")

        #Train ensemble layer
        self.ensemble_model.fit(
            self.train_split,
            epochs=self.config["train"]["ensemble"]["epochs"],
            validation_data=self.val_split,
            callbacks=callback_list,
            class_weight=class_weight)
예제 #3
0
 def train(self, class_weight=None):
     """Train a model"""       
     
     callback_list = callbacks.create()
     
     self.model.fit(
         self.train_split,
         epochs=self.config["train"]["epochs"],
         validation_data=self.val_split,
         callbacks=callback_list,
         class_weight=class_weight
     )
예제 #4
0
 def ensemble(self, experiment, class_weight=None, freeze = True, train=True):
     #Manually override batch size
     self.config["train"]["batch_size"] = self.config["train"]["ensemble"]["batch_size"]
     self.read_data(mode="ensemble")
     self.ensemble_model = Hang.ensemble([self.HSI_model, self.RGB_model], freeze=freeze, classes=self.classes)
     
     if train:
         self.ensemble_model.compile(
             loss="categorical_crossentropy",
             optimizer=tf.keras.optimizers.Adam(
             lr=float(self.config["train"]["learning_rate"])),
             metrics=[tf.keras.metrics.CategoricalAccuracy(
                                                          name='acc')])
         
         if self.val_split is None:
             print("Cannot run callbacks without validation data, skipping...")
             callback_list = None
             label_names = None
         elif experiment is None:
             print("Cannot run callbacks without comet experiment, skipping...")
             callback_list = None
             label_names = None
         else:            
             if self.classes_file is not None:
                 labeldf = pd.read_csv(self.classes_file)                
                 label_names = list(labeldf.taxonID.values)
             else:
                 label_names = None
                 
             callback_list = callbacks.create(log_dir=self.log_dir,
                                              experiment=experiment,
                                              validation_data=self.val_split,
                                              train_data=self.train_split,
                                              label_names=label_names,
                                              submodel="ensemble")
                 
         #Train ensemble layer
         self.ensemble_model.fit(
             self.train_split,
             epochs=self.config["train"]["epochs"],
             validation_data=self.val_split,
             callbacks=callback_list,
             class_weight=class_weight)
예제 #5
0
    def train(self, experiment=None, class_weight=None, submodel=None):
        """Train a model with callbacks"""

        if self.val_split is None:
            print("Cannot run callbacks without validation data, skipping...")
            callback_list = None
        elif experiment is None:
            print("Cannot run callbacks without comet experiment, skipping...")
            callback_list = None
        else:
            labeldf = pd.read_csv(self.classes_file)
            callback_list = callbacks.create(log_dir=self.log_dir,
                                             experiment=experiment,
                                             validation_data=self.val_split,
                                             label_names=list(
                                                 labeldf.taxonID.values))

        if submodel == "spatial":
            #The spatial model is very shallow compared to spectral, train for longer
            self.spatial_model.fit(self.train_split,
                                   epochs=int(self.config["train"]["epochs"] /
                                              2),
                                   validation_data=self.val_split,
                                   callbacks=callback_list,
                                   class_weight=class_weight)

        elif submodel == "spectral":
            #one for each loss layer
            self.spectral_model.fit(self.train_split,
                                    epochs=int(self.config["train"]["epochs"] /
                                               2),
                                    validation_data=self.val_split,
                                    callbacks=callback_list,
                                    class_weight=class_weight)
        else:
            self.model.fit(self.train_split,
                           epochs=self.config["train"]["epochs"],
                           validation_data=self.val_split,
                           callbacks=callback_list,
                           class_weight=class_weight)
예제 #6
0
    def train(self, experiment=None, class_weight=None, submodel=None, sensor="hyperspectral"):
        """Train a model with callbacks"""

        if self.val_split is None:
            print("Cannot run callbacks without validation data, skipping...")
            callback_list = None
        elif experiment is None:
            print("Cannot run callbacks without comet experiment, skipping...")
            callback_list = None
        else:            
            if self.classes_file is not None:
                labeldf = pd.read_csv(self.classes_file)                
                label_names = list(labeldf.taxonID.values)
            else:
                label_names = None
                
            callback_list = callbacks.create(log_dir=self.log_dir,
                                             experiment=experiment,
                                             validation_data=self.val_split,
                                             train_data=self.train_split,
                                             label_names=label_names,
                                             submodel=submodel)
        
        if submodel == "spatial":
            if sensor == "hyperspectral":
                self.HSI_spatial.fit(self.train_split,
                                       epochs=int(self.config["train"]["epochs"]),
                                       validation_data=self.val_split,
                                       callbacks=callback_list,
                                       class_weight=class_weight)
            
            elif sensor == "RGB":
                self.RGB_spatial.fit(self.train_split,
                                                 epochs=int(self.config["train"]["epochs"]),
                                                   validation_data=self.val_split,
                                                   callbacks=callback_list,
                                                   class_weight=class_weight)                

        elif submodel == "spectral":
            if sensor == "hyperspectral":
                self.HSI_spectral.fit(self.train_split,
                                       epochs=int(self.config["train"]["epochs"]),
                                       validation_data=self.val_split,
                                       callbacks=callback_list,
                                       class_weight=class_weight)
            elif sensor == "RGB":
                self.RGB_spectral.fit(self.train_split,
                                                 epochs=int(self.config["train"]["epochs"]),
                                                   validation_data=self.val_split,
                                                   callbacks=callback_list,
                                                   class_weight=class_weight)      
        else:
            if sensor == "hyperspectral":
                self.HSI_model.fit(self.train_split,
                               epochs=self.config["train"]["epochs"],
                               validation_data=self.val_split,
                               callbacks=callback_list,
                               class_weight=class_weight)
            
            elif sensor == "RGB":
                self.RGB_model.fit(
                    self.train_split,
                    epochs=self.config["train"]["epochs"],
                    validation_data=self.val_split,
                    callbacks=callback_list,
                    class_weight=class_weight)
예제 #7
0
    def ensemble(self, experiment, class_weight=None, freeze=True, train=True):
        self.classes = pd.read_csv(self.classes_file).shape[0]

        self.read_data(mode="ensemble")

        if self.val_split is None:
            print("Cannot run callbacks without validation data, skipping...")
            callback_list = None
            label_names = None
        elif experiment is None:
            print("Cannot run callbacks without comet experiment, skipping...")
            callback_list = None
            label_names = None
        else:
            if self.classes_file is not None:
                labeldf = pd.read_csv(self.classes_file)
                label_names = list(labeldf.taxonID.values)
            else:
                label_names = None

            callback_list = callbacks.create(log_dir=self.log_dir,
                                             experiment=experiment,
                                             validation_data=self.val_split,
                                             train_data=self.train_split,
                                             label_names=label_names,
                                             train_shp=self.train_shp,
                                             submodel="ensemble")

            print("callback list is {}".format(callback_list))

        if self.config["train"]["gpus"] > 1:
            with self.strategy.scope():
                self.ensemble_model = Hang.learned_ensemble(
                    HSI_model=self.HSI_model,
                    metadata_model=self.metadata_model,
                    freeze=freeze,
                    classes=self.classes)

                if train:
                    self.ensemble_model.compile(
                        loss="categorical_crossentropy",
                        optimizer=tf.keras.optimizers.Adam(
                            lr=float(self.config["train"]["learning_rate"])),
                        metrics=[
                            tf.keras.metrics.CategoricalAccuracy(name='acc')
                        ])
                    #Train ensemble layer
                    self.ensemble_model.fit(
                        self.train_split,
                        epochs=self.config["train"]["ensemble"]["epochs"],
                        validation_data=self.val_split,
                        callbacks=callback_list,
                        class_weight=class_weight)
        else:
            self.ensemble_model = Hang.learned_ensemble(
                HSI_model=self.HSI_model,
                metadata_model=self.metadata_model,
                freeze=freeze,
                classes=self.classes)
            if train:
                self.ensemble_model.compile(
                    loss="categorical_crossentropy",
                    optimizer=tf.keras.optimizers.Adam(
                        lr=float(self.config["train"]["learning_rate"])),
                    metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc')])

                #Train ensemble layer
                self.ensemble_model.fit(
                    self.train_split,
                    epochs=self.config["train"]["ensemble"]["epochs"],
                    validation_data=self.val_split,
                    callbacks=callback_list,
                    class_weight=class_weight)
예제 #8
0
    config="/home/b.weinstein/DeepTreeAttention/conf/tree_config.yml",
    log_dir=save_dir)
model.read_data("HSI")
model.create()

baseline = vanilla.create(
    height=model.config["train"]["HSI"]["crop_size"],
    width=model.config["train"]["HSI"]["crop_size"],
    channels=model.config["train"]["HSI"]["sensor_channels"],
    classes=model.classes)
baseline.compile(loss="categorical_crossentropy",
                 optimizer=tf.keras.optimizers.Adam(
                     lr=float(model.config["train"]["learning_rate"])),
                 metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc')])

labeldf = pd.read_csv(model.classes_file)
label_names = list(labeldf.taxonID.values)

callback_list = callbacks.create(experiment=experiment,
                                 train_data=model.train_split,
                                 validation_data=model.val_split,
                                 train_shp=model.train_shp,
                                 log_dir=None,
                                 label_names=label_names,
                                 submodel=False)

baseline.fit(model.train_split,
             epochs=model.config["train"]["ensemble"]["epochs"],
             validation_data=model.val_split,
             callbacks=callback_list)