def train(solver, test_net, data_arrays, train_data_arrays, options): caffe.select_device(options.train_device, False) net = solver.net test_eval = None if (options.test_net != None): test_eval = TestNetEvaluator(test_net, net, train_data_arrays, options) input_dims, output_dims, input_padding = get_spatial_io_dims(net) fmaps_in, fmaps_out = get_fmap_io_dims(net) dims = len(output_dims) losses = [] shapes = [] # Raw data slice input (n = 1, f = 1, spatial dims) shapes += [[1,fmaps_in] + input_dims] # Label data slice input (n = 1, f = #edges, spatial dims) shapes += [[1,fmaps_out] + output_dims] if (options.loss_function == 'malis'): # Connected components input (n = 1, f = 1, spatial dims) shapes += [[1,1] + output_dims] if (options.loss_function == 'euclid'): # Error scale input (n = 1, f = #edges, spatial dims) shapes += [[1,fmaps_out] + output_dims] # Nhood specifications (n = #edges, f = 3) if (('nhood' in data_arrays[0]) and (options.loss_function == 'malis')): shapes += [[1,1] + list(np.shape(data_arrays[0]['nhood']))] net_io = NetInputWrapper(net, shapes) make_dataset_offset = MakeDatasetOffset(input_dims, output_dims) if data_io.data_loader_should_be_used_with(data_arrays): using_data_loader = True # and initialize queue! loader_size = 20 n_workers = 10 make_dataset_offset = MakeDatasetOffset(dims, output_dims, input_padding) loader_kwargs = dict( size=loader_size, datasets=data_arrays, input_shape=tuple(input_dims), output_shape=tuple(output_dims), n_workers=n_workers, dataset_offset_func=make_dataset_offset ) print("creating queue with kwargs {}".format(loader_kwargs)) training_data_loader = data_io.DataLoader(**loader_kwargs) # start populating the queue for i in range(loader_size): if DEBUG: print("Pre-populating data loader's dataset #{i}/{size}" .format(i=i, size=training_data_loader.size)) shared_dataset_index, async_result = \ training_data_loader.start_refreshing_shared_dataset(i) else: using_data_loader = False # Loop from current iteration to last iteration for i in range(solver.iter, solver.max_iter): start = time.time() if (options.test_net != None and i % options.test_interval == 1): test_eval.evaluate(i) if USE_ONE_THREAD: # after testing finishes, switch back to the training device caffe.select_device(options.train_device, False) if not using_data_loader: dataset_index, offsets = make_dataset_offset(data_arrays) dataset = data_arrays[dataset_index] # These are the raw data elements data_slice = data_io.util.get_zero_padded_slice_from_array_by_offset( array=dataset['data'], origin=[0] + offsets, shape=[fmaps_in] + input_dims) label_slice = slice_data(dataset['label'], [0] + [offsets[di] + int(math.ceil(input_padding[di] / float(2))) for di in range(0, dims)], [fmaps_out] + output_dims) if 'transform' in dataset: # transform the input # assumes that the original input pixel values are scaled between (0,1) if DEBUG: print("data_slice stats, pre-transform: min", data_slice.min(), "mean", data_slice.mean(), "max", data_slice.max()) lo, hi = dataset['transform']['scale'] data_slice = 0.5 + (data_slice - 0.5) * np.random.uniform(low=lo, high=hi) lo, hi = dataset['transform']['shift'] data_slice = data_slice + np.random.uniform(low=lo, high=hi) else: dataset, index_of_shared_dataset = training_data_loader.get_dataset() data_slice = dataset['data'] assert data_slice.shape == (fmaps_in,) + tuple(input_dims) label_slice = dataset['label'] assert label_slice.shape == (fmaps_out,) + tuple(output_dims) if DEBUG: print("Training with next dataset in data loader, which has offset", dataset['offset']) mask_slice = None if 'mask' in dataset: mask_slice = dataset['mask'] if DEBUG: print("data_slice stats: min", data_slice.min(), "mean", data_slice.mean(), "max", data_slice.max()) if options.loss_function == 'malis': components_slice, ccSizes = malis.connected_components_affgraph(label_slice.astype(int32), dataset['nhood']) # Also recomputing the corresponding labels (connected components) net_io.setInputs([data_slice, label_slice, components_slice, data_arrays[0]['nhood']]) elif options.loss_function == 'euclid': label_slice_mean = label_slice.mean() if 'mask' in dataset: label_slice = label_slice * mask_slice label_slice_mean = label_slice.mean() / mask_slice.mean() w_pos = 1.0 w_neg = 1.0 if options.scale_error: frac_pos = np.clip(label_slice_mean, 0.05, 0.95) w_pos = w_pos / (2.0 * frac_pos) w_neg = w_neg / (2.0 * (1.0 - frac_pos)) error_scale_slice = scale_errors(label_slice, w_neg, w_pos) net_io.setInputs([data_slice, label_slice, error_scale_slice]) elif options.loss_function == 'softmax': # These are the affinity edge values net_io.setInputs([data_slice, label_slice]) loss = solver.step(1) # Single step if using_data_loader: training_data_loader.start_refreshing_shared_dataset(index_of_shared_dataset) while gc.collect(): pass time_of_iteration = time.time() - start if options.loss_function == 'euclid' or options.loss_function == 'euclid_aniso': print("[Iter %i] Time: %05.2fs Loss: %f, frac_pos=%f, w_pos=%f" % (i, time_of_iteration, loss, frac_pos, w_pos)) else: print("[Iter %i] Time: %05.2fs Loss: %f" % (i, time_of_iteration, loss)) losses += [loss] if hasattr(options, 'loss_snapshot') and ((i % options.loss_snapshot) == 0): io.savemat('loss.mat',{'loss':losses}) if using_data_loader: training_data_loader.destroy()
def train(solver, test_net, data_arrays, train_data_arrays, options): if DEBUG: data_io.logger.setLevel(logging.DEBUG) else: data_io.logger.setLevel(logging.INFO) caffe.select_device(options.train_device, False) net = solver.net test_eval = None if (options.test_net != None): test_eval = TestNetEvaluator(test_net, net, train_data_arrays, options) input_dims, output_dims, input_padding = get_spatial_io_dims(net) fmaps_in, fmaps_out = get_fmap_io_dims(net) dims = len(output_dims) losses = [] shapes = [] # Raw data slice input (n = 1, f = 1, spatial dims) shapes += [[1,fmaps_in] + input_dims] # Label data slice input (n = 1, f = #edges, spatial dims) shapes += [[1,fmaps_out] + output_dims] if (options.loss_function == 'malis'): # Connected components input. 2 channels, one for each phase of computation shapes += [[1, 2] + output_dims] if (options.loss_function == 'euclid'): # Error scale input (n = 1, f = #edges, spatial dims) shapes += [[1,fmaps_out] + output_dims] # Nhood specifications (n = #edges, f = 3) if (('nhood' in data_arrays[0]) and (options.loss_function == 'malis')): shapes += [[1,1] + list(np.shape(data_arrays[0]['nhood']))] net_io = NetInputWrapper(net, shapes) if DEBUG: for key in net.blobs.keys(): print(key, net.blobs[key].data.shape) make_dataset_offset = MakeDatasetOffset(input_dims, output_dims) if data_io.data_loader_should_be_used_with(data_arrays): using_data_loader = True # and initialize queue! loader_size = 20 n_workers = 10 loader_kwargs = dict( size=loader_size, datasets=data_arrays, input_shape=tuple(input_dims), output_shape=tuple(output_dims), n_workers=n_workers, dataset_offset_func=make_dataset_offset ) print("creating queue with kwargs {}".format(loader_kwargs)) training_data_loader = data_io.DataLoader(**loader_kwargs) # start populating the queue for i in range(loader_size): if DEBUG: print("Pre-populating data loader's dataset #{i}/{size}" .format(i=i, size=training_data_loader.size)) shared_dataset_index, async_result = \ training_data_loader.start_refreshing_shared_dataset(i, wait=True) else: using_data_loader = False # Loop from current iteration to last iteration for i in range(solver.iter, solver.max_iter): start = time.time() if (options.test_net != None and i % options.test_interval == 1): if not SAVE_IMAGES: test_eval.evaluate(i) if USE_ONE_THREAD: # after testing finishes, switch back to the training device caffe.select_device(options.train_device, False) if not using_data_loader: dataset_index, offsets = make_dataset_offset(data_arrays) dataset = data_arrays[dataset_index] # These are the raw data elements data_slice = data_io.util.get_zero_padded_slice_from_array_by_offset( array=dataset['data'], origin=[0] + offsets, shape=[fmaps_in] + input_dims) label_slice = slice_data(dataset['label'], [0] + [offsets[di] + int(math.ceil(input_padding[di] / float(2))) for di in range(0, dims)], [fmaps_out] + output_dims) components_slice, ccSizes = malis.connected_components_affgraph(label_slice.astype(np.int32), dataset['nhood']) components_shape = (1,) + tuple(output_dims) components_slice = components_slice.reshape(components_shape) mask_slice = np.ones_like(components_slice, dtype=np.uint8) mask_mean = 1 if 'transform' in dataset: # transform the input # assumes that the original input pixel values are scaled between (0,1) if DEBUG: print("data_slice stats, pre-transform: min", data_slice.min(), "mean", data_slice.mean(), "max", data_slice.max()) lo, hi = dataset['transform']['scale'] data_slice = 0.5 + (data_slice - 0.5) * np.random.uniform(low=lo, high=hi) lo, hi = dataset['transform']['shift'] data_slice = data_slice + np.random.uniform(low=lo, high=hi) else: dataset, index_of_shared_dataset = training_data_loader.get_dataset() data_slice = dataset['data'] label_slice = dataset['label'] if DEBUG: print("Training with next dataset in data loader, which has offset", dataset['offset']) mask_slice = dataset['mask'] mask_mean = np.mean(mask_slice) components_slice = dataset['components'] assert data_slice.shape == (fmaps_in,) + tuple(input_dims) assert label_slice.shape == (fmaps_out,) + tuple(output_dims) assert mask_slice.shape == (1,) + tuple(output_dims) assert components_slice.shape == (1,) + tuple(output_dims) if DEBUG: print("data_slice stats: min", data_slice.min(), "mean", data_slice.mean(), "max", data_slice.max()) print("mask_mean: ", mask_mean) if options.loss_function == 'malis': try: use_simple_malis_components =\ mask_mean == 1 or not options.malis_split_component_phases except AttributeError: use_simple_malis_components = mask_mean == 1 if use_simple_malis_components: components_negative_slice = components_slice else: ''' assumes that... * mask_slice is 1 at voxels containing good components, with a small dilation * components_slice does not contain component values equal to 1. (Original values were incremented.) ''' assert 1 not in components_slice, "components_slice can't contain a component value of 1. " \ "That's used for masked voxels in 'negative' MALIS training." mask_inverse = np.ones_like(mask_slice) - mask_slice # assert mask_inverse.shape == components_slice.shape mask_inverse = mask_inverse.astype(components_slice.dtype) # assert mask_inverse.dtype == components_slice.dtype components_negative_slice = components_slice + mask_inverse components_positive_slice = components_slice components_malis_slice = np.concatenate( (components_negative_slice, components_positive_slice), axis=0) net_io.setInputs([data_slice, label_slice, components_malis_slice, data_arrays[0]['nhood']]) elif options.loss_function == 'euclid': if mask_mean < 1: label_slice = label_slice * mask_slice label_slice_mean = label_slice.mean() / mask_mean else: label_slice_mean = label_slice.mean() w_pos = 1.0 w_neg = 1.0 if options.scale_error: frac_pos = np.clip(label_slice_mean, 0.05, 0.95) w_pos = w_pos / (2.0 * frac_pos) w_neg = w_neg / (2.0 * (1.0 - frac_pos)) error_scale_slice = scale_errors(label_slice, w_neg, w_pos) if mask_mean < 1: error_scale_slice *= mask_slice net_io.setInputs([data_slice, label_slice, error_scale_slice]) elif options.loss_function == 'softmax': # These are the affinity edge values net_io.setInputs([data_slice, label_slice]) loss = solver.step(1) # Single step try: save_image = SAVE_IMAGES or i % options.save_image_snapshot_period == 0 except AttributeError: save_image = SAVE_IMAGES if save_image: dataset_to_show = dict() net_prediction_shape = (1, fmaps_out,) + tuple(output_dims) prediction = np.zeros(net_prediction_shape, np.float32) for blob_key in reversed(solver.net.blobs.keys()): try: blob_shape = solver.net.blobs[blob_key].data.shape except: blob_shape = None if blob_shape == net_prediction_shape: prediction = solver.net.blobs[blob_key].data.copy() break # stop checking blobs dataset_to_show['pred'] = prediction.reshape((fmaps_out,) + tuple(output_dims)) try: import zwatershed (s, V) = zwatershed.zwatershed_and_metrics( components_slice.reshape(output_dims).astype(np.uint32), dataset_to_show['pred'], [50000], [50000]) components_prediction = s[0] except: components_prediction = np.zeros(shape=output_dims, dtype=np.int32) dataset_to_show['predseg'] = components_prediction dataset_to_show['data'] = data_slice.reshape(input_dims) if components_slice is None: components_slice, _ = malis.connected_components_affgraph(label_slice.astype(np.int32), dataset['nhood']) dataset_to_show['components'] = components_slice.reshape(output_dims) dataset_to_show['label'] = label_slice dataset_to_show['mask'] = mask_slice.reshape(output_dims) try: dataset_to_show['components_negative'] = components_negative_slice.reshape(output_dims) dataset_to_show['components_positive'] = components_positive_slice.reshape(output_dims) except UnboundLocalError: # variables weren't declared, because malis isn't being computed pass try: dataset_to_show['error_scale_slice'] = error_scale_slice except UnboundLocalError: pass assert dataset_to_show['data'].shape == tuple(input_dims), dataset_to_show['data'].shape assert dataset_to_show['label'].shape == (3,) + tuple(output_dims), dataset_to_show['label'].shape assert dataset_to_show['components'].shape == tuple(output_dims), dataset_to_show['components'].shape assert dataset_to_show['pred'].shape == (3,) + tuple(output_dims), dataset_to_show['pred'].shape assert dataset_to_show['predseg'].shape == tuple(output_dims), dataset_to_show['predseg'].shape f = visualization.showme(dataset_to_show, int(input_dims[0] / 2)) if not os.path.exists('snapshots'): os.mkdir('snapshots') f.savefig('snapshots/%08d.png' % i) plt.close() while gc.collect(): pass time_of_iteration = time.time() - start if 'mask' in dataset: mask_fraction_str = "%07.5f" % np.mean(dataset['mask']) else: mask_fraction_str = "no mask" if options.loss_function == 'euclid' or options.loss_function == 'euclid_aniso': print("[Iter %06i]" % i, "Time: %05.2fs" % time_of_iteration, "Loss: %08.6f" % loss, "Mask:", mask_fraction_str, "frac_pos=%08.6f" % frac_pos, "w_pos=%08.6f" % w_pos, ) else: print("[Iter %06i]" % i, "Time: %05.2fs" % time_of_iteration, "Loss: %08.6f" % loss, "Mask:", mask_fraction_str, ) losses += [loss] if hasattr(options, 'loss_snapshot') and ((i % options.loss_snapshot) == 0): io.savemat('loss.mat',{'loss':losses}) if using_data_loader: training_data_loader.start_refreshing_shared_dataset(index_of_shared_dataset) if using_data_loader: training_data_loader.destroy()
def process(nets, data_arrays, shapes=None, net_io=None, zero_pad_source_data=True, target_arrays=None): net = None thread_pool = None device_locks = None if isinstance(nets, list): # Grab one network to figure out parameters net = nets[0] else: net = nets input_dims, output_dims, input_padding = get_spatial_io_dims(net) fmaps_in, fmaps_out = get_fmap_io_dims(net) dims = len(output_dims) if target_arrays is not None: assert len(data_arrays) == len(target_arrays) for data_array, target in zip(data_arrays, target_arrays): prediction_shape = (fmaps_out,) + data_array['data'].shape[-dims:] assert prediction_shape == target.shape, \ "Target array for dname {} is the wrong shape. {} should be {}"\ .format(data_array['name'], target.shape, prediction_shape) pred_arrays = [] if shapes is None: # Raw data slice input (n = 1, f = 1, spatial dims) shapes = [[1, fmaps_in] + input_dims] if net_io is None: if isinstance(nets, list): net_io = [] for net_inst in nets: net_io += [NetInputWrapper(net_inst, shapes)] else: net_io = NetInputWrapper(net, shapes) using_data_loader = data_io.data_loader_should_be_used_with(data_arrays) processing_data_loader = None if using_data_loader: processing_data_loader = data_io.DataLoader( size=5, datasets=data_arrays, input_shape=tuple(input_dims), output_shape=None, # ignore labels n_workers=3 ) dataset_offsets_to_process = generate_dataset_offsets_for_processing( net, data_arrays, process_borders=zero_pad_source_data) for source_dataset_index in dataset_offsets_to_process: # Launch if isinstance(nets, list): thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=len(nets)) device_locks = [] for device_list_id in range(0,len(nets)): device_locks += [threading.Lock()] list_of_offsets_to_process = dataset_offsets_to_process[source_dataset_index] if DEBUG: print("source_dataset_index = ", source_dataset_index) print("Processing source volume #{i} with offsets list {o}" .format(i=source_dataset_index, o=list_of_offsets_to_process)) # make a copy of that list for enqueueing purposes offsets_to_enqueue = list(list_of_offsets_to_process) data_array = data_arrays[source_dataset_index]['data'] if target_arrays is not None: pred_array = target_arrays[source_dataset_index] else: prediction_shape = (fmaps_out,) + data_array.shape[-dims:] pred_array = np.zeros(shape=prediction_shape, dtype=np.float32) if using_data_loader: # start pre-populating queue for shared_dataset_index in range(min(processing_data_loader.size, len(list_of_offsets_to_process))): # fill shared-memory datasets with an offset offsets = offsets_to_enqueue.pop(0) offsets = tuple([int(o) for o in offsets]) # print("Pre-populating processing data loader with data at offset {}".format(offsets)) print("Pre-populating data loader's dataset #{i}/{size} with dataset #{d} and offset {o}" .format(i=shared_dataset_index, size=processing_data_loader.size, d=source_dataset_index, o=offsets)) shared_dataset_index, async_result = processing_data_loader.start_refreshing_shared_dataset( shared_dataset_index, offsets, source_dataset_index, transform=False, wait=True ) # process each offset for i_offsets in range(len(list_of_offsets_to_process)): index_of_shared_dataset = None if using_data_loader: dataset, index_of_shared_dataset = processing_data_loader.get_dataset() offsets = list(dataset['offset']) # convert tuple to list data_slice = dataset['data'] if DEBUG: print("Processing next dataset in processing data loader, which has offset {o}" .format(o=dataset['offset'])) else: offsets = list_of_offsets_to_process[i_offsets] if zero_pad_source_data: data_slice = data_io.util.get_zero_padded_slice_from_array_by_offset( array=data_array, origin=[0] + offsets, shape=[fmaps_in] + [output_dims[di] + input_padding[di] for di in range(dims)] ) else: data_slice = slice_data( data_array, [0] + offsets, [fmaps_in] + [output_dims[di] + input_padding[di] for di in range(dims)] ) # process the chunk if isinstance(net_io, list): thread_pool.submit(process_core_multithreaded, device_locks, net_io, data_slice, offsets, pred_array, input_padding, fmaps_out, output_dims, using_data_loader, offsets_to_enqueue, processing_data_loader, index_of_shared_dataset, source_dataset_index) else: process_core(net_io, data_slice, offsets, pred_array, input_padding, fmaps_out, output_dims, using_data_loader, offsets_to_enqueue, processing_data_loader, index_of_shared_dataset, source_dataset_index) if not (thread_pool is None): thread_pool.shutdown(True) pred_arrays.append(pred_array) if using_data_loader: processing_data_loader.destroy() return pred_arrays
def process(net, data_arrays, shapes=None, net_io=None, zero_pad_source_data=True, target_arrays=None): if DEBUG: data_io.logger.setLevel(logging.DEBUG) else: data_io.logger.setLevel(logging.INFO) input_dims, output_dims, input_padding = get_spatial_io_dims(net) fmaps_in, fmaps_out = get_fmap_io_dims(net) dims = len(output_dims) if target_arrays is not None: assert len(data_arrays) == len(target_arrays) for data_array, target in zip(data_arrays, target_arrays): prediction_shape = (fmaps_out,) + data_array['data'].shape[-dims:] assert prediction_shape == target.shape, \ "Target array for dname {} is the wrong shape. {} should be {}"\ .format(data_array['name'], target.shape, prediction_shape) pred_arrays = [] if shapes is None: # Raw data slice input (n = 1, f = 1, spatial dims) shapes = [[1, fmaps_in] + input_dims] if net_io is None: net_io = NetInputWrapper(net, shapes) using_data_loader = data_io.data_loader_should_be_used_with(data_arrays) if using_data_loader: processing_data_loader = data_io.DataLoader( size=5, datasets=data_arrays, input_shape=tuple(input_dims), output_shape=None, # ignore labels n_workers=3 ) dataset_offsets_to_process = generate_dataset_offsets_for_processing( net, data_arrays, process_borders=zero_pad_source_data) for source_dataset_index in dataset_offsets_to_process: list_of_offsets_to_process = dataset_offsets_to_process[source_dataset_index] offsets_that_have_been_processed = [] offsets_that_have_been_requested = [] if DEBUG: print("source_dataset_index = ", source_dataset_index) print("Processing source volume #{i} with offsets list {o}" .format(i=source_dataset_index, o=list_of_offsets_to_process)) # make a copy of that list for enqueueing purposes offsets_to_enqueue = list(list_of_offsets_to_process) data_array = data_arrays[source_dataset_index]['data'] if target_arrays is not None: pred_array = target_arrays[source_dataset_index] else: prediction_shape = (fmaps_out,) + data_array.shape[-dims:] pred_array = np.zeros(shape=prediction_shape, dtype=np.float32) if using_data_loader: # start pre-populating queue for shared_dataset_index in range(min(processing_data_loader.size, len(list_of_offsets_to_process))): # fill shared-memory datasets with an offset offsets = offsets_to_enqueue.pop(0) offsets = tuple([int(o) for o in offsets]) # print("Pre-populating processing data loader with data at offset {}".format(offsets)) print("Pre-populating data loader's dataset #{i}/{size} with dataset #{d} and offset {o}" .format(i=shared_dataset_index, size=processing_data_loader.size, d=source_dataset_index, o=offsets)) shared_dataset_index, async_result = processing_data_loader.start_refreshing_shared_dataset( shared_dataset_index, offsets, source_dataset_index, transform=False, wait=True ) offsets_that_have_been_requested.append(offsets) # process each offset for i_offsets in range(len(list_of_offsets_to_process)): if DEBUG: print("offsets that have been requested but which haven't been processed:", sorted(list( set([tuple(o) for o in offsets_that_have_been_requested]) - \ set([tuple(o) for o in offsets_that_have_been_processed]) )) ) if using_data_loader: dataset, index_of_shared_dataset = processing_data_loader.get_dataset() offsets = list(dataset['offset']) # convert tuple to list data_slice = dataset['data'] if DEBUG: print("Processing next dataset in processing data loader, which has offset {o}" .format(o=dataset['offset'])) else: offsets = list_of_offsets_to_process[i_offsets] if zero_pad_source_data: data_slice = data_io.util.get_zero_padded_slice_from_array_by_offset( array=data_array, origin=[0] + offsets, shape=[fmaps_in] + [output_dims[di] + input_padding[di] for di in range(dims)] ) else: data_slice = slice_data( data_array, [0] + offsets, [fmaps_in] + [output_dims[di] + input_padding[di] for di in range(dims)] ) # process the chunk output = process_input_data(net_io, data_slice) print(offsets) print(output.mean()) offsets_that_have_been_processed.append(offsets) pads = [int(math.ceil(pad / float(2))) for pad in input_padding] offsets_for_pred_array = [0] + [offset + pad for offset, pad in zip(offsets, pads)] set_slice_data(pred_array, output, offsets_for_pred_array, [fmaps_out] + output_dims) if using_data_loader and len(offsets_to_enqueue) > 0: # start adding the next slice to the loader with index_of_shared_dataset new_offsets = offsets_to_enqueue.pop(0) processing_data_loader.start_refreshing_shared_dataset( index_of_shared_dataset, new_offsets, source_dataset_index, transform=False ) offsets_that_have_been_requested.append(new_offsets) pred_arrays.append(pred_array) if using_data_loader: processing_data_loader.destroy() return pred_arrays