def main(fs_exp_config, slices, test):
    # Load data
    data = load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=fs_exp_config.data_mode,
        size=fs_exp_config.image_size,
        target_resolution=fs_exp_config.target_resolution,
        force_overwrite=False
    )
    # Get images
    batch_size = len(slices)
    if test:
        slices = slices[slices < len(data['images_test'])]
        images = data['images_test'][slices, ...]
        prefix = 'test'
    else:
        slices = slices[slices < len(data['images_train'])]
        images = data['images_train'][slices, ...]
        prefix = 'train'

    image_tensor_shape = [batch_size] + list(fs_exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')
    feed_dict = {
        images_pl: np.expand_dims(images, -1),
    }

    #Get full supervision prediction
    mask_pl, softmax_pl = model.predict(images_pl, fs_exp_config.model_handle, fs_exp_config.nlabels)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        checkpoint_path = utils.get_latest_model_checkpoint_path(fs_model_path,
                                                                 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)
        fs_predictions, _ = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)

    for i in range(batch_size):
        print_coloured(fs_predictions[i, ...],  filepath=OUTPUT_FOLDER, filename='{}{}_fs_pred'.format(prefix, slices[i]))
Example #2
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    init_step = 0

    if continue_run:
        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0

        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )

    if hasattr(exp_config, 'train_on_all_data'):
        train_on_all_data = exp_config.train_on_all_data
    else:
        train_on_all_data = False

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False,
        split_test_train=(not train_on_all_data))

    # the following are HDF5 datasets, not numpy arrays
    images_train = data['images_train']
    labels_train = data['masks_train']

    if not train_on_all_data:
        images_val = data['images_test']
        labels_val = data['masks_test']

    if exp_config.use_data_fraction:
        num_images = images_train.shape[0]
        new_last_index = int(float(num_images) * exp_config.use_data_fraction)

        logging.warning('USING ONLY FRACTION OF DATA!')
        logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' %
                        (num_images, new_last_index))
        images_train = images_train[0:new_last_index, ...]
        labels_train = labels_train[0:new_last_index, ...]

    logging.info('Data summary:')
    logging.info(' - Images:')
    logging.info(images_train.shape)
    logging.info(images_train.dtype)
    logging.info(' - Labels:')
    logging.info(labels_train.shape)
    logging.info(labels_train.dtype)

    # Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        mask_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size)

        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=mask_tensor_shape,
                                   name='labels')

        learning_rate_pl = tf.placeholder(tf.float32, shape=[])
        training_pl = tf.placeholder(tf.bool, shape=[])

        tf.summary.scalar('learning_rate', learning_rate_pl)

        # Build a Graph that computes predictions from the inference model.
        logits = model.inference(images_pl, exp_config, training=training_pl)

        # Add to the Graph the Ops for loss calculation.
        [loss, _,
         weights_norm] = model.loss(logits,
                                    labels_pl,
                                    nlabels=exp_config.nlabels,
                                    loss_type=exp_config.loss_type,
                                    weight_decay=exp_config.weight_decay
                                    )  # second output is unregularised loss

        tf.summary.scalar('loss', loss)
        tf.summary.scalar('weights_norm_term', weights_norm)

        # Add to the Graph the Ops that calculate and apply gradients.
        if exp_config.momentum is not None:
            train_op = model.training_step(loss,
                                           exp_config.optimizer_handle,
                                           learning_rate_pl,
                                           momentum=exp_config.momentum)
        else:
            train_op = model.training_step(loss, exp_config.optimizer_handle,
                                           learning_rate_pl)

        # Add the Op to compare the logits to the labels during evaluation.
        eval_loss = model.evaluation(logits,
                                     labels_pl,
                                     images_pl,
                                     nlabels=exp_config.nlabels,
                                     loss_type=exp_config.loss_type)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.

        if train_on_all_data:
            max_to_keep = None
        else:
            max_to_keep = 5

        saver = tf.train.Saver(max_to_keep=max_to_keep)
        saver_best_dice = tf.train.Saver()
        saver_best_xent = tf.train.Saver()

        # Create a session for running Ops on the Graph.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        config.allow_soft_placement = True  # If a operation is not define it the default device, let it execute in another.
        sess = tf.Session(config=config)

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # with tf.name_scope('monitoring'):

        val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
        val_error_summary = tf.summary.scalar('validation_loss', val_error_)

        val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
        val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

        val_summary = tf.summary.merge([val_error_summary, val_dice_summary])

        train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error')
        train_error_summary = tf.summary.scalar('training_loss', train_error_)

        train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice')
        train_dice_summary = tf.summary.scalar('training_dice', train_dice_)

        train_summary = tf.summary.merge(
            [train_error_summary, train_dice_summary])

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver.restore(sess, init_checkpoint_path)

        step = init_step
        curr_lr = exp_config.learning_rate

        no_improvement_counter = 0
        best_val = np.inf
        last_train = np.inf
        loss_history = []
        loss_gradient = np.inf
        best_dice = 0

        for epoch in range(exp_config.max_epochs):

            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(
                    images_train,
                    labels_train,
                    batch_size=exp_config.batch_size,
                    augment_batch=exp_config.augment_batch):

                # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the
                # training speed. However, be aware that currently an exception inside this loop may not be caught.
                # The batch generator may just continue running silently without warning eventhough the code has
                # crashed.
                # for batch in BackgroundGenerator(iterate_minibatches(images_train,
                #                                                      labels_train,
                #                                                      batch_size=exp_config.batch_size,
                #                                                      augment_batch=exp_config.augment_batch)):

                if exp_config.warmup_training:
                    if step < 50:
                        curr_lr = exp_config.learning_rate / 10.0
                    elif step == 50:
                        curr_lr = exp_config.learning_rate

                start_time = time.time()

                # batch = bgn_train.retrieve()
                x, y = batch

                # TEMPORARY HACK (to avoid incomplete batches
                if y.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_pl: y,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

                duration = time.time() - start_time

                # Write the summaries and print an overview fairly often.
                if step % 10 == 0:
                    # Print status to stdout.
                    logging.info('Step %d: loss = %.2f (%.3f sec)' %
                                 (step, loss_value, duration))
                    # Update the events file.

                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                if (step + 1) % exp_config.train_eval_frequency == 0:

                    logging.info('Training Data Eval:')
                    [train_loss,
                     train_dice] = do_eval(sess, eval_loss, images_pl,
                                           labels_pl, training_pl,
                                           images_train, labels_train,
                                           exp_config.batch_size)

                    train_summary_msg = sess.run(train_summary,
                                                 feed_dict={
                                                     train_error_: train_loss,
                                                     train_dice_: train_dice
                                                 })
                    summary_writer.add_summary(train_summary_msg, step)

                    loss_history.append(train_loss)
                    if len(loss_history) > 5:
                        loss_history.pop(0)
                        loss_gradient = (loss_history[-5] -
                                         loss_history[-1]) / 2

                    logging.info('loss gradient is currently %f' %
                                 loss_gradient)

                    if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold:
                        logging.warning('Reducing learning rate!')
                        curr_lr /= 10.0
                        logging.info('Learning rate changed to: %f' % curr_lr)

                        # reset loss history to give the optimisation some time to start decreasing again
                        loss_gradient = np.inf
                        loss_history = []

                    if train_loss <= last_train:  # best_train:
                        logging.info('Decrease in training error!')
                    else:
                        logging.info(
                            'No improvment in training error for %d steps' %
                            no_improvement_counter)

                    last_train = train_loss

                # Save a checkpoint and evaluate the model periodically.
                if (step + 1) % exp_config.val_eval_frequency == 0:

                    checkpoint_file = os.path.join(log_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_file, global_step=step)
                    # Evaluate against the training set.

                    if not train_on_all_data:

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss,
                         val_dice] = do_eval(sess, eval_loss, images_pl,
                                             labels_pl, training_pl,
                                             images_val, labels_val,
                                             exp_config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(log_dir,
                                                     'model_best_dice.ckpt')
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - %f -  Saving model_best_dice.ckpt'
                                % val_dice)

                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(log_dir,
                                                     'model_best_xent.ckpt')
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - %f -  Saving model_best_xent.ckpt'
                                % val_loss)

                step += 1

        sess.close()
    data.close()
Example #3
0
def main(exp_config):

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False)

    batch_size = 1

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    #mask_pl, softmax_pl = model.predict(images_pl, exp_config.model_handle, exp_config.nlabels)
    training_time_placeholder = tf.placeholder(tf.bool, shape=[])
    logits = model.inference(images_pl,
                             exp_config.model_handle,
                             training=training_time_placeholder,
                             nlabels=exp_config.nlabels)
    softmax_pl = tf.nn.softmax(logits)
    threshold = tf.constant(0.95, dtype=tf.float32)
    s = tf.multiply(tf.ones(shape=[1, 212, 212, 1]), threshold)
    softmax_pl = tf.concat([s, softmax_pl[..., 1:]], axis=-1)
    mask_pl = tf.arg_max(logits, dimension=-1)

    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    with tf.Session() as sess:

        sess.run(init)

        #checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            model_path, 'recursion_1_model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        for i in range(10, 20):
            ind = i  #np.random.randint(data['images_test'].shape[0])

            x = data['images_test'][ind, ...]
            y = data['masks_test'][ind, ...]

            x = image_utils.reshape_2Dimage_to_tensor(x)
            y = image_utils.reshape_2Dimage_to_tensor(y)

            feed_dict = {images_pl: x, training_time_placeholder: False}

            mask_out, softmax_out = sess.run([mask_pl, softmax_pl],
                                             feed_dict=feed_dict)

            #postprocessing

            fig = plt.figure()
            ax1 = fig.add_subplot(251)
            ax1.imshow(np.squeeze(x), cmap='gray')
            ax2 = fig.add_subplot(252)
            ax2.imshow(np.squeeze(y))
            ax3 = fig.add_subplot(253)
            ax3.imshow(np.squeeze(mask_out))

            ax5 = fig.add_subplot(256)
            ax5.imshow(np.squeeze(softmax_out[..., 0]))
            ax6 = fig.add_subplot(257)
            ax6.imshow(np.squeeze(softmax_out[..., 1]))
            ax7 = fig.add_subplot(258)
            ax7.imshow(np.squeeze(softmax_out[..., 2]))
            ax8 = fig.add_subplot(259)
            ax8.imshow(np.squeeze(softmax_out[..., 3]))
            ax8 = fig.add_subplot(2, 5, 10)
            ax8.imshow(np.squeeze(softmax_out[..., 4]))

            plt.show()
def generate_adversarial_examples(input_folder,
                                  output_path,
                                  model_path,
                                  attack,
                                  attack_args,
                                  exp_config,
                                  add_gaussian=False):
    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = exp_config.nlabels

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    mask_tensor_shape = [batch_size] + list(exp_config.image_size)
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    labels_pl = tf.placeholder(tf.uint8,
                               shape=mask_tensor_shape,
                               name='labels')
    logits_pl = model.inference(images_pl,
                                exp_config=exp_config,
                                training=tf.constant(False, dtype=tf.bool))
    eval_loss = model.evaluation(logits_pl,
                                 labels_pl,
                                 images_pl,
                                 nlabels=exp_config.nlabels,
                                 loss_type=exp_config.loss_type)

    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False,
        split_test_train=True)

    images = data['images_test'][:20]
    labels = data['masks_test'][:20]

    print("Num images train {} test {}".format(len(data['images_train']),
                                               len(images)))

    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    baseline_closs = 0.0
    baseline_cdice = 0.0
    attack_closs = 0.0
    attack_cdice = 0.0
    l2_diff_sum = 0.0
    ln_diff_sum = 0.0
    ln_diff = 0.0
    l2_diff = 0.0
    batches = 0
    result_dict = []

    with tf.Session() as sess:
        results = []
        sess.run(init)
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        for batch in BackgroundGenerator(
                train.iterate_minibatches(images, labels, batch_size)):
            x, y = batch
            batches += 1

            if batches != 9:
                continue

            non_adv_mask_out = sess.run(
                [tf.arg_max(tf.nn.softmax(logits_pl), dimension=-1)],
                feed_dict={images_pl: x})

            if attack == 'fgsm':
                adv_x = adv_attack.fgsm_run(x, y, images_pl, labels_pl,
                                            logits_pl, exp_config, sess,
                                            attack_args)
            elif attack == 'pgd':
                adv_x = adv_attack.pgd(x, y, images_pl, labels_pl, logits_pl,
                                       exp_config, sess, attack_args)
            elif attack == 'spgd':
                adv_x = adv_attack.pgd_conv(x, y, images_pl, labels_pl,
                                            logits_pl, exp_config, sess,
                                            **attack_args)
            else:
                raise NotImplementedError
            adv_x = [adv_x]

            if add_gaussian:
                print('adding gaussian noise')
                adv_x = adv_attack.add_gaussian_noise(
                    x,
                    adv_x[0],
                    sess,
                    eps=attack_args['eps'],
                    sizes=attack_args['sizes'],
                    weights=attack_args['weights'])

            for i in range(len(adv_x)):
                l2_diff = np.average(
                    np.squeeze(np.linalg.norm(adv_x[i] - x, axis=(1, 2))))
                ln_diff = np.average(
                    np.squeeze(
                        np.linalg.norm(adv_x[i] - x, axis=(1, 2), ord=np.inf)))

                l2_diff_sum += l2_diff
                ln_diff_sum += ln_diff

                print(l2_diff, l2_diff)

                adv_mask_out = sess.run(
                    [tf.arg_max(tf.nn.softmax(logits_pl), dimension=-1)],
                    feed_dict={images_pl: adv_x[i]})

                closs, cdice = sess.run(eval_loss,
                                        feed_dict={
                                            images_pl: x,
                                            labels_pl: y
                                        })
                baseline_closs = closs + baseline_closs
                baseline_cdice = cdice + baseline_cdice

                adv_closs, adv_cdice = sess.run(eval_loss,
                                                feed_dict={
                                                    images_pl: adv_x[i],
                                                    labels_pl: y
                                                })
                attack_closs = adv_closs + attack_closs
                attack_cdice = adv_cdice + attack_cdice

                partial_result = dict({
                    'attack': attack,
                    'attack_args': {
                        k: attack_args[k]
                        for k in ['eps', 'step_alpha', 'epochs']
                    },  #
                    'baseline_closs': closs,
                    'baseline_cdice': cdice,
                    'attack_closs': adv_closs,
                    'attack_cdice': adv_cdice,
                    'attack_l2_diff': l2_diff,
                    'attack_ln_diff': ln_diff
                })

                jsonString = json.dumps(str(partial_result))

                #results.append(copy.deepcopy(result_dict))

                with open(
                        "eval_results/{}-{}-{}-{}-metrics.json".format(
                            attack, add_gaussian, batches, i),
                        "w") as jsonFile:
                    jsonFile.write(jsonString)

                image_gt = "eval_results/ground-truth-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.imshow(np.squeeze(x), cmap='gray')
                plt.imshow(np.squeeze(y), cmap='viridis', alpha=0.7)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_gt, format='pdf')
                plt.clf()

                image_benign = "eval_results/benign-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.imshow(np.squeeze(x), cmap='gray')
                plt.imshow(np.squeeze(non_adv_mask_out),
                           cmap='viridis',
                           alpha=0.7)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_benign, format='pdf')
                plt.clf()

                image_adv = "eval_results/adversarial-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.imshow(np.squeeze(adv_x[i]), cmap='gray')
                plt.imshow(np.squeeze(adv_mask_out), cmap='viridis', alpha=0.7)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_adv, format='pdf')
                plt.clf()

                plt.imshow(np.squeeze(adv_x[i]), cmap='gray')
                image_adv_input = "eval_results/adv-input-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.tight_layout()
                plt.axis('off')
                plt.savefig(image_adv_input, format='pdf')
                plt.clf()

                plt.imshow(np.squeeze(x), cmap='gray')
                image_adv_input = "eval_results/benign-input-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_adv_input, format='pdf')
                plt.clf()

                print(attack_closs, attack_cdice, l2_diff, ln_diff)

        print("Evaluation results")
        print("{} Attack Params {}".format(attack, attack_args))
        print("Baseline metrics: Avg loss {}, Avg DICE Score {} ".format(
            baseline_closs / (batches * len(adv_x)),
            baseline_cdice / (batches * len(adv_x))))
        print(
            "{} Attack effectiveness: Avg loss {}, Avg DICE Score {} ".format(
                attack, attack_closs / (batches * len(adv_x)),
                attack_cdice / (batches * len(adv_x))))
        print(
            "{} Attack visibility: Avg l2-norm diff {} Avg l-inf-norm diff {}".
            format(attack, l2_diff_sum / (batches * len(adv_x)),
                   ln_diff_sum / (batches * len(adv_x))))
        result_dict = dict({
            'attack': attack,
            'attack_args':
            {k: attack_args[k]
             for k in ['eps', 'step_alpha', 'epochs']},  #
            'baseline_closs_avg': baseline_closs / batches,
            'baseline_cdice_avg': baseline_cdice / batches,
            'attack_closs_avg': attack_closs / batches,
            'attack_cdice_avg': attack_cdice / batches,
            'attack_l2_diff': l2_diff_sum / batches,
            'attack_ln_diff': ln_diff_sum / batches
        })

        results.append(copy.deepcopy(result_dict))
        print(results)

        jsonString = json.dumps(results)
        with open("eval_results/{}-results.json".format(attack),
                  "w") as jsonFile:
            jsonFile.write(jsonString)
Example #5
0
def main(exp_config):

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False
    )

    batch_size = 1

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    mask_pl, softmax_pl = model.predict(images_pl, exp_config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()


    with tf.Session() as sess:

        sess.run(init)

        checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        while True:

            ind = np.random.randint(data['images_test'].shape[0])

            x = data['images_test'][ind,...]
            y = data['masks_test'][ind,...]

            x = image_utils.reshape_2Dimage_to_tensor(x)
            y = image_utils.reshape_2Dimage_to_tensor(y)

            feed_dict = {
                images_pl: x,
            }

            mask_out, softmax_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)

            fig = plt.figure()
            ax1 = fig.add_subplot(241)
            ax1.imshow(np.squeeze(x), cmap='gray')
            ax2 = fig.add_subplot(242)
            ax2.imshow(np.squeeze(y))
            ax3 = fig.add_subplot(243)
            ax3.imshow(np.squeeze(mask_out))

            ax5 = fig.add_subplot(245)
            ax5.imshow(np.squeeze(softmax_out[...,0]))
            ax6 = fig.add_subplot(246)
            ax6.imshow(np.squeeze(softmax_out[...,1]))
            ax7 = fig.add_subplot(247)
            ax7.imshow(np.squeeze(softmax_out[...,2]))
            ax8 = fig.add_subplot(248)
            ax8.imshow(np.squeeze(softmax_out[...,3]))

            plt.show()

    data.close()
Example #6
0
def main(config):

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=config.data_root ,                  #data_root    test_data_root
        preprocessing_folder=config.preprocessing_folder,
        mode=config.data_mode,
        size=config.image_size,
        target_resolution=config.target_resolution,
        force_overwrite=False
    )

    batch_size = 1

    image_tensor_shape = [batch_size] + list(config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    mask_pl, softmax_pl = model.predict(images_pl, config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()


    with tf.Session() as sess:

        sess.run(init)

        checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        while True:

            ind = np.random.randint(data['images_test'].shape[0])

            x = data['images_test'][ind,...]
            y = data['masks_test'][ind,...]
            
            for img in x:
                if config.standardize:
                    img = image_utils.standardize_image(img)
                if config.normalize:
                    img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX)

            #x = cv2.normalize(x, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX)

            x = image_utils.reshape_2Dimage_to_tensor(x)
            y = image_utils.reshape_2Dimage_to_tensor(y)
            logging.info('x')
            logging.info(x.shape)
            logging.info(x.dtype)
            logging.info(x.min())
            logging.info(x.max())
            plt.imshow(np.squeeze(x))
            plt.gray()
            plt.axis('off')
            plt.show()
            logging.info('y')
            logging.info(y.shape)
            logging.info(y.dtype)
            
            feed_dict = {
                images_pl: x,
            }

            mask_out, softmax_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)
            logging.info('mask_out')
            logging.info(mask_out.shape)
            logging.info('softmax_out')
            logging.info(softmax_out.shape)
            fig = plt.figure(1)
            ax1 = fig.add_subplot(241)
            ax1.set_axis_off()
            ax1.imshow(np.squeeze(x), cmap='gray')
            ax2 = fig.add_subplot(242)
            ax2.set_axis_off()
            ax2.imshow(np.squeeze(y))
            ax3 = fig.add_subplot(243)
            ax3.set_axis_off()
            ax3.imshow(np.squeeze(mask_out))
            ax1.title.set_text('a')
            ax2.title.set_text('b')
            ax3.title.set_text('c')
           
            ax5 = fig.add_subplot(245)
            ax5.set_axis_off()
            ax5.imshow(np.squeeze(softmax_out[...,0]))
            ax6 = fig.add_subplot(246)
            ax6.set_axis_off()
            ax6.imshow(np.squeeze(softmax_out[...,1]))
            ax7 = fig.add_subplot(247)
            ax7.set_axis_off()
            ax7.imshow(np.squeeze(softmax_out[...,2]))
            ax8 = fig.add_subplot(248)
            ax8.set_axis_off()
            ax8.imshow(np.squeeze(softmax_out[...,3]))  #cmap=cm.Blues
            ax5.title.set_text('d')
            ax6.title.set_text('e')
            ax7.title.set_text('f')
            ax8.title.set_text('g')
            plt.gray()
            plt.show()
            
            logging.info('mask_out type')
            logging.info(mask_out.dtype)

    data.close()
Example #7
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % config.experiment_name)

    init_step = 0

    if continue_run:
        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0

        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )

    train_on_all_data = config.train_on_all_data

    # Load data
    data = acdc_data.load_and_maybe_process_data(
        input_folder=config.input_folder,
        preprocessing_folder=config.preprocessing_folder,
        mode=config.data_mode,
        size=config.image_size,
        target_resolution=config.target_resolution,
        force_overwrite=False,
        split_test_train=config.split_test_train)

    # the following are HDF5 datasets, not numpy arrays
    images_train = data['images_train']
    labels_train = data['masks_train']
    id_train = data['id_images_train']

    if not train_on_all_data:
        images_val = data['images_test']
        labels_val = data['masks_test']
        id_val = data['id_images_test']

    if config.use_data_fraction:
        num_images = images_train.shape[0]
        new_last_index = int(float(num_images) * config.use_data_fraction)

        logging.warning('USING ONLY FRACTION OF DATA!')
        logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' %
                        (num_images, new_last_index))
        images_train = images_train[0:new_last_index, ...]
        labels_train = labels_train[0:new_last_index, ...]

    logging.info('Data summary:')
    logging.info(' - Images:')
    logging.info(images_train.shape)
    logging.info(images_train.dtype)
    logging.info(' - Labels:')
    logging.info(labels_train.shape)
    logging.info(labels_train.dtype)

    #pre-process
    for img in images_train:
        if config.equalize:
            img = image_utils.equalization_image(img)
        if config.clahe:
            img = image_utils.CLAHE(img)
        if config.standardize:
            img = image_utils.standardize_image(img)
        if config.normalize:
            img = image_utils.normalize_image(img)

    if not train_on_all_data:
        for img in images_val:
            if config.equalize:
                img = image_utils.equalization_image(img)
            if config.clahe:
                img = image_utils.CLAHE(img)
            if config.standardize:
                img = image_utils.standardize_image(img)
            if config.normalize:
                img = image_utils.normalize_image(img)

    if config.prob:  #if prob is not 0
        logging.info(
            'Before data_augmentation the number of training images is:')
        logging.info(images_train.shape[0])
        #augmentation
        image_aug, label_aug = aug.augmentation_function(
            images_train, labels_train)

        #num_aug = image_aug.shape[0]
        # id images augmented will be b'0.0'
        #id_aug = np.zeros([num_aug,]).astype('|S9')
        #concatenate
        #id_train = np.concatenate((id__train,id_aug))
        images_train = np.concatenate((images_train, image_aug))
        labels_train = np.concatenate((labels_train, label_aug))

        logging.info(
            'After data_augmentation the number of training images is:')
        logging.info(images_train.shape[0])
    else:
        logging.info('No data_augmentation. Number of training images is:')
        logging.info(images_train.shape[0])

    # Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        image_tensor_shape = [config.batch_size] + list(
            config.image_size) + [1]
        mask_tensor_shape = [config.batch_size] + list(config.image_size)

        images_pl = tf.placeholder(tf.float32,
                                   shape=image_tensor_shape,
                                   name='images')
        labels_pl = tf.placeholder(tf.uint8,
                                   shape=mask_tensor_shape,
                                   name='labels')

        learning_rate_pl = tf.placeholder(tf.float32, shape=[])
        training_pl = tf.placeholder(tf.bool, shape=[])

        tf.summary.scalar('learning_rate', learning_rate_pl)

        # Build a Graph that computes predictions from the inference model.
        if (config.experiment_name == 'unet2D_valid'
                or config.experiment_name == 'unet2D_same'
                or config.experiment_name == 'unet2D_same_mod'):
            logits = model.inference(images_pl, config, training=training_pl)
        elif config.experiment_name == 'ENet':
            with slim.arg_scope(
                    model_structure.ENet_arg_scope(weight_decay=2e-4)):
                logits = model_structure.ENet(
                    images_pl,
                    num_classes=config.nlabels,
                    batch_size=config.batch_size,
                    is_training=True,
                    reuse=None,
                    num_initial_blocks=1,
                    stage_two_repeat=2,
                    skip_connections=config.skip_connections)
        else:
            logging.warning('invalid experiment_name!')

        logging.info('images_pl shape')
        logging.info(images_pl.shape)
        logging.info('labels_pl shape')
        logging.info(labels_pl.shape)
        logging.info('logits shape:')
        logging.info(logits.shape)
        # Add to the Graph the Ops for loss calculation.
        [loss, _,
         weights_norm] = model.loss(logits,
                                    labels_pl,
                                    nlabels=config.nlabels,
                                    loss_type=config.loss_type,
                                    weight_decay=config.weight_decay
                                    )  # second output is unregularised loss

        # record how Total loss and weight decay change over time
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('weights_norm_term', weights_norm)

        # Add to the Graph the Ops that calculate and apply gradients.
        if config.momentum is not None:
            train_op = model.training_step(loss,
                                           config.optimizer_handle,
                                           learning_rate_pl,
                                           momentum=config.momentum)
        else:
            train_op = model.training_step(loss, config.optimizer_handle,
                                           learning_rate_pl)

        # Add the Op to compare the logits to the labels during evaluation.
        # loss and dice on a minibatch
        eval_loss = model.evaluation(logits,
                                     labels_pl,
                                     images_pl,
                                     nlabels=config.nlabels,
                                     loss_type=config.loss_type)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.

        if train_on_all_data:
            max_to_keep = None
        else:
            max_to_keep = 5

        saver = tf.train.Saver(max_to_keep=max_to_keep)
        saver_best_dice = tf.train.Saver()
        saver_best_xent = tf.train.Saver()

        # Create a session for running Ops on the Graph.
        configP = tf.ConfigProto()
        configP.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        configP.allow_soft_placement = True  # If a operation is not define it the default device, let it execute in another.
        sess = tf.Session(config=configP)

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # with tf.name_scope('monitoring'):

        val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
        val_error_summary = tf.summary.scalar('validation_loss', val_error_)

        val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
        val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

        val_summary = tf.summary.merge([val_error_summary, val_dice_summary])

        train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error')
        train_error_summary = tf.summary.scalar('training_loss', train_error_)

        train_dice_ = tf.placeholder(tf.float32, shape=[], name='train_dice')
        train_dice_summary = tf.summary.scalar('training_dice', train_dice_)

        train_summary = tf.summary.merge(
            [train_error_summary, train_dice_summary])

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver.restore(sess, init_checkpoint_path)

        step = init_step
        curr_lr = config.learning_rate

        no_improvement_counter = 0
        best_val = np.inf
        last_train = np.inf
        loss_history = []
        loss_gradient = np.inf
        best_dice = 0

        for epoch in range(config.max_epochs):

            logging.info('EPOCH %d' % epoch)

            for batch in iterate_minibatches(images_train,
                                             labels_train,
                                             batch_size=config.batch_size):

                start_time = time.time()

                # batch = bgn_train.retrieve()
                x, y = batch

                # TEMPORARY HACK (to avoid incomplete batches)
                if y.shape[0] < config.batch_size:
                    step += 1
                    continue

                feed_dict = {
                    images_pl: x,
                    labels_pl: y,
                    learning_rate_pl: curr_lr,
                    training_pl: True
                }

                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

                duration = time.time() - start_time

                # Write the summaries and print an overview fairly often.
                if step % 10 == 0:
                    # Print status to stdout.
                    logging.info('Step %d: loss = %.2f (%.3f sec)' %
                                 (step, loss_value, duration))
                    # Update the events file.

                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                step += 1

            # end epoch
            logging.info('Training Data Eval:')
            [train_loss,
             train_dice] = do_eval(sess, eval_loss, images_pl, labels_pl,
                                   training_pl, images_train, labels_train,
                                   config.batch_size)

            train_summary_msg = sess.run(train_summary,
                                         feed_dict={
                                             train_error_: train_loss,
                                             train_dice_: train_dice
                                         })
            summary_writer.add_summary(train_summary_msg, step)

            loss_history.append(train_loss)
            if len(loss_history) > 5:
                loss_history.pop(0)
                loss_gradient = (loss_history[-5] - loss_history[-1]) / 2

            logging.info('loss gradient is currently %f' % loss_gradient)

            if train_loss <= last_train:  # best_train:
                no_improvement_counter = 0
                logging.info('Decrease in training error!')
            else:
                no_improvement_counter = no_improvement_counter + 1
                logging.info('No improvment in training error for %d steps' %
                             no_improvement_counter)

            last_train = train_loss

            # Save a checkpoint and evaluate the model periodically.
            checkpoint_file = os.path.join(log_dir, 'model.ckpt')
            saver.save(sess, checkpoint_file, global_step=step)
            # Evaluate against the training set.

            if not train_on_all_data:

                # Evaluate against the validation set.
                logging.info('Validation Data Eval:')
                [val_loss,
                 val_dice] = do_eval(sess, eval_loss, images_pl, labels_pl,
                                     training_pl, images_val, labels_val,
                                     config.batch_size)

                val_summary_msg = sess.run(val_summary,
                                           feed_dict={
                                               val_error_: val_loss,
                                               val_dice_: val_dice
                                           })
                summary_writer.add_summary(val_summary_msg, step)

                if val_dice > best_dice:
                    best_dice = val_dice
                    best_file = os.path.join(log_dir, 'model_best_dice.ckpt')
                    saver_best_dice.save(sess, best_file, global_step=step)
                    logging.info(
                        'Found new best dice on validation set! - %f -  Saving model_best_dice.ckpt'
                        % val_dice)

                if val_loss < best_val:
                    best_val = val_loss
                    best_file = os.path.join(log_dir, 'model_best_xent.ckpt')
                    saver_best_xent.save(sess, best_file, global_step=step)
                    logging.info(
                        'Found new best crossentropy on validation set! - %f -  Saving model_best_xent.ckpt'
                        % val_loss)
        sess.close()
    data.close()