def default_config(): train_data = None # Directory with training data eval_data = None # Directory with validation data model_output_dir = None # Directory to output tf model restore_model = False # Set to true to continue training classes_file = None # txt file with classes values (unused for REGRESSION) gpu = '' # GPU to be used for training prediction_type = utils.PredictionType.CLASSIFICATION # One of CLASSIFICATION, REGRESSION or MULTILABEL pretrained_model_name = 'resnet50' model_params = utils.ModelParams(pretrained_model_name=pretrained_model_name).to_dict() # Model parameters training_params = utils.TrainingParams().to_dict() # Training parameters if prediction_type == utils.PredictionType.CLASSIFICATION: assert classes_file is not None model_params['n_classes'] = utils.get_n_classes_from_file(classes_file) elif prediction_type == utils.PredictionType.REGRESSION: model_params['n_classes'] = 1 elif prediction_type == utils.PredictionType.MULTILABEL: assert classes_file is not None model_params['n_classes'] = utils.get_n_classes_from_file_multilabel(classes_file)
def default_config(): #train_data = "Saumweg_training_combined.csv" # Directory with training data train_data = "/scratch/users/ycan/dhSegment/zistovi/train/" eval_data = "/scratch/users/ycan/dhSegment/zistovi/train/" # Directory with validation data model_output_dir = "model_zistovi_unet_all_100/" # Directory to output tf model restore_model = False # Set to true to continue training classes_file = "classes.txt" # txt file with classes values (unused for REGRESSION) gpu = '0' # GPU to be used for training prediction_type = utils.PredictionType.CLASSIFICATION # One of CLASSIFICATION, REGRESSION or MULTILABEL pretrained_model_name = 'unet' model_params = utils.ModelParams(pretrained_model_name=pretrained_model_name).to_dict() # Model parameters training_params = utils.TrainingParams().to_dict() # Training parameters if prediction_type == utils.PredictionType.CLASSIFICATION: assert classes_file is not None model_params['n_classes'] = utils.get_n_classes_from_file(classes_file) elif prediction_type == utils.PredictionType.REGRESSION: model_params['n_classes'] = 1 elif prediction_type == utils.PredictionType.MULTILABEL: assert classes_file is not None model_params['n_classes'] = utils.get_n_classes_from_file_multilabel(classes_file)
import os