def __call__(self, input_chunk: np.ndarray): """ args: input_chunk (Chunk): input chunk with global offset """ assert isinstance(input_chunk, Chunk) self._update_parameters_for_input_chunk(input_chunk) output_buffer = self._get_output_buffer(input_chunk) if not self.mask_output_chunk: self._check_alignment() if self.dry_run: print('dry run, return a special artifical chunk.') size = output_buffer.shape if self.mask_myelin_threshold: # eleminate the myelin channel size = (size[0] - 1, *size[1:]) return Chunk.create(size=size, dtype=output_buffer.dtype, voxel_offset=output_buffer.global_offset) if input_chunk == 0: print('input is all zero, return zero buffer directly') if self.mask_myelin_threshold: assert output_buffer.shape[0] == 4 return output_buffer[:-1, ...] else: return output_buffer if np.issubdtype(input_chunk.dtype, np.integer): # normalize to 0-1 value range dtype_max = np.iinfo(input_chunk.dtype).max input_chunk = input_chunk.astype(self.dtype) / dtype_max if self.verbose: chunk_time_start = time.time() # iterate the offset list for i in tqdm(range(0, len(self.patch_slices_list), self.batch_size), disable=not self.verbose, desc='ConvNet inference for patches: '): if self.verbose: start = time.time() batch_slices = self.patch_slices_list[i:i + self.batch_size] for batch_idx, slices in enumerate(batch_slices): self.input_patch_buffer[batch_idx, 0, :, :, :] = input_chunk.cutout( slices[0]).array if self.verbose > 1: end = time.time() print('prepare %d input patches takes %3f sec' % (self.batch_size, end - start)) start = end # the input and output patch is a 5d numpy array with # datatype of float32, the dimensions are batch/channel/z/y/x. # the input image should be normalized to [0,1] output_patch = self.patch_inferencer(self.input_patch_buffer) if self.verbose > 1: assert output_patch.ndim == 5 end = time.time() print('run inference for %d patch takes %3f sec' % (self.batch_size, end - start)) start = end for batch_idx, slices in enumerate(batch_slices): # only use the required number of channels # the remaining channels are dropped # the slices[0] is for input patch slice # the slices[1] is for output patch slice offset = (0, ) + tuple(s.start for s in slices[1]) output_chunk = Chunk(output_patch[batch_idx, :, :, :, :], global_offset=offset) ## save some patch for debug #bbox = output_chunk.bbox #if bbox.minpt[-1] < 94066 and bbox.maxpt[-1] > 94066 and \ # bbox.minpt[-2]<81545 and bbox.maxpt[-2]>81545 and \ # bbox.minpt[-3]<17298 and bbox.maxpt[-3]>17298: # print('save patch: ', output_chunk.bbox) # breakpoint() # output_chunk.to_tif() # #input_chunk.cutout(slices[0]).to_tif() output_buffer.blend(output_chunk) if self.verbose > 1: end = time.time() print('blend patch takes %3f sec' % (end - start)) if self.verbose: print("Inference of whole chunk takes %3f sec" % (time.time() - chunk_time_start)) if self.mask_output_chunk: output_buffer *= self.output_chunk_mask # theoretically, all the value of output_buffer should not be greater than 1 # we use a slightly higher value here to accomondate numerical precision issue np.testing.assert_array_less( output_buffer, 1.0001, err_msg='output buffer should not be greater than 1') if self.mask_myelin_threshold: # currently only for masking out affinity map assert output_buffer.shape[0] == 4 output_chunk = output_buffer.mask_using_last_channel( threshold=self.mask_myelin_threshold) # currently neuroglancer only support float32, not float16 if output_chunk.dtype == np.dtype('float16'): output_chunk = output_chunk.astype('float32') return output_chunk else: return output_buffer
def predict_chunk(self, input_chunk: np.ndarray): """ args: input_chunk (Chunk): input chunk with global offset """ assert isinstance(input_chunk, Chunk) self._update_parameters_for_input_chunk(input_chunk) output_buffer = self._get_output_buffer(input_chunk) if not self.mask_output_chunk: self._check_alignment() if self.dry_run: print('dry run, return a special artificial chunk.') size = output_buffer.shape if self.mask_myelin_threshold: # eliminate myelin channel size = (size[0] - 1, *size[1:]) return Chunk.create(size=size, dtype=output_buffer.dtype, voxel_offset=output_buffer.global_offset) if input_chunk == 0: print('input is all zero, return zero buffer directly') if self.mask_myelin_threshold: assert output_buffer.shape[0] == 4 return output_buffer[:-1, ...] else: return output_buffer if np.issubdtype(input_chunk.dtype, np.integer): # normalize to 0-1 value range dtype_max = np.info(input_chunk.dtype).max input_chunk = input_chunk.astype(self.dtype) / dtype_max if self.verbose: chunk_time_start = time.time() # set model to evalutation mode self.model.eval() # send model to device self.model.cuda() with torch.no_grad(): for i in range(0, len(self.patch_slices_list), self.batch_size): if self.verbose: start = time.time() batch_slices = self.patch_slices_list[i:i + self.batch_size] for batch_idx, slices in enumerate(batch_slices): self.input_patch_buffer[batch_idx, 0, :, :, :] = input_chunk.cutout( slices[0]).array if self.verbose > 1: end = time.time() print( 'preparing %d input patches takes %3f sec' % self.batch_size, end - start) start = end # the input and output patch is a 5d numpy array with # datatype of float32, the dimensions are (batch, channel, z, y, x) # the input image should be normalized to [0,1] patch = torch.from_numpy( self.input_patch_buffer).float().cuda() output_patch = self.model(patch) assert output_patch.ndim == 5 net_out = output_patch.cpu().numpy() #net_out_mask = np.where(net_out >= 0.9, 1, 0) #print(net_out.shape) for batch_idx, slices in enumerate(batch_slices): # slices[0] is for input patch slice # slices[1] is for output patch slice offset = (0, ) + tuple(s.start for s in slices[1]) print(offset) output_chunk = Chunk(net_out[batch_idx, 1:, :, :, :], global_offset=offset) output_buffer.blend(output_chunk) return output_buffer