示例#1
0
    def on_metric_call(self, epoch, iter, logs={}):
        """ compute metrics on several predictions """
        with timer.Timer('predict metrics callback', self.verbose):

            # prepare metric
            met = np.zeros(
                (self.nb_samples, self.nb_labels, len(self.metrics)))

            # generate predictions
            # the idea is to predict either a full volume or just a slice,
            # depending on what we need
            gen = _generate_predictions(self.model, self.data_generator,
                                        self.batch_size, self.nb_samples,
                                        self.vol_params)
            batch_idx = 0
            for (vol_true, vol_pred) in gen:
                for idx, metric in enumerate(self.metrics):
                    met[batch_idx, :, idx] = metric(vol_true, vol_pred)
                batch_idx += 1

            # write metric to csv file
            if self.filepath is not None:
                for idx, metric in enumerate(self.metrics):
                    filen = self.filepath.format(epoch=epoch,
                                                 iter=iter,
                                                 metric=metric.__name__)
                    np.savetxt(filen, met[:, :, idx], fmt='%f', delimiter=',')
            else:
                meanmet = np.nanmean(met, axis=0)
                for midx, metric in enumerate(self.metrics):
                    for idx in range(self.nb_labels):
                        varname = '%s_label_%d' % (metric.__name__,
                                                   self.label_ids[idx])
                        logs[varname] = meanmet[idx, midx]
示例#2
0
def next_pred_label(model, data_generator, verbose=False):
    """
    predict the next sample batch from the generator, and compute max labels
    return sample, prediction, max_labels
    """
    sample = next(data_generator)
    with timer.Timer('prediction', verbose):
        pred = model.predict(sample[0])
    sample_input = sample[0] if not isinstance(sample[0],
                                               (list, tuple)) else sample[0][0]
    max_labels = pred_to_label(sample_input, pred)
    return (sample, pred) + max_labels
示例#3
0
    def on_model_save(self, epoch, iter, logs=None):
        """ save the model to hdf5. Code mostly from keras core """

        with timer.Timer('model save callback', self.verbose):
            logs = logs or {}
            num_outputs = len(self.model.outputs)
            self.epochs_since_last_save += 1
            if self.epochs_since_last_save >= self.period:
                self.epochs_since_last_save = 0
                filepath = self.filepath.format(epoch=epoch, iter=iter, **logs)
                if self.save_best_only:
                    current = logs.get(self.monitor)
                    if current is None:
                        warnings.warn(
                            'Can save best model only with %s available, '
                            'skipping.' % (self.monitor), RuntimeWarning)
                    else:
                        if self.monitor_op(current, self.best):
                            if self.verbose > 0:
                                print(
                                    'Epoch %05d: Iter%05d: %s improved from %0.5f to %0.5f,'
                                    ' saving model to %s' %
                                    (epoch, iter, self.monitor, self.best,
                                     current, filepath))
                            self.best = current
                            if self.save_weights_only:
                                self.model.layers[-(num_outputs +
                                                    1)].save_weights(
                                                        filepath,
                                                        overwrite=True)
                            else:
                                self.model.layers[-(num_outputs + 1)].save(
                                    filepath, overwrite=True)
                        else:
                            if self.verbose > 0:
                                print(
                                    'Epoch %05d Iter%05d: %s did not improve' %
                                    (epoch, iter, self.monitor))
                else:
                    if self.verbose > 0:
                        print('Epoch %05d: saving model to %s' %
                              (epoch, filepath))
                    if self.save_weights_only:
                        self.model.layers[-(num_outputs + 1)].save_weights(
                            filepath, overwrite=True)
                    else:
                        self.model.layers[-(num_outputs + 1)].save(
                            filepath, overwrite=True)
示例#4
0
def next_vol_pred(model, data_generator, verbose=False):
    """
    get the next batch, predict model output

    returns (input_vol, y_true, y_pred, <prior>)
    """

    # batch to input, output and prediction
    sample = next(data_generator)
    with timer.Timer('prediction', verbose):
        pred = model.predict(sample[0])
    data = (sample[0], sample[1], pred)
    if isinstance(sample[0], (list, tuple)):  # if given prior, might be a list
        data = (sample[0][0], sample[1], pred, sample[0][1])

    return data
示例#5
0
def _load_medical_volume(filename, ext, verbose=False):
    """
    load a medical volume from one of a number of file types
    """
    with timer.Timer('load_vol', verbose >= 2):
        if ext == '.npz':
            vol_file = np.load(filename)
            vol_data = vol_file['vol_data']
        elif ext == 'npy':
            vol_data = np.load(filename)
        elif ext == '.mgz' or ext == '.nii' or ext == '.nii.gz':
            vol_med = nib.load(filename)
            vol_data = vol_med.get_data()
        else:
            raise ValueError("Unexpected extension %s" % ext)

    return vol_data
示例#6
0
    def on_plot_save(self, epoch, iter, logs={}):
        # import neuron sandbox
        # has to be here, can't be at the top, due to cyclical imports (??)
        # TODO: should just pass the function to compute the figures given the model and generator

        with timer.Timer('plot callback', self.verbose):
            if len(self.run.grid_size) == 3:
                collapse_2d = [0, 1, 2]
            else:
                collapse_2d = [2]

            # TODO: show_example_prediction_result is actually in neuron_sandbox for now
            exampl = show_example_prediction_result(
                self.model,
                self.generator,
                self.run,
                self.data,
                test_batch_size=1,
                test_model_names=None,
                test_grid_size=self.run.grid_size,
                ccmap=None,
                collapse_2d=collapse_2d,
                slice_nr=None,
                plt_width=17,
                verbose=self.verbose)

            # save, then close
            figs = exampl[1:]
            for idx, fig in enumerate(figs):
                dirn = "dirn_%d" % idx
                slice_nr = 0
                filename = self.savefilepath.format(epoch=epoch,
                                                    iter=iter,
                                                    axis=dirn,
                                                    slice_nr=slice_nr)
                fig.savefig(filename)
            plt.close()
示例#7
0
def predict_volumes(
        models,
        data_generator,
        batch_size,
        patch_size,
        patch_stride,
        grid_size,
        nan_func=np.nanmedian,
        do_extra_vol=False,  # should compute vols beyond label
        do_prob_of_true=False,  # should compute prob_of_true vols
        verbose=False):
    """
    Note: we allow models to be a list or a single model.
    Normally, if you'd like to run a function over a list for some param,
    you can simply loop outside of the function. here, however, we are dealing with a generator,
    and want the output of that generator to be consistent for each model.

    Returns:
    if models isa list of more than one model:
        a tuple of model entried, each entry is a tuple of:
        true_label, pred_label, <vol>, <prior_label>, <pred_prob_of_true>, <prior_prob_of_true>
    if models is just one model:
        a tuple of
        (true_label, pred_label, <vol>, <prior_label>, <pred_prob_of_true>, <prior_prob_of_true>)

    TODO: could add prior
    """

    if not isinstance(models, (list, tuple)):
        models = (models, )

    # get the input and prediction stack
    with timer.Timer('predict_volume_stack', verbose):
        vol_stack = predict_volume_stack(models, data_generator, batch_size,
                                         grid_size, verbose)
    if len(models) == 1:
        do_prior = len(vol_stack) == 4
    else:
        do_prior = len(vol_stack[0]) == 4

    # go through models and volumes
    ret = ()
    for midx, _ in enumerate(models):

        stack = vol_stack if len(models) == 1 else vol_stack[midx]

        if do_prior:
            all_true, all_pred, all_vol, all_prior = stack
        else:
            all_true, all_pred, all_vol = stack

        # get max labels
        all_true_label, all_pred_label = pred_to_label(all_true, all_pred)

        # quilt volumes and aggregate overlapping patches, if any
        args = [patch_size, grid_size, patch_stride]
        label_kwargs = {
            'nan_func_layers': nan_func,
            'nan_func_K': nan_func,
            'verbose': verbose
        }
        vol_true_label = _quilt(all_true_label, *args,
                                **label_kwargs).astype('int')
        vol_pred_label = _quilt(all_pred_label, *args,
                                **label_kwargs).astype('int')

        ret_set = (vol_true_label, vol_pred_label)

        if do_extra_vol:
            vol_input = _quilt(all_vol, *args)
            ret_set += (vol_input, )

            if do_prior:
                all_prior_label, = pred_to_label(all_prior)
                vol_prior_label = _quilt(all_prior_label, *args,
                                         **label_kwargs).astype('int')
                ret_set += (vol_prior_label, )

        # compute the probability of prediction and prior
        # instead of quilting the probabilistic volumes and then computing the probability
        # of true label, which takes a long time, we'll first compute the probability of label,
        # and then quilt. This is faster, but we'll need to take median votes
        if do_extra_vol and do_prob_of_true:
            all_pp = prob_of_label(all_pred, all_true_label)
            pred_prob_of_true = _quilt(all_pp, *args, **label_kwargs)
            ret_set += (pred_prob_of_true, )

            if do_prior:
                all_pp = prob_of_label(all_prior, all_true_label)
                prior_prob_of_true = _quilt(all_pp, *args, **label_kwargs)

                ret_set += (prior_prob_of_true, )

        ret += (ret_set, )

    if len(models) == 1:
        ret = ret[0]

    # return
    return ret
示例#8
0
def vol_prior_hack(
        *args,
        proc_vol_fn=None,
        proc_seg_fn=None,
        prior_type='location',  # file-static, file-gen, location
        prior_file=None,  # prior filename
        prior_feed='input',  # input or output
        patch_stride=1,
        patch_size=None,
        batch_size=1,
        collapse_2d=None,
        extract_slice=None,
        force_binary=False,
        nb_input_feats=1,
        verbose=False,
        vol_rand_seed=None,
        **kwargs):
    """
    
    """
    # prepare the vol_seg
    gen = vol_seg_hack(*args,
                       **kwargs,
                       proc_vol_fn=None,
                       proc_seg_fn=None,
                       collapse_2d=collapse_2d,
                       extract_slice=extract_slice,
                       force_binary=force_binary,
                       verbose=verbose,
                       patch_size=patch_size,
                       patch_stride=patch_stride,
                       batch_size=batch_size,
                       vol_rand_seed=vol_rand_seed,
                       nb_input_feats=nb_input_feats)

    # get prior
    if prior_type == 'location':
        prior_vol = nd.volsize2ndgrid(vol_size)
        prior_vol = np.transpose(prior_vol, [1, 2, 3, 0])
        prior_vol = np.expand_dims(prior_vol, axis=0)  # reshape for model

    elif prior_type == 'file':  # assumes a npz filename passed in prior_file
        with timer.Timer('loading prior', True):
            data = np.load(prior_file)
            prior_vol = data['prior'].astype('float16')
    else:  # assumes a volume
        with timer.Timer('astyping prior', verbose):
            prior_vol = prior_file
            if not (prior_vol.dtype == 'float16'):
                prior_vol = prior_vol.astype('float16')

    if force_binary:
        nb_labels = prior_vol.shape[-1]
        prior_vol[:, :, :, 1] = np.sum(prior_vol[:, :, :, 1:nb_labels], 3)
        prior_vol = np.delete(prior_vol, range(2, nb_labels), 3)

    nb_channels = prior_vol.shape[-1]

    if extract_slice is not None:
        if isinstance(extract_slice, int):
            prior_vol = prior_vol[:, :, extract_slice, np.newaxis, :]
        else:  # assume slices
            prior_vol = prior_vol[:, :, extract_slice, :]

    # get the prior to have the right volume [x, y, z, nb_channels]
    assert np.ndim(prior_vol) == 4 or np.ndim(
        prior_vol) == 3, "prior is the wrong size"

    # prior generator
    if patch_size is None:
        patch_size = prior_vol.shape[0:3]
    assert len(patch_size) == len(patch_stride)
    prior_gen = patch(
        prior_vol,
        [*patch_size, nb_channels],
        patch_stride=[*patch_stride, nb_channels],
        batch_size=batch_size,
        collapse_2d=collapse_2d,
        keep_vol_size=True,
        infinite=True,
        #variable_batch_size=True,  # this
        nb_labels_reshape=0)
    # assert next(prior_gen) is None, "bad prior gen setup"

    # generator loop
    while 1:

        # generate input and output volumes
        input_vol = next(gen)

        if verbose and np.all(input_vol.flat == 0):
            print("all entries are 0")

        # generate prior batch
        # with timer.Timer("with send?"):
        # prior_batch = prior_gen.send(input_vol.shape[0])
        prior_batch = next(prior_gen)

        if prior_feed == 'input':
            yield ([input_vol, prior_batch], input_vol)
        else:
            assert prior_feed == 'output'
            yield (input_vol, [input_vol, prior_batch])
示例#9
0
def add_prior(
        gen,
        proc_vol_fn=None,
        proc_seg_fn=None,
        prior_type='location',  # file-static, file-gen, location
        prior_file=None,  # prior filename
        prior_feed='input',  # input or output
        patch_stride=1,
        patch_size=None,
        batch_size=1,
        collapse_2d=None,
        extract_slice=None,
        force_binary=False,
        verbose=False,
        patch_rand=False,
        patch_rand_seed=None):
    """
    #
    # add a prior generator to a given generator
    # with the number of patches in batch matching output of gen
    """

    # get prior
    if prior_type == 'location':
        prior_vol = nd.volsize2ndgrid(vol_size)
        prior_vol = np.transpose(prior_vol, [1, 2, 3, 0])
        prior_vol = np.expand_dims(prior_vol, axis=0)  # reshape for model

    elif prior_type == 'file':  # assumes a npz filename passed in prior_file
        with timer.Timer('loading prior', True):
            data = np.load(prior_file)
            prior_vol = data['prior'].astype('float16')

    else:  # assumes a volume
        with timer.Timer('loading prior', True):
            prior_vol = prior_file.astype('float16')

    if force_binary:
        nb_labels = prior_vol.shape[-1]
        prior_vol[:, :, :, 1] = np.sum(prior_vol[:, :, :, 1:nb_labels], 3)
        prior_vol = np.delete(prior_vol, range(2, nb_labels), 3)

    nb_channels = prior_vol.shape[-1]

    if extract_slice is not None:
        if isinstance(extract_slice, int):
            prior_vol = prior_vol[:, :, extract_slice, np.newaxis, :]
        else:  # assume slices
            prior_vol = prior_vol[:, :, extract_slice, :]

    # get the prior to have the right volume [x, y, z, nb_channels]
    assert np.ndim(prior_vol) == 4 or np.ndim(
        prior_vol) == 3, "prior is the wrong size"

    # prior generator
    if patch_size is None:
        patch_size = prior_vol.shape[0:3]
    assert len(patch_size) == len(patch_stride)
    prior_gen = patch(prior_vol, [*patch_size, nb_channels],
                      patch_stride=[*patch_stride, nb_channels],
                      batch_size=batch_size,
                      collapse_2d=collapse_2d,
                      keep_vol_size=True,
                      infinite=True,
                      patch_rand=patch_rand,
                      patch_rand_seed=patch_rand_seed,
                      variable_batch_size=True,
                      nb_labels_reshape=0)
    assert next(prior_gen) is None, "bad prior gen setup"

    # generator loop
    while 1:

        # generate input and output volumes
        gen_sample = next(gen)

        # generate prior batch
        gs_sample = _get_shape(gen_sample)
        prior_batch = prior_gen.send(gs_sample)

        yield (gen_sample, prior_batch)