def run(gParameters): # load data x_train, y_train = unet.load_data() # example has 420 x 580 model = unet.build_model(420, 580, gParameters['activation'], gParameters['kernel_initializer']) model.summary() model.compile(optimizer=gParameters['optimizer'], loss='binary_crossentropy', metrics=['accuracy']) model_chkpoint = ModelCheckpoint('unet.hdf5', monitor='loss', verbose=1, save_best_only=True) history = model.fit(x_train, y_train, batch_size=gParameters['batch_size'], epochs=gParameters['epochs'], verbose=1, validation_split=0.3, shuffle=True, callbacks=[model_chkpoint]) return history
def train(): unet_model = unet.build_model(*oxford_iiit_pet.IMAGE_SIZE, channels=oxford_iiit_pet.channels, num_classes=oxford_iiit_pet.classes, layer_depth=4, filters_root=64, padding="same") unet.finalize_model(unet_model, loss=losses.SparseCategoricalCrossentropy(), metrics=[metrics.SparseCategoricalAccuracy()], auc=False, learning_rate=LEARNING_RATE) trainer = unet.Trainer(name="oxford_iiit_pet") train_dataset, validation_dataset = oxford_iiit_pet.load_data() trainer.fit(unet_model, train_dataset, validation_dataset, epochs=25, batch_size=1) return unet_model
def train(): unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=3, filters_root=16) unet.finalize_model(unet_model, learning_rate=LEARNING_RATE) trainer = unet.Trainer( name="circles", learning_rate_scheduler=unet.SchedulerType.WARMUP_LINEAR_DECAY, warmup_proportion=0.1, learning_rate=LEARNING_RATE) train_dataset, validation_dataset, test_dataset = circles.load_data( 100, nx=272, ny=272, r_max=20) trainer.fit(unet_model, train_dataset, validation_dataset, test_dataset, epochs=25, batch_size=5) return unet_model
def run( model_name, model_type=None, loss=unet.IOU_inverse, flip=False, translate=False, rotate=False, brightness=False, #augmentation batchnorm=False, droprate=None, regularizer=None, #model ): print(model_name) cp_path = os.path.join( OUTDIR, '{}-checkpoint.h5'.format(model_name)) #checkpoint path model_path = os.path.join(OUTDIR, '{}.h5'.format(model_name)) #final model path #create generators train_gen = Generator(scale=SCALE, batch_size=BATCHSIZE, train=True, flip=flip, translate=translate, rotate=rotate, brightness=brightness) val_gen = Generator(scale=SCALE, batch_size=BATCHSIZE, train=False) # build model model = unet.build_model(scale_factor=SCALE, verbose=1, loss=loss, batchnorm=batchnorm, droprate=droprate, regularizer=regularizer) model.name = model_name # train! earlystop = EarlyStopping(patience=PATIENCE, verbose=1) checkpoint = ModelCheckpoint(cp_path, verbose=0, save_best_only=True) model_hist = model.fit_generator( generator=train_gen, steps_per_epoch=len(train_gen.image_IDs) // train_gen.batch_size, validation_data=val_gen, validation_steps=len(val_gen.image_IDs) // val_gen.batch_size, epochs=EPOCHS, callbacks=[earlystop, checkpoint], verbose=2) #save weights model.save(model_path) print("{} saved to disk".format(model.name)) print(model_hist.history, '\n')
def train(): callback = tf.keras.callbacks.LearningRateScheduler(scheduler) Ground_truth = readimagesintonumpyArray( '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Ground/' )[0] Ground_truth = makegroundtruth2(Ground_truth) train_data = readimagesintonumpyArray( '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Training/Training1/' )[0] train_data = maketrainingdata(train_data) validation_imgz = readimagesintonumpyArray( '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Testing/' )[0] validation_imgz = maketrainingdata(validation_imgz) validation_groundtruth = readimagesintonumpyArray( '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/TestGround/' )[0] validation_groundtruth = makegroundtruth2(validation_groundtruth) validation_dataset = tf.data.Dataset.from_tensor_slices( (validation_imgz, validation_groundtruth)) train_dataset = tf.data.Dataset.from_tensor_slices( (train_data, Ground_truth)) unet_model = unet.build_model(channels=1, num_classes=2, layer_depth=3, filters_root=16) unet_model = tf.keras.models.load_model( '/home/stephen/Downloads/TF2UNET/unet/scripts/circles/2021-04-17T00-12_02/', custom_objects=custom_objects) unet.finalize_model(unet_model, loss="binary_crossentropy", learning_rate=LEARNING_RATE) trainer = unet.Trainer(name="circles", learning_rate=LEARNING_RATE, tensorboard_callback=True, learning_rate_scheduler=scheduler) trainer.fit(unet_model, train_dataset, validation_dataset, epochs=70, batch_size=10) return unet_model
def main(start_index=0, last_index=99, filename=None, plot_validation=False, plot_test=True, calculate_train_metric=False): """ :param start_index: :param filename: :param plot_validation: Plots 3 samples from the validation set each fold :param plot_test: Plots the test test image for each fold :return: """ if filename is None: now = datetime.now() current_dt = now.strftime("%y_%m_%d_%H_%M_%S") filename = "results/" + current_dt + ".csv" results_file = Path(filename) if not results_file.is_file(): results_file.write_text( 'index;jaccard;Dice;Adj;Warp;jaccard_to;Dice_to;Adj_to;Warp_to\n') """ Load data """ #image_path = "data/BBBC004_v1_images/*/" #label_path = "data/BBBC004_v1_foreground/*/" image_path = "../datasets/BBBC004/images/all/" label_path = "../datasets/BBBC004/masks/all/" file_extension = "tif" inp_dim = 950 file_names = sorted(glob.glob(image_path + "*." + file_extension)) file_names_labels = sorted(glob.glob(label_path + "*." + file_extension)) print(file_names) print(file_names_labels) # Determine largest and smallest pixel values in the dataset min_val = float('inf') max_val = float('-inf') for filename in file_names: img = plt.imread(filename) if np.min(img) < min_val: min_val = np.min(img) if np.max(img) > max_val: max_val = np.max(img) images = [] for file in file_names: if file_extension == "tif": images.append( tf.convert_to_tensor(np.expand_dims(plt.imread(file), axis=2))) # For .tif #images[-1] = images[-1] / 255 # Normalize images[-1] = (images[-1] - min_val) / (max_val - min_val) images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') #print(np.min(images[-1]), np.max(images[-1])) elif file_extension == "png": images.append(tf.convert_to_tensor( plt.imread(file)[:, :, :3])) # For .png images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') images[-1] = tf.image.rgb_to_grayscale(images[-1]) images[-1] = mirror_pad_image(images[-1], pixels=21) labels = [] for file in file_names_labels: label = plt.imread(file)[:, :, :3] label = (np.expand_dims(np.sum(label, axis=2), axis=2)) label = np.where(label > 0, [0, 1], [1, 0]) labels.append(tf.convert_to_tensor(label)) labels[-1] = tf.image.resize(labels[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') labels[-1] = np.where(labels[-1] > 0.5, 1, 0) labels[-1] = mirror_pad_image(labels[-1], pixels=21) print("num images: " + str(len(images))) print("num labels: " + str(len(labels))) num_data_points = len(images) for test_data_point_index in range(start_index, num_data_points): if test_data_point_index > last_index: break print("\nStarted for data_point_index: " + str(test_data_point_index)) images_temp = images.copy() labels_temp = labels.copy() """for i in range((5)): plt.matshow(images_temp[i][..., -1]) plt.show() plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray) plt.show()""" test_image = images_temp.pop(test_data_point_index) test_label = labels_temp.pop(test_data_point_index) test_dataset = tf.data.Dataset.from_tensor_slices( ([test_image], [test_label])) print("num images: " + str(len(images_temp))) print("num labels: " + str(len(labels_temp))) random_permutation = np.random.permutation(len(images_temp)) images_temp = np.array(images_temp)[random_permutation] labels_temp = np.array(labels_temp)[random_permutation] image_dataset = tf.data.Dataset.from_tensor_slices( (images_temp, labels_temp)) """Crate data splits""" data_augmentation = tf.keras.Sequential([ tf.keras.layers.experimental.preprocessing.RandomFlip( "horizontal_and_vertical"), tf.keras.layers.experimental.preprocessing.RandomRotation(0.2), ]) # image_dataset.shuffle(100, reshuffle_each_iteration=False) train_dataset = image_dataset.take(80) validation_dataset = image_dataset.skip(80) train_dataset.shuffle(80, reshuffle_each_iteration=True) train_dataset = train_dataset.map( augment_image) # Apply transformations to training data """Load model""" print(circles.channels) print(circles.classes) unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=3, filters_root=16) if calculate_train_metric: unet.finalize_model(unet_model) else: unet.finalize_model(unet_model, dice_coefficient=False, auc=False, mean_iou=False) # Don't track so many metrics """Train""" # Use early stopping or not? # es_callback = tf.keras.callbacks.EarlyStopping( # monitor='val_loss', # patience=6, # restore_best_weights=True) trainer = unet.Trainer( checkpoint_callback=False, tensorboard_callback=False, tensorboard_images_callback=False, #callbacks=[es_callback] ) trainer.fit( unet_model, train_dataset, #validation_dataset, epochs=40, batch_size=2) """Calculate best amplification""" prediction = unet_model.predict(validation_dataset.batch(batch_size=1)) original_images = [] metric_labels = [] metric_predictions_unprocessed = [] metric_predictions = [] dataset = validation_dataset.map( utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2))) prediction = remove_border(prediction, inp_dim, inp_dim) for i, (image, label) in enumerate(dataset): original_images.append(image[..., -1]) metric_labels.append(np.argmax(label, axis=-1)) metric_predictions_unprocessed.append( normalize_output(prediction[i, ...])) best_tau, best_score = get_best_threshold( metric_predictions_unprocessed, metric_labels, min=0, max=1, num_steps=50, use_metric=1) #best_tau = 0.5 # Use this to not threshold at all, also comment above print("Best tau: " + str(best_tau)) print("Best avg score: " + str(best_score)) for i in range(len(metric_predictions_unprocessed)): metric_predictions.append( (metric_predictions_unprocessed[i] >= best_tau).astype(int)) if plot_validation: fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(8, 8)) for i in range(3): ax[i][0].matshow(original_images[i]) ax[i][1].matshow(metric_labels[i], cmap=plt.cm.gray) ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray) plt.tight_layout() plt.show() original_images = [] metric_labels_test = [] metric_predictions_unprocessed_test = [] metric_predictions = [] metric_predictions_unthresholded = [] """Evaluate and print to file""" prediction = unet_model.predict(test_dataset.batch(batch_size=1)) dataset = test_dataset.map( utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2))) prediction = remove_border(prediction, inp_dim, inp_dim) print("Test shape shape: ", prediction.shape) for i, (image, label) in enumerate(dataset): original_images.append(image[..., -1]) metric_labels_test.append(np.argmax(label, axis=-1)) metric_predictions_unprocessed_test.append(prediction[i, ...]) for i in range(len(metric_predictions_unprocessed_test)): metric_predictions.append( (normalize_output(metric_predictions_unprocessed_test[i]) >= best_tau).astype(int)) metric_predictions_unthresholded.append((normalize_output( metric_predictions_unprocessed_test[i]) >= 0.5).astype(int)) # Calculate thresholded and unthresholded metrics in parallel parallel_metrics = [ Metrics(metric_labels_test, metric_predictions_unthresholded, safe=False, parallel=False), Metrics(metric_labels_test, metric_predictions, safe=False, parallel=False) ] def f(m): return (m.jaccard()[0], m.dice()[0], m.adj_rand()[0], m.warping_error()[0]) pool = Pool(2) metric_result = pool.map(f, parallel_metrics) jaccard_index = metric_result[0][0] dice = metric_result[0][1] adj = metric_result[0][2] warping_error = metric_result[0][3] jaccard_index_to = metric_result[1][0] dice_to = metric_result[1][1] adj_to = metric_result[1][2] warping_error_to = metric_result[1][3] with results_file.open("a") as f: f.write( str(test_data_point_index) + ";" + str(jaccard_index) + ";" + str(dice) + ";" + str(adj) + ";" + str(warping_error) + ";" + str(jaccard_index_to) + ";" + str(dice_to) + ";" + str(adj_to) + ";" + str(warping_error_to) + "\n") print("test_data_point_index: " + str(test_data_point_index)) print("Jaccard index: " + str(jaccard_index) + " with threshold optimization: " + str(jaccard_index_to)) print("Dice: " + str(dice) + " with threshold optimization: " + str(dice_to)) print("Adj: " + str(adj) + " with threshold optimization: " + str(adj_to)) print("Warping Error: " + str(warping_error) + " with threshold optimization: " + str(warping_error_to)) """Plot predictions""" if plot_test: fig, ax = plt.subplots(1, 3, figsize=(8, 4)) fig.suptitle("Test point: " + str(test_data_point_index), fontsize=14) ax[0].matshow(original_images[i]) ax[0].set_title("Input data") ax[0].set_axis_off() ax[1].matshow(metric_labels[i], cmap=plt.cm.gray) ax[1].set_title("True mask") ax[1].set_axis_off() ax[2].matshow(metric_predictions[i], cmap=plt.cm.gray) ax[2].set_title("Predicted mask") ax[2].set_axis_off() fig.tight_layout() plt.show()
def main(start_index=0, last_index=199, filename=None, plot=True, store_masks=False): if filename is None: now = datetime.now() current_dt = now.strftime("%y_%m_%d_%H_%M_%S") filename = "results/" + current_dt + ".csv" results_file = Path(filename) if not results_file.is_file(): results_file.write_text( 'index;jaccard;Dice;Adj;Warp;jaccard_to;Dice_to;Adj_to;Warp_to\n') """ Load data """ print("Start read") images, labels = read_data() print("Done read") min_val = float('inf') max_val = float('-inf') for img in images: if np.min(img) < min_val: min_val = np.min(img) if np.max(img) > max_val: max_val = np.max(img) print(min_val, max_val) #images = [np.expand_dims(image, axis=2)/ max(np.max(image), 255) for image in images] # Normalize relative to entire dataset images = [(np.expand_dims(image, axis=2) - min_val) / (max_val - min_val) for image in images] labels = [split_into_classes(label[:, :, :2]) for label in labels] print(np.array(images).shape) print(np.array(labels).shape) for i in range(len(images)): images[i] = mirror_pad_image(images[i]) labels[i] = mirror_pad_image(labels[i]) print("num images: " + str(len(images))) print("num labels: " + str(len(labels))) num_data_points = len(images) for test_data_point_index in range(start_index, num_data_points): if test_data_point_index > last_index: break print("\nStarted for data_point_index: " + str(test_data_point_index)) images_temp = images.copy() labels_temp = labels.copy() """for i in range((5)): plt.matshow(images_temp[i][..., -1]) plt.show() plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray) plt.show()""" test_image = images_temp.pop(test_data_point_index) test_label = labels_temp.pop(test_data_point_index) test_dataset = tf.data.Dataset.from_tensor_slices( ([test_image], [test_label])) print("num images: " + str(len(images_temp))) print("num labels: " + str(len(labels_temp))) random_permutation = np.random.permutation(len(images_temp)) images_temp = np.array(images_temp)[random_permutation] labels_temp = np.array(labels_temp)[random_permutation] image_dataset = tf.data.Dataset.from_tensor_slices( (images_temp, labels_temp)) """Crate data splits""" data_augmentation = tf.keras.Sequential([ tf.keras.layers.experimental.preprocessing.RandomFlip( "horizontal_and_vertical"), tf.keras.layers.experimental.preprocessing.RandomRotation(0.2), ]) train_dataset = image_dataset.take(160) validation_dataset = image_dataset.skip(160) train_dataset.shuffle(160, reshuffle_each_iteration=True) train_dataset = train_dataset.map( augment_image) # Apply transformations to training data """Load model""" print(circles.channels) print(circles.classes) unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=3, filters_root=16) unet.finalize_model(unet_model, dice_coefficient=False, auc=False, mean_iou=False) # Don't track so many metrics """Train""" # Use early stopping or not? # es_callback = tf.keras.callbacks.EarlyStopping( # monitor='val_loss', # patience=6, # restore_best_weights=True) trainer = unet.Trainer( checkpoint_callback=False, tensorboard_callback=False, tensorboard_images_callback=False, #callbacks=[es_callback] ) trainer.fit( unet_model, train_dataset, #validation_dataset, epochs=40, batch_size=2) """Calculate best amplification""" prediction = unet_model.predict(validation_dataset.batch(batch_size=1)) original_images = [] metric_labels = [] metric_predictions_unprocessed = [] metric_predictions = [] dataset = validation_dataset.map( utils.crop_image_and_label_to_shape(prediction.shape[1:])) for i, (image, label) in enumerate(dataset): original_images.append(image[..., -1]) metric_labels.append(np.argmax(label, axis=-1)) metric_predictions_unprocessed.append( normalize_output(prediction[i, ...])) best_tau, best_score = get_best_threshold( metric_predictions_unprocessed, metric_labels, min=0, max=1, num_steps=50, use_metric=1) #best_tau = 0.5 # Use this to not threshold at all, also comment above print("Best tau: " + str(best_tau)) print("Best avg score: " + str(best_score)) for i in range(len(metric_predictions_unprocessed)): metric_predictions.append( (metric_predictions_unprocessed[i] >= best_tau).astype(int)) if plot: fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(8, 8)) for i in range(3): ax[i][0].matshow(original_images[i]) ax[i][1].matshow(metric_labels[i], cmap=plt.cm.gray) ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray) plt.tight_layout() plt.show() original_images = [] metric_labels_test = [] metric_predictions_unprocessed_test = [] metric_predictions = [] metric_predictions_unthresholded = [] """Evaluate and print to file""" prediction = unet_model.predict(test_dataset.batch(batch_size=1)) dataset = test_dataset.map( utils.crop_image_and_label_to_shape(prediction.shape[1:])) for i, (image, label) in enumerate(dataset): original_images.append(image[..., -1]) metric_labels_test.append(np.argmax(label, axis=-1)) metric_predictions_unprocessed_test.append(prediction[i, ...]) for i in range(len(metric_predictions_unprocessed_test)): metric_predictions.append( (normalize_output(metric_predictions_unprocessed_test[i]) >= best_tau).astype(int)) metric_predictions_unthresholded.append((normalize_output( metric_predictions_unprocessed_test[i]) >= 0.5).astype(int)) # Calculate thresholded and unthresholded metrics in parallel parallel_metrics = [ Metrics(metric_labels_test, metric_predictions_unthresholded, safe=False, parallel=False), Metrics(metric_labels_test, metric_predictions, safe=False, parallel=False) ] def f(m): return (m.jaccard()[0], m.dice()[0], m.adj_rand()[0], m.warping_error()[0]) pool = Pool(2) metric_result = pool.map(f, parallel_metrics) jaccard_index = metric_result[0][0] dice = metric_result[0][1] adj = metric_result[0][2] warping_error = metric_result[0][3] jaccard_index_to = metric_result[1][0] dice_to = metric_result[1][1] adj_to = metric_result[1][2] warping_error_to = metric_result[1][3] with results_file.open("a") as f: f.write( str(test_data_point_index) + ";" + str(jaccard_index) + ";" + str(dice) + ";" + str(adj) + ";" + str(warping_error) + ";" + str(jaccard_index_to) + ";" + str(dice_to) + ";" + str(adj_to) + ";" + str(warping_error_to) + "\n") print("test_data_point_index: " + str(test_data_point_index)) print("Jaccard index: " + str(jaccard_index) + " with threshold optimization: " + str(jaccard_index_to)) print("Dice: " + str(dice) + " with threshold optimization: " + str(dice_to)) print("Adj: " + str(adj) + " with threshold optimization: " + str(adj_to)) print("Warping Error: " + str(warping_error) + " with threshold optimization: " + str(warping_error_to)) """Plot predictions""" if plot: fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(8, 8)) for i in range(len(metric_labels_test)): ax[i][0].matshow(original_images[i]) ax[i][1].matshow(metric_labels_test[i], cmap=plt.cm.gray) ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray) plt.tight_layout() plt.show() if store_masks: np.save( "results/BBBC039_val_fold_" + str(test_data_point_index) + ".npy", metric_predictions_unprocessed) np.save( "results/BBBC039_val_true_fold_" + str(test_data_point_index) + ".npy", metric_labels) np.save( "results/BBBC039_test_fold_" + str(test_data_point_index) + ".npy", metric_predictions_unprocessed_test) np.save( "results/BBBC039_test_true_fold_" + str(test_data_point_index) + ".npy", metric_labels)
def main(filename=None, calculate_train_metric=False): """ :param start_index: :param filename: :param plot_validation: Plots 3 samples from the validation set each fold :param plot_test: Plots the test test image for each fold :return: """ now = datetime.now() current_dt = now.strftime("%y_%m_%d_%H_%M_%S") if filename is None: filename = "results/" + current_dt + ".csv" results_file = Path(filename) if not results_file.is_file(): results_file.write_text('index; jaccard; Dice; Adj; Warp\n') """ Load data """ image_path = "data/synthetic/images/" label_path = "data/synthetic/labels/" file_extension = "tif" # inp_dim = 572 # inp_dim = 200 # inp_dim = 710 inp_dim = 1024 file_names = sorted(glob.glob(image_path + "*." + file_extension)) file_names_labels = sorted(glob.glob(label_path + "*." + file_extension)) print(file_names) print(file_names_labels) images = [] for file in file_names: if file_extension == "tif": images.append( tf.convert_to_tensor(np.expand_dims(plt.imread(file), axis=2))) # For .tif images[-1] = images[-1] / 255 # Normalize images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') elif file_extension == "png": images.append(tf.convert_to_tensor( plt.imread(file)[:, :, :3])) # For .png images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') images[-1] = tf.image.rgb_to_grayscale(images[-1]) images[-1] = mirror_pad_image(images[-1], pixels=20) labels = [] for file in file_names_labels: label = plt.imread(file) # label = plt.imread(file)[:, :, :3] label = (np.expand_dims(label, axis=2)) label = np.where(label > 0, [0, 1], [1, 0]) labels.append(tf.convert_to_tensor(label)) labels[-1] = tf.image.resize(labels[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') labels[-1] = np.where(labels[-1] > 0.5, 1, 0) labels[-1] = mirror_pad_image(labels[-1], pixels=20) print("num images: " + str(len(images))) print("num labels: " + str(len(labels))) num_data_points = len(images) scilife_images, scilife_labels = scilife_data() # plt.matshow(scilife_images[1][..., -1]) # plt.show() # # for i in range(len(scilife_images)): # print(np.max(scilife_images[i])) images_temp = images.copy() labels_temp = labels.copy() """for i in range((5)): plt.matshow(images_temp[i][..., -1]) plt.show() plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray) plt.show()""" print("num images: " + str(len(images_temp))) print("num labels: " + str(len(labels_temp))) random_permutation = np.random.permutation(len(images_temp)) images_temp = np.array(images_temp)[random_permutation] labels_temp = np.array(labels_temp)[random_permutation] image_dataset = tf.data.Dataset.from_tensor_slices( (images_temp, labels_temp)) """Crate data splits""" train_dataset = image_dataset.take(100) validation_dataset = image_dataset.skip(100) train_dataset.shuffle(100, reshuffle_each_iteration=True) train_dataset = train_dataset.map( augment_image) # Apply transformations to training data """Load model""" print(circles.channels) print(circles.classes) unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=3, filters_root=16) if calculate_train_metric: unet.finalize_model(unet_model) else: unet.finalize_model(unet_model, dice_coefficient=False, auc=False, mean_iou=False) """Train""" # callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) # trainer = unet.Trainer(checkpoint_callback=False, callbacks=[callback]) trainer = unet.Trainer(checkpoint_callback=False) trainer.fit(unet_model, train_dataset, epochs=25, batch_size=1) """Sci Life data prediction""" scilife_dataset = tf.data.Dataset.from_tensor_slices( (scilife_images, scilife_labels)) prediction = unet_model.predict(scilife_dataset.batch(batch_size=1)) original_images = [] metric_labels = [] metric_predictions_unprocessed = [] metric_predictions = [] dataset = scilife_dataset.map( utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2))) prediction = remove_border(prediction, inp_dim, inp_dim) # print("Validation shape after: ", prediction.shape) for i, (image, _) in enumerate(dataset): original_images.append(image[..., -1]) metric_predictions_unprocessed.append(prediction[i, ...]) for i in range(len(metric_predictions_unprocessed)): metric_predictions.append( np.argmax(metric_predictions_unprocessed[i] * np.array([[[1, 1]]]), axis=-1)) fig, ax = plt.subplots(5, 2, sharex=True, sharey=True, figsize=(25, 60)) for i in range(5): ax[i][0].matshow(original_images[i]) ax[i][1].matshow(metric_predictions[i], cmap=plt.cm.gray) plt.imsave("results/scilifelab_" + str(current_dt) + "_index_" + str(i) + ".png", metric_predictions[i], cmap=plt.cm.gray) plt.tight_layout() plt.savefig("results/scilifelab_" + str(current_dt) + ".png") plt.show()