def on_metric_call(self, epoch, iter, logs={}): """ compute metrics on several predictions """ with timer.Timer('predict metrics callback', self.verbose): # prepare metric met = np.zeros( (self.nb_samples, self.nb_labels, len(self.metrics))) # generate predictions # the idea is to predict either a full volume or just a slice, # depending on what we need gen = _generate_predictions(self.model, self.data_generator, self.batch_size, self.nb_samples, self.vol_params) batch_idx = 0 for (vol_true, vol_pred) in gen: for idx, metric in enumerate(self.metrics): met[batch_idx, :, idx] = metric(vol_true, vol_pred) batch_idx += 1 # write metric to csv file if self.filepath is not None: for idx, metric in enumerate(self.metrics): filen = self.filepath.format(epoch=epoch, iter=iter, metric=metric.__name__) np.savetxt(filen, met[:, :, idx], fmt='%f', delimiter=',') else: meanmet = np.nanmean(met, axis=0) for midx, metric in enumerate(self.metrics): for idx in range(self.nb_labels): varname = '%s_label_%d' % (metric.__name__, self.label_ids[idx]) logs[varname] = meanmet[idx, midx]
def next_pred_label(model, data_generator, verbose=False): """ predict the next sample batch from the generator, and compute max labels return sample, prediction, max_labels """ sample = next(data_generator) with timer.Timer('prediction', verbose): pred = model.predict(sample[0]) sample_input = sample[0] if not isinstance(sample[0], (list, tuple)) else sample[0][0] max_labels = pred_to_label(sample_input, pred) return (sample, pred) + max_labels
def on_model_save(self, epoch, iter, logs=None): """ save the model to hdf5. Code mostly from keras core """ with timer.Timer('model save callback', self.verbose): logs = logs or {} num_outputs = len(self.model.outputs) self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 filepath = self.filepath.format(epoch=epoch, iter=iter, **logs) if self.save_best_only: current = logs.get(self.monitor) if current is None: warnings.warn( 'Can save best model only with %s available, ' 'skipping.' % (self.monitor), RuntimeWarning) else: if self.monitor_op(current, self.best): if self.verbose > 0: print( 'Epoch %05d: Iter%05d: %s improved from %0.5f to %0.5f,' ' saving model to %s' % (epoch, iter, self.monitor, self.best, current, filepath)) self.best = current if self.save_weights_only: self.model.layers[-(num_outputs + 1)].save_weights( filepath, overwrite=True) else: self.model.layers[-(num_outputs + 1)].save( filepath, overwrite=True) else: if self.verbose > 0: print( 'Epoch %05d Iter%05d: %s did not improve' % (epoch, iter, self.monitor)) else: if self.verbose > 0: print('Epoch %05d: saving model to %s' % (epoch, filepath)) if self.save_weights_only: self.model.layers[-(num_outputs + 1)].save_weights( filepath, overwrite=True) else: self.model.layers[-(num_outputs + 1)].save( filepath, overwrite=True)
def next_vol_pred(model, data_generator, verbose=False): """ get the next batch, predict model output returns (input_vol, y_true, y_pred, <prior>) """ # batch to input, output and prediction sample = next(data_generator) with timer.Timer('prediction', verbose): pred = model.predict(sample[0]) data = (sample[0], sample[1], pred) if isinstance(sample[0], (list, tuple)): # if given prior, might be a list data = (sample[0][0], sample[1], pred, sample[0][1]) return data
def _load_medical_volume(filename, ext, verbose=False): """ load a medical volume from one of a number of file types """ with timer.Timer('load_vol', verbose >= 2): if ext == '.npz': vol_file = np.load(filename) vol_data = vol_file['vol_data'] elif ext == 'npy': vol_data = np.load(filename) elif ext == '.mgz' or ext == '.nii' or ext == '.nii.gz': vol_med = nib.load(filename) vol_data = vol_med.get_data() else: raise ValueError("Unexpected extension %s" % ext) return vol_data
def on_plot_save(self, epoch, iter, logs={}): # import neuron sandbox # has to be here, can't be at the top, due to cyclical imports (??) # TODO: should just pass the function to compute the figures given the model and generator with timer.Timer('plot callback', self.verbose): if len(self.run.grid_size) == 3: collapse_2d = [0, 1, 2] else: collapse_2d = [2] # TODO: show_example_prediction_result is actually in neuron_sandbox for now exampl = show_example_prediction_result( self.model, self.generator, self.run, self.data, test_batch_size=1, test_model_names=None, test_grid_size=self.run.grid_size, ccmap=None, collapse_2d=collapse_2d, slice_nr=None, plt_width=17, verbose=self.verbose) # save, then close figs = exampl[1:] for idx, fig in enumerate(figs): dirn = "dirn_%d" % idx slice_nr = 0 filename = self.savefilepath.format(epoch=epoch, iter=iter, axis=dirn, slice_nr=slice_nr) fig.savefig(filename) plt.close()
def predict_volumes( models, data_generator, batch_size, patch_size, patch_stride, grid_size, nan_func=np.nanmedian, do_extra_vol=False, # should compute vols beyond label do_prob_of_true=False, # should compute prob_of_true vols verbose=False): """ Note: we allow models to be a list or a single model. Normally, if you'd like to run a function over a list for some param, you can simply loop outside of the function. here, however, we are dealing with a generator, and want the output of that generator to be consistent for each model. Returns: if models isa list of more than one model: a tuple of model entried, each entry is a tuple of: true_label, pred_label, <vol>, <prior_label>, <pred_prob_of_true>, <prior_prob_of_true> if models is just one model: a tuple of (true_label, pred_label, <vol>, <prior_label>, <pred_prob_of_true>, <prior_prob_of_true>) TODO: could add prior """ if not isinstance(models, (list, tuple)): models = (models, ) # get the input and prediction stack with timer.Timer('predict_volume_stack', verbose): vol_stack = predict_volume_stack(models, data_generator, batch_size, grid_size, verbose) if len(models) == 1: do_prior = len(vol_stack) == 4 else: do_prior = len(vol_stack[0]) == 4 # go through models and volumes ret = () for midx, _ in enumerate(models): stack = vol_stack if len(models) == 1 else vol_stack[midx] if do_prior: all_true, all_pred, all_vol, all_prior = stack else: all_true, all_pred, all_vol = stack # get max labels all_true_label, all_pred_label = pred_to_label(all_true, all_pred) # quilt volumes and aggregate overlapping patches, if any args = [patch_size, grid_size, patch_stride] label_kwargs = { 'nan_func_layers': nan_func, 'nan_func_K': nan_func, 'verbose': verbose } vol_true_label = _quilt(all_true_label, *args, **label_kwargs).astype('int') vol_pred_label = _quilt(all_pred_label, *args, **label_kwargs).astype('int') ret_set = (vol_true_label, vol_pred_label) if do_extra_vol: vol_input = _quilt(all_vol, *args) ret_set += (vol_input, ) if do_prior: all_prior_label, = pred_to_label(all_prior) vol_prior_label = _quilt(all_prior_label, *args, **label_kwargs).astype('int') ret_set += (vol_prior_label, ) # compute the probability of prediction and prior # instead of quilting the probabilistic volumes and then computing the probability # of true label, which takes a long time, we'll first compute the probability of label, # and then quilt. This is faster, but we'll need to take median votes if do_extra_vol and do_prob_of_true: all_pp = prob_of_label(all_pred, all_true_label) pred_prob_of_true = _quilt(all_pp, *args, **label_kwargs) ret_set += (pred_prob_of_true, ) if do_prior: all_pp = prob_of_label(all_prior, all_true_label) prior_prob_of_true = _quilt(all_pp, *args, **label_kwargs) ret_set += (prior_prob_of_true, ) ret += (ret_set, ) if len(models) == 1: ret = ret[0] # return return ret
def vol_prior_hack( *args, proc_vol_fn=None, proc_seg_fn=None, prior_type='location', # file-static, file-gen, location prior_file=None, # prior filename prior_feed='input', # input or output patch_stride=1, patch_size=None, batch_size=1, collapse_2d=None, extract_slice=None, force_binary=False, nb_input_feats=1, verbose=False, vol_rand_seed=None, **kwargs): """ """ # prepare the vol_seg gen = vol_seg_hack(*args, **kwargs, proc_vol_fn=None, proc_seg_fn=None, collapse_2d=collapse_2d, extract_slice=extract_slice, force_binary=force_binary, verbose=verbose, patch_size=patch_size, patch_stride=patch_stride, batch_size=batch_size, vol_rand_seed=vol_rand_seed, nb_input_feats=nb_input_feats) # get prior if prior_type == 'location': prior_vol = nd.volsize2ndgrid(vol_size) prior_vol = np.transpose(prior_vol, [1, 2, 3, 0]) prior_vol = np.expand_dims(prior_vol, axis=0) # reshape for model elif prior_type == 'file': # assumes a npz filename passed in prior_file with timer.Timer('loading prior', True): data = np.load(prior_file) prior_vol = data['prior'].astype('float16') else: # assumes a volume with timer.Timer('astyping prior', verbose): prior_vol = prior_file if not (prior_vol.dtype == 'float16'): prior_vol = prior_vol.astype('float16') if force_binary: nb_labels = prior_vol.shape[-1] prior_vol[:, :, :, 1] = np.sum(prior_vol[:, :, :, 1:nb_labels], 3) prior_vol = np.delete(prior_vol, range(2, nb_labels), 3) nb_channels = prior_vol.shape[-1] if extract_slice is not None: if isinstance(extract_slice, int): prior_vol = prior_vol[:, :, extract_slice, np.newaxis, :] else: # assume slices prior_vol = prior_vol[:, :, extract_slice, :] # get the prior to have the right volume [x, y, z, nb_channels] assert np.ndim(prior_vol) == 4 or np.ndim( prior_vol) == 3, "prior is the wrong size" # prior generator if patch_size is None: patch_size = prior_vol.shape[0:3] assert len(patch_size) == len(patch_stride) prior_gen = patch( prior_vol, [*patch_size, nb_channels], patch_stride=[*patch_stride, nb_channels], batch_size=batch_size, collapse_2d=collapse_2d, keep_vol_size=True, infinite=True, #variable_batch_size=True, # this nb_labels_reshape=0) # assert next(prior_gen) is None, "bad prior gen setup" # generator loop while 1: # generate input and output volumes input_vol = next(gen) if verbose and np.all(input_vol.flat == 0): print("all entries are 0") # generate prior batch # with timer.Timer("with send?"): # prior_batch = prior_gen.send(input_vol.shape[0]) prior_batch = next(prior_gen) if prior_feed == 'input': yield ([input_vol, prior_batch], input_vol) else: assert prior_feed == 'output' yield (input_vol, [input_vol, prior_batch])
def add_prior( gen, proc_vol_fn=None, proc_seg_fn=None, prior_type='location', # file-static, file-gen, location prior_file=None, # prior filename prior_feed='input', # input or output patch_stride=1, patch_size=None, batch_size=1, collapse_2d=None, extract_slice=None, force_binary=False, verbose=False, patch_rand=False, patch_rand_seed=None): """ # # add a prior generator to a given generator # with the number of patches in batch matching output of gen """ # get prior if prior_type == 'location': prior_vol = nd.volsize2ndgrid(vol_size) prior_vol = np.transpose(prior_vol, [1, 2, 3, 0]) prior_vol = np.expand_dims(prior_vol, axis=0) # reshape for model elif prior_type == 'file': # assumes a npz filename passed in prior_file with timer.Timer('loading prior', True): data = np.load(prior_file) prior_vol = data['prior'].astype('float16') else: # assumes a volume with timer.Timer('loading prior', True): prior_vol = prior_file.astype('float16') if force_binary: nb_labels = prior_vol.shape[-1] prior_vol[:, :, :, 1] = np.sum(prior_vol[:, :, :, 1:nb_labels], 3) prior_vol = np.delete(prior_vol, range(2, nb_labels), 3) nb_channels = prior_vol.shape[-1] if extract_slice is not None: if isinstance(extract_slice, int): prior_vol = prior_vol[:, :, extract_slice, np.newaxis, :] else: # assume slices prior_vol = prior_vol[:, :, extract_slice, :] # get the prior to have the right volume [x, y, z, nb_channels] assert np.ndim(prior_vol) == 4 or np.ndim( prior_vol) == 3, "prior is the wrong size" # prior generator if patch_size is None: patch_size = prior_vol.shape[0:3] assert len(patch_size) == len(patch_stride) prior_gen = patch(prior_vol, [*patch_size, nb_channels], patch_stride=[*patch_stride, nb_channels], batch_size=batch_size, collapse_2d=collapse_2d, keep_vol_size=True, infinite=True, patch_rand=patch_rand, patch_rand_seed=patch_rand_seed, variable_batch_size=True, nb_labels_reshape=0) assert next(prior_gen) is None, "bad prior gen setup" # generator loop while 1: # generate input and output volumes gen_sample = next(gen) # generate prior batch gs_sample = _get_shape(gen_sample) prior_batch = prior_gen.send(gs_sample) yield (gen_sample, prior_batch)