コード例 #1
0
    def test_shuffle_train_val(self):
        # test shuffling of entire dataset prior to train-val split
        # (does not test shuffling within each epoch)
        data_dir = "../test-assets/"
        batch_size = 2
        validation_split = 0.5
        mask = "inner"
        classes = 2
        seed = 5               # random number seed

        # there should be 2 images in the validation set, and we'll check if
        # they always appear in the same order with a fixed seed
        image_list = []
        mask_list = []
        for i in range(10):
            _, _, val_generator, _ = dataset.create_generators(
                data_dir, batch_size, validation_split=validation_split,
                mask=mask, shuffle_train_val=True, shuffle=False, seed=seed,
                normalize_images=True)

            images, masks = next(val_generator)
            self.assertEqual(images.shape, (2, 216, 256, 1))
            self.assertEqual(masks.shape, (2, 216, 256, classes))

            # also check image normalization
            for image in images:
                self.assertAlmostEqual(np.mean(image), 0)
                self.assertAlmostEqual(np.std(image), 1, places=5)

            image_list.append(images[0])
            mask_list.append(masks[0])

        # first image/mask in each case should be the same
        image0 = image_list[0]
        for image in image_list[1:]:
            np.testing.assert_array_equal(image0, image)
        mask0 = mask_list[0]
        for mask in mask_list[1:]:
            np.testing.assert_array_equal(mask0, mask)

        # now test that things get shuffled if we don't specify a seed
        mask = "both"
        _, _, val_generator, _ = dataset.create_generators(
            data_dir, batch_size, validation_split=validation_split,
            mask=mask, shuffle_train_val=True, shuffle=False, seed=None,
            normalize_images=True)

        images, masks = next(val_generator)
        image0 = images[0]
        while 1:
            _, _, val_generator, _ = dataset.create_generators(
                data_dir, batch_size, validation_split=validation_split,
                mask=mask, shuffle_train_val=True, shuffle=True, seed=None,
                normalize_images=True)
            images, masks = next(val_generator)            
            try:
                np.testing.assert_array_equal(image0, images[0])
            except AssertionError:
                break           # break if arrays are differet (= success!)
コード例 #2
0
    def _test_no_validation(self, mask):
        data_dir = "../test-assets/"
        batch_size = 2
        validation_split = 0.0

        (train_generator, train_steps_per_epoch,
         val_generator, val_steps_per_epoch) = dataset.create_generators(
             data_dir, batch_size,
             validation_split=validation_split,
             mask=mask)

        self.assertEqual(train_steps_per_epoch, 2)
        self.assertEqual(val_steps_per_epoch, 0)

        classes = 3 if mask == 'both' else 2

        # first 2 train images
        images, masks = next(train_generator)
        self.assertEqual(images.shape, (2, 216, 256, 1))
        self.assertEqual(masks.shape, (2, 216, 256, classes))

        # last train image (for total of 3)
        images, masks = next(train_generator)
        self.assertEqual(images.shape, (1, 216, 256, 1))
        self.assertEqual(masks.shape, (1, 216, 256, classes))

        # first 2 train images again
        images, masks = next(train_generator)
        self.assertEqual(images.shape, (2, 216, 256, 1))
        self.assertEqual(masks.shape, (2, 216, 256, classes))

        # validation generator should be nothing
        self.assertEqual(val_generator, None)
コード例 #3
0
    def _test_generator(self, mask):
        data_dir = "../test-assets/"
        batch_size = 2
        validation_split = 0.5
        # With a total of 3 training images, this split will create 1
        # training image and 2 validation images

        (train_generator, train_steps_per_epoch,
         val_generator, val_steps_per_epoch) = dataset.create_generators(
             data_dir, batch_size,
             validation_split=validation_split,
             mask=mask)

        self.assertEqual(train_steps_per_epoch, 1)
        self.assertEqual(val_steps_per_epoch, 1)

        classes = 3 if mask == 'both' else 2

        images, masks = next(train_generator)
        self.assertEqual(images.shape, (1, 216, 256, 1))
        self.assertEqual(masks.shape, (1, 216, 256, classes))

        images, masks = next(val_generator)
        self.assertEqual(images.shape, (2, 216, 256, 1))
        self.assertEqual(masks.shape, (2, 216, 256, classes))
コード例 #4
0
ファイル: train.py プロジェクト: ol-sen/cardiac-segmentation
def train():
    logging.basicConfig(level=logging.INFO)

    args = opts.parse_arguments()

    #logging.info("Loading dataset...")
    augmentation_args = {
        'rotation_range': args.rotation_range,
        'width_shift_range': args.width_shift_range,
        'height_shift_range': args.height_shift_range,
        'shear_range': args.shear_range,
        'zoom_range': args.zoom_range,
        'fill_mode': args.fill_mode,
        'alpha': args.alpha,
        'sigma': args.sigma,
    }

    ACDC.load(args.datadir, "RV", "ES")
    ACDC.resize(216, 256)
    channels = 1

    images = ACDC.get_images()
    for i in range(len(images)):
        height, width, depth = images[i].shape
        #images[i] = images[i].reshape(depth, channels, height, width)
        images[i] = images[i].reshape(depth, height, width, channels)

    masks = ACDC.get_masks()
    for i in range(len(masks)):
        height, width, depth, classes = masks[i].shape
        masks[i] = masks[i].reshape(depth, height, width, classes)

    #images, masks = RVSC.load(args.datadir, args.classes)

    ## get image dimensions
    #_, height, width, channels = images[0].shape
    #_, _, _, classes = masks[0].shape

    logging.info("Building model...")
    string_to_model = {
        "unet": models.unet,
        "dilated-unet": models.dilated_unet,
        "dilated-densenet": models.dilated_densenet,
        "dilated-densenet2": models.dilated_densenet2,
        "dilated-densenet3": models.dilated_densenet3,
    }

    if args.multi_gpu:
        with tf.device('/cpu:0'):
            model = string_to_model[args.model]
            m = model(height=height,
                      width=width,
                      channels=channels,
                      classes=classes,
                      features=args.features,
                      depth=args.depth,
                      padding=args.padding,
                      temperature=args.temperature,
                      batchnorm=args.batchnorm,
                      dropout=args.dropout)

            m.summary()

            if args.load_weights:
                logging.info("Loading saved weights from file: {}".format(
                    args.load_weights))
                m.load_weights(args.load_weights)
    else:
        model = string_to_model[args.model]
        m = model(height=height,
                  width=width,
                  channels=channels,
                  classes=classes,
                  features=args.features,
                  depth=args.depth,
                  padding=args.padding,
                  temperature=args.temperature,
                  batchnorm=args.batchnorm,
                  dropout=args.dropout)

        m.summary()

        if args.load_weights:
            logging.info("Loading saved weights from file: {}".format(
                args.load_weights))
            m.load_weights(args.load_weights)

    # instantiate optimizer, and only keep args that have been set
    # (not all optimizers have args like `momentum' or `decay')
    optimizer_args = {
        'lr': args.learning_rate,
        'momentum': args.momentum,
        'decay': args.decay
    }
    for k in list(optimizer_args):
        if optimizer_args[k] is None:
            del optimizer_args[k]
    optimizer = select_optimizer(args.optimizer, optimizer_args)

    # select loss function: pixel-wise crossentropy, soft dice or soft
    # jaccard coefficient
    if args.loss == 'pixel':

        def lossfunc(y_true, y_pred):
            return loss.weighted_categorical_crossentropy(
                y_true, y_pred, args.loss_weights)
    elif args.loss == 'dice':

        def lossfunc(y_true, y_pred):
            return loss.sorensen_dice_loss(y_true, y_pred, args.loss_weights)
    elif args.loss == 'jaccard':

        def lossfunc(y_true, y_pred):
            return loss.jaccard_loss(y_true, y_pred, args.loss_weights)
    else:
        raise Exception("Unknown loss ({})".format(args.loss))

    def dice(y_true, y_pred):
        batch_dice_coefs = loss.sorensen_dice(y_true, y_pred, axis=[1, 2])
        dice_coefs = K.mean(batch_dice_coefs, axis=0)
        return dice_coefs[1]  # HACK for 2-class case

    def jaccard(y_true, y_pred):
        batch_jaccard_coefs = loss.jaccard(y_true, y_pred, axis=[1, 2])
        jaccard_coefs = K.mean(batch_jaccard_coefs, axis=0)
        return jaccard_coefs[1]  # HACK for 2-class case

    metrics = ['accuracy', dice, jaccard]

    if args.multi_gpu:
        parallel_model = multi_gpu_model(m, gpus=2)
        parallel_model.compile(optimizer=optimizer,
                               loss=lossfunc,
                               metrics=metrics)
    else:
        m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics)

    train_indexes = []
    val_indexes = []
    if args.cross_val_folds is not None:
        if args.cross_val_folds > len(images):
            raise Exception(
                "Number of cross validation folds must be not more than {}.".
                format(len(images)))

        kf = KFold(n_splits=4)
        val_dice_values = []
        fold = 1
        for train_indexes, val_indexes in kf.split(images):
            print("fold #{}".format(fold))
            print("{} {}".format(train_indexes, val_indexes))
            train_generator, train_steps_per_epoch, \
                val_generator, val_steps_per_epoch = dataset.create_generators(
                    images, masks, args.datadir, args.batch_size,
                    train_indexes, val_indexes,
                    validation_split=args.validation_split,
                    shuffle_train_val=args.shuffle_train_val,
                    shuffle=args.shuffle,
                    seed=args.seed,
                    normalize_images=args.normalize,
                    augment_training=args.augment_training,
                    augment_validation=args.augment_validation,
                    augmentation_args=augmentation_args)

            # automatic saving of model during training
            if args.checkpoint:
                monitor = 'val_dice'
                mode = 'max'
                filepath = os.path.join(
                    args.outdir, "weights-{epoch:02d}-{val_dice:.4f}" +
                    "-fold{}.hdf5".format(fold))

                if args.multi_gpu:
                    checkpoint = MultiGPUModelCheckpoint(filepath,
                                                         m,
                                                         monitor=monitor,
                                                         verbose=1,
                                                         save_best_only=True,
                                                         mode=mode)
                else:
                    checkpoint = ModelCheckpoint(filepath,
                                                 monitor=monitor,
                                                 verbose=1,
                                                 save_best_only=True,
                                                 mode=mode)

                callbacks = [checkpoint]
            else:
                callbacks = []

            # train
            logging.info("Begin training.")
            if args.multi_gpu:
                history = parallel_model.fit_generator(
                    train_generator,
                    epochs=args.epochs,
                    steps_per_epoch=train_steps_per_epoch,
                    validation_data=val_generator,
                    validation_steps=val_steps_per_epoch,
                    callbacks=callbacks,
                    verbose=2)
            else:
                history = m.fit_generator(
                    train_generator,
                    epochs=args.epochs,
                    steps_per_epoch=train_steps_per_epoch,
                    validation_data=val_generator,
                    validation_steps=val_steps_per_epoch,
                    callbacks=callbacks,
                    verbose=2)

            save_plot(args.outfile_plot + "-fold{}.png".format(fold), history)
            m.save(
                os.path.join(args.outdir,
                             args.outfile + "-fold{}.hdf5".format(fold)))
            val_dice_values += [max(history.history['val_dice'])]
            fold += 1

        print("Maximum dice values on validation sets are {}.".format(
            val_dice_values))
        print("Mean dice value is {}.".format(np.mean(val_dice_values)))
    else:
        train_generator, train_steps_per_epoch, \
            val_generator, val_steps_per_epoch = dataset.create_generators(
                images, masks,
                args.datadir, args.batch_size,
                train_indexes, val_indexes,
                validation_split=args.validation_split,
                shuffle_train_val=args.shuffle_train_val,
                shuffle=args.shuffle,
                seed=args.seed,
                normalize_images=args.normalize,
                augment_training=args.augment_training,
                augment_validation=args.augment_validation,
                augmentation_args=augmentation_args)

        # automatic saving of model during training
        if args.checkpoint:
            monitor = 'val_dice'
            mode = 'max'
            filepath = os.path.join(args.outdir,
                                    "weights-{epoch:02d}-{val_dice:.4f}.hdf5")

            if args.multi_gpu:
                checkpoint = MultiGPUModelCheckpoint(filepath,
                                                     m,
                                                     monitor=monitor,
                                                     verbose=1,
                                                     save_best_only=True,
                                                     mode=mode)
            else:
                checkpoint = ModelCheckpoint(filepath,
                                             monitor=monitor,
                                             verbose=1,
                                             save_best_only=True,
                                             mode=mode)

            callbacks = [checkpoint]
        else:
            callbacks = []

        # train
        logging.info("Begin training.")
        if args.multi_gpu:
            history = parallel_model.fit_generator(
                train_generator,
                epochs=args.epochs,
                steps_per_epoch=train_steps_per_epoch,
                validation_data=val_generator,
                validation_steps=val_steps_per_epoch,
                callbacks=callbacks,
                verbose=2)
        else:
            history = m.fit_generator(train_generator,
                                      epochs=args.epochs,
                                      steps_per_epoch=train_steps_per_epoch,
                                      validation_data=val_generator,
                                      validation_steps=val_steps_per_epoch,
                                      callbacks=callbacks,
                                      verbose=2)

        save_plot(args.outfile_plot + ".png", history)
        m.save(os.path.join(args.outdir, args.outfile + ".hdf5"))

        print("Maximum dice value on validation set is {}.".format(
            max(history.history['val_dice'])))
コード例 #5
0
ファイル: eval.py プロジェクト: prateekiitr/D-N-N_HEALTH-CARE
def main():
    # Sort of a hack:
    # args.outfile = file basename to store train / val dice scores
    # args.checkpoint = turns on saving of images
    args = opts.parse_arguments()

    print("Loading dataset...")
    augmentation_args = {
        'rotation_range': args.rotation_range,
        'width_shift_range': args.width_shift_range,
        'height_shift_range': args.height_shift_range,
        'shear_range': args.shear_range,
        'zoom_range': args.zoom_range,
        'fill_mode': args.fill_mode,
        'alpha': args.alpha,
        'sigma': args.sigma,
    }
    train_generator, train_steps_per_epoch, \
        val_generator, val_steps_per_epoch = dataset.create_generators(
            args.datadir, args.batch_size,
            validation_split=args.validation_split,
            mask=args.classes,
            shuffle_train_val=args.shuffle_train_val,
            shuffle=args.shuffle,
            seed=args.seed,
            normalize_images=args.normalize,
            augment_training=args.augment_training,
            augment_validation=args.augment_validation,
            augmentation_args=augmentation_args)

    # get image dimensions from first batch
    images, masks = next(train_generator)
    _, height, width, channels = images.shape
    _, _, _, classes = masks.shape

    print("Building model...")
    string_to_model = {
        "unet": models.unet,
        "dilated-unet": models.dilated_unet,
        "dilated-densenet": models.dilated_densenet,
        "dilated-densenet2": models.dilated_densenet2,
        "dilated-densenet3": models.dilated_densenet3,
    }
    model = string_to_model[args.model]

    m = model(height=height,
              width=width,
              channels=channels,
              classes=classes,
              features=args.features,
              depth=args.depth,
              padding=args.padding,
              temperature=args.temperature,
              batchnorm=args.batchnorm,
              dropout=args.dropout)

    m.load_weights(args.load_weights)

    print("Training Set:")
    train_dice, train_jaccard, train_images = compute_statistics(
        m,
        train_generator,
        train_steps_per_epoch,
        return_images=args.checkpoint)
    print()
    print("Validation Set:")
    val_dice, val_jaccard, val_images = compute_statistics(
        m, val_generator, val_steps_per_epoch, return_images=args.checkpoint)

    if args.outfile:
        train_data = np.asarray([train_dice, train_jaccard]).T
        val_data = np.asarray([val_dice, val_jaccard]).T
        np.savetxt(args.outfile + ".train", train_data)
        np.savetxt(args.outfile + ".val", val_data)

    if args.checkpoint:
        print("Saving images...")
        for i, dice in enumerate(train_dice):
            image, mask_true, mask_pred = train_images[i]
            figname = "train-{:03d}-{:.3f}.png".format(i, dice)
            save_image(figname, image, mask_true, np.round(mask_pred))
        for i, dice in enumerate(val_dice):
            image, mask_true, mask_pred = val_images[i]
            figname = "val-{:03d}-{:.3f}.png".format(i, dice)
            save_image(figname, image, mask_true, np.round(mask_pred))
コード例 #6
0
def train():
    logging.basicConfig(level=logging.INFO)

    args = opts.parse_arguments()

    logging.info("Loading dataset...")
    augmentation_args = {
        'rotation_range': args.rotation_range,
        'width_shift_range': args.width_shift_range,
        'height_shift_range': args.height_shift_range,
        'shear_range': args.shear_range,
        'zoom_range': args.zoom_range,
        'fill_mode' : args.fill_mode,
        'alpha': args.alpha,
        'sigma': args.sigma,
    }
    train_generator, train_steps_per_epoch, \
        val_generator, val_steps_per_epoch = dataset.create_generators(
            args.datadir, args.batch_size,
            validation_split=args.validation_split,
            mask=args.classes,
            shuffle_train_val=args.shuffle_train_val,
            shuffle=args.shuffle,
            seed=args.seed,
            normalize_images=args.normalize,
            augment_training=args.augment_training,
            augment_validation=args.augment_validation,
            augmentation_args=augmentation_args)

    # get image dimensions from first batch
    images, masks = next(train_generator)
    _, height, width, channels = images.shape
    _, _, _, classes = masks.shape

    logging.info("Building model...")
    string_to_model = {
        "unet": models.unet,
    }
    model = string_to_model[args.model]
    m = model(height=height, width=width, channels=channels, classes=classes,
              features=args.features, depth=args.depth, padding=args.padding,
              temperature=args.temperature, batchnorm=args.batchnorm,
              dropout=args.dropout)

    m.summary()

    if args.load_weights:
        logging.info("Loading saved weights from file: {}".format(args.load_weights))
        m.load_weights(args.load_weights)

    # instantiate optimizer, and only keep args that have been set
    # (not all optimizers have args like `momentum' or `decay')
    optimizer_args = {
        'lr':       args.learning_rate,
        'momentum': args.momentum,
        'decay':    args.decay
    }
    for k in list(optimizer_args):
        if optimizer_args[k] is None:
            del optimizer_args[k]
    optimizer = select_optimizer(args.optimizer, optimizer_args)

    # select loss function: pixel-wise crossentropy, soft dice or soft
    if args.loss == 'pixel':
        def lossfunc(y_true, y_pred):
            return loss.weighted_categorical_crossentropy(
                y_true, y_pred, args.loss_weights)
    elif args.loss == 'dice':
        def lossfunc(y_true, y_pred):
            return loss.sorensen_dice_loss(y_true, y_pred, args.loss_weights)

    def dice(y_true, y_pred):
        batch_dice_coefs = loss.sorensen_dice(y_true, y_pred, axis=[1, 2])
        dice_coefs = K.mean(batch_dice_coefs, axis=0)
        return dice_coefs[1]    # HACK for 2-class case

    metrics = ['accuracy', dice]

    m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics)

    # automatic saving of model during training
    if args.checkpoint:
        if args.loss == 'pixel':
            filepath = os.path.join(
                args.outdir, "weights-{epoch:02d}-{val_acc:.4f}.hdf5")
            monitor = 'val_acc'
            mode = 'max'
        elif args.loss == 'dice':
            filepath = os.path.join(
                args.outdir, "weights-{epoch:02d}-{val_dice:.4f}.hdf5")
            monitor='val_dice'
            mode = 'max'
        checkpoint = ModelCheckpoint(
            filepath, monitor=monitor, verbose=1,
            save_best_only=True, mode=mode)
        callbacks = [checkpoint]
    else:
        callbacks = []

    # train
    logging.info("Begin training.")
    m.fit_generator(train_generator,
                    epochs=args.epochs,
                    steps_per_epoch=train_steps_per_epoch,
                    validation_data=val_generator,
                    validation_steps=val_steps_per_epoch,
                    callbacks=callbacks,
                    verbose=2)

    m.save(os.path.join(args.outdir, args.outfile))