def __init__(self, **kwargs): BaseNet.__init__(self, **kwargs) # defaults self.num_channel_initial_global = 8 self.transform_initial = [ 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0. ] nc = [int(self.num_channel_initial_global * (2**i)) for i in range(5)] h0, hc0 = layer.downsample_resnet_block(self.input_layer, 2, nc[0], k_conv0=[7, 7, 7], name='global_down_0') h1, hc1 = layer.downsample_resnet_block(h0, nc[0], nc[1], name='global_down_1') h2, hc2 = layer.downsample_resnet_block(h1, nc[1], nc[2], name='global_down_2') h3, hc3 = layer.downsample_resnet_block(h2, nc[2], nc[3], name='global_down_3') h4 = layer.conv3_block(h3, nc[3], nc[4], name='global_deep_4') theta = layer.fully_connected(h4, 12, self.transform_initial, name='global_project_0') self.grid_warped = util.warp_grid(self.grid_ref, theta) self.ddf = self.grid_warped - self.grid_ref
def __init__(self, ddf_levels=None, **kwargs): BaseNet.__init__(self, **kwargs) # defaults self.ddf_levels = [0, 1, 2, 3, 4] if ddf_levels is None else ddf_levels # self.ddf_levels = [0, 1] if ddf_levels is None else ddf_levels self.num_channel_initial = 32 nc = [int(self.num_channel_initial * (2**i)) for i in range(5)] h0, hc0 = layer.downsample_resnet_block(self.input_layer, 2, nc[0], k_conv0=[7, 7, 7], name='local_down_0') h1, hc1 = layer.downsample_resnet_block(h0, nc[0], nc[1], name='local_down_1') h2, hc2 = layer.downsample_resnet_block(h1, nc[1], nc[2], name='local_down_2') h3, hc3 = layer.downsample_resnet_block(h2, nc[2], nc[3], name='local_down_3') hm = [layer.conv3_block(h3, nc[3], nc[4], name='local_deep_4')] min_level = min(self.ddf_levels) hm += [ layer.upsample_resnet_block( hm[0], hc3, nc[4], nc[3], name='local_up_3') ] if min_level < 4 else [] hm += [ layer.upsample_resnet_block( hm[1], hc2, nc[3], nc[2], name='local_up_2') ] if min_level < 3 else [] hm += [ layer.upsample_resnet_block( hm[2], hc1, nc[2], nc[1], name='local_up_1') ] if min_level < 2 else [] hm += [ layer.upsample_resnet_block( hm[3], hc0, nc[1], nc[0], name='local_up_0') ] if min_level < 1 else [] self.ddf = tf.reduce_sum(tf.stack([ layer.ddf_summand( hm[4 - idx], nc[idx], self.image_size, name='sum_%d' % idx) for idx in self.ddf_levels ], axis=5), axis=5) self.grid_warped = self.grid_ref + self.ddf
def build_network(self): self.global_step = tf.Variable(0, trainable=False) self.learning_rate = tf.train.exponential_decay(self.args.lr, self.global_step, self.args.decay_freq, 0.96, staircase=True) self.grid_ref = util.get_reference_grid(self.image_size) self.grid_warped_MV_FIX = tf.zeros_like( self.grid_ref) # initial zeros are safer for debug self.grid_warped_FIX_MV = tf.zeros_like( self.grid_ref) # initial zeros are safer for debug self.ph_MV_image = tf.placeholder(tf.float32, [self.args.batch_size] + self.image_size + [1]) self.ph_FIX_image = tf.placeholder(tf.float32, [self.args.batch_size] + self.image_size + [1]) self.ph_moving_affine = tf.placeholder( tf.float32, [self.args.batch_size] + [1, 12]) # 数据进行augment,4x4矩阵,但是最后四个参数为0001,所以一共12个参数 self.ph_fixed_affine = tf.placeholder(tf.float32, [self.args.batch_size] + [1, 12]) self.ph_random_ddf = tf.placeholder( tf.float32, [self.args.batch_size] + self.image_size + [3]) self.ph_MV_label = tf.placeholder(tf.float32, [self.args.batch_size] + self.image_size + [1]) self.ph_FIX_label = tf.placeholder(tf.float32, [self.args.batch_size] + self.image_size + [1]) self.input_MV_image, self.input_MV_label = util.augment_3Ddata_by_affine( self.ph_MV_image, self.ph_MV_label, self.ph_moving_affine) self.input_FIX_image, self.input_FIX_label = util.augment_3Ddata_by_affine( self.ph_FIX_image, self.ph_FIX_label, self.ph_fixed_affine) self.input_layer = tf.concat([ layer.resize_volume(self.input_MV_image, self.image_size), self.input_FIX_image ], axis=4) self.transform_initial = [ 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0. ] nc = [int(self.args.num_channel_initial * (2**i)) for i in range(5)] h0, hc0 = layer.downsample_resnet_block(self.is_train, self.input_layer, 2, nc[0], k_conv0=[7, 7, 7], name='global_down_0') h1, hc1 = layer.downsample_resnet_block(self.is_train, h0, nc[0], nc[1], name='global_down_1') h2, hc2 = layer.downsample_resnet_block(self.is_train, h1, nc[1], nc[2], name='global_down_2') h3, hc3 = layer.downsample_resnet_block(self.is_train, h2, nc[2], nc[3], name='global_down_3') h4 = layer.conv3_block(self.is_train, h3, nc[3], nc[4], name='global_deep_4') self.theta_fw = layer.fully_connected(h4, 12, self.transform_initial, name='global_project_0') self.theta_bw = layer.fully_connected(h4, 12, self.transform_initial, name='global_project_1') # out=tf.layers.flatten(h4) # out=tf.layers.dense(out,1024) # out=tf.layers.dense(out,256) # xyz=tf.layers.dense(out,3) # angles=tf.nn.tanh(tf.layers.dense(out,3)) self.grid_warped_fw = util.warp_grid(self.grid_ref, self.theta_fw) # 这个地方为啥又减去, self.ddf_fw = self.grid_warped_fw - self.grid_ref self.grid_warped_bw = util.warp_grid(self.grid_ref, self.theta_bw) # 这个地方为啥又减去, self.ddf_bw = self.grid_warped_bw - self.grid_ref self.warped_MV_image = self.warp_image( self.input_MV_image, self.grid_warped_fw ) # warp the moving label with the predicted ddf self.warped_FIX_image = self.warp_image( self.input_FIX_image, self.grid_warped_bw ) # warp the moving label with the predicted ddf # self.resotre_MV_label=self.warp_image(self.warped_MV_label,self.grid_warped_bw) self.restore_MV_image = self.warp_image(self.warped_MV_image, self.grid_warped_bw) # self.resotre_FIX_label=self.warp_image(self.warped_FIX_label,self.grid_warped_fw) self.restore_FIX_image = self.warp_image(self.warped_FIX_image, self.grid_warped_fw) #这里可以让restore_fix_image* restore_fix_label,因为在形变的时候,图像四周容易生成空白, self.ddf_regularisation1 = self.args.lambda_consis * restore_loss2( self.input_FIX_image, self.restore_FIX_image) self.ddf_regularisation2 = self.args.lambda_consis * restore_loss2( self.input_MV_image, self.restore_MV_image) self.ddf_regularisation = self.ddf_regularisation1 + self.ddf_regularisation2 # self.restore_MV_label= self.warp_image(self.warped_MV_label,self.grid_warped_bw) # warp the moving label with the predicted ddf # self.warped_MV_image = self.warp_MV_image(self.input_MV_image) # self.warped_MV_label = self.warp_MV_image(self.input_MV_label) # warp the moving label with the predicted ddf self.warped_MV_label = self.warp_image( self.input_MV_label, self.grid_warped_fw ) # warp the moving label with the predicted ddf self.warped_FIX_label = self.warp_image( self.input_FIX_label, self.grid_warped_bw ) # warp the moving label with the predicted ddf self.grad_loss_fw = tf.reduce_mean( loss.multi_scale_loss(self.input_FIX_label, self.warped_MV_label, 'dice', [0, 1, 2, 4, 8])) self.grad_loss_bw = tf.reduce_mean( loss.multi_scale_loss(self.input_MV_label, self.warped_FIX_label, 'dice', [0, 1, 2, 4, 8])) self.grad_loss = self.grad_loss_fw + self.grad_loss_bw self.train_op = tf.train.AdamOptimizer( self.args.lr).minimize(self.grad_loss + self.ddf_regularisation)
def build_network(self): self.global_step = tf.Variable(0, trainable=False) self.learning_rate = tf.train.exponential_decay(self.args.lr, self.global_step, self.args.decay_freq, 0.96, staircase=True) self.grid_ref = util.get_reference_grid(self.image_size) self.grid_warped_MV_FIX = tf.zeros_like( self.grid_ref) # initial zeros are safer for debug self.grid_warped_FIX_MV = tf.zeros_like( self.grid_ref) # initial zeros are safer for debug self.ph_MV_image = tf.placeholder(tf.float32, [self.args.batch_size] + self.image_size + [1]) self.ph_FIX_image = tf.placeholder(tf.float32, [self.args.batch_size] + self.image_size + [1]) self.ph_moving_affine = tf.placeholder( tf.float32, [self.args.batch_size] + [1, 12]) # 数据进行augment,4x4矩阵,但是最后四个参数为0001,所以一共12个参数 self.ph_fixed_affine = tf.placeholder(tf.float32, [self.args.batch_size] + [1, 12]) self.ph_random_ddf = tf.placeholder( tf.float32, [self.args.batch_size] + self.image_size + [3]) self.ph_MV_label = tf.placeholder(tf.float32, [self.args.batch_size] + self.image_size + [1]) self.ph_FIX_label = tf.placeholder(tf.float32, [self.args.batch_size] + self.image_size + [1]) self.input_MV_image, self.input_MV_label = util.augment_3Ddata_by_affine( self.ph_MV_image, self.ph_MV_label, self.ph_moving_affine) self.input_FIX_image, self.input_FIX_label = util.augment_3Ddata_by_affine( self.ph_FIX_image, self.ph_FIX_label, self.ph_fixed_affine) # self.input_FIX_image,self.input_FIX_label=util.augment_3Ddata_by_DDF(self.ph_FIX_image,self.ph_FIX_label,self.ph_random_ddf) self.input_layer = tf.concat([ layer.resize_volume(self.input_MV_image, self.image_size), self.input_FIX_image ], axis=4) self.lambda_bend = self.args.lambda_ben self.lambda_consis = self.args.lambda_consis self.ddf_levels = [0, 1, 2, 3, 4] self.num_channel_initial = self.args.num_channel_initial # 32,64,128,256,512 nc = [int(self.num_channel_initial * (2**i)) for i in range(5)] h0, hc0 = layer.downsample_resnet_block(self.is_train, self.input_layer, 2, nc[0], k_conv0=[7, 7, 7], name='local_down_0') h1, hc1 = layer.downsample_resnet_block(self.is_train, h0, nc[0], nc[1], name='local_down_1') h2, hc2 = layer.downsample_resnet_block(self.is_train, h1, nc[1], nc[2], name='local_down_2') h3, hc3 = layer.downsample_resnet_block(self.is_train, h2, nc[2], nc[3], name='local_down_3') # 这个代码是对应文章中 fig.4 中的哪个卷积块? hm = [ layer.conv3_block(self.is_train, h3, nc[3], nc[4], name='local_deep_4') ] min_level = min(self.ddf_levels) gated_h1, self.gated1 = layer.att_upsample_resnet_block( self.is_train, hm[0], hc3, nc[4], nc[3], name='local_up_3') # if min_level < 4 else None,None hm += [gated_h1] gated_h2, self.gated2 = layer.att_upsample_resnet_block( self.is_train, hm[1], hc2, nc[3], nc[2], name='local_up_2') # if min_level < 3 else None,None hm += [gated_h2] gated_h3, self.gated3 = layer.att_upsample_resnet_block( self.is_train, hm[2], hc1, nc[2], nc[1], name='local_up_1') # if min_level < 2 else None,None hm += [gated_h3] gated_h4, self.gated4 = layer.att_upsample_resnet_block( self.is_train, hm[3], hc0, nc[1], nc[0], name='local_up_0') # if min_level < 1 else None,None hm += [gated_h4] ddf_list = [ layer.ddf_summand(hm[4 - idx], nc[idx], self.image_size, name='ddf1_sum_%d' % idx) for idx in self.ddf_levels ] ddf_list = tf.stack(ddf_list, axis=5) self.ddf_MV_FIX = tf.reduce_sum(ddf_list, axis=5) self.grid_warped_MV_FIX = self.grid_ref + self.ddf_MV_FIX # self.grid_warped_FIX_MV = self.grid_ref + self.ddf_FIX_MV #create loss self.warped_MV_image = self.warp_MV_image(self.input_MV_image) self.warped_MV_label = self.warp_MV_image( self.input_MV_label ) # warp the moving label with the predicted ddf self.loss_warp_mv_fix = tf.reduce_mean( loss.multi_scale_loss(self.input_FIX_label, self.warped_MV_label, 'dice', [0, 1, 2, 4])) self.ddf_regu_MV = self.args.lambda_ben * tf.reduce_mean( loss.local_displacement_energy(self.ddf_MV_FIX, 'bending', 1)) self.train_op = tf.train.AdamOptimizer(self.learning_rate).minimize( self.loss_warp_mv_fix + self.ddf_regu_MV, global_step=self.global_step) self.logger.debug("build network finish")