#Log config experiment = Experiment(project_name="neontrees", workspace="bw4sz") experiment.log_parameters(model.config["train"]) experiment.log_parameters(model.config["evaluation"]) experiment.log_parameters(model.config["predict"]) experiment.add_tag("HSI") ##Train #Train see config.yml for tfrecords path with weighted classes in cross entropy model.read_data() class_weight = model.calc_class_weight() ##Train subnetwork experiment.log_parameter("Train subnetworks", True) with experiment.context_manager("HSI_spatial_subnetwork"): print("Train HSI spatial subnetwork") model.read_data(mode="HSI_submodel") model.train(submodel="spatial", sensor="hyperspectral",class_weight=[class_weight, class_weight, class_weight], experiment=experiment) with experiment.context_manager("HSI_spectral_subnetwork"): print("Train HSI spectral subnetwork") model.read_data(mode="HSI_submodel") model.train(submodel="spectral", sensor="hyperspectral", class_weight=[class_weight, class_weight, class_weight], experiment=experiment) #Train full model with experiment.context_manager("HSI_model"): experiment.log_parameter("Class Weighted", True) model.read_data(mode="HSI_train") model.train(class_weight=class_weight, sensor="hyperspectral", experiment=experiment) model.HSI_model.save("{}/HSI_model.h5".format(save_dir))
dirname = model.config["train"]["checkpoint_dir"] if model.config["train"]["gpus"] > 1: with model.strategy.scope(): print("Running in parallel on {} GPUs".format(model.strategy.num_replicas_in_sync)) #model.RGB_model = load_model("{}/RGB_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum}, compile=False) model.HSI_model = load_model("{}/HSI_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum}, compile=False) model.metadata_model = load_model("{}/metadata_model.h5".format(dirname), compile=False) else: #model.RGB_model = load_model("{}/RGB_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum}) model.HSI_model = load_model("{}/HSI_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum}) model.metadata_model = load_model("{}/metadata_model.h5".format(dirname), compile=False) else: if model.config["train"]["pretrain"]: #metadata network with experiment.context_manager("metadata"): print("Train metadata") model.read_data(mode="metadata") print(model.metadata_model.summary()) model.train(submodel="metadata", experiment=experiment) model.metadata_model.save("{}/metadata_model.h5".format(save_dir)) ##Train subnetwork experiment.log_parameter("Train subnetworks", True) with experiment.context_manager("HSI_spatial_subnetwork"): print("Train HSI spatial subnetwork") model.read_data(mode="HSI_submodel") model.train(submodel="spatial", sensor="hyperspectral", experiment=experiment) with experiment.context_manager("HSI_spectral_subnetwork"):
#Linear metadata model for testing purposes from comet_ml import Experiment import tensorflow as tf from DeepTreeAttention.trees import AttentionModel from DeepTreeAttention.models import metadata from DeepTreeAttention.callbacks import callbacks import pandas as pd model = AttentionModel( config="/home/b.weinstein/DeepTreeAttention/conf/tree_config.yml") model.create() #Log config experiment = Experiment(project_name="neontrees", workspace="bw4sz") experiment.log_parameters(model.config["train"]) experiment.log_parameters(model.config["evaluation"]) experiment.log_parameters(model.config["predict"]) experiment.add_tag("metadata") ##Train #Train see config.yml for tfrecords path with weighted classes in cross entropy with experiment.context_manager("metadata"): model.read_data(mode="metadata") class_weight = model.calc_class_weight() model.train(submodel="metadata", experiment=experiment, class_weight=class_weight)
#Create a class and run model = AttentionModel(config="/home/b.weinstein/DeepTreeAttention/conf/tree_config.yml") model.create() #Log config experiment.log_parameters(model.config["train"]) experiment.log_parameters(model.config["predict"]) ##Train #Train see config.yml for tfrecords path with weighted classes in cross entropy model.read_data(validation_split=True) class_weight = model.calc_class_weight() ## Train subnetwork experiment.log_parameter("Train subnetworks", True) with experiment.context_manager("spatial_subnetwork"): print("Train spatial subnetwork") model.read_data(mode="submodel",validation_split=True) model.train(submodel="spatial", class_weight=[class_weight, class_weight, class_weight]) with experiment.context_manager("spectral_subnetwork"): print("Train spectral subnetwork") model.read_data(mode="submodel",validation_split=True) model.train(submodel="spectral", class_weight=[class_weight, class_weight, class_weight]) #Train full model experiment.log_parameter("Class Weighted", True) model.read_data(validation_split=True) model.train(class_weight=class_weight, experiment=experiment) #Get Alpha score for the weighted spectral/spatial average. Higher alpha favors spatial network. if model.config["train"]["weighted_sum"]:
os.mkdir(save_dir) #Log config experiment = Experiment(project_name="neontrees", workspace="bw4sz") experiment.log_parameters(model.config["train"]) experiment.log_parameters(model.config["evaluation"]) experiment.log_parameters(model.config["predict"]) experiment.add_tag("RGB") experiment.log_parameter("timestamp", timestamp) ##Train #Train see config.yml for tfrecords path with weighted classes in cross entropy ##Train subnetwork experiment.log_parameter("Train subnetworks", True) with experiment.context_manager("RGB_spatial_subnetwork"): print("Train RGB spatial subnetwork") model.read_data(HSI=False, RGB=True, metadata=False, submodel=False) model.train(submodel="spatial", sensor="RGB", experiment=experiment) with experiment.context_manager("RGB_spectral_subnetwork"): print("Train RGB spectral subnetwork") model.train(submodel="spectral", sensor="RGB", experiment=experiment) #Train full model with experiment.context_manager("RGB_model"): model.read_data(HSI=False, RGB=True, metadata=False) model.train(sensor="RGB", experiment=experiment) #Get Alpha score for the weighted spectral/spatial average. Higher alpha favors spatial network. if model.config["train"]["RGB"]["weighted_sum"]: