def test_idw_no_broadcasting(self): # 3D test_input = np.zeros((2, 3, 3, 3, 2)) test_input[:, 0, 0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.concat( [tf.ones((1, 5, 5, 5, 3)) * 0.2, tf.ones((1, 5, 5, 5, 3)) * 1.2], axis=0) out = ResamplerLayer("IDW")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue( np.all( np.isclose(out_value[0, ..., 0], 1.0 / (1. + 1. / 2. + 36. / 132. + 12. / 192.), atol=1e-5))) self.assertTrue( np.all(np.isclose(out_value[1, ..., 0], 0.0, atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 5, 5, 2)) # 2D test_input = np.zeros((2, 3, 3, 2)) test_input[:, 0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.concat( [tf.ones((1, 5, 5, 2)) * 0.2, tf.ones((1, 5, 5, 2)) * 1.2], axis=0) out = ResamplerLayer("IDW")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue( np.all( np.isclose(out_value[0, ..., 0], 1.0 / (1.0 + 1.0 / 16.0 + 16. / 68.), atol=1e-5))) self.assertTrue( np.all(np.isclose(out_value[1, ..., 0], 0.0, atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 5, 2)) # 1D test_input = np.zeros((2, 3, 2)) test_input[:, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.concat( [tf.ones((1, 5, 1)) * 0.2, tf.ones((1, 5, 1)) * 1.2], axis=0) out = ResamplerLayer("IDW")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue( np.all( np.isclose(out_value[0, ..., 0], 1.0 / (1.0 + 1 / 16.0), atol=1e-5))) self.assertTrue( np.all(np.isclose(out_value[1, ..., 0], 0.0, atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 2))
def test_idw_shape(self): # 3D test_input = np.zeros((2, 8, 8, 8, 2)) test_input[0, 0, 0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.ones((1, 5, 5, 5, 3)) * 0.1 out = ResamplerLayer("IDW")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue( np.all( np.isclose(out_value[0, ..., 0], 1.0 / (1. + 9. / 83 + 9. / 163 + 3. / 243), atol=1e-5))) self.assertTrue(np.all(out_value[1, ...] == 0)) self.assertEqual(out_value.shape, (2, 5, 5, 5, 2)) # 2D test_input = np.zeros((2, 8, 8, 2)) test_input[0, 0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.ones((1, 5, 5, 2)) * 0.1 out = ResamplerLayer("IDW")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue(np.all(out_value[1, ...] == 0)) self.assertTrue( np.all( np.isclose(out_value[0, ..., 0], 1. / (2. / 41. + 1. / 81.0 + 1.0), atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 5, 2)) # 1D test_input = np.zeros((2, 8, 2)) test_input[0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.ones((1, 5, 1)) * 0.1 out = ResamplerLayer("IDW")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue(np.all(out_value[1, ...] == 0)) self.assertTrue( np.all( np.isclose(out_value[0, ..., 0], 100.0 / (100.0 + 1 / 0.81), atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 2))
def _test_partial_shape_correctness(self, input, rank, batch_size, grid, interpolation, boundary, expected_value=None): resampler = ResamplerLayer(interpolation=interpolation, boundary=boundary) input_default = tf.random_uniform(input.shape) if batch_size > 0 and rank > 0: input_placeholder = tf.placeholder_with_default( input_default, shape=[batch_size] + [None] * (rank + 1)) elif batch_size <= 0 and rank > 0: input_placeholder = tf.placeholder_with_default(input_default, shape=[None] * (rank + 2)) elif batch_size <= 0 and rank <= 0: input_placeholder = tf.placeholder_with_default(input_default, shape=None) out = resampler(input_placeholder, grid) with self.test_session() as sess: out_value = sess.run(out, feed_dict={input_placeholder: input}) # print(expected_value) # print(out_value) if expected_value is not None: self.assertAllClose(expected_value, out_value)
def test_linear_no_broadcasting(self): # 3D test_input = np.zeros((2, 8, 8, 8, 2)) test_input[:, 0, 0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.concat( [tf.ones((1, 5, 5, 5, 3)) * 0.1, tf.ones((1, 5, 5, 5, 3)) * 0.2], axis=0) out = ResamplerLayer("LINEAR")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue( np.all(np.isclose(out_value[0, ..., 0], 0.9**3, atol=1e-5))) self.assertTrue( np.all(np.isclose(out_value[1, ..., 0], 0.8**3, atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 5, 5, 2)) # 2D test_input = np.zeros((2, 8, 8, 2)) test_input[:, 0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.concat( [tf.ones((1, 5, 5, 2)) * 0.1, tf.ones((1, 5, 5, 2)) * 0.2], axis=0) out = ResamplerLayer("LINEAR")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue( np.all(np.isclose(out_value[0, ..., 0], 0.9**2, atol=1e-5))) self.assertTrue( np.all(np.isclose(out_value[1, ..., 0], 0.8**2, atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 5, 2)) # 1D test_input = np.zeros((2, 8, 2)) test_input[:, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.concat( [tf.ones((1, 5, 1)) * 0.1, tf.ones((1, 5, 1)) * 0.2], axis=0) out = ResamplerLayer("LINEAR")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue( np.all(np.isclose(out_value[0, ..., 0], 0.9, atol=1e-5))) self.assertTrue( np.all(np.isclose(out_value[1, ..., 0], 0.8, atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 2))
def _test_correctness(self, inputs, grid, interpolation, boundary, expected_value): resampler = ResamplerLayer(interpolation=interpolation, boundary=boundary) out = resampler(inputs, grid) with self.test_session() as sess: out_value = sess.run(out) self.assertAllClose(expected_value, out_value)
def _test_correctness(self, input, grid, interpolation, boundary, expected_value): resampler = ResamplerLayer(interpolation=interpolation, boundary=boundary) out = resampler(input, grid) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) out_value = sess.run(out) self.assertAllClose(expected_value, out_value)
def test_shape_interface(self): test_input = tf.zeros((2, 10, 10, 10, 3)) test_coords = tf.zeros((3, 5, 5, 5, 3)) # bad batch sizes with self.assertRaisesRegexp(ValueError, ''): out = ResamplerLayer()(test_input, test_coords) test_input = tf.zeros((2, 10, 10, 10, 3)) test_coords = tf.zeros((5, 5, 5, 3)) # bad batch sizes with self.assertRaisesRegexp(ValueError, ''): out = ResamplerLayer()(test_input, test_coords) test_input = tf.zeros((1, 10, 10, 3)) test_coords = tf.zeros((1, 5, 5, 3)) # bad n coordinates with self.assertRaisesRegexp(ValueError, ''): out = ResamplerLayer()(test_input, test_coords)
def _test_simple_2d_images(self, interpolation='linear', boundary='replicate'): # rotating around the center (8, 8) by 15 degree expected = [[0.96592583, -0.25881905, 2.34314575], [0.25881905, 0.96592583, -1.79795897]] expected = np.asarray(expected).flatten() test_image, input_shape = get_multiple_2d_images() test_target, target_shape = get_multiple_2d_rotated_targets() identity_affine = [[1., 0., 0., 0., 1., 0.], [1., 0., 0., 0., 1., 0.], [1., 0., 0., 0., 1., 0.], [1., 0., 0., 0., 1., 0.]] affine_var = tf.get_variable('affine', initializer=identity_affine) grid = AffineGridWarperLayer(source_shape=input_shape[1:-1], output_shape=target_shape[1:-1], constraints=None) warp_coords = grid(affine_var) resampler = ResamplerLayer(interpolation, boundary=boundary) new_image = resampler(tf.constant(test_image, dtype=tf.float32), warp_coords) diff = tf.reduce_mean( tf.squared_difference(new_image, tf.constant(test_target, dtype=tf.float32))) learning_rate = 0.05 if (interpolation == 'linear') and (boundary == 'zero'): learning_rate = 0.0003 optimiser = tf.train.AdagradOptimizer(learning_rate) grads = optimiser.compute_gradients(diff) opt = optimiser.apply_gradients(grads) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) init_val, affine_val = sess.run([diff, affine_var]) # compute the MAE between the initial estimated parameters and the expected parameters init_var_diff = np.sum(np.abs(affine_val[0] - expected)) for it in range(500): _, diff_val, affine_val = sess.run([opt, diff, affine_var]) # print('{} diff: {}, {}'.format(it, diff_val, affine_val[0])) # import matplotlib.pyplot as plt # plt.figure() # plt.imshow(test_target[0]) # plt.draw() # plt.figure() # plt.imshow(sess.run(new_image).astype(np.uint8)[0]) # plt.draw() # plt.show() self.assertGreater(init_val, diff_val) # compute the MAE between the final estimated parameters and the expected parameters var_diff = np.sum(np.abs(affine_val[0] - expected)) self.assertGreater(init_var_diff, var_diff) print('{} {} -- diff {}'.format(interpolation, boundary, var_diff)) print('{}'.format(affine_val[0]))
def test_nearest_shape(self): # 3D test_input = np.zeros((2, 8, 8, 8, 2)) test_input[0, 0, 0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.ones((1, 5, 5, 5, 3)) * 0.1 out = ResamplerLayer("NEAREST")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue( np.all(np.isclose(out_value[0, ..., 0], 1.0, atol=1e-5))) self.assertTrue(np.all(out_value[1, ...] == 0)) self.assertEqual(out_value.shape, (2, 5, 5, 5, 2)) # 2D test_input = np.zeros((2, 8, 8, 2)) test_input[0, 0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.ones((1, 5, 5, 2)) * 0.1 out = ResamplerLayer("NEAREST")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue(np.all(out_value[1, ...] == 0)) self.assertTrue( np.all(np.isclose(out_value[0, ..., 0], 1.0, atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 5, 2)) # 1D test_input = np.zeros((2, 8, 2)) test_input[0, 0, 0] = 1.0 test_input = tf.constant(test_input) test_coords = tf.ones((1, 5, 1)) * 0.1 out = ResamplerLayer("NEAREST")(test_input, test_coords) with self.cached_session() as sess: out_value = sess.run(out) self.assertTrue(np.all(out_value[1, ...] == 0)) self.assertTrue( np.all(np.isclose(out_value[0, ..., 0], 1.0, atol=1e-5))) self.assertEqual(out_value.shape, (2, 5, 2))
def __init__(self, source_shape, output_shape, coeff_shape, field_transform=None, resampler=None, name='resampling_interpolated_spline_grid_warper'): """Constructs an ResampledFieldingGridWarperLayer. Args: source_shape: Iterable of integers determining the size of the source signal domain. output_shape: Iterable of integers determining the size of the destination resampled signal domain. coeff_shape: Shape of displacement field. interpolation: type_str of interpolation as used by tf.image.resize_images name: Name of Module. field_transform: an object defining the spatial relationship between the output_grid and the field. batch_size x4x4 tensor: per-image transform matrix from output coords to field coords None (default): corners of output map to corners of field with an allowance for interpolation (1 for bspline, 0 for linear) resampler: a ResamplerLayer used to interpolate the deformation field name: Name of module. Raises: TypeError: If output_shape and source_shape are not both iterable. """ if resampler == None: self._resampler = ResamplerLayer(interpolation='LINEAR', boundary='REPLICATE') self._interpolation = 'LINEAR' else: self._resampler = resampler self._interpolation = self._resampler.interpolation self._field_transform = field_transform super(ResampledFieldGridWarperLayer, self).__init__(source_shape=source_shape, output_shape=output_shape, coeff_shape=coeff_shape, name=name)
def layer_op(self, input_tensor): input_shape = input_tensor.shape.as_list() batch_size = input_shape[0] spatial_shape = input_shape[1:-1] spatial_rank = infer_spatial_rank(input_tensor) if self._transform is None: relative_transform = self._random_transform( batch_size, spatial_rank) self._transform = relative_transform else: relative_transform = self._transform grid_warper = AffineGridWarperLayer(spatial_shape, spatial_shape) resampler = ResamplerLayer(interpolation=self.interpolation, boundary=self.boundary) warp_parameters = tf.reshape(relative_transform[:, :spatial_rank, :], [batch_size, -1]) grid = grid_warper(warp_parameters) resampled = resampler(input_tensor, grid) return resampled
def _test_grads_images(self, interpolation='linear', boundary='replicate', ndim=2): if ndim == 2: test_image, input_shape = get_multiple_2d_images() test_target, target_shape = get_multiple_2d_targets() identity_affine = [[1., 0., 0., 0., 1., 0.]] * 4 else: test_image, input_shape = get_multiple_3d_images() test_target, target_shape = get_multiple_3d_targets() identity_affine = [[ 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0. ]] * 4 affine_var = tf.get_variable('affine', initializer=identity_affine) grid = AffineGridWarperLayer(source_shape=input_shape[1:-1], output_shape=target_shape[1:-1], constraints=None) warp_coords = grid(affine_var) resampler = ResamplerLayer(interpolation, boundary=boundary) new_image = resampler(tf.constant(test_image, dtype=tf.float32), warp_coords) diff = tf.reduce_mean( tf.squared_difference(new_image, tf.constant(test_target, dtype=tf.float32))) optimiser = tf.train.AdagradOptimizer(0.01) grads = optimiser.compute_gradients(diff) opt = optimiser.apply_gradients(grads) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) init_val, affine_val = sess.run([diff, affine_var]) for _ in range(5): _, diff_val, affine_val = sess.run([opt, diff, affine_var]) print('{}, {}'.format(diff_val, affine_val[0])) self.assertGreater(init_val, diff_val)
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_samplers(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0 if for_training else -1] return sampler() # returns image only if self.is_training: self.patience = self.action_param.patience self.mode = self.action_param.early_stopping_mode if self.action_param.validation_every_n > 0: sampler_window = \ tf.cond(tf.logical_not(self.is_validation), lambda: switch_samplers(True), lambda: switch_samplers(False)) else: sampler_window = switch_samplers(True) image_windows, _ = sampler_window # image_windows, locations = sampler_window # decode channels for moving and fixed images image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1) ] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list # estimate ddf dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer(interpolation='linear', boundary='replicate') resampled_moving_label = resampler(moving_label, dense_field) # compute label loss (foreground only) loss_func = LossFunction(n_class=1, loss_type=self.action_param.loss_type, softmax=False) label_loss = loss_func(prediction=resampled_moving_label, ground_truth=fixed_label) dice_fg = 1.0 - label_loss # appending regularisation loss total_loss = label_loss reg_loss = tf.get_collection('bending_energy') if reg_loss: total_loss = total_loss + \ self.net_param.decay * tf.reduce_mean(reg_loss) self.total_loss = total_loss # compute training gradients with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) grads = self.optimiser.compute_gradients( total_loss, colocate_gradients_with_ops=True) gradients_collector.add_to_collection(grads) metrics_dice = loss_func( prediction=tf.to_float(resampled_moving_label >= 0.5), ground_truth=tf.to_float(fixed_label >= 0.5)) metrics_dice = 1.0 - metrics_dice # command line output outputs_collector.add_to_collection(var=dice_fg, name='one_minus_data_loss', collection=CONSOLE) outputs_collector.add_to_collection(var=tf.reduce_mean(reg_loss), name='bending_energy', collection=CONSOLE) outputs_collector.add_to_collection(var=total_loss, name='total_loss', collection=CONSOLE) outputs_collector.add_to_collection(var=metrics_dice, name='ave_fg_dice', collection=CONSOLE) # for tensorboard outputs_collector.add_to_collection(var=dice_fg, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=total_loss, name='total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=metrics_dice, name='averaged_foreground_Dice', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # for visualisation debugging # resampled_moving_image = resampler(moving_image, dense_field) # outputs_collector.add_to_collection( # var=fixed_image, name='fixed_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=fixed_label, name='fixed_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_image, name='moving_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_label, name='moving_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_image, name='resampled_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_label, name='resampled_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=dense_field, name='ddf', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=locations, name='locations', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=shift[0], name='a', collection=CONSOLE) # outputs_collector.add_to_collection( # var=shift[1], name='b', collection=CONSOLE) else: image_windows, locations = self.sampler() image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1) ] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer(interpolation='linear', boundary='replicate') resampled_moving_image = resampler(moving_image, dense_field) resampled_moving_label = resampler(moving_label, dense_field) outputs_collector.add_to_collection(var=fixed_image, name='fixed_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=moving_image, name='moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=resampled_moving_image, name='resampled_moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=resampled_moving_label, name='resampled_moving_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=fixed_label, name='fixed_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=moving_label, name='moving_label', collection=NETWORK_OUTPUT) #outputs_collector.add_to_collection( # var=dense_field, name='field', # collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=locations, name='locations', collection=NETWORK_OUTPUT) self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], # fixed image reader name='fixed_image', output_path=self.action_param.save_seg_dir, interp_order=self.action_param.output_interp_order)
def layer_op(self): """ This function concatenate image and label volumes at the last dim and randomly cropping the volumes (also the cropping margins) """ image_id, fixed_inputs, moving_inputs, fixed_shape, moving_shape = \ self.iterator.get_next() # TODO preprocessing layer modifying # image shapes will not be supported # assuming the same shape across modalities, using the first image_id.set_shape((self.batch_size, )) image_id = tf.to_float(image_id) fixed_inputs.set_shape((self.batch_size, ) + (None, ) * self.spatial_rank + (2, )) # last dim is 1 image + 1 label moving_inputs.set_shape((self.batch_size, ) + self.moving_image_shape + (2, )) fixed_shape.set_shape((self.batch_size, self.spatial_rank + 1)) moving_shape.set_shape((self.batch_size, self.spatial_rank + 1)) # resizing the moving_inputs to match the target # assumes the same shape across the batch target_spatial_shape = \ tf.unstack(fixed_shape[0], axis=0)[:self.spatial_rank] moving_inputs = Resize(new_size=target_spatial_shape)(moving_inputs) combined_volume = tf.concat([fixed_inputs, moving_inputs], axis=-1) # smoothing_layer = Smoothing( # sigma=1, truncate=3.0, type_str='gaussian') # combined_volume = tf.unstack(combined_volume, axis=-1) # combined_volume[0] = tf.expand_dims(combined_volume[0], axis=-1) # combined_volume[1] = smoothing_layer( # tf.expand_dims(combined_volume[1]), axis=-1) # combined_volume[2] = tf.expand_dims(combined_volume[2], axis=-1) # combined_volume[3] = smoothing_layer( # tf.expand_dims(combined_volume[3]), axis=-1) # combined_volume = tf.stack(combined_volume, axis=-1) # TODO affine data augmentation here if self.spatial_rank == 3: window_channels = np.prod(self.window_size[self.spatial_rank:]) * 4 # TODO if no affine augmentation: img_spatial_shape = target_spatial_shape win_spatial_shape = [ tf.constant(dim) for dim in self.window_size[:self.spatial_rank] ] # when img==win make sure shift => 0.0 # otherwise interpolation is out of bound batch_shift = [ tf.random_uniform(shape=(self.batch_size, 1), minval=0, maxval=tf.maximum(tf.to_float(img - win - 1), 0.01)) for (win, img) in zip(win_spatial_shape, img_spatial_shape) ] batch_shift = tf.concat(batch_shift, axis=1) affine_constraints = ((1.0, 0.0, 0.0, None), (0.0, 1.0, 0.0, None), (0.0, 0.0, 1.0, None)) computed_grid = AffineGridWarperLayer( source_shape=(None, None, None), output_shape=self.window_size[:self.spatial_rank], constraints=affine_constraints)(batch_shift) computed_grid.set_shape((self.batch_size, ) + self.window_size[:self.spatial_rank] + (self.spatial_rank, )) resampler = ResamplerLayer(interpolation='linear', boundary='replicate') windows = resampler(combined_volume, computed_grid) out_shape = [self.batch_size] + \ list(self.window_size[:self.spatial_rank]) + \ [window_channels] windows.set_shape(out_shape) image_id = tf.reshape(image_id, (self.batch_size, 1)) start_location = tf.zeros((self.batch_size, self.spatial_rank)) locations = tf.concat([image_id, start_location, batch_shift], axis=1) return windows, locations