Пример #1
0
def main():
    # Sort of a hack:
    # args.checkpoint = turns on saving of images
    args = opts.parse_arguments()
    args.checkpoint = False  # override for now

    glob_search = os.path.join(args.datadir, "patient*")
    patient_dirs = sorted(glob.glob(glob_search))
    if len(patient_dirs) == 0:
        raise Exception("No patient directors found in {}".format(data_dir))

    # get image dimensions from first patient
    images, _, _, _ = load_patient_images(patient_dirs[0], args.normalize)
    _, height, width, channels = images.shape
    classes = 2  # hard coded for now
    contour_type = {'inner': 'i', 'outer': 'o'}[args.classes]

    print("Building model...")
    string_to_model = {
        "unet": models.unet,
        "dilated-unet": models.dilated_unet,
        "dilated-densenet": models.dilated_densenet,
        "dilated-densenet2": models.dilated_densenet2,
        "dilated-densenet3": models.dilated_densenet3,
    }
    model = string_to_model[args.model]

    m = model(height=height,
              width=width,
              channels=channels,
              classes=classes,
              features=args.features,
              depth=args.depth,
              padding=args.padding,
              temperature=args.temperature,
              batchnorm=args.batchnorm,
              dropout=args.dropout)

    m.load_weights(args.load_weights)

    for path in patient_dirs:
        ret = load_patient_images(path, args.normalize)
        images, patient_number, frame_indices, rotated = ret

        predictions = []
        for image in images:
            mask_pred = m.predict(image[None, :, :, :])  # feed one at a time
            predictions.append((image[:, :, 0], mask_pred[0, :, :, 1]))

        for (image, mask), frame_index in zip(predictions, frame_indices):
            filename = "P{:02d}-{:04d}-{}contour-auto.txt".format(
                patient_number, frame_index, contour_type)
            outpath = os.path.join(args.outdir, filename)
            print(filename)

            contour = get_contours(mask)
            if rotated:
                height, width = image.shape
                x, y = contour.T
                x, y = height - y, x
                contour = np.vstack((x, y)).T

            np.savetxt(outpath, contour, fmt='%i', delimiter=' ')

            if args.checkpoint:
                filename = "P{:02d}-{:04d}-{}contour-auto.png".format(
                    patient_number, frame_index, contour_type)
                outpath = os.path.join(args.outdir, filename)
                save_image(outpath, image, np.round(mask))
Пример #2
0
def main():
    # Sort of a hack:
    # args.outfile = file basename to store train / val dice scores
    # args.checkpoint = turns on saving of images
    args = opts.parse_arguments()

    print("Loading dataset...")
    augmentation_args = {
        'rotation_range': args.rotation_range,
        'width_shift_range': args.width_shift_range,
        'height_shift_range': args.height_shift_range,
        'shear_range': args.shear_range,
        'zoom_range': args.zoom_range,
        'fill_mode': args.fill_mode,
        'alpha': args.alpha,
        'sigma': args.sigma,
    }
    train_generator, train_steps_per_epoch, \
        val_generator, val_steps_per_epoch = dataset.create_generators(
            args.datadir, args.batch_size,
            validation_split=args.validation_split,
            mask=args.classes,
            shuffle_train_val=args.shuffle_train_val,
            shuffle=args.shuffle,
            seed=args.seed,
            normalize_images=args.normalize,
            augment_training=args.augment_training,
            augment_validation=args.augment_validation,
            augmentation_args=augmentation_args)

    # get image dimensions from first batch
    images, masks = next(train_generator)
    _, height, width, channels = images.shape
    _, _, _, classes = masks.shape

    print("Building model...")
    string_to_model = {
        "unet": models.unet,
        "dilated-unet": models.dilated_unet,
        "dilated-densenet": models.dilated_densenet,
        "dilated-densenet2": models.dilated_densenet2,
        "dilated-densenet3": models.dilated_densenet3,
    }
    model = string_to_model[args.model]

    m = model(height=height,
              width=width,
              channels=channels,
              classes=classes,
              features=args.features,
              depth=args.depth,
              padding=args.padding,
              temperature=args.temperature,
              batchnorm=args.batchnorm,
              dropout=args.dropout)

    m.load_weights(args.load_weights)

    print("Training Set:")
    train_dice, train_jaccard, train_images = compute_statistics(
        m,
        train_generator,
        train_steps_per_epoch,
        return_images=args.checkpoint)
    print()
    print("Validation Set:")
    val_dice, val_jaccard, val_images = compute_statistics(
        m, val_generator, val_steps_per_epoch, return_images=args.checkpoint)

    if args.outfile:
        train_data = np.asarray([train_dice, train_jaccard]).T
        val_data = np.asarray([val_dice, val_jaccard]).T
        np.savetxt(args.outfile + ".train", train_data)
        np.savetxt(args.outfile + ".val", val_data)

    if args.checkpoint:
        print("Saving images...")
        for i, dice in enumerate(train_dice):
            image, mask_true, mask_pred = train_images[i]
            figname = "train-{:03d}-{:.3f}.png".format(i, dice)
            save_image(figname, image, mask_true, np.round(mask_pred))
        for i, dice in enumerate(val_dice):
            image, mask_true, mask_pred = val_images[i]
            figname = "val-{:03d}-{:.3f}.png".format(i, dice)
            save_image(figname, image, mask_true, np.round(mask_pred))
Пример #3
0
def train():
    logging.basicConfig(level=logging.INFO)

    args = opts.parse_arguments()

    #logging.info("Loading dataset...")
    augmentation_args = {
        'rotation_range': args.rotation_range,
        'width_shift_range': args.width_shift_range,
        'height_shift_range': args.height_shift_range,
        'shear_range': args.shear_range,
        'zoom_range': args.zoom_range,
        'fill_mode': args.fill_mode,
        'alpha': args.alpha,
        'sigma': args.sigma,
    }

    ACDC.load(args.datadir, "RV", "ES")
    ACDC.resize(216, 256)
    channels = 1

    images = ACDC.get_images()
    for i in range(len(images)):
        height, width, depth = images[i].shape
        #images[i] = images[i].reshape(depth, channels, height, width)
        images[i] = images[i].reshape(depth, height, width, channels)

    masks = ACDC.get_masks()
    for i in range(len(masks)):
        height, width, depth, classes = masks[i].shape
        masks[i] = masks[i].reshape(depth, height, width, classes)

    #images, masks = RVSC.load(args.datadir, args.classes)

    ## get image dimensions
    #_, height, width, channels = images[0].shape
    #_, _, _, classes = masks[0].shape

    logging.info("Building model...")
    string_to_model = {
        "unet": models.unet,
        "dilated-unet": models.dilated_unet,
        "dilated-densenet": models.dilated_densenet,
        "dilated-densenet2": models.dilated_densenet2,
        "dilated-densenet3": models.dilated_densenet3,
    }

    if args.multi_gpu:
        with tf.device('/cpu:0'):
            model = string_to_model[args.model]
            m = model(height=height,
                      width=width,
                      channels=channels,
                      classes=classes,
                      features=args.features,
                      depth=args.depth,
                      padding=args.padding,
                      temperature=args.temperature,
                      batchnorm=args.batchnorm,
                      dropout=args.dropout)

            m.summary()

            if args.load_weights:
                logging.info("Loading saved weights from file: {}".format(
                    args.load_weights))
                m.load_weights(args.load_weights)
    else:
        model = string_to_model[args.model]
        m = model(height=height,
                  width=width,
                  channels=channels,
                  classes=classes,
                  features=args.features,
                  depth=args.depth,
                  padding=args.padding,
                  temperature=args.temperature,
                  batchnorm=args.batchnorm,
                  dropout=args.dropout)

        m.summary()

        if args.load_weights:
            logging.info("Loading saved weights from file: {}".format(
                args.load_weights))
            m.load_weights(args.load_weights)

    # instantiate optimizer, and only keep args that have been set
    # (not all optimizers have args like `momentum' or `decay')
    optimizer_args = {
        'lr': args.learning_rate,
        'momentum': args.momentum,
        'decay': args.decay
    }
    for k in list(optimizer_args):
        if optimizer_args[k] is None:
            del optimizer_args[k]
    optimizer = select_optimizer(args.optimizer, optimizer_args)

    # select loss function: pixel-wise crossentropy, soft dice or soft
    # jaccard coefficient
    if args.loss == 'pixel':

        def lossfunc(y_true, y_pred):
            return loss.weighted_categorical_crossentropy(
                y_true, y_pred, args.loss_weights)
    elif args.loss == 'dice':

        def lossfunc(y_true, y_pred):
            return loss.sorensen_dice_loss(y_true, y_pred, args.loss_weights)
    elif args.loss == 'jaccard':

        def lossfunc(y_true, y_pred):
            return loss.jaccard_loss(y_true, y_pred, args.loss_weights)
    else:
        raise Exception("Unknown loss ({})".format(args.loss))

    def dice(y_true, y_pred):
        batch_dice_coefs = loss.sorensen_dice(y_true, y_pred, axis=[1, 2])
        dice_coefs = K.mean(batch_dice_coefs, axis=0)
        return dice_coefs[1]  # HACK for 2-class case

    def jaccard(y_true, y_pred):
        batch_jaccard_coefs = loss.jaccard(y_true, y_pred, axis=[1, 2])
        jaccard_coefs = K.mean(batch_jaccard_coefs, axis=0)
        return jaccard_coefs[1]  # HACK for 2-class case

    metrics = ['accuracy', dice, jaccard]

    if args.multi_gpu:
        parallel_model = multi_gpu_model(m, gpus=2)
        parallel_model.compile(optimizer=optimizer,
                               loss=lossfunc,
                               metrics=metrics)
    else:
        m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics)

    train_indexes = []
    val_indexes = []
    if args.cross_val_folds is not None:
        if args.cross_val_folds > len(images):
            raise Exception(
                "Number of cross validation folds must be not more than {}.".
                format(len(images)))

        kf = KFold(n_splits=4)
        val_dice_values = []
        fold = 1
        for train_indexes, val_indexes in kf.split(images):
            print("fold #{}".format(fold))
            print("{} {}".format(train_indexes, val_indexes))
            train_generator, train_steps_per_epoch, \
                val_generator, val_steps_per_epoch = dataset.create_generators(
                    images, masks, args.datadir, args.batch_size,
                    train_indexes, val_indexes,
                    validation_split=args.validation_split,
                    shuffle_train_val=args.shuffle_train_val,
                    shuffle=args.shuffle,
                    seed=args.seed,
                    normalize_images=args.normalize,
                    augment_training=args.augment_training,
                    augment_validation=args.augment_validation,
                    augmentation_args=augmentation_args)

            # automatic saving of model during training
            if args.checkpoint:
                monitor = 'val_dice'
                mode = 'max'
                filepath = os.path.join(
                    args.outdir, "weights-{epoch:02d}-{val_dice:.4f}" +
                    "-fold{}.hdf5".format(fold))

                if args.multi_gpu:
                    checkpoint = MultiGPUModelCheckpoint(filepath,
                                                         m,
                                                         monitor=monitor,
                                                         verbose=1,
                                                         save_best_only=True,
                                                         mode=mode)
                else:
                    checkpoint = ModelCheckpoint(filepath,
                                                 monitor=monitor,
                                                 verbose=1,
                                                 save_best_only=True,
                                                 mode=mode)

                callbacks = [checkpoint]
            else:
                callbacks = []

            # train
            logging.info("Begin training.")
            if args.multi_gpu:
                history = parallel_model.fit_generator(
                    train_generator,
                    epochs=args.epochs,
                    steps_per_epoch=train_steps_per_epoch,
                    validation_data=val_generator,
                    validation_steps=val_steps_per_epoch,
                    callbacks=callbacks,
                    verbose=2)
            else:
                history = m.fit_generator(
                    train_generator,
                    epochs=args.epochs,
                    steps_per_epoch=train_steps_per_epoch,
                    validation_data=val_generator,
                    validation_steps=val_steps_per_epoch,
                    callbacks=callbacks,
                    verbose=2)

            save_plot(args.outfile_plot + "-fold{}.png".format(fold), history)
            m.save(
                os.path.join(args.outdir,
                             args.outfile + "-fold{}.hdf5".format(fold)))
            val_dice_values += [max(history.history['val_dice'])]
            fold += 1

        print("Maximum dice values on validation sets are {}.".format(
            val_dice_values))
        print("Mean dice value is {}.".format(np.mean(val_dice_values)))
    else:
        train_generator, train_steps_per_epoch, \
            val_generator, val_steps_per_epoch = dataset.create_generators(
                images, masks,
                args.datadir, args.batch_size,
                train_indexes, val_indexes,
                validation_split=args.validation_split,
                shuffle_train_val=args.shuffle_train_val,
                shuffle=args.shuffle,
                seed=args.seed,
                normalize_images=args.normalize,
                augment_training=args.augment_training,
                augment_validation=args.augment_validation,
                augmentation_args=augmentation_args)

        # automatic saving of model during training
        if args.checkpoint:
            monitor = 'val_dice'
            mode = 'max'
            filepath = os.path.join(args.outdir,
                                    "weights-{epoch:02d}-{val_dice:.4f}.hdf5")

            if args.multi_gpu:
                checkpoint = MultiGPUModelCheckpoint(filepath,
                                                     m,
                                                     monitor=monitor,
                                                     verbose=1,
                                                     save_best_only=True,
                                                     mode=mode)
            else:
                checkpoint = ModelCheckpoint(filepath,
                                             monitor=monitor,
                                             verbose=1,
                                             save_best_only=True,
                                             mode=mode)

            callbacks = [checkpoint]
        else:
            callbacks = []

        # train
        logging.info("Begin training.")
        if args.multi_gpu:
            history = parallel_model.fit_generator(
                train_generator,
                epochs=args.epochs,
                steps_per_epoch=train_steps_per_epoch,
                validation_data=val_generator,
                validation_steps=val_steps_per_epoch,
                callbacks=callbacks,
                verbose=2)
        else:
            history = m.fit_generator(train_generator,
                                      epochs=args.epochs,
                                      steps_per_epoch=train_steps_per_epoch,
                                      validation_data=val_generator,
                                      validation_steps=val_steps_per_epoch,
                                      callbacks=callbacks,
                                      verbose=2)

        save_plot(args.outfile_plot + ".png", history)
        m.save(os.path.join(args.outdir, args.outfile + ".hdf5"))

        print("Maximum dice value on validation set is {}.".format(
            max(history.history['val_dice'])))
Пример #4
0
def train():
    logging.basicConfig(level=logging.INFO)

    args = opts.parse_arguments()

    logging.info("Loading dataset...")
    augmentation_args = {
        'rotation_range': args.rotation_range,
        'width_shift_range': args.width_shift_range,
        'height_shift_range': args.height_shift_range,
        'shear_range': args.shear_range,
        'zoom_range': args.zoom_range,
        'fill_mode' : args.fill_mode,
        'alpha': args.alpha,
        'sigma': args.sigma,
    }
    train_generator, train_steps_per_epoch, \
        val_generator, val_steps_per_epoch = dataset.create_generators(
            args.datadir, args.batch_size,
            validation_split=args.validation_split,
            mask=args.classes,
            shuffle_train_val=args.shuffle_train_val,
            shuffle=args.shuffle,
            seed=args.seed,
            normalize_images=args.normalize,
            augment_training=args.augment_training,
            augment_validation=args.augment_validation,
            augmentation_args=augmentation_args)

    # get image dimensions from first batch
    images, masks = next(train_generator)
    _, height, width, channels = images.shape
    _, _, _, classes = masks.shape

    logging.info("Building model...")
    string_to_model = {
        "unet": models.unet,
    }
    model = string_to_model[args.model]
    m = model(height=height, width=width, channels=channels, classes=classes,
              features=args.features, depth=args.depth, padding=args.padding,
              temperature=args.temperature, batchnorm=args.batchnorm,
              dropout=args.dropout)

    m.summary()

    if args.load_weights:
        logging.info("Loading saved weights from file: {}".format(args.load_weights))
        m.load_weights(args.load_weights)

    # instantiate optimizer, and only keep args that have been set
    # (not all optimizers have args like `momentum' or `decay')
    optimizer_args = {
        'lr':       args.learning_rate,
        'momentum': args.momentum,
        'decay':    args.decay
    }
    for k in list(optimizer_args):
        if optimizer_args[k] is None:
            del optimizer_args[k]
    optimizer = select_optimizer(args.optimizer, optimizer_args)

    # select loss function: pixel-wise crossentropy, soft dice or soft
    if args.loss == 'pixel':
        def lossfunc(y_true, y_pred):
            return loss.weighted_categorical_crossentropy(
                y_true, y_pred, args.loss_weights)
    elif args.loss == 'dice':
        def lossfunc(y_true, y_pred):
            return loss.sorensen_dice_loss(y_true, y_pred, args.loss_weights)

    def dice(y_true, y_pred):
        batch_dice_coefs = loss.sorensen_dice(y_true, y_pred, axis=[1, 2])
        dice_coefs = K.mean(batch_dice_coefs, axis=0)
        return dice_coefs[1]    # HACK for 2-class case

    metrics = ['accuracy', dice]

    m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics)

    # automatic saving of model during training
    if args.checkpoint:
        if args.loss == 'pixel':
            filepath = os.path.join(
                args.outdir, "weights-{epoch:02d}-{val_acc:.4f}.hdf5")
            monitor = 'val_acc'
            mode = 'max'
        elif args.loss == 'dice':
            filepath = os.path.join(
                args.outdir, "weights-{epoch:02d}-{val_dice:.4f}.hdf5")
            monitor='val_dice'
            mode = 'max'
        checkpoint = ModelCheckpoint(
            filepath, monitor=monitor, verbose=1,
            save_best_only=True, mode=mode)
        callbacks = [checkpoint]
    else:
        callbacks = []

    # train
    logging.info("Begin training.")
    m.fit_generator(train_generator,
                    epochs=args.epochs,
                    steps_per_epoch=train_steps_per_epoch,
                    validation_data=val_generator,
                    validation_steps=val_steps_per_epoch,
                    callbacks=callbacks,
                    verbose=2)

    m.save(os.path.join(args.outdir, args.outfile))