def ssim(self, y_true, y_pred): patches_true = tf.extract_volume_patches(y_true, self.__kernel_shape, self.__stride, 'VALID', 'patches_true') patches_pred = tf.extract_volume_patches(y_pred, self.__kernel_shape, self.__stride, 'VALID', 'patches_pred') #bs, w, h, d, *c = self.__int_shape(patches_pred) #patches_true = tf.reshape(patches_true, [-1, w, h, d, tf.reduce_prod(c)]) #patches_pred = tf.reshape(patches_pred, [-1, w, h, d, tf.reduce_prod(c)]) # Mean u_true = tf.reduce_mean(patches_true, axis=-1) u_pred = tf.reduce_mean(patches_pred, axis=-1) # Variance v_true = tf.math.reduce_variance(patches_true, axis=-1) v_pred = tf.math.reduce_variance(patches_pred, axis=-1) # Covariance covar = tf.reduce_mean(patches_true * patches_pred, axis=-1) - u_true * u_pred # SSIM numerator = (2 * u_true * u_pred + self.__c1) * (2 * covar + self.__c2) denominator = ((tf.square(u_true) + tf.square(u_pred) + self.__c1) * (v_pred + v_true + self.__c2)) ssim = numerator / denominator return tf.reduce_mean(ssim)
def _sliding_window(I, M): ''' Parses a 3D image into size [128, 128, 32] with no overlap Used for training images where boundary abnormalities are of no concern ''' I = tf.extract_volume_patches(I, ksizes=[1, 128, 128, 32, 1], strides=[1, 128, 128, 32, 1], padding='VALID') I = tf.reshape(I, [-1, 128, 128, 32, 1]) M = tf.extract_volume_patches(M, ksizes=[1, 128, 128, 32, 1], strides=[1, 128, 128, 32, 1], padding='VALID') M = tf.reshape(M, [-1, 128, 128, 32, 2]) return I, M
def _sliding_window(self, I): I = tf.pad(I, ((0, 0), (64, 64), (64, 64), (16, 16), (0, 0))) I = tf.extract_volume_patches(I, ksizes=[1, 128, 128, 32, 1], strides=[1, 64, 64, 16, 1], padding='VALID') return tf.reshape(I, [-1, 128, 128, 32, 1])
def extract_patch(image, *args, **kwargs): patches = tf.extract_volume_patches( image, ksizes=kwargs['ksizes'], strides=kwargs['strides'], padding="SAME", ) return patches
def _sliding_window_overlap(I, M): ''' Parses a 3D image into sizes [128,128,32] with a 50% overlap Used for evaluation images to remove boundary abnormalities ''' I = tf.pad(I, ((0, 0), (64, 64), (64, 64), (16, 16), (0, 0))) M = tf.pad(M, ((0, 0), (64, 64), (64, 64), (16, 16), (0, 0))) I = tf.extract_volume_patches(I, ksizes=[1, 128, 128, 32, 1], strides=[1, 64, 64, 16, 1], padding='VALID') I = tf.reshape(I, [-1, 128, 128, 32, 1]) M = tf.extract_volume_patches(M, ksizes=[1, 128, 128, 32, 1], strides=[1, 64, 64, 16, 1], padding='VALID') M = tf.reshape(M, [-1, 128, 128, 32, 2]) return I, M
def gen_patches(data, patch_slices, patch_rows, patch_cols, stride_slices, stride_rows, stride_cols, input_dim_order='XYZ', padding='VALID'): # Reorder the dimensions to ZYX if input_dim_order == 'XYZ': data = np.transpose(data, axes=(2, 1, 0)) # Check if the data has channels if np.size(data.shape) != 3: print( 'WARNING! Function is only meant to be used for data with one channel' ) return # Expand dimension for depth (number of channels) data = data[:, :, :, np.newaxis] # Expand the dimension for batches data = data[np.newaxis, :, :, :, :] # Extract patches of size patch_slices x patch_rows x patch_cols with tf.Session() as sess: t = tf.extract_volume_patches( data, ksizes=[1, patch_slices, patch_rows, patch_cols, 1], strides=[1, stride_slices, stride_rows, stride_cols, 1], padding=padding).eval(session=sess) # Reshape the patches to 3D # t.shape[1] -> number of extracted patches in z-direction # t.shape[2] -> number of extracted patches in y-direction # t.shape[3] -> number of extracted patches in x-direction t = tf.reshape(t, [ 1, t.shape[1], t.shape[2], t.shape[3], patch_slices, patch_rows, patch_cols ]).eval(session=sess) # Remove the batch dimension patches = t[0, :, :, :, :] # Remove the channel dimension #if has_channels == False: #patches = t[:,:,:,0] return patches
def _split(self, input, path=None, input_is_latent=False, filter_clipped_blocks=False, full_block_latent=None, empty_block_latent=None, empty_block_detection_threshold=1e-5): input = tf.expand_dims(input, 0) window_size = self.input_size() if not input_is_latent else self.total_blocks() strides = self.focus_size() if not input_is_latent else self.focused_blocks input_patches = tf.extract_volume_patches(input, [1, window_size, window_size, window_size, 1], [1, strides, strides, strides, 1], "VALID") complete_latent_mask = tf.constant(np.expand_dims(self.complete_latent_mask, 0)) latent_patches = tf.extract_volume_patches(complete_latent_mask, [1, self.total_blocks(), self.total_blocks(), self.total_blocks(), 1], [1, self.focused_blocks, self.focused_blocks, self.focused_blocks, 1], "VALID") input_block_size = self.input_size() if not input_is_latent else self.total_blocks() num_blocks = self.number_of_blocks_per_voxelgrid() input_patches = tf.reshape(input_patches, [num_blocks, input_block_size, input_block_size, input_block_size, 1 if not input_is_latent else self.latent_channel_size]) latent_patches = tf.reshape(latent_patches, [num_blocks, self.total_blocks(), self.total_blocks(), self.total_blocks(), 1]) if filter_clipped_blocks: # Create masks for empty and full blocks if empty_block_latent is None: empty_blocks = tf.reduce_all(tf.equal(input_patches, self.truncation_threshold), [1, 2, 3, 4]) else: empty_blocks = tf.reduce_all(tf.less_equal(tf.abs(input_patches - empty_block_latent), empty_block_detection_threshold), [1, 2, 3, 4]) if full_block_latent is None: filled_blocks = tf.reduce_all(tf.equal(input_patches, -self.truncation_threshold), [1, 2, 3, 4]) else: filled_blocks = tf.reduce_all(tf.less_equal(tf.abs(input_patches - full_block_latent), empty_block_detection_threshold), [1, 2, 3, 4]) # Create an array which contains 1 for empty, -1 for filled and 0 for all other blocks ones = tf.ones((num_blocks,)) types = tf.where(empty_blocks, ones, tf.where(filled_blocks, ones * -1, ones * 0)) types = tf.expand_dims(types, 0) # Remove all blocks which are filled or empty indices_to_keep = tf.logical_not(tf.logical_or(empty_blocks, filled_blocks)) input_patches = tf.boolean_mask(input_patches, indices_to_keep) latent_patches = tf.boolean_mask(latent_patches, indices_to_keep) return input_patches, latent_patches, types, path else: return input_patches, latent_patches
def extract_patches(self, data_4d, stride): cube = self.CUBE data_5d = tf.expand_dims(data_4d, -1) patches = tf.extract_volume_patches( input=data_5d, ksizes=[1, cube, cube, cube, 1], strides=[1, stride, stride, stride, 1], padding='VALID', ) result_tf = tf.reshape(patches, [-1, cube, cube, cube]) img = result_tf result = tf.expand_dims(img, -1) return result
def spatially_constrained_loss(data_dict, kernal_size=3, sigma=0.5): # time_start = time.time() orgs = tf.cast(data_dict['orgs'], tf.float32) logits = tf.cast(data_dict['logits'], tf.float32) ndim = len(logits.shape) assert ndim in [4, 5], 'only allow 2d or 3d images without RGB channel.' if type(kernal_size) is int: kernal_size = [1] + [ kernal_size, ] * (ndim - 2) + [1] elif type(kernal_size) is list: kernal_size = [1] + kernal_size + [1] strides = [ 1, ] * ndim rates = [ 1, ] * ndim probs = tf.nn.softmax(logits) confs = tf.reduce_max(probs, -1, keepdims=True) arg_preds = tf.cast(tf.expand_dims(tf.argmax(probs, -1), -1), tf.float32) if ndim == 4: p_zmask = tf.image.extract_patches(tf.ones(confs.shape), kernal_size, strides, rates, padding='SAME') p_confs = tf.image.extract_patches(confs, kernal_size, strides, rates, padding='SAME') p_orgs = tf.image.extract_patches(orgs, kernal_size, strides, rates, padding='SAME') p_preds = tf.image.extract_patches(arg_preds, kernal_size, strides, rates, padding='SAME') elif ndim == 5: p_zmask = tf.extract_volume_patches(tf.ones(confs.shape), kernal_size, strides, padding='SAME') p_confs = tf.extract_volume_patches(confs, kernal_size, strides, padding='SAME') p_orgs = tf.extract_volume_patches(orgs, kernal_size, strides, padding='SAME') p_preds = tf.extract_volume_patches(arg_preds, kernal_size, strides, padding='SAME') p_exp = tf.exp(-tf.square(orgs - p_orgs) / (2 * sigma**2)) p_exp = p_zmask * p_exp p_mask = 2 * tf.cast(arg_preds == p_preds, tf.float32) - 1 u_ij = p_exp * p_mask P_ij = confs * p_confs F_ij = u_ij * P_ij F_i = (tf.reduce_sum(F_ij, -1) - tf.reshape( confs**2, confs.shape[:-1])) / (tf.reduce_sum(p_exp, -1) - 1 + 1e-9) sc_loss_map = 1 - F_i # print('ray time cost: {}'.format(time.time() - time_start)) return sc_loss_map
def gen_patches(session, data, patch_slices, patch_rows, patch_cols, stride_slices, stride_rows, stride_cols, input_dim_order='XYZ', padding='VALID'): """ Generates patches of the Numpy Array data of the size patch_slices x patch_rows x patch_cols with stride_slices x stride stride_rows x stride_cols. Parameters: data (Numpy Array): Numpy Arrayout of which the patches are generated. patch_slices (int): Number of slices (z-size) one patch should have. patch_rows (int): Number of rows (y-size) one patch should have. patch_cols (int): Number of columns (x-size) one patch should have. stride_slices (int): Stride in slice direction (z-direction). stride_rows (int): Stride in row direction (y-direction). stride_cols (int): Stride in column direction (x-direction). input_dim_order (String): String of the dimension order of data. Can be 'XYZ' oder 'ZYX'. padding (String): String which padding should be used. Can be 'VALID' (no padding) or 'SAME' (with zero-padding). Returns: patches (Numpy Array): Generated Patches of size slice_indice x row_indice x column_indice x image_slice x image_row x image_column """ # Reorder the dimensions to ZYX if input_dim_order == 'XYZ': data = np.transpose(data, axes=(2,1,0)) # ZYX # Check if the data has channels if np.size(data.shape) != 3: print('WARNING! Function is only meant to be used for data with one channel') return # Expand dimension for depth (number of channels) data = data[:,:,:,np.newaxis] # Expand the dimension for batches data = data[np.newaxis,:,:,:,:] # Extract patches of size patch_slices x patch_rows x patch_cols t = tf.extract_volume_patches(data, ksizes=[1, patch_slices, patch_rows, patch_cols, 1], strides=[1, stride_slices, stride_rows, stride_cols, 1], padding=padding) # t = session.run(t) # Reshape the patches to 3D # t.shape[1] -> number of extracted patches in z-direction # t.shape[2] -> number of extracted patches in y-direction # t.shape[3] -> number of extracted patches in x-direction t = tf.reshape(t, [1, t.shape[1], t.shape[2], t.shape[3], patch_slices, patch_rows, patch_cols]) t = session.run(t) # Remove the batch dimension patches = t[0,:,:,:,:] # Remove the channel dimension #if has_channels == False: #patches = t[:,:,:,0] return patches
# Generate 3D Volume with ZYX-Dimensions of 6x6x6 #vol_in = gen_volume(6,6,6) vol_in = gen_volume2(6, 6, 6) #vol_in = gen_volume3(6,6,6,3) # Expand the dimension for batches vol_in_exp = np.expand_dims(vol_in, axis=0) # Expand dimension for depth vol_in_exp = np.expand_dims(vol_in_exp, axis=4) # Extract 4 patches of 6x3x3 (ZYX) with tf.Session() as sess: t = tf.extract_volume_patches(vol_in_exp, ksizes=[1, 6, 3, 3, 1], strides=[1, 6, 3, 3, 1], padding='VALID').eval() print(t) # Reduce the dimensions of batches and planes patches = t[0, 0, :, :, :] # Extract the patches y1x1 = patches[0, 0, :] y1x2 = patches[0, 1, :] y2x1 = patches[1, 0, :] y2x2 = patches[1, 1, :] # Reshape the patches to its volumes (ZYX) y1x1 = np.reshape(y1x1, (6, 3, 3)) y1x2 = np.reshape(y1x2, (6, 3, 3))
def _train(self, data, log_images): ''' define the RL Algorithms the world model: do observation the actor net: do image the critic net: do image ''' with tf.GradientTape() as model_tape: ''' the world model, which is _dynamics(RSSM) ''' # data: {'action': shape=(25, 50, 4) float16, 'reward':shape=(25, 50) float16, #'discount': shape=(25, 50)float16, 'image': shape=(25, 50, 64, 64, 3) float16>} # 25: batch_size/num of GPU, 50:batch_length embed = self._encode(data) # (25, 50, 1024) post, prior = self._dynamics.observe( embed, data['action'] ) # world model try to dream from first step to last step. # post: post['meant'].shape: (25, 50, 30) feat = self._dynamics.get_feat( post) # feat: (25, batch_length, 230) image_pred = self._decode( feat) # image_pred.sample(): (25, batch_length, 64, 64, 3) reward_pred = self._reward( feat) # reward_pred.sample(): (25, batch_length) likes = tools.AttrDict( ) # collect the likelihood(prob of acturally happend events) likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) likes.reward = tf.reduce_mean( reward_pred.log_prob(data['reward']) ) # data['reward'].shape: (25, 50) => log_prob each step : (25, 50) scalar => likes.reward(mean of logprob) : () scalar if self._c.pcont: # for my aspect, this will make model to learn which step to focus by itself. pcont_pred = self._pcont(feat) pcont_target = self._c.discount * data['discount'] likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) likes.pcont *= self._c.pcont_scale prior_dist = self._dynamics.get_dist(prior) post_dist = self._dynamics.get_dist(post) div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) div = tf.maximum(div, self._c.free_nats) ''' the model loss is exactly the VAE loss of world model(which is VAE sample generator) ''' model_loss = self._c.kl_scale * div - sum(likes.values( )) # like.value contains log prob of image and reward model_loss /= float(self._strategy.num_replicas_in_sync) ''' dreamer ''' # with tf.GradientTape() as actor_tape: # imag_feat = self._imagine_ahead(post) # scaning to get prior for each prev state, step(policy&world model) for horizon(15) steps # print("imag_feat:",imag_feat.shape) # (15, 1225, 230) # reward = self._reward(imag_feat).mode() # get reward for every step # (15, 1225) # print("reward:",reward) # if self._c.pcont: # pcont = self._pcont(imag_feat).mean() # else: # pcont = self._c.discount * tf.ones_like(reward) # value = self._value(imag_feat).mode() # (15, 1250), 15 is horizon, 1250 is batch_length*batch_size/num of GPUs. # returns = tools.lambda_return( # reward[:-1], value[:-1], pcont[:-1], # bootstrap=value[-1], lambda_=self._c.disclam, axis=0) # an exponentially-weighted average of the estimates V for different k to balance bias and variance # # print("returns: ",returns) # (14, 1225) # discount = tf.stop_gradient(tf.math.cumprod(tf.concat( # [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) # not to effect the world model # print("discount:",discount.shape) # actor_loss = -tf.reduce_mean(discount * returns) # !!!!! not using policy gradient !!!! directy maximize return # actor_loss /= float(self._strategy.num_replicas_in_sync) # with tf.GradientTape() as value_tape: # value_pred = self._value(imag_feat)[:-1] # target = tf.stop_gradient(returns) # print("target:",target.shape) # (14, 1225) # print("value_pred.log_prob(target).shape:",value_pred.log_prob(target).shape) # (14, 1225) # value_loss = -tf.reduce_mean(discount * value_pred.log_prob(target)) # to directy predict return. gradient is not effecting world model # value_loss /= float(self._strategy.num_replicas_in_sync) ''' A2C ''' with tf.GradientTape() as actor_tape: # imaging ahead to get state, action, reward to imagine horizon # imag_feat: (15, 1250, 230) # also, revise the image ahead func and img_step func to get the action it take # imag_act: (15, 1250, number of action) => (15, 25, 50, number of action) imag_feat, imagine_action = self._imagine_ahead_and_get_action( post ) # scaning to get prior for each prev state, step(policy&world model) for horizon(15) steps # print("imagine_action:",imagine_action) # (15, 1225, number of action) if self._c.pcont: reduce_batch_length = self._c.batch_length - 1 # do not take the last one reduce_horizon = self._c.horizon - 1 else: reduce_batch_length = self._c.batch_length reduce_horizon = self._c.horizon - 1 imagine_action = tf.reshape( imagine_action, [self._c.horizon, -1, reduce_batch_length, self._actdim]) # print("imagine_action:",imagine_action) # (15, 25, 49 or 50, number of action) imagine_action = imagine_action[:, :, :reduce_batch_length - 10, :] # for td argmax_imagine_action = tf.argmax(imagine_action, -1) # one_hot_imagine_action = tf.one_hot(tf.argmax(imagine_action,-1),self._actdim) # print("imagine_action:",imagine_action.shape) # (15, 25, 39, 4) # print("one_hot_imagine_action:",one_hot_imagine_action.shape) # (15, 25, 39, 4) # Preprocess reward for actor and critic. sliding window size decide TD-N(10) # (15, 25, 50, 230) => (15, 25, 50, 1) => (15, 25, 50) # sliding for window 10: (15, 25, 50) =slidesum=> (15, 25, 40) # # first step: advantage, first and after: model-based(planning) advantage # # imagine reward first step (15, 25, 50) # discount (14, 1225) => (14,25,50,1) => (14,25,39,1) reward = self._reward( imag_feat).mode() # get reward for every step # (15, 1250) reward = tf.reshape(reward, [self._c.horizon, -1, reduce_batch_length, 1]) dim5_reward = tf.expand_dims(reward, -1) sum_reward = tf.extract_volume_patches( dim5_reward, [1, 1, 10, 1, 1], [1, 1, 1, 1, 1], "VALID") # need to be dimension 5 sum_reward = tf.reduce_sum(sum_reward, -1) # (15, 25, 40, 1) if self._c.pcont: pcont = self._pcont(imag_feat).mean() else: pcont = self._c.discount * tf.ones_like(reward) discount = tf.math.cumprod( tf.concat([tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0 ) # for A2C, if we like to learn pcont, do not stop the gradient print("discount:", discount.shape) # (14, 1225) discount = tf.reshape(discount, [reduce_horizon, -1, reduce_batch_length, 1]) discount = discount[:, :, :reduce_batch_length - 10, :] print("discount:", discount.shape) # discount: (14, 25, 39, 1) # value prediction # this value function is prediction current value to TD-N. (not sum to end of imagine horizon) # (15, 25, 50, 230) => (15, 25, 50) # (15, 25, [0:40]) => (15, 25, 40) st # (15, 25, [10:50]) => (15, 25, 40) st+1 # reward(15, 25, [0:40]) + (value prediction st(15,25, 40) - st+1(15, 25, 40)) => (15, 25, 40) # get advantage # stop gedient(15,25,40) value = self._value(imag_feat).mode( ) # (15, 1250), 15 is horizon, 1250 is batch_length*batch_size/num of GPUs. value = tf.reshape(value, [self._c.horizon, -1, reduce_batch_length, 1 ]) # (15, 1250 or 1245) => [15,-1,50 or 49,1] st_value = value[:, :, :reduce_batch_length - 10] stp1_value = value[:, :, 1:1 + reduce_batch_length - 10] print("st_value:", st_value.shape) # st_value: (15, 25, 40 or 39, 1) print("stp1_value:", stp1_value.shape) # stp1_value: (15, 25, 40 or 39, 1) # advantage actor-critic policy gradient # action(15, 25, [0:40]) * advantage(15, 25, 40) => (15, 25, 40) # reduce mean(15, 25, 40) if self._c.pcont: sum_reward = sum_reward[:, :, :reduce_batch_length - 10, :] # (15, 25, 39, 1) advantage = sum_reward + st_value - stp1_value # (15, 25, 39, 1) advantage = tf.stop_gradient(advantage) # update only actor print("imagine_action:", imagine_action.shape) # (15, 25, 39, 4) print("argmax_imagine_action:", argmax_imagine_action.shape) # (15, 25, 39, 4) policy_gradient = tf.keras.losses.sparse_categorical_crossentropy( argmax_imagine_action, imagine_action, from_logits=False) print("policy_gradient:", policy_gradient.shape) # (15, 25, 39) policy_gradient = tf.expand_dims(policy_gradient, -1) * advantage print("policy_gradient:", policy_gradient.shape) # (15, 25, 39, 1) policy_gradient = policy_gradient[:-1] * discount # (14, 25, 39, 1)*(14, 25, 39, 1) actor_loss = tf.reduce_mean(policy_gradient) actor_loss /= float(self._strategy.num_replicas_in_sync) with tf.GradientTape() as value_tape: # value loss # (15, 25, 40)st # slide reward: (15, 25, 40) # reduce_mean(l2((15, 25, 40),(15, 25, 40)) value = self._value(imag_feat).mode( ) # (15, 1250), 15 is horizon, 1250 is batch_length*batch_size/num of GPUs. value = tf.reshape(value, [self._c.horizon, -1, reduce_batch_length, 1 ]) # (15, 1250 or 1245) => [15,-1,50 or 49,1] st_value = value[:, :, :reduce_batch_length - 10] value_MSE = tf.keras.losses.MSE(tf.stop_gradient(sum_reward), st_value) print("value_MSE:", value_MSE.shape) value_MSE = tf.expand_dims(value_MSE[:-1], -1) * discount value_loss = tf.reduce_mean(value_MSE) value_loss /= float(self._strategy.num_replicas_in_sync) model_norm = self._model_opt(model_tape, model_loss) actor_norm = self._actor_opt(actor_tape, actor_loss) value_norm = self._value_opt(value_tape, value_loss) if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: if self._c.log_scalars: self._scalar_summaries(data, feat, prior_dist, post_dist, likes, div, model_loss, value_loss, actor_loss, model_norm, value_norm, actor_norm) if tf.equal(log_images, True): self._image_summaries(data, embed, image_pred)
vol_in_exp = np.expand_dims(vol_in_exp, axis=4) # Specify the kernel-size -> volume-size of the extracted patches k_size_z = 4 # k_size_planes k_size_y = 2 # ksize_rows k_size_x = 5 # ksize_cols # Specify the strides stride_z = k_size_z # stride_planes stride_y = k_size_y # stride_rows stride_x = k_size_x # stride_cols # Extract patches of k_size_z x k_size_y x k_size_x with tf.Session() as sess: t = tf.extract_volume_patches(vol_in_exp, ksizes=[1, k_size_z, k_size_y, k_size_x, 1], strides=[1, stride_z, stride_y, stride_x, 1], padding='VALID').eval() print(t) # Reshape the patches to 3D # t.shape[1] -> number of extracted patches in z-direction # t.shape[2] -> number of extracted patches in y-direction # t.shape[3] -> number of extracted patches in x-direction t = tf.reshape( t, [1, t.shape[1], t.shape[2], t.shape[3], k_size_z, k_size_y, k_size_x ]).eval() # Reduce the dimensions of batches patches = t[0, :, :, :, :]
feat_h, feat_w = [int(i) for i in [n, n]] x, y, z = tf.meshgrid(tf.range(10), tf.range(10), tf.range(5)) x = tf.transpose(x, [2, 0, 1]) y = tf.transpose(y, [2, 0, 1]) z = tf.transpose(z, [2, 0, 1]) print(z.numpy()) # _, _,z = tf.meshgrid(tf.range(1), tf.range(1),tf.range(64)) x, y, z = [tf.reshape(i, [1, *i.get_shape(), 1]) for i in [x, y, z]] # shape [1, h, w, 1] # _, _,_,z = tf.meshgrid(tf.range(1), tf.range(1),tf.range(9),tf.range(64)) x, y, z = [ tf.extract_volume_patches(i, [1, 3, 3, 3, 1], [1, 1, 1, 1, 1], 'SAME') for i in [x, y, z] ] print(x.numpy()) #x= tf.expand_dims(x,axis=-1); #y= tf.expand_dims(y,axis=-1); pix = tf.stack([x, z], axis=-1) print(pix.numpy()) xOffset = tf.Variable(lambda: initializer(shape=[9], dtype=tf.float32), dtype=tf.float32, name="xOffset") yOffset = tf.Variable(lambda: initializer(shape=[9], dtype=tf.float32), dtype=tf.float32, name="yOffset") x = tf.cast(x, tf.float32)