Exemple #1
0
def train():
    unet_model = unet.build_model(channels=circles.channels,
                                  num_classes=circles.classes,
                                  layer_depth=3,
                                  filters_root=16)

    unet.finalize_model(unet_model, learning_rate=LEARNING_RATE)

    trainer = unet.Trainer(
        name="circles",
        learning_rate_scheduler=unet.SchedulerType.WARMUP_LINEAR_DECAY,
        warmup_proportion=0.1,
        learning_rate=LEARNING_RATE)

    train_dataset, validation_dataset, test_dataset = circles.load_data(
        100, nx=272, ny=272, r_max=20)

    trainer.fit(unet_model,
                train_dataset,
                validation_dataset,
                test_dataset,
                epochs=25,
                batch_size=5)

    return unet_model
Exemple #2
0
def train():
    unet_model = unet.build_model(*oxford_iiit_pet.IMAGE_SIZE,
                                  channels=oxford_iiit_pet.channels,
                                  num_classes=oxford_iiit_pet.classes,
                                  layer_depth=4,
                                  filters_root=64,
                                  padding="same")

    unet.finalize_model(unet_model,
                        loss=losses.SparseCategoricalCrossentropy(),
                        metrics=[metrics.SparseCategoricalAccuracy()],
                        auc=False,
                        learning_rate=LEARNING_RATE)

    trainer = unet.Trainer(name="oxford_iiit_pet")

    train_dataset, validation_dataset = oxford_iiit_pet.load_data()

    trainer.fit(unet_model,
                train_dataset,
                validation_dataset,
                epochs=25,
                batch_size=1)

    return unet_model
Exemple #3
0
def train():

    callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
    Ground_truth = readimagesintonumpyArray(
        '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Ground/'
    )[0]
    Ground_truth = makegroundtruth2(Ground_truth)
    train_data = readimagesintonumpyArray(
        '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Training/Training1/'
    )[0]
    train_data = maketrainingdata(train_data)
    validation_imgz = readimagesintonumpyArray(
        '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Testing/'
    )[0]
    validation_imgz = maketrainingdata(validation_imgz)
    validation_groundtruth = readimagesintonumpyArray(
        '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/TestGround/'
    )[0]
    validation_groundtruth = makegroundtruth2(validation_groundtruth)

    validation_dataset = tf.data.Dataset.from_tensor_slices(
        (validation_imgz, validation_groundtruth))
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (train_data, Ground_truth))

    unet_model = unet.build_model(channels=1,
                                  num_classes=2,
                                  layer_depth=3,
                                  filters_root=16)

    unet_model = tf.keras.models.load_model(
        '/home/stephen/Downloads/TF2UNET/unet/scripts/circles/2021-04-17T00-12_02/',
        custom_objects=custom_objects)

    unet.finalize_model(unet_model,
                        loss="binary_crossentropy",
                        learning_rate=LEARNING_RATE)

    trainer = unet.Trainer(name="circles",
                           learning_rate=LEARNING_RATE,
                           tensorboard_callback=True,
                           learning_rate_scheduler=scheduler)

    trainer.fit(unet_model,
                train_dataset,
                validation_dataset,
                epochs=70,
                batch_size=10)

    return unet_model
Exemple #4
0
width = 160
batch_size = 10
train_path = '/DB/rhome/qyzheng/Desktop/qyzheng/source/renji_data/from_senior/0_cv_train.csv'
val_path = '/DB/rhome/qyzheng/Desktop/qyzheng/source/renji_data/from_senior/0_cv_val.csv'

dataset, iters = image_gen.GetDataset(train_path, batch_size)
generator = image_gen.BladderDataProvider(height, width, dataset)
"""
x_test, y_test = generator(1)
fig, ax = plt.subplots(1, 2, sharey=True, figsize=(8, 4))
ax[0].imshow(x_test[0, ..., 0], aspect="auto")
ax[1].imshow(y_test[0, ..., 1], aspect="auto")
#plt.show()
"""
net = unet.Unet(channels=generator.channels, n_class=generator.n_class, layers=4, features_root=64)
trainer = unet.Trainer(net, batch_size=4, optimizer="momentum", opt_kwargs=dict(momentum=0.2))
path = trainer.train(generator, "../unet_trained", training_iters=iters, epochs=100, display_step=4, 
	                 prediction_path='/DATA/data/sxfeng/data/IVDM3Seg/result/result_2/prediction')

'''
x_test, y_test = generator(1)
print(x_test.shape)
print(y_test.shape)
prediction = net.predict("../unet_trained/model.ckpt", x_test)
print(prediction.shape)
'''
"""
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12, 5))
ax[0].imshow(x_test[0,...,0], aspect="auto")
ax[1].imshow(y_test[0,...,1], aspect="auto")
mask = prediction[0,...,1] > 0.3
Exemple #5
0
def main(start_index=0,
         last_index=99,
         filename=None,
         plot_validation=False,
         plot_test=True,
         calculate_train_metric=False):
    """

    :param start_index:
    :param filename:
    :param plot_validation: Plots 3 samples from the validation set each fold
    :param plot_test:  Plots the test test image for each fold
    :return:
    """
    if filename is None:
        now = datetime.now()
        current_dt = now.strftime("%y_%m_%d_%H_%M_%S")
        filename = "results/" + current_dt + ".csv"
    results_file = Path(filename)
    if not results_file.is_file():
        results_file.write_text(
            'index;jaccard;Dice;Adj;Warp;jaccard_to;Dice_to;Adj_to;Warp_to\n')
    """ Load data """
    #image_path = "data/BBBC004_v1_images/*/"
    #label_path = "data/BBBC004_v1_foreground/*/"
    image_path = "../datasets/BBBC004/images/all/"
    label_path = "../datasets/BBBC004/masks/all/"

    file_extension = "tif"

    inp_dim = 950

    file_names = sorted(glob.glob(image_path + "*." + file_extension))
    file_names_labels = sorted(glob.glob(label_path + "*." + file_extension))

    print(file_names)
    print(file_names_labels)

    # Determine largest and smallest pixel values in the dataset
    min_val = float('inf')
    max_val = float('-inf')
    for filename in file_names:
        img = plt.imread(filename)
        if np.min(img) < min_val:
            min_val = np.min(img)
        if np.max(img) > max_val:
            max_val = np.max(img)

    images = []
    for file in file_names:
        if file_extension == "tif":
            images.append(
                tf.convert_to_tensor(np.expand_dims(plt.imread(file),
                                                    axis=2)))  # For .tif
            #images[-1] = images[-1] / 255  # Normalize
            images[-1] = (images[-1] - min_val) / (max_val - min_val)
            images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim],
                                         preserve_aspect_ratio=True,
                                         method='bilinear')
            #print(np.min(images[-1]), np.max(images[-1]))
        elif file_extension == "png":
            images.append(tf.convert_to_tensor(
                plt.imread(file)[:, :, :3]))  # For .png
            images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim],
                                         preserve_aspect_ratio=True,
                                         method='bilinear')
            images[-1] = tf.image.rgb_to_grayscale(images[-1])

        images[-1] = mirror_pad_image(images[-1], pixels=21)

    labels = []
    for file in file_names_labels:
        label = plt.imread(file)[:, :, :3]
        label = (np.expand_dims(np.sum(label, axis=2), axis=2))

        label = np.where(label > 0, [0, 1], [1, 0])
        labels.append(tf.convert_to_tensor(label))

        labels[-1] = tf.image.resize(labels[-1], [inp_dim, inp_dim],
                                     preserve_aspect_ratio=True,
                                     method='bilinear')
        labels[-1] = np.where(labels[-1] > 0.5, 1, 0)

        labels[-1] = mirror_pad_image(labels[-1], pixels=21)

    print("num images: " + str(len(images)))
    print("num labels: " + str(len(labels)))

    num_data_points = len(images)

    for test_data_point_index in range(start_index, num_data_points):
        if test_data_point_index > last_index:
            break
        print("\nStarted for data_point_index: " + str(test_data_point_index))

        images_temp = images.copy()
        labels_temp = labels.copy()
        """for i in range((5)):
            plt.matshow(images_temp[i][..., -1])
            plt.show()
            plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray)
            plt.show()"""

        test_image = images_temp.pop(test_data_point_index)
        test_label = labels_temp.pop(test_data_point_index)

        test_dataset = tf.data.Dataset.from_tensor_slices(
            ([test_image], [test_label]))

        print("num images: " + str(len(images_temp)))
        print("num labels: " + str(len(labels_temp)))

        random_permutation = np.random.permutation(len(images_temp))
        images_temp = np.array(images_temp)[random_permutation]
        labels_temp = np.array(labels_temp)[random_permutation]

        image_dataset = tf.data.Dataset.from_tensor_slices(
            (images_temp, labels_temp))
        """Crate data splits"""
        data_augmentation = tf.keras.Sequential([
            tf.keras.layers.experimental.preprocessing.RandomFlip(
                "horizontal_and_vertical"),
            tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
        ])

        # image_dataset.shuffle(100, reshuffle_each_iteration=False)

        train_dataset = image_dataset.take(80)
        validation_dataset = image_dataset.skip(80)

        train_dataset.shuffle(80, reshuffle_each_iteration=True)

        train_dataset = train_dataset.map(
            augment_image)  # Apply transformations to training data
        """Load model"""
        print(circles.channels)
        print(circles.classes)

        unet_model = unet.build_model(channels=circles.channels,
                                      num_classes=circles.classes,
                                      layer_depth=3,
                                      filters_root=16)
        if calculate_train_metric:
            unet.finalize_model(unet_model)
        else:
            unet.finalize_model(unet_model,
                                dice_coefficient=False,
                                auc=False,
                                mean_iou=False)  # Don't track so many metrics
        """Train"""
        # Use early stopping or not?
        # es_callback = tf.keras.callbacks.EarlyStopping(
        #     monitor='val_loss',
        #     patience=6,
        #     restore_best_weights=True)
        trainer = unet.Trainer(
            checkpoint_callback=False,
            tensorboard_callback=False,
            tensorboard_images_callback=False,
            #callbacks=[es_callback]
        )
        trainer.fit(
            unet_model,
            train_dataset,
            #validation_dataset,
            epochs=40,
            batch_size=2)
        """Calculate best amplification"""
        prediction = unet_model.predict(validation_dataset.batch(batch_size=1))

        original_images = []
        metric_labels = []
        metric_predictions_unprocessed = []
        metric_predictions = []

        dataset = validation_dataset.map(
            utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2)))
        prediction = remove_border(prediction, inp_dim, inp_dim)

        for i, (image, label) in enumerate(dataset):
            original_images.append(image[..., -1])
            metric_labels.append(np.argmax(label, axis=-1))
            metric_predictions_unprocessed.append(
                normalize_output(prediction[i, ...]))

        best_tau, best_score = get_best_threshold(
            metric_predictions_unprocessed,
            metric_labels,
            min=0,
            max=1,
            num_steps=50,
            use_metric=1)

        #best_tau = 0.5 # Use this to not threshold at all, also comment above
        print("Best tau: " + str(best_tau))
        print("Best avg score: " + str(best_score))

        for i in range(len(metric_predictions_unprocessed)):
            metric_predictions.append(
                (metric_predictions_unprocessed[i] >= best_tau).astype(int))

        if plot_validation:
            fig, ax = plt.subplots(3,
                                   3,
                                   sharex=True,
                                   sharey=True,
                                   figsize=(8, 8))

            for i in range(3):
                ax[i][0].matshow(original_images[i])
                ax[i][1].matshow(metric_labels[i], cmap=plt.cm.gray)
                ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray)

            plt.tight_layout()
            plt.show()

        original_images = []
        metric_labels_test = []
        metric_predictions_unprocessed_test = []
        metric_predictions = []
        metric_predictions_unthresholded = []
        """Evaluate and print to file"""
        prediction = unet_model.predict(test_dataset.batch(batch_size=1))
        dataset = test_dataset.map(
            utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2)))
        prediction = remove_border(prediction, inp_dim, inp_dim)
        print("Test shape shape: ", prediction.shape)

        for i, (image, label) in enumerate(dataset):
            original_images.append(image[..., -1])
            metric_labels_test.append(np.argmax(label, axis=-1))
            metric_predictions_unprocessed_test.append(prediction[i, ...])

        for i in range(len(metric_predictions_unprocessed_test)):
            metric_predictions.append(
                (normalize_output(metric_predictions_unprocessed_test[i]) >=
                 best_tau).astype(int))
            metric_predictions_unthresholded.append((normalize_output(
                metric_predictions_unprocessed_test[i]) >= 0.5).astype(int))

        # Calculate thresholded and unthresholded metrics in parallel
        parallel_metrics = [
            Metrics(metric_labels_test,
                    metric_predictions_unthresholded,
                    safe=False,
                    parallel=False),
            Metrics(metric_labels_test,
                    metric_predictions,
                    safe=False,
                    parallel=False)
        ]

        def f(m):
            return (m.jaccard()[0], m.dice()[0], m.adj_rand()[0],
                    m.warping_error()[0])

        pool = Pool(2)
        metric_result = pool.map(f, parallel_metrics)

        jaccard_index = metric_result[0][0]
        dice = metric_result[0][1]
        adj = metric_result[0][2]
        warping_error = metric_result[0][3]

        jaccard_index_to = metric_result[1][0]
        dice_to = metric_result[1][1]
        adj_to = metric_result[1][2]
        warping_error_to = metric_result[1][3]

        with results_file.open("a") as f:
            f.write(
                str(test_data_point_index) + ";" + str(jaccard_index) + ";" +
                str(dice) + ";" + str(adj) + ";" + str(warping_error) + ";" +
                str(jaccard_index_to) + ";" + str(dice_to) + ";" +
                str(adj_to) + ";" + str(warping_error_to) + "\n")

        print("test_data_point_index: " + str(test_data_point_index))
        print("Jaccard index: " + str(jaccard_index) +
              " with threshold optimization: " + str(jaccard_index_to))
        print("Dice: " + str(dice) + " with threshold optimization: " +
              str(dice_to))
        print("Adj: " + str(adj) + " with threshold optimization: " +
              str(adj_to))
        print("Warping Error: " + str(warping_error) +
              " with threshold optimization: " + str(warping_error_to))
        """Plot predictions"""
        if plot_test:
            fig, ax = plt.subplots(1, 3, figsize=(8, 4))
            fig.suptitle("Test point: " + str(test_data_point_index),
                         fontsize=14)

            ax[0].matshow(original_images[i])
            ax[0].set_title("Input data")
            ax[0].set_axis_off()

            ax[1].matshow(metric_labels[i], cmap=plt.cm.gray)
            ax[1].set_title("True mask")
            ax[1].set_axis_off()

            ax[2].matshow(metric_predictions[i], cmap=plt.cm.gray)
            ax[2].set_title("Predicted mask")
            ax[2].set_axis_off()

            fig.tight_layout()
            plt.show()
Exemple #6
0
    def test_fit(self, tmp_path):
        output_shape = (8, 8, 2)
        image_shape = (10, 10, 3)
        epochs = 5
        shuffle = True
        batch_size = 10

        model = Mock(name="model")
        model.predict().shape = (None, *output_shape)

        mock_callback = Mock()
        trainer = unet.Trainer(
            name="test",
            log_dir_path=str(tmp_path),
            checkpoint_callback=True,
            tensorboard_callback=True,
            tensorboard_images_callback=True,
            callbacks=[mock_callback],
            learning_rate_scheduler=unet.SchedulerType.WARMUP_LINEAR_DECAY,
            warmup_proportion=0.1,
            learning_rate=1.0)

        train_dataset = _build_dataset(image_shape=image_shape)
        validation_dataset = _build_dataset(image_shape=image_shape)
        test_dataset = _build_dataset(image_shape=image_shape)

        trainer.fit(model,
                    train_dataset=train_dataset,
                    validation_dataset=validation_dataset,
                    test_dataset=test_dataset,
                    epochs=epochs,
                    batch_size=batch_size,
                    shuffle=shuffle)

        args, kwargs = model.fit.call_args
        train_dataset = args[0]
        validation_dataset = kwargs["validation_data"]

        assert tuple(train_dataset.element_spec[0].shape) == (None,
                                                              *image_shape)
        assert tuple(train_dataset.element_spec[1].shape) == (None,
                                                              *output_shape)
        assert train_dataset._batch_size.numpy() == batch_size

        assert validation_dataset._batch_size.numpy() == batch_size
        assert tuple(
            validation_dataset.element_spec[0].shape) == (None, *image_shape)
        assert tuple(
            validation_dataset.element_spec[1].shape) == (None, *output_shape)

        callbacks = kwargs["callbacks"]
        callback_types = [type(callback) for callback in callbacks]
        assert mock_callback in callbacks
        assert ModelCheckpoint in callback_types
        assert TensorBoardWithLearningRate in callback_types
        assert TensorBoardImageSummary in callback_types
        assert LearningRateScheduler in callback_types

        assert kwargs["epochs"] == epochs
        assert kwargs["shuffle"] == shuffle

        args, kwargs = model.evaluate.call_args
        test_dataset = args[0]
        assert tuple(test_dataset.element_spec[0].shape) == (None,
                                                             *image_shape)
        assert tuple(test_dataset.element_spec[1].shape) == (None,
                                                             *output_shape)
def main(start_index=0,
         last_index=199,
         filename=None,
         plot=True,
         store_masks=False):
    if filename is None:
        now = datetime.now()
        current_dt = now.strftime("%y_%m_%d_%H_%M_%S")
        filename = "results/" + current_dt + ".csv"
    results_file = Path(filename)
    if not results_file.is_file():
        results_file.write_text(
            'index;jaccard;Dice;Adj;Warp;jaccard_to;Dice_to;Adj_to;Warp_to\n')
    """ Load data """
    print("Start read")

    images, labels = read_data()
    print("Done read")

    min_val = float('inf')
    max_val = float('-inf')
    for img in images:
        if np.min(img) < min_val:
            min_val = np.min(img)
        if np.max(img) > max_val:
            max_val = np.max(img)
    print(min_val, max_val)
    #images = [np.expand_dims(image, axis=2)/ max(np.max(image), 255) for image in images]
    # Normalize relative to entire dataset
    images = [(np.expand_dims(image, axis=2) - min_val) / (max_val - min_val)
              for image in images]
    labels = [split_into_classes(label[:, :, :2]) for label in labels]

    print(np.array(images).shape)
    print(np.array(labels).shape)

    for i in range(len(images)):
        images[i] = mirror_pad_image(images[i])
        labels[i] = mirror_pad_image(labels[i])

    print("num images: " + str(len(images)))
    print("num labels: " + str(len(labels)))

    num_data_points = len(images)

    for test_data_point_index in range(start_index, num_data_points):
        if test_data_point_index > last_index:
            break
        print("\nStarted for data_point_index: " + str(test_data_point_index))

        images_temp = images.copy()
        labels_temp = labels.copy()
        """for i in range((5)):
            plt.matshow(images_temp[i][..., -1])
            plt.show()
            plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray)
            plt.show()"""

        test_image = images_temp.pop(test_data_point_index)
        test_label = labels_temp.pop(test_data_point_index)

        test_dataset = tf.data.Dataset.from_tensor_slices(
            ([test_image], [test_label]))

        print("num images: " + str(len(images_temp)))
        print("num labels: " + str(len(labels_temp)))

        random_permutation = np.random.permutation(len(images_temp))
        images_temp = np.array(images_temp)[random_permutation]
        labels_temp = np.array(labels_temp)[random_permutation]

        image_dataset = tf.data.Dataset.from_tensor_slices(
            (images_temp, labels_temp))
        """Crate data splits"""
        data_augmentation = tf.keras.Sequential([
            tf.keras.layers.experimental.preprocessing.RandomFlip(
                "horizontal_and_vertical"),
            tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
        ])

        train_dataset = image_dataset.take(160)
        validation_dataset = image_dataset.skip(160)

        train_dataset.shuffle(160, reshuffle_each_iteration=True)

        train_dataset = train_dataset.map(
            augment_image)  # Apply transformations to training data
        """Load model"""
        print(circles.channels)
        print(circles.classes)

        unet_model = unet.build_model(channels=circles.channels,
                                      num_classes=circles.classes,
                                      layer_depth=3,
                                      filters_root=16)
        unet.finalize_model(unet_model,
                            dice_coefficient=False,
                            auc=False,
                            mean_iou=False)  # Don't track so many metrics
        """Train"""
        # Use early stopping or not?
        # es_callback = tf.keras.callbacks.EarlyStopping(
        #     monitor='val_loss',
        #     patience=6,
        #     restore_best_weights=True)
        trainer = unet.Trainer(
            checkpoint_callback=False,
            tensorboard_callback=False,
            tensorboard_images_callback=False,
            #callbacks=[es_callback]
        )
        trainer.fit(
            unet_model,
            train_dataset,
            #validation_dataset,
            epochs=40,
            batch_size=2)
        """Calculate best amplification"""
        prediction = unet_model.predict(validation_dataset.batch(batch_size=1))

        original_images = []
        metric_labels = []
        metric_predictions_unprocessed = []
        metric_predictions = []

        dataset = validation_dataset.map(
            utils.crop_image_and_label_to_shape(prediction.shape[1:]))

        for i, (image, label) in enumerate(dataset):
            original_images.append(image[..., -1])
            metric_labels.append(np.argmax(label, axis=-1))
            metric_predictions_unprocessed.append(
                normalize_output(prediction[i, ...]))

        best_tau, best_score = get_best_threshold(
            metric_predictions_unprocessed,
            metric_labels,
            min=0,
            max=1,
            num_steps=50,
            use_metric=1)

        #best_tau = 0.5 # Use this to not threshold at all, also comment above
        print("Best tau: " + str(best_tau))
        print("Best avg score: " + str(best_score))

        for i in range(len(metric_predictions_unprocessed)):
            metric_predictions.append(
                (metric_predictions_unprocessed[i] >= best_tau).astype(int))

        if plot:
            fig, ax = plt.subplots(3,
                                   3,
                                   sharex=True,
                                   sharey=True,
                                   figsize=(8, 8))

            for i in range(3):
                ax[i][0].matshow(original_images[i])
                ax[i][1].matshow(metric_labels[i], cmap=plt.cm.gray)
                ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray)

            plt.tight_layout()
            plt.show()

        original_images = []
        metric_labels_test = []
        metric_predictions_unprocessed_test = []
        metric_predictions = []
        metric_predictions_unthresholded = []
        """Evaluate and print to file"""
        prediction = unet_model.predict(test_dataset.batch(batch_size=1))
        dataset = test_dataset.map(
            utils.crop_image_and_label_to_shape(prediction.shape[1:]))

        for i, (image, label) in enumerate(dataset):
            original_images.append(image[..., -1])
            metric_labels_test.append(np.argmax(label, axis=-1))
            metric_predictions_unprocessed_test.append(prediction[i, ...])

        for i in range(len(metric_predictions_unprocessed_test)):
            metric_predictions.append(
                (normalize_output(metric_predictions_unprocessed_test[i]) >=
                 best_tau).astype(int))
            metric_predictions_unthresholded.append((normalize_output(
                metric_predictions_unprocessed_test[i]) >= 0.5).astype(int))

        # Calculate thresholded and unthresholded metrics in parallel
        parallel_metrics = [
            Metrics(metric_labels_test,
                    metric_predictions_unthresholded,
                    safe=False,
                    parallel=False),
            Metrics(metric_labels_test,
                    metric_predictions,
                    safe=False,
                    parallel=False)
        ]

        def f(m):
            return (m.jaccard()[0], m.dice()[0], m.adj_rand()[0],
                    m.warping_error()[0])

        pool = Pool(2)
        metric_result = pool.map(f, parallel_metrics)

        jaccard_index = metric_result[0][0]
        dice = metric_result[0][1]
        adj = metric_result[0][2]
        warping_error = metric_result[0][3]

        jaccard_index_to = metric_result[1][0]
        dice_to = metric_result[1][1]
        adj_to = metric_result[1][2]
        warping_error_to = metric_result[1][3]

        with results_file.open("a") as f:
            f.write(
                str(test_data_point_index) + ";" + str(jaccard_index) + ";" +
                str(dice) + ";" + str(adj) + ";" + str(warping_error) + ";" +
                str(jaccard_index_to) + ";" + str(dice_to) + ";" +
                str(adj_to) + ";" + str(warping_error_to) + "\n")

        print("test_data_point_index: " + str(test_data_point_index))
        print("Jaccard index: " + str(jaccard_index) +
              " with threshold optimization: " + str(jaccard_index_to))
        print("Dice: " + str(dice) + " with threshold optimization: " +
              str(dice_to))
        print("Adj: " + str(adj) + " with threshold optimization: " +
              str(adj_to))
        print("Warping Error: " + str(warping_error) +
              " with threshold optimization: " + str(warping_error_to))
        """Plot predictions"""
        if plot:
            fig, ax = plt.subplots(3,
                                   3,
                                   sharex=True,
                                   sharey=True,
                                   figsize=(8, 8))

            for i in range(len(metric_labels_test)):
                ax[i][0].matshow(original_images[i])
                ax[i][1].matshow(metric_labels_test[i], cmap=plt.cm.gray)
                ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray)

            plt.tight_layout()
            plt.show()

        if store_masks:
            np.save(
                "results/BBBC039_val_fold_" + str(test_data_point_index) +
                ".npy", metric_predictions_unprocessed)
            np.save(
                "results/BBBC039_val_true_fold_" + str(test_data_point_index) +
                ".npy", metric_labels)
            np.save(
                "results/BBBC039_test_fold_" + str(test_data_point_index) +
                ".npy", metric_predictions_unprocessed_test)
            np.save(
                "results/BBBC039_test_true_fold_" +
                str(test_data_point_index) + ".npy", metric_labels)
Exemple #8
0
#with h5py.File('../dataset_impl/patches4/train.h5', 'r') as hf:
#    data_train = np.array(hf.get('data'))
#    label_train = np.array(hf.get('label'))

##split in  and testset
data_train, label_train, data_test, label_test = minidataset.extract(
    '../../dataset_impl/patches4', 20, 6, 15)  #0,0
data_provider = image_util.SimpleDataProvider(data_train,
                                              label_train,
                                              channels_in=5,
                                              channels_out=4,
                                              n_class=16)

##setup & training
net = unet.Unet(channels_in=5, channels_out=4, n_class=16)
trainer = unet.Trainer(net, batch_size=1, optimizer="momentum")  #10
path = trainer.train(data_provider, "prediction", training_iters=20,
                     epochs=6)  #51-100

#verification

#prediction = net.predict(path, data_test) #data=testset

#unet.error_rate(prediction, util.crop_to_shape(label_test, prediction.shape))

#modified through reshape

#true_y=util.to_rgb(util.crop_to_shape(label_test, prediction.shape))
#est_y=util.to_rgb(prediction)
#util.save_image(true_y, 'true_y_fin.jpg')
#util.save_image(est_y, 'est_y_fin.jpg')
Exemple #9
0
#train

net = unet.Unet(channels=generator.channels,
                n_class=generator.n_class,
                cost=para.cost,
                cost_kwargs=dict(regularizer=para.regularizer),
                layers=para.layers,
                features_root=para.features_root,
                training=True)

#trainer = unet.Trainer(net, batch_size=para.batch_size, optimizer="momentum",
#                       opt_kwargs=dict(momentum=para.momentum, learning_rate=para.learning_rate))
trainer = unet.Trainer(net,
                       batch_size=para.batch_size,
                       optimizer="adam",
                       opt_kwargs=dict(learning_rate=para.learning_rate,
                                       decay_rate=para.decay_rate))
path = trainer.train(generator,
                     unet_trained_path,
                     training_iters=para.training_iters,
                     epochs=para.epochs,
                     dropout=para.dropout,
                     display_step=para.display_step,
                     restore=para.restore,
                     prediction_path=prediction_address)

#test one image
x_test, y_test = generator(1)
prediction = net.predict(os.path.join(unet_trained_path, 'model.ckpt'), x_test)
Exemple #10
0
def main(filename=None, calculate_train_metric=False):
    """

    :param start_index:
    :param filename:
    :param plot_validation: Plots 3 samples from the validation set each fold
    :param plot_test:  Plots the test test image for each fold
    :return:
    """
    now = datetime.now()
    current_dt = now.strftime("%y_%m_%d_%H_%M_%S")
    if filename is None:
        filename = "results/" + current_dt + ".csv"
    results_file = Path(filename)
    if not results_file.is_file():
        results_file.write_text('index; jaccard; Dice; Adj; Warp\n')
    """ Load data """
    image_path = "data/synthetic/images/"
    label_path = "data/synthetic/labels/"

    file_extension = "tif"

    # inp_dim = 572
    # inp_dim = 200
    # inp_dim = 710
    inp_dim = 1024

    file_names = sorted(glob.glob(image_path + "*." + file_extension))
    file_names_labels = sorted(glob.glob(label_path + "*." + file_extension))

    print(file_names)
    print(file_names_labels)

    images = []
    for file in file_names:
        if file_extension == "tif":
            images.append(
                tf.convert_to_tensor(np.expand_dims(plt.imread(file),
                                                    axis=2)))  # For .tif
            images[-1] = images[-1] / 255  # Normalize
            images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim],
                                         preserve_aspect_ratio=True,
                                         method='bilinear')
        elif file_extension == "png":
            images.append(tf.convert_to_tensor(
                plt.imread(file)[:, :, :3]))  # For .png
            images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim],
                                         preserve_aspect_ratio=True,
                                         method='bilinear')
            images[-1] = tf.image.rgb_to_grayscale(images[-1])

        images[-1] = mirror_pad_image(images[-1], pixels=20)

    labels = []
    for file in file_names_labels:
        label = plt.imread(file)
        # label = plt.imread(file)[:, :, :3]
        label = (np.expand_dims(label, axis=2))

        label = np.where(label > 0, [0, 1], [1, 0])
        labels.append(tf.convert_to_tensor(label))

        labels[-1] = tf.image.resize(labels[-1], [inp_dim, inp_dim],
                                     preserve_aspect_ratio=True,
                                     method='bilinear')
        labels[-1] = np.where(labels[-1] > 0.5, 1, 0)

        labels[-1] = mirror_pad_image(labels[-1], pixels=20)

    print("num images: " + str(len(images)))
    print("num labels: " + str(len(labels)))

    num_data_points = len(images)

    scilife_images, scilife_labels = scilife_data()

    # plt.matshow(scilife_images[1][..., -1])
    # plt.show()
    #
    # for i in range(len(scilife_images)):
    #     print(np.max(scilife_images[i]))

    images_temp = images.copy()
    labels_temp = labels.copy()
    """for i in range((5)):
        plt.matshow(images_temp[i][..., -1])
        plt.show()
        plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray)
        plt.show()"""

    print("num images: " + str(len(images_temp)))
    print("num labels: " + str(len(labels_temp)))

    random_permutation = np.random.permutation(len(images_temp))
    images_temp = np.array(images_temp)[random_permutation]
    labels_temp = np.array(labels_temp)[random_permutation]

    image_dataset = tf.data.Dataset.from_tensor_slices(
        (images_temp, labels_temp))
    """Crate data splits"""
    train_dataset = image_dataset.take(100)
    validation_dataset = image_dataset.skip(100)

    train_dataset.shuffle(100, reshuffle_each_iteration=True)

    train_dataset = train_dataset.map(
        augment_image)  # Apply transformations to training data
    """Load model"""
    print(circles.channels)
    print(circles.classes)

    unet_model = unet.build_model(channels=circles.channels,
                                  num_classes=circles.classes,
                                  layer_depth=3,
                                  filters_root=16)
    if calculate_train_metric:
        unet.finalize_model(unet_model)
    else:
        unet.finalize_model(unet_model,
                            dice_coefficient=False,
                            auc=False,
                            mean_iou=False)
    """Train"""
    # callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
    # trainer = unet.Trainer(checkpoint_callback=False, callbacks=[callback])
    trainer = unet.Trainer(checkpoint_callback=False)

    trainer.fit(unet_model, train_dataset, epochs=25, batch_size=1)
    """Sci Life data prediction"""
    scilife_dataset = tf.data.Dataset.from_tensor_slices(
        (scilife_images, scilife_labels))
    prediction = unet_model.predict(scilife_dataset.batch(batch_size=1))

    original_images = []
    metric_labels = []
    metric_predictions_unprocessed = []
    metric_predictions = []

    dataset = scilife_dataset.map(
        utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2)))
    prediction = remove_border(prediction, inp_dim, inp_dim)
    # print("Validation shape after: ", prediction.shape)

    for i, (image, _) in enumerate(dataset):
        original_images.append(image[..., -1])
        metric_predictions_unprocessed.append(prediction[i, ...])

    for i in range(len(metric_predictions_unprocessed)):
        metric_predictions.append(
            np.argmax(metric_predictions_unprocessed[i] * np.array([[[1, 1]]]),
                      axis=-1))

    fig, ax = plt.subplots(5, 2, sharex=True, sharey=True, figsize=(25, 60))

    for i in range(5):
        ax[i][0].matshow(original_images[i])
        ax[i][1].matshow(metric_predictions[i], cmap=plt.cm.gray)
        plt.imsave("results/scilifelab_" + str(current_dt) + "_index_" +
                   str(i) + ".png",
                   metric_predictions[i],
                   cmap=plt.cm.gray)

    plt.tight_layout()
    plt.savefig("results/scilifelab_" + str(current_dt) + ".png")
    plt.show()
Exemple #11
0
    epochs = 100
    dropout = 0.75  # Dropout, probability to keep units
    display_step = 2
    restore = True

    generator = image_gen.RgbDataProvider(nx, ny, cnt=20, rectangles=False)

    net = unet.Unet(channels=generator.channels,
                    n_class=generator.n_class,
                    layers=3,
                    features_root=4,
                    cost="IoU")

    trainer = unet.Trainer(net,
                           optimizer="momentum",
                           opt_kwargs=dict(momentum=0.2,
                                           learning_rate=0.1,
                                           decay_rate=0.9))
    path = trainer.train(generator,
                         "./unet_trained",
                         training_iters=training_iters,
                         epochs=epochs,
                         dropout=dropout,
                         display_step=display_step,
                         restore=restore)

    x_test, y_test = generator(4)
    prediction = net.predict(path, x_test)

    print("Testing error rate: {:.2f}%".format(
        unet.error_rate(prediction, util.crop_to_shape(y_test,