Beispiel #1
0
    def on_metric_call(self, epoch, iter, logs={}):
        """ compute metrics on several predictions """
        with timer.Timer('predict metrics callback', self.verbose):

            # prepare metric
            met = np.zeros(
                (self.nb_samples, self.nb_labels, len(self.metrics)))

            # generate predictions
            # the idea is to predict either a full volume or just a slice,
            # depending on what we need
            gen = _generate_predictions(self.model, self.data_generator,
                                        self.batch_size, self.nb_samples,
                                        self.vol_params)
            batch_idx = 0
            for (vol_true, vol_pred) in gen:
                for idx, metric in enumerate(self.metrics):
                    met[batch_idx, :, idx] = metric(vol_true, vol_pred)
                batch_idx += 1

            # write metric to csv file
            if self.filepath is not None:
                for idx, metric in enumerate(self.metrics):
                    filen = self.filepath.format(epoch=epoch,
                                                 iter=iter,
                                                 metric=metric.__name__)
                    np.savetxt(filen, met[:, :, idx], fmt='%f', delimiter=',')
            else:
                meanmet = np.nanmean(met, axis=0)
                for midx, metric in enumerate(self.metrics):
                    for idx in range(self.nb_labels):
                        varname = '%s_label_%d' % (metric.__name__,
                                                   self.label_ids[idx])
                        logs[varname] = meanmet[idx, midx]
Beispiel #2
0
def next_pred_label(model, data_generator, verbose=False):
    """
    predict the next sample batch from the generator, and compute max labels
    return sample, prediction, max_labels
    """
    sample = next(data_generator)
    with timer.Timer('prediction', verbose):
        pred = model.predict(sample[0])
    sample_input = sample[0] if not isinstance(sample[0],
                                               (list, tuple)) else sample[0][0]
    max_labels = pred_to_label(sample_input, pred)
    return (sample, pred) + max_labels
Beispiel #3
0
    def on_model_save(self, epoch, iter, logs=None):
        """ save the model to hdf5. Code mostly from keras core """

        with timer.Timer('model save callback', self.verbose):
            logs = logs or {}
            num_outputs = len(self.model.outputs)
            self.epochs_since_last_save += 1
            if self.epochs_since_last_save >= self.period:
                self.epochs_since_last_save = 0
                filepath = self.filepath.format(epoch=epoch, iter=iter, **logs)
                if self.save_best_only:
                    current = logs.get(self.monitor)
                    if current is None:
                        warnings.warn(
                            'Can save best model only with %s available, '
                            'skipping.' % (self.monitor), RuntimeWarning)
                    else:
                        if self.monitor_op(current, self.best):
                            if self.verbose > 0:
                                print(
                                    'Epoch %05d: Iter%05d: %s improved from %0.5f to %0.5f,'
                                    ' saving model to %s' %
                                    (epoch, iter, self.monitor, self.best,
                                     current, filepath))
                            self.best = current
                            if self.save_weights_only:
                                self.model.layers[-(num_outputs +
                                                    1)].save_weights(
                                                        filepath,
                                                        overwrite=True)
                            else:
                                self.model.layers[-(num_outputs + 1)].save(
                                    filepath, overwrite=True)
                        else:
                            if self.verbose > 0:
                                print(
                                    'Epoch %05d Iter%05d: %s did not improve' %
                                    (epoch, iter, self.monitor))
                else:
                    if self.verbose > 0:
                        print('Epoch %05d: saving model to %s' %
                              (epoch, filepath))
                    if self.save_weights_only:
                        self.model.layers[-(num_outputs + 1)].save_weights(
                            filepath, overwrite=True)
                    else:
                        self.model.layers[-(num_outputs + 1)].save(
                            filepath, overwrite=True)
Beispiel #4
0
def next_vol_pred(model, data_generator, verbose=False):
    """
    get the next batch, predict model output

    returns (input_vol, y_true, y_pred, <prior>)
    """

    # batch to input, output and prediction
    sample = next(data_generator)
    with timer.Timer('prediction', verbose):
        pred = model.predict(sample[0])
    data = (sample[0], sample[1], pred)
    if isinstance(sample[0], (list, tuple)):  # if given prior, might be a list
        data = (sample[0][0], sample[1], pred, sample[0][1])

    return data
Beispiel #5
0
    def on_plot_save(self, epoch, iter, logs={}):
        # import neuron sandbox
        # has to be here, can't be at the top, due to cyclical imports (??)
        # TODO: should just pass the function to compute the figures given the model and generator
        import neuron.sandbox as nrn_sandbox
        reload(nrn_sandbox)

        with timer.Timer('plot callback', self.verbose):
            if len(self.run.grid_size) == 3:
                collapse_2d = [0, 1, 2]
            else:
                collapse_2d = [2]

            exampl = nrn_sandbox.show_example_prediction_result(
                self.model,
                self.generator,
                self.run,
                self.data,
                test_batch_size=1,
                test_model_names=None,
                test_grid_size=self.run.grid_size,
                ccmap=None,
                collapse_2d=collapse_2d,
                slice_nr=None,
                plt_width=17,
                verbose=self.verbose)

            # save, then close
            figs = exampl[1:]
            for idx, fig in enumerate(figs):
                dirn = "dirn_%d" % idx
                slice_nr = 0
                filename = self.savefilepath.format(epoch=epoch,
                                                    iter=iter,
                                                    axis=dirn,
                                                    slice_nr=slice_nr)
                fig.savefig(filename)
            plt.close()
Beispiel #6
0
def predict_volumes(
        models,
        data_generator,
        batch_size,
        patch_size,
        patch_stride,
        grid_size,
        nan_func=np.nanmedian,
        do_extra_vol=False,  # should compute vols beyond label
        do_prob_of_true=False,  # should compute prob_of_true vols
        verbose=False):
    """
    Note: we allow models to be a list or a single model.
    Normally, if you'd like to run a function over a list for some param,
    you can simply loop outside of the function. here, however, we are dealing with a generator,
    and want the output of that generator to be consistent for each model.

    Returns:
    if models isa list of more than one model:
        a tuple of model entried, each entry is a tuple of:
        true_label, pred_label, <vol>, <prior_label>, <pred_prob_of_true>, <prior_prob_of_true>
    if models is just one model:
        a tuple of
        (true_label, pred_label, <vol>, <prior_label>, <pred_prob_of_true>, <prior_prob_of_true>)

    TODO: could add prior
    """

    if not isinstance(models, (list, tuple)):
        models = (models, )

    # get the input and prediction stack
    with timer.Timer('predict_volume_stack', verbose):
        vol_stack = predict_volume_stack(models, data_generator, batch_size,
                                         grid_size, verbose)
    if len(models) == 1:
        do_prior = len(vol_stack) == 4
    else:
        do_prior = len(vol_stack[0]) == 4

    # go through models and volumes
    ret = ()
    for midx, _ in enumerate(models):

        stack = vol_stack if len(models) == 1 else vol_stack[midx]

        if do_prior:
            all_true, all_pred, all_vol, all_prior = stack
        else:
            all_true, all_pred, all_vol = stack

        # get max labels
        all_true_label, all_pred_label = pred_to_label(all_true, all_pred)

        # quilt volumes and aggregate overlapping patches, if any
        args = [patch_size, grid_size, patch_stride]
        label_kwargs = {
            'nan_func_layers': nan_func,
            'nan_func_K': nan_func,
            'verbose': verbose
        }
        vol_true_label = _quilt(all_true_label, *args,
                                **label_kwargs).astype('int')
        vol_pred_label = _quilt(all_pred_label, *args,
                                **label_kwargs).astype('int')

        ret_set = (vol_true_label, vol_pred_label)

        if do_extra_vol:
            vol_input = _quilt(all_vol, *args)
            ret_set += (vol_input, )

            if do_prior:
                all_prior_label, = pred_to_label(all_prior)
                vol_prior_label = _quilt(all_prior_label, *args,
                                         **label_kwargs).astype('int')
                ret_set += (vol_prior_label, )

        # compute the probability of prediction and prior
        # instead of quilting the probabilistic volumes and then computing the probability
        # of true label, which takes a long time, we'll first compute the probability of label,
        # and then quilt. This is faster, but we'll need to take median votes
        if do_extra_vol and do_prob_of_true:
            all_pp = prob_of_label(all_pred, all_true_label)
            pred_prob_of_true = _quilt(all_pp, *args, **label_kwargs)
            ret_set += (pred_prob_of_true, )

            if do_prior:
                all_pp = prob_of_label(all_prior, all_true_label)
                prior_prob_of_true = _quilt(all_pp, *args, **label_kwargs)

                ret_set += (prior_prob_of_true, )

        ret += (ret_set, )

    if len(models) == 1:
        ret = ret[0]

    # return
    return ret