def create_ops(net, mode='train'):
    is_training = True if (mode == 'train') else False

    ops = {}

    ops['x'] = tf.placeholder(tf.float32, shape=[None, 64, 64, 64, num_channels], name='x_placeholder')
    ops['y'] = tf.placeholder(tf.int32, shape=[None, 64, 64, 64], name='y_placeholder')

    ops['net'] = net(ops['x'], is_training=is_training)

    ops['y_'] = ops['net']['y_']
    ops['y_prob'] = ops['net']['y_prob']

    # Define and add losses
    ops['loss_dice'] = modules.dice_loss(ops['net']['logits'], ops['y'], num_classes, include_background=True, smooth=0.,
                                         name='{}/loss_dice'.format(mode), collections=['losses', '{}'.format(mode)])
    ops['loss_ce'] = modules.sparse_crossentropy(ops['net']['logits'], ops['y'],
                                         name='{}/loss_ce'.format(mode), collections=['losses', '{}'.format(mode)])
    ops['loss_balce'] = modules.sparse_balanced_crossentropy(ops['net']['logits'], ops['y'],
                                                 name='{}/loss_ce'.format(mode),
                                                 collections=['losses', '{}'.format(mode)])
    ops['loss_all'] =  ops['loss_dice']  

    if is_training:
        # Add image summaries for x, y, y_
        modules.image_summary(ops['x'], 'train_img', ['{}'.format(mode)])
        modules.image_summary(tf.expand_dims(tf.to_float(ops['y_']) / num_classes, axis=-1), 'train_pred',
                              ['{}'.format(mode)])
        modules.image_summary(tf.expand_dims(tf.to_float(ops['y']) / num_classes, axis=-1), 'train_lbl',
                              ['{}'.format(mode)])

        # Add scalar summaries of the loss
        modules.scalar_summary(ops['loss_all'], '{}/loss_all'.format(mode), collections=['losses', '{}'.format(mode)])

        output_hist = tf.summary.histogram('out/soft', ops['y_prob'])

        # Merge all tf summaries from collection 'training' and all MOVING_AVERAGE_VARIABLES (i.e. batch norm)
        ops['summaries'] = tf.summary.merge(
            [tf.summary.merge_all('{}'.format(mode)), ] + [tf.summary.histogram(var.name, var) for var in
                                                           net.get_variables(tf.GraphKeys.MOVING_AVERAGE_VARIABLES)]
            + [output_hist, ])

        # Create a learning rate placeholder for scheduling and choose an optimisation
        # ops['lr'] = tf.placeholder(tf.float32)
        ops['global_step'] = tf.Variable(0, name='global_step', trainable=False)
        ops['optimiser'] = tf.train.AdamOptimizer(0.001, epsilon=1e-5).minimize(ops['loss_all'],
                                                                               global_step=ops['global_step'])

    return ops
예제 #2
0
def validate(ops, session, supervisor, name, v_all=True):
    """
        Run an inference on a validation dataset

        Parameters
        ----------
        ops : dict
            a dictionary containing all validation ops

        session : tf.session object

        supervisor : tf.Supervisor object

        Returns
        -------
    """

    # Pick num_validation_examples datasets to validate on
    if v_all:
        num_validation_examples = len(ops['filenames'])
    else:
        num_validation_examples = 4

    val_idx = range(num_validation_examples)

    # Track loss and Dice similarity coefficients as validation metrics
    val_loss = []
    val_dscs = []
#    val_orig_dscs = []

    # Iterate through the datasets and perform a sliding window inference
    for f in ops['filenames'][val_idx]:

        # Read a validation image and label of arbitrary dimensions
        val_x, val_y = ops['read_func']([f])

#        pid = os.path.basename(f[-1]).split('_')[0]
        pid = 'Subj.' + f[0].split('p/')[1][:2]

        y_prob = sliding_window_segmentation_inference(session, [ops['y_prob']], {ops['x']: val_x}, batch_size=16)[0]

        y_ = np.argmax(y_prob, axis=-1)

        # Compute the performance metrics for the dataset
        dscs = metrics.dice(y_, val_y, num_classes)
        loss = metrics.crossentropy(y_prob, np.eye(num_classes)[val_y.astype(int)], logits=False)
#        print(pid + '; CE= {:.6f}; DSC: l1 = {:.3f}'.format(loss, dscs[1]))
        print(pid + '; CE= {:.6f}; DSC:  LVR = {:.3f}, SPL = {:.3f}, RKDN = {:.3f}, LKDN = {:.3f}'.format(loss, dscs[1], dscs[2], dscs[3],dscs[4]))  
        # Collect the metrics over the validation data
        val_loss = val_loss + [loss]
        val_dscs = val_dscs + [dscs]

    np.save(args.save_dscs, val_dscs) 
    mean_dscs = np.mean(val_dscs, axis=0)
    mean_loss = np.mean(val_loss, axis=0)

    print('Mean; CE= {:.6f}; DSC: l1 = {:.3f}'.format(mean_loss, mean_dscs[1]))

    # Add the last computed dataset as an numpy image summary
    img_summaries = [modules.image_summary(val_x[0], name + '_img'),
                     modules.image_summary(val_y[0, :, :, :, np.newaxis] / num_classes, name + '_lbl'),
                     modules.image_summary(y_[0, :, :, :, np.newaxis] / num_classes, name + '_pred')]

    metrics_summaries = [modules.scalar_summary(mean_loss, name + '/ce'),
                         modules.scalar_summary(mean_dscs.mean(), name + '/dsc'),
                         modules.scalar_summary(mean_dscs[1:].mean(), name + '/dsc_wo_bg'),
                        ] + [modules.scalar_summary(mean_dscs[i + 1], name + '/dsc_lbl{}'.format(i + 1))
                             for i in range(num_classes - 1)]

    val_summaries = img_summaries + metrics_summaries
    return val_summaries
예제 #3
0
def train(args):
    """
        Complete training and validation script. Additionally, saves inference model, trained weights and summaries. 

        Parameters
        ----------
        args : argparse.parser object
            contains all necessary command line arguments

        Returns
        -------
    """

    if not args.resume:
        os.system("rm -rf %s" % args.save_path)
        os.system("mkdir -p %s" % args.save_path)
    else:
        print('Resuming training')

    num_classes = 9
    num_channels = 3
    batch_size = 4

    g = tf.Graph()
    with g.as_default():

        # Set a seed
        np.random.seed(1337)
        tf.set_random_seed(1337)

        # Build the network graph
        net = ResUNET(num_classes,
                      num_residual_units=3,
                      filters=[32, 64, 128, 256],
                      strides=[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]])

        # Parse the csv files
        print('Loading training file names from %s' % args.train_csv)
        train_files = pd.read_csv(args.train_csv, dtype=object).as_matrix()

        # I/O ops for training and validation via a custom mrbrains_reader
        x_train, y_train = reader.MRBrainsReader(
            [tf.float32, tf.int32], [[24, 64, 64, 3], [24, 64, 64]],
            name='train_queue')(train_files,
                                batch_size=batch_size,
                                n_examples=18,
                                min_queue_examples=batch_size * 2,
                                capacity=batch_size * 4)

        # Training metrics and optimisation ops
        train_net = net(x_train)
        train_logits_ = train_net['logits']
        train_pred_ = train_net['y_']
        train_truth_ = y_train

        # Add image summaries and a summary op
        modules.image_summary(x_train, 'train_img', ['training'])
        modules.image_summary(
            tf.expand_dims(tf.to_float(train_pred_) / num_classes, axis=-1),
            'train_pred', ['training'])
        modules.image_summary(
            tf.expand_dims(tf.to_float(y_train) / num_classes, axis=-1),
            'train_lbl', ['training'])

        train_summaries = tf.summary.merge([
            tf.summary.merge_all('training'),
        ] + [
            tf.summary.histogram(var.name, var)
            for var in net.get_variables(tf.GraphKeys.MOVING_AVERAGE_VARIABLES)
        ])

        # Add crossentropy loss and regularisation
        ce = modules.sparse_crossentropy(train_logits_,
                                         train_truth_,
                                         name='train/loss',
                                         collections=['losses', 'training'])
        l1 = modules.l1_regularization(
            net.get_variables(tf.GraphKeys.BIASES),
            1e-5,
            name='train/l1',
            collections=['training', 'regularization'])
        l2 = modules.l2_regularization(
            net.get_variables(tf.GraphKeys.WEIGHTS),
            1e-4,
            name='train/l2',
            collections=['training', 'regularization'])

        train_loss_ = ce + l2

        # Create a learning rate placeholder for scheduling and choose an optimisation
        lr_placeholder = tf.placeholder(tf.float32)
        train_op_ = tf.train.MomentumOptimizer(lr_placeholder,
                                               0.9).minimize(train_loss_)

        # Set up ops for validation
        if args.run_validation:
            print('Loading validation file names from %s' % args.val_csv)
            val_files = pd.read_csv(args.val_csv, dtype=str).as_matrix()
            val_reader = reader.MRBrainsReader(
                [tf.float32, tf.int32], [[48, 240, 240, 3], [48, 240, 240]],
                name='val_queue')
            val_read_func = lambda x: val_reader._read_sample(
                x, is_training=False)

            # Reuse the training model for validation inference and replace inputs with placeholders
            x_placeholder = tf.placeholder(
                tf.float32,
                shape=[None, None, None, None, num_channels],
                name='x_placeholder')
            y_placeholder = tf.placeholder(tf.int32,
                                           shape=[None, None, None, None],
                                           name='y_placeholder')

            val_net = net(x_placeholder, is_training=False)
            val_logits_ = val_net['logits']
            val_pred_ = val_net['y_']
            val_truth_ = y_placeholder

            val_loss_ = modules.sparse_crossentropy(
                val_logits_, val_truth_, collections=['losses', 'validation'])

        # Define and set up a training supervisor, handling queues and logging for tensorboard
        global_step = tf.Variable(0, name='global_step', trainable=False)
        sv = tf.train.Supervisor(logdir=args.save_path,
                                 is_chief=True,
                                 summary_op=None,
                                 save_summaries_secs=tps.save_summary_sec,
                                 save_model_secs=tps.save_model_sec,
                                 global_step=global_step)
        s = sv.prepare_or_wait_for_session(config=tf.ConfigProto())

        # Main training loop
        step = s.run(global_step) if args.resume else 0
        while not sv.should_stop():

            # Each step run the training op with a learning rate schedule
            lr = 0.001 if step < 40000 else 0.0001
            _ = s.run(train_op_, feed_dict={lr_placeholder: lr})

            # Evaluation of training and validation data
            if step % tps.steps_eval == 0:
                (train_loss, train_pred, train_truth, t_sum) = s.run(
                    [train_loss_, train_pred_, train_truth_, train_summaries])
                dscs, avds = eval_mrbrains_metrics(train_pred, train_truth)

                # Build custom metric summaries and add them to tensorboard
                sv.summary_computed(s, t_sum, global_step=step)
                sv.summary_computed(
                    s,
                    modules.scalar_summary(
                        {
                            n: val
                            for n, val in zip(['CSF', 'GM', 'WM'], dscs[1:])
                        }, 'train/dsc'),
                    global_step=step)
                sv.summary_computed(
                    s,
                    modules.scalar_summary(
                        {
                            n: val
                            for n, val in zip(['CSF', 'GM', 'WM'], avds[1:])
                        }, 'train/avd'),
                    global_step=step)

                print("\nEval step= {:d}".format(step))
                print(
                    "Train: Loss= {:.6f}; DSC= {:.4f} {:.4f} {:.4f}, AVD= {:.4f} {:.4f} {:.4f} "
                    .format(train_loss, dscs[1], dscs[2], dscs[3], avds[1],
                            avds[2], avds[3]))

                # Run inference on all validation data (here, just one dataset) and compute mean performance metrics
                if args.run_validation:
                    all_loss = []
                    all_dscs = []
                    all_avds = []
                    for f in val_files:
                        val_x, val_y = val_read_func([f])
                        val_x = val_x[np.newaxis, :]
                        val_y = val_y[np.newaxis, :]
                        (val_loss, val_pred,
                         val_truth) = s.run([val_loss_, val_pred_, val_truth_],
                                            feed_dict={
                                                x_placeholder: val_x,
                                                y_placeholder: val_y
                                            })

                        dscs, avds = eval_mrbrains_metrics(val_pred, val_truth)

                        all_loss.append(val_loss)
                        all_dscs.append(dscs)
                        all_avds.append(avds)

                    mean_loss = np.mean(all_loss, axis=0)
                    mean_dscs = np.mean(all_dscs, axis=0)
                    mean_avds = np.mean(all_avds, axis=0)

                    # Add them to tensorboard as image and metrics summaries
                    sv.summary_computed(s,
                                        modules.image_summary(
                                            val_x[0], 'val_img'),
                                        global_step=step)
                    sv.summary_computed(s,
                                        modules.image_summary(
                                            val_y[0, :, :, :, np.newaxis] /
                                            num_classes, 'val_lbl'),
                                        global_step=step)
                    sv.summary_computed(s,
                                        modules.image_summary(
                                            val_pred[0, :, :, :, np.newaxis] /
                                            num_classes, 'val_pred'),
                                        global_step=step)

                    sv.summary_computed(s,
                                        modules.scalar_summary(
                                            mean_loss, 'val/loss'),
                                        global_step=step)
                    sv.summary_computed(
                        s,
                        modules.scalar_summary(
                            {
                                n: val
                                for n, val in zip(['CSF', 'GM', 'WM'],
                                                  mean_dscs[1:])
                            }, 'val/dsc'),
                        global_step=step)
                    sv.summary_computed(
                        s,
                        modules.scalar_summary(
                            {
                                n: val
                                for n, val in zip(['CSF', 'GM', 'WM'],
                                                  mean_avds[1:])
                            }, 'val/avd'),
                        global_step=step)

                    print(
                        "Valid: Loss= {:.6f}; DSC= {:.4f} {:.4f} {:.4f}, AVD= {:.4f} {:.4f} {:.4f} "
                        .format(mean_loss, mean_dscs[1], mean_dscs[2],
                                mean_dscs[3], mean_avds[1], mean_avds[2],
                                mean_avds[3]))

            # Stopping condition
            if step >= tps.max_steps and tps.max_steps > 0:
                print('Run %d steps of %d steps - stopping now' %
                      (step, tps.max_steps))
                break

            step += 1