def res3_block(self, x, ct):
        self.res3a_2 = Conv3D(kernel_size=3, filters=128, strides=1, padding='same', 
                               data_format=self.DATA_FORMAT, trainable=ct, name='res3a_2')
        self.res3a_bn = BatchNormalization(axis=1, trainable=ct, name='res3a_bn')
        self.res3b_1 = Conv3D(kernel_size=3, filters=128, strides=1, padding='same', 
                              data_format=self.DATA_FORMAT, trainable=ct, name='res3b_1')
        self.res3b_1_bn = BatchNormalization(axis=1, trainable=ct, name='res3b_1_bn')
        self.res3b_2 = Conv3D(kernel_size=3, filters=128, strides=1, padding='same', 
                              data_format=self.DATA_FORMAT, trainable=ct, name='res3b_2')
        self.res3b_bn = BatchNormalization(axis=1, trainable=ct, name='res3b_bn')
        
        x1 = self.res3a_2(x)

        x2 = self.res3a_2(x)
        x2 = self.res3a_bn(x2)
        x2 = tf.nn.relu(x2, name='res3a_relu')
        x2 = self.res3b_1(x2)
        x2 = self.res3b_1_bn(x2)
        x2 = tf.nn.relu(x2, name='res3b_1_relu')
        x2 = self.res3b_2(x2)

        x = x1 + x2
        x = self.res3b_bn(x)
        x = tf.nn.relu(x, name='res3b')
        return x
Beispiel #2
0
def SAAM(x, up_scale, N_svd):
    bsize, ang, hei, wid, chn = x.get_shape().as_list()
    # \phi'_q
    h1 = Conv3D(chn // 8, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_1')(x)
    h1 = tf.transpose(h1, [0, 2, 1, 3, 4])
    h1 = tf.reshape(h1, [-1, ang * wid, chn // 8])
    if N_svd > 0:
        s1, u1, v1 = tf.svd(h1)
        s1 = tf.slice(s1, [0, 0], [-1, N_svd])
        u1 = tf.slice(u1, [0, 0, 0], [-1, -1, N_svd])
        v1 = tf.slice(v1, [0, 0, 0], [-1, -1, N_svd])

    # \phi'_k
    h2 = Conv3D(chn // 8, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_2')(x)
    h2 = tf.transpose(h2, [0, 2, 1, 3, 4])
    h2 = tf.reshape(h2, [-1, ang * wid, chn // 8])
    if N_svd > 0:
        s2, u2, v2 = tf.svd(h2)
        s2 = tf.slice(s2, [0, 0], [-1, N_svd])
        u2 = tf.slice(u2, [0, 0, 0], [-1, -1, N_svd])
        v2 = tf.slice(v2, [0, 0, 0], [-1, -1, N_svd])

    # \phi'_v
    h3 = Conv3D(chn // 8 * up_scale, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_3')(x)
    h3 = tf.transpose(h3, [0, 2, 1, 3, 4])
    h3 = tf.reshape(h3, [-1, ang * wid, chn // 8 * up_scale])

    # Map
    if N_svd > 0:
        attn_EPI = tf.matmul(tf.matmul(u1, tf.matmul(tf.matmul(tf.matrix_diag(s1), tf.matmul(v1, v2, transpose_a=True)),
                                                     tf.matrix_diag(s2), transpose_b=True)), u2, transpose_b=True)
    else:
        attn_EPI = tf.matmul(h1, h2, transpose_b=True)
    attn_EPI = tf.nn.softmax(attn_EPI)

    # \phi'_a
    h = tf.matmul(attn_EPI, h3)

    # \phi'_b
    h = tf.reshape(h, [-1, hei, ang, wid, chn // 8 * up_scale])
    h = tf.transpose(h, [0, 2, 1, 3, 4])
    
    x = Conv3D(chn // 8 * up_scale, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_4')(x)
    sigma = beta_variable([1], name='attn_sigma')
    h = x + sigma * h

    # \phi_c
    h = pixel_shuffle(h, up_scale)
    h = Cropping3D(cropping=((0, up_scale - 1), (0, 0), (0, 0)))(h)
    h = Conv3D(filters=chn, kernel_size=(1, 1, 7), activation='relu', padding='SAME', name='epi_5')(h)
    h = Conv3D(filters=chn, kernel_size=(7, 1, 1), activation='relu', padding='SAME', name='epi_6')(h)
    return h
def conv_block_simple_3d_no_bn(prevlayer,
                               num_filters,
                               prefix,
                               kernel_size=(3, 3, 3),
                               initializer="glorot_normal",
                               strides=(1, 1, 1)):

    conv = Conv3D(filters=num_filters,
                  kernel_size=kernel_size,
                  padding="same",
                  kernel_initializer=initializer,
                  strides=strides,
                  name=prefix + "_conv",
                  data_format='channels_first')(prevlayer)

    conv = tf.nn.relu(conv, name=prefix + "_activation")
    # conv = Activation('relu', name=prefix + "_activation")(conv)

    return conv
def unet_7_layers_3D(input_tensor):

    # print('INPUT IMAGE SHAPE')
    # print(input_tensor.shape)

    mp_param = (1, 2, 2)  # (1,2,2)
    stride_param = (1, 2, 2)
    d_format = "channels_first"
    pad = "same"
    us_param = (1, 2, 2)
    # kern=(1,3,3)
    kern = (2, 3, 3)

    # filt=(32,64,128,256,512)
    filt = (32, 64, 128, 256)
    # filt=(64,128,256,512,1024)

    conv1 = conv_block_simple_3d(prevlayer=input_tensor,
                                 num_filters=filt[0],
                                 prefix="conv1",
                                 kernel_size=kern)
    conv1 = conv_block_simple_3d(prevlayer=conv1,
                                 num_filters=filt[0],
                                 prefix="conv1_1",
                                 kernel_size=kern)
    pool1 = MaxPooling3D(pool_size=mp_param,
                         strides=stride_param,
                         padding="same",
                         data_format="channels_first",
                         name="pool1")(conv1)

    conv2 = conv_block_simple_3d(prevlayer=pool1,
                                 num_filters=filt[1],
                                 prefix="conv2",
                                 kernel_size=kern)
    conv2 = conv_block_simple_3d(prevlayer=conv2,
                                 num_filters=filt[1],
                                 prefix="conv2_1",
                                 kernel_size=kern)
    pool2 = MaxPooling3D(pool_size=mp_param,
                         strides=stride_param,
                         padding="same",
                         data_format="channels_first",
                         name="pool2")(conv2)

    conv3 = conv_block_simple_3d(prevlayer=pool2,
                                 num_filters=filt[2],
                                 prefix="conv3",
                                 kernel_size=kern)
    conv3 = conv_block_simple_3d(prevlayer=conv3,
                                 num_filters=filt[2],
                                 prefix="conv3_1",
                                 kernel_size=kern)
    pool3 = MaxPooling3D(pool_size=mp_param,
                         strides=stride_param,
                         padding="same",
                         data_format="channels_first",
                         name="pool3")(conv3)

    conv4 = conv_block_simple_3d(prevlayer=pool3,
                                 num_filters=filt[3],
                                 prefix="conv_4",
                                 kernel_size=kern)
    conv4 = conv_block_simple_3d(prevlayer=conv4,
                                 num_filters=filt[3],
                                 prefix="conv_4_1",
                                 kernel_size=kern)
    conv4 = conv_block_simple_3d(prevlayer=conv4,
                                 num_filters=filt[3],
                                 prefix="conv_4_2",
                                 kernel_size=kern)

    up5 = Conv3DTranspose(filters=filt[2],
                          kernel_size=kern,
                          strides=(1, 2, 2),
                          padding="same",
                          data_format="channels_first")(conv4)

    up5 = concatenate([up5, conv3], axis=1)
    conv5 = conv_block_simple_3d(prevlayer=up5,
                                 num_filters=filt[2],
                                 prefix="conv5_1")
    conv5 = conv_block_simple_3d(prevlayer=conv5,
                                 num_filters=filt[2],
                                 prefix="conv5_2")

    up6 = Conv3DTranspose(filters=filt[1],
                          kernel_size=kern,
                          strides=(1, 2, 2),
                          padding="same",
                          data_format="channels_first")(conv5)

    up6 = concatenate([up6, conv2], axis=1)
    conv6 = conv_block_simple_3d(prevlayer=up6,
                                 num_filters=filt[1],
                                 prefix="conv6_1")
    conv6 = conv_block_simple_3d(prevlayer=conv6,
                                 num_filters=filt[1],
                                 prefix="conv6_2")

    up7 = Conv3DTranspose(filters=filt[0],
                          kernel_size=kern,
                          strides=(1, 2, 2),
                          padding="same",
                          data_format="channels_first")(conv6)
    up7 = concatenate([up7, conv1], axis=1)
    conv7 = conv_block_simple_3d(prevlayer=up7,
                                 num_filters=filt[0],
                                 prefix="conv7_1")
    conv7 = conv_block_simple_3d(prevlayer=conv7,
                                 num_filters=filt[0],
                                 prefix="conv7_2")

    # conv9 = SpatialDropout2D(0.2,data_format=d_format)(conv9)

    prediction = Conv3D(filters=1,
                        kernel_size=(1, 1, 1),
                        activation="sigmoid",
                        name="prediction",
                        data_format=d_format)(conv7)

    # print('PREDICTION SHAPE')
    # print(prediction.shape)

    return prediction
def unet_9_layers_3D(input_shape):

    img_input = Input(input_shape)

    # print('INPUT IMAGE SHAPE')
    # print(img_input.shape)

    mp_param = (1, 2, 2)  # (1,2,2)
    stride_param = (1, 2, 2)
    d_format = "channels_first"
    pad = "same"
    us_param = (1, 2, 2)

    filt = (32, 64, 128, 256, 512)
    # filt=(64,128,256,512,1024)

    conv1 = conv_block_simple_3d(prevlayer=img_input,
                                 num_filters=filt[0],
                                 prefix="conv1")
    conv1 = conv_block_simple_3d(prevlayer=conv1,
                                 num_filters=filt[0],
                                 prefix="conv1_1")
    pool1 = MaxPooling3D(pool_size=mp_param,
                         strides=stride_param,
                         padding="same",
                         data_format="channels_first",
                         name="pool1")(conv1)

    conv2 = conv_block_simple_3d(prevlayer=pool1,
                                 num_filters=filt[1],
                                 prefix="conv2")
    conv2 = conv_block_simple_3d(prevlayer=conv2,
                                 num_filters=filt[1],
                                 prefix="conv2_1")
    pool2 = MaxPooling3D(pool_size=mp_param,
                         strides=stride_param,
                         padding="same",
                         data_format="channels_first",
                         name="pool2")(conv2)

    conv3 = conv_block_simple_3d(prevlayer=pool2,
                                 num_filters=filt[2],
                                 prefix="conv3")
    conv3 = conv_block_simple_3d(prevlayer=conv3,
                                 num_filters=filt[2],
                                 prefix="conv3_1")
    pool3 = MaxPooling3D(pool_size=mp_param,
                         strides=stride_param,
                         padding="same",
                         data_format="channels_first",
                         name="pool3")(conv3)

    conv4 = conv_block_simple_3d(prevlayer=pool3,
                                 num_filters=filt[3],
                                 prefix="conv4")
    conv4 = conv_block_simple_3d(prevlayer=conv4,
                                 num_filters=filt[3],
                                 prefix="conv4_1")
    pool4 = MaxPooling3D(pool_size=mp_param,
                         strides=stride_param,
                         padding="same",
                         data_format="channels_first",
                         name="pool4")(conv4)

    conv5 = conv_block_simple_3d(prevlayer=pool4,
                                 num_filters=filt[4],
                                 prefix="conv_5")
    conv5 = conv_block_simple_3d(prevlayer=conv5,
                                 num_filters=filt[4],
                                 prefix="conv_5_1")
    conv5 = conv_block_simple_3d(prevlayer=conv5,
                                 num_filters=filt[4],
                                 prefix="conv_5_2")

    up6 = UpSampling3D(size=us_param, data_format=d_format)(conv5)

    up6 = concatenate([up6, conv4], axis=1)
    conv6 = conv_block_simple_3d(prevlayer=up6,
                                 num_filters=filt[3],
                                 prefix="conv6_1")
    conv6 = conv_block_simple_3d(prevlayer=conv6,
                                 num_filters=filt[3],
                                 prefix="conv6_2")

    up7 = UpSampling3D(size=us_param, data_format=d_format)(conv6)
    up7 = concatenate([up7, conv3], axis=1)
    conv7 = conv_block_simple_3d(prevlayer=up7,
                                 num_filters=filt[2],
                                 prefix="conv7_1")
    conv7 = conv_block_simple_3d(prevlayer=conv7,
                                 num_filters=filt[2],
                                 prefix="conv7_2")

    up8 = UpSampling3D(size=us_param, data_format=d_format)(conv7)
    up8 = concatenate([up8, conv2], axis=1)
    conv8 = conv_block_simple_3d(prevlayer=up8,
                                 num_filters=filt[1],
                                 prefix="conv8_1")
    conv8 = conv_block_simple_3d(prevlayer=conv8,
                                 num_filters=filt[1],
                                 prefix="conv8_2")

    up9 = UpSampling3D(size=us_param, data_format=d_format)(conv8)
    up9 = concatenate([up9, conv1], axis=1)
    conv9 = conv_block_simple_3d(prevlayer=up9,
                                 num_filters=filt[0],
                                 prefix="conv9_1")
    conv9 = conv_block_simple_3d(prevlayer=conv9,
                                 num_filters=filt[0],
                                 prefix="conv9_2")

    # conv9 = SpatialDropout2D(0.2,data_format=d_format)(conv9)

    prediction = Conv3D(filters=1,
                        kernel_size=(1, 1, 1),
                        activation="sigmoid",
                        name="prediction",
                        data_format=d_format)(conv9)
    model = Model(img_input, prediction)

    # print('PREDICTION SHAPE')
    # print(prediction.shape)

    return model
Beispiel #6
0
def model(x, N_svd=0):
    up_scale = 4
    with tf.variable_scope('ASR'):
        input_shape = x.get_shape().as_list()
        chn_in = input_shape[4]
        chn_base = 6 * up_scale
        # shape is [batch, 6, 24, 64, 1]

        # Group 1
        h = Conv3D(filters=chn_base,
                   kernel_size=(3, 1, 3),
                   activation='relu',
                   padding='SAME',
                   name='conv1_1')(x)
        h = Conv3D(filters=chn_base,
                   kernel_size=(3, 3, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv1_2')(h)
        h1 = Conv3DTranspose(chn_base, (7, 1, 3), (up_scale, 1, 1),
                             'SAME',
                             activation='relu',
                             name='deconv1')(h)
        h1 = Cropping3D(cropping=((0, 3), (0, 0), (0, 0)))(h1)
        h1 = Conv3D(filters=chn_base,
                    kernel_size=(1, 1, 1),
                    activation='relu',
                    padding='SAME',
                    name='conv1_3')(h1)
        # shape is [batch, 6, 24, 64, chn_base]
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(1, 3, 3),
                   strides=(1, 2, 2),
                   activation='relu',
                   padding='SAME',
                   name='conv1_4')(h)
        # shape is [batch, 6, 12, 32, chn_base * 2]

        # Group 2
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(3, 1, 3),
                   activation='relu',
                   padding='SAME',
                   name='conv2_1')(h)
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(3, 3, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv2_2')(h)
        h2 = Conv3DTranspose(chn_base, (7, 1, 3), (up_scale, 1, 1),
                             'SAME',
                             activation='relu',
                             name='deconv2')(h)
        h2 = Cropping3D(cropping=((0, 3), (0, 0), (0, 0)))(h2)
        h2 = Conv3D(filters=chn_base * 2,
                    kernel_size=(1, 1, 1),
                    activation='relu',
                    padding='SAME',
                    name='conv2_3')(h2)
        # shape is [batch, 6, 12, 32, chn_base * 2]
        h = Conv3D(filters=chn_base * 4,
                   kernel_size=(1, 1, 3),
                   strides=(1, 1, 2),
                   activation='relu',
                   padding='SAME',
                   name='conv2_4')(h)
        # shape is [batch, 6, 12, 16, chn_base * 4]

        # Layer 3, shrinking
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(1, 1, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv3')(h)
        # shape is [batch, 6, 12, 16, chn_base * 2]

        # Group 4, Mapping
        for i in range(2):
            h = Conv3D(filters=chn_base * 2,
                       kernel_size=(3, 1, 3),
                       activation='relu',
                       padding='SAME',
                       name='conv4_1_' + str(i))(h)
            h = Conv3D(filters=chn_base * 2,
                       kernel_size=(3, 3, 1),
                       activation='relu',
                       padding='SAME',
                       name='conv4_2_' + str(i))(h)
        # shape is [batch, 6, 12, 16, chn_base * 2]

        # Group 5, Attention
        h = SAAM(h, up_scale=up_scale, N_svd=N_svd)

        # Layer 6, Expanding
        h = Conv3D(filters=chn_base * 4,
                   kernel_size=(1, 1, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv6')(h)
        # shape is [batch, 6, 12, 16, chn_base * 4]

        # Group 7
        h = Conv3DTranspose(chn_base * 2, (1, 1, 4), (1, 1, 2),
                            activation='relu',
                            padding='SAME',
                            name='conv7_1')(h)
        # shape is [batch, 16, 12, 32, chn_base * 2]
        h = tf.concat([h, h2], axis=-1)
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(3, 1, 3),
                   activation='relu',
                   padding='SAME',
                   name='conv7_2')(h)
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(3, 3, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv7_3')(h)
        # shape is [batch, 16, 12, 32, chn_base * 2]

        # Group 8
        h = Conv3DTranspose(chn_base, (1, 4, 4), (1, 2, 2),
                            activation='relu',
                            padding='SAME',
                            name='conv8_1')(h)
        # shape is [batch, 16, 24, 64, chn_base]
        h = tf.concat([h, h1], axis=-1)
        h = Conv3D(filters=chn_base,
                   kernel_size=(3, 1, 3),
                   activation='relu',
                   padding='SAME',
                   name='conv8_2')(h)
        h = Conv3D(filters=chn_base,
                   kernel_size=(3, 3, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv8_3')(h)
        # shape is [batch, 16, 24, 64, chn_base]

        # Group 9
        h = Conv3D(filters=chn_in,
                   kernel_size=(3, 3, 3),
                   padding='SAME',
                   name='conv9')(h)
    return h