Exemple #1
0
    def eval_step(self,
                  num_classes,
                  cm,
                  trained_model,
                  aug_strategy,
                  multi_as_binary,
                  multi_class,
                  which_epoch,
                  which_slice,
                  images_gif,
                  gif_dir,
                  pred,
                  target,
                  sample_pred,
                  sample_y,
                  visual_file,
                  name,
                  which_volume,
                  dataset):

        for step, (x, label) in enumerate(dataset):
            print('step ', step)
            pred = trained_model.predict(x)

            # Update visuals
            cm = update_cm(cm, num_classes)
            visualise_multi_class(label, pred)

            if step+1 == which_volume:
                update_gif_slice(x, label, trained_model,
                                 aug_strategy,
                                 multi_as_binary, multi_class,
                                 which_epoch, which_slice)

                images_gif = update_volume_comp_gif(x, label, images_gif, trained_model,
                                                    multi_class,
                                                    which_epoch,
                                                    gif_dir=gif_dir)

                images_gif = update_epoch_gif(x, trained_model, aug_strategy,
                                              multi_class, which_slice, 
                                              gif_dir=gif_dir)
            
                sample_pred, sample_y = update_volume_npy(label, pred, target, 
                                                          sample_pred, sample_y, 
                                                          visual_file, name, 
                                                          which_volume, multi_class)
                    
            iou = iou_loss_eval(label, pred) if multi_class else iou_loss(label, pred)
            dice = dice_coef_eval(label, pred) if multi_class else dice_coef(label, pred)

            with eval_metric_writer.as_default():
                tf.summary.scalar('iou eval validation', iou, step=step)
                tf.summary.scalar('dice eval validation', dice, step=step)
Exemple #2
0
def main(argv):

    del argv

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.logdir, 'summaries'))

    #select mode architecture
    if FLAGS.model_architecture == 'unet':
        model = UNet(FLAGS.num_filters, FLAGS.num_classes, FLAGS.num_conv,
                     FLAGS.kernel_size, FLAGS.activation, FLAGS.batchnorm,
                     FLAGS.dropout_rate, FLAGS.use_spatial,
                     FLAGS.channel_order)

    elif FLAGS.model_architecture == 'multires_unet':
        model = MultiResUnet(FLAGS.num_filters,
                             FLAGS.num_classes,
                             FLAGS.res_path_length,
                             FLAGS.num_conv,
                             FLAGS.kernel_size,
                             use_bias=False,
                             padding='same',
                             activation=FLAGS.activation,
                             use_batchnorm=FLAGS.batchnorm,
                             use_transpose=True,
                             data_format=FLAGS.channel_order)

    elif FLAGS.model_architecture == 'attention_unet_v1':
        model = AttentionUNet_v1(FLAGS.num_filters,
                                 FLAGS.num_classes,
                                 FLAGS.num_conv,
                                 FLAGS.kernel_size,
                                 use_bias=False,
                                 padding='same',
                                 nonlinearity=FLAGS.activation,
                                 use_batchnorm=FLAGS.batchnorm,
                                 use_transpose=True,
                                 data_format=FLAGS.channel_order)

    else:
        print("%s is not a valid or supported model architecture." %
              FLAGS.model_architecture)

    lr_schedule = LearningRateSchedule(19200 / FLAGS.batch_size,
                                       FLAGS.base_learning_rate, 0.5, 1)
    optimiser = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    metrics = {
        'train/tversky_loss': tf.keras.metrics.Mean(),
        'train/accuracy': tf.keras.metrics.CategoricalAccuracy(),
        'train/dice_coef': tf.keras.metrics.Mean(),
        'train/cce': tf.keras.metrics.CategoricalCrossentropy(),
        'valid/tversky_loss': tf.keras.metrics.Mean(),
        'valid/accuracy': tf.keras.metrics.CategoricalAccuracy(),
        'valid/dice_coef': tf.keras.metrics.Mean(),
        'valid/cce': tf.keras.metrics.CategoricalCrossentropy()
    }

    train_ds = dataset_generator('./Data/train_2d/',
                                 batch_size=FLAGS.batch_size,
                                 shuffle=True)
    valid_ds = dataset_generator('./Data/valid_2d/',
                                 batch_size=FLAGS.batch_size,
                                 shuffle=False)

    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            tversky = tf.reduce_mean(tversky_loss(labels, predictions))
            dice = tf.reduce_mean(dice_coef(labels, predictions))
        gradients = tape.gradient(tversky, model.trainable_variables)
        optimiser.apply_gradients(zip(gradients, model.trainable_variables))

        metrics['train/tversky_loss'].update_state(tversky)
        metrics['train/dice_coef'].update_state(dice)
        metrics['train/cce'].update_state(labels, predictions)
        metrics['train/accuracy'].update_state(labels, predictions)

    @tf.function
    def test_step(images, labels):
        predictions = model(images, training=False)
        tversky = tf.reduce_mean(tversky_loss(labels, predictions))
        dice = tf.reduce_mean(dice_coef(labels, predictions))

        metrics['valid/tversky_loss'].update_state(tversky)
        metrics['valid/dice_coef'].update_state(dice)
        metrics['valid/cce'].update_state(labels, predictions)
        metrics['valid/accuracy'].update_state(labels, predictions)

    for epoch in range(FLAGS.train_epochs):

        start = time.process_time()
        for step, (images, labels) in enumerate(train_ds):
            if FLAGS.num_classes != 1:
                labels = get_multiclass(labels)
            else:
                labels = np.sum(labels, axis=3)

            train_step(images, labels)
            template = 'Epoch {}, Step {}, Elapsed Time: {:3f}, Tversky Loss: {:.3f}, Dice Coefficient:{:.3f}, Categorical CE: {:.3f}, Accuracy: {:.2f}'
            print(
                template.format(epoch + 1, step + 1,
                                time.process_time() - start,
                                metrics['train/tversky_loss'].result(),
                                metrics['train/dice_coef'].result(),
                                metrics['train/cce'].result(),
                                metrics['train/accuracy'].result() * 100))

            if step == (19200 / FLAGS.batch_size):
                break

        for valid_step, (valid_images, valid_labels) in enumerate(valid_ds):
            if FLAGS.num_classes != 1:
                valid_labels = get_multiclass(valid_labels)
            else:
                valid_labels = np.sum(valid_labels, axis=3)

            test_step(valid_images, valid_labels)
            if (valid_step + 1) % 100 == 0:
                pred = model(valid_images, training=False)
                visualise_multi_class(valid_labels, pred)

            if valid_step == (4480 / FLAGS.batch_size):
                break

        valid_template = 'Validation results: Epoch {}, Tversky Loss: {:.3f}, Dice Coefficient:{:.3f}, Categorical CE: {:.3f}, Accuracy: {:.2f}'
        print(
            valid_template.format(epoch + 1,
                                  metrics['valid/tversky_loss'].result(),
                                  metrics['valid/dice_coef'].result(),
                                  metrics['valid/cce'].result(),
                                  metrics['valid/accuracy'].result() * 100))

        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        model.save_weights(FLAGS.logdir + '/' + FLAGS.model_architecture +
                           str(epoch + 1) + '.ckpt')
def eval_loop(dataset,
              validation_steps,
              aug_strategy,
              bucket_name,
              logdir,
              tpu_name,
              visual_file,
              weights_dir,
              fig_dir,
              which_volume,
              which_epoch,
              which_slice,
              multi_as_binary,
              trained_model,
              model_architecture,
              callbacks,
              num_classes=7):
    """ Evaluate model and visualize as needed """

    multi_class = num_classes > 1
    gif_dir = ''

    # load the checkpoints in the specified log directory
    session_weights = get_bucket_weights(bucket_name, logdir, tpu_name,
                                         visual_file, weights_dir)
    last_epoch = len(session_weights)

    # trained_model.load_weights(weights_dir).expect_partial()
    # trained_model.evaluate(dataset, steps=validation_steps, callbacks=callbacks)

    # Callbacks (as in og conf matrix function)
    f = weights_dir.split('/')[-1]
    # Excluding parenthese before f too
    if weights_dir.endswith(f):
        writer_dir = weights_dir[:-(len(f) + 1)]
    writer_dir = os.path.join(writer_dir, 'eval')
    eval_metric_writer = tf.summary.create_file_writer(writer_dir)

    # Init visuals
    cm, classes = initialize_cm(multi_class, num_classes)
    fig, axes, images_gif = initialize_gif()
    target = 160  # how many slices in 1 vol
    sample_pred = []  # prediction for current 160,288,288 vol
    sample_y = []  # y for current 160,288,288 vol

    for chkpt in session_weights:
        ### Skip to last chkpt if you only want evaluation

        name = chkpt.split('/')[-1]
        name = name.split('.inde')[0]
        epoch = name.split('.')[1]

        #########################
        print(
            "\n\n\n\n+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
        )
        print(f"\t\tLoading weights from {epoch} epoch")
        print(f"\t\t  {name}")
        print(
            "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
        )
        #########################

        trained_model.load_weights(
            'gs://' +
            os.path.join(bucket_name, weights_dir, tpu_name, visual_file, name)
        ).expect_partial()
        if epoch == last_epoch:
            trained_model.evaluate(dataset,
                                   steps=validation_steps,
                                   callbacks=callbacks)

        # Initializing volume saving
        sample_pred = []  # prediction for current 160,288,288 vol
        sample_y = []  # y for current 160,288,288 vol

        for step, (x, label) in enumerate(dataset):
            print('step', step)
            pred = trained_model.predict(x)

            # Update visuals
            cm = update_cm(cm, num_classes)
            visualise_multi_class(label, pred)

            if step + 1 == which_volume:
                update_gif_slice(x, label, trained_model, aug_strategy,
                                 multi_as_binary, multi_class, which_epoch,
                                 which_slice)

                images_gif = update_volume_comp_gif(x,
                                                    label,
                                                    images_gif,
                                                    trained_model,
                                                    multi_class,
                                                    which_epoch,
                                                    gif_dir=gif_dir)

                images_gif = update_epoch_gif(x,
                                              trained_model,
                                              aug_strategy,
                                              multi_class,
                                              which_slice,
                                              gif_dir=gif_dir)

                sample_pred, sample_y = update_volume_npy(
                    label, pred, target, sample_pred, sample_y, visual_file,
                    name, which_volume, multi_class)

            # if multi_class:
            #     iou = iou_loss_eval(label, pred)
            #     dice = dice_coef_eval(label, pred)
            # else:
            #     iou = iou_loss(label, pred)
            #     dice = dice_coef(label, pred)
            iou = iou_loss_eval(label, pred) if multi_class else iou_loss(
                label, pred)
            dice = dice_coef_eval(label, pred) if multi_class else dice_coef(
                label, pred)

            with eval_metric_writer.as_default():
                tf.summary.scalar('iou eval validation', iou, step=step)
                tf.summary.scalar('dice eval validation', dice, step=step)

        # Save visuals
        save_cm(cm, model_architecture, fig_dir, classes)
    pred_evolution_gif(fig,
                       images_gif,
                       save_dir=gif_dir,
                       save=True,
                       no_margins=False)
Exemple #4
0
def main(argv):

    del argv  #unused arg

    #select mode architecture
    if FLAGS.model_architecture == 'unet':
        model = UNet(FLAGS.num_filters, FLAGS.num_classes, FLAGS.num_conv,
                     FLAGS.kernel_size, FLAGS.activation, FLAGS.batchnorm,
                     FLAGS.dropout_rate, FLAGS.use_spatial,
                     FLAGS.channel_order)

    elif FLAGS.model_architecture == 'multires_unet':
        model = MultiResUnet(FLAGS.num_filters,
                             FLAGS.num_classes,
                             FLAGS.res_path_length,
                             FLAGS.num_conv,
                             FLAGS.kernel_size,
                             use_bias=False,
                             padding='same',
                             activation=FLAGS.activation,
                             use_batchnorm=FLAGS.batchnorm,
                             use_transpose=True,
                             data_format=FLAGS.channel_order)

    elif FLAGS.model_architecture == 'attention_unet_v1':
        model = AttentionUNet_v1(FLAGS.num_filters,
                                 FLAGS.num_classes,
                                 FLAGS.num_conv,
                                 FLAGS.kernel_size,
                                 use_bias=False,
                                 padding='same',
                                 nonlinearity=FLAGS.activation,
                                 use_batchnorm=FLAGS.batchnorm,
                                 use_transpose=True,
                                 data_format=FLAGS.channel_order)

    else:
        print("%s is not a valid or supported model architecture." %
              FLAGS.model_architecture)

    optimiser = tf.keras.optimizers.Adam(
        learning_rate=FLAGS.base_learning_rate)

    if FLAGS.num_classes == 1:
        if FLAGS.dataset == 'oai_challenge':
            generator_train = DataGenerator("./Data/train_2d/samples/",
                                            "./Data/train_2d/labels/",
                                            batch_size=FLAGS.batch_size,
                                            shuffle=True,
                                            multi_class=False)
            generator_valid = DataGenerator("./Data/valid_2d/samples/",
                                            "./Data/valid_2d/labels/",
                                            batch_size=FLAGS.batch_size,
                                            shuffle=True,
                                            multi_class=False)
        model.compile(optimizer=optimiser,
                      loss=dice_loss,
                      metrics=[dice_coef_loss, 'binary_crossentropy', 'acc'])

    else:
        if FLAGS.dataset == 'oai_challenge':
            generator_train = DataGenerator("./Data/train_2d/samples/",
                                            "./Data/train_2d/labels/",
                                            batch_size=FLAGS.batch_size,
                                            shuffle=True,
                                            multi_class=True)
            generator_valid = DataGenerator("./Data/valid_2d/samples/",
                                            "./Data/valid_2d/labels/",
                                            batch_size=FLAGS.batch_size,
                                            shuffle=True,
                                            multi_class=True)

        model.compile(optimizer=optimiser,
                      loss=tversky_loss,
                      metrics=['categorical_crossentropy', 'acc'])

    #Note that fit_generator will be deprecated in future Tensorflow version.
    #Use model.fit instead but ensure that your Tensorflow version is >= 2.1.0 or else it won't work with tf.keras.utils.Sequence object

    if FLAGS.train:
        history = model.fit_generator(generator=generator_train,
                                      epochs=FLAGS.train_epochs,
                                      validation_data=generator_valid,
                                      use_multiprocessing=True,
                                      workers=8,
                                      max_queue_size=16)

        t = time.localtime()
        current_time = time.strftime("%H%M%S", t)
        model_path = FLAGS.model_architecture + '_' + current_time + '.ckpt'
        save_path = os.path.join(FLAGS.logdir, model_path)
        model.save_weights(save_path)

        if FLAGS.num_classes == 1:
            plot_train_history_loss(history, multi_class=False)
        else:
            plot_train_history_loss(history, multi_class=True)

    else:
        #load the latest checkpoint in the FLAGS.logdir file
        latest = tf.train.latest_checkpoint(FLAGS.logdir)
        model.load_weights(latest).expect_partial()

        #this is just to roughly preview the results, we need to build a proper pipeline for visualising & saving output segmentation
        x_val, y_val = generator_valid.__getitem__(idx=100)

        y_pred = model.predict(x_val)
        visualise_multi_class(y_val, y_pred)
Exemple #5
0
def main(argv):

    if FLAGS.visual_file:
        assert FLAGS.train is False, "Train must be set to False if you are doing a visual."
    del argv  # unused arg

    if FLAGS.seed is not None:
        logging.info('Setting seed {}'.format(FLAGS.seed))
        tf.random.set_seed(FLAGS.seed)  # set seed

    # set whether to train on GPU or TPU
    strategy = setup_accelerator(use_gpu=FLAGS.use_gpu,
                                 num_cores=FLAGS.num_cores,
                                 device_name=FLAGS.tpu)

    batch_size = (FLAGS.batch_size * FLAGS.num_cores)

    # set dataset configuration
    train_ds, validation_ds = load_dataset(
        batch_size=batch_size,
        dataset_dir=FLAGS.tfrec_dir,
        augmentation=FLAGS.aug_strategy,
        use_2d=FLAGS.use_2d,
        multi_class=FLAGS.multi_class,
        crop_size=FLAGS.crop_size,
        buffer_size=FLAGS.buffer_size,
        use_bfloat16=FLAGS.use_bfloat16,
        use_RGB=False if FLAGS.backbone_architecture == 'default' else True)

    num_classes = 7 if FLAGS.multi_class else 1
    steps_per_epoch = 19200 // batch_size
    validation_steps = 4480 // batch_size

    if FLAGS.loss == 'tversky':
        loss_fn = tversky_loss
    elif FLAGS.loss == 'dice':
        if FLAGS.multi_class:
            loss_fn = multi_class_dice_coef_loss
        else:
            loss_fn = dice_coef_loss

    elif FLAGS.loss == 'focal_tversky':
        loss_fn = tversky_loss

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    # set model architecture
    model_fn, model_args = select_model(FLAGS, num_classes)

    with strategy.scope():
        model = model_fn(*model_args)

        if FLAGS.custom_decay_lr:
            lr_decay_epochs = FLAGS.lr_decay_epochs
        else:
            lr_decay_epochs = list(
                range(FLAGS.lr_warmup_epochs + 1, FLAGS.train_epochs))

        lr_rate = LearningRateSchedule(steps_per_epoch,
                                       FLAGS.base_learning_rate,
                                       FLAGS.min_learning_rate,
                                       FLAGS.lr_drop_ratio, lr_decay_epochs,
                                       FLAGS.lr_warmup_epochs)

        if FLAGS.optimizer == 'adam':
            optimizer = tf.keras.optimizers.Adam(learning_rate=lr_rate)
        elif FLAGS.optimizer == 'rmsprop':
            optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr_rate)
        elif FLAGS.optimizer == 'sgd':
            optimizer = tf.keras.optimizers.SGD(learning_rate=lr_rate)
        elif FLAGS.optimizer == 'adamw':
            optimizer = tfa.optimizers.AdamW(weight_decay=1e-04,
                                             learning_rate=lr_rate)

        if FLAGS.train:
            if FLAGS.use_2d:
                if FLAGS.backbone_architecture == 'default':
                    model.build((None, FLAGS.crop_size, FLAGS.crop_size, 1))
                else:
                    model.build((None, FLAGS.crop_size, FLAGS.crop_size, 3))
            else:
                model.build((None, FLAGS.depth_crop_size, FLAGS.crop_size,
                             FLAGS.crop_size, 1))
            model.summary()

        if FLAGS.multi_class:
            dice_metrics = [DiceMetrics(idx=idx) for idx in range(num_classes)]
            metrics = [dice_metrics, dice_coef_eval]
        else:
            metrics = [dice_coef]

        model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

    # FLAGS.train will be outside train()
    if FLAGS.train:
        callbacks = []

        # get the timestamp for saved flags and checkpoints
        time = datetime.now().strftime("%Y%m%d-%H%M%S")

        # define checkpoints
        logdir = os.path.join(FLAGS.logdir, FLAGS.tpu)
        logdir = os.path.join(logdir, time)

        if FLAGS.save_weights:

            logdir_arch = os.path.join(logdir, FLAGS.model)
            ckpt_cb = tf.keras.callbacks.ModelCheckpoint(
                logdir_arch + '_weights.{epoch:03d}.ckpt',
                save_best_only=False,
                save_weights_only=True)
            logging.info(
                'Saving weights into the following directory: {}'.format(
                    logdir_arch))
            callbacks.append(ckpt_cb)

            # save flags settings to a directory
            training_history_dir = os.path.join(FLAGS.fig_dir, FLAGS.tpu)
            training_history_dir = os.path.join(training_history_dir, time)
            Path(training_history_dir).mkdir(parents=True, exist_ok=True)
            local_flag_name = os.path.join(training_history_dir,
                                           'train_flags.cfg')
            flag_name = os.path.join(logdir, 'train_flags.cfg')
            logging.info(
                'Saving flags into the following directory: {}'.format(
                    local_flag_name))
            FLAGS.append_flags_into_file(local_flag_name)
            if FLAGS.use_cloud:
                from Segmentation.utils.cloud_utils import upload_blob
                upload_blob(FLAGS.bucket, local_flag_name, flag_name)

        if FLAGS.save_tb:
            logging.info(
                'Saving training logs in Tensorboard save at the following directory: {}'
                .format(logdir))
            tb = tf.keras.callbacks.TensorBoard(logdir, update_freq='epoch')
            callbacks.append(tb)
            # file_writer_cm = tf.summary.create_file_writer(logdir + '/cm')
            # cm_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=get_confusion_matrix_cb)

        model.fit(train_ds,
                  steps_per_epoch=steps_per_epoch,
                  epochs=FLAGS.train_epochs,
                  validation_data=validation_ds,
                  validation_steps=validation_steps,
                  callbacks=callbacks,
                  verbose=1)

    else:
        logging.info('Evaluating {}...'.format(FLAGS.model))
        model.load_weights(FLAGS.weights_dir).expect_partial()
        model.evaluate(validation_ds, steps=validation_steps)

        for step, (x, y_true) in enumerate(validation_ds):
            if step == 80:
                y_pred = model(x, training=False)
                visualise_multi_class(y_true,
                                      y_pred,
                                      savefig='sample_output.png')