Exemple #1
0
def load_data(dataset_path, contour_type, crop_size):
    contours_path = os.path.join(dataset_path, 'contours')
    images_path = os.path.join(dataset_path, 'images')

    contours = load_all_contours(contours_path, contour_type, shuffle=True)
    return export_all_contours(
        contours,
        images_path,
        crop_size=crop_size,
        sax_series=get_SAX_SERIES(),
    )
Exemple #2
0
#!/usr/bin/env python2.7

import re, sys, os
import shutil, cv2
import numpy as np

from train_sunnybrook_unetres import read_contour, map_all_contours, export_all_contours
from helpers import reshape, get_SAX_SERIES, draw_result

from unet_model_inv import unet_model_inv, dice_coef_each
SAX_SERIES = get_SAX_SERIES()
SUNNYBROOK_ROOT_PATH = 'D:\cardiac_data\Sunnybrook'
VAL_CONTOUR_PATH = os.path.join(
    SUNNYBROOK_ROOT_PATH, 'Sunnybrook Cardiac MR Database ContoursPart2',
    'ValidationDataContours')
VAL_IMG_PATH = os.path.join(SUNNYBROOK_ROOT_PATH,
                            'Sunnybrook Cardiac MR Database DICOMPart2',
                            'ValidationDataDICOM')
VAL_OVERLAY_PATH = os.path.join(SUNNYBROOK_ROOT_PATH,
                                'Sunnybrook Cardiac MR Database OverlayPart2',
                                'ValidationDataOverlay')
ONLINE_CONTOUR_PATH = os.path.join(
    SUNNYBROOK_ROOT_PATH, 'Sunnybrook Cardiac MR Database ContoursPart1',
    'OnlineDataContours')
ONLINE_IMG_PATH = os.path.join(SUNNYBROOK_ROOT_PATH,
                               'Sunnybrook Cardiac MR Database DICOMPart1',
                               'OnlineDataDICOM')
ONLINE_OVERLAY_PATH = os.path.join(
    SUNNYBROOK_ROOT_PATH, 'Sunnybrook Cardiac MR Database OverlayPart1',
    'OnlineDataOverlay')
SAVE_VAL_PATH = os.path.join(SUNNYBROOK_ROOT_PATH, 'Sunnybrook_val_submission')
def train(image_path, contour_path, *, contour_type, crop_size, batch_size,
          seed, epoch_count):
    contours = load_all_contours(contour_path, contour_type, shuffle=True)

    loaded_train_x, loaded_train_y = export_all_contours(
        contours,
        image_path,
        crop_size=crop_size,
        sax_series=get_SAX_SERIES(),
    )

    input_shape = (crop_size, crop_size, 1)
    num_classes = 2

    m = fcn_model

    model = m(input_shape, num_classes)

    kwargs = dict(
        rotation_range=180,
        zoom_range=0.0,
        width_shift_range=0.0,
        height_shift_range=0.0,
        horizontal_flip=True,
        vertical_flip=True,
    )
    image_datagen = ImageDataGenerator(**kwargs)
    mask_datagen = ImageDataGenerator(**kwargs)

    image_generator = image_datagen.flow(
        loaded_train_x,
        shuffle=False,
        batch_size=batch_size,
        seed=seed,
    )
    mask_generator = mask_datagen.flow(
        loaded_train_y,
        shuffle=False,
        batch_size=batch_size,
        seed=seed,
    )
    train_generator = zip(image_generator, mask_generator)

    max_iter = (len(contours) / batch_size) * epoch_count
    curr_iter = 0
    base_lr = K.eval(model.optimizer.lr)
    learning_rate = lr_poly_decay(model,
                                  base_lr,
                                  curr_iter,
                                  max_iter,
                                  power=0.5)

    for epoch in range(1, epoch_count + 1):
        print()
        print('Main Epoch {:d}'.format(epoch))
        print('Learning rate: {:6f}'.format(learning_rate))
        train_result = []

        iter_count = len(img_train) // batch_size
        for img, mask in islice(train_generator, iter_count):
            res = model.train_on_batch(img, mask)
            curr_iter += 1
            learning_rate = lr_poly_decay(model,
                                          base_lr,
                                          curr_iter,
                                          max_iter,
                                          power=0.5)
            train_result.append(res)

        train_result = np.asarray(train_result)
        train_result = np.mean(train_result, axis=0).round(decimals=10)

        print('Train result {}:'.format(str(model.metrics_names)))
        print('{}'.format(str(train_result)))
        print()
        print('Evaluating dev set ...')

        result = model.evaluate(img_dev, mask_dev, batch_size=32)
        result = np.round(result, decimals=10)

        print()
        print('Dev set result {}:'.format(str(model.metrics_names)))
        print('{}'.format(str(result)))

        predict_and_save(
            model,
            img_dev[:10],
            output_dir=os.path.join(DIR_DATA, 'predictions'),
        )

        save_file = '_'.join([
            'sunnybrook',
            contour_type,
            'epoch',
            str(epoch),
        ]) + '.h5'

        logs_dir = os.path.join(DIR_DATA, 'model_logs')
        os.makedirs(logs_dir, exist_ok=True)
        save_path = os.path.join(logs_dir, save_file)

        print()
        print('Saving model weights to {}'.format(save_path))

        model.save_weights(save_path)