예제 #1
0
def DEBUG_core(sess):
    logging.info('#-----------------------------------------------#')
    logging.info('#                Start Debugging                #')
    logging.info('#-----------------------------------------------#')

    args.test_dir = filesys.find_or_create_test_dir(args.test_dir,
                                                    args.train_dir,
                                                    opts=opts_test)

    batch_img, batch_label, batch_weights = data_layers.data_HDF5(
        opts_train.dataset,
        opts_train.shape_img,
        opts_train.shape_label,
        opts_train.shape_weights,
        shuffle=False,
        batch_size=opts.batch_size,
        prefetch_threads=12,
        prefetch_n=40,
        resample_n=40,
        augment=True)

    resize = opts_train.resize
    if opts_train.resize is not None:
        if opts_train.resize_method == "scale":
            batch_img = tf.image.resize_images(batch_img, resize)
            batch_label = tf.image.resize_images(batch_label, resize)
            batch_label = tf.cast(batch_label,
                                  tf.uint8)  # is changed by resize
            batch_weights = tf.image.resize_images(batch_weights, resize)
        elif opts_train.resize_method == "center_crop":
            batch_img = tf.map_fn(lambda img: tf.image.central_crop(img, 0.5),
                                  batch_img,
                                  parallel_iterations=8,
                                  name="center_crop")
            batch_label = tf.map_fn(
                lambda img: tf.image.central_crop(img, 0.5),
                batch_label,
                parallel_iterations=8,
                name="center_crop")
            batch_weights = tf.map_fn(
                lambda img: tf.image.central_crop(img, 0.5),
                batch_weights,
                parallel_iterations=8,
                name="center_crop")
        else:
            random_seed = 42  # tf.random_uniform(1, minval=0, maxval=65536, dtype=tf.int16)
            batch_img = tf.random_crop(batch_img, [
                batch_img.get_shape().as_list()[0], resize[0], resize[1],
                batch_img.get_shape().as_list()[3]
            ],
                                       seed=random_seed)
            batch_label = tf.random_crop(batch_label, [
                batch_label.get_shape().as_list()[0], resize[0], resize[1],
                batch_label.get_shape().as_list()[3]
            ],
                                         seed=random_seed)
            batch_weights = tf.random_crop(batch_weights, [
                batch_weights.get_shape().as_list()[0], resize[0], resize[1],
                batch_weights.get_shape().as_list()[3]
            ],
                                           seed=random_seed)

    for bb in range(opts_test.n_samples):
        r_batch_img = []
        for b in range(8):
            start = timer()
            e_batch_img, e_batch_label, e_batch_weights = sess.run(
                [batch_img, batch_label, batch_weights])
            end = timer()

            print("got batch in %.4f s : img %s %s" % (
                (end - start), str(e_batch_img.shape), str(e_batch_img.dtype)))

            r_batch_img.append(
                np.reshape(e_batch_img,
                           [-1, e_batch_img.shape[2], e_batch_img.shape[3]]))

        print("stitching and creating file")
        out_img = np.concatenate(
            [img_util.to_rgb(batch) for batch in r_batch_img], axis=1)
        img_util.save_image(out_img,
                            "%s/img_aug_%s.jpg" % (args.test_dir, str(bb)))
예제 #2
0
def test_debug_sampling(sess, net_test):
    logging.info('#-----------------------------------------------#')
    logging.info('#        Starting Testing with sampling         #')
    logging.info('#-----------------------------------------------#')

    # load model for testing (if None provided, searches in train_dir, if not found doesn't load)
    trainer = SimpleTrainer(session=sess, train_dir=args.train_dir)
    chkpt_loaded = trainer.load_checkpoint(args.checkpoint)
    args.test_dir = filesys.find_or_create_test_dir(args.test_dir,
                                                    args.train_dir)

    # add uncertainty op if supported (network trained with aleatoric loss)
    if opts_test.aleatoric_sample_n is not None and opts_test.aleatoric_sample_n > 0:
        # get uncertainty (entropy) and sampled, softmaxed output mask
        net_batch_uncertainty, net_batch_softmax = uncertainty.aleatoric_entropy(
            net_test.output_mask, net_test.sigma_activations,
            opts_test.aleatoric_sample_n)
        # create prediction from sampled, softmaxed output mask
        net_batch_prediction = tf.argmax(tf.nn.softmax(net_batch_softmax),
                                         axis=-1,
                                         output_type=tf.int32)
    else:
        # add softmax op, default net prediction and dummy uncertainty
        net_batch_softmax = tf.nn.softmax(net_test.output_mask)
        net_batch_prediction = net_test.prediction
        # for default net w/o aleatoric loss no uncertainty is defined. Can be generated with dropout (resample_n)
        net_batch_uncertainty = tf.constant(0,
                                            shape=list(
                                                net_batch_softmax.shape)[0:3])

    # ###########################################################################
    # RUN UNET
    # ###########################################################################
    logging.debug("predicting: n_samples %s, resample_n %s, batch_size %s" %
                  (str(opts_test.n_samples), str(
                      opts_test.resample_n), str(opts_test.batch_size)))

    # create folder to store samples in
    if opts_test.resample_n is not None:
        sample_dir = args.test_dir + os.sep + 'samples'
        os.mkdir(sample_dir)

    for b in range(opts_test.n_samples):
        # GENERATE RESULTS
        # ----------------
        try:
            # standard run (or first run if resampling, to get batch_img and batch_label once)
            logging.debug("run ...")
            batch_img, batch_label, batch_softmax, batch_prediction, batch_uncertainty \
                = sess.run([net_test.batch_img, net_test.batch_label,
                            net_batch_softmax, net_batch_prediction, net_batch_uncertainty])

            # if resampling, create store to sample into
            if opts_test.resample_n is not None:
                r_batch_pred = np.reshape(batch_prediction,
                                          [-1, batch_prediction.shape[2]])
                # init prediction_sample store
                prediction_samples = np.zeros([opts_test.resample_n] +
                                              list(r_batch_pred.shape),
                                              dtype=np.uint8)
                # store prediction for averaging
                prediction_samples[0, ...] = r_batch_pred
                if opts_test.aleatoric_sample_n is not None:
                    r_batch_uncertainty = np.reshape(
                        batch_uncertainty, [-1, batch_uncertainty.shape[2]])
                    # init prediction_sample store
                    uncertainty_samples = np.zeros(
                        [opts_test.aleatoric_sample_n] +
                        list(r_batch_uncertainty.shape),
                        dtype=np.float32)
                    # store uncertainty for averaging (comes as [batch_size x y])
                    uncertainty_samples[0, ...] = r_batch_uncertainty

                # start sampling (-1 because one run was already done)
                for s in range(1, opts_test.resample_n):
                    batch_softmax, batch_prediction, batch_uncertainty \
                        = sess.run([net_batch_softmax, net_test.prediction, net_batch_uncertainty])

                    # store prediction
                    r_batch_pred = np.reshape(batch_prediction,
                                              [-1, batch_prediction.shape[2]])
                    prediction_samples[s, ...] = r_batch_pred
                    # save sample images to sample folder
                    out_sample = img_util.to_rgb(prediction_samples[s, ...])
                    img_util.save_image(
                        out_sample,
                        "%s/sample_%s_%s_pred.png" % (sample_dir, b, s))

                    if opts_test.aleatoric_sample_n is not None:
                        # store uncertainty
                        r_batch_uncertainty = np.reshape(
                            batch_uncertainty,
                            [-1, batch_uncertainty.shape[2]])
                        uncertainty_samples[s, ...] = r_batch_uncertainty
                        # save sample images to sample folder
                        out_sample = img_util.to_rgb(uncertainty_samples[s,
                                                                         ...])
                        img_util.save_image(
                            out_sample,
                            "%s/sample_%s_%s_unc.png" % (sample_dir, b, s))

        except tf.errors.OutOfRangeError:
            break

        # PROCESS RESULTS
        # ---------------
        # reshape so that batch_size is merged into x dimension (images are concatenated along x dim)
        r_batch_img = np.reshape(batch_img,
                                 [-1, batch_img.shape[2], batch_img.shape[3]])
        r_batch_label = np.reshape(
            batch_label, [-1, batch_label.shape[2], batch_label.shape[3]])
        r_batch_softmax = np.reshape(
            batch_softmax,
            [-1, batch_softmax.shape[2], batch_softmax.shape[3]])

        # if resampling, create pred and uncertainty from samples
        if opts_test.resample_n is not None:
            # create mean of samples and save as r_batch_prediction
            logging.info('resampled pred (%s), averaging for pred' %
                         (str(opts_test.resample_n)))
            r_batch_prediction = np.mean(prediction_samples, axis=0)
            #std_deviation = np.std(prediction_samples, axis=0)

            if opts_test.aleatoric_sample_n is not None:
                #uncertainty_samples = uncertainty_samples[..., np.newaxis]
                # create mean of samples and save as r_batch_uncertainty
                logging.info(
                    'calculating combined entropy from %s aleatoric samples' %
                    (str(opts_test.aleatoric_sample_n)))
                # create
                r_batch_uncertainty = -np.sum(
                    uncertainty_samples *
                    np.nan_to_num(np.log(uncertainty_samples)),
                    axis=0)
                #r_batch_uncertainty = r_batch_uncertainty
                r_batch_uncertainty /= np.max(r_batch_uncertainty)
            else:
                logging.info(
                    'calculating epistemic entropy from %s pred samples' %
                    (str(opts_test.resample_n)))
                # compute epistemic uncertainty and overwrite (empty) network output uncertainty
                #TODO entropy with softmax? but which class?
                r_batch_uncertainty = calc.entropy_bin_array(
                    prediction_samples)
        else:
            r_batch_prediction = np.reshape(batch_prediction,
                                            [-1, batch_prediction.shape[2]])
            # append axis for easier processing in MATLAB (same rank as softmax)
            batch_uncertainty = batch_uncertainty[..., np.newaxis]
            r_batch_uncertainty = np.reshape(
                batch_uncertainty,
                [-1, batch_uncertainty.shape[2], batch_uncertainty.shape[3]])

        # WRITE RESULTS
        # -------------
        logging.debug('writing tile-file: %s/tile_%s.mat' % (args.test_dir, b))
        logging.debug('  softmax_activations: %s %s' %
                      (str(r_batch_softmax.shape), str(r_batch_softmax.dtype)))
        logging.debug(
            '  prediction: %s %s' %
            (str(r_batch_prediction.shape), str(r_batch_prediction.dtype)))
        logging.debug('  img: %s %s' %
                      (str(r_batch_img.shape), str(r_batch_img.dtype)))
        logging.debug('  label: %s %s' %
                      (str(r_batch_label.shape), str(r_batch_label.dtype)))
        logging.debug(
            '  uncertainty: %s %s' %
            (str(r_batch_uncertainty.shape), str(r_batch_uncertainty.dtype)))

        # write matlab arrays
        scipy.io.savemat("%s/tile_%02d.mat" % (args.test_dir, b),
                         mdict={
                             'tile': r_batch_img,
                             'label': r_batch_label,
                             'pred': r_batch_prediction,
                             'softmax': r_batch_softmax,
                             'uncertainty': r_batch_uncertainty
                         })

        # summary
        out_img = np.concatenate(
            (
                np.squeeze(img_util.to_rgb(r_batch_img)),
                np.squeeze(img_util.to_rgb(r_batch_label)),
                #np.squeeze(img_util.to_rgb(r_batch_softmax[..., 0, np.newaxis], normalize=True)),
                np.squeeze(
                    img_util.to_rgb(r_batch_softmax[..., 1, np.newaxis],
                                    normalize=True)),
                np.squeeze(img_util.to_rgb(r_batch_prediction[...,
                                                              np.newaxis])),
                np.squeeze(
                    img_util.to_rgb(r_batch_uncertainty[..., np.newaxis]))),
            axis=1)
        img_util.save_image(out_img,
                            "%s/tile_%02d_summary.png" % (args.test_dir, b))
예제 #3
0
def test_metrics(sess, net_test):
    logging.info('#-----------------------------------------------#')
    logging.info('#               Starting Testing (metrics)      #')
    logging.info('#-----------------------------------------------#')

    # load model for testing (if None provided, searches in train_dir, if not found doesn't load)
    logging.info("Attempt restore with SimpleTrainer load from: %s" %
                 args.checkpoint)
    # load model for continued training (if None provided, searches in train_dir, if not found doesn't load)
    trainer = SimpleTrainer(session=sess, train_dir=args.train_dir)
    chkpt_loaded = trainer.load_checkpoint(args.checkpoint)

    # init variables if no checkpoint was loaded
    if not chkpt_loaded: sess.run(tf.group(tf.global_variables_initializer()))
    logging.info("Loaded variables from checkpoint"
                 if chkpt_loaded else "Randomly initialized (!) variables")

    # create test op and initialize associated variables
    test_op = net_test.test_op()
    tf_helpers.initialize_uninitialized(sess, vars=tf.global_variables())
    tf_helpers.initialize_uninitialized(sess, vars=tf.local_variables())

    # ###########################################################################
    # RUN UNET
    # ###########################################################################
    logging.debug("predicting, sampling %s times, batch_size %s" %
                  (opts_test.n_samples, opts_test.batch_size))

    [c_accuracy, c_precision, c_recall, c_accuracy_per_class,
     c_mean_iou] = [(0, 0), (0, 0), (0, 0), 0, 0]

    [global_step] = sess.run([net_test.global_step])
    args.test_dir = filesys.find_or_create_test_dir(args.test_dir,
                                                    args.train_dir,
                                                    global_step=global_step)

    for b in range(opts_test.n_samples):
        try:
            test_op_results = sess.run(test_op)
        except tf.errors.OutOfRangeError:
            break

        [
            batch_img, batch_label, batch_weights, batch_activations,
            batch_softmax, batch_prediction, accuracy, precision, recall,
            accuracy_per_class, mean_iou
        ] = test_op_results

        [c_accuracy, c_precision, c_recall, c_accuracy_per_class, c_mean_iou] = \
            [[sum(x) for x in zip(accuracy,c_accuracy)],
             [sum(x) for x in zip(precision, c_precision)],
             [sum(x) for x in zip(recall, c_recall)],
              c_accuracy_per_class + accuracy_per_class[0],
              c_mean_iou + mean_iou[0]]

        # out_img = np.squeeze(img_util.to_rgb(batch_activations))
        # img_util.save_image(out_img, "%s/img_%s_pred.png" % (args.test_dir, b))
        # logging.debug('\naccuracy: %s, prec: %s, rec: %s \naccuracy_per_class %s, \nmean_iou %s' %
        #               (str(accuracy), str(precision), str(recall),
        #                str(accuracy_per_class), str(mean_iou)))
        #logging.debug('batch_activations: %s %s' % (str(batch_activations.shape), str(batch_activations.dtype)))
        #logging.debug('batch_prediction: %s %s' % (str(batch_prediction.shape), str(batch_prediction.dtype)))
        #logging.debug('batch_img: %s %s' % (str(batch_img.shape), str(batch_img.dtype)))
        #logging.debug('batch_label: %s %s' % (str(batch_label.shape), str(batch_label.dtype)))

        # logging.debug('describe prediction_samples: ' + str(stats.describe(batch_activations)))
        # logging.debug('describe prediction_samples[0]: ' + str(stats.describe(prediction_samples[0])))
        # out_img = img_util.combine_img_prediction(batch_img, batch_label, batch_activations)

        r_batch_img = np.reshape(batch_img,
                                 [-1, batch_img.shape[2], batch_img.shape[3]])
        r_batch_label = np.reshape(
            batch_label, [-1, batch_label.shape[2], batch_label.shape[3]])
        r_batch_activations = np.reshape(
            batch_activations,
            [-1, batch_activations.shape[2], batch_activations.shape[3]])
        r_batch_prediction = np.reshape(batch_prediction,
                                        [-1, batch_prediction.shape[2]])

        # r_batch_softmax = calc.softmax(r_batch_activations, axis=-1) # slow
        argmax = np.argmax(r_batch_activations,
                           axis=-1)  # just take direct max

        out_img = np.concatenate(
            (np.squeeze(img_util.to_rgb(r_batch_img)),
             np.squeeze(img_util.to_rgb(r_batch_label)),
             np.squeeze(
                 img_util.to_rgb(r_batch_activations[..., 0, np.newaxis],
                                 normalize=False)),
             np.squeeze(
                 img_util.to_rgb(r_batch_activations[..., 1, np.newaxis],
                                 normalize=False)),
             np.squeeze(img_util.to_rgb(argmax[..., np.newaxis])),
             np.squeeze(img_util.to_rgb(r_batch_prediction[..., np.newaxis]))),
            axis=1)

        img_util.save_image(out_img, "%s/img_%s.png" % (args.test_dir, b))

        # ###########################################################################
        # CLOSE NET
        # ###########################################################################

    c_accuracy = tuple(x / opts_test.n_samples for x in c_accuracy)
    c_precision = tuple(x / opts_test.n_samples for x in c_precision)
    c_recall = tuple(x / opts_test.n_samples for x in c_recall)
    c_accuracy_per_class = c_accuracy_per_class / opts_test.n_samples
    c_mean_iou = c_mean_iou / opts_test.n_samples
    logging.debug(
        '\naccuracy: %s, prec: %s, rec: %s \naccuracy_per_class %s, \nmean_iou %s'
        % (str(c_accuracy), str(c_precision), str(c_recall),
           str(c_accuracy_per_class), str(c_mean_iou)))
예제 #4
0
    def val_sample_fn(self):
        """ does opts_val.n_samples forward passes and writes images """
        opts_val = self.opts_val
        val_dir = self.val_dir
        sess = self.sess

        if opts_val is not None and opts_val.val_sample_intervall is not None:
            logging.info('#----------------------------#')
            logging.info('#    Sampling over Valset    #')
            logging.info("# %s samples, to %s" % (opts_val.n_samples, val_dir))

            global_step_value = self.sess.run(self.global_step)
            # not needed:
            # self.val_summary_writer.add_session_log(tf.SessionLog(status=tf.SessionLog.START), global_step=global_step_value)

            for b in range(opts_val.n_samples):
                try:
                    # run session with feed_dict to switch datalayer
                    batch_img, batch_label, batch_activations, batch_prediction, \
                    val_summary = \
                        sess.run(
                        [self.batch_img, self.batch_label, self.output_mask, self.prediction,
                         self.merged_summary],
                        feed_dict = {self.use_train_data : False}
                    )
                except tf.errors.OutOfRangeError:
                    break

                # add summaries
                self.val_summary_writer.add_summary(val_summary,
                                                    global_step_value)
                self.val_summary_writer.flush()

                if val_dir is not None:
                    r_batch_img = np.reshape(
                        batch_img,
                        [-1, batch_img.shape[2], batch_img.shape[3]])
                    r_batch_label = np.reshape(
                        batch_label,
                        [-1, batch_label.shape[2], batch_label.shape[3]])
                    r_batch_activations = np.reshape(batch_activations, [
                        -1, batch_activations.shape[2],
                        batch_activations.shape[3]
                    ])
                    r_batch_prediction = np.reshape(
                        batch_prediction, [-1, batch_prediction.shape[2]])

                    out_img = np.concatenate(
                        (np.squeeze(img_util.to_rgb(r_batch_img)),
                         np.squeeze(img_util.to_rgb(r_batch_label)),
                         np.squeeze(
                             img_util.to_rgb(r_batch_activations[..., 0,
                                                                 np.newaxis],
                                             normalize=True)),
                         np.squeeze(
                             img_util.to_rgb(r_batch_activations[..., 1,
                                                                 np.newaxis],
                                             normalize=True)),
                         np.squeeze(
                             img_util.to_rgb(r_batch_prediction[...,
                                                                np.newaxis]))),
                        axis=1)

                    img_util.save_image(out_img,
                                        "%s/img_%s.png" % (val_dir, b))

            logging.info('#                            #')
            logging.info('#-X--------------------------#')
        pass
예제 #5
0
def test_debug(sess, net_test):
    logging.info('#-----------------------------------------------#')
    logging.info('#               Starting Testing (debug)        #')
    logging.info('#-----------------------------------------------#')

    # load model for testing (if None provided, searches in train_dir, if not found doesn't load)
    trainer = SimpleTrainer(session=sess, train_dir=args.train_dir)
    chkpt_loaded = trainer.load_checkpoint(args.checkpoint)
    # init variables if no checkpoint was loaded
    if not chkpt_loaded: sess.run(tf.group(tf.global_variables_initializer()))
    logging.info("Loaded variables from checkpoint"
                 if chkpt_loaded else "Randomly initialized (!) variables")

    # add softmax op
    net_batch_softmax = tf.nn.softmax(net_test.output_mask)
    # add uncertainty op if supported (network trained with aleatoric loss)
    if opts_test.aleatoric_sample_n is not None and opts_test.aleatoric_sample_n > 0:
        net_batch_uncertainty = uncertainty.aleatoric_entropy(
            net_test.output_mask, net_test.sigma_activations,
            opts_test.aleatoric_sample_n)
    else:
        net_batch_uncertainty = \
            tf.constant(0, shape = [opts_test.batch_size, opts_test.shape_label[0], opts_test.shape_label[1]])
    # ###########################################################################
    # RUN UNET
    # ###########################################################################
    logging.debug("predicting, sampling %s times, batch_size %s" %
                  (opts_test.n_samples, opts_test.batch_size))

    for b in range(opts_test.n_samples):
        try:
            logging.debug("run ...")
            batch_img, batch_label, batch_softmax, batch_prediction, batch_uncertainty \
                = sess.run(  [net_test.batch_img, net_test.batch_label,
                              net_batch_softmax, net_test.prediction, net_batch_uncertainty ])
        except tf.errors.OutOfRangeError:
            break
        logging.debug("... success")

        # reshape so that batch_size is merged into x dimension (images are concatenated along x dim)
        #logging.debug('  uncertainty: %s %s' % (str(batch_uncertainty.shape), str(batch_uncertainty.dtype)))
        batch_uncertainty = batch_uncertainty[..., np.newaxis]
        r_batch_img = np.reshape(batch_img,
                                 [-1, batch_img.shape[2], batch_img.shape[3]])
        r_batch_label = np.reshape(
            batch_label, [-1, batch_label.shape[2], batch_label.shape[3]])
        r_batch_softmax = np.reshape(
            batch_softmax,
            [-1, batch_softmax.shape[2], batch_softmax.shape[3]])
        r_batch_prediction = np.reshape(batch_prediction,
                                        [-1, batch_prediction.shape[2]])
        r_batch_uncertainty = np.reshape(
            batch_uncertainty,
            [-1, batch_uncertainty.shape[2], batch_uncertainty.shape[3]])

        import scipy.io
        # matlab arrays
        scipy.io.savemat("%s/tile_%02d.mat" % (args.test_dir, b),
                         mdict={
                             'tile': r_batch_img,
                             'label': r_batch_label,
                             'pred': r_batch_prediction,
                             'softmax': r_batch_softmax,
                             'uncertainty': r_batch_uncertainty
                         })

        logging.debug('writing tile-file: %s/tile_%s.mat' % (args.test_dir, b))
        logging.debug('  softmax_activations: %s %s' %
                      (str(r_batch_softmax.shape), str(r_batch_softmax.dtype)))
        logging.debug(
            '  prediction: %s %s' %
            (str(r_batch_prediction.shape), str(r_batch_prediction.dtype)))
        logging.debug('  img: %s %s' %
                      (str(r_batch_img.shape), str(r_batch_img.dtype)))
        logging.debug('  label: %s %s' %
                      (str(r_batch_label.shape), str(r_batch_label.dtype)))
        logging.debug(
            '  uncertainty: %s %s' %
            (str(r_batch_uncertainty.shape), str(r_batch_uncertainty.dtype)))

        # summary
        out_img = np.concatenate(
            (
                np.squeeze(img_util.to_rgb(r_batch_img)),
                np.squeeze(img_util.to_rgb(r_batch_label)),
                #np.squeeze(img_util.to_rgb(r_batch_softmax[..., 0, np.newaxis], normalize=True)),
                np.squeeze(
                    img_util.to_rgb(r_batch_softmax[..., 1, np.newaxis],
                                    normalize=True)),
                np.squeeze(img_util.to_rgb(r_batch_prediction[...,
                                                              np.newaxis])),
                np.squeeze(
                    img_util.to_rgb(r_batch_uncertainty[..., np.newaxis]))),
            axis=1)
        img_util.save_image(out_img,
                            "%s/tile_%02d_summary.png" % (args.test_dir, b))
예제 #6
0
def DEBUG_core(sess):
    logging.info('#-----------------------------------------------#')
    logging.info('#                Start Debugging                #')
    logging.info('#-----------------------------------------------#')

    batch_img, batch_label, batch_weights = data_layers.data_HDF5(
        args.dataset,
        opts.shape_img,
        opts.shape_label,
        opts.shape_weights,
        shuffle=False,
        batch_size=opts.batch_size,
        prefetch_threads=12,
        prefetch_n=40,
        resample_n=40,
        augment=True)

    for bb in range(opts.n_samples):
        r_batch_img = []
        for b in range(8):
            start = timer()
            e_batch_img, e_batch_label, e_batch_weights = sess.run(
                [batch_img, batch_label, batch_weights])
            end = timer()

            print("got batch in %.4f s : img %s %s" % (
                (end - start), str(e_batch_img.shape), str(e_batch_img.dtype)))

            r_batch_img.append(
                np.reshape(e_batch_img,
                           [-1, e_batch_img.shape[2], e_batch_img.shape[3]]))

        print("stitching and creating file")
        out_img = np.concatenate(
            [img_util.to_rgb(batch) for batch in r_batch_img], axis=1)
        img_util.save_image(out_img,
                            "%s/img_aug_%s.jpg" % (args.predict_dir, str(bb)))

    batch_img, batch_label, batch_weights = data_layers.data_HDF5(
        args.dataset,
        opts.shape_img,
        opts.shape_label,
        opts.shape_weights,
        shuffle=True,
        batch_size=opts.batch_size,
        prefetch_threads=12,
        prefetch_n=10,
        resample_n=None,
        augment=True)

    for b in range(50):
        start = timer()
        e_batch_img, e_batch_label, e_batch_weights = sess.run(
            [batch_img, batch_label, batch_weights])
        end = timer()

        print("got batch in %.4f s : img %s %s, label %s %s, weights %s %s" %
              ((end - start), str(e_batch_img.shape), str(e_batch_img.dtype),
               str(e_batch_label.shape), str(e_batch_label.dtype),
               str(e_batch_weights.shape), str(e_batch_weights.dtype)))

        r_batch_img = np.reshape(
            e_batch_img, [-1, e_batch_img.shape[2], e_batch_img.shape[3]])
        r_batch_label = np.reshape(
            e_batch_label,
            [-1, e_batch_label.shape[2], e_batch_label.shape[3]])
        r_batch_weights = np.reshape(
            e_batch_weights,
            [-1, e_batch_weights.shape[2], e_batch_weights.shape[3]])

        out_img = np.concatenate(
            (np.squeeze(img_util.to_rgb(r_batch_img)),
             np.squeeze(img_util.to_rgb(r_batch_label)),
             np.squeeze(img_util.to_rgb(r_batch_weights, normalize=True))),
            axis=1)

        img_util.save_image(out_img,
                            "%s/img_%s.png" % (args.predict_dir, str(b)))
예제 #7
0
def test_sampling(sess, net_test):
    logging.info('#-----------------------------------------------#')
    logging.info('#        Starting Testing with sampling         #')
    logging.info('#-----------------------------------------------#')

    # # init variables
    # sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))

    trainer = SimpleTrainer(session=sess, train_dir=train_dir)
    # load model (if None provided, gets latest from train_dir/checkpoints, if none found doesn't load)
    trainer.load_checkpoint(args.checkpoint)

    # ###########################################################################
    # RUN UNET
    # ###########################################################################
    logging.debug(
        "predicting, sampling %s x %s times, batch_size %s" %
        (str(opts.n_samples), str(opts.resample_n), str(opts.batch_size)))

    sample_dir = args.predict_dir + os.sep + 'samples'
    os.mkdir(sample_dir)
    for b in range(opts.n_samples):

        for s in range(opts.resample_n):
            if s == 0:
                batch_img, batch_label, batch_pred = sess.run([
                    net_test.batch_img, net_test.batch_label,
                    net_test.prediction
                ])
                logging.debug('batch_img: %s %s' %
                              (str(batch_img.shape), str(batch_img.dtype)))
                logging.debug('batch_label: %s %s' %
                              (str(batch_label.shape), str(batch_label.dtype)))
                logging.debug('prediction: %s %s' %
                              (str(batch_pred.shape), str(batch_pred.dtype)))
            else:
                _, _, batch_pred = sess.run([
                    net_test.batch_img, net_test.batch_label,
                    net_test.prediction
                ])
                logging.debug('prediction: %s %s' %
                              (str(prediction_samples[s, ...].shape),
                               str(prediction_samples[s, ...].dtype)))

            r_batch_pred = np.reshape(batch_pred, [-1, batch_pred.shape[2]])
            if s == 0:
                prediction_samples = np.zeros([opts.resample_n] +
                                              list(r_batch_pred.shape),
                                              dtype=np.uint8)
            prediction_samples[s, ...] = r_batch_pred

            out_sample = img_util.to_rgb(prediction_samples[s, ...])
            img_util.save_image(out_sample,
                                "%s/sample_%s_%s.png" % (sample_dir, b, s))

        logging.info('finished resampling (%s), calculating entropy' %
                     (str(opts.resample_n)))

        entropy = calc.entropy_bin_array(prediction_samples)
        mean = np.mean(prediction_samples, axis=0)
        std = np.std(prediction_samples, axis=0)

        r_batch_img = np.reshape(batch_img,
                                 [-1, batch_img.shape[2], batch_img.shape[3]])
        r_batch_label = np.reshape(
            batch_label, [-1, batch_label.shape[2], batch_label.shape[3]])

        out_img = np.concatenate(
            (np.squeeze(img_util.to_rgb(r_batch_img)),
             np.squeeze(img_util.to_rgb(r_batch_label)),
             np.squeeze(img_util.to_rgb(mean)),
             np.squeeze(img_util.to_rgb_heatmap(entropy, rgb_256=True)),
             np.squeeze(img_util.to_rgb_heatmap(std, rgb_256=True))),
            axis=1)

        img_util.save_image(out_img, "%s/img_%s.png" % (args.predict_dir, b))
예제 #8
0
def test_debug(sess, net_test):
    logging.info('#-----------------------------------------------#')
    logging.info('#               Starting Testing (debug)        #')
    logging.info('#-----------------------------------------------#')

    # load model for testing (if None provided, searches in train_dir, if not found doesn't load)
    logging.info("Attempt restore with SimpleTrainer load from: %s" %
                 args.checkpoint)
    # load model for continued training (if None provided, searches in train_dir, if not found doesn't load)
    trainer = SimpleTrainer(session=sess, train_dir=train_dir)
    chkpt_loaded = trainer.load_checkpoint(args.checkpoint)

    # init variables if no checkpoint was loaded
    if not chkpt_loaded: sess.run(tf.group(tf.global_variables_initializer()))
    logging.info("Loaded variables from checkpoint"
                 if chkpt_loaded else "Randomly initialized (!) variables")

    # in case any variables are not yet initialized
    #initialize_uninitialized(sess)

    # ###########################################################################
    # RUN UNET
    # ###########################################################################
    logging.debug("predicting, sampling %s times, batch_size %s" %
                  (opts.n_samples, opts.batch_size))

    for b in range(opts.n_samples):
        batch_img, batch_label, batch_activations, batch_prediction = sess.run(
            [
                net_test.batch_img, net_test.batch_label, net_test.output_mask,
                net_test.prediction
            ])

        # out_img = np.squeeze(img_util.to_rgb(batch_activations))
        # img_util.save_image(out_img, "%s/img_%s_pred.png" % (args.test_dir, b))

        logging.debug(
            'batch_activations: %s %s' %
            (str(batch_activations.shape), str(batch_activations.dtype)))
        logging.debug(
            'batch_prediction: %s %s' %
            (str(batch_prediction.shape), str(batch_prediction.dtype)))
        logging.debug('batch_img: %s %s' %
                      (str(batch_img.shape), str(batch_img.dtype)))
        logging.debug('batch_label: %s %s' %
                      (str(batch_label.shape), str(batch_label.dtype)))

        # logging.debug('describe prediction_samples: ' + str(stats.describe(batch_activations)))
        # logging.debug('describe prediction_samples[0]: ' + str(stats.describe(prediction_samples[0])))
        # out_img = img_util.combine_img_prediction(batch_img, batch_label, batch_activations)

        r_batch_img = np.reshape(batch_img,
                                 [-1, batch_img.shape[2], batch_img.shape[3]])
        r_batch_label = np.reshape(
            batch_label, [-1, batch_label.shape[2], batch_label.shape[3]])
        r_batch_activations = np.reshape(
            batch_activations,
            [-1, batch_activations.shape[2], batch_activations.shape[3]])
        r_batch_prediction = np.reshape(batch_prediction,
                                        [-1, batch_prediction.shape[2]])

        #r_batch_softmax = calc.softmax(r_batch_activations, axis=-1) # slow
        argmax = np.argmax(r_batch_activations,
                           axis=-1)  # just take direct max

        out_img = np.concatenate(
            (np.squeeze(img_util.to_rgb(r_batch_img)),
             np.squeeze(img_util.to_rgb(r_batch_label)),
             np.squeeze(
                 img_util.to_rgb(r_batch_activations[..., 0, np.newaxis],
                                 normalize=False)),
             np.squeeze(
                 img_util.to_rgb(r_batch_activations[..., 1, np.newaxis],
                                 normalize=False)),
             np.squeeze(img_util.to_rgb(argmax[..., np.newaxis])),
             np.squeeze(img_util.to_rgb(r_batch_prediction[..., np.newaxis]))),
            axis=1)

        img_util.save_image(out_img, "%s/img_%s.png" % (args.predict_dir, b))
예제 #9
0
def test_core(sess, net_test):
    logging.info('#-----------------------------------------------#')
    logging.info('#               Starting Testing                #')
    logging.info('#-----------------------------------------------#')

    # load model for testing (if None provided, searches in train_dir, if not found doesn't load)
    trainer = SimpleTrainer(session=sess, train_dir=train_dir)
    chkpt_loaded = trainer.load_checkpoint(args.checkpoint)
    # init variables if no checkpoint was loaded
    if not chkpt_loaded: sess.run(tf.group(tf.global_variables_initializer()))
    logging.info("Loaded variables from checkpoint"
                 if chkpt_loaded else "Randomly initialized (!) variables")

    # ###########################################################################
    # RUN UNET
    # ###########################################################################
    logging.debug("predicting, sampling %s times, batch_size %s" %
                  (opts.n_samples, opts.batch_size))

    for b in range(opts.n_samples):
        batch_img, batch_label, batch_activations, batch_prediction = sess.run(
            [
                net_test.batch_img, net_test.batch_label, net_test.output_mask,
                net_test.prediction
            ])

        # out_img = np.squeeze(img_util.to_rgb(batch_activations))
        # img_util.save_image(out_img, "%s/img_%s_pred.png" % (args.test_dir, b))

        logging.debug(
            'batch_activations: %s %s' %
            (str(batch_activations.shape), str(batch_activations.dtype)))
        logging.debug(
            'batch_prediction: %s %s' %
            (str(batch_prediction.shape), str(batch_prediction.dtype)))
        logging.debug('batch_img: %s %s' %
                      (str(batch_img.shape), str(batch_img.dtype)))
        logging.debug('batch_label: %s %s' %
                      (str(batch_label.shape), str(batch_label.dtype)))

        r_batch_img = np.reshape(batch_img,
                                 [-1, batch_img.shape[2], batch_img.shape[3]])
        r_batch_label = np.reshape(
            batch_label, [-1, batch_label.shape[2], batch_label.shape[3]])
        r_batch_activations = np.reshape(
            batch_activations,
            [-1, batch_activations.shape[2], batch_activations.shape[3]])
        r_batch_prediction = np.reshape(batch_prediction,
                                        [-1, batch_prediction.shape[2]])

        out_img = np.concatenate(
            (np.squeeze(img_util.to_rgb(r_batch_img)),
             np.squeeze(img_util.to_rgb(r_batch_label)),
             np.squeeze(
                 img_util.to_rgb(r_batch_activations[..., 0, np.newaxis],
                                 normalize=True)),
             np.squeeze(
                 img_util.to_rgb(r_batch_activations[..., 1, np.newaxis],
                                 normalize=True)),
             np.squeeze(img_util.to_rgb(r_batch_prediction[..., np.newaxis]))),
            axis=1)

        img_util.save_image(out_img, "%s/img_%s.png" % (args.predict_dir, b))
예제 #10
0
def test_feed_sampling(sess, net_test):
    logging.info('#-----------------------------------------------#')
    logging.info('#        Starting Testing with sampling         #')
    logging.info('#-----------------------------------------------#')

    # # init variables
    # sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))

    trainer = SimpleTrainer(session=sess, train_dir=train_dir)
    # load model (if None provided, gets latest from train_dir/checkpoints, if none found doesn't load)
    trainer.load_checkpoint(args.checkpoint)

    # ###########################################################################
    # RUN UNET
    # ###########################################################################
    logging.debug(
        "predicting, sampling %s x %s times, batch_size %s" %
        (str(opts.n_samples), str(opts.resample_n), str(opts.batch_size)))

    sample_dir = args.predict_dir + os.sep + 'samples'
    os.mkdir(sample_dir)
    for b in range(opts.n_samples):
        data_provider = data_import.create_img_batch_provider(
            data_dir, opts.batch_size, shuffle_data=True)
        batch_img = next(data_provider)
        logging.debug('batch_img: %s %s' %
                      (str(batch_img.shape), str(batch_img.dtype)))
        if opts.resize_method == "center_crop":
            batch_img = img_util.crop_to_shape(batch_img, [
                batch_img.shape[0], opts.resize[0], opts.resize[1],
                batch_img.shape[3]
            ])
            logging.debug(' resized to: %s %s' %
                          (str(batch_img.shape), str(batch_img.dtype)))

        for s in range(opts.resample_n):
            batch_activations, batch_pred = sess.run(
                [net_test.output_mask, net_test.prediction],
                feed_dict={net_test.batch_img: batch_img})
            logging.debug('prediction: %s %s' %
                          (str(batch_pred.shape), str(batch_pred.dtype)))

            r_batch_pred = np.reshape(batch_pred, [-1, batch_pred.shape[2]])
            if s == 0:
                prediction_samples = np.zeros([opts.resample_n] +
                                              list(r_batch_pred.shape),
                                              dtype=np.uint8)
            prediction_samples[s, ...] = r_batch_pred

            out_sample = img_util.to_rgb(prediction_samples[s, ...])
            img_util.save_image(out_sample,
                                "%s/sample_%s_%s.png" % (sample_dir, b, s))

        logging.info('finished resampling (%s), calculating entropy' %
                     (str(opts.resample_n)))

        entropy = calc.entropy_bin_array(prediction_samples)
        mean = np.mean(prediction_samples, axis=0)
        std = np.std(prediction_samples, axis=0)

        r_batch_img = np.reshape(batch_img,
                                 [-1, batch_img.shape[2], batch_img.shape[3]])
        #r_batch_img_rgb = r_batch_img[..., 0:3]  # ignore fourth channel
        # r_batch_img_new[..., 0] = r_batch_img[..., 1]
        # r_batch_img_new[..., 1] = r_batch_img[..., 0]
        # r_batch_img_new[..., 2] = r_batch_img[..., 2]

        out_img = np.concatenate(
            (np.squeeze(img_util.to_rgb(r_batch_img_rgb)),
             np.squeeze(img_util.to_rgb(mean)),
             np.squeeze(img_util.to_rgb_heatmap(entropy, rgb_256=True)),
             np.squeeze(img_util.to_rgb_heatmap(std, rgb_256=True))),
            axis=1)

        img_util.save_image(out_img, "%s/img_%s.png" % (args.predict_dir, b))