def process_plot_mri_images(paths, params): """ Plot MRI images from HDF5 file """ # dynamically create hdf5 file hdf5_file = os.path.join(paths['hdf5_folder'], params['hdf5_file']) # read datasets from HDF5 file D = get_datasets_from_group(group_name=params['group_no_bg'], hdf5_file=hdf5_file) # read data from each dataset and plot mri data for i, d in enumerate(D): logging.info(f'Processing dataset : {d} {i}/{len(D)}') # read data from group data = read_dataset_from_group(group_name=params['group_no_bg'], dataset=d, hdf5_file=hdf5_file) # image plot folder image_plot_folder = os.path.join(paths['plot_folder'], params['group_no_bg'], d.split()[-1], d) # create folder to store image to create_directory(image_plot_folder) # a single image for each image in dimensions[0] for i in range(data.shape[0]): # create figure and axes fig, ax = plt.subplots(1, 1, figsize=(10, 10)) # plot mri image ax.imshow(data[i], cmap='gray', vmax=1000) # remove all white space and axes plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) # save the figure fig.savefig(os.path.join(image_plot_folder, f'{i}.png'), dpi=300) # close the plot environment plt.close()
def process_convert_segmentation_to_features(paths, params, verbose=True): # read in all segmentation files F = [ x for x in read_directory(paths['segmentation_folder']) if x[-4:] == '.nii' or x[-7:] == '.nii.gz' ] # get feature size from params feature_size = params['feature_size'] # process each segmentation file for f_idx, file in enumerate(F): logging.info(f'Processing segmentation file : {file} {f_idx}/{len(F)}') # extract patient name from file patient = file.split(os.sep)[-1][:-7] # read patient original MRI image original_images = read_dataset_from_group( group_name=params['group_original_mri'], dataset=patient, hdf5_file=os.path.join(paths['hdf5_folder'], params['hdf5_file'])) # check if original image can be found if original_images is None: logging.error( f'No original image found, please check patient name : {patient}' ) exit(1) # read in nifti file with segmenation data. shape 256,256,54 images = nib.load(file) # empty lists to store X and Y features X = [] Y = [] # fig, axs = plt.subplots(6,4, figsize = (10,10)) # axs = axs.ravel() # plt_idx = 0 # process each slice for mri_slice in range(images.shape[2]): if verbose: logging.debug(f'Slice : {mri_slice}') # extract image slice img = images.dataobj[:, :, mri_slice] # test image for patchers # img_patches = np.zeros((img.shape)) # check if there are any segmentations to be found if np.sum(img) == 0: if verbose: logging.debug('No segmentations found, skipping...') continue # we have to now flip and rotate the image to make them comparable with original dicom orientation when reading it into pyhon img = np.flip(img, 1) img = np.rot90(img) # unique segmentation classes seg_classes = np.unique(img) # remove zero class (this is the background) seg_classes = seg_classes[seg_classes != 0] # get features for each class for seg_class in seg_classes: if verbose: logging.debug( f'Processing segmentation class : {seg_class}') # check which rows have an annotation (we skip the rows that don't have the annotation) rows = np.argwhere(np.any(img[:] == seg_class, axis=1)) # check which colums have an annotation cols = np.argwhere(np.any(img[:] == seg_class, axis=0)) # get start and stop rows min_rows, max_rows = rows[0][0], rows[-1][0] # get start and stop columns min_cols, max_cols = cols[0][0], cols[-1][0] logging.debug(f'Processing rows: {min_rows}-{max_rows}') logging.debug(f'Processing cols: {min_cols}-{max_cols}') # loop over rows and columns to extract patches of the image and check if there are annotations for i in range(min_rows, max_rows - feature_size[0]): for j in range(min_cols, max_cols - feature_size[1]): # extract image patch with the dimensions of the feature img_patch = img[i:i + feature_size[0], j:j + feature_size[1]] # check if all cells have been annotated if np.all(img_patch == seg_class): # extract patch from original MRI image, these will contain the features. patch = original_images[mri_slice][i:i + feature_size[0], j:j + feature_size[1]] # add patch to X and segmentation class to Y X.append([patch]) Y.append([seg_class]) # img_patches[i:i + feature_size[0], j : j + feature_size[1]] = seg_class # axs[plt_idx].imshow(original_images[mri_slice], cmap = 'gray') # axs[plt_idx + 1].imshow(img_patches, vmin = 0, vmax = 3, interpolation = 'nearest') # plt_idx += 2 # plt.show() # continue # convert X and Y to numpy arrays X = np.vstack(X) Y = np.vstack(Y) # create save folder location save_folder = os.path.join(paths['feature_folder'], patient) # create folder create_directory(save_folder) # save features to disk np.savez(file=os.path.join(save_folder, f'{patient}.npz'), X=X, Y=Y)
def get_paths(env=None, create_folders=True): """ Get all project paths """ # if environement argument is not given then get hostname with socket package if env is None: env = get_environment() # empty dictionary to return paths = {} # name of the project paths['project_name'] = 'cod_supervised_classification' # path for local machine if env == 'Shaheens-MacBook-Pro-2.local' or env == 'shaheens-mbp-2.lan': # project base folder paths['base_path'] = os.path.join(os.sep, 'Users', 'shaheen.syed', 'data', 'projects', paths['project_name']) elif env == 'shaheensyed-gpu': # base folder on nofima GPU workstation paths['base_path'] = os.path.join(os.sep, 'home', 'shaheensyed', 'projects', paths['project_name']) elif env == 'shaheengpu': # base folder on UIT GPU workstation paths['base_path'] = os.path.join(os.sep, 'home', 'shaheen', 'projects', paths['project_name']) else: logging.error(f'Environment {env} not implemented.') exit(1) # folder contained original MRI data in Dicom format paths['mri_folder'] = os.path.join(paths['base_path'], 'data', 'mri') # folder for HDF5 files paths['hdf5_folder'] = os.path.join(paths['base_path'], 'data', 'hdf5') # folder for .dcm files with new patient name paths['dcm_folder'] = os.path.join(paths['base_path'], 'data', 'dcm') # folder location for segmentation labels paths['segmentation_folder'] = os.path.join(paths['base_path'], 'data', 'segmentations') # define folder for features paths['feature_folder'] = os.path.join(paths['base_path'], 'data', 'features') # define folder for data augmentation paths[ 'augmentation_folder'] = None #os.path.join(paths['base_path'], 'data', 'augmentation') # folder for datasets paths['dataset_folder'] = os.path.join(paths['base_path'], 'data', 'datasets') # define the plot folder paths['plot_folder'] = os.path.join(paths['base_path'], 'plots') # define plot folder for paper ready plots paths['paper_plot_folder'] = os.path.join(paths['base_path'], 'plots', 'paper_plots') # define folder for tables paths['table_folder'] = os.path.join(paths['base_path'], 'data', 'tables') # folde for trained models paths['model_folder'] = os.path.join(paths['base_path'], 'models') # create all folders if not exist if create_folders: for folder in paths.values(): if folder is not None: if folder != paths['project_name']: create_directory(folder) return paths
def process_plot_mri_with_damaged(paths, params): """ Plot original MRI on left and MRI image with damaged overlayed on the right """ # hdf5 file that contains the original images hdf5_file = os.path.join(paths['hdf5_folder'], params['hdf5_file']) # get all patient names from original MRI group patients = get_datasets_from_group(group_name=params['group_original_mri'], hdf5_file=hdf5_file) # get list of patients without state patients = set( [re.search('(.*) (fersk|Tint)', x).group(1) for x in patients]) # loop over each patient, read data, perform inference for i, patient in enumerate(patients): logging.info(f'Processing patient: {patient} {i + 1}/{len(patients)}') # parse out treatment, sample, and state from patient name treatment, _, _ = parse_patientname(patient_name=f'{patient} fersk') """ Get fresh state """ # read original images fresh_original_images = read_dataset_from_group( dataset=f'{patient} fersk', group_name=params['group_original_mri'], hdf5_file=hdf5_file) # read reconstructed images fresh_reconstructed_images = read_dataset_from_group( dataset=f'{patient} fersk', group_name=params['group_segmented_classification_mri'], hdf5_file=hdf5_file) # only take damaged tissue and set connected tissue fresh_reconstructed_damaged_images = (process_connected_tissue( images=fresh_reconstructed_images.copy(), params=params) == 1) """ Get frozen/thawed """ # read original images frozen_original_images = read_dataset_from_group( dataset=f'{patient} Tint', group_name=params['group_original_mri'], hdf5_file=hdf5_file) # read reconstructed images frozen_reconstructed_images = read_dataset_from_group( dataset=f'{patient} Tint', group_name=params['group_segmented_classification_mri'], hdf5_file=hdf5_file) # only take damaged tissue and set connected tissue frozen_reconstructed_damaged_images = (process_connected_tissue( images=frozen_reconstructed_images.copy(), params=params) == 1) # get total number of slices to process total_num_slices = fresh_original_images.shape[0] # loop over each slice for mri_slice in range(total_num_slices): # check slice validity of fresh patient if check_mri_slice_validity(patient=f'{patient} fersk', mri_slice=mri_slice, total_num_slices=total_num_slices): if check_mri_slice_validity(patient=f'{patient} Tint', mri_slice=mri_slice, total_num_slices=total_num_slices): # setting up the plot environment fig, axs = plt.subplots(2, 2, figsize=(8, 8)) axs = axs.ravel() # define the colors we want plot_colors = ['#250463', '#e34a33'] # create a custom listed colormap (so we can overwrite the colors of predefined cmaps) cmap = colors.ListedColormap(plot_colors) # subfigure label for example, a, b, c, d etc sf = cycle(['a', 'b', 'c', 'd', 'e', 'f', 'g']) """ Plot fresh state """ # obtain vmax score so image grayscales are normalized better vmax_percentile = 99.9 vmax = np.percentile(fresh_original_images[mri_slice], vmax_percentile) # plot fresh original MRI image axs[0].imshow(fresh_original_images[mri_slice], cmap='gray', vmin=0, vmax=vmax) axs[0].set_title( rf'$\bf({next(sf)})$ Fresh - Original MRI') # plot fresh reconstucted image overlayed on top of the original image # axs[1].imshow(fresh_original_images[mri_slice], cmap = 'gray', vmin = 0, vmax = vmax) # im = axs[1].imshow(fresh_reconstructed_images[mri_slice],alpha = 0.7, interpolation = 'none') # axs[1].set_title(rf'$\bf({next(sf)})$ Fresh - Reconstructed') # plot fresh reconstucted image overlayed on top of the original image axs[1].imshow(fresh_original_images[mri_slice], cmap='gray', vmin=0, vmax=vmax) axs[1].imshow( fresh_reconstructed_damaged_images[mri_slice], cmap=cmap, alpha=.5, interpolation='none') axs[1].set_title( rf'$\bf({next(sf)})$ Fresh - Reconstructed') """ Plot frozen/thawed state """ # plot frozen/thawed original MRI image # obtain vmax score so image grayscales are normalized better vmax = np.percentile(frozen_original_images[mri_slice], vmax_percentile) axs[2].imshow(frozen_original_images[mri_slice], cmap='gray', vmin=0, vmax=vmax) axs[2].set_title( rf'$\bf({next(sf)})$ {treatment_to_title(treatment)} - Original MRI' ) # plot frozen reconstucted all classes # axs[4].imshow(frozen_original_images[mri_slice], cmap = 'gray', vmin = 0, vmax = vmax) # im = axs[4].imshow(frozen_reconstructed_images[mri_slice], alpha = 0.7, interpolation = 'none') # axs[4].set_title(rf'$\bf({next(sf)})$ {treatment_to_title(treatment)} - Reconstructed') # # plot frozen/thawed reconstucted image overlayed on top of the original image axs[3].imshow(frozen_original_images[mri_slice], cmap='gray', vmin=0, vmax=vmax) axs[3].imshow( frozen_reconstructed_damaged_images[mri_slice], cmap=cmap, alpha=.5, interpolation='none') axs[3].set_title( rf'$\bf({next(sf)})$ {treatment_to_title(treatment)} - Reconstructed' ) """ Create custom legend """ # add custom legend class_labels = {0: 'background', 1: 'damaged tissue'} class_values = list(class_labels.keys()) # create a patch patches = [ mpatches.Patch(color=plot_colors[i], label=class_labels[i]) for i in range(len(class_values)) ] axs[1].legend( handles=patches ) #, bbox_to_anchor=(1.05, 1), loc = 2, borderaxespad=0. ) # legend for fully reconstructed image # get class labels # class_labels = params['class_labels'] # # get class indexes from dictionary # values = class_labels.keys() # # get the colors of the values, according to the # # colormap used by imshow # plt_colors = [ im.cmap(im.norm(value)) for value in values] # # create a patch (proxy artist) for every color # patches = [ mpatches.Patch(color = plt_colors[i], label= class_labels[i]) for i in range(len(values)) ] # # put those patched as legend-handles into the legend # axs[1].legend(handles = patches)#, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. ) """ Adjust figures """ # remove axis of all subplots [ax.axis('off') for ax in axs] # define plot subfolder subfolder = os.path.join(paths['paper_plot_folder'], 'original_vs_reconstructed', patient) # create subfolder create_directory(subfolder) # crop white space fig.set_tight_layout(True) # save the figure fig.savefig( os.path.join(subfolder, f'slice_{mri_slice}.pdf')) # close the figure environment plt.close()
def create_1d_cnn_non_wear_episodes(X, Y, save_model_folder, model_name, cnn_type, epoch, train_split, dev_split, return_model=False): # define settings of training buffer_size = 64 batch_size = 32 # define training and development split percentage logging.info('Training split : {}, development split : {}'.format( train_split, dev_split)) train_size, dev_size = int(len(X) * train_split), int(len(X) * dev_split) logging.info('Training size : {}, development size : {}'.format( train_size, dev_size)) # create train, dev, test set X_train, X_dev, X_test = X[:train_size], X[train_size:train_size + dev_size], X[train_size + dev_size:] Y_train, Y_dev, Y_test = Y[:train_size], Y[train_size:train_size + dev_size], Y[train_size + dev_size:] # trim down X_train to have equal size batches batch_trim = len(X_train) % batch_size if batch_trim != 0: X_train = X_train[:-batch_trim] Y_train = Y_train[:-batch_trim] # trim down X_dev to have equal sized batches batch_trim = len(X_dev) % batch_size if batch_trim != 0: X_dev = X_dev[:-batch_trim] Y_dev = Y_dev[:-batch_trim] # trim down X_test to have equal sized batches batch_trim = len(X_test) % batch_size if batch_trim != 0: X_test = X_test[:-batch_trim] Y_test = Y_test[:-batch_trim] logging.info(f'X_train : {X_train.shape}, Y_train: {Y_train.shape}') logging.info(f'X_dev : {X_dev.shape}, Y_dev: {Y_dev.shape}') logging.info(f'X_test : {X_test.shape}, Y_test: {Y_test.shape}') # create tensorflow training dataset train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)) # shuffle and create batches train_dataset = train_dataset.shuffle(buffer_size).batch(batch_size) # create tensorflow development dataset dev_dataset = tf.data.Dataset.from_tensor_slices((X_dev, Y_dev)) # shuffle and create batches dev_dataset = dev_dataset.shuffle(buffer_size).batch(batch_size) # create tensorflow development dataset test_dataset = tf.data.Dataset.from_tensor_slices((X_test, Y_test)) # shuffle and create batches test_dataset = test_dataset.shuffle(buffer_size).batch(batch_size) # use multi GPUs mirrored_strategy = tf.distribute.MirroredStrategy() METRICS = [ keras.metrics.TruePositives(name='tp'), keras.metrics.FalsePositives(name='fp'), keras.metrics.TrueNegatives(name='tn'), keras.metrics.FalseNegatives(name='fn'), keras.metrics.BinaryAccuracy(name='accuracy'), keras.metrics.Precision(name='precision'), keras.metrics.Recall(name='recall'), keras.metrics.AUC(name='auc'), ] # empty list to add callbacks to callback_list = [] # early stopping callback callback_list.append( EarlyStopping(monitor='val_loss', restore_best_weights=True, min_delta=0, patience=25, verbose=1, mode='auto')) # context manager for multi-gpu with mirrored_strategy.scope(): # create sequential model model = keras.models.Sequential() if cnn_type == 'v1': model.add( keras.layers.Conv1D(filters=10, kernel_size=10, activation='relu', input_shape=X.shape[1:])) model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(1, activation='sigmoid')) elif cnn_type == 'v2': model.add( keras.layers.Conv1D(filters=10, kernel_size=10, activation='relu', input_shape=X.shape[1:])) model.add( keras.layers.Conv1D(filters=10, kernel_size=10, activation='relu')) model.add( keras.layers.Conv1D(filters=10, kernel_size=10, activation='relu')) model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(1, activation='sigmoid')) elif cnn_type == 'v3': model.add( keras.layers.Conv1D(filters=20, kernel_size=50, activation='relu', input_shape=X.shape[1:])) model.add( keras.layers.Conv1D(filters=20, kernel_size=50, activation='relu')) model.add( keras.layers.Conv1D(filters=20, kernel_size=50, activation='relu')) model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(1, activation='sigmoid')) elif cnn_type == 'v4': model.add( keras.layers.Conv1D(filters=10, kernel_size=10, activation='relu', input_shape=X.shape[1:])) model.add(keras.layers.MaxPooling1D(pool_size=2)) model.add( keras.layers.Conv1D(filters=10, kernel_size=10, activation='relu')) model.add(keras.layers.MaxPooling1D(pool_size=2)) model.add( keras.layers.Conv1D(filters=10, kernel_size=10, activation='relu')) model.add(keras.layers.MaxPooling1D(pool_size=2)) model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(50, activation='relu')) model.add(keras.layers.Dense(1, activation='sigmoid')) # compile the model model.compile(optimizer=keras.optimizers.Adam(lr=1e-3), loss=keras.losses.BinaryCrossentropy(), metrics=METRICS) # fit the model history = model.fit( train_dataset, epochs=epoch, validation_data=dev_dataset, callbacks=callback_list) #, class_weight = class_weight) # evaluate on test set history_test = model.evaluate(test_dataset) # create dataframe df_history_test = pd.DataFrame(history_test, index=[ 'test_loss', 'test_tp', 'test_fp', 'test_tn', 'test_fn', 'test_accuracy', 'test_precision', 'test_recall', 'test_auc' ]) # create save folder if not exists create_directory(os.path.join(save_model_folder, cnn_type)) # save the model model.save(os.path.join(save_model_folder, cnn_type, model_name)) pd.DataFrame(history.history).to_csv( os.path.join(save_model_folder, cnn_type, f'{model_name}_history.csv')) # save test history df_history_test.to_csv( os.path.join(save_model_folder, cnn_type, f'{model_name}_history_test.csv')) # return model if return_model set to True if return_model: return model, history
def plot_segmented_images(paths, params): """ Plot segmented images """ # create hdf5 file hdf5_file = os.path.join(paths['hdf5_folder'], params['hdf5_file']) # get list of patient names to plot patients = get_datasets_from_group(group_name = params['group_segmented_classification_mri'], hdf5_file = hdf5_file) # plot each patient for i, patient in enumerate(patients): logging.info(f'Processing patient: {patient} {i}/{len(patients)}') # read segmented images images = read_dataset_from_group(dataset = patient, group_name = params['group_segmented_classification_mri'], hdf5_file = hdf5_file) # set up plotting environment fig, axs = plt.subplots(6,9, figsize = (20,20)) axs = axs.ravel() # loop over each slice and print for mri_slice in range(images.shape[0]): logging.debug(f'Processing slice: {mri_slice}') # check slice validity if check_mri_slice_validity(patient = patient, mri_slice = mri_slice, total_num_slices = images.shape[0]): # plot image im = axs[mri_slice].imshow(images[mri_slice], vmin = 0, vmax = 5, interpolation='none') axs[mri_slice].set_title(f'{mri_slice}') # get class labels class_labels = params['class_labels'] # get class indexes from dictionary values = class_labels.keys() # get the colors of the values, according to the # colormap used by imshow colors = [ im.cmap(im.norm(value)) for value in values] # create a patch (proxy artist) for every color patches = [ mpatches.Patch(color = colors[i], label= class_labels[i]) for i in range(len(values)) ] # put those patched as legend-handles into the legend plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. ) # make adjustments to each subplot for ax in axs: ax.axis('off') # create plotfolder subfolder plot_sub_folder = os.path.join(paths['plot_folder'], 'segmentation', params['cnn_model']) create_directory(plot_sub_folder) # crop white space fig.set_tight_layout(True) # save the figure fig.savefig(os.path.join(plot_sub_folder, f'{patient}.png')) # close the figure environment plt.close()
def create_image_data_generator(x, y, batch_size, rescale=None, rotation_range=None, width_shift_range=None, height_shift_range=None, shear_range=None, zoom_range=None, horizontal_flip=None, vertical_flip=None, brightness_range=None, save_to_dir=None, seed=42): """ Create image data generator for tensorflow Parameters ------------ x : np.ndarray or os.path X features. Either direct as numpy array or as os.path which will then be loaded y : np.ndarray or os.path Y labels. Either direct as numpy array or as os.path which will then be loaded """ # create image data generator img_args = {} # convert arguments to dictionary when not None if rescale is not None: img_args['rescale'] = rescale if rotation_range is not None: img_args['rotation_range'] = rotation_range if width_shift_range is not None: img_args['width_shift_range'] = width_shift_range if height_shift_range is not None: img_args['height_shift_range'] = height_shift_range if shear_range is not None: img_args['shear_range'] = shear_range if zoom_range is not None: img_args['zoom_range'] = zoom_range if horizontal_flip is not None: img_args['horizontal_flip'] = horizontal_flip if vertical_flip is not None: img_args['vertical_flip'] = vertical_flip if brightness_range is not None: img_args['brightness_range'] = brightness_range # create save_to_dir folder if not None if save_to_dir is not None: create_directory(save_to_dir) # create ImageDataGenerator from unpacked dictionary image_data_generator = ImageDataGenerator(**img_args) # check if x is numpy array, if not, then load x x = x if type(x) is np.ndarray else np.load(x) # same for y y = y if type(y) is np.ndarray else np.load(y) # create the generator generator = image_data_generator.flow(x=x, y=y, batch_size=batch_size, seed=seed, save_to_dir=save_to_dir) return generator
def train_cnn_classifier(paths, params): """ Train CNN classifier Parameters ----------- """ # grid search variables cnn_architectures = ['v6'] for cnn_architecture in cnn_architectures: # read datasets from file datasets = get_datasets_paths(paths['dataset_folder']) # # type of architecture to use # cnn_architecture = 'v3' # read one dataset and extract number of classes num_classes = len(np.unique(np.load(datasets['Y_train']))) # read input shape input_shape = np.load(datasets['X_train']).shape # model checkpoint and final model save folder model_save_folder = os.path.join(paths['model_folder'], get_current_timestamp()) # create folder create_directory(model_save_folder) """ DEFINE LEARNING PARAMETERS """ params.update({'ARCHITECTURE' : cnn_architecture, 'NUM_CLASSES' : num_classes, 'LR' : .05, 'OPTIMIZER' : 'sgd', 'TRAIN_SHAPE' : input_shape, 'INPUT_SHAPE' : input_shape[1:], 'BATCH_SIZE' : 32, 'EPOCHS' : 100, 'ES' : True, 'ES_PATIENCE' : 20, 'ES_RESTORE_WEIGHTS' : True, 'SAVE_CHECKPOINTS' : True, 'RESCALE' : params['rescale_factor'], 'ROTATION_RANGE' : None, 'WIDTH_SHIFT_RANGE' : None, 'HEIGHT_SHIFT_RANGE' : None, 'SHEAR_RANGE' : None, 'ZOOM_RANGE' : None, 'HORIZONTAL_FLIP' : False, 'VERTICAL_FLIP' : False, 'BRIGHTNESS_RANGE' : None, }) """ DATAGENERATORS """ # generator for training data train_generator = create_image_data_generator(x = datasets['X_train'], y = datasets['Y_train'], batch_size = params['BATCH_SIZE'], rescale = params['RESCALE'], rotation_range = params['ROTATION_RANGE'], width_shift_range = params['WIDTH_SHIFT_RANGE'], height_shift_range = params['HEIGHT_SHIFT_RANGE'], shear_range = params['SHEAR_RANGE'], zoom_range = params['ZOOM_RANGE'], horizontal_flip = params['HORIZONTAL_FLIP'], vertical_flip = params['VERTICAL_FLIP'], brightness_range = params['BRIGHTNESS_RANGE'], save_to_dir = None if paths['augmentation_folder'] is None else paths['augmentation_folder']) # generator for validation data val_generator = create_image_data_generator(x = datasets['X_val'], y = datasets['Y_val'], batch_size = params['BATCH_SIZE'], rescale = params['RESCALE']) # generator for test data test_generator = create_image_data_generator(x = datasets['X_test'], y = datasets['Y_test'], batch_size = params['BATCH_SIZE'], rescale = params['RESCALE']) """ CALLBACKS """ # empty list to hold callbacks callback_list = [] # early stopping callback if params['ES']: callback_list.append(EarlyStopping(monitor = 'val_loss', min_delta = 0, patience = params['ES_PATIENCE'], restore_best_weights = params['ES_RESTORE_WEIGHTS'], verbose = 1, mode = 'auto')) # save checkpoints model if params['SAVE_CHECKPOINTS']: # create checkpoint subfolder create_directory(os.path.join(model_save_folder, 'checkpoints')) callback_list.append(ModelCheckpoint(filepath = os.path.join(model_save_folder, 'checkpoints', 'checkpoint_model.{epoch:02d}_{val_loss:.3f}_{val_accuracy:.3f}.h5'), save_weights_only = False, monitor = 'val_loss', mode = 'auto', save_best_only = True)) """ TRAIN CNN MODEL """ # use multi GPUs mirrored_strategy = distribute.MirroredStrategy() # context manager for multi-gpu with mirrored_strategy.scope(): # get cnn model architecture model = get_cnn_model(cnn_type = params['ARCHITECTURE'], input_shape = params['INPUT_SHAPE'], num_classes = params['NUM_CLASSES'], learning_rate = params['LR'], optimizer_name = params['OPTIMIZER']) history = model.fit(train_generator, epochs = params['EPOCHS'], steps_per_epoch = len(train_generator), validation_data = val_generator, validation_steps = len(val_generator), callbacks = callback_list) # evaluate on test set history_test = model.evaluate(test_generator) # save the whole model model.save(os.path.join(model_save_folder, 'model.h5')) # save history of training pd.DataFrame(history.history).to_csv(os.path.join(model_save_folder, 'history_training.csv')) # save test results pd.DataFrame(history_test, index = ['loss', 'accuracy']).to_csv(os.path.join(model_save_folder, 'history_test.csv')) # save model hyperparameters pd.DataFrame(pd.Series(params)).to_csv(os.path.join(model_save_folder, 'params.csv'))