Exemple #1
0
def run(gParameters):

    # load data
    x_train, y_train = unet.load_data()

    # example has 420 x 580
    model = unet.build_model(420, 580, gParameters['activation'],
                             gParameters['kernel_initializer'])

    model.summary()
    model.compile(optimizer=gParameters['optimizer'],
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    model_chkpoint = ModelCheckpoint('unet.hdf5',
                                     monitor='loss',
                                     verbose=1,
                                     save_best_only=True)
    history = model.fit(x_train,
                        y_train,
                        batch_size=gParameters['batch_size'],
                        epochs=gParameters['epochs'],
                        verbose=1,
                        validation_split=0.3,
                        shuffle=True,
                        callbacks=[model_chkpoint])

    return history
Exemple #2
0
def train():
    unet_model = unet.build_model(*oxford_iiit_pet.IMAGE_SIZE,
                                  channels=oxford_iiit_pet.channels,
                                  num_classes=oxford_iiit_pet.classes,
                                  layer_depth=4,
                                  filters_root=64,
                                  padding="same")

    unet.finalize_model(unet_model,
                        loss=losses.SparseCategoricalCrossentropy(),
                        metrics=[metrics.SparseCategoricalAccuracy()],
                        auc=False,
                        learning_rate=LEARNING_RATE)

    trainer = unet.Trainer(name="oxford_iiit_pet")

    train_dataset, validation_dataset = oxford_iiit_pet.load_data()

    trainer.fit(unet_model,
                train_dataset,
                validation_dataset,
                epochs=25,
                batch_size=1)

    return unet_model
Exemple #3
0
def train():
    unet_model = unet.build_model(channels=circles.channels,
                                  num_classes=circles.classes,
                                  layer_depth=3,
                                  filters_root=16)

    unet.finalize_model(unet_model, learning_rate=LEARNING_RATE)

    trainer = unet.Trainer(
        name="circles",
        learning_rate_scheduler=unet.SchedulerType.WARMUP_LINEAR_DECAY,
        warmup_proportion=0.1,
        learning_rate=LEARNING_RATE)

    train_dataset, validation_dataset, test_dataset = circles.load_data(
        100, nx=272, ny=272, r_max=20)

    trainer.fit(unet_model,
                train_dataset,
                validation_dataset,
                test_dataset,
                epochs=25,
                batch_size=5)

    return unet_model
Exemple #4
0
def run(
        model_name,
        model_type=None,
        loss=unet.IOU_inverse,
        flip=False,
        translate=False,
        rotate=False,
        brightness=False,  #augmentation 
        batchnorm=False,
        droprate=None,
        regularizer=None,  #model 
):

    print(model_name)

    cp_path = os.path.join(
        OUTDIR, '{}-checkpoint.h5'.format(model_name))  #checkpoint path
    model_path = os.path.join(OUTDIR,
                              '{}.h5'.format(model_name))  #final model path

    #create generators
    train_gen = Generator(scale=SCALE,
                          batch_size=BATCHSIZE,
                          train=True,
                          flip=flip,
                          translate=translate,
                          rotate=rotate,
                          brightness=brightness)
    val_gen = Generator(scale=SCALE, batch_size=BATCHSIZE, train=False)

    # build model
    model = unet.build_model(scale_factor=SCALE,
                             verbose=1,
                             loss=loss,
                             batchnorm=batchnorm,
                             droprate=droprate,
                             regularizer=regularizer)
    model.name = model_name

    # train!
    earlystop = EarlyStopping(patience=PATIENCE, verbose=1)
    checkpoint = ModelCheckpoint(cp_path, verbose=0, save_best_only=True)

    model_hist = model.fit_generator(
        generator=train_gen,
        steps_per_epoch=len(train_gen.image_IDs) // train_gen.batch_size,
        validation_data=val_gen,
        validation_steps=len(val_gen.image_IDs) // val_gen.batch_size,
        epochs=EPOCHS,
        callbacks=[earlystop, checkpoint],
        verbose=2)

    #save weights
    model.save(model_path)
    print("{} saved to disk".format(model.name))

    print(model_hist.history, '\n')
Exemple #5
0
def train():

    callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
    Ground_truth = readimagesintonumpyArray(
        '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Ground/'
    )[0]
    Ground_truth = makegroundtruth2(Ground_truth)
    train_data = readimagesintonumpyArray(
        '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Training/Training1/'
    )[0]
    train_data = maketrainingdata(train_data)
    validation_imgz = readimagesintonumpyArray(
        '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/Testing/'
    )[0]
    validation_imgz = maketrainingdata(validation_imgz)
    validation_groundtruth = readimagesintonumpyArray(
        '/home/stephen/Downloads/TF2UNET/unet/src/unet/datasets/Neuron/TestGround/'
    )[0]
    validation_groundtruth = makegroundtruth2(validation_groundtruth)

    validation_dataset = tf.data.Dataset.from_tensor_slices(
        (validation_imgz, validation_groundtruth))
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (train_data, Ground_truth))

    unet_model = unet.build_model(channels=1,
                                  num_classes=2,
                                  layer_depth=3,
                                  filters_root=16)

    unet_model = tf.keras.models.load_model(
        '/home/stephen/Downloads/TF2UNET/unet/scripts/circles/2021-04-17T00-12_02/',
        custom_objects=custom_objects)

    unet.finalize_model(unet_model,
                        loss="binary_crossentropy",
                        learning_rate=LEARNING_RATE)

    trainer = unet.Trainer(name="circles",
                           learning_rate=LEARNING_RATE,
                           tensorboard_callback=True,
                           learning_rate_scheduler=scheduler)

    trainer.fit(unet_model,
                train_dataset,
                validation_dataset,
                epochs=70,
                batch_size=10)

    return unet_model
Exemple #6
0
def main(start_index=0,
         last_index=99,
         filename=None,
         plot_validation=False,
         plot_test=True,
         calculate_train_metric=False):
    """

    :param start_index:
    :param filename:
    :param plot_validation: Plots 3 samples from the validation set each fold
    :param plot_test:  Plots the test test image for each fold
    :return:
    """
    if filename is None:
        now = datetime.now()
        current_dt = now.strftime("%y_%m_%d_%H_%M_%S")
        filename = "results/" + current_dt + ".csv"
    results_file = Path(filename)
    if not results_file.is_file():
        results_file.write_text(
            'index;jaccard;Dice;Adj;Warp;jaccard_to;Dice_to;Adj_to;Warp_to\n')
    """ Load data """
    #image_path = "data/BBBC004_v1_images/*/"
    #label_path = "data/BBBC004_v1_foreground/*/"
    image_path = "../datasets/BBBC004/images/all/"
    label_path = "../datasets/BBBC004/masks/all/"

    file_extension = "tif"

    inp_dim = 950

    file_names = sorted(glob.glob(image_path + "*." + file_extension))
    file_names_labels = sorted(glob.glob(label_path + "*." + file_extension))

    print(file_names)
    print(file_names_labels)

    # Determine largest and smallest pixel values in the dataset
    min_val = float('inf')
    max_val = float('-inf')
    for filename in file_names:
        img = plt.imread(filename)
        if np.min(img) < min_val:
            min_val = np.min(img)
        if np.max(img) > max_val:
            max_val = np.max(img)

    images = []
    for file in file_names:
        if file_extension == "tif":
            images.append(
                tf.convert_to_tensor(np.expand_dims(plt.imread(file),
                                                    axis=2)))  # For .tif
            #images[-1] = images[-1] / 255  # Normalize
            images[-1] = (images[-1] - min_val) / (max_val - min_val)
            images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim],
                                         preserve_aspect_ratio=True,
                                         method='bilinear')
            #print(np.min(images[-1]), np.max(images[-1]))
        elif file_extension == "png":
            images.append(tf.convert_to_tensor(
                plt.imread(file)[:, :, :3]))  # For .png
            images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim],
                                         preserve_aspect_ratio=True,
                                         method='bilinear')
            images[-1] = tf.image.rgb_to_grayscale(images[-1])

        images[-1] = mirror_pad_image(images[-1], pixels=21)

    labels = []
    for file in file_names_labels:
        label = plt.imread(file)[:, :, :3]
        label = (np.expand_dims(np.sum(label, axis=2), axis=2))

        label = np.where(label > 0, [0, 1], [1, 0])
        labels.append(tf.convert_to_tensor(label))

        labels[-1] = tf.image.resize(labels[-1], [inp_dim, inp_dim],
                                     preserve_aspect_ratio=True,
                                     method='bilinear')
        labels[-1] = np.where(labels[-1] > 0.5, 1, 0)

        labels[-1] = mirror_pad_image(labels[-1], pixels=21)

    print("num images: " + str(len(images)))
    print("num labels: " + str(len(labels)))

    num_data_points = len(images)

    for test_data_point_index in range(start_index, num_data_points):
        if test_data_point_index > last_index:
            break
        print("\nStarted for data_point_index: " + str(test_data_point_index))

        images_temp = images.copy()
        labels_temp = labels.copy()
        """for i in range((5)):
            plt.matshow(images_temp[i][..., -1])
            plt.show()
            plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray)
            plt.show()"""

        test_image = images_temp.pop(test_data_point_index)
        test_label = labels_temp.pop(test_data_point_index)

        test_dataset = tf.data.Dataset.from_tensor_slices(
            ([test_image], [test_label]))

        print("num images: " + str(len(images_temp)))
        print("num labels: " + str(len(labels_temp)))

        random_permutation = np.random.permutation(len(images_temp))
        images_temp = np.array(images_temp)[random_permutation]
        labels_temp = np.array(labels_temp)[random_permutation]

        image_dataset = tf.data.Dataset.from_tensor_slices(
            (images_temp, labels_temp))
        """Crate data splits"""
        data_augmentation = tf.keras.Sequential([
            tf.keras.layers.experimental.preprocessing.RandomFlip(
                "horizontal_and_vertical"),
            tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
        ])

        # image_dataset.shuffle(100, reshuffle_each_iteration=False)

        train_dataset = image_dataset.take(80)
        validation_dataset = image_dataset.skip(80)

        train_dataset.shuffle(80, reshuffle_each_iteration=True)

        train_dataset = train_dataset.map(
            augment_image)  # Apply transformations to training data
        """Load model"""
        print(circles.channels)
        print(circles.classes)

        unet_model = unet.build_model(channels=circles.channels,
                                      num_classes=circles.classes,
                                      layer_depth=3,
                                      filters_root=16)
        if calculate_train_metric:
            unet.finalize_model(unet_model)
        else:
            unet.finalize_model(unet_model,
                                dice_coefficient=False,
                                auc=False,
                                mean_iou=False)  # Don't track so many metrics
        """Train"""
        # Use early stopping or not?
        # es_callback = tf.keras.callbacks.EarlyStopping(
        #     monitor='val_loss',
        #     patience=6,
        #     restore_best_weights=True)
        trainer = unet.Trainer(
            checkpoint_callback=False,
            tensorboard_callback=False,
            tensorboard_images_callback=False,
            #callbacks=[es_callback]
        )
        trainer.fit(
            unet_model,
            train_dataset,
            #validation_dataset,
            epochs=40,
            batch_size=2)
        """Calculate best amplification"""
        prediction = unet_model.predict(validation_dataset.batch(batch_size=1))

        original_images = []
        metric_labels = []
        metric_predictions_unprocessed = []
        metric_predictions = []

        dataset = validation_dataset.map(
            utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2)))
        prediction = remove_border(prediction, inp_dim, inp_dim)

        for i, (image, label) in enumerate(dataset):
            original_images.append(image[..., -1])
            metric_labels.append(np.argmax(label, axis=-1))
            metric_predictions_unprocessed.append(
                normalize_output(prediction[i, ...]))

        best_tau, best_score = get_best_threshold(
            metric_predictions_unprocessed,
            metric_labels,
            min=0,
            max=1,
            num_steps=50,
            use_metric=1)

        #best_tau = 0.5 # Use this to not threshold at all, also comment above
        print("Best tau: " + str(best_tau))
        print("Best avg score: " + str(best_score))

        for i in range(len(metric_predictions_unprocessed)):
            metric_predictions.append(
                (metric_predictions_unprocessed[i] >= best_tau).astype(int))

        if plot_validation:
            fig, ax = plt.subplots(3,
                                   3,
                                   sharex=True,
                                   sharey=True,
                                   figsize=(8, 8))

            for i in range(3):
                ax[i][0].matshow(original_images[i])
                ax[i][1].matshow(metric_labels[i], cmap=plt.cm.gray)
                ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray)

            plt.tight_layout()
            plt.show()

        original_images = []
        metric_labels_test = []
        metric_predictions_unprocessed_test = []
        metric_predictions = []
        metric_predictions_unthresholded = []
        """Evaluate and print to file"""
        prediction = unet_model.predict(test_dataset.batch(batch_size=1))
        dataset = test_dataset.map(
            utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2)))
        prediction = remove_border(prediction, inp_dim, inp_dim)
        print("Test shape shape: ", prediction.shape)

        for i, (image, label) in enumerate(dataset):
            original_images.append(image[..., -1])
            metric_labels_test.append(np.argmax(label, axis=-1))
            metric_predictions_unprocessed_test.append(prediction[i, ...])

        for i in range(len(metric_predictions_unprocessed_test)):
            metric_predictions.append(
                (normalize_output(metric_predictions_unprocessed_test[i]) >=
                 best_tau).astype(int))
            metric_predictions_unthresholded.append((normalize_output(
                metric_predictions_unprocessed_test[i]) >= 0.5).astype(int))

        # Calculate thresholded and unthresholded metrics in parallel
        parallel_metrics = [
            Metrics(metric_labels_test,
                    metric_predictions_unthresholded,
                    safe=False,
                    parallel=False),
            Metrics(metric_labels_test,
                    metric_predictions,
                    safe=False,
                    parallel=False)
        ]

        def f(m):
            return (m.jaccard()[0], m.dice()[0], m.adj_rand()[0],
                    m.warping_error()[0])

        pool = Pool(2)
        metric_result = pool.map(f, parallel_metrics)

        jaccard_index = metric_result[0][0]
        dice = metric_result[0][1]
        adj = metric_result[0][2]
        warping_error = metric_result[0][3]

        jaccard_index_to = metric_result[1][0]
        dice_to = metric_result[1][1]
        adj_to = metric_result[1][2]
        warping_error_to = metric_result[1][3]

        with results_file.open("a") as f:
            f.write(
                str(test_data_point_index) + ";" + str(jaccard_index) + ";" +
                str(dice) + ";" + str(adj) + ";" + str(warping_error) + ";" +
                str(jaccard_index_to) + ";" + str(dice_to) + ";" +
                str(adj_to) + ";" + str(warping_error_to) + "\n")

        print("test_data_point_index: " + str(test_data_point_index))
        print("Jaccard index: " + str(jaccard_index) +
              " with threshold optimization: " + str(jaccard_index_to))
        print("Dice: " + str(dice) + " with threshold optimization: " +
              str(dice_to))
        print("Adj: " + str(adj) + " with threshold optimization: " +
              str(adj_to))
        print("Warping Error: " + str(warping_error) +
              " with threshold optimization: " + str(warping_error_to))
        """Plot predictions"""
        if plot_test:
            fig, ax = plt.subplots(1, 3, figsize=(8, 4))
            fig.suptitle("Test point: " + str(test_data_point_index),
                         fontsize=14)

            ax[0].matshow(original_images[i])
            ax[0].set_title("Input data")
            ax[0].set_axis_off()

            ax[1].matshow(metric_labels[i], cmap=plt.cm.gray)
            ax[1].set_title("True mask")
            ax[1].set_axis_off()

            ax[2].matshow(metric_predictions[i], cmap=plt.cm.gray)
            ax[2].set_title("Predicted mask")
            ax[2].set_axis_off()

            fig.tight_layout()
            plt.show()
def main(start_index=0,
         last_index=199,
         filename=None,
         plot=True,
         store_masks=False):
    if filename is None:
        now = datetime.now()
        current_dt = now.strftime("%y_%m_%d_%H_%M_%S")
        filename = "results/" + current_dt + ".csv"
    results_file = Path(filename)
    if not results_file.is_file():
        results_file.write_text(
            'index;jaccard;Dice;Adj;Warp;jaccard_to;Dice_to;Adj_to;Warp_to\n')
    """ Load data """
    print("Start read")

    images, labels = read_data()
    print("Done read")

    min_val = float('inf')
    max_val = float('-inf')
    for img in images:
        if np.min(img) < min_val:
            min_val = np.min(img)
        if np.max(img) > max_val:
            max_val = np.max(img)
    print(min_val, max_val)
    #images = [np.expand_dims(image, axis=2)/ max(np.max(image), 255) for image in images]
    # Normalize relative to entire dataset
    images = [(np.expand_dims(image, axis=2) - min_val) / (max_val - min_val)
              for image in images]
    labels = [split_into_classes(label[:, :, :2]) for label in labels]

    print(np.array(images).shape)
    print(np.array(labels).shape)

    for i in range(len(images)):
        images[i] = mirror_pad_image(images[i])
        labels[i] = mirror_pad_image(labels[i])

    print("num images: " + str(len(images)))
    print("num labels: " + str(len(labels)))

    num_data_points = len(images)

    for test_data_point_index in range(start_index, num_data_points):
        if test_data_point_index > last_index:
            break
        print("\nStarted for data_point_index: " + str(test_data_point_index))

        images_temp = images.copy()
        labels_temp = labels.copy()
        """for i in range((5)):
            plt.matshow(images_temp[i][..., -1])
            plt.show()
            plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray)
            plt.show()"""

        test_image = images_temp.pop(test_data_point_index)
        test_label = labels_temp.pop(test_data_point_index)

        test_dataset = tf.data.Dataset.from_tensor_slices(
            ([test_image], [test_label]))

        print("num images: " + str(len(images_temp)))
        print("num labels: " + str(len(labels_temp)))

        random_permutation = np.random.permutation(len(images_temp))
        images_temp = np.array(images_temp)[random_permutation]
        labels_temp = np.array(labels_temp)[random_permutation]

        image_dataset = tf.data.Dataset.from_tensor_slices(
            (images_temp, labels_temp))
        """Crate data splits"""
        data_augmentation = tf.keras.Sequential([
            tf.keras.layers.experimental.preprocessing.RandomFlip(
                "horizontal_and_vertical"),
            tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
        ])

        train_dataset = image_dataset.take(160)
        validation_dataset = image_dataset.skip(160)

        train_dataset.shuffle(160, reshuffle_each_iteration=True)

        train_dataset = train_dataset.map(
            augment_image)  # Apply transformations to training data
        """Load model"""
        print(circles.channels)
        print(circles.classes)

        unet_model = unet.build_model(channels=circles.channels,
                                      num_classes=circles.classes,
                                      layer_depth=3,
                                      filters_root=16)
        unet.finalize_model(unet_model,
                            dice_coefficient=False,
                            auc=False,
                            mean_iou=False)  # Don't track so many metrics
        """Train"""
        # Use early stopping or not?
        # es_callback = tf.keras.callbacks.EarlyStopping(
        #     monitor='val_loss',
        #     patience=6,
        #     restore_best_weights=True)
        trainer = unet.Trainer(
            checkpoint_callback=False,
            tensorboard_callback=False,
            tensorboard_images_callback=False,
            #callbacks=[es_callback]
        )
        trainer.fit(
            unet_model,
            train_dataset,
            #validation_dataset,
            epochs=40,
            batch_size=2)
        """Calculate best amplification"""
        prediction = unet_model.predict(validation_dataset.batch(batch_size=1))

        original_images = []
        metric_labels = []
        metric_predictions_unprocessed = []
        metric_predictions = []

        dataset = validation_dataset.map(
            utils.crop_image_and_label_to_shape(prediction.shape[1:]))

        for i, (image, label) in enumerate(dataset):
            original_images.append(image[..., -1])
            metric_labels.append(np.argmax(label, axis=-1))
            metric_predictions_unprocessed.append(
                normalize_output(prediction[i, ...]))

        best_tau, best_score = get_best_threshold(
            metric_predictions_unprocessed,
            metric_labels,
            min=0,
            max=1,
            num_steps=50,
            use_metric=1)

        #best_tau = 0.5 # Use this to not threshold at all, also comment above
        print("Best tau: " + str(best_tau))
        print("Best avg score: " + str(best_score))

        for i in range(len(metric_predictions_unprocessed)):
            metric_predictions.append(
                (metric_predictions_unprocessed[i] >= best_tau).astype(int))

        if plot:
            fig, ax = plt.subplots(3,
                                   3,
                                   sharex=True,
                                   sharey=True,
                                   figsize=(8, 8))

            for i in range(3):
                ax[i][0].matshow(original_images[i])
                ax[i][1].matshow(metric_labels[i], cmap=plt.cm.gray)
                ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray)

            plt.tight_layout()
            plt.show()

        original_images = []
        metric_labels_test = []
        metric_predictions_unprocessed_test = []
        metric_predictions = []
        metric_predictions_unthresholded = []
        """Evaluate and print to file"""
        prediction = unet_model.predict(test_dataset.batch(batch_size=1))
        dataset = test_dataset.map(
            utils.crop_image_and_label_to_shape(prediction.shape[1:]))

        for i, (image, label) in enumerate(dataset):
            original_images.append(image[..., -1])
            metric_labels_test.append(np.argmax(label, axis=-1))
            metric_predictions_unprocessed_test.append(prediction[i, ...])

        for i in range(len(metric_predictions_unprocessed_test)):
            metric_predictions.append(
                (normalize_output(metric_predictions_unprocessed_test[i]) >=
                 best_tau).astype(int))
            metric_predictions_unthresholded.append((normalize_output(
                metric_predictions_unprocessed_test[i]) >= 0.5).astype(int))

        # Calculate thresholded and unthresholded metrics in parallel
        parallel_metrics = [
            Metrics(metric_labels_test,
                    metric_predictions_unthresholded,
                    safe=False,
                    parallel=False),
            Metrics(metric_labels_test,
                    metric_predictions,
                    safe=False,
                    parallel=False)
        ]

        def f(m):
            return (m.jaccard()[0], m.dice()[0], m.adj_rand()[0],
                    m.warping_error()[0])

        pool = Pool(2)
        metric_result = pool.map(f, parallel_metrics)

        jaccard_index = metric_result[0][0]
        dice = metric_result[0][1]
        adj = metric_result[0][2]
        warping_error = metric_result[0][3]

        jaccard_index_to = metric_result[1][0]
        dice_to = metric_result[1][1]
        adj_to = metric_result[1][2]
        warping_error_to = metric_result[1][3]

        with results_file.open("a") as f:
            f.write(
                str(test_data_point_index) + ";" + str(jaccard_index) + ";" +
                str(dice) + ";" + str(adj) + ";" + str(warping_error) + ";" +
                str(jaccard_index_to) + ";" + str(dice_to) + ";" +
                str(adj_to) + ";" + str(warping_error_to) + "\n")

        print("test_data_point_index: " + str(test_data_point_index))
        print("Jaccard index: " + str(jaccard_index) +
              " with threshold optimization: " + str(jaccard_index_to))
        print("Dice: " + str(dice) + " with threshold optimization: " +
              str(dice_to))
        print("Adj: " + str(adj) + " with threshold optimization: " +
              str(adj_to))
        print("Warping Error: " + str(warping_error) +
              " with threshold optimization: " + str(warping_error_to))
        """Plot predictions"""
        if plot:
            fig, ax = plt.subplots(3,
                                   3,
                                   sharex=True,
                                   sharey=True,
                                   figsize=(8, 8))

            for i in range(len(metric_labels_test)):
                ax[i][0].matshow(original_images[i])
                ax[i][1].matshow(metric_labels_test[i], cmap=plt.cm.gray)
                ax[i][2].matshow(metric_predictions[i], cmap=plt.cm.gray)

            plt.tight_layout()
            plt.show()

        if store_masks:
            np.save(
                "results/BBBC039_val_fold_" + str(test_data_point_index) +
                ".npy", metric_predictions_unprocessed)
            np.save(
                "results/BBBC039_val_true_fold_" + str(test_data_point_index) +
                ".npy", metric_labels)
            np.save(
                "results/BBBC039_test_fold_" + str(test_data_point_index) +
                ".npy", metric_predictions_unprocessed_test)
            np.save(
                "results/BBBC039_test_true_fold_" +
                str(test_data_point_index) + ".npy", metric_labels)
Exemple #8
0
def main(filename=None, calculate_train_metric=False):
    """

    :param start_index:
    :param filename:
    :param plot_validation: Plots 3 samples from the validation set each fold
    :param plot_test:  Plots the test test image for each fold
    :return:
    """
    now = datetime.now()
    current_dt = now.strftime("%y_%m_%d_%H_%M_%S")
    if filename is None:
        filename = "results/" + current_dt + ".csv"
    results_file = Path(filename)
    if not results_file.is_file():
        results_file.write_text('index; jaccard; Dice; Adj; Warp\n')
    """ Load data """
    image_path = "data/synthetic/images/"
    label_path = "data/synthetic/labels/"

    file_extension = "tif"

    # inp_dim = 572
    # inp_dim = 200
    # inp_dim = 710
    inp_dim = 1024

    file_names = sorted(glob.glob(image_path + "*." + file_extension))
    file_names_labels = sorted(glob.glob(label_path + "*." + file_extension))

    print(file_names)
    print(file_names_labels)

    images = []
    for file in file_names:
        if file_extension == "tif":
            images.append(
                tf.convert_to_tensor(np.expand_dims(plt.imread(file),
                                                    axis=2)))  # For .tif
            images[-1] = images[-1] / 255  # Normalize
            images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim],
                                         preserve_aspect_ratio=True,
                                         method='bilinear')
        elif file_extension == "png":
            images.append(tf.convert_to_tensor(
                plt.imread(file)[:, :, :3]))  # For .png
            images[-1] = tf.image.resize(images[-1], [inp_dim, inp_dim],
                                         preserve_aspect_ratio=True,
                                         method='bilinear')
            images[-1] = tf.image.rgb_to_grayscale(images[-1])

        images[-1] = mirror_pad_image(images[-1], pixels=20)

    labels = []
    for file in file_names_labels:
        label = plt.imread(file)
        # label = plt.imread(file)[:, :, :3]
        label = (np.expand_dims(label, axis=2))

        label = np.where(label > 0, [0, 1], [1, 0])
        labels.append(tf.convert_to_tensor(label))

        labels[-1] = tf.image.resize(labels[-1], [inp_dim, inp_dim],
                                     preserve_aspect_ratio=True,
                                     method='bilinear')
        labels[-1] = np.where(labels[-1] > 0.5, 1, 0)

        labels[-1] = mirror_pad_image(labels[-1], pixels=20)

    print("num images: " + str(len(images)))
    print("num labels: " + str(len(labels)))

    num_data_points = len(images)

    scilife_images, scilife_labels = scilife_data()

    # plt.matshow(scilife_images[1][..., -1])
    # plt.show()
    #
    # for i in range(len(scilife_images)):
    #     print(np.max(scilife_images[i]))

    images_temp = images.copy()
    labels_temp = labels.copy()
    """for i in range((5)):
        plt.matshow(images_temp[i][..., -1])
        plt.show()
        plt.matshow(np.argmax(labels_temp[i], axis=-1), cmap=plt.cm.gray)
        plt.show()"""

    print("num images: " + str(len(images_temp)))
    print("num labels: " + str(len(labels_temp)))

    random_permutation = np.random.permutation(len(images_temp))
    images_temp = np.array(images_temp)[random_permutation]
    labels_temp = np.array(labels_temp)[random_permutation]

    image_dataset = tf.data.Dataset.from_tensor_slices(
        (images_temp, labels_temp))
    """Crate data splits"""
    train_dataset = image_dataset.take(100)
    validation_dataset = image_dataset.skip(100)

    train_dataset.shuffle(100, reshuffle_each_iteration=True)

    train_dataset = train_dataset.map(
        augment_image)  # Apply transformations to training data
    """Load model"""
    print(circles.channels)
    print(circles.classes)

    unet_model = unet.build_model(channels=circles.channels,
                                  num_classes=circles.classes,
                                  layer_depth=3,
                                  filters_root=16)
    if calculate_train_metric:
        unet.finalize_model(unet_model)
    else:
        unet.finalize_model(unet_model,
                            dice_coefficient=False,
                            auc=False,
                            mean_iou=False)
    """Train"""
    # callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
    # trainer = unet.Trainer(checkpoint_callback=False, callbacks=[callback])
    trainer = unet.Trainer(checkpoint_callback=False)

    trainer.fit(unet_model, train_dataset, epochs=25, batch_size=1)
    """Sci Life data prediction"""
    scilife_dataset = tf.data.Dataset.from_tensor_slices(
        (scilife_images, scilife_labels))
    prediction = unet_model.predict(scilife_dataset.batch(batch_size=1))

    original_images = []
    metric_labels = []
    metric_predictions_unprocessed = []
    metric_predictions = []

    dataset = scilife_dataset.map(
        utils.crop_image_and_label_to_shape((inp_dim, inp_dim, 2)))
    prediction = remove_border(prediction, inp_dim, inp_dim)
    # print("Validation shape after: ", prediction.shape)

    for i, (image, _) in enumerate(dataset):
        original_images.append(image[..., -1])
        metric_predictions_unprocessed.append(prediction[i, ...])

    for i in range(len(metric_predictions_unprocessed)):
        metric_predictions.append(
            np.argmax(metric_predictions_unprocessed[i] * np.array([[[1, 1]]]),
                      axis=-1))

    fig, ax = plt.subplots(5, 2, sharex=True, sharey=True, figsize=(25, 60))

    for i in range(5):
        ax[i][0].matshow(original_images[i])
        ax[i][1].matshow(metric_predictions[i], cmap=plt.cm.gray)
        plt.imsave("results/scilifelab_" + str(current_dt) + "_index_" +
                   str(i) + ".png",
                   metric_predictions[i],
                   cmap=plt.cm.gray)

    plt.tight_layout()
    plt.savefig("results/scilifelab_" + str(current_dt) + ".png")
    plt.show()