experiment.log_parameters(model.config["evaluation"]) experiment.log_parameters(model.config["predict"]) experiment.add_tag("HSI") ##Train #Train see config.yml for tfrecords path with weighted classes in cross entropy model.read_data() class_weight = model.calc_class_weight() ##Train subnetwork experiment.log_parameter("Train subnetworks", True) with experiment.context_manager("HSI_spatial_subnetwork"): print("Train HSI spatial subnetwork") model.read_data(mode="HSI_submodel") model.train(submodel="spatial", sensor="hyperspectral",class_weight=[class_weight, class_weight, class_weight], experiment=experiment) with experiment.context_manager("HSI_spectral_subnetwork"): print("Train HSI spectral subnetwork") model.read_data(mode="HSI_submodel") model.train(submodel="spectral", sensor="hyperspectral", class_weight=[class_weight, class_weight, class_weight], experiment=experiment) #Train full model with experiment.context_manager("HSI_model"): experiment.log_parameter("Class Weighted", True) model.read_data(mode="HSI_train") model.train(class_weight=class_weight, sensor="hyperspectral", experiment=experiment) model.HSI_model.save("{}/HSI_model.h5".format(save_dir)) #Get Alpha score for the weighted spectral/spatial average. Higher alpha favors spatial network. if model.config["train"]["HSI"]["weighted_sum"]:
model.HSI_model = load_model("{}/HSI_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum}, compile=False) model.metadata_model = load_model("{}/metadata_model.h5".format(dirname), compile=False) else: #model.RGB_model = load_model("{}/RGB_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum}) model.HSI_model = load_model("{}/HSI_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum}) model.metadata_model = load_model("{}/metadata_model.h5".format(dirname), compile=False) else: if model.config["train"]["pretrain"]: #metadata network with experiment.context_manager("metadata"): print("Train metadata") model.read_data(mode="metadata") print(model.metadata_model.summary()) model.train(submodel="metadata", experiment=experiment) model.metadata_model.save("{}/metadata_model.h5".format(save_dir)) ##Train subnetwork experiment.log_parameter("Train subnetworks", True) with experiment.context_manager("HSI_spatial_subnetwork"): print("Train HSI spatial subnetwork") model.read_data(mode="HSI_submodel") model.train(submodel="spatial", sensor="hyperspectral", experiment=experiment) with experiment.context_manager("HSI_spectral_subnetwork"): print("Train HSI spectral subnetwork") model.train(submodel="spectral", sensor="hyperspectral", experiment=experiment) #Train full model with experiment.context_manager("HSI_model"):
#Linear metadata model for testing purposes from comet_ml import Experiment import tensorflow as tf from DeepTreeAttention.trees import AttentionModel from DeepTreeAttention.models import metadata from DeepTreeAttention.callbacks import callbacks import pandas as pd model = AttentionModel( config="/home/b.weinstein/DeepTreeAttention/conf/tree_config.yml") model.create() #Log config experiment = Experiment(project_name="neontrees", workspace="bw4sz") experiment.log_parameters(model.config["train"]) experiment.log_parameters(model.config["evaluation"]) experiment.log_parameters(model.config["predict"]) experiment.add_tag("metadata") ##Train #Train see config.yml for tfrecords path with weighted classes in cross entropy with experiment.context_manager("metadata"): model.read_data(mode="metadata") class_weight = model.calc_class_weight() model.train(submodel="metadata", experiment=experiment, class_weight=class_weight)
#Log config experiment.log_parameters(model.config["train"]) experiment.log_parameters(model.config["predict"]) ##Train #Train see config.yml for tfrecords path with weighted classes in cross entropy model.read_data(validation_split=True) class_weight = model.calc_class_weight() ## Train subnetwork experiment.log_parameter("Train subnetworks", True) with experiment.context_manager("spatial_subnetwork"): print("Train spatial subnetwork") model.read_data(mode="submodel",validation_split=True) model.train(submodel="spatial", class_weight=[class_weight, class_weight, class_weight]) with experiment.context_manager("spectral_subnetwork"): print("Train spectral subnetwork") model.read_data(mode="submodel",validation_split=True) model.train(submodel="spectral", class_weight=[class_weight, class_weight, class_weight]) #Train full model experiment.log_parameter("Class Weighted", True) model.read_data(validation_split=True) model.train(class_weight=class_weight, experiment=experiment) #Get Alpha score for the weighted spectral/spatial average. Higher alpha favors spatial network. if model.config["train"]["weighted_sum"]: estimate_a = model.model.layers[-1].get_weights() experiment.log_metric(name="spatial-spectral weight", value=estimate_a[0][0])
experiment = Experiment(project_name="neontrees", workspace="bw4sz") experiment.log_parameters(model.config["train"]) experiment.log_parameters(model.config["evaluation"]) experiment.log_parameters(model.config["predict"]) experiment.add_tag("RGB") experiment.log_parameter("timestamp", timestamp) ##Train #Train see config.yml for tfrecords path with weighted classes in cross entropy ##Train subnetwork experiment.log_parameter("Train subnetworks", True) with experiment.context_manager("RGB_spatial_subnetwork"): print("Train RGB spatial subnetwork") model.read_data(HSI=False, RGB=True, metadata=False, submodel=False) model.train(submodel="spatial", sensor="RGB", experiment=experiment) with experiment.context_manager("RGB_spectral_subnetwork"): print("Train RGB spectral subnetwork") model.train(submodel="spectral", sensor="RGB", experiment=experiment) #Train full model with experiment.context_manager("RGB_model"): model.read_data(HSI=False, RGB=True, metadata=False) model.train(sensor="RGB", experiment=experiment) #Get Alpha score for the weighted spectral/spatial average. Higher alpha favors spatial network. if model.config["train"]["RGB"]["weighted_sum"]: estimate_a = model.RGB_model.get_layer("weighted_sum").get_weights() experiment.log_metric(name="spatial-spectral weight", value=estimate_a[0][0])