def setUp(self): super(ErrorBase, self).setUp() rtp = data_provider.ReadTableParameters([self.recordio_path()], True, util.BatchParameters(2, 1, 2), True, 0, 768) dp = data_provider.DataParameters(rtp, self.input_z_values, self.input_channel_values, self.target_z_values, self.target_channel_values) _, self.target_lt = data_provider.cropped_input_and_target(dp) self.target_lt_0 = lt.select(self.target_lt, { 'channel': 'DAPI_CONFOCAL', 'mask': False }) self.target_lt_1 = lt.select(self.target_lt, { 'channel': 'NEURITE_CONFOCAL', 'mask': False }) self.target_lt_0 = lt.reshape(self.target_lt_0, self.target_lt_0.axes.keys()[3:], ['channel']) self.target_lt_1 = lt.reshape(self.target_lt_1, self.target_lt_1.axes.keys()[3:], ['channel'])
def setUp(self): super(Base, self).setUp() self.input_z_values = [round(v, 4) for v in np.linspace(0.0, 1.0, 13)] self.input_channel_values = ['BRIGHTFIELD', 'PHASE_CONTRAST', 'DIC'] self.target_z_values = ['MAXPROJECT'] self.target_channel_values = [ 'DAPI_CONFOCAL', 'DAPI_WIDEFIELD', 'CELLMASK_CONFOCAL', 'TUJ1_WIDEFIELD', 'NFH_CONFOCAL', 'MAP2_CONFOCAL', 'ISLET_WIDEFIELD', 'DEAD_CONFOCAL', ] rtp = data_provider.ReadTableParameters([self.recordio_path()], True, util.BatchParameters(2, 1, 2), True, 0, 256) dp = data_provider.DataParameters(rtp, self.input_z_values, self.input_channel_values, self.target_z_values, self.target_channel_values) # pylint: disable=line-too-long self.input_lt, self.target_lt = data_provider.cropped_input_and_target( dp) # pylint: enable=line-too-long minception_rtp = data_provider.ReadTableParameters( [self.recordio_path()], True, util.BatchParameters(2, 1, 2), True, 51, 110) minception_dp = data_provider.DataParameters( minception_rtp, self.input_z_values, self.input_channel_values, self.target_z_values, self.target_channel_values) # pylint: disable=line-too-long self.minception_input_lt, self.minception_target_lt = data_provider.cropped_input_and_target( minception_dp)
def setUp(self): super(CroppedInputAndTargetTest, self).setUp() batch_size = 2 rtp = data_provider.ReadTableParameters( shard_paths=[self.recordio_path()], is_recordio=True, bp=util.BatchParameters(batch_size, num_threads=1, capacity=2), is_deterministic=True, pad_size=0, crop_size=512) dp = data_provider.DataParameters( rtp, input_z_values=self.input_z_values, input_channel_values=self.input_channel_values, target_z_values=self.target_z_values, target_channel_values=self.target_channel_values) self.input_lt, self.target_lt = data_provider.cropped_input_and_target( dp)
def setUp(self): super(ErrorPanelTest, self).setUp() rtp = data_provider.ReadTableParameters([self.recordio_path()], True, util.BatchParameters(2, 1, 2), True, 0, 768) dp = data_provider.DataParameters(rtp, self.input_z_values, self.input_channel_values, self.target_z_values, self.target_channel_values) _, batch_target_lt = data_provider.cropped_input_and_target(dp) self.prediction_lt = lt.slice( lt.select(batch_target_lt, {'mask': False}), {'batch': slice(0, 1)}) self.prediction_lt = util.onehot(16, self.prediction_lt) self.target_lt = lt.slice(batch_target_lt, {'batch': slice(1, 2)}) self.error_panel_lt = visualize.error_panel(self.target_lt, self.prediction_lt)
def provide_preprocessed_data( dp: data_provider.DataParameters, ap: Optional[augment.AugmentParameters], extract_patch_size: int, stride: int, name: str = None, ) -> Tuple[np.ndarray, lt.LabeledTensor, lt.LabeledTensor]: """Provide preprocessed input and target patches. Args: dp: DataParameters. ap: Optional AugmentParameters. extract_patch_size: Patch size for patch extraction. stride: Stride for patch extraction. name: Optional op name. Returns: An array containing patch center locations. The patches are extracted with padding, so that the stitched model outputs will form an image the same size as the input. A tensor with model inputs, possibly jittered and corrupted for data enrichment. A tensor with model targets. """ with tf.name_scope(name, 'provide_preprocessed_data', []) as scope: input_lt, target_lt = data_provider.cropped_input_and_target(dp) visualize.summarize_image( visualize.canonical_image(input_lt, name=scope + 'input')) visualize.summarize_image( visualize.canonical_image(target_lt, name=scope + 'target')) if ap is not None: input_lt, target_lt = augment.augment(ap, input_lt, target_lt) visualize.summarize_image( visualize.canonical_image(input_lt, name=scope + 'input_jitter')) visualize.summarize_image( visualize.canonical_image(target_lt, name=scope + 'target_jitter')) rc = lt.ReshapeCoder(['z', 'channel', 'mask'], ['channel']) patch_centers, input_lt = ops.extract_patches_single_scale( extract_patch_size, stride, rc.encode(input_lt)) input_lt = rc.decode(input_lt) input_lt = lt.reshape(input_lt, ['batch', 'patch_row', 'patch_column'], ['batch'], name=scope + 'input') rc = lt.ReshapeCoder(['z', 'channel', 'mask'], ['channel']) target_lt = rc.decode( ops.extract_patches_single_scale(extract_patch_size, stride, rc.encode(target_lt))[1]) target_lt = lt.reshape(target_lt, ['batch', 'patch_row', 'patch_column'], ['batch'], name=scope + 'target') return patch_centers, input_lt, target_lt