Ejemplo n.º 1
0
#Log config
experiment = Experiment(project_name="neontrees", workspace="bw4sz")
experiment.log_parameters(model.config["train"])
experiment.log_parameters(model.config["evaluation"])    
experiment.log_parameters(model.config["predict"])
experiment.add_tag("HSI")

##Train

#Train see config.yml for tfrecords path with weighted classes in cross entropy
model.read_data()
class_weight = model.calc_class_weight()

##Train subnetwork
experiment.log_parameter("Train subnetworks", True)
with experiment.context_manager("HSI_spatial_subnetwork"):
    print("Train HSI spatial subnetwork")
    model.read_data(mode="HSI_submodel")
    model.train(submodel="spatial", sensor="hyperspectral",class_weight=[class_weight, class_weight, class_weight], experiment=experiment)

with experiment.context_manager("HSI_spectral_subnetwork"):
    print("Train HSI spectral subnetwork")    
    model.read_data(mode="HSI_submodel")   
    model.train(submodel="spectral", sensor="hyperspectral", class_weight=[class_weight, class_weight, class_weight], experiment=experiment)
        
#Train full model
with experiment.context_manager("HSI_model"):
    experiment.log_parameter("Class Weighted", True)
    model.read_data(mode="HSI_train")
    model.train(class_weight=class_weight, sensor="hyperspectral", experiment=experiment)
    model.HSI_model.save("{}/HSI_model.h5".format(save_dir))
Ejemplo n.º 2
0
     dirname = model.config["train"]["checkpoint_dir"]        
     if model.config["train"]["gpus"] > 1:
         with model.strategy.scope():   
             print("Running in parallel on {} GPUs".format(model.strategy.num_replicas_in_sync))
             #model.RGB_model = load_model("{}/RGB_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum}, compile=False)
             model.HSI_model = load_model("{}/HSI_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum}, compile=False)  
             model.metadata_model = load_model("{}/metadata_model.h5".format(dirname), compile=False)  
     else:
         #model.RGB_model = load_model("{}/RGB_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum})
         model.HSI_model = load_model("{}/HSI_model.h5".format(dirname), custom_objects={"WeightedSum": WeightedSum})     
         model.metadata_model = load_model("{}/metadata_model.h5".format(dirname), compile=False)  
             
 else:
     if model.config["train"]["pretrain"]:
         #metadata network
         with experiment.context_manager("metadata"):
             print("Train metadata")
             model.read_data(mode="metadata")
             print(model.metadata_model.summary())
             
             model.train(submodel="metadata", experiment=experiment)
             model.metadata_model.save("{}/metadata_model.h5".format(save_dir))
         
         ##Train subnetwork
         experiment.log_parameter("Train subnetworks", True)
         with experiment.context_manager("HSI_spatial_subnetwork"):
             print("Train HSI spatial subnetwork")
             model.read_data(mode="HSI_submodel")
             model.train(submodel="spatial", sensor="hyperspectral", experiment=experiment)
         
         with experiment.context_manager("HSI_spectral_subnetwork"):
Ejemplo n.º 3
0
#Linear metadata model for testing purposes
from comet_ml import Experiment
import tensorflow as tf
from DeepTreeAttention.trees import AttentionModel
from DeepTreeAttention.models import metadata
from DeepTreeAttention.callbacks import callbacks
import pandas as pd

model = AttentionModel(
    config="/home/b.weinstein/DeepTreeAttention/conf/tree_config.yml")
model.create()

#Log config
experiment = Experiment(project_name="neontrees", workspace="bw4sz")
experiment.log_parameters(model.config["train"])
experiment.log_parameters(model.config["evaluation"])
experiment.log_parameters(model.config["predict"])
experiment.add_tag("metadata")

##Train
#Train see config.yml for tfrecords path with weighted classes in cross entropy
with experiment.context_manager("metadata"):
    model.read_data(mode="metadata")
    class_weight = model.calc_class_weight()
    model.train(submodel="metadata",
                experiment=experiment,
                class_weight=class_weight)
Ejemplo n.º 4
0
 #Create a class and run
 model = AttentionModel(config="/home/b.weinstein/DeepTreeAttention/conf/tree_config.yml")
 model.create()
     
 #Log config
 experiment.log_parameters(model.config["train"])
 experiment.log_parameters(model.config["predict"])
 
 ##Train
 #Train see config.yml for tfrecords path with weighted classes in cross entropy
 model.read_data(validation_split=True)
 class_weight = model.calc_class_weight()
 
 ## Train subnetwork
 experiment.log_parameter("Train subnetworks", True)
 with experiment.context_manager("spatial_subnetwork"):
     print("Train spatial subnetwork")
     model.read_data(mode="submodel",validation_split=True)
     model.train(submodel="spatial", class_weight=[class_weight, class_weight, class_weight])
 with experiment.context_manager("spectral_subnetwork"):
     print("Train spectral subnetwork")    
     model.read_data(mode="submodel",validation_split=True)   
     model.train(submodel="spectral", class_weight=[class_weight, class_weight, class_weight])
         
 #Train full model
 experiment.log_parameter("Class Weighted", True)
 model.read_data(validation_split=True)
 model.train(class_weight=class_weight, experiment=experiment)
 
 #Get Alpha score for the weighted spectral/spatial average. Higher alpha favors spatial network.
 if model.config["train"]["weighted_sum"]:
Ejemplo n.º 5
0
os.mkdir(save_dir)

#Log config
experiment = Experiment(project_name="neontrees", workspace="bw4sz")
experiment.log_parameters(model.config["train"])
experiment.log_parameters(model.config["evaluation"])
experiment.log_parameters(model.config["predict"])
experiment.add_tag("RGB")
experiment.log_parameter("timestamp", timestamp)

##Train

#Train see config.yml for tfrecords path with weighted classes in cross entropy
##Train subnetwork
experiment.log_parameter("Train subnetworks", True)
with experiment.context_manager("RGB_spatial_subnetwork"):
    print("Train RGB spatial subnetwork")
    model.read_data(HSI=False, RGB=True, metadata=False, submodel=False)
    model.train(submodel="spatial", sensor="RGB", experiment=experiment)

with experiment.context_manager("RGB_spectral_subnetwork"):
    print("Train RGB spectral subnetwork")
    model.train(submodel="spectral", sensor="RGB", experiment=experiment)

#Train full model
with experiment.context_manager("RGB_model"):
    model.read_data(HSI=False, RGB=True, metadata=False)
    model.train(sensor="RGB", experiment=experiment)

    #Get Alpha score for the weighted spectral/spatial average. Higher alpha favors spatial network.
    if model.config["train"]["RGB"]["weighted_sum"]: