Exemplo n.º 1
0
    def createAugmentor(self):
        """ Creates ImageAugmentor object with the parameters for image augmentation

        Rotation range was set in -15 to 15 degrees
        Shear Range was set in between -0.3 and 0.3 radians
        Zoom range between 0.8 and 2
        Shift range was set in +/- 5 pixels

        Returns:
            ImageAugmentor object

        """
        rotation_range = [-15, 15]
        shear_range = [-0.3 * 180 / math.pi, 0.3 * 180 / math.pi]
        zoom_range = [0.8, 2]
        shift_range = [5, 5]

        return ImageAugmentor(0.5, shear_range, rotation_range, shift_range, zoom_range)
Exemplo n.º 2
0
from __future__ import division
import torch
from torch.autograd import Variable
from torch.utils import data
from norm import FCN
from datasets import CSDataSet
from loss import CrossEntropy2d, CrossEntropyLoss2d
from transform import ReLabel, ToLabel, ToSP, Scale, Augment
from torchvision.transforms import Compose, CenterCrop, Normalize, ToTensor
from PIL import Image
import numpy as np

import utils
from image_augmentor import ImageAugmentor

image_augmentor = ImageAugmentor()

NUM_CLASSES = 6
MODEL_NAME = "seg-norm"

input_transform = Compose([
    Scale((512, 256), Image.BILINEAR),
    Augment(0, image_augmentor),
    ToTensor(),
    Normalize([.485, .456, .406], [.229, .224, .225]),
])
target_transform = Compose([
    Scale((512, 256), Image.NEAREST),
    ToLabel(),
    ReLabel(),
])
Exemplo n.º 3
0
class Trainer():
    def __init__(self,
                 dataset,
                 dirs,
                 network_architecture_params,
                 augmentation_params,
                 batch_size=32,
                 epochs=50,
                 verbose=1,
                 patience=0,
                 loss_function='binary_crossentropy',
                 optimizer='Adam',
                 initial_lr=1e-4,
                 **kwargs):

        # for train & validation data
        self.dataset = dataset
        self.X_train = self.dataset.X_train
        self.y_train = self.dataset.y_train
        self.X_valid = self.dataset.X_valid
        self.y_valid = self.dataset.y_valid

        self.batch_size = batch_size
        self.epochs = epochs
        self.verbose = verbose
        self.patience = patience

        self.fetch_sample_valid_data(size=5)
        self.n_batch = math.ceil(self.X_train.shape[0] / self.batch_size)
        self.n_class = np.max(self.y_train) + 1

        # for model compile
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.loss_function = loss_function

        # for model construction
        self.network_architecture_params = network_architecture_params

        # for augmentation
        self.augmentation_params = augmentation_params
        if self.augmentation_params['augmentation']:
            self.augmentor = ImageAugmentor(**self.augmentation_params)
        else:
            self.augmentor = None

        # for save results
        self.result_dir = dirs['result_dir']
        self.log_dir = dirs['log_dir']
        self.train_batch_dir = dirs['train_batch_dir']
        self.valid_batch_dir = dirs['valid_batch_dir']

        # for results logger
        self.logger = Logger()
        self.train_loss_epoch = []
        self.train_loss_batch = []
        self.valid_loss_epoch = []
        self.best_loss = float('inf')
        self.wait = 0

    def update(self):

        # update loop
        for idx in range(self.n_batch):
            batch_X = copy(self.X_train[idx * self.batch_size:(idx + 1) *
                                        self.batch_size])
            batch_y = copy(self.y_train[idx * self.batch_size:(idx + 1) *
                                        self.batch_size])

            # print(batch_X.shape)
            # print(batch_y.shape)

            # apply transform as data augmentation
            if self.augmentor is not None:
                batch_X, batch_y = self.augmentor.augment(batch_X,
                                                          batch_y,
                                                          borderMode='reflect')

            #TODO: save batch_X and batch_y to self.train_batch_dir
            if idx < 5 and self.current_epoch < 5:
                for i in range(batch_X.shape[0]):
                    visualize_results(
                        save_path=self.train_batch_dir +
                        '/sample{:02d}_epoch{:04d}.png'.format(
                            (idx * batch_X.shape[0]) + i, self.current_epoch),
                        image=batch_X[i],
                        gt=batch_y[i],
                        pred=batch_y[i])

            # update weight
            train_loss = self.model.train_on_batch(batch_X, batch_y)

            # store loss
            self.loss_tmp.append(train_loss)
            self.train_loss_batch.append(train_loss)
            print("\rbatch: {}/{} loss: {:.5f} ".format(
                idx + 1, self.n_batch, train_loss),
                  end="")

        self.logger.plot_history(histories=[self.train_loss_batch],
                                 save_path=self.log_dir +
                                 '/loss_hist_iter.png',
                                 legends=['train loss'],
                                 x_label='iteration')

    def define_model(self):

        # create model
        u_net = UNet(input_shape=(self.X_train.shape[1], self.X_train.shape[2],
                                  3),
                     n_class=self.n_class,
                     **self.network_architecture_params)
        self.model = u_net.get_model()

        # save architecture
        self.logger.save_model_structure(
            self.model,
            #  fig_path=self.result_dir+'/model.png',
            txt_path=self.result_dir + '/model.txt')

        # define optimizer and loss function
        self.model.compile(optimizer=get_optimizer(name=self.optimizer,
                                                   lr=self.initial_lr),
                           loss=self.loss_function)

    def train(self):

        self.define_model()

        for epoch in range(1, self.epochs + 1):
            print("-" * 80)
            print("epoch: {}/{}".format(epoch, self.epochs))
            self.current_epoch = epoch
            self.loss_tmp = []

            # shuffle
            p = np.random.permutation(self.X_train.shape[0])
            self.X_train, self.y_train = self.X_train[p], self.y_train[p]

            # update loop
            self.update()

            # summarize loss
            train_loss = np.mean(self.loss_tmp)
            val_loss = self.model.evaluate(self.X_valid,
                                           self.y_valid,
                                           batch_size=1,
                                           verbose=0)
            self.train_loss_epoch.append(train_loss)
            self.valid_loss_epoch.append(val_loss)
            print("training loss: {:.5f}, validation loss: {:.5f}".format(
                train_loss, val_loss))

            # plot loss hitory
            self.logger.plot_history(
                histories=[self.train_loss_epoch, self.valid_loss_epoch],
                save_path=self.log_dir + '/loss_hist_epoch.png')

            # test using sampled validation data
            pred_sample = self.model.predict(self.X_valid_sample,
                                             batch_size=2,
                                             verbose=0)
            pred_sample = np.argmax(pred_sample, axis=3).astype(np.uint8)

            for i in range(pred_sample.shape[0]):
                visualize_results(
                    save_path=self.valid_batch_dir +
                    '/sample{:02d}_epoch{:04d}.png'.format(i, epoch),
                    image=self.X_valid_sample[i],
                    gt=self.y_valid_sample[i],
                    pred=pred_sample[i])

            # model checkpoint
            self.current_loss = val_loss
            self.checkpoint()
            if (self.patience > 0) and (self.wait >= self.patience):
                break

            self.logger.save_history_as_json(
                {
                    'train_loss_epoch':
                    [str(r) for r in self.train_loss_epoch],
                    'train_loss_batch':
                    [str(r) for r in self.train_loss_batch],
                    'valid_loss_epoch':
                    [str(r) for r in self.valid_loss_epoch]
                }, self.log_dir + '/loss.json')

    def fetch_sample_valid_data(self, size=5):

        if size > self.X_valid.shape[0]:
            warnings.warn(
                'indicated sample size is larger than whole validation data.\n\
                           all validation data will be used as sample data.')
            size = self.X_valid.shape[0]

        sample_idx = np.random.randint(self.X_valid.shape[0], size=size)
        self.X_valid_sample = copy(self.X_valid[sample_idx])
        self.y_valid_sample = copy(self.y_valid[sample_idx])

    def checkpoint(self):

        if self.current_loss < self.best_loss:
            print('validation loss was improved from {:.5f} to {:.5f}.'.format(
                self.best_loss, self.current_loss))
            self.model.save(self.result_dir + '/trained_model.h5')
            self.wait = 0
            self.best_loss = self.current_loss
        else:
            print('validation loss was not improved.')
            self.wait += 1
Exemplo n.º 4
0
class DataManager(object):
    def __init__(self, param, shuffle=True, valid=False, extension='.bmp'):

        self.shuffle = shuffle
        self.extension = extension
        if valid:
            self.image_root = param["image_root_valid"]
            self.mask_root = param["mask_root_valid"]
            # self.batch_size = 1
        else:
            self.image_root = param["image_root_train"]
            self.mask_root = param["mask_root_train"]

        self.mode = param["mode"]
        if self.mode == "train_segmentation" or self.mode == "train_decision":
            self.batch_size = param["batch_size"]
        elif self.mode == "savePb" or self.mode == "testPb":
            self.batch_size = param["batch_size_inference"]
        else:
            self.batch_size = 1
        self.epoch_num = param["epochs"]
        self.augmentor = ImageAugmentor(param)
        # self.augmentation = param["augmentation"]
        self.next_batch = self.get_next()
        self.image_files = [x[2] for x in os.walk(self.image_root)][0]
        self.mask_files = [x[2] for x in os.walk(self.mask_root)][0]
        if param["balanced_mode"]:
            self.image_files = [x for x in self.image_files if "p_" in x]
            self.mask_files = [x for x in self.mask_files if "p_" in x]

        self.num_batch = len(self.image_files) // self.batch_size

    def get_next(self):
        """ Encapsulate generator into TensorFlow DataSet"""
        dataset = tf.data.Dataset.from_generator(
            self.generator, (tf.float32, tf.float32, tf.float32, tf.string))
        dataset = dataset.repeat(self.epoch_num + self.epoch_num // 10 + 1)
        dataset = dataset.batch(self.batch_size)
        iterator = dataset.make_one_shot_iterator()
        out_batch = iterator.get_next()
        return out_batch

    def generator(self):
        """
        Generator of image, mask, label, and image_root
        Should be revised according to the saving path of data
        :return:  image
                  mask
                  label
                  image_path
        """
        rand_index = np.arange(len(self.image_files))
        if self.shuffle:
            np.random.shuffle(rand_index)
        for index in rand_index:
            image_filename = self.image_files[index]
            # 训练数据
            image_path = self.image_root + image_filename
            if self.check_mask_name():
                mask_path = self.mask_root + image_filename
            else:
                mask_path = self.mask_root + image_filename.split(
                    ".")[0] + "_label" + self.extension

            " Generate label from image filename. Shall be revised according to specific situation"
            if image_path.split('/')[-1].split('_')[0] == 'n':
                label = np.array([0.0])
            else:
                label = np.array([1.0])

            image, mask = self.read_data(image_path, mask_path)

            image = image / 255.0
            mask = mask // 255

            if self.mode == "train_segmentation" or self.mode == "train_decision":
                aug_random = np.random.uniform()
                if aug_random > 0.9:
                    image, mask = self.augmentor.transform_seg(image, mask)
                    # # adjust_gamma
                    # if np.random.uniform() > 0.7 and "adjust_gamma" in self.augmentation:
                    #     expo = np.random.choice([0.7, 0.8, 0.9, 1.1, 1.2, 1.3])
                    #     image = exposure.adjust_gamma(image, expo)
                    #
                    # # flip
                    # if np.random.uniform() > 0.7 and "flip" in self.augmentation:
                    #     aug_seed = np.random.randint(-1, 2)
                    #     image = cv2.flip(image, aug_seed)
                    #     mask = cv2.flip(mask, aug_seed)
                    #
                    # # rotate
                    # if np.random.uniform() > 0.7 and "rotate" in self.augmentation:
                    #     angle = np.random.randint(-5, 5)
                    #     image = self.rotate(image, angle)
                    #     mask = self.rotate(mask, angle)
                    #
                    # # GassianBlur
                    # if np.random.uniform() > 0.7 and "GaussianBlur" in self.augmentation:
                    #     image = cv2.GaussianBlur(image, (5, 5), 0)
                    #
                    # # shift
                    # if np.random.uniform() > 0.7 and "shift" in self.augmentation:
                    #     dx = np.random.randint(-5, 5)  # width*5%
                    #     dy = np.random.randint(-5, 5)  # Height*10%
                    #     rows, cols = image.shape[:2]
                    #     M = np.float32([[1, 0, dx], [0, 1, dy]])  # (x,y) -> (dx,dy)
                    #     image = cv2.warpAffine(image, M, (cols, rows))
                    #     mask = cv2.warpAffine(mask, M, (cols, rows))

            if len(image.shape) == 2:
                image = (np.array(image[:, :, np.newaxis]))
            if len(mask.shape) == 2:
                mask = (np.array(mask[:, :, np.newaxis]))

            yield image, mask, label, image_path

    @staticmethod
    def read_data(image_path, mask_path):
        """ Read image and mask"""
        img = cv2.imread(image_path, 0)  # /255.#read the gray image
        img = cv2.resize(img, (IMAGE_WIDTH, IMAGE_HEIGHT))

        try:
            msk = cv2.imread(mask_path, 0)  # /255.#read the gray image
            msk = cv2.resize(msk, (IMAGE_WIDTH, IMAGE_HEIGHT))
            _, msk = cv2.threshold(msk, 0, 255, cv2.THRESH_BINARY)
        except:
            raise ValueError(" Cannot find mask {}".format(mask_path))

        return img, msk

    @staticmethod
    def rotate(image, angle, center=None, scale=1.0):
        """Rotate and scale image around given center at given scale"""
        (h, w) = image.shape[:2]
        if center is None:
            center = (w // 2, h // 2)

        M = cv2.getRotationMatrix2D(center, angle, scale)

        rotated = cv2.warpAffine(image, M, (w, h))
        return rotated

    def check_mask_name(self):
        if "label" in self.mask_files[0]:
            return False
        else:
            return True