def train(): unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=3, filters_root=16) unet.finalize_model(unet_model, learning_rate=LEARNING_RATE) trainer = unet.Trainer( name="circles", learning_rate_scheduler=unet.SchedulerType.WARMUP_LINEAR_DECAY, warmup_proportion=0.1, learning_rate=LEARNING_RATE) train_dataset, validation_dataset, test_dataset = circles.load_data( 100, nx=272, ny=272, r_max=20) trainer.fit(unet_model, train_dataset, validation_dataset, test_dataset, epochs=25, batch_size=5) return unet_model
def train(): unet_model = unet.build_model(*oxford_iiit_pet.IMAGE_SIZE, channels=oxford_iiit_pet.channels, num_classes=oxford_iiit_pet.classes, layer_depth=4, filters_root=64, padding="same") unet.finalize_model(unet_model, loss=losses.SparseCategoricalCrossentropy(), metrics=[metrics.SparseCategoricalAccuracy()], auc=False, learning_rate=LEARNING_RATE) trainer = unet.Trainer(name="oxford_iiit_pet") train_dataset, validation_dataset = oxford_iiit_pet.load_data() trainer.fit(unet_model, train_dataset, validation_dataset, epochs=25, batch_size=1) return unet_model
def train(): callback = tf.keras.callbacks.LearningRateScheduler(scheduler) Ground_truth = readimagesintonumpyArray( '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Ground/' )[0] Ground_truth = makegroundtruth2(Ground_truth) train_data = readimagesintonumpyArray( '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Training/Training1/' )[0] train_data = maketrainingdata(train_data) validation_imgz = readimagesintonumpyArray( '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Testing/' )[0] validation_imgz = maketrainingdata(validation_imgz) validation_groundtruth = readimagesintonumpyArray( '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/TestGround/' )[0] validation_groundtruth = makegroundtruth2(validation_groundtruth) validation_dataset = tf.data.Dataset.from_tensor_slices( (validation_imgz, validation_groundtruth)) train_dataset = tf.data.Dataset.from_tensor_slices( (train_data, Ground_truth)) unet_model = unet.build_model(channels=1, num_classes=2, layer_depth=3, filters_root=16) unet_model = tf.keras.models.load_model( '/home/stephen/Downloads/TF2UNET/unet/scripts/circles/2021-04-17T00-12_02/', custom_objects=custom_objects) unet.finalize_model(unet_model, loss="binary_crossentropy", learning_rate=LEARNING_RATE) trainer = unet.Trainer(name="circles", learning_rate=LEARNING_RATE, tensorboard_callback=True, learning_rate_scheduler=scheduler) trainer.fit(unet_model, train_dataset, validation_dataset, epochs=70, batch_size=10) return unet_model
width = 160 batch_size = 10 train_path = '/DB/rhome/qyzheng/Desktop/qyzheng/source/renji_data/from_senior/0_cv_train.csv' val_path = '/DB/rhome/qyzheng/Desktop/qyzheng/source/renji_data/from_senior/0_cv_val.csv' dataset, iters = image_gen.GetDataset(train_path, batch_size) generator = image_gen.BladderDataProvider(height, width, dataset) """ x_test, y_test = generator(1) fig, ax = plt.subplots(1, 2, sharey=True, figsize=(8, 4)) ax[0].imshow(x_test[0, ..., 0], aspect="auto") ax[1].imshow(y_test[0, ..., 1], aspect="auto") #plt.show() """ net = unet.Unet(channels=generator.channels, n_class=generator.n_class, layers=4, features_root=64) trainer = unet.Trainer(net, batch_size=4, optimizer="momentum", opt_kwargs=dict(momentum=0.2)) path = trainer.train(generator, "../unet_trained", training_iters=iters, epochs=100, display_step=4, prediction_path='/DATA/data/sxfeng/data/IVDM3Seg/result/result_2/prediction') ''' x_test, y_test = generator(1) print(x_test.shape) print(y_test.shape) prediction = net.predict("../unet_trained/model.ckpt", x_test) print(prediction.shape) ''' """ fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12, 5)) ax[0].imshow(x_test[0,...,0], aspect="auto") ax[1].imshow(y_test[0,...,1], aspect="auto") mask = prediction[0,...,1] > 0.3
def main(start_index=0, last_index=99, filename=None, plot_validation=False, plot_test=True, calculate_train_metric=False): """ :param start_index: :param filename: :param plot_validation: Plots 3 samples from the validation set each fold :param plot_test: Plots the test test image for each fold :return: """ if filename is None: now = datetime.now() current_dt = now.strftime("%y_%m_%d_%H_%M_%S") filename = "results/" + current_dt + ".csv" results_file = Path(filename) if not results_file.is_file(): results_file.write_text( 'index;jaccard;Dice;Adj;Warp;jaccard_to;Dice_to;Adj_to;Warp_to\n') """ Load data """ #image_path = "data/BBBC004_v1_images/*/" #label_path = "data/BBBC004_v1_foreground/*/" image_path = "../datasets/BBBC004/images/all/" label_path = "../datasets/BBBC004/masks/all/" file_extension = "tif" inp_dim = 950 file_names = sorted(glob.glob(image_path + "*." + file_extension)) file_names_labels = sorted(glob.glob(label_path + "*." + file_extension)) print(file_names) print(file_names_labels) # Determine largest and smallest pixel values in the dataset min_val = float('inf') max_val = float('-inf') for filename in file_names: img = plt.imread(filename) if np.min(img) < min_val: min_val = np.min(img) if np.max(img) > max_val: max_val = np.max(img) images = [] for file in file_names: if file_extension == "tif": images.append( tf.convert_to_tensor(np.expand_dims(plt.imread(file), axis=2))) # For .tif #images[-1] = images[-1] / 255 # Normalize images[-1] = (images[-1] - min_val) / (max_val - min_val) images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') #print(np.min(images[-1]), np.max(images[-1])) elif file_extension == "png": images.append(tf.convert_to_tensor( plt.imread(file)[:, :, :3])) # For .png images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') images[-1] = tf.image.rgb_to_grayscale(images[-1]) images[-1] = mirror_pad_image(images[-1], pixels=21) labels = [] for file in file_names_labels: label = plt.imread(file)[:, :, :3] label = (np.expand_dims(np.sum(label, axis=2), axis=2)) label = np.where(label > 0, [0, 1], [1, 0]) labels.append(tf.convert_to_tensor(label)) labels[-1] = tf.image.resize(labels[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') labels[-1] = np.where(labels[-1] > 0.5, 1, 0) labels[-1] = mirror_pad_image(labels[-1], pixels=21) print("num images: " + str(len(images))) print("num labels: " + str(len(labels))) num_data_points = len(images) for test_data_point_index in range(start_index, num_data_points): if test_data_point_index > last_index: break print("\nStarted for data_point_index: " + str(test_data_point_index)) images_temp = images.copy() labels_temp = labels.copy() """for i in range((5)): plt.matshow(images_temp[i][..., -1]) plt.show() plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray) plt.show()""" test_image = images_temp.pop(test_data_point_index) test_label = labels_temp.pop(test_data_point_index) test_dataset = tf.data.Dataset.from_tensor_slices( ([test_image], [test_label])) print("num images: " + str(len(images_temp))) print("num labels: " + str(len(labels_temp))) random_permutation = np.random.permutation(len(images_temp)) images_temp = np.array(images_temp)[random_permutation] labels_temp = np.array(labels_temp)[random_permutation] image_dataset = tf.data.Dataset.from_tensor_slices( (images_temp, labels_temp)) """Crate data splits""" data_augmentation = tf.keras.Sequential([ tf.keras.layers.experimental.preprocessing.RandomFlip( "horizontal_and_vertical"), tf.keras.layers.experimental.preprocessing.RandomRotation(0.2), ]) # image_dataset.shuffle(100, reshuffle_each_iteration=False) train_dataset = image_dataset.take(80) validation_dataset = image_dataset.skip(80) train_dataset.shuffle(80, reshuffle_each_iteration=True) train_dataset = train_dataset.map( augment_image) # Apply transformations to training data """Load model""" print(circles.channels) print(circles.classes) unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=3, filters_root=16) if calculate_train_metric: unet.finalize_model(unet_model) else: unet.finalize_model(unet_model, dice_coefficient=False, auc=False, mean_iou=False) # Don't track so many metrics """Train""" # Use early stopping or not? # es_callback = tf.keras.callbacks.EarlyStopping( # monitor='val_loss', # patience=6, # restore_best_weights=True) trainer = unet.Trainer( checkpoint_callback=False, tensorboard_callback=False, tensorboard_images_callback=False, #callbacks=[es_callback] ) trainer.fit( unet_model, train_dataset, #validation_dataset, epochs=40, batch_size=2) """Calculate best amplification""" prediction = unet_model.predict(validation_dataset.batch(batch_size=1)) original_images = [] metric_labels = [] metric_predictions_unprocessed = [] metric_predictions = [] dataset = validation_dataset.map( utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2))) prediction = remove_border(prediction, inp_dim, inp_dim) for i, (image, label) in enumerate(dataset): original_images.append(image[..., -1]) metric_labels.append(np.argmax(label, axis=-1)) metric_predictions_unprocessed.append( normalize_output(prediction[i, ...])) best_tau, best_score = get_best_threshold( metric_predictions_unprocessed, metric_labels, min=0, max=1, num_steps=50, use_metric=1) #best_tau = 0.5 # Use this to not threshold at all, also comment above print("Best tau: " + str(best_tau)) print("Best avg score: " + str(best_score)) for i in range(len(metric_predictions_unprocessed)): metric_predictions.append( (metric_predictions_unprocessed[i] >= best_tau).astype(int)) if plot_validation: fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(8, 8)) for i in range(3): ax[i][0].matshow(original_images[i]) ax[i][1].matshow(metric_labels[i], cmap=plt.cm.gray) ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray) plt.tight_layout() plt.show() original_images = [] metric_labels_test = [] metric_predictions_unprocessed_test = [] metric_predictions = [] metric_predictions_unthresholded = [] """Evaluate and print to file""" prediction = unet_model.predict(test_dataset.batch(batch_size=1)) dataset = test_dataset.map( utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2))) prediction = remove_border(prediction, inp_dim, inp_dim) print("Test shape shape: ", prediction.shape) for i, (image, label) in enumerate(dataset): original_images.append(image[..., -1]) metric_labels_test.append(np.argmax(label, axis=-1)) metric_predictions_unprocessed_test.append(prediction[i, ...]) for i in range(len(metric_predictions_unprocessed_test)): metric_predictions.append( (normalize_output(metric_predictions_unprocessed_test[i]) >= best_tau).astype(int)) metric_predictions_unthresholded.append((normalize_output( metric_predictions_unprocessed_test[i]) >= 0.5).astype(int)) # Calculate thresholded and unthresholded metrics in parallel parallel_metrics = [ Metrics(metric_labels_test, metric_predictions_unthresholded, safe=False, parallel=False), Metrics(metric_labels_test, metric_predictions, safe=False, parallel=False) ] def f(m): return (m.jaccard()[0], m.dice()[0], m.adj_rand()[0], m.warping_error()[0]) pool = Pool(2) metric_result = pool.map(f, parallel_metrics) jaccard_index = metric_result[0][0] dice = metric_result[0][1] adj = metric_result[0][2] warping_error = metric_result[0][3] jaccard_index_to = metric_result[1][0] dice_to = metric_result[1][1] adj_to = metric_result[1][2] warping_error_to = metric_result[1][3] with results_file.open("a") as f: f.write( str(test_data_point_index) + ";" + str(jaccard_index) + ";" + str(dice) + ";" + str(adj) + ";" + str(warping_error) + ";" + str(jaccard_index_to) + ";" + str(dice_to) + ";" + str(adj_to) + ";" + str(warping_error_to) + "\n") print("test_data_point_index: " + str(test_data_point_index)) print("Jaccard index: " + str(jaccard_index) + " with threshold optimization: " + str(jaccard_index_to)) print("Dice: " + str(dice) + " with threshold optimization: " + str(dice_to)) print("Adj: " + str(adj) + " with threshold optimization: " + str(adj_to)) print("Warping Error: " + str(warping_error) + " with threshold optimization: " + str(warping_error_to)) """Plot predictions""" if plot_test: fig, ax = plt.subplots(1, 3, figsize=(8, 4)) fig.suptitle("Test point: " + str(test_data_point_index), fontsize=14) ax[0].matshow(original_images[i]) ax[0].set_title("Input data") ax[0].set_axis_off() ax[1].matshow(metric_labels[i], cmap=plt.cm.gray) ax[1].set_title("True mask") ax[1].set_axis_off() ax[2].matshow(metric_predictions[i], cmap=plt.cm.gray) ax[2].set_title("Predicted mask") ax[2].set_axis_off() fig.tight_layout() plt.show()
def test_fit(self, tmp_path): output_shape = (8, 8, 2) image_shape = (10, 10, 3) epochs = 5 shuffle = True batch_size = 10 model = Mock(name="model") model.predict().shape = (None, *output_shape) mock_callback = Mock() trainer = unet.Trainer( name="test", log_dir_path=str(tmp_path), checkpoint_callback=True, tensorboard_callback=True, tensorboard_images_callback=True, callbacks=[mock_callback], learning_rate_scheduler=unet.SchedulerType.WARMUP_LINEAR_DECAY, warmup_proportion=0.1, learning_rate=1.0) train_dataset = _build_dataset(image_shape=image_shape) validation_dataset = _build_dataset(image_shape=image_shape) test_dataset = _build_dataset(image_shape=image_shape) trainer.fit(model, train_dataset=train_dataset, validation_dataset=validation_dataset, test_dataset=test_dataset, epochs=epochs, batch_size=batch_size, shuffle=shuffle) args, kwargs = model.fit.call_args train_dataset = args[0] validation_dataset = kwargs["validation_data"] assert tuple(train_dataset.element_spec[0].shape) == (None, *image_shape) assert tuple(train_dataset.element_spec[1].shape) == (None, *output_shape) assert train_dataset._batch_size.numpy() == batch_size assert validation_dataset._batch_size.numpy() == batch_size assert tuple( validation_dataset.element_spec[0].shape) == (None, *image_shape) assert tuple( validation_dataset.element_spec[1].shape) == (None, *output_shape) callbacks = kwargs["callbacks"] callback_types = [type(callback) for callback in callbacks] assert mock_callback in callbacks assert ModelCheckpoint in callback_types assert TensorBoardWithLearningRate in callback_types assert TensorBoardImageSummary in callback_types assert LearningRateScheduler in callback_types assert kwargs["epochs"] == epochs assert kwargs["shuffle"] == shuffle args, kwargs = model.evaluate.call_args test_dataset = args[0] assert tuple(test_dataset.element_spec[0].shape) == (None, *image_shape) assert tuple(test_dataset.element_spec[1].shape) == (None, *output_shape)
def main(start_index=0, last_index=199, filename=None, plot=True, store_masks=False): if filename is None: now = datetime.now() current_dt = now.strftime("%y_%m_%d_%H_%M_%S") filename = "results/" + current_dt + ".csv" results_file = Path(filename) if not results_file.is_file(): results_file.write_text( 'index;jaccard;Dice;Adj;Warp;jaccard_to;Dice_to;Adj_to;Warp_to\n') """ Load data """ print("Start read") images, labels = read_data() print("Done read") min_val = float('inf') max_val = float('-inf') for img in images: if np.min(img) < min_val: min_val = np.min(img) if np.max(img) > max_val: max_val = np.max(img) print(min_val, max_val) #images = [np.expand_dims(image, axis=2)/ max(np.max(image), 255) for image in images] # Normalize relative to entire dataset images = [(np.expand_dims(image, axis=2) - min_val) / (max_val - min_val) for image in images] labels = [split_into_classes(label[:, :, :2]) for label in labels] print(np.array(images).shape) print(np.array(labels).shape) for i in range(len(images)): images[i] = mirror_pad_image(images[i]) labels[i] = mirror_pad_image(labels[i]) print("num images: " + str(len(images))) print("num labels: " + str(len(labels))) num_data_points = len(images) for test_data_point_index in range(start_index, num_data_points): if test_data_point_index > last_index: break print("\nStarted for data_point_index: " + str(test_data_point_index)) images_temp = images.copy() labels_temp = labels.copy() """for i in range((5)): plt.matshow(images_temp[i][..., -1]) plt.show() plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray) plt.show()""" test_image = images_temp.pop(test_data_point_index) test_label = labels_temp.pop(test_data_point_index) test_dataset = tf.data.Dataset.from_tensor_slices( ([test_image], [test_label])) print("num images: " + str(len(images_temp))) print("num labels: " + str(len(labels_temp))) random_permutation = np.random.permutation(len(images_temp)) images_temp = np.array(images_temp)[random_permutation] labels_temp = np.array(labels_temp)[random_permutation] image_dataset = tf.data.Dataset.from_tensor_slices( (images_temp, labels_temp)) """Crate data splits""" data_augmentation = tf.keras.Sequential([ tf.keras.layers.experimental.preprocessing.RandomFlip( "horizontal_and_vertical"), tf.keras.layers.experimental.preprocessing.RandomRotation(0.2), ]) train_dataset = image_dataset.take(160) validation_dataset = image_dataset.skip(160) train_dataset.shuffle(160, reshuffle_each_iteration=True) train_dataset = train_dataset.map( augment_image) # Apply transformations to training data """Load model""" print(circles.channels) print(circles.classes) unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=3, filters_root=16) unet.finalize_model(unet_model, dice_coefficient=False, auc=False, mean_iou=False) # Don't track so many metrics """Train""" # Use early stopping or not? # es_callback = tf.keras.callbacks.EarlyStopping( # monitor='val_loss', # patience=6, # restore_best_weights=True) trainer = unet.Trainer( checkpoint_callback=False, tensorboard_callback=False, tensorboard_images_callback=False, #callbacks=[es_callback] ) trainer.fit( unet_model, train_dataset, #validation_dataset, epochs=40, batch_size=2) """Calculate best amplification""" prediction = unet_model.predict(validation_dataset.batch(batch_size=1)) original_images = [] metric_labels = [] metric_predictions_unprocessed = [] metric_predictions = [] dataset = validation_dataset.map( utils.crop_image_and_label_to_shape(prediction.shape[1:])) for i, (image, label) in enumerate(dataset): original_images.append(image[..., -1]) metric_labels.append(np.argmax(label, axis=-1)) metric_predictions_unprocessed.append( normalize_output(prediction[i, ...])) best_tau, best_score = get_best_threshold( metric_predictions_unprocessed, metric_labels, min=0, max=1, num_steps=50, use_metric=1) #best_tau = 0.5 # Use this to not threshold at all, also comment above print("Best tau: " + str(best_tau)) print("Best avg score: " + str(best_score)) for i in range(len(metric_predictions_unprocessed)): metric_predictions.append( (metric_predictions_unprocessed[i] >= best_tau).astype(int)) if plot: fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(8, 8)) for i in range(3): ax[i][0].matshow(original_images[i]) ax[i][1].matshow(metric_labels[i], cmap=plt.cm.gray) ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray) plt.tight_layout() plt.show() original_images = [] metric_labels_test = [] metric_predictions_unprocessed_test = [] metric_predictions = [] metric_predictions_unthresholded = [] """Evaluate and print to file""" prediction = unet_model.predict(test_dataset.batch(batch_size=1)) dataset = test_dataset.map( utils.crop_image_and_label_to_shape(prediction.shape[1:])) for i, (image, label) in enumerate(dataset): original_images.append(image[..., -1]) metric_labels_test.append(np.argmax(label, axis=-1)) metric_predictions_unprocessed_test.append(prediction[i, ...]) for i in range(len(metric_predictions_unprocessed_test)): metric_predictions.append( (normalize_output(metric_predictions_unprocessed_test[i]) >= best_tau).astype(int)) metric_predictions_unthresholded.append((normalize_output( metric_predictions_unprocessed_test[i]) >= 0.5).astype(int)) # Calculate thresholded and unthresholded metrics in parallel parallel_metrics = [ Metrics(metric_labels_test, metric_predictions_unthresholded, safe=False, parallel=False), Metrics(metric_labels_test, metric_predictions, safe=False, parallel=False) ] def f(m): return (m.jaccard()[0], m.dice()[0], m.adj_rand()[0], m.warping_error()[0]) pool = Pool(2) metric_result = pool.map(f, parallel_metrics) jaccard_index = metric_result[0][0] dice = metric_result[0][1] adj = metric_result[0][2] warping_error = metric_result[0][3] jaccard_index_to = metric_result[1][0] dice_to = metric_result[1][1] adj_to = metric_result[1][2] warping_error_to = metric_result[1][3] with results_file.open("a") as f: f.write( str(test_data_point_index) + ";" + str(jaccard_index) + ";" + str(dice) + ";" + str(adj) + ";" + str(warping_error) + ";" + str(jaccard_index_to) + ";" + str(dice_to) + ";" + str(adj_to) + ";" + str(warping_error_to) + "\n") print("test_data_point_index: " + str(test_data_point_index)) print("Jaccard index: " + str(jaccard_index) + " with threshold optimization: " + str(jaccard_index_to)) print("Dice: " + str(dice) + " with threshold optimization: " + str(dice_to)) print("Adj: " + str(adj) + " with threshold optimization: " + str(adj_to)) print("Warping Error: " + str(warping_error) + " with threshold optimization: " + str(warping_error_to)) """Plot predictions""" if plot: fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(8, 8)) for i in range(len(metric_labels_test)): ax[i][0].matshow(original_images[i]) ax[i][1].matshow(metric_labels_test[i], cmap=plt.cm.gray) ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray) plt.tight_layout() plt.show() if store_masks: np.save( "results/BBBC039_val_fold_" + str(test_data_point_index) + ".npy", metric_predictions_unprocessed) np.save( "results/BBBC039_val_true_fold_" + str(test_data_point_index) + ".npy", metric_labels) np.save( "results/BBBC039_test_fold_" + str(test_data_point_index) + ".npy", metric_predictions_unprocessed_test) np.save( "results/BBBC039_test_true_fold_" + str(test_data_point_index) + ".npy", metric_labels)
#with h5py.File('../dataset_impl/patches4/train.h5', 'r') as hf: # data_train = np.array(hf.get('data')) # label_train = np.array(hf.get('label')) ##split in and testset data_train, label_train, data_test, label_test = minidataset.extract( '../../dataset_impl/patches4', 20, 6, 15) #0,0 data_provider = image_util.SimpleDataProvider(data_train, label_train, channels_in=5, channels_out=4, n_class=16) ##setup & training net = unet.Unet(channels_in=5, channels_out=4, n_class=16) trainer = unet.Trainer(net, batch_size=1, optimizer="momentum") #10 path = trainer.train(data_provider, "prediction", training_iters=20, epochs=6) #51-100 #verification #prediction = net.predict(path, data_test) #data=testset #unet.error_rate(prediction, util.crop_to_shape(label_test, prediction.shape)) #modified through reshape #true_y=util.to_rgb(util.crop_to_shape(label_test, prediction.shape)) #est_y=util.to_rgb(prediction) #util.save_image(true_y, 'true_y_fin.jpg') #util.save_image(est_y, 'est_y_fin.jpg')
#train net = unet.Unet(channels=generator.channels, n_class=generator.n_class, cost=para.cost, cost_kwargs=dict(regularizer=para.regularizer), layers=para.layers, features_root=para.features_root, training=True) #trainer = unet.Trainer(net, batch_size=para.batch_size, optimizer="momentum", # opt_kwargs=dict(momentum=para.momentum, learning_rate=para.learning_rate)) trainer = unet.Trainer(net, batch_size=para.batch_size, optimizer="adam", opt_kwargs=dict(learning_rate=para.learning_rate, decay_rate=para.decay_rate)) path = trainer.train(generator, unet_trained_path, training_iters=para.training_iters, epochs=para.epochs, dropout=para.dropout, display_step=para.display_step, restore=para.restore, prediction_path=prediction_address) #test one image x_test, y_test = generator(1) prediction = net.predict(os.path.join(unet_trained_path, 'model.ckpt'), x_test)
def main(filename=None, calculate_train_metric=False): """ :param start_index: :param filename: :param plot_validation: Plots 3 samples from the validation set each fold :param plot_test: Plots the test test image for each fold :return: """ now = datetime.now() current_dt = now.strftime("%y_%m_%d_%H_%M_%S") if filename is None: filename = "results/" + current_dt + ".csv" results_file = Path(filename) if not results_file.is_file(): results_file.write_text('index; jaccard; Dice; Adj; Warp\n') """ Load data """ image_path = "data/synthetic/images/" label_path = "data/synthetic/labels/" file_extension = "tif" # inp_dim = 572 # inp_dim = 200 # inp_dim = 710 inp_dim = 1024 file_names = sorted(glob.glob(image_path + "*." + file_extension)) file_names_labels = sorted(glob.glob(label_path + "*." + file_extension)) print(file_names) print(file_names_labels) images = [] for file in file_names: if file_extension == "tif": images.append( tf.convert_to_tensor(np.expand_dims(plt.imread(file), axis=2))) # For .tif images[-1] = images[-1] / 255 # Normalize images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') elif file_extension == "png": images.append(tf.convert_to_tensor( plt.imread(file)[:, :, :3])) # For .png images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') images[-1] = tf.image.rgb_to_grayscale(images[-1]) images[-1] = mirror_pad_image(images[-1], pixels=20) labels = [] for file in file_names_labels: label = plt.imread(file) # label = plt.imread(file)[:, :, :3] label = (np.expand_dims(label, axis=2)) label = np.where(label > 0, [0, 1], [1, 0]) labels.append(tf.convert_to_tensor(label)) labels[-1] = tf.image.resize(labels[-1], [inp_dim, inp_dim], preserve_aspect_ratio=True, method='bilinear') labels[-1] = np.where(labels[-1] > 0.5, 1, 0) labels[-1] = mirror_pad_image(labels[-1], pixels=20) print("num images: " + str(len(images))) print("num labels: " + str(len(labels))) num_data_points = len(images) scilife_images, scilife_labels = scilife_data() # plt.matshow(scilife_images[1][..., -1]) # plt.show() # # for i in range(len(scilife_images)): # print(np.max(scilife_images[i])) images_temp = images.copy() labels_temp = labels.copy() """for i in range((5)): plt.matshow(images_temp[i][..., -1]) plt.show() plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray) plt.show()""" print("num images: " + str(len(images_temp))) print("num labels: " + str(len(labels_temp))) random_permutation = np.random.permutation(len(images_temp)) images_temp = np.array(images_temp)[random_permutation] labels_temp = np.array(labels_temp)[random_permutation] image_dataset = tf.data.Dataset.from_tensor_slices( (images_temp, labels_temp)) """Crate data splits""" train_dataset = image_dataset.take(100) validation_dataset = image_dataset.skip(100) train_dataset.shuffle(100, reshuffle_each_iteration=True) train_dataset = train_dataset.map( augment_image) # Apply transformations to training data """Load model""" print(circles.channels) print(circles.classes) unet_model = unet.build_model(channels=circles.channels, num_classes=circles.classes, layer_depth=3, filters_root=16) if calculate_train_metric: unet.finalize_model(unet_model) else: unet.finalize_model(unet_model, dice_coefficient=False, auc=False, mean_iou=False) """Train""" # callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) # trainer = unet.Trainer(checkpoint_callback=False, callbacks=[callback]) trainer = unet.Trainer(checkpoint_callback=False) trainer.fit(unet_model, train_dataset, epochs=25, batch_size=1) """Sci Life data prediction""" scilife_dataset = tf.data.Dataset.from_tensor_slices( (scilife_images, scilife_labels)) prediction = unet_model.predict(scilife_dataset.batch(batch_size=1)) original_images = [] metric_labels = [] metric_predictions_unprocessed = [] metric_predictions = [] dataset = scilife_dataset.map( utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2))) prediction = remove_border(prediction, inp_dim, inp_dim) # print("Validation shape after: ", prediction.shape) for i, (image, _) in enumerate(dataset): original_images.append(image[..., -1]) metric_predictions_unprocessed.append(prediction[i, ...]) for i in range(len(metric_predictions_unprocessed)): metric_predictions.append( np.argmax(metric_predictions_unprocessed[i] * np.array([[[1, 1]]]), axis=-1)) fig, ax = plt.subplots(5, 2, sharex=True, sharey=True, figsize=(25, 60)) for i in range(5): ax[i][0].matshow(original_images[i]) ax[i][1].matshow(metric_predictions[i], cmap=plt.cm.gray) plt.imsave("results/scilifelab_" + str(current_dt) + "_index_" + str(i) + ".png", metric_predictions[i], cmap=plt.cm.gray) plt.tight_layout() plt.savefig("results/scilifelab_" + str(current_dt) + ".png") plt.show()
epochs = 100 dropout = 0.75 # Dropout, probability to keep units display_step = 2 restore = True generator = image_gen.RgbDataProvider(nx, ny, cnt=20, rectangles=False) net = unet.Unet(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=4, cost="IoU") trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2, learning_rate=0.1, decay_rate=0.9)) path = trainer.train(generator, "./unet_trained", training_iters=training_iters, epochs=epochs, dropout=dropout, display_step=display_step, restore=restore) x_test, y_test = generator(4) prediction = net.predict(path, x_test) print("Testing error rate: {:.2f}%".format( unet.error_rate(prediction, util.crop_to_shape(y_test,