Exemplo n.º 1
0
def train(args):
    """
        Complete training and validation script. Additionally, saves inference model, trained weights and summaries. 

        Parameters
        ----------
        args : argparse.parser object
            contains all necessary command line arguments

        Returns
        -------
    """

    if not args.resume:
        os.system("rm -rf %s" % args.save_path)
        os.system("mkdir -p %s" % args.save_path)
    else:
        print('Resuming training')

    num_classes = 9
    num_channels = 3
    batch_size = 4

    g = tf.Graph()
    with g.as_default():

        # Set a seed
        np.random.seed(1337)
        tf.set_random_seed(1337)

        # Build the network graph
        net = ResUNET(num_classes,
                      num_residual_units=3,
                      filters=[32, 64, 128, 256],
                      strides=[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]])

        # Parse the csv files
        print('Loading training file names from %s' % args.train_csv)
        train_files = pd.read_csv(args.train_csv, dtype=object).as_matrix()

        # I/O ops for training and validation via a custom mrbrains_reader
        x_train, y_train = reader.MRBrainsReader(
            [tf.float32, tf.int32], [[24, 64, 64, 3], [24, 64, 64]],
            name='train_queue')(train_files,
                                batch_size=batch_size,
                                n_examples=18,
                                min_queue_examples=batch_size * 2,
                                capacity=batch_size * 4)

        # Training metrics and optimisation ops
        train_net = net(x_train)
        train_logits_ = train_net['logits']
        train_pred_ = train_net['y_']
        train_truth_ = y_train

        # Add image summaries and a summary op
        modules.image_summary(x_train, 'train_img', ['training'])
        modules.image_summary(
            tf.expand_dims(tf.to_float(train_pred_) / num_classes, axis=-1),
            'train_pred', ['training'])
        modules.image_summary(
            tf.expand_dims(tf.to_float(y_train) / num_classes, axis=-1),
            'train_lbl', ['training'])

        train_summaries = tf.summary.merge([
            tf.summary.merge_all('training'),
        ] + [
            tf.summary.histogram(var.name, var)
            for var in net.get_variables(tf.GraphKeys.MOVING_AVERAGE_VARIABLES)
        ])

        # Add crossentropy loss and regularisation
        ce = modules.sparse_crossentropy(train_logits_,
                                         train_truth_,
                                         name='train/loss',
                                         collections=['losses', 'training'])
        l1 = modules.l1_regularization(
            net.get_variables(tf.GraphKeys.BIASES),
            1e-5,
            name='train/l1',
            collections=['training', 'regularization'])
        l2 = modules.l2_regularization(
            net.get_variables(tf.GraphKeys.WEIGHTS),
            1e-4,
            name='train/l2',
            collections=['training', 'regularization'])

        train_loss_ = ce + l2

        # Create a learning rate placeholder for scheduling and choose an optimisation
        lr_placeholder = tf.placeholder(tf.float32)
        train_op_ = tf.train.MomentumOptimizer(lr_placeholder,
                                               0.9).minimize(train_loss_)

        # Set up ops for validation
        if args.run_validation:
            print('Loading validation file names from %s' % args.val_csv)
            val_files = pd.read_csv(args.val_csv, dtype=str).as_matrix()
            val_reader = reader.MRBrainsReader(
                [tf.float32, tf.int32], [[48, 240, 240, 3], [48, 240, 240]],
                name='val_queue')
            val_read_func = lambda x: val_reader._read_sample(
                x, is_training=False)

            # Reuse the training model for validation inference and replace inputs with placeholders
            x_placeholder = tf.placeholder(
                tf.float32,
                shape=[None, None, None, None, num_channels],
                name='x_placeholder')
            y_placeholder = tf.placeholder(tf.int32,
                                           shape=[None, None, None, None],
                                           name='y_placeholder')

            val_net = net(x_placeholder, is_training=False)
            val_logits_ = val_net['logits']
            val_pred_ = val_net['y_']
            val_truth_ = y_placeholder

            val_loss_ = modules.sparse_crossentropy(
                val_logits_, val_truth_, collections=['losses', 'validation'])

        # Define and set up a training supervisor, handling queues and logging for tensorboard
        global_step = tf.Variable(0, name='global_step', trainable=False)
        sv = tf.train.Supervisor(logdir=args.save_path,
                                 is_chief=True,
                                 summary_op=None,
                                 save_summaries_secs=tps.save_summary_sec,
                                 save_model_secs=tps.save_model_sec,
                                 global_step=global_step)
        s = sv.prepare_or_wait_for_session(config=tf.ConfigProto())

        # Main training loop
        step = s.run(global_step) if args.resume else 0
        while not sv.should_stop():

            # Each step run the training op with a learning rate schedule
            lr = 0.001 if step < 40000 else 0.0001
            _ = s.run(train_op_, feed_dict={lr_placeholder: lr})

            # Evaluation of training and validation data
            if step % tps.steps_eval == 0:
                (train_loss, train_pred, train_truth, t_sum) = s.run(
                    [train_loss_, train_pred_, train_truth_, train_summaries])
                dscs, avds = eval_mrbrains_metrics(train_pred, train_truth)

                # Build custom metric summaries and add them to tensorboard
                sv.summary_computed(s, t_sum, global_step=step)
                sv.summary_computed(
                    s,
                    modules.scalar_summary(
                        {
                            n: val
                            for n, val in zip(['CSF', 'GM', 'WM'], dscs[1:])
                        }, 'train/dsc'),
                    global_step=step)
                sv.summary_computed(
                    s,
                    modules.scalar_summary(
                        {
                            n: val
                            for n, val in zip(['CSF', 'GM', 'WM'], avds[1:])
                        }, 'train/avd'),
                    global_step=step)

                print("\nEval step= {:d}".format(step))
                print(
                    "Train: Loss= {:.6f}; DSC= {:.4f} {:.4f} {:.4f}, AVD= {:.4f} {:.4f} {:.4f} "
                    .format(train_loss, dscs[1], dscs[2], dscs[3], avds[1],
                            avds[2], avds[3]))

                # Run inference on all validation data (here, just one dataset) and compute mean performance metrics
                if args.run_validation:
                    all_loss = []
                    all_dscs = []
                    all_avds = []
                    for f in val_files:
                        val_x, val_y = val_read_func([f])
                        val_x = val_x[np.newaxis, :]
                        val_y = val_y[np.newaxis, :]
                        (val_loss, val_pred,
                         val_truth) = s.run([val_loss_, val_pred_, val_truth_],
                                            feed_dict={
                                                x_placeholder: val_x,
                                                y_placeholder: val_y
                                            })

                        dscs, avds = eval_mrbrains_metrics(val_pred, val_truth)

                        all_loss.append(val_loss)
                        all_dscs.append(dscs)
                        all_avds.append(avds)

                    mean_loss = np.mean(all_loss, axis=0)
                    mean_dscs = np.mean(all_dscs, axis=0)
                    mean_avds = np.mean(all_avds, axis=0)

                    # Add them to tensorboard as image and metrics summaries
                    sv.summary_computed(s,
                                        modules.image_summary(
                                            val_x[0], 'val_img'),
                                        global_step=step)
                    sv.summary_computed(s,
                                        modules.image_summary(
                                            val_y[0, :, :, :, np.newaxis] /
                                            num_classes, 'val_lbl'),
                                        global_step=step)
                    sv.summary_computed(s,
                                        modules.image_summary(
                                            val_pred[0, :, :, :, np.newaxis] /
                                            num_classes, 'val_pred'),
                                        global_step=step)

                    sv.summary_computed(s,
                                        modules.scalar_summary(
                                            mean_loss, 'val/loss'),
                                        global_step=step)
                    sv.summary_computed(
                        s,
                        modules.scalar_summary(
                            {
                                n: val
                                for n, val in zip(['CSF', 'GM', 'WM'],
                                                  mean_dscs[1:])
                            }, 'val/dsc'),
                        global_step=step)
                    sv.summary_computed(
                        s,
                        modules.scalar_summary(
                            {
                                n: val
                                for n, val in zip(['CSF', 'GM', 'WM'],
                                                  mean_avds[1:])
                            }, 'val/avd'),
                        global_step=step)

                    print(
                        "Valid: Loss= {:.6f}; DSC= {:.4f} {:.4f} {:.4f}, AVD= {:.4f} {:.4f} {:.4f} "
                        .format(mean_loss, mean_dscs[1], mean_dscs[2],
                                mean_dscs[3], mean_avds[1], mean_avds[2],
                                mean_avds[3]))

            # Stopping condition
            if step >= tps.max_steps and tps.max_steps > 0:
                print('Run %d steps of %d steps - stopping now' %
                      (step, tps.max_steps))
                break

            step += 1
Exemplo n.º 2
0
def train(args):
    """
        Complete training and validation script. Additionally, saves inference model, trained weights and summaries.

        Parameters
        ----------
        args : argparse.parser object
            contains all necessary command line arguments

        Returns
        -------
    """

    if not args.resume:
        os.system("rm -rf %s" % args.save_path)
        os.system("mkdir -p %s" % args.save_path)
    else:
        print('Resuming training')

    num_classes = 10
    batch_size = 128

    # Define and build the network graph
    net = ResNet(num_classes, strides=[[1, 1], [1, 1], [2, 2], [2, 2]])

    # Parse the csv files and define input ops for training and validation I/O
    print('Loading data from {}'.format(args.data_dir))
    x_train, y_train = reader.build_input(
        'cifar10', os.path.join(args.data_dir, 'data_batch*'), batch_size,
        'train')

    # Define training metrics and optimisation ops
    train_net = net(x_train)
    train_logits_ = train_net['logits']
    train_pred_ = train_net['y_']
    train_truth_ = y_train

    train_acc_ = tf.reduce_mean(
        tf.cast(
            tf.equal(tf.cast(train_truth_, tf.int32),
                     tf.cast(train_pred_, tf.int32)), tf.float32))
    modules.scalar_summary(train_acc_,
                           'train/acc',
                           collections=['losses', 'metrics'])

    ce = modules.sparse_crossentropy(train_logits_,
                                     train_truth_,
                                     name='train/loss',
                                     collections=['losses', 'training'])
    l2 = modules.l2_regularization(net.get_variables(tf.GraphKeys.WEIGHTS),
                                   0.0002,
                                   name='train/l2',
                                   collections=['training', 'regularization'])
    train_loss_ = ce + l2

    lr_placeholder = tf.placeholder(tf.float32)
    train_op_ = tf.train.MomentumOptimizer(lr_placeholder,
                                           0.9).minimize(train_loss_)

    train_summaries = tf.summary.merge([
        tf.summary.merge_all('training'),
    ] + [
        tf.summary.histogram(var.name, var)
        for var in net.get_variables(tf.GraphKeys.MOVING_AVERAGE_VARIABLES)
    ])

    if args.run_validation:
        X_test, Y_test = reader.build_input(
            'cifar10', os.path.join(args.data_dir, 'test_batch*'), 100, 'eval')
        # Define validation outputs
        val_net = net(X_test, is_training=False)
        val_logits_ = val_net['logits']
        val_pred_ = val_net['y_']
        val_truth_ = Y_test

        val_loss_ = modules.sparse_crossentropy(
            val_logits_, val_truth_, collections=['losses', 'validation'])
        val_acc_ = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.cast(val_truth_, tf.int32),
                         tf.cast(val_pred_, tf.int32)), tf.float32))

    # Define and setup a training supervisor
    global_step = tf.Variable(0, name='global_step', trainable=False)
    sv = tf.train.Supervisor(logdir=args.save_path,
                             is_chief=True,
                             summary_op=None,
                             save_summaries_secs=tps.save_summary_sec,
                             save_model_secs=tps.save_model_sec,
                             global_step=global_step)

    s = sv.prepare_or_wait_for_session(config=tf.ConfigProto())

    # Main training loop
    step = s.run(global_step) if args.resume else 0
    while not sv.should_stop():

        if step < 40000:
            lr = 0.1
        elif step < 60000:
            lr = 0.01
        elif step < 80000:
            lr = 0.001
        else:
            lr = 0.0001

        # Run the training op
        _ = s.run(train_op_, feed_dict={lr_placeholder: lr})

        # Evaluation of training and validation data
        if step % tps.steps_eval == 0:
            (train_loss, train_acc, train_pred, train_truth, t_sum) = s.run([
                train_loss_, train_acc_, train_pred_, train_truth_,
                train_summaries
            ])
            sv.summary_computed(s, t_sum, global_step=step)

            print("\nEval step= {:d}".format(step))
            print("Train: Loss= {:.6f}; Acc {:.6f}".format(
                train_loss, train_acc))

            # Evaluate all validation data
            if args.run_validation:
                all_loss = []
                all_acc = []
                for _ in range(50):
                    (val_loss, val_pred, val_truth, val_acc) = s.run(
                        [val_loss_, val_pred_, val_truth_, val_acc_])

                    all_loss.append(val_loss)
                    all_acc.append(val_acc)

                mean_loss = np.mean(all_loss, axis=0)
                mean_acc = np.mean(all_acc, axis=0)

                sv.summary_computed(s,
                                    modules.scalar_summary(
                                        mean_loss, 'val/loss'),
                                    global_step=step)
                sv.summary_computed(s,
                                    modules.scalar_summary(
                                        mean_acc, 'val/acc'),
                                    global_step=step)

                print("Valid: Loss= {:.6f}; Acc {:.6f}".format(
                    mean_loss, mean_acc))

        # Stopping condition
        if step >= tps.max_steps and tps.max_steps > 0:
            print('Run %d steps of %d steps - stopping now' %
                  (step, tps.max_steps))
            break

        step += 1