def test_shuffle_train_val(self): # test shuffling of entire dataset prior to train-val split # (does not test shuffling within each epoch) data_dir = "../test-assets/" batch_size = 2 validation_split = 0.5 mask = "inner" classes = 2 seed = 5 # random number seed # there should be 2 images in the validation set, and we'll check if # they always appear in the same order with a fixed seed image_list = [] mask_list = [] for i in range(10): _, _, val_generator, _ = dataset.create_generators( data_dir, batch_size, validation_split=validation_split, mask=mask, shuffle_train_val=True, shuffle=False, seed=seed, normalize_images=True) images, masks = next(val_generator) self.assertEqual(images.shape, (2, 216, 256, 1)) self.assertEqual(masks.shape, (2, 216, 256, classes)) # also check image normalization for image in images: self.assertAlmostEqual(np.mean(image), 0) self.assertAlmostEqual(np.std(image), 1, places=5) image_list.append(images[0]) mask_list.append(masks[0]) # first image/mask in each case should be the same image0 = image_list[0] for image in image_list[1:]: np.testing.assert_array_equal(image0, image) mask0 = mask_list[0] for mask in mask_list[1:]: np.testing.assert_array_equal(mask0, mask) # now test that things get shuffled if we don't specify a seed mask = "both" _, _, val_generator, _ = dataset.create_generators( data_dir, batch_size, validation_split=validation_split, mask=mask, shuffle_train_val=True, shuffle=False, seed=None, normalize_images=True) images, masks = next(val_generator) image0 = images[0] while 1: _, _, val_generator, _ = dataset.create_generators( data_dir, batch_size, validation_split=validation_split, mask=mask, shuffle_train_val=True, shuffle=True, seed=None, normalize_images=True) images, masks = next(val_generator) try: np.testing.assert_array_equal(image0, images[0]) except AssertionError: break # break if arrays are differet (= success!)
def _test_no_validation(self, mask): data_dir = "../test-assets/" batch_size = 2 validation_split = 0.0 (train_generator, train_steps_per_epoch, val_generator, val_steps_per_epoch) = dataset.create_generators( data_dir, batch_size, validation_split=validation_split, mask=mask) self.assertEqual(train_steps_per_epoch, 2) self.assertEqual(val_steps_per_epoch, 0) classes = 3 if mask == 'both' else 2 # first 2 train images images, masks = next(train_generator) self.assertEqual(images.shape, (2, 216, 256, 1)) self.assertEqual(masks.shape, (2, 216, 256, classes)) # last train image (for total of 3) images, masks = next(train_generator) self.assertEqual(images.shape, (1, 216, 256, 1)) self.assertEqual(masks.shape, (1, 216, 256, classes)) # first 2 train images again images, masks = next(train_generator) self.assertEqual(images.shape, (2, 216, 256, 1)) self.assertEqual(masks.shape, (2, 216, 256, classes)) # validation generator should be nothing self.assertEqual(val_generator, None)
def _test_generator(self, mask): data_dir = "../test-assets/" batch_size = 2 validation_split = 0.5 # With a total of 3 training images, this split will create 1 # training image and 2 validation images (train_generator, train_steps_per_epoch, val_generator, val_steps_per_epoch) = dataset.create_generators( data_dir, batch_size, validation_split=validation_split, mask=mask) self.assertEqual(train_steps_per_epoch, 1) self.assertEqual(val_steps_per_epoch, 1) classes = 3 if mask == 'both' else 2 images, masks = next(train_generator) self.assertEqual(images.shape, (1, 216, 256, 1)) self.assertEqual(masks.shape, (1, 216, 256, classes)) images, masks = next(val_generator) self.assertEqual(images.shape, (2, 216, 256, 1)) self.assertEqual(masks.shape, (2, 216, 256, classes))
def train(): logging.basicConfig(level=logging.INFO) args = opts.parse_arguments() #logging.info("Loading dataset...") augmentation_args = { 'rotation_range': args.rotation_range, 'width_shift_range': args.width_shift_range, 'height_shift_range': args.height_shift_range, 'shear_range': args.shear_range, 'zoom_range': args.zoom_range, 'fill_mode': args.fill_mode, 'alpha': args.alpha, 'sigma': args.sigma, } ACDC.load(args.datadir, "RV", "ES") ACDC.resize(216, 256) channels = 1 images = ACDC.get_images() for i in range(len(images)): height, width, depth = images[i].shape #images[i] = images[i].reshape(depth, channels, height, width) images[i] = images[i].reshape(depth, height, width, channels) masks = ACDC.get_masks() for i in range(len(masks)): height, width, depth, classes = masks[i].shape masks[i] = masks[i].reshape(depth, height, width, classes) #images, masks = RVSC.load(args.datadir, args.classes) ## get image dimensions #_, height, width, channels = images[0].shape #_, _, _, classes = masks[0].shape logging.info("Building model...") string_to_model = { "unet": models.unet, "dilated-unet": models.dilated_unet, "dilated-densenet": models.dilated_densenet, "dilated-densenet2": models.dilated_densenet2, "dilated-densenet3": models.dilated_densenet3, } if args.multi_gpu: with tf.device('/cpu:0'): model = string_to_model[args.model] m = model(height=height, width=width, channels=channels, classes=classes, features=args.features, depth=args.depth, padding=args.padding, temperature=args.temperature, batchnorm=args.batchnorm, dropout=args.dropout) m.summary() if args.load_weights: logging.info("Loading saved weights from file: {}".format( args.load_weights)) m.load_weights(args.load_weights) else: model = string_to_model[args.model] m = model(height=height, width=width, channels=channels, classes=classes, features=args.features, depth=args.depth, padding=args.padding, temperature=args.temperature, batchnorm=args.batchnorm, dropout=args.dropout) m.summary() if args.load_weights: logging.info("Loading saved weights from file: {}".format( args.load_weights)) m.load_weights(args.load_weights) # instantiate optimizer, and only keep args that have been set # (not all optimizers have args like `momentum' or `decay') optimizer_args = { 'lr': args.learning_rate, 'momentum': args.momentum, 'decay': args.decay } for k in list(optimizer_args): if optimizer_args[k] is None: del optimizer_args[k] optimizer = select_optimizer(args.optimizer, optimizer_args) # select loss function: pixel-wise crossentropy, soft dice or soft # jaccard coefficient if args.loss == 'pixel': def lossfunc(y_true, y_pred): return loss.weighted_categorical_crossentropy( y_true, y_pred, args.loss_weights) elif args.loss == 'dice': def lossfunc(y_true, y_pred): return loss.sorensen_dice_loss(y_true, y_pred, args.loss_weights) elif args.loss == 'jaccard': def lossfunc(y_true, y_pred): return loss.jaccard_loss(y_true, y_pred, args.loss_weights) else: raise Exception("Unknown loss ({})".format(args.loss)) def dice(y_true, y_pred): batch_dice_coefs = loss.sorensen_dice(y_true, y_pred, axis=[1, 2]) dice_coefs = K.mean(batch_dice_coefs, axis=0) return dice_coefs[1] # HACK for 2-class case def jaccard(y_true, y_pred): batch_jaccard_coefs = loss.jaccard(y_true, y_pred, axis=[1, 2]) jaccard_coefs = K.mean(batch_jaccard_coefs, axis=0) return jaccard_coefs[1] # HACK for 2-class case metrics = ['accuracy', dice, jaccard] if args.multi_gpu: parallel_model = multi_gpu_model(m, gpus=2) parallel_model.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics) else: m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics) train_indexes = [] val_indexes = [] if args.cross_val_folds is not None: if args.cross_val_folds > len(images): raise Exception( "Number of cross validation folds must be not more than {}.". format(len(images))) kf = KFold(n_splits=4) val_dice_values = [] fold = 1 for train_indexes, val_indexes in kf.split(images): print("fold #{}".format(fold)) print("{} {}".format(train_indexes, val_indexes)) train_generator, train_steps_per_epoch, \ val_generator, val_steps_per_epoch = dataset.create_generators( images, masks, args.datadir, args.batch_size, train_indexes, val_indexes, validation_split=args.validation_split, shuffle_train_val=args.shuffle_train_val, shuffle=args.shuffle, seed=args.seed, normalize_images=args.normalize, augment_training=args.augment_training, augment_validation=args.augment_validation, augmentation_args=augmentation_args) # automatic saving of model during training if args.checkpoint: monitor = 'val_dice' mode = 'max' filepath = os.path.join( args.outdir, "weights-{epoch:02d}-{val_dice:.4f}" + "-fold{}.hdf5".format(fold)) if args.multi_gpu: checkpoint = MultiGPUModelCheckpoint(filepath, m, monitor=monitor, verbose=1, save_best_only=True, mode=mode) else: checkpoint = ModelCheckpoint(filepath, monitor=monitor, verbose=1, save_best_only=True, mode=mode) callbacks = [checkpoint] else: callbacks = [] # train logging.info("Begin training.") if args.multi_gpu: history = parallel_model.fit_generator( train_generator, epochs=args.epochs, steps_per_epoch=train_steps_per_epoch, validation_data=val_generator, validation_steps=val_steps_per_epoch, callbacks=callbacks, verbose=2) else: history = m.fit_generator( train_generator, epochs=args.epochs, steps_per_epoch=train_steps_per_epoch, validation_data=val_generator, validation_steps=val_steps_per_epoch, callbacks=callbacks, verbose=2) save_plot(args.outfile_plot + "-fold{}.png".format(fold), history) m.save( os.path.join(args.outdir, args.outfile + "-fold{}.hdf5".format(fold))) val_dice_values += [max(history.history['val_dice'])] fold += 1 print("Maximum dice values on validation sets are {}.".format( val_dice_values)) print("Mean dice value is {}.".format(np.mean(val_dice_values))) else: train_generator, train_steps_per_epoch, \ val_generator, val_steps_per_epoch = dataset.create_generators( images, masks, args.datadir, args.batch_size, train_indexes, val_indexes, validation_split=args.validation_split, shuffle_train_val=args.shuffle_train_val, shuffle=args.shuffle, seed=args.seed, normalize_images=args.normalize, augment_training=args.augment_training, augment_validation=args.augment_validation, augmentation_args=augmentation_args) # automatic saving of model during training if args.checkpoint: monitor = 'val_dice' mode = 'max' filepath = os.path.join(args.outdir, "weights-{epoch:02d}-{val_dice:.4f}.hdf5") if args.multi_gpu: checkpoint = MultiGPUModelCheckpoint(filepath, m, monitor=monitor, verbose=1, save_best_only=True, mode=mode) else: checkpoint = ModelCheckpoint(filepath, monitor=monitor, verbose=1, save_best_only=True, mode=mode) callbacks = [checkpoint] else: callbacks = [] # train logging.info("Begin training.") if args.multi_gpu: history = parallel_model.fit_generator( train_generator, epochs=args.epochs, steps_per_epoch=train_steps_per_epoch, validation_data=val_generator, validation_steps=val_steps_per_epoch, callbacks=callbacks, verbose=2) else: history = m.fit_generator(train_generator, epochs=args.epochs, steps_per_epoch=train_steps_per_epoch, validation_data=val_generator, validation_steps=val_steps_per_epoch, callbacks=callbacks, verbose=2) save_plot(args.outfile_plot + ".png", history) m.save(os.path.join(args.outdir, args.outfile + ".hdf5")) print("Maximum dice value on validation set is {}.".format( max(history.history['val_dice'])))
def main(): # Sort of a hack: # args.outfile = file basename to store train / val dice scores # args.checkpoint = turns on saving of images args = opts.parse_arguments() print("Loading dataset...") augmentation_args = { 'rotation_range': args.rotation_range, 'width_shift_range': args.width_shift_range, 'height_shift_range': args.height_shift_range, 'shear_range': args.shear_range, 'zoom_range': args.zoom_range, 'fill_mode': args.fill_mode, 'alpha': args.alpha, 'sigma': args.sigma, } train_generator, train_steps_per_epoch, \ val_generator, val_steps_per_epoch = dataset.create_generators( args.datadir, args.batch_size, validation_split=args.validation_split, mask=args.classes, shuffle_train_val=args.shuffle_train_val, shuffle=args.shuffle, seed=args.seed, normalize_images=args.normalize, augment_training=args.augment_training, augment_validation=args.augment_validation, augmentation_args=augmentation_args) # get image dimensions from first batch images, masks = next(train_generator) _, height, width, channels = images.shape _, _, _, classes = masks.shape print("Building model...") string_to_model = { "unet": models.unet, "dilated-unet": models.dilated_unet, "dilated-densenet": models.dilated_densenet, "dilated-densenet2": models.dilated_densenet2, "dilated-densenet3": models.dilated_densenet3, } model = string_to_model[args.model] m = model(height=height, width=width, channels=channels, classes=classes, features=args.features, depth=args.depth, padding=args.padding, temperature=args.temperature, batchnorm=args.batchnorm, dropout=args.dropout) m.load_weights(args.load_weights) print("Training Set:") train_dice, train_jaccard, train_images = compute_statistics( m, train_generator, train_steps_per_epoch, return_images=args.checkpoint) print() print("Validation Set:") val_dice, val_jaccard, val_images = compute_statistics( m, val_generator, val_steps_per_epoch, return_images=args.checkpoint) if args.outfile: train_data = np.asarray([train_dice, train_jaccard]).T val_data = np.asarray([val_dice, val_jaccard]).T np.savetxt(args.outfile + ".train", train_data) np.savetxt(args.outfile + ".val", val_data) if args.checkpoint: print("Saving images...") for i, dice in enumerate(train_dice): image, mask_true, mask_pred = train_images[i] figname = "train-{:03d}-{:.3f}.png".format(i, dice) save_image(figname, image, mask_true, np.round(mask_pred)) for i, dice in enumerate(val_dice): image, mask_true, mask_pred = val_images[i] figname = "val-{:03d}-{:.3f}.png".format(i, dice) save_image(figname, image, mask_true, np.round(mask_pred))
def train(): logging.basicConfig(level=logging.INFO) args = opts.parse_arguments() logging.info("Loading dataset...") augmentation_args = { 'rotation_range': args.rotation_range, 'width_shift_range': args.width_shift_range, 'height_shift_range': args.height_shift_range, 'shear_range': args.shear_range, 'zoom_range': args.zoom_range, 'fill_mode' : args.fill_mode, 'alpha': args.alpha, 'sigma': args.sigma, } train_generator, train_steps_per_epoch, \ val_generator, val_steps_per_epoch = dataset.create_generators( args.datadir, args.batch_size, validation_split=args.validation_split, mask=args.classes, shuffle_train_val=args.shuffle_train_val, shuffle=args.shuffle, seed=args.seed, normalize_images=args.normalize, augment_training=args.augment_training, augment_validation=args.augment_validation, augmentation_args=augmentation_args) # get image dimensions from first batch images, masks = next(train_generator) _, height, width, channels = images.shape _, _, _, classes = masks.shape logging.info("Building model...") string_to_model = { "unet": models.unet, } model = string_to_model[args.model] m = model(height=height, width=width, channels=channels, classes=classes, features=args.features, depth=args.depth, padding=args.padding, temperature=args.temperature, batchnorm=args.batchnorm, dropout=args.dropout) m.summary() if args.load_weights: logging.info("Loading saved weights from file: {}".format(args.load_weights)) m.load_weights(args.load_weights) # instantiate optimizer, and only keep args that have been set # (not all optimizers have args like `momentum' or `decay') optimizer_args = { 'lr': args.learning_rate, 'momentum': args.momentum, 'decay': args.decay } for k in list(optimizer_args): if optimizer_args[k] is None: del optimizer_args[k] optimizer = select_optimizer(args.optimizer, optimizer_args) # select loss function: pixel-wise crossentropy, soft dice or soft if args.loss == 'pixel': def lossfunc(y_true, y_pred): return loss.weighted_categorical_crossentropy( y_true, y_pred, args.loss_weights) elif args.loss == 'dice': def lossfunc(y_true, y_pred): return loss.sorensen_dice_loss(y_true, y_pred, args.loss_weights) def dice(y_true, y_pred): batch_dice_coefs = loss.sorensen_dice(y_true, y_pred, axis=[1, 2]) dice_coefs = K.mean(batch_dice_coefs, axis=0) return dice_coefs[1] # HACK for 2-class case metrics = ['accuracy', dice] m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics) # automatic saving of model during training if args.checkpoint: if args.loss == 'pixel': filepath = os.path.join( args.outdir, "weights-{epoch:02d}-{val_acc:.4f}.hdf5") monitor = 'val_acc' mode = 'max' elif args.loss == 'dice': filepath = os.path.join( args.outdir, "weights-{epoch:02d}-{val_dice:.4f}.hdf5") monitor='val_dice' mode = 'max' checkpoint = ModelCheckpoint( filepath, monitor=monitor, verbose=1, save_best_only=True, mode=mode) callbacks = [checkpoint] else: callbacks = [] # train logging.info("Begin training.") m.fit_generator(train_generator, epochs=args.epochs, steps_per_epoch=train_steps_per_epoch, validation_data=val_generator, validation_steps=val_steps_per_epoch, callbacks=callbacks, verbose=2) m.save(os.path.join(args.outdir, args.outfile))