Exemple #1
0
 def __init__(self, minibatch_size, image_moving, image_fixed):
     self.minibatch_size = minibatch_size
     self.image_size = image_fixed.shape.as_list()[1:4]
     self.grid_ref = util.get_reference_grid(self.image_size)
     self.grid_warped = tf.zeros_like(
         self.grid_ref)  # initial zeros are safer for debug
     self.image_moving = image_moving
     self.image_fixed = image_fixed
     self.input_layer = tf.concat(
         [layer.resize_volume(image_moving, self.image_size), image_fixed],
         axis=4)
Exemple #2
0
def warp_volumes_by_ddf(input_, ddf):
    grid_warped = util.get_reference_grid(ddf.shape[1:4]) + ddf
    warped = util.resample_linear(tf.convert_to_tensor(input_, dtype=tf.float32), grid_warped)
    with tf.Session() as sess:
        return sess.run(warped)
Exemple #3
0
    def build_network(self):

        self.global_step = tf.Variable(0, trainable=False)
        self.learning_rate = tf.train.exponential_decay(self.args.lr,
                                                        self.global_step,
                                                        self.args.decay_freq,
                                                        0.96,
                                                        staircase=True)

        self.grid_ref = util.get_reference_grid(self.image_size)
        self.grid_warped_MV_FIX = tf.zeros_like(
            self.grid_ref)  # initial zeros are safer for debug
        self.grid_warped_FIX_MV = tf.zeros_like(
            self.grid_ref)  # initial zeros are safer for debug

        self.ph_MV_image = tf.placeholder(tf.float32, [self.args.batch_size] +
                                          self.image_size + [1])
        self.ph_FIX_image = tf.placeholder(tf.float32, [self.args.batch_size] +
                                           self.image_size + [1])
        self.ph_moving_affine = tf.placeholder(
            tf.float32, [self.args.batch_size] +
            [1, 12])  # 数据进行augment,4x4矩阵,但是最后四个参数为0001,所以一共12个参数
        self.ph_fixed_affine = tf.placeholder(tf.float32,
                                              [self.args.batch_size] + [1, 12])
        self.ph_random_ddf = tf.placeholder(
            tf.float32, [self.args.batch_size] + self.image_size + [3])

        self.ph_MV_label = tf.placeholder(tf.float32, [self.args.batch_size] +
                                          self.image_size + [1])
        self.ph_FIX_label = tf.placeholder(tf.float32, [self.args.batch_size] +
                                           self.image_size + [1])

        self.input_MV_image, self.input_MV_label = util.augment_3Ddata_by_affine(
            self.ph_MV_image, self.ph_MV_label, self.ph_moving_affine)
        self.input_FIX_image, self.input_FIX_label = util.augment_3Ddata_by_affine(
            self.ph_FIX_image, self.ph_FIX_label, self.ph_fixed_affine)

        self.input_layer = tf.concat([
            layer.resize_volume(self.input_MV_image, self.image_size),
            self.input_FIX_image
        ],
                                     axis=4)
        self.transform_initial = [
            1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.
        ]

        nc = [int(self.args.num_channel_initial * (2**i)) for i in range(5)]
        h0, hc0 = layer.downsample_resnet_block(self.is_train,
                                                self.input_layer,
                                                2,
                                                nc[0],
                                                k_conv0=[7, 7, 7],
                                                name='global_down_0')
        h1, hc1 = layer.downsample_resnet_block(self.is_train,
                                                h0,
                                                nc[0],
                                                nc[1],
                                                name='global_down_1')
        h2, hc2 = layer.downsample_resnet_block(self.is_train,
                                                h1,
                                                nc[1],
                                                nc[2],
                                                name='global_down_2')
        h3, hc3 = layer.downsample_resnet_block(self.is_train,
                                                h2,
                                                nc[2],
                                                nc[3],
                                                name='global_down_3')
        h4 = layer.conv3_block(self.is_train,
                               h3,
                               nc[3],
                               nc[4],
                               name='global_deep_4')
        self.theta_fw = layer.fully_connected(h4,
                                              12,
                                              self.transform_initial,
                                              name='global_project_0')
        self.theta_bw = layer.fully_connected(h4,
                                              12,
                                              self.transform_initial,
                                              name='global_project_1')
        # out=tf.layers.flatten(h4)
        # out=tf.layers.dense(out,1024)
        # out=tf.layers.dense(out,256)
        # xyz=tf.layers.dense(out,3)
        # angles=tf.nn.tanh(tf.layers.dense(out,3))

        self.grid_warped_fw = util.warp_grid(self.grid_ref, self.theta_fw)
        # 这个地方为啥又减去,
        self.ddf_fw = self.grid_warped_fw - self.grid_ref

        self.grid_warped_bw = util.warp_grid(self.grid_ref, self.theta_bw)
        # 这个地方为啥又减去,
        self.ddf_bw = self.grid_warped_bw - self.grid_ref

        self.warped_MV_image = self.warp_image(
            self.input_MV_image, self.grid_warped_fw
        )  # warp the moving label with the predicted ddf
        self.warped_FIX_image = self.warp_image(
            self.input_FIX_image, self.grid_warped_bw
        )  # warp the moving label with the predicted ddf

        # self.resotre_MV_label=self.warp_image(self.warped_MV_label,self.grid_warped_bw)
        self.restore_MV_image = self.warp_image(self.warped_MV_image,
                                                self.grid_warped_bw)

        # self.resotre_FIX_label=self.warp_image(self.warped_FIX_label,self.grid_warped_fw)
        self.restore_FIX_image = self.warp_image(self.warped_FIX_image,
                                                 self.grid_warped_fw)

        #这里可以让restore_fix_image* restore_fix_label,因为在形变的时候,图像四周容易生成空白,
        self.ddf_regularisation1 = self.args.lambda_consis * restore_loss2(
            self.input_FIX_image, self.restore_FIX_image)
        self.ddf_regularisation2 = self.args.lambda_consis * restore_loss2(
            self.input_MV_image, self.restore_MV_image)
        self.ddf_regularisation = self.ddf_regularisation1 + self.ddf_regularisation2

        # self.restore_MV_label= self.warp_image(self.warped_MV_label,self.grid_warped_bw)  # warp the moving label with the predicted ddf

        # self.warped_MV_image = self.warp_MV_image(self.input_MV_image)
        # self.warped_MV_label = self.warp_MV_image(self.input_MV_label)  # warp the moving label with the predicted ddf

        self.warped_MV_label = self.warp_image(
            self.input_MV_label, self.grid_warped_fw
        )  # warp the moving label with the predicted ddf
        self.warped_FIX_label = self.warp_image(
            self.input_FIX_label, self.grid_warped_bw
        )  # warp the moving label with the predicted ddf
        self.grad_loss_fw = tf.reduce_mean(
            loss.multi_scale_loss(self.input_FIX_label, self.warped_MV_label,
                                  'dice', [0, 1, 2, 4, 8]))
        self.grad_loss_bw = tf.reduce_mean(
            loss.multi_scale_loss(self.input_MV_label, self.warped_FIX_label,
                                  'dice', [0, 1, 2, 4, 8]))
        self.grad_loss = self.grad_loss_fw + self.grad_loss_bw

        self.train_op = tf.train.AdamOptimizer(
            self.args.lr).minimize(self.grad_loss + self.ddf_regularisation)
    def build_network(self):

        self.global_step = tf.Variable(0, trainable=False)
        self.learning_rate = tf.train.exponential_decay(self.args.lr,
                                                        self.global_step,
                                                        self.args.decay_freq,
                                                        0.96,
                                                        staircase=True)

        self.grid_ref = util.get_reference_grid(self.image_size)
        self.grid_warped_MV_FIX = tf.zeros_like(
            self.grid_ref)  # initial zeros are safer for debug
        self.grid_warped_FIX_MV = tf.zeros_like(
            self.grid_ref)  # initial zeros are safer for debug

        self.ph_MV_image = tf.placeholder(tf.float32, [self.args.batch_size] +
                                          self.image_size + [1])
        self.ph_FIX_image = tf.placeholder(tf.float32, [self.args.batch_size] +
                                           self.image_size + [1])
        self.ph_moving_affine = tf.placeholder(
            tf.float32, [self.args.batch_size] +
            [1, 12])  # 数据进行augment,4x4矩阵,但是最后四个参数为0001,所以一共12个参数
        self.ph_fixed_affine = tf.placeholder(tf.float32,
                                              [self.args.batch_size] + [1, 12])
        self.ph_random_ddf = tf.placeholder(
            tf.float32, [self.args.batch_size] + self.image_size + [3])

        self.ph_MV_label = tf.placeholder(tf.float32, [self.args.batch_size] +
                                          self.image_size + [1])
        self.ph_FIX_label = tf.placeholder(tf.float32, [self.args.batch_size] +
                                           self.image_size + [1])

        self.input_MV_image, self.input_MV_label = util.augment_3Ddata_by_affine(
            self.ph_MV_image, self.ph_MV_label, self.ph_moving_affine)
        self.input_FIX_image, self.input_FIX_label = util.augment_3Ddata_by_affine(
            self.ph_FIX_image, self.ph_FIX_label, self.ph_fixed_affine)
        # self.input_FIX_image,self.input_FIX_label=util.augment_3Ddata_by_DDF(self.ph_FIX_image,self.ph_FIX_label,self.ph_random_ddf)
        self.input_layer = tf.concat([
            layer.resize_volume(self.input_MV_image, self.image_size),
            self.input_FIX_image
        ],
                                     axis=4)
        self.lambda_bend = self.args.lambda_ben
        self.lambda_consis = self.args.lambda_consis
        self.ddf_levels = [0, 1, 2, 3, 4]
        self.num_channel_initial = self.args.num_channel_initial
        # 32,64,128,256,512
        nc = [int(self.num_channel_initial * (2**i)) for i in range(5)]
        h0, hc0 = layer.downsample_resnet_block(self.is_train,
                                                self.input_layer,
                                                2,
                                                nc[0],
                                                k_conv0=[7, 7, 7],
                                                name='local_down_0')
        h1, hc1 = layer.downsample_resnet_block(self.is_train,
                                                h0,
                                                nc[0],
                                                nc[1],
                                                name='local_down_1')
        h2, hc2 = layer.downsample_resnet_block(self.is_train,
                                                h1,
                                                nc[1],
                                                nc[2],
                                                name='local_down_2')
        h3, hc3 = layer.downsample_resnet_block(self.is_train,
                                                h2,
                                                nc[2],
                                                nc[3],
                                                name='local_down_3')
        # 这个代码是对应文章中 fig.4 中的哪个卷积块?
        hm = [
            layer.conv3_block(self.is_train,
                              h3,
                              nc[3],
                              nc[4],
                              name='local_deep_4')
        ]
        min_level = min(self.ddf_levels)
        gated_h1, self.gated1 = layer.att_upsample_resnet_block(
            self.is_train, hm[0], hc3, nc[4], nc[3],
            name='local_up_3')  # if min_level < 4 else None,None
        hm += [gated_h1]

        gated_h2, self.gated2 = layer.att_upsample_resnet_block(
            self.is_train, hm[1], hc2, nc[3], nc[2],
            name='local_up_2')  # if min_level < 3 else None,None
        hm += [gated_h2]

        gated_h3, self.gated3 = layer.att_upsample_resnet_block(
            self.is_train, hm[2], hc1, nc[2], nc[1],
            name='local_up_1')  # if min_level < 2 else None,None
        hm += [gated_h3]

        gated_h4, self.gated4 = layer.att_upsample_resnet_block(
            self.is_train, hm[3], hc0, nc[1], nc[0],
            name='local_up_0')  # if min_level < 1 else None,None
        hm += [gated_h4]
        ddf_list = [
            layer.ddf_summand(hm[4 - idx],
                              nc[idx],
                              self.image_size,
                              name='ddf1_sum_%d' % idx)
            for idx in self.ddf_levels
        ]
        ddf_list = tf.stack(ddf_list, axis=5)
        self.ddf_MV_FIX = tf.reduce_sum(ddf_list, axis=5)

        self.grid_warped_MV_FIX = self.grid_ref + self.ddf_MV_FIX
        # self.grid_warped_FIX_MV = self.grid_ref + self.ddf_FIX_MV

        #create loss
        self.warped_MV_image = self.warp_MV_image(self.input_MV_image)
        self.warped_MV_label = self.warp_MV_image(
            self.input_MV_label
        )  # warp the moving label with the predicted ddf

        self.loss_warp_mv_fix = tf.reduce_mean(
            loss.multi_scale_loss(self.input_FIX_label, self.warped_MV_label,
                                  'dice', [0, 1, 2, 4]))

        self.ddf_regu_MV = self.args.lambda_ben * tf.reduce_mean(
            loss.local_displacement_energy(self.ddf_MV_FIX, 'bending', 1))

        self.train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(
            self.loss_warp_mv_fix + self.ddf_regu_MV,
            global_step=self.global_step)
        self.logger.debug("build network finish")