def layer_op(self, inputs, deformation, **kwargs): nof_dims = infer_spatial_rank(inputs) nof_output_dims = infer_spatial_rank(deformation) batch_size = inputs.shape.as_list()[0] if deformation.shape.as_list()[0] != batch_size: deformation = tf.tile(deformation, [batch_size] + [1]*(nof_output_dims + 1)) output_spatial_dims = deformation.shape.as_list()[1:-1] input_dims = [d if d else -1 for d in inputs.shape.as_list()] if len(output_spatial_dims) != nof_dims: resample_def = deformation while len(resample_def.shape) < len(inputs.shape): resample_def = tf.expand_dims(resample_def, axis=len(resample_def.shape) - 2) else: resample_def = deformation assert infer_spatial_rank(resample_def) == nof_dims resampled = get_niftyreg_module().niftyreg_image_resampling( _transpose(inputs), _transpose(resample_def), interpolation=self._interpolation, boundary=__BOUNDARY_CODES__[self._boundary]) return tf.reshape( _transpose(resampled), [batch_size] + output_spatial_dims + [input_dims[-1]])
def _resample_linear(self, inputs, sample_coords): in_size = inputs.shape.as_list() in_spatial_size = in_size[1:-1] in_spatial_rank = infer_spatial_rank(inputs) batch_size = in_size[0] out_spatial_rank = infer_spatial_rank(sample_coords) out_spatial_size = sample_coords.shape.as_list()[1:-1] if in_spatial_rank == 2 and self.boundary == 'ZERO': inputs = tf.transpose(inputs, [0, 2, 1, 3]) return tf.contrib.resampler.resampler(inputs, sample_coords) xy = tf.unstack(sample_coords, axis=-1) base_coords = [tf.floor(coords) for coords in xy] floor_coords = [ tf.cast(self.boundary_func(x, in_spatial_size[idx]), COORDINATES_TYPE) for (idx, x) in enumerate(base_coords)] ceil_coords = [ tf.cast(self.boundary_func(x + 1.0, in_spatial_size[idx]), COORDINATES_TYPE) for (idx, x) in enumerate(base_coords)] if self.boundary == 'ZERO': weight_0 = [tf.expand_dims(x - tf.cast(i, tf.float32), -1) for (x, i) in zip(xy, floor_coords)] weight_1 = [tf.expand_dims(tf.cast(i, tf.float32) - x, -1) for (x, i) in zip(xy, ceil_coords)] else: weight_0 = [tf.expand_dims(x - i, -1) for (x, i) in zip(xy, base_coords)] weight_1 = [1.0 - w for w in weight_0] batch_ids = tf.reshape( tf.range(batch_size), [batch_size] + [1] * out_spatial_rank) batch_ids = tf.tile(batch_ids, [1] + out_spatial_size) sc = (floor_coords, ceil_coords) def get_knot(binary_code): coord = [sc[code][ind] for ind, code in enumerate(binary_code)] coord = tf.stack([batch_ids] + coord, -1) return tf.gather_nd(inputs, coord) def _pyramid_combination(two_samples, w_0, w_1): if len(w_0) == 1: return two_samples[0] * w_1[0] + two_samples[1] * w_0[0] f_0 = _pyramid_combination(two_samples[::2], w_0[:-1], w_1[:-1]) f_1 = _pyramid_combination(two_samples[1::2], w_0[:-1], w_1[:-1]) return f_0 * w_1[-1] + f_1 * w_0[-1] binary_neighbour_ids = [ [int(c) for c in format(i, '0%ib' % in_spatial_rank)] for i in range(2 ** in_spatial_rank)] samples = [get_knot(bc) for bc in binary_neighbour_ids] return _pyramid_combination(samples, weight_0, weight_1)
def _resample_linear(self, inputs, sample_coords): in_size = inputs.get_shape().as_list() in_spatial_size = in_size[1:-1] in_spatial_rank = infer_spatial_rank(inputs) batch_size = in_size[0] out_spatial_rank = infer_spatial_rank(sample_coords) out_spatial_size = sample_coords.get_shape().as_list()[1:-1] if in_spatial_rank == 2 and self.boundary == 'ZERO': inputs = tf.transpose(inputs, [0, 2, 1, 3]) return tf.contrib.resampler.resampler(inputs, sample_coords) xy = tf.unstack(sample_coords, axis=-1) base_coords = [tf.floor(coords) for coords in xy] floor_coords = [ tf.cast(self.boundary_func(x, in_spatial_size[idx]), COORDINATES_TYPE) for (idx, x) in enumerate(base_coords)] ceil_coords = [ tf.cast(self.boundary_func(x + 1.0, in_spatial_size[idx]), COORDINATES_TYPE) for (idx, x) in enumerate(base_coords)] if self.boundary == 'ZERO': weight_0 = [tf.expand_dims(x - tf.cast(i, tf.float32), -1) for (x, i) in zip(xy, floor_coords)] weight_1 = [tf.expand_dims(tf.cast(i, tf.float32) - x, -1) for (x, i) in zip(xy, ceil_coords)] else: weight_0 = [tf.expand_dims(x - i, -1) for (x, i) in zip(xy, base_coords)] weight_1 = [1.0 - w for w in weight_0] batch_ids = tf.reshape( tf.range(batch_size), [batch_size] + [1] * out_spatial_rank) batch_ids = tf.tile(batch_ids, [1] + out_spatial_size) sc = (floor_coords, ceil_coords) def get_knot(binary_code): coord = [sc[code][ind] for ind, code in enumerate(binary_code)] coord = tf.stack([batch_ids] + coord, -1) return tf.gather_nd(inputs, coord) def _pyramid_combination(two_samples, w_0, w_1): if len(w_0) == 1: return two_samples[0] * w_1[0] + two_samples[1] * w_0[0] f_0 = _pyramid_combination(two_samples[::2], w_0[:-1], w_1[:-1]) f_1 = _pyramid_combination(two_samples[1::2], w_0[:-1], w_1[:-1]) return f_0 * w_1[-1] + f_1 * w_0[-1] binary_neighbour_ids = [ [int(c) for c in format(i, '0%ib' % in_spatial_rank)] for i in range(2 ** in_spatial_rank)] samples = [get_knot(bc) for bc in binary_neighbour_ids] return _pyramid_combination(samples, weight_0, weight_1)
def _resample_bspline(self, inputs, sample_coords): assert inputs.shape.is_fully_defined(), \ "input shape should be fully defined for bspline interpolation" in_size = inputs.shape.as_list() batch_size = in_size[0] in_spatial_size = in_size[1:-1] in_spatial_rank = infer_spatial_rank(inputs) out_spatial_rank = infer_spatial_rank(sample_coords) if in_spatial_rank == 2: raise NotImplementedError( 'bspline interpolation not implemented for 2d yet') assert batch_size == int(sample_coords.get_shape()[0]) floor_coords = tf.floor(sample_coords) # Compute voxels to use for interpolation grid = tf.meshgrid([-1., 0., 1., 2.], [-1., 0., 1., 2.], [-1., 0., 1., 2.], indexing='ij') offset_shape = [1, -1] + [1] * out_spatial_rank + [in_spatial_rank] offsets = tf.reshape(tf.stack(grid, 3), offset_shape) spatial_coords = offsets + tf.expand_dims(floor_coords, 1) spatial_coords = self.boundary_func(spatial_coords, in_spatial_size) spatial_coords = tf.cast(spatial_coords, COORDINATES_TYPE) knot_size = spatial_coords.shape.as_list() # Compute weights for each voxel def build_coef(u, d): coeff_list = [ tf.pow(1 - u, 3), 3 * tf.pow(u, 3) - 6 * tf.pow(u, 2) + 4, -3 * tf.pow(u, 3) + 3 * tf.pow(u, 2) + 3 * u + 1, tf.pow(u, 3) ] return tf.concat(coeff_list, d) / 6 weight = tf.reshape(sample_coords - floor_coords, [batch_size, -1, 3]) coef_shape = [batch_size, 1, 1, 1, -1] Bu = build_coef(tf.reshape(weight[:, :, 0], coef_shape), 1) Bv = build_coef(tf.reshape(weight[:, :, 1], coef_shape), 2) Bw = build_coef(tf.reshape(weight[:, :, 2], coef_shape), 3) all_weights = tf.reshape(Bu * Bv * Bw, [batch_size] + knot_size[1:-1] + [1]) # Gather voxel values and compute weighted sum batch_coords = tf.reshape(tf.range(batch_size), [batch_size] + [1] * (len(knot_size) - 1)) batch_coords = tf.tile(batch_coords, [1] + knot_size[1:-1] + [1]) raw_samples = tf.gather_nd( inputs, tf.concat([batch_coords, spatial_coords], -1)) return tf.reduce_sum(all_weights * raw_samples, reduction_indices=1)
def _resample_bspline(self, inputs, sample_coords): assert inputs.shape.is_fully_defined(), \ "input shape should be fully defined for bspline interpolation" in_size = inputs.shape.as_list() batch_size = in_size[0] in_spatial_size = in_size[1:-1] in_spatial_rank = infer_spatial_rank(inputs) out_spatial_rank = infer_spatial_rank(sample_coords) if in_spatial_rank == 2: raise NotImplementedError( 'bspline interpolation not implemented for 2d yet') assert batch_size == int(sample_coords.get_shape()[0]) floor_coords = tf.floor(sample_coords) # Compute voxels to use for interpolation grid = tf.meshgrid([-1., 0., 1., 2.], [-1., 0., 1., 2.], [-1., 0., 1., 2.], indexing='ij') offset_shape = [1, -1] + [1] * out_spatial_rank + [in_spatial_rank] offsets = tf.reshape(tf.stack(grid, 3), offset_shape) spatial_coords = offsets + tf.expand_dims(floor_coords, 1) spatial_coords = self.boundary_func(spatial_coords, in_spatial_size) spatial_coords = tf.cast(spatial_coords, COORDINATES_TYPE) knot_size = spatial_coords.shape.as_list() # Compute weights for each voxel def build_coef(u, d): coeff_list = [tf.pow(1 - u, 3), 3 * tf.pow(u, 3) - 6 * tf.pow(u, 2) + 4, -3 * tf.pow(u, 3) + 3 * tf.pow(u, 2) + 3 * u + 1, tf.pow(u, 3)] return tf.concat(coeff_list, d) / 6 weight = tf.reshape(sample_coords - floor_coords, [batch_size, -1, 3]) coef_shape = [batch_size, 1, 1, 1, -1] Bu = build_coef(tf.reshape(weight[:, :, 0], coef_shape), 1) Bv = build_coef(tf.reshape(weight[:, :, 1], coef_shape), 2) Bw = build_coef(tf.reshape(weight[:, :, 2], coef_shape), 3) all_weights = tf.reshape(Bu * Bv * Bw, [batch_size] + knot_size[1:-1] + [1]) # Gather voxel values and compute weighted sum batch_coords = tf.reshape( tf.range(batch_size), [batch_size] + [1] * (len(knot_size) - 1)) batch_coords = tf.tile(batch_coords, [1] + knot_size[1:-1] + [1]) raw_samples = tf.gather_nd( inputs, tf.concat([batch_coords, spatial_coords], -1)) return tf.reduce_sum(all_weights * raw_samples, reduction_indices=1)
def layer_op(self, inputs): spatial_rank = layer_util.infer_spatial_rank(inputs) offsets = [0] + [int(self.border)] * spatial_rank + [0] # inferring the shape of the output by subtracting the border dimension out_shape = [-1] + [int(d) - 2 * int(self.border) for d in list(inputs.shape)[1:-1]] + [-1] output_tensor = tf.slice(inputs, offsets, out_shape) return output_tensor
def resample_bspline(self,inputs,sample_coords): input_size=tf.reshape(inputs.get_shape().as_list()[1:-1],[1]*(len(sample_coords.get_shape().as_list())-1)+[-1]) spatial_rank = layer_util.infer_spatial_rank(inputs) batch_size=sample_coords.get_shape().as_list()[0] grid_shape = sample_coords.get_shape().as_list()[1:-1] if spatial_rank==2: raise NotImplementedError('bspline interpolation not implemented for 2d yet') index_voxel_coords = tf.floor(sample_coords) # Compute voxels to use for interpolation grid=tf.meshgrid(list(range(-1,3)),list(range(-1,3)),list(range(-1,3)), indexing='ij') offsets = tf.reshape(tf.stack(grid,3),[1,4**spatial_rank]+[1]*len(grid_shape)+[spatial_rank]) preboundary_spatial_coords = offsets+tf.expand_dims(tf.cast(index_voxel_coords,tf.int32),1) spatial_coords = self.boundary_func_(preboundary_spatial_coords,input_size) sz=spatial_coords.get_shape().as_list() # Compute weights for each voxel build_coefficient = lambda u,d: tf.concat([tf.pow(1-u,3), 3*tf.pow(u,3) - 6*tf.pow(u,2) + 4, -3*tf.pow(u,3) + 3*tf.pow(u,2) + 3*u + 1, tf.pow(u,3)],d)/6 weight=tf.reshape(sample_coords-index_voxel_coords,[batch_size,-1,3]) Bu=build_coefficient(tf.reshape(weight[:,:,0],[batch_size,1,1,1,-1]),1) Bv=build_coefficient(tf.reshape(weight[:,:,1],[batch_size,1,1,1,-1]),2) Bw=build_coefficient(tf.reshape(weight[:,:,2],[batch_size,1,1,1,-1]),3) all_weights=tf.reshape(Bu*Bv*Bw,[batch_size] +sz[1:-1]+[1]) # Gather voxel values and compute weighted sum batch_coords = tf.tile(tf.reshape(tf.range(sz[0]),[sz[0]]+[1]*(len(sz)-1)),[1]+sz[1:-1]+[1]) raw_samples = tf.gather_nd(inputs,tf.concat([batch_coords,spatial_coords],-1)) return tf.reduce_sum(all_weights*raw_samples,reduction_indices=1)
def resample_linear(self,inputs,sample_coords): input_size = inputs.get_shape().as_list()[1:-1] spatial_rank = layer_util.infer_spatial_rank(inputs) xy=tf.unstack(sample_coords,axis=len(sample_coords.get_shape())-1) index_voxel_coords = [tf.floor(x) for x in xy] spatial_coords=[self.boundary_func_(tf.cast(x,tf.int32), input_size[idx]) for idx,x in enumerate(index_voxel_coords)] spatial_coords_plus1=[self.boundary_func_(tf.cast(x+1.,tf.int32), input_size[idx]) for idx,x in enumerate(index_voxel_coords)] if self.boundary == 'ZERO': # weight = [tf.expand_dims(x - tf.cast(i, tf.float32), -1) for x, i in zip(xy, spatial_coords)] weight_c = [tf.expand_dims(tf.cast(i, tf.float32) - x, -1) for x, i in zip(xy, spatial_coords_plus1)] else: weight = [tf.expand_dims(x - i, -1) for x, i in zip(xy, index_voxel_coords)] weight_c = [1. - w for w in weight] sz = spatial_coords[0].get_shape().as_list() batch_coords = tf.tile(tf.reshape(tf.range(sz[0]), [sz[0]] + [1] * (len(sz) - 1)), [1] + sz[1:] ) sc=(spatial_coords,spatial_coords_plus1) binary_codes = [[int(c) for c in format(i,'0%ib'%spatial_rank)] for i in range(2**spatial_rank)] make_sample = lambda bc: tf.gather_nd(inputs, tf.stack([batch_coords] + [sc[c][i] for i,c in enumerate(bc)] , -1)) samples = [make_sample(bc) for bc in binary_codes] def pyramid_combination(samples,weight,weight_c): if len(weight)==1: return samples[0]*weight_c[0]+samples[1]*weight[0] else: return pyramid_combination(samples[::2], weight[:-1], weight_c[:-1]) * weight_c[-1] + \ pyramid_combination(samples[1::2], weight[:-1], weight_c[:-1]) * weight[-1] return pyramid_combination(samples, weight, weight_c)
def do_conv(input_tensor, dim): spatial_rank = infer_spatial_rank(input_tensor) assert dim < spatial_rank if dim < 0: return input_tensor _sigmas = expand_spatial_params(input_param=1.5, spatial_rank=spatial_rank, param_type=float) _truncate = expand_spatial_params(input_param=3.0, spatial_rank=spatial_rank, param_type=float) # squeeze the kernel to be along the 'dim' new_kernel_shape = [1] * (spatial_rank + 2) new_kernel_shape[dim] = -1 kernel_tensor = gaussian_1d(sigma=_sigmas[dim], truncated=_truncate[dim]) kernel_tensor = tf.reshape(kernel_tensor, new_kernel_shape) # split channels and do smoothing respectively chn_wise_list = tf.unstack(do_conv(input_tensor, dim - 1), axis=-1) output_tensor = [ tf.nn.convolution(input=tf.expand_dims(chn, axis=-1), filter=kernel_tensor, padding='VALID', strides=[1] * spatial_rank) for chn in chn_wise_list ] return tf.concat(output_tensor, axis=-1)
def ftheta(U, H1, permutohedrals, mu, kernel_weights, aspect_ratio, name): nCh = U.shape.as_list()[-1] batch_size = int(U.shape[0]) # Message Passing data = tf.reshape(tf.nn.softmax(H1), [batch_size, -1, nCh]) Q1 = [None] * len(permutohedrals) with tf.device('/cpu:0'): for idx, permutohedral in enumerate(permutohedrals): Q1[idx] = tf.reshape( permutohedral_gen(permutohedral, data, name + str(idx)), U.get_shape()) # Weighting Filter Outputs Q2 = tf.add_n([Q1 * w for Q1, w in zip(Q1, kernel_weights)]) # Compatibility Transform spatial_dim = infer_spatial_rank(U) if spatial_dim == 2: Q3 = tf.nn.conv2d(Q2, mu, strides=[1, 1, 1, 1], padding='SAME') elif spatial_dim == 3: Q3 = tf.nn.conv3d(Q2, mu, strides=[1, 1, 1, 1, 1], padding='SAME') else: raise NotImplementedError( 'CRFAsRNNLayer is only implemented for 2d and 3d images.') # Adding Unary Potentials Q4 = U - Q3 # Normalizing return Q4 # output logits, not the softmax
def layer_op(self, param_flow, bypass_flow): n_param_flow = param_flow.shape[-1] n_bypass_flow = bypass_flow.shape[-1] spatial_rank = layer_util.infer_spatial_rank(param_flow) output_tensor = param_flow if self.func == 'SUM': if n_param_flow > n_bypass_flow: # pad the channel dim pad_1 = np.int((n_param_flow - n_bypass_flow) // 2) pad_2 = np.int(n_param_flow - n_bypass_flow - pad_1) padding_dims = np.vstack(([[0, 0]], [[0, 0]] * spatial_rank, [[pad_1, pad_2]])) bypass_flow = tf.pad(tensor=bypass_flow, paddings=padding_dims.tolist(), mode='CONSTANT') elif n_param_flow < n_bypass_flow: # make a projection projector = ConvLayer(n_output_chns=n_param_flow, kernel_size=1, stride=1, padding='SAME', w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], name='proj') bypass_flow = projector(bypass_flow) # element-wise sum of both paths output_tensor = param_flow + bypass_flow elif self.func == 'CONCAT': output_tensor = tf.concat([param_flow, bypass_flow], axis=-1) return output_tensor
def layer_op(self, param_flow, bypass_flow): n_param_flow = param_flow.get_shape()[-1] n_bypass_flow = bypass_flow.get_shape()[-1] spatial_rank = layer_util.infer_spatial_rank(param_flow) output_tensor = param_flow if self.func == 'SUM': if n_param_flow > n_bypass_flow: # pad the channel dim pad_1 = np.int((n_param_flow - n_bypass_flow) // 2) pad_2 = np.int(n_param_flow - n_bypass_flow - pad_1) padding_dims = np.vstack( ([[0, 0]], [[0, 0]] * spatial_rank, [[pad_1, pad_2]])) bypass_flow = tf.pad(tensor=bypass_flow, paddings=padding_dims.tolist(), mode='CONSTANT') elif n_param_flow < n_bypass_flow: # make a projection projector = ConvLayer(n_output_chns=n_param_flow, kernel_size=1, stride=1, padding='SAME', w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], name='proj') bypass_flow = projector(bypass_flow) # element-wise sum of both paths output_tensor = param_flow + bypass_flow elif self.func == 'CONCAT': output_tensor = tf.concat([param_flow, bypass_flow], axis=-1) return output_tensor
def layer_op(self, I,U): """ Parameters: I: feature maps defining the non-spatial dimensions within which smoothing is performed For example, to smooth U within regions of similar intensity this would be the image intensity U: activation maps to smooth """ spatial_dim = infer_spatial_rank(U) if self._aspect_ratio is None: self._aspect_ratio = [1.] * spatial_dim batch_size=int(U.shape[0]) H1=[U] # Build permutohedral structures for smoothing coords=tf.tile(tf.expand_dims(tf.stack(tf.meshgrid(*[numpy.array(range(int(i)),dtype=numpy.float32)*a for i,a in zip(U.shape[1:spatial_dim+1],self._aspect_ratio)],indexing='ij'),spatial_dim),0),[batch_size]+[1]*spatial_dim+[1]) print(coords.shape, I.shape) bilateralCoords =tf.reshape(tf.concat([coords/self._alpha,I/self._beta],-1),[batch_size,-1,int(I.shape[-1])+spatial_dim]) spatialCoords=tf.reshape(coords/self._gamma,[batch_size,-1,spatial_dim]) kernel_coords=[bilateralCoords,spatialCoords] permutohedrals = [permutohedral_prepare(coords) for coords in kernel_coords] nCh=U.shape[-1] mu = tf.get_variable('Compatibility',initializer=tf.constant(numpy.reshape(numpy.eye(nCh),[1]*spatial_dim+[nCh,nCh]),dtype=tf.float32)) kernel_weights = [tf.get_variable("FilterWeights"+str(idx), shape=[1]*spatial_dim+[1,nCh], initializer=tf.zeros_initializer()) for idx,k in enumerate(permutohedrals)] for t in range(self._T): H1.append(ftheta(U,H1[-1],permutohedrals,mu,kernel_weights, aspect_ratio=self._aspect_ratio,name=self._name+str(t))) return H1[-1]
def layer_op(self, input_tensor): spatial_rank = layer_util.infer_spatial_rank(input_tensor) look_up_operations(self.func, SUPPORTED_OP) kernel_size_all_dims = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) stride_all_dims = layer_util.expand_spatial_params( self.stride, spatial_rank) if self.func == 'CONSTANT': full_kernel_size = kernel_size_all_dims + (1, 1) np_kernel = layer_util.trivial_kernel(full_kernel_size) kernel = tf.constant(np_kernel, dtype=tf.float32) output_tensor = [ tf.expand_dims(x, -1) for x in tf.unstack(input_tensor, axis=-1) ] output_tensor = [ tf.nn.convolution(input=inputs, filter=kernel, strides=stride_all_dims, padding=self.padding, name='conv') for inputs in output_tensor ] output_tensor = tf.concat(output_tensor, axis=-1) else: output_tensor = tf.nn.pool(input=input_tensor, window_shape=kernel_size_all_dims, pooling_type=self.func, padding=self.padding, dilation_rate=[1] * spatial_rank, strides=stride_all_dims, name=self.layer_name) return output_tensor
def layer_op(self, input_tensor, is_training): output_tensor = input_tensor for i in range(len(self.kernels)): # create parameterised layers input_shape = input_tensor.shape.as_list() n_input_chns = input_shape[-1] spatial_rank = layer_util.infer_spatial_rank(input_tensor) w_full_size = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) w_full_size = w_full_size + (n_input_chns, self.n_output_chns) conv_kernel = tf.get_variable('w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) alphas = tf.get_variable( 'alpha', input_tensor.shape[-1], initializer=tf.constant_initializer(0.0), regularizer=None) output_tensor = tf.layers.batch_normalization(input=output_tensor,alphas) output_tensor = self.prelu(input_tensor,name='acti_{}'.format(i)) output_tensor = tf.nn.convolution(input=output_tensor, filter=conv_kernel, strides=self.strides, dilation_rate=self.dilation_rates, padding=self.padding, name='conv_{}'.format(i)) output_tensor = ElementwiseLayer('SUM')(output_tensor, input_tensor) return output_tensor
def layer_op(self, input_tensor): spatial_rank = layer_util.infer_spatial_rank(input_tensor) look_up_operations(self.func, SUPPORTED_OP) kernel_size_all_dims = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) stride_all_dims = layer_util.expand_spatial_params( self.stride, spatial_rank) if self.func == 'CONSTANT': full_kernel_size = kernel_size_all_dims + (1, 1) np_kernel = layer_util.trivial_kernel(full_kernel_size) kernel = tf.constant(np_kernel, dtype=tf.float32) output_tensor = [tf.expand_dims(x, -1) for x in tf.unstack(input_tensor, axis=-1)] output_tensor = [ tf.nn.convolution( input=inputs, filter=kernel, strides=stride_all_dims, padding=self.padding, name='conv') for inputs in output_tensor] output_tensor = tf.concat(output_tensor, axis=-1) else: output_tensor = tf.nn.pool( input=input_tensor, window_shape=kernel_size_all_dims, pooling_type=self.func, padding=self.padding, dilation_rate=[1] * spatial_rank, strides=stride_all_dims, name=self.layer_name) return output_tensor
def layer_op(self, I, U): """ Parameters: I: feature maps defining the non-spatial dimensions within which smoothing is performed For example, to smooth U within regions of similar intensity this would be the image intensity U: activation maps to smooth """ spatial_dim = infer_spatial_rank(U) if self._aspect_ratio is None: self._aspect_ratio = [1.] * spatial_dim batch_size = int(U.shape[0]) H1 = [U] # Build permutohedral structures for smoothing coords = tf.tile( tf.expand_dims( tf.stack( tf.meshgrid(*[ numpy.array(range(int(i)), dtype=numpy.float32) * a for i, a in zip(U.shape[1:spatial_dim + 1], self._aspect_ratio) ], indexing='ij'), spatial_dim), 0), [batch_size] + [1] * spatial_dim + [1]) print(coords.shape, I.shape) bilateralCoords = tf.reshape( tf.concat([coords / self._alpha, I / self._beta], -1), [batch_size, -1, int(I.shape[-1]) + spatial_dim]) spatialCoords = tf.reshape(coords / self._gamma, [batch_size, -1, spatial_dim]) kernel_coords = [bilateralCoords, spatialCoords] permutohedrals = [ permutohedral_prepare(coords) for coords in kernel_coords ] nCh = U.shape[-1] mu = tf.get_variable('Compatibility', initializer=tf.constant(numpy.reshape( numpy.eye(nCh), [1] * spatial_dim + [nCh, nCh]), dtype=tf.float32)) kernel_weights = [ tf.get_variable("FilterWeights" + str(idx), shape=[1] * spatial_dim + [1, nCh], initializer=tf.zeros_initializer()) for idx, k in enumerate(permutohedrals) ] for t in range(self._T): H1.append( ftheta(U, H1[-1], permutohedrals, mu, kernel_weights, aspect_ratio=self._aspect_ratio, name=self._name + str(t))) return H1[-1]
def _computing_bending_energy(displacement): spatial_rank = infer_spatial_rank(displacement) if spatial_rank == 2: return _computing_bending_energy_2d(displacement) if spatial_rank == 3: return _computing_bending_energy_3d(displacement) raise NotImplementedError( "Not implmented: bending energy for {}-d input".format(spatial_rank))
def _computing_bending_energy(displacement): spatial_rank = infer_spatial_rank(displacement) if spatial_rank == 2: return _computing_bending_energy_2d(displacement) if spatial_rank == 3: return _computing_bending_energy_3d(displacement) raise NotImplementedError( "Not implmented: bending energy for {}-d input".format(spatial_rank))
def _computing_gradient_norm(displacement, flag_L1=False): norms = [] for spatial_ind in range(infer_spatial_rank(displacement)): dTdt = ImgGrad(spatial_axis=spatial_ind)(displacement) if flag_L1: norms.append(tf.abs(dTdt)) else: norms.append(dTdt * dTdt) return tf.reduce_mean(norms)
def layer_op(self, input_tensor): input_shape = input_tensor.shape.as_list() n_input_chns = input_shape[-1] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply kernel_size_all_dim = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) w_full_size = kernel_size_all_dim + (self.n_output_chns, n_input_chns) stride_all_dim = layer_util.expand_spatial_params( self.stride, spatial_rank) full_stride = (1,) + stride_all_dim + (1,) deconv_kernel = tf.get_variable( 'w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) if spatial_rank == 2: op_ = SUPPORTED_OP['2D'] elif spatial_rank == 3: op_ = SUPPORTED_OP['3D'] else: raise ValueError( "Only 2D and 3D spatial deconvolutions are supported") spatial_shape = [] for (i, dim) in enumerate(input_shape[:-1]): if i == 0: continue if dim is None: spatial_shape.append(tf.shape(input_tensor)[i]) else: spatial_shape.append(dim) output_dims = infer_output_dims(spatial_shape, stride_all_dim, kernel_size_all_dim, self.padding) full_output_size = [input_shape[0]] + output_dims + [self.n_output_chns] output_tensor = op_(value=input_tensor, filter=deconv_kernel, output_shape=full_output_size, strides=full_stride, padding=self.padding, name='deconv') if not self.with_bias: return output_tensor # adding the bias term bias_full_size = (self.n_output_chns,) bias_term = tf.get_variable( 'b', shape=bias_full_size, initializer=self.initializers['b'], regularizer=self.regularizers['b']) output_tensor = tf.nn.bias_add(output_tensor, bias_term, name='add_bias') return output_tensor
def _computing_gradient_norm(displacement, flag_L1=False): norms = [] for spatial_ind in range(infer_spatial_rank(displacement)): dTdt = ImgGrad(spatial_axis=spatial_ind)(displacement) if flag_L1: norms.append(tf.abs(dTdt)) else: norms.append(dTdt * dTdt) return tf.reduce_mean(norms)
def __init__(self, input_tensor, dilation_factor): assert (layer_util.check_spatial_dims( input_tensor, lambda x: x % dilation_factor == 0)) self._tensor = input_tensor self.dilation_factor = dilation_factor # parameters to transform input tensor self.spatial_rank = layer_util.infer_spatial_rank(self._tensor) self.zero_paddings = [[0, 0]] * self.spatial_rank self.block_shape = [dilation_factor] * self.spatial_rank
def layer_op(self, input_tensor): input_shape = input_tensor.shape.as_list() n_input_chns = input_shape[-1] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply kernel_size_all_dim = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) w_full_size = kernel_size_all_dim + (self.n_output_chns, n_input_chns) stride_all_dim = layer_util.expand_spatial_params( self.stride, spatial_rank) full_stride = (1, ) + stride_all_dim + (1, ) deconv_kernel = tf.get_variable('w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) if spatial_rank == 2: op_ = SUPPORTED_OP['2D'] elif spatial_rank == 3: op_ = SUPPORTED_OP['3D'] else: raise ValueError( "Only 2D and 3D spatial deconvolutions are supported") spatial_shape = [] for (i, dim) in enumerate(input_shape[:-1]): if i == 0: continue if dim is None: spatial_shape.append(tf.shape(input_tensor)[i]) else: spatial_shape.append(dim) output_dims = infer_output_dims(spatial_shape, stride_all_dim, kernel_size_all_dim, self.padding) full_output_size = [input_shape[0] ] + output_dims + [self.n_output_chns] output_tensor = op_(value=input_tensor, filter=deconv_kernel, output_shape=full_output_size, strides=full_stride, padding=self.padding, name='deconv') if not self.with_bias: return output_tensor # adding the bias term bias_full_size = (self.n_output_chns, ) bias_term = tf.get_variable('b', shape=bias_full_size, initializer=self.initializers['b'], regularizer=self.regularizers['b']) output_tensor = tf.nn.bias_add(output_tensor, bias_term, name='add_bias') return output_tensor
def resample_nearest(self,inputs,sample_coords): input_size=tf.reshape(inputs.get_shape().as_list()[1:-1],[1]*(len(sample_coords.get_shape().as_list())-1)+[-1] ) spatial_rank = layer_util.infer_spatial_rank(inputs) spatial_coords = self.boundary_func_(tf.cast(tf.round(sample_coords), tf.int32), input_size); sz=spatial_coords.get_shape().as_list() batch_coords = tf.tile(tf.reshape(tf.range(sz[0]),[sz[0]]+[1]*(len(sz)-1)),[1]+sz[1:-1]+[1]) output = tf.gather_nd(inputs,tf.concat([batch_coords,spatial_coords],-1)) if self.boundary == 'ZERO': scale = 1./tf.to_float(input_size-1) mask = tf.to_float(tf.logical_and(tf.reduce_any(sample_coords>0,axis=-1,keep_dims=True), tf.reduce_any(scale*sample_coords<1,axis=-1,keep_dims=True))) output=output*mask return output
def _computing_bending_energy(displacement): """ :param displacement: tensor, displacement field :return: bending energy """ spatial_rank = infer_spatial_rank(displacement) if spatial_rank == 2: return _computing_bending_energy_2d(displacement) if spatial_rank == 3: return _computing_bending_energy_3d(displacement) raise NotImplementedError( "Not implmented: bending energy for {}-d input".format(spatial_rank))
def layer_op(self, fixed_image, moving_image, is_training=True, **unused_kwargs): """ :param fixed_image: tensor, fixed image for registration (defines reference space) :param moving_image: tensor, moving image to be registered to fixed :param is_training: boolean, True if network is in training mode :return: displacement fields transformed by estimating affine """ spatial_rank = infer_spatial_rank(moving_image) spatial_shape = fixed_image.get_shape().as_list()[1:-1] # resize the moving image to match the fixed moving_image = Resize(spatial_shape)(moving_image) img = tf.concat([moving_image, fixed_image], axis=-1) res_1 = DownRes(self.fea[0], kernel_size=7, **self.res_param)(img, is_training)[0] res_2 = DownRes(self.fea[1], **self.res_param)(res_1, is_training)[0] res_3 = DownRes(self.fea[2], **self.res_param)(res_2, is_training)[0] res_4 = DownRes(self.fea[3], **self.res_param)(res_3, is_training)[0] conv_5 = Conv(n_output_chns=self.fea[4], kernel_size=self.k_conv, with_bias=False, feature_normalization='batch', **self.res_param)(res_4, is_training) if spatial_rank == 2: affine_size = 6 elif spatial_rank == 3: affine_size = 12 else: tf.logging.fatal('Not supported spatial rank') raise NotImplementedError if self.affine_w_initializer is None: self.affine_w_initializer = init_affine_w() if self.affine_b_initializer is None: self.affine_b_initializer = init_affine_b(spatial_rank) affine = FC(n_output_chns=affine_size, feature_normalization=None, w_initializer=self.affine_w_initializer, b_initializer=self.affine_b_initializer, **self.affine_param)(conv_5) grid_global = Grid(source_shape=spatial_shape, output_shape=spatial_shape)(affine) return grid_global
def layer_op(self, input_tensor): """ Computing spatial gradient of ``input_tensor`` along ``self.spatial_axis``. output is equivalent to convolve along ``spatial_axis`` with a kernel: ``[-1, 0, 1]`` This layer assumes the first and the last dimension of the input tensor represent batch and feature channels. Therefore ``spatial_axis=1`` is computing gradient along the third dimension of input tensor, i.e., ``input_tensor[:, :, y, ...]`` Given the input with shape ``[B, X, Y, Z, C]``, and ``spatial_axis=1`` the output shape is:: [B, X-2, Y-2, Z-2, C] if do_scropping is True [B, X, Y-2, Z, C] otherwise Setting do_cropping to True makes the output tensor has the same dimensionality for different ``spatial_axis``. :param input_tensor: a batch of images with a shape of ``[Batch, x[, y, z, ... ], Channel]`` :return: spatial gradients of ``input_tensor`` """ spatial_rank = infer_spatial_rank(input_tensor) spatial_size = input_tensor.get_shape().as_list()[1:-1] if self.do_cropping: # remove two elements in all spatial dims spatial_size = [size_x - 2 for size_x in spatial_size] spatial_begins = [1] * spatial_rank else: # remove two elements along the gradient dim only spatial_size[self.spatial_axis] = spatial_size[self.spatial_axis] -2 spatial_begins = [0] * spatial_rank spatial_begins[self.spatial_axis] = 2 begins_0 = [0] + spatial_begins + [0] spatial_begins[self.spatial_axis] = 0 begins_1 = [0] + spatial_begins + [0] sizes_0 = [-1] + spatial_size + [-1] sizes_1 = [-1] + spatial_size + [-1] image_gradients = \ tf.slice(input_tensor, begins_0, sizes_0) - \ tf.slice(input_tensor, begins_1, sizes_1) return image_gradients
def layer_op(self, input_tensor): input_shape = input_tensor.get_shape().as_list() n_input_chns = input_shape[-1] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply w_full_size = layer_util.expand_spatial_params(self.kernel_size, spatial_rank) # expand kernel size to include number of features w_full_size = w_full_size + (n_input_chns, self.n_output_chns) full_stride = layer_util.expand_spatial_params(self.stride, spatial_rank) full_dilation = layer_util.expand_spatial_params( self.dilation, spatial_rank) conv_kernel = tf.get_variable('w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) print("W FULL SIZEEEE", w_full_size) output_tensor = tf.nn.convolution(input=input_tensor, filter=conv_kernel, strides=full_stride, dilation_rate=full_dilation, padding=self.padding, name='conv') OUTPUT_TENSOR = tf.nn.convolution(input=input_tensor, filter=conv_kernel, strides=full_stride, dilation_rate=full_dilation, padding=self.padding, name='CONVV') #print("OUTPUT TENSOR NAMEEEEEEEEEEEE", OUTPUT_TENSOR.name) if not self.with_bias: return output_tensor # adding the bias term bias_term = tf.get_variable('b', shape=self.n_output_chns, initializer=self.initializers['b'], regularizer=self.regularizers['b']) output_tensor = tf.nn.bias_add(output_tensor, bias_term, name='add_bias') return output_tensor
def layer_op(self, input_tensor): """ Resize the image by linearly interpolating the input using TF ``resize_bilinear`` function. :param input_tensor: 2D/3D image tensor, with shape: ``batch, X, Y, [Z,] Channels`` :return: interpolated volume """ input_spatial_rank = infer_spatial_rank(input_tensor) assert input_spatial_rank in (2, 3), \ "linearly interpolation layer can only be applied to " \ "2D/3D images (4D or 5D tensor)." self.new_size = expand_spatial_params(self.new_size, input_spatial_rank) if input_spatial_rank == 2: return tf.image.resize_bilinear(input_tensor, self.new_size) b_size, x_size, y_size, z_size, c_size = \ input_tensor.shape.as_list() x_size_new, y_size_new, z_size_new = self.new_size if (x_size == x_size_new) and (y_size == y_size_new) and ( z_size == z_size_new): # already in the target shape return input_tensor # resize y-z squeeze_b_x = tf.reshape( input_tensor, [-1, y_size, z_size, c_size]) resize_b_x = tf.image.resize_bilinear( squeeze_b_x, [y_size_new, z_size_new]) resume_b_x = tf.reshape( resize_b_x, [b_size, x_size, y_size_new, z_size_new, c_size]) # resize x # first reorient reoriented = tf.transpose(resume_b_x, [0, 3, 2, 1, 4]) # squeeze and 2d resize squeeze_b_z = tf.reshape( reoriented, [-1, y_size_new, x_size, c_size]) resize_b_z = tf.image.resize_bilinear( squeeze_b_z, [y_size_new, x_size_new]) resume_b_z = tf.reshape( resize_b_z, [b_size, z_size_new, y_size_new, x_size_new, c_size]) output_tensor = tf.transpose(resume_b_z, [0, 3, 2, 1, 4]) return output_tensor
def _computing_gradient_norm(displacement, flag_L1=False): """ :param displacement: tensor, displacement field :param flag_L1: boolean, True if L1 norm shoudl be used :return: L2 (or L1) norm of gradients """ norms = [] for spatial_ind in range(infer_spatial_rank(displacement)): dTdt = ImgGrad(spatial_axis=spatial_ind)(displacement) if flag_L1: norms.append(tf.abs(dTdt)) else: norms.append(dTdt * dTdt) return tf.reduce_mean(norms)
def layer_op(self, inputs): spatial_rank = layer_util.infer_spatial_rank(inputs) if isinstance(self.border, list): self.border = self.border else: self.border = [int(self.border)] * spatial_rank offsets = [0] + self.border + [0] # inferring the shape of the output by subtracting the border dimension out_shape = [ int(d) - 2 * b for (d, b) in zip(list(inputs.shape)[1:-1], offsets[1:-1])] out_shape = [-1] + out_shape + [-1] output_tensor = tf.slice(inputs, offsets, out_shape) return output_tensor
def layer_op(self, input_tensor): input_shape = input_tensor.shape.as_list() n_input_chns = input_shape[-1] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply w_full_size = layer_util.expand_spatial_params(self.kernel_size, spatial_rank) # expand kernel size to include number of features w_full_size = w_full_size + (n_input_chns, self.n_output_chns) full_stride = layer_util.expand_spatial_params(self.stride, spatial_rank) full_dilation = layer_util.expand_spatial_params( self.dilation, spatial_rank) conv_kernel = tf.get_variable('w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) if self.padding in ('VALID', 'SAME'): output_tensor = tf.nn.convolution(input=input_tensor, filter=conv_kernel, strides=full_stride, dilation_rate=full_dilation, padding=self.padding, name='conv') else: output_tensor = _extended_convolution( input_tensor, conv_kernel, full_stride, full_dilation, self.padding, constant=self.padding_constant) if not self.with_bias: return output_tensor # adding the bias term bias_term = tf.get_variable('b', shape=self.n_output_chns, initializer=self.initializers['b'], regularizer=self.regularizers['b']) output_tensor = tf.nn.bias_add(output_tensor, bias_term, name='add_bias') return output_tensor
def layer_op(self, images, is_training=True, **unused_kwargs): """ :param images: tensor, input to the network :param is_training: boolean, True if network is in training mode :param unused_kwargs: not in use :return: tensor, output of the final fully connected layer """ layers = self.create() out = layers.conv1(images, is_training) for block in layers.blocks: out = block(out, is_training) spatial_rank = layer_util.infer_spatial_rank(out) axis_to_avg = [dim + 1 for dim in range(spatial_rank)] out = tf.reduce_mean(tf.nn.relu(layers.bn(out, is_training)), axis=axis_to_avg) return layers.fc(out)
def layer_op(self, inputs): spatial_rank = layer_util.infer_spatial_rank(inputs) kernel_shape = np.hstack(( [self.border * 2 + 1] * spatial_rank, 1, 1)).flatten() # initializer a kernel with all 0s, and set the central element to 1 np_kernel = layer_util.trivial_kernel(kernel_shape) crop_kernel = tf.constant(np_kernel, dtype=inputs.dtype) # split channel dim output_tensor = [tf.expand_dims(x, -1) for x in tf.unstack(inputs, axis=-1)] output_tensor = [tf.nn.convolution(input=inputs, filter=crop_kernel, strides=[1] * spatial_rank, padding='VALID', name='conv') for inputs in output_tensor] output_tensor = tf.concat(output_tensor, axis=-1) return output_tensor
def layer_op(self, fixed_image, moving_image, is_training=True): """ :param fixed_image: :param moving_image: :param is_training: :return: displacement fields transformed by estimating affine """ spatial_rank = infer_spatial_rank(moving_image) spatial_shape = fixed_image.get_shape().as_list()[1:-1] # resize the moving image to match the fixed moving_image = Resize(spatial_shape)(moving_image) img = tf.concat([moving_image, fixed_image], axis=-1) res_1 = DownRes(self.fea[0], kernel_size=7, **self.res_param)(img, is_training)[0] res_2 = DownRes(self.fea[1], **self.res_param)(res_1, is_training)[0] res_3 = DownRes(self.fea[2], **self.res_param)(res_2, is_training)[0] res_4 = DownRes(self.fea[3], **self.res_param)(res_3, is_training)[0] conv_5 = Conv(n_output_chns=self.fea[4], kernel_size=self.k_conv, with_bias=False, with_bn=True, **self.res_param)(res_4, is_training) if spatial_rank == 2: affine_size = 6 elif spatial_rank == 3: affine_size = 12 else: tf.logging.fatal('Not supported spatial rank') raise NotImplementedError if self.affine_w_initializer is None: self.affine_w_initializer = init_affine_w() if self.affine_b_initializer is None: self.affine_b_initializer = init_affine_b(spatial_rank) affine = FC(n_output_chns=affine_size, with_bn=False, w_initializer=self.affine_w_initializer, b_initializer=self.affine_b_initializer, **self.affine_param)(conv_5) grid_global = Grid(source_shape=spatial_shape, output_shape=spatial_shape)(affine) return grid_global
def layer_op(self, inputs): spatial_rank = layer_util.infer_spatial_rank(inputs) kernel_shape = np.hstack( ([self.border * 2 + 1] * spatial_rank, 1, 1)).flatten() # initializer a kernel with all 0s, and set the central element to 1 np_kernel = layer_util.trivial_kernel(kernel_shape) crop_kernel = tf.constant(np_kernel, dtype=inputs.dtype) # split channel dim output_tensor = [ tf.expand_dims(x, -1) for x in tf.unstack(inputs, axis=-1) ] output_tensor = [ tf.nn.convolution(input=inputs, filter=crop_kernel, strides=[1] * spatial_rank, padding='VALID', name='conv') for inputs in output_tensor ] output_tensor = tf.concat(output_tensor, axis=-1) return output_tensor
def layer_op(self, input_tensor): spatial_rank = layer_util.infer_spatial_rank(input_tensor) output_tensor = input_tensor if self.func == 'REPLICATE': if self.kernel_size != self.stride: raise ValueError( "`kernel_size` != `stride` currently not" "supported in `REPLICATE` mode. Please" "consider using `CHANNELWISE_DECONV` operation.") # simply replicate input values to # local regions of (kernel_size ** spatial_rank) element kernel_size_all_dims = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) pixel_num = np.prod(kernel_size_all_dims) repmat = np.hstack((pixel_num, [1] * spatial_rank, 1)).flatten() output_tensor = tf.tile(input=input_tensor, multiples=repmat) output_tensor = tf.batch_to_space_nd( input=output_tensor, block_shape=kernel_size_all_dims, crops=[[0, 0]] * spatial_rank) elif self.func == 'CHANNELWISE_DECONV': output_tensor = [ tf.expand_dims(x, -1) for x in tf.unstack(input_tensor, axis=-1) ] output_tensor = [ DeconvLayer(n_output_chns=1, kernel_size=self.kernel_size, stride=self.stride, padding='SAME', with_bias=self.with_bias, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], b_initializer=self.initializers['b'], b_regularizer=self.regularizers['b'], name='deconv_{}'.format(i))(x) for (i, x) in enumerate(output_tensor) ] output_tensor = tf.concat(output_tensor, axis=-1) return output_tensor
def layer_op(self, input_tensor): input_shape = input_tensor.shape.as_list() batch_size = input_shape[0] spatial_shape = input_shape[1:-1] spatial_rank = infer_spatial_rank(input_tensor) if self._transform is None: relative_transform = self._random_transform( batch_size, spatial_rank) self._transform = relative_transform else: relative_transform = self._transform grid_warper = AffineGridWarperLayer(spatial_shape, spatial_shape) resampler = ResamplerLayer(interpolation=self.interpolation, boundary=self.boundary) warp_parameters = tf.reshape(relative_transform[:, :spatial_rank, :], [batch_size, -1]) grid = grid_warper(warp_parameters) resampled = resampler(input_tensor, grid) return resampled
def layer_op(self, image): """ :param image: in shape `(batch, x[, y, z], feature_channels)` :return: spatially smoothed image """ spatial_rank = infer_spatial_rank(image) _sigmas = expand_spatial_params(input_param=self.sigma, spatial_rank=spatial_rank, param_type=float) _truncate = expand_spatial_params(input_param=self.truncate, spatial_rank=spatial_rank, param_type=float) if not all(_sigmas): # return the original image if any sigma is zero return image def do_conv(input_tensor, dim): assert dim < spatial_rank if dim < 0: return input_tensor # squeeze the kernel to be along the 'dim' new_kernel_shape = [1] * (spatial_rank + 2) new_kernel_shape[dim] = -1 kernel_tensor = self.kernel_func( sigma=_sigmas[dim], truncated=_truncate[dim]) kernel_tensor = tf.reshape(kernel_tensor, new_kernel_shape) # split channels and do smoothing respectively chn_wise_list = tf.unstack(do_conv(input_tensor, dim - 1), axis=-1) output_tensor = [ tf.nn.convolution(input=tf.expand_dims(chn, axis=-1), filter=kernel_tensor, padding='SAME', strides=[1] * spatial_rank) for chn in chn_wise_list] return tf.concat(output_tensor, axis=-1) return do_conv(image, spatial_rank - 1)
def ftheta(U,H1,permutohedrals,mu,kernel_weights, aspect_ratio,name): nCh=U.shape.as_list()[-1] batch_size=int(U.shape[0]) # Message Passing data=tf.reshape(tf.nn.softmax(H1),[batch_size,-1,nCh]) Q1=[None]*len(permutohedrals) with tf.device('/cpu:0'): for idx,permutohedral in enumerate(permutohedrals): Q1[idx] = tf.reshape(permutohedral_gen(permutohedral,data,name+str(idx)),U.get_shape()) # Weighting Filter Outputs Q2=tf.add_n([Q1*w for Q1,w in zip(Q1,kernel_weights)]) # Compatibility Transform spatial_dim = infer_spatial_rank(U) if spatial_dim == 2: Q3=tf.nn.conv2d(Q2,mu,strides=[1,1,1,1],padding='SAME') elif spatial_dim == 3: Q3=tf.nn.conv3d(Q2,mu,strides=[1,1,1,1,1],padding='SAME') else: raise NotImplementedError('CRFAsRNNLayer is only implemented for 2d and 3d images.') # Adding Unary Potentials Q4=U-Q3 # Normalizing return Q4 # output logits, not the softmax
def layer_op(self, input_tensor): spatial_rank = layer_util.infer_spatial_rank(input_tensor) output_tensor = input_tensor if self.func == 'REPLICATE': if self.kernel_size != self.stride: raise ValueError( "`kernel_size` != `stride` currently not" "supported in `REPLICATE` mode. Please" "consider using `CHANNELWISE_DECONV` operation.") # simply replicate input values to # local regions of (kernel_size ** spatial_rank) element kernel_size_all_dims = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) pixel_num = np.prod(kernel_size_all_dims) repmat = np.hstack((pixel_num, [1] * spatial_rank, 1)).flatten() output_tensor = tf.tile(input=input_tensor, multiples=repmat) output_tensor = tf.batch_to_space_nd( input=output_tensor, block_shape=kernel_size_all_dims, crops=[[0, 0]] * spatial_rank) elif self.func == 'CHANNELWISE_DECONV': output_tensor = [tf.expand_dims(x, -1) for x in tf.unstack(input_tensor, axis=-1)] output_tensor = [DeconvLayer(n_output_chns=1, kernel_size=self.kernel_size, stride=self.stride, padding='SAME', with_bias=self.with_bias, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], b_initializer=self.initializers['b'], b_regularizer=self.regularizers['b'], name='deconv_{}'.format(i))(x) for (i, x) in enumerate(output_tensor)] output_tensor = tf.concat(output_tensor, axis=-1) return output_tensor
def layer_op(self, input_tensor): input_shape = input_tensor.shape.as_list() n_input_chns = input_shape[-1] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply w_full_size = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) # expand kernel size to include number of features w_full_size = w_full_size + (n_input_chns, self.n_output_chns) full_stride = layer_util.expand_spatial_params( self.stride, spatial_rank) full_dilation = layer_util.expand_spatial_params( self.dilation, spatial_rank) conv_kernel = tf.get_variable( 'w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) output_tensor = tf.nn.convolution(input=input_tensor, filter=conv_kernel, strides=full_stride, dilation_rate=full_dilation, padding=self.padding, name='conv') if not self.with_bias: return output_tensor # adding the bias term bias_term = tf.get_variable( 'b', shape=self.n_output_chns, initializer=self.initializers['b'], regularizer=self.regularizers['b']) output_tensor = tf.nn.bias_add(output_tensor, bias_term, name='add_bias') return output_tensor
def SSIM(x1, x2, max_val=1.0, axes=None): C1 = (0.01 * max_val)**2 C2 = (0.03 * max_val)**2 spatial_rank = infer_spatial_rank(x1) frameReference = tf.cast(x1, tf.float32) frameUnderTest = tf.cast(x2, tf.float32) frameReference_square = tf.square(frameReference) frameUnderTest_square = tf.square(frameUnderTest) frameReference_frameUnderTest = frameReference * frameUnderTest mu1 = do_conv(frameReference, spatial_rank - 1) mu2 = do_conv(frameUnderTest, spatial_rank - 1) mu1_square = tf.square(mu1) mu2_square = tf.square(mu2) mu1_mu2 = mu1 * mu2 sigma1_square = do_conv(frameReference_square, spatial_rank - 1) sigma1_square = sigma1_square - mu1_square sigma2_square = do_conv(frameUnderTest_square, spatial_rank - 1) sigma2_square = sigma2_square - mu2_square sigma12 = do_conv(frameReference_frameUnderTest, spatial_rank - 1) sigma12 = sigma12 - mu1_mu2 numerator = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) denominator = ((mu1_square + mu2_square + C1) * (sigma1_square + sigma2_square + C2)) ssim_map = numerator / denominator if spatial_rank == 3 and axes == None: axes = tf.constant([-4, -3, -2, -1, 0], dtype=tf.int32) elif not spatial_rank == 3 and axes == None: axes = tf.constant([-3, -2, -1, 0], dtype=tf.int32) ssim = tf.reduce_mean(ssim_map, axis=axes) return ssim, tf.reduce_mean(denominator, axis=axes), ssim_map
def _resample_inv_dst_weighting(self, inputs, sample_coords): in_size = inputs.shape.as_list() in_spatial_size = in_size[1:-1] in_spatial_rank = infer_spatial_rank(inputs) out_rank = len(sample_coords.shape.as_list()) self.N = 2 ** in_spatial_rank binary_neighbour_ids = [ [int(c) for c in format(i, '0%ib' % in_spatial_rank)] for i in range(self.N)] weight_id = [[[c, i] for i, c in enumerate(bc)] for bc in binary_neighbour_ids] sample_coords = tf.transpose( sample_coords, [out_rank - 1, 0] + list(range(1, out_rank - 1))) # broadcasting input spatial size for boundary functions b_size = tf.reshape(in_spatial_size, [len(in_spatial_size)] + [1] * (out_rank - 1)) # find floor and ceil coordinates all_coords_f = tf.stack([ self.boundary_func(tf.floor(sample_coords), b_size), self.boundary_func(tf.ceil(sample_coords), b_size)]) # find N weights associated to each output point diff = tf.stack( [tf.squared_difference(sample_coords - EPS, all_coords_f[0]), tf.squared_difference(sample_coords + EPS, all_coords_f[1])]) # gather_nd for both matrices, the same as: # point_weights = tf.gather_nd(diff, weight_id) # knots_id = tf.gather_nd(all_coords_f, weight_id) n_val = tf.gather_nd(tf.stack([diff, all_coords_f], axis=-1), weight_id) n_val = tf.unstack(n_val, axis=-1) point_weights, knots_id = n_val[0], n_val[1] # inverse distance weighting # sum_i (w_i*p_i/(sum_j w_j)) w_i = 1/((p-p_i)^2) # point_weights shape: # `[N, input_rank, b, sp_dim_0, ..., sp_dim_K]` # where: # `N` is 2**source data spatial rank # `b` is batch size, # `sp_dim_0` is the output spatial output 0, # # `point_weights` represents (p - p_i)^2 # with i= 0...2**source_rank neighbours # (to do: these operations could be refactored as a resampling kernel) point_weights = tf.reduce_sum(point_weights, axis=1) # skip this as power = 2.0: # self.power = 1.0 # point_weights = tf.pow(point_weights, self.power / 2.0) point_weights = tf.reciprocal(point_weights) point_weights = point_weights / tf.reduce_sum(point_weights, axis=0) # find N neighbours associated to each output point knots_id = tf.transpose(tf.cast(knots_id, COORDINATES_TYPE), [0] + list(range(2, out_rank + 1)) + [1]) # get values of N neighbours samples = [ tf.gather_nd(img, knots) for (img, knots) in zip(tf.unstack(inputs, axis=0), tf.unstack(knots_id, axis=1))] samples = tf.stack(samples, axis=1) # weighted average over N neighbours return tf.reduce_sum( samples * tf.expand_dims(point_weights, axis=-1), axis=0)
def layer_op(self,input_tensor,input_mask=None,output_mask=None): """ Parameters: input_tensor: image to convolve with kernel input_mask: 1-Tensor with a binary mask of input channels to use If this is None, all channels are used. output_mask: 1-Tensor with a binary mask of output channels to generate If this is None, all channels are used and the number of output channels is set at graph-creation time. """ input_shape = input_tensor.shape.as_list() if input_mask is None: _input_mask=tf.ones([input_shape[-1]])>0 else: _input_mask=input_mask if output_mask is None: n_sparse_output_chns = self.n_output_chns _output_mask=tf.ones([self.n_output_chns])>0 else: n_sparse_output_chns = tf.reduce_sum(tf.cast(output_mask, tf.float32)) _output_mask=output_mask n_full_input_chns = _input_mask.shape.as_list()[0] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply w_full_size = np.vstack(( [self.kernel_size] * spatial_rank, self.n_output_chns, n_full_input_chns)).flatten() full_stride = np.vstack(( 1, [self.stride] * spatial_rank, 1)).flatten() deconv_kernel = tf.get_variable( 'w', shape=w_full_size.tolist(), initializer=self.initializers['w'], regularizer=self.regularizers['w']) sparse_kernel = tf.transpose(tf.boolean_mask( tf.transpose(tf.boolean_mask( tf.transpose(deconv_kernel,[3,4,2,1,0]),_output_mask),[1,0,2,3,4]),_input_mask),[4,3,2,1,0]) if spatial_rank == 2: op_ = SUPPORTED_OP['2D'] elif spatial_rank == 3: op_ = SUPPORTED_OP['3D'] else: raise ValueError( "Only 2D and 3D spatial deconvolutions are supported") output_dim = infer_output_dims(input_shape[1], self.stride, self.kernel_size, self.padding) sparse_output_size = tf.stack([input_shape[0], [output_dim] * spatial_rank, n_sparse_output_chns],0) output_tensor = op_(value=input_tensor, filter=deconv_kernel, output_shape=sparse_output_size, strides=full_stride.tolist(), padding=self.padding, name='deconv') if output_mask is None: # If all output channels are used, we can specify # the number of output channels which is useful for later layers old_shape=output_tensor.shape.as_list() old_shape[-1]=self.n_output_chns output_tensor.set_shape(old_shape) if not self.with_bias: return output_tensor # adding the bias term bias_full_size = (self.n_output_chns,) bias_term = tf.get_variable( 'b', shape=bias_full_size, initializer=self.initializers['b'], regularizer=self.regularizers['b']) sparse_bias = tf.boolean_mask(bias_term,_output_mask) output_tensor = tf.nn.bias_add(output_tensor, sparse_bias, name='add_bias') return output_tensor
def layer_op(self,input_tensor,input_mask,output_mask): """ Parameters: input_tensor: image to convolve with kernel input_mask: 1-Tensor with a binary mask of input channels to use If this is None, all channels are used. output_mask: 1-Tensor with a binary mask of output channels to generate If this is None, all channels are used and the number of output channels is set at graph-creation time. """ sparse_input_shape = input_tensor.shape.as_list() if input_mask is None: _input_mask=tf.ones([sparse_input_shape[-1]])>0 else: _input_mask=input_mask if output_mask is None: _output_mask=tf.ones([self.n_output_chns])>0 else: _output_mask=output_mask n_full_input_chns = _input_mask.shape.as_list()[0] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply w_full_size = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) # expand kernel size to include number of features w_full_size = w_full_size + (n_full_input_chns, self.n_output_chns) full_stride = layer_util.expand_spatial_params( self.stride, spatial_rank) full_dilation = layer_util.expand_spatial_params( self.dilation, spatial_rank) conv_kernel = tf.get_variable( 'w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) sparse_kernel = tf.transpose(tf.boolean_mask( tf.transpose(tf.boolean_mask( tf.transpose(conv_kernel,[4,3,2,1,0]), _output_mask),[1,0,2,3,4]),_input_mask),[4,3,2,0,1]) output_tensor = tf.nn.convolution(input=input_tensor, filter=sparse_kernel, strides=full_stride, dilation_rate=full_dilation, padding=self.padding, name='conv') if output_mask is None: # If all output channels are used, we can specify # the number of output channels which is useful for later layers old_shape=output_tensor.shape.as_list() old_shape[-1]=self.n_output_chns output_tensor.set_shape(old_shape) if not self.with_bias: return output_tensor # adding the bias term bias_term = tf.get_variable( 'b', shape=self.n_output_chns, initializer=self.initializers['b'], regularizer=self.regularizers['b']) sparse_bias = tf.boolean_mask(bias_term,output_mask) output_tensor = tf.nn.bias_add(output_tensor, sparse_bias, name='add_bias') return output_tensor
def layer_op(self, inputs, sample_coords): """ This layer resamples 2D or 3D data given the coordinates. In terms of 3D inputs, when the shape of ``inputs`` is ``[batch, x, y, z, num_channels]``, the shape of ``sample_coords`` can be ``[1, d0, d1, ..., 3]`` or ``[batch, d0, d1, ..., 3]``. The output shape would be ``[batch, d0, d1, ..., num_channels]``. Similarly, in 2D, when the shape of ``inputs`` is ``[batch, x, y, num_channels]``, the shape of ``sample_coords`` can be ``[1, d0, d1, ..., 2]`` or ``[batch, d0, d1, ..., 2]``. The output shape would be ``[batch, d0, d1, ... num_channels]`` (If the shape of ``inputs`` is not fully specified, ``sample_coords`` must be checked before using this function, to make sure the coordinates are pointing to locations within the inputs.) (Resampling 2D inputs is implemented by calling ``tf.contrib.resampler.resampler``. The interpretaion of coordinates is different in between this function and ``tf.contrib.resampler.resampler``: using ``self.layer_op(inputs, sample_coords)`` for 2D data is equivalent to (apart from the batch size broadcasting feature):: tf.contrib.resampler.resampler( tf.transpose(inputs, [0, 2, 1, 3]), sample_coords) (No gradient is computed for ``NEAREST`` method, and some of the padding modes.) """ # check the input dims try: batch_inputs = int(inputs.shape[0]) batch_sample_coords = int(sample_coords.shape[0]) except (TypeError, ValueError): tf.logging.fatal('Unknown input shape, at least batch size ' 'needs to be specified.') raise if batch_inputs != batch_sample_coords and batch_sample_coords > 1: tf.logging.fatal( '\nOnly the following two cases are currently supported:\n' ' - batch size of inputs == batch size of sample_coords\n' ' - batch size of sample_coords == 1\n' 'In the second case, sample_coords will be applied to each of ' 'the batch component of the inputs.') raise ValueError # input_spatial_rank = infer_spatial_rank(inputs) # if input_spatial_rank != 2 and input_spatial_rank != 3: # tf.logging.fatal('Only 2D or 3D inputs are supported.') # raise ValueError try: coords_n_dim = int(sample_coords.shape[-1]) except (TypeError, ValueError): tf.logging.fatal( 'The last dim of the coordinates must have 2 or 3 elements.') raise if infer_spatial_rank(inputs) != coords_n_dim: tf.logging.fatal( 'sample_coords.shape[-1] must be the same as the spatial rank ' 'of the inputs.') raise ValueError # currently converting everything to floats if inputs.dtype not in SUPPORTED_INPUT_DTYPE: inputs = tf.to_float(inputs) if sample_coords.dtype not in SUPPORTED_INPUT_DTYPE: sample_coords = tf.to_float(sample_coords) if self.interpolation == 'LINEAR': return self._resample_linear(inputs, sample_coords) if self.interpolation == 'NEAREST': return self._resample_nearest(inputs, sample_coords) if self.interpolation == 'BSPLINE': return self._resample_bspline(inputs, sample_coords) if self.interpolation == 'IDW': return self._resample_inv_dst_weighting(inputs, sample_coords) tf.logging.fatal('interpolation method not implmented') raise NotImplementedError
def _resample_inv_dst_weighting(self, inputs, sample_coords): # inverse distance weighting using 2^(sptial_rank) neighbours in_size = inputs.shape partial_shape = not in_size.is_fully_defined() try: batch_size = int(in_size[0]) n_coords = int(sample_coords.shape[0]) in_spatial_rank = infer_spatial_rank(inputs) in_spatial_size = \ None if partial_shape else in_size.as_list()[1:-1] except (TypeError, AssertionError, ValueError): tf.logging.fatal('Unknown input shape, at least batch size ' 'and rank of the inputs are required.') raise out_rank = len(sample_coords.get_shape()) binary_neighbour_ids = _binary_neighbour_ids(in_spatial_rank) weight_id = [[[c, i] for i, c in enumerate(bc)] for bc in binary_neighbour_ids] sample_coords_shape = [out_rank - 1, 0] + list(range(1, out_rank - 1)) sample_coords = tf.transpose(sample_coords, sample_coords_shape) if partial_shape or in_spatial_size is None: all_coords_f = tf.stack( [tf.floor(sample_coords), tf.ceil(sample_coords)]) else: # broadcasting input spatial size for boundary functions expanded_spatial_size = \ [len(in_spatial_size)] + [1] * (out_rank - 1) b_size = tf.reshape(in_spatial_size, expanded_spatial_size) # find floor and ceil coordinates all_coords_f = tf.stack([ self.boundary_func(tf.floor(sample_coords), b_size), self.boundary_func(tf.ceil(sample_coords), b_size)]) # find N weights associated to each output point diff = tf.stack( [tf.squared_difference(sample_coords - EPS, all_coords_f[0]), tf.squared_difference(sample_coords + EPS, all_coords_f[1])]) # gather_nd for both matrices, the same as: # point_weights = tf.gather_nd(diff, weight_id) # knots_id = tf.gather_nd(all_coords_f, weight_id) n_val = tf.gather_nd( tf.stack([diff, all_coords_f], axis=-1), weight_id) n_val = tf.unstack(n_val, axis=-1) point_weights, knots_id = n_val[0], n_val[1] # inverse distance weighting # sum_i (w_i*p_i/(sum_j w_j)) where w_i = 1/((p-p_i)^2) # point_weights has the shape:100 # `[N, input_rank, b, sp_dim_0, ..., sp_dim_K]` # where: # `N` is 2**source data spatial rank # `b` is batch size, # `sp_dim_0` is the output spatial output 0, # # `point_weights` represents (p - p_i)^2 # with i= 0...2**source_rank neighbours # (to do: these operations could be refactored as a resampling kernel) point_weights = tf.reduce_sum(point_weights, axis=1) # skip this as power = 2.0: # self.power = 2.0 # point_weights = tf.pow(point_weights, self.power / 2.0) point_weights = tf.reciprocal(point_weights) point_weights = point_weights / tf.reduce_sum(point_weights, axis=0) # find N neighbours associated to each output point # knots_shape = tf.concat([[0], tf.range(2, out_rank + 1), [1]], 0) knots_shape = [0] + list(range(2, out_rank + 1)) + [1] knots_id = tf.cast(knots_id, COORDINATES_TYPE) knots_id = tf.transpose(knots_id, knots_shape) # get values of N neighbours batch_inputs = tf.unstack(inputs, axis=0) batch_knots = tf.unstack(knots_id, axis=1) if batch_size == n_coords: samples = [tf.gather_nd(img, knot) for (img, knot) in zip(batch_inputs, batch_knots)] elif n_coords == 1 and batch_size > 1: samples = [tf.gather_nd(img, batch_knots[0]) for img in batch_inputs] else: raise NotImplementedError samples = tf.stack(samples, axis=1) # weighted average over N neighbours return tf.reduce_sum( samples * tf.expand_dims(point_weights, axis=-1), axis=0)
def _resample_linear(self, inputs, sample_coords): """ Bilinear or trilinear resampling. :param inputs: :param sample_coords: :return: """ # read input shape in_size = inputs.shape partial_shape = not in_size.is_fully_defined() try: batch_size = int(in_size[0]) n_coords = int(sample_coords.shape[0]) in_spatial_rank = infer_spatial_rank(inputs) in_spatial_size = \ None if partial_shape else in_size.as_list()[1:-1] except (TypeError, AssertionError, ValueError): tf.logging.fatal('Unknown input shape, at least batch size ' 'and rank of the inputs are required.') raise # read output shape out_spatial_rank = infer_spatial_rank(sample_coords) out_spatial_size = sample_coords.shape.as_list()[1:-1] if in_spatial_rank == 2 and self.boundary == 'ZERO': # calling TF's resampler inputs = tf.transpose(inputs, [0, 2, 1, 3]) if batch_size == n_coords: return tf.contrib.resampler.resampler(inputs, sample_coords) outputs = [ tf.contrib.resampler.resampler( tf.expand_dims(img), sample_coords) for img in tf.unstack(inputs)] return tf.concat(outputs, axis=0) xy = tf.unstack(sample_coords, axis=-1) base_coords = [tf.floor(coords) for coords in xy] if partial_shape: # if input shape is not defined, unable to compute # boundary elements floor_coords = [coord for coord in base_coords] ceil_coords = [coord + 1.0 for coord in base_coords] else: floor_coords = [self.boundary_func(x, in_spatial_size[idx]) for (idx, x) in enumerate(base_coords)] ceil_coords = [self.boundary_func(x + 1.0, in_spatial_size[idx]) for (idx, x) in enumerate(base_coords)] if self.boundary == 'ZERO': weight_0 = [tf.expand_dims(x - i, -1) for (x, i) in zip(xy, floor_coords)] weight_1 = [tf.expand_dims(i - x, -1) for (x, i) in zip(xy, ceil_coords)] else: weight_0 = [tf.expand_dims(x - i, -1) for (x, i) in zip(xy, base_coords)] weight_1 = [1.0 - w for w in weight_0] sc = (tf.cast(floor_coords, COORDINATES_TYPE), tf.cast(ceil_coords, COORDINATES_TYPE)) if n_coords == 1 and batch_size > 1: # fetch neighbours with the same coordinates across the input batch inputs = tf.unstack(inputs) def _get_knot(bc): coord = [sc[c][i] for i, c in enumerate(bc)] coord = tf.stack(coord, axis=-1) batch_samples = [tf.gather_nd(img, coord) for img in inputs] batch_samples = tf.concat(batch_samples, axis=0) return batch_samples elif n_coords == batch_size: batch_ids = tf.reshape( tf.range(batch_size), [batch_size] + [1] * out_spatial_rank) batch_ids = tf.tile(batch_ids, [1] + out_spatial_size) def _get_knot(bc): coord = [batch_ids] + [sc[c][i] for i, c in enumerate(bc)] coord = tf.stack(coord, axis=-1) return tf.gather_nd(inputs, coord) else: raise NotImplementedError def _pyramid_combination(samples, w_0, w_1): # the case where n_coords = 1 and batch_size > 1 is handled by # shape broadcasting if len(w_0) == 1: return samples[0] * w_1[0] + samples[1] * w_0[0] f_0 = _pyramid_combination(samples[::2], w_0[:-1], w_1[:-1]) f_1 = _pyramid_combination(samples[1::2], w_0[:-1], w_1[:-1]) return f_0 * w_1[-1] + f_1 * w_0[-1] binary_neighbour_ids = _binary_neighbour_ids(in_spatial_rank) samples = [_get_knot(bc) for bc in binary_neighbour_ids] return _pyramid_combination(samples, weight_0, weight_1)
def layer_op(self, fixed_image, moving_image, base_grid=None, is_training=True): """ :param fixed_image: :param moving_image: :param base_grid: :param is_training: :return: estimated dense displacement fields """ spatial_rank = infer_spatial_rank(fixed_image) spatial_shape = fixed_image.get_shape().as_list()[1:-1] check_spatial_dims(fixed_image, lambda x: x % 16 == 0) # resize the moving image to match the fixed moving_image = Resize(spatial_shape)(moving_image) img = tf.concat([moving_image, fixed_image], axis=-1) down_res_0, conv_0_0, _ = \ DownRes(self.fea[0], kernel_size=7, **self.down_res_param)(img, is_training) down_res_1, conv_0_1, _ = \ DownRes(self.fea[1], **self.down_res_param)(down_res_0, is_training) down_res_2, conv_0_2, _ = \ DownRes(self.fea[2], **self.down_res_param)(down_res_1, is_training) down_res_3, conv_0_3, _ = \ DownRes(self.fea[3], **self.down_res_param)(down_res_2, is_training) conv_4 = Conv(n_output_chns=self.fea[4], kernel_size=self.k_conv, **self.down_res_param)(down_res_3, is_training) up_res_0 = UpRes(self.fea[3], **self.up_res_param)( conv_4, conv_0_3, is_training) up_res_1 = UpRes(self.fea[2], **self.up_res_param)( up_res_0, conv_0_2, is_training) up_res_2 = UpRes(self.fea[1], **self.up_res_param)( up_res_1, conv_0_1, is_training) up_res_3 = UpRes(self.fea[0], **self.up_res_param)( up_res_2, conv_0_0, is_training) if self.multi_scale_fusion: output_list = [up_res_3, up_res_2, up_res_1, up_res_0, conv_4] else: output_list = [up_res_3] # converting all output layers to displacement fields dense_fields = [] for scale_out in output_list: field = Conv(n_output_chns=spatial_rank, kernel_size=self.k_conv, with_bias=True, with_bn=False, acti_func=None, **self.disp_param)(scale_out) resized_field = Resize(new_size=spatial_shape)(field) dense_fields.append(resized_field) if base_grid is None: # adding a reference grid if it doesn't exist in_spatial_size = [None] * spatial_rank base_grid = _create_affine_features(output_shape=spatial_shape, source_shape=in_spatial_size) base_grid = np.asarray(base_grid[:-1]) base_grid = np.reshape( base_grid.T, [-1] + spatial_shape + [spatial_rank]) base_grid = tf.constant(base_grid, dtype=resized_field.dtype) if self.multi_scale_fusion and len(dense_fields) > 1: dense_field = tf.reduce_sum(dense_fields, axis=0) else: dense_field = dense_fields[0] # TODO filtering if self.smoothing_func is not None: dense_field = self.smoothing_func(dense_field, spatial_rank) tf.add_to_collection('bending_energy', _computing_bending_energy(dense_field)) tf.add_to_collection('gradient_norm', _computing_gradient_norm(dense_field)) dense_field = dense_field + base_grid return dense_field