def load_train_volumes(only_t2=False, adaptive_hist=False): # Set up image path image_base_path = '/media/matt/Seagate Expansion Drive/MR Data/MR_Images_Sarcoma' # Set up data constants block_size = [18, 142, 142] oversamp_test = 1.0 lab_trun = 2 test_split = 0.1 val_split = 0.2 # Get filenames filenames = load_filenames_2nd(base_path=image_base_path) nfiles = len(filenames) # Yield the number of sets in the generator yield round((1 - (val_split + test_split)) * nfiles) if only_t2: filenames = keep_t2(filenames) # Remove validation and test set inds = np.array((range(nfiles)), dtype=int) np.random.seed(RAND_SEED) np.random.shuffle(inds) mask = np.ones(inds.shape, dtype=bool) # Test data mask[:round(val_split * nfiles)] = 0 mask[-round(test_split * nfiles):] = 0 train_files = [filenames[i] for i in train_inds] while True: for train_file in train_files: X_train, Y_train, _ = load_data([train_file], block_size, oversamp_test, lab_trun, adaptive_hist) yield [X_train, Y_train]
def statistical_metrics(spaths): epochs = 600 batch_size = 20 block_size = [18, 142, 142] oversamp = 1.0 oversamp_test = 1.0 lab_trun = 2 im_freq = 50 val_split = 0.2 test_split = 0.1 lr = 2e-4 # Load training data image_base_path = '/media/matt/Seagate Expansion Drive/MR Data/MR_Images_Sarcoma' # Load training data filenames = load_filenames_2nd(base_path=image_base_path) nfiles = len(filenames) # Remove all but T2 images if only_t2: filenames = keep_t2(filenames) # Remove validation and test set inds = np.array((range(nfiles)), dtype=int) np.random.seed(RAND_SEED) np.random.shuffle(inds) # Validation data val_inds = inds[:round(val_split * nfiles)] val_file = [filenames[i] for i in val_inds] # Test data test_inds = inds[-round(test_split * nfiles):] test_file = [filenames[i] for i in test_inds] # Delete all data filenames = [ filename for i, filename in enumerate(filenames) if i not in list(val_inds) + list(test_inds) ] # Load data x_test, y_test, orig_size_test = load_data(test_file, block_size, oversamp_test, lab_trun, adaptive_hist) print('Size of test set: \t\t', x_test.shape) for spath in spaths: # Display which network is training _, net_name = os.path.split(spath) print('\n\n\n') print('Testing: %s' % net_name) print('-' * 80 + '\n') # Load trained model model_path = os.path.join(spath, 'Trained_model.h5') model = keras.models.load_model(model_path, custom_objects={ 'dice_loss': dice_loss, 'dice_metric': dice_metric }) # Load best threshold file = os.path.join(spath, 'metrics2.txt') # Read meetrics file with open(file, 'r') as f: dat = f.readlines() # Append values ind = -7 tmp = [i for i in dat[ind] if i.isdigit() or i == '.'] threshold = float(''.join(tmp))
def train_model(networks, spaths, only_t2): """ Train specified model Args: networks (list): list of keras networks to train spaths (list): list of output directories only_t2 (bool): whether or not to only use T2 data Returns: """ epochs = 600 batch_size = 20 block_size = [18, 142, 142] oversamp = 1.0 oversamp_test = 1.0 lab_trun = 2 im_freq = 50 val_split = 0.2 test_split = 0.1 lr = 1e-4 adaptive_hist = False # Load training data image_base_path = '/media/matt/Seagate Expansion Drive/MR Data/MR_Images_Sarcoma' # image_base_path = 'E:/MR Data/MR_Images_Sarcoma' # Load training data filenames = load_filenames_2nd(base_path=image_base_path) nfiles = len(filenames) # Remove T2 images if only_t2: filenames = keep_t2(filenames) # Remove validation and test set inds = np.array((range(nfiles)), dtype=int) np.random.seed(RAND_SEED) np.random.shuffle(inds) # Validation data val_inds = inds[:round(val_split * nfiles)] val_file = [filenames[i] for i in val_inds] # Test data test_inds = inds[-round(test_split * nfiles):] test_file = [filenames[i] for i in test_inds] # Delete all data filenames = [ filename for i, filename in enumerate(filenames) if i not in list(val_inds) + list(test_inds) ] # Load data print('Loading data') x, y, orig_size = load_data(filenames, block_size, oversamp, lab_trun, adaptive_hist) # val_file = val_file[:1] x_val, Y_val, orig_size_val = load_data(val_file, block_size, oversamp, lab_trun, adaptive_hist) # test_file = test_file[:1] x_test, y_test, orig_size_test = load_data(test_file, block_size, oversamp_test, lab_trun, adaptive_hist) # shuffle training data inds = np.arange(0, x.shape[0]) np.random.seed(5) np.random.shuffle(inds) x = x[inds] y = y[inds] print('Size of training set:\t\t', x.shape) print('Size of validation set: \t', x_val.shape) print('Size of test set: \t\t', x_test.shape) sz_patch = x.shape for network, spath in zip(networks, spaths): # Set up save path spath = spath % datetime.strftime(datetime.now(), '%Y_%m_%d_%H-%M-%S') if not os.path.exists(spath): os.mkdir(spath) # Display which network is training _, net_name = os.path.split(spath) if only_t2: net_name += '_t2' print('\n\n\n') print('Training: %s' % net_name) print('-' * 80 + '\n') # Save a copy of this code save_code(spath, os.path.realpath(__file__)) # Load model model, opt = network(pretrained_weights=None, input_size=(sz_patch[1], sz_patch[2], sz_patch[3], sz_patch[4]), lr=lr) # Set up callbacks tensorboard = keras.callbacks.TensorBoard(log_dir="logs/%s_%s" % net_name, write_graph=False, write_grads=False, write_images=False, histogram_freq=0) ckpoint_weights = keras.callbacks.ModelCheckpoint( os.path.join(spath, 'ModelCheckpoint.h5'), monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True, mode='auto', period=10) image_recon_callback = image_callback_val(x_val, Y_val, spath, orig_size_val, block_size=block_size, oversamp=oversamp_test, lab_trun=lab_trun, im_freq=im_freq, batch_size=batch_size) # Train model model.fit( x=x, y=y, epochs=epochs, batch_size=batch_size, validation_data=(x_val, Y_val), callbacks=[tensorboard, ckpoint_weights, image_recon_callback]) model.save(os.path.join(spath, 'Trained_model.h5')) # Calculate best threshold from training data threshold = training_threshold(model, spath, X=x, Y=y) # threshold = 0.5 # Evaluate test data test_set_3D(model, x_test, y_test, spath, orig_size_test, block_size, oversamp_test, lab_trun, batch_size=1, threshold=threshold)
def run_model_test_best_weigts(spaths, only_t2): epochs = 600 batch_size = 20 block_size = [18, 142, 142] oversamp = 1.0 oversamp_test = 1.0 lab_trun = 2 im_freq = 50 val_split = 0.2 test_split = 0.1 lr = 2e-4 adaptive_hist = False # Load training data image_base_path = '/media/matt/Seagate Expansion Drive/MR Data/MR_Images_Sarcoma' # Load training data filenames = load_filenames_2nd(base_path=image_base_path) nfiles = len(filenames) # Remove all but T2 images if only_t2: filenames = keep_t2(filenames) # Remove validation and test set inds = np.array((range(nfiles)), dtype=int) np.random.seed(RAND_SEED) np.random.shuffle(inds) # Validation data val_inds = inds[:round(val_split * nfiles)] val_file = [filenames[i] for i in val_inds] # Test data test_inds = inds[-round(test_split * nfiles):] test_file = [filenames[i] for i in test_inds] # Delete all data filenames = [ filename for i, filename in enumerate(filenames) if i not in list(val_inds) + list(test_inds) ] # Load data x, y, orig_size = load_data(filenames, block_size, oversamp, lab_trun, adaptive_hist) x_test, y_test, orig_size_test = load_data(test_file, block_size, oversamp_test, lab_trun, adaptive_hist) print('Size of training set:\t\t', x.shape) print('Size of test set: \t\t', x_test.shape) for spath in spaths: # Display which network is training _, net_name = os.path.split(spath) print('\n\n\n') print('Testing: %s' % net_name) print('-' * 80 + '\n') # Load trained model model_path = os.path.join(spath, 'Trained_model.h5') model = keras.models.load_model(model_path, custom_objects={ 'dice_loss': dice_loss, 'dice_metric': dice_metric }) # Load best (from validation set) weights_file = os.path.join(spath, 'ModelCheckpoint.h5') model.load_weights(weights_file) # Calculate best threshold from training data threshold = training_threshold(model, spath, X=x, Y=y) # threshold = 0.5 # Evaluate test data test_set_3D(model, x_test, y_test, spath, orig_size_test, block_size, oversamp_test, lab_trun, batch_size=20, threshold=threshold, vols=len(test_file), continuous=True)