def train(self, class_weight=None, submodel=None): """Train a model""" callback_list = callbacks.create(self.log_dir) if submodel == "spatial": #The spatial model is very shallow compared to spectral, train for longer self.spatial_model.fit(self.train_split, epochs=int(self.config["train"]["epochs"] / 2), validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) elif submodel == "spectral": #one for each loss layer self.spectral_model.fit(self.train_split, epochs=int(self.config["train"]["epochs"] / 2), validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) else: self.model.fit(self.train_split, epochs=self.config["train"]["epochs"], validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight)
def ensemble(self, experiment, class_weight=None): if self.val_split is None: print("Cannot run callbacks without validation data, skipping...") callback_list = None label_names = None elif experiment is None: print("Cannot run callbacks without comet experiment, skipping...") callback_list = None label_names = None else: if self.classes_file is not None: labeldf = pd.read_csv(self.classes_file) label_names = list(labeldf.taxonID.values) else: label_names = None callback_list = callbacks.create(log_dir=self.log_dir, experiment=experiment, validation_data=self.val_split, train_data=self.train_split, label_names=label_names, train_shp=self.train_shp, submodel="ensemble") #Train ensemble layer self.ensemble_model.fit( self.train_split, epochs=self.config["train"]["ensemble"]["epochs"], validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight)
def train(self, class_weight=None): """Train a model""" callback_list = callbacks.create() self.model.fit( self.train_split, epochs=self.config["train"]["epochs"], validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight )
def ensemble(self, experiment, class_weight=None, freeze = True, train=True): #Manually override batch size self.config["train"]["batch_size"] = self.config["train"]["ensemble"]["batch_size"] self.read_data(mode="ensemble") self.ensemble_model = Hang.ensemble([self.HSI_model, self.RGB_model], freeze=freeze, classes=self.classes) if train: self.ensemble_model.compile( loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam( lr=float(self.config["train"]["learning_rate"])), metrics=[tf.keras.metrics.CategoricalAccuracy( name='acc')]) if self.val_split is None: print("Cannot run callbacks without validation data, skipping...") callback_list = None label_names = None elif experiment is None: print("Cannot run callbacks without comet experiment, skipping...") callback_list = None label_names = None else: if self.classes_file is not None: labeldf = pd.read_csv(self.classes_file) label_names = list(labeldf.taxonID.values) else: label_names = None callback_list = callbacks.create(log_dir=self.log_dir, experiment=experiment, validation_data=self.val_split, train_data=self.train_split, label_names=label_names, submodel="ensemble") #Train ensemble layer self.ensemble_model.fit( self.train_split, epochs=self.config["train"]["epochs"], validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight)
def train(self, experiment=None, class_weight=None, submodel=None): """Train a model with callbacks""" if self.val_split is None: print("Cannot run callbacks without validation data, skipping...") callback_list = None elif experiment is None: print("Cannot run callbacks without comet experiment, skipping...") callback_list = None else: labeldf = pd.read_csv(self.classes_file) callback_list = callbacks.create(log_dir=self.log_dir, experiment=experiment, validation_data=self.val_split, label_names=list( labeldf.taxonID.values)) if submodel == "spatial": #The spatial model is very shallow compared to spectral, train for longer self.spatial_model.fit(self.train_split, epochs=int(self.config["train"]["epochs"] / 2), validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) elif submodel == "spectral": #one for each loss layer self.spectral_model.fit(self.train_split, epochs=int(self.config["train"]["epochs"] / 2), validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) else: self.model.fit(self.train_split, epochs=self.config["train"]["epochs"], validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight)
def train(self, experiment=None, class_weight=None, submodel=None, sensor="hyperspectral"): """Train a model with callbacks""" if self.val_split is None: print("Cannot run callbacks without validation data, skipping...") callback_list = None elif experiment is None: print("Cannot run callbacks without comet experiment, skipping...") callback_list = None else: if self.classes_file is not None: labeldf = pd.read_csv(self.classes_file) label_names = list(labeldf.taxonID.values) else: label_names = None callback_list = callbacks.create(log_dir=self.log_dir, experiment=experiment, validation_data=self.val_split, train_data=self.train_split, label_names=label_names, submodel=submodel) if submodel == "spatial": if sensor == "hyperspectral": self.HSI_spatial.fit(self.train_split, epochs=int(self.config["train"]["epochs"]), validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) elif sensor == "RGB": self.RGB_spatial.fit(self.train_split, epochs=int(self.config["train"]["epochs"]), validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) elif submodel == "spectral": if sensor == "hyperspectral": self.HSI_spectral.fit(self.train_split, epochs=int(self.config["train"]["epochs"]), validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) elif sensor == "RGB": self.RGB_spectral.fit(self.train_split, epochs=int(self.config["train"]["epochs"]), validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) else: if sensor == "hyperspectral": self.HSI_model.fit(self.train_split, epochs=self.config["train"]["epochs"], validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) elif sensor == "RGB": self.RGB_model.fit( self.train_split, epochs=self.config["train"]["epochs"], validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight)
def ensemble(self, experiment, class_weight=None, freeze=True, train=True): self.classes = pd.read_csv(self.classes_file).shape[0] self.read_data(mode="ensemble") if self.val_split is None: print("Cannot run callbacks without validation data, skipping...") callback_list = None label_names = None elif experiment is None: print("Cannot run callbacks without comet experiment, skipping...") callback_list = None label_names = None else: if self.classes_file is not None: labeldf = pd.read_csv(self.classes_file) label_names = list(labeldf.taxonID.values) else: label_names = None callback_list = callbacks.create(log_dir=self.log_dir, experiment=experiment, validation_data=self.val_split, train_data=self.train_split, label_names=label_names, train_shp=self.train_shp, submodel="ensemble") print("callback list is {}".format(callback_list)) if self.config["train"]["gpus"] > 1: with self.strategy.scope(): self.ensemble_model = Hang.learned_ensemble( HSI_model=self.HSI_model, metadata_model=self.metadata_model, freeze=freeze, classes=self.classes) if train: self.ensemble_model.compile( loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam( lr=float(self.config["train"]["learning_rate"])), metrics=[ tf.keras.metrics.CategoricalAccuracy(name='acc') ]) #Train ensemble layer self.ensemble_model.fit( self.train_split, epochs=self.config["train"]["ensemble"]["epochs"], validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight) else: self.ensemble_model = Hang.learned_ensemble( HSI_model=self.HSI_model, metadata_model=self.metadata_model, freeze=freeze, classes=self.classes) if train: self.ensemble_model.compile( loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam( lr=float(self.config["train"]["learning_rate"])), metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc')]) #Train ensemble layer self.ensemble_model.fit( self.train_split, epochs=self.config["train"]["ensemble"]["epochs"], validation_data=self.val_split, callbacks=callback_list, class_weight=class_weight)
config="/home/b.weinstein/DeepTreeAttention/conf/tree_config.yml", log_dir=save_dir) model.read_data("HSI") model.create() baseline = vanilla.create( height=model.config["train"]["HSI"]["crop_size"], width=model.config["train"]["HSI"]["crop_size"], channels=model.config["train"]["HSI"]["sensor_channels"], classes=model.classes) baseline.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam( lr=float(model.config["train"]["learning_rate"])), metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc')]) labeldf = pd.read_csv(model.classes_file) label_names = list(labeldf.taxonID.values) callback_list = callbacks.create(experiment=experiment, train_data=model.train_split, validation_data=model.val_split, train_shp=model.train_shp, log_dir=None, label_names=label_names, submodel=False) baseline.fit(model.train_split, epochs=model.config["train"]["ensemble"]["epochs"], validation_data=model.val_split, callbacks=callback_list)