def res3_block(self, x, ct): self.res3a_2 = Conv3D(kernel_size=3, filters=128, strides=1, padding='same', data_format=self.DATA_FORMAT, trainable=ct, name='res3a_2') self.res3a_bn = BatchNormalization(axis=1, trainable=ct, name='res3a_bn') self.res3b_1 = Conv3D(kernel_size=3, filters=128, strides=1, padding='same', data_format=self.DATA_FORMAT, trainable=ct, name='res3b_1') self.res3b_1_bn = BatchNormalization(axis=1, trainable=ct, name='res3b_1_bn') self.res3b_2 = Conv3D(kernel_size=3, filters=128, strides=1, padding='same', data_format=self.DATA_FORMAT, trainable=ct, name='res3b_2') self.res3b_bn = BatchNormalization(axis=1, trainable=ct, name='res3b_bn') x1 = self.res3a_2(x) x2 = self.res3a_2(x) x2 = self.res3a_bn(x2) x2 = tf.nn.relu(x2, name='res3a_relu') x2 = self.res3b_1(x2) x2 = self.res3b_1_bn(x2) x2 = tf.nn.relu(x2, name='res3b_1_relu') x2 = self.res3b_2(x2) x = x1 + x2 x = self.res3b_bn(x) x = tf.nn.relu(x, name='res3b') return x
def SAAM(x, up_scale, N_svd): bsize, ang, hei, wid, chn = x.get_shape().as_list() # \phi'_q h1 = Conv3D(chn // 8, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_1')(x) h1 = tf.transpose(h1, [0, 2, 1, 3, 4]) h1 = tf.reshape(h1, [-1, ang * wid, chn // 8]) if N_svd > 0: s1, u1, v1 = tf.svd(h1) s1 = tf.slice(s1, [0, 0], [-1, N_svd]) u1 = tf.slice(u1, [0, 0, 0], [-1, -1, N_svd]) v1 = tf.slice(v1, [0, 0, 0], [-1, -1, N_svd]) # \phi'_k h2 = Conv3D(chn // 8, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_2')(x) h2 = tf.transpose(h2, [0, 2, 1, 3, 4]) h2 = tf.reshape(h2, [-1, ang * wid, chn // 8]) if N_svd > 0: s2, u2, v2 = tf.svd(h2) s2 = tf.slice(s2, [0, 0], [-1, N_svd]) u2 = tf.slice(u2, [0, 0, 0], [-1, -1, N_svd]) v2 = tf.slice(v2, [0, 0, 0], [-1, -1, N_svd]) # \phi'_v h3 = Conv3D(chn // 8 * up_scale, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_3')(x) h3 = tf.transpose(h3, [0, 2, 1, 3, 4]) h3 = tf.reshape(h3, [-1, ang * wid, chn // 8 * up_scale]) # Map if N_svd > 0: attn_EPI = tf.matmul(tf.matmul(u1, tf.matmul(tf.matmul(tf.matrix_diag(s1), tf.matmul(v1, v2, transpose_a=True)), tf.matrix_diag(s2), transpose_b=True)), u2, transpose_b=True) else: attn_EPI = tf.matmul(h1, h2, transpose_b=True) attn_EPI = tf.nn.softmax(attn_EPI) # \phi'_a h = tf.matmul(attn_EPI, h3) # \phi'_b h = tf.reshape(h, [-1, hei, ang, wid, chn // 8 * up_scale]) h = tf.transpose(h, [0, 2, 1, 3, 4]) x = Conv3D(chn // 8 * up_scale, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_4')(x) sigma = beta_variable([1], name='attn_sigma') h = x + sigma * h # \phi_c h = pixel_shuffle(h, up_scale) h = Cropping3D(cropping=((0, up_scale - 1), (0, 0), (0, 0)))(h) h = Conv3D(filters=chn, kernel_size=(1, 1, 7), activation='relu', padding='SAME', name='epi_5')(h) h = Conv3D(filters=chn, kernel_size=(7, 1, 1), activation='relu', padding='SAME', name='epi_6')(h) return h
def conv_block_simple_3d_no_bn(prevlayer, num_filters, prefix, kernel_size=(3, 3, 3), initializer="glorot_normal", strides=(1, 1, 1)): conv = Conv3D(filters=num_filters, kernel_size=kernel_size, padding="same", kernel_initializer=initializer, strides=strides, name=prefix + "_conv", data_format='channels_first')(prevlayer) conv = tf.nn.relu(conv, name=prefix + "_activation") # conv = Activation('relu', name=prefix + "_activation")(conv) return conv
def unet_7_layers_3D(input_tensor): # print('INPUT IMAGE SHAPE') # print(input_tensor.shape) mp_param = (1, 2, 2) # (1,2,2) stride_param = (1, 2, 2) d_format = "channels_first" pad = "same" us_param = (1, 2, 2) # kern=(1,3,3) kern = (2, 3, 3) # filt=(32,64,128,256,512) filt = (32, 64, 128, 256) # filt=(64,128,256,512,1024) conv1 = conv_block_simple_3d(prevlayer=input_tensor, num_filters=filt[0], prefix="conv1", kernel_size=kern) conv1 = conv_block_simple_3d(prevlayer=conv1, num_filters=filt[0], prefix="conv1_1", kernel_size=kern) pool1 = MaxPooling3D(pool_size=mp_param, strides=stride_param, padding="same", data_format="channels_first", name="pool1")(conv1) conv2 = conv_block_simple_3d(prevlayer=pool1, num_filters=filt[1], prefix="conv2", kernel_size=kern) conv2 = conv_block_simple_3d(prevlayer=conv2, num_filters=filt[1], prefix="conv2_1", kernel_size=kern) pool2 = MaxPooling3D(pool_size=mp_param, strides=stride_param, padding="same", data_format="channels_first", name="pool2")(conv2) conv3 = conv_block_simple_3d(prevlayer=pool2, num_filters=filt[2], prefix="conv3", kernel_size=kern) conv3 = conv_block_simple_3d(prevlayer=conv3, num_filters=filt[2], prefix="conv3_1", kernel_size=kern) pool3 = MaxPooling3D(pool_size=mp_param, strides=stride_param, padding="same", data_format="channels_first", name="pool3")(conv3) conv4 = conv_block_simple_3d(prevlayer=pool3, num_filters=filt[3], prefix="conv_4", kernel_size=kern) conv4 = conv_block_simple_3d(prevlayer=conv4, num_filters=filt[3], prefix="conv_4_1", kernel_size=kern) conv4 = conv_block_simple_3d(prevlayer=conv4, num_filters=filt[3], prefix="conv_4_2", kernel_size=kern) up5 = Conv3DTranspose(filters=filt[2], kernel_size=kern, strides=(1, 2, 2), padding="same", data_format="channels_first")(conv4) up5 = concatenate([up5, conv3], axis=1) conv5 = conv_block_simple_3d(prevlayer=up5, num_filters=filt[2], prefix="conv5_1") conv5 = conv_block_simple_3d(prevlayer=conv5, num_filters=filt[2], prefix="conv5_2") up6 = Conv3DTranspose(filters=filt[1], kernel_size=kern, strides=(1, 2, 2), padding="same", data_format="channels_first")(conv5) up6 = concatenate([up6, conv2], axis=1) conv6 = conv_block_simple_3d(prevlayer=up6, num_filters=filt[1], prefix="conv6_1") conv6 = conv_block_simple_3d(prevlayer=conv6, num_filters=filt[1], prefix="conv6_2") up7 = Conv3DTranspose(filters=filt[0], kernel_size=kern, strides=(1, 2, 2), padding="same", data_format="channels_first")(conv6) up7 = concatenate([up7, conv1], axis=1) conv7 = conv_block_simple_3d(prevlayer=up7, num_filters=filt[0], prefix="conv7_1") conv7 = conv_block_simple_3d(prevlayer=conv7, num_filters=filt[0], prefix="conv7_2") # conv9 = SpatialDropout2D(0.2,data_format=d_format)(conv9) prediction = Conv3D(filters=1, kernel_size=(1, 1, 1), activation="sigmoid", name="prediction", data_format=d_format)(conv7) # print('PREDICTION SHAPE') # print(prediction.shape) return prediction
def unet_9_layers_3D(input_shape): img_input = Input(input_shape) # print('INPUT IMAGE SHAPE') # print(img_input.shape) mp_param = (1, 2, 2) # (1,2,2) stride_param = (1, 2, 2) d_format = "channels_first" pad = "same" us_param = (1, 2, 2) filt = (32, 64, 128, 256, 512) # filt=(64,128,256,512,1024) conv1 = conv_block_simple_3d(prevlayer=img_input, num_filters=filt[0], prefix="conv1") conv1 = conv_block_simple_3d(prevlayer=conv1, num_filters=filt[0], prefix="conv1_1") pool1 = MaxPooling3D(pool_size=mp_param, strides=stride_param, padding="same", data_format="channels_first", name="pool1")(conv1) conv2 = conv_block_simple_3d(prevlayer=pool1, num_filters=filt[1], prefix="conv2") conv2 = conv_block_simple_3d(prevlayer=conv2, num_filters=filt[1], prefix="conv2_1") pool2 = MaxPooling3D(pool_size=mp_param, strides=stride_param, padding="same", data_format="channels_first", name="pool2")(conv2) conv3 = conv_block_simple_3d(prevlayer=pool2, num_filters=filt[2], prefix="conv3") conv3 = conv_block_simple_3d(prevlayer=conv3, num_filters=filt[2], prefix="conv3_1") pool3 = MaxPooling3D(pool_size=mp_param, strides=stride_param, padding="same", data_format="channels_first", name="pool3")(conv3) conv4 = conv_block_simple_3d(prevlayer=pool3, num_filters=filt[3], prefix="conv4") conv4 = conv_block_simple_3d(prevlayer=conv4, num_filters=filt[3], prefix="conv4_1") pool4 = MaxPooling3D(pool_size=mp_param, strides=stride_param, padding="same", data_format="channels_first", name="pool4")(conv4) conv5 = conv_block_simple_3d(prevlayer=pool4, num_filters=filt[4], prefix="conv_5") conv5 = conv_block_simple_3d(prevlayer=conv5, num_filters=filt[4], prefix="conv_5_1") conv5 = conv_block_simple_3d(prevlayer=conv5, num_filters=filt[4], prefix="conv_5_2") up6 = UpSampling3D(size=us_param, data_format=d_format)(conv5) up6 = concatenate([up6, conv4], axis=1) conv6 = conv_block_simple_3d(prevlayer=up6, num_filters=filt[3], prefix="conv6_1") conv6 = conv_block_simple_3d(prevlayer=conv6, num_filters=filt[3], prefix="conv6_2") up7 = UpSampling3D(size=us_param, data_format=d_format)(conv6) up7 = concatenate([up7, conv3], axis=1) conv7 = conv_block_simple_3d(prevlayer=up7, num_filters=filt[2], prefix="conv7_1") conv7 = conv_block_simple_3d(prevlayer=conv7, num_filters=filt[2], prefix="conv7_2") up8 = UpSampling3D(size=us_param, data_format=d_format)(conv7) up8 = concatenate([up8, conv2], axis=1) conv8 = conv_block_simple_3d(prevlayer=up8, num_filters=filt[1], prefix="conv8_1") conv8 = conv_block_simple_3d(prevlayer=conv8, num_filters=filt[1], prefix="conv8_2") up9 = UpSampling3D(size=us_param, data_format=d_format)(conv8) up9 = concatenate([up9, conv1], axis=1) conv9 = conv_block_simple_3d(prevlayer=up9, num_filters=filt[0], prefix="conv9_1") conv9 = conv_block_simple_3d(prevlayer=conv9, num_filters=filt[0], prefix="conv9_2") # conv9 = SpatialDropout2D(0.2,data_format=d_format)(conv9) prediction = Conv3D(filters=1, kernel_size=(1, 1, 1), activation="sigmoid", name="prediction", data_format=d_format)(conv9) model = Model(img_input, prediction) # print('PREDICTION SHAPE') # print(prediction.shape) return model
def model(x, N_svd=0): up_scale = 4 with tf.variable_scope('ASR'): input_shape = x.get_shape().as_list() chn_in = input_shape[4] chn_base = 6 * up_scale # shape is [batch, 6, 24, 64, 1] # Group 1 h = Conv3D(filters=chn_base, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv1_1')(x) h = Conv3D(filters=chn_base, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv1_2')(h) h1 = Conv3DTranspose(chn_base, (7, 1, 3), (up_scale, 1, 1), 'SAME', activation='relu', name='deconv1')(h) h1 = Cropping3D(cropping=((0, 3), (0, 0), (0, 0)))(h1) h1 = Conv3D(filters=chn_base, kernel_size=(1, 1, 1), activation='relu', padding='SAME', name='conv1_3')(h1) # shape is [batch, 6, 24, 64, chn_base] h = Conv3D(filters=chn_base * 2, kernel_size=(1, 3, 3), strides=(1, 2, 2), activation='relu', padding='SAME', name='conv1_4')(h) # shape is [batch, 6, 12, 32, chn_base * 2] # Group 2 h = Conv3D(filters=chn_base * 2, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv2_1')(h) h = Conv3D(filters=chn_base * 2, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv2_2')(h) h2 = Conv3DTranspose(chn_base, (7, 1, 3), (up_scale, 1, 1), 'SAME', activation='relu', name='deconv2')(h) h2 = Cropping3D(cropping=((0, 3), (0, 0), (0, 0)))(h2) h2 = Conv3D(filters=chn_base * 2, kernel_size=(1, 1, 1), activation='relu', padding='SAME', name='conv2_3')(h2) # shape is [batch, 6, 12, 32, chn_base * 2] h = Conv3D(filters=chn_base * 4, kernel_size=(1, 1, 3), strides=(1, 1, 2), activation='relu', padding='SAME', name='conv2_4')(h) # shape is [batch, 6, 12, 16, chn_base * 4] # Layer 3, shrinking h = Conv3D(filters=chn_base * 2, kernel_size=(1, 1, 1), activation='relu', padding='SAME', name='conv3')(h) # shape is [batch, 6, 12, 16, chn_base * 2] # Group 4, Mapping for i in range(2): h = Conv3D(filters=chn_base * 2, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv4_1_' + str(i))(h) h = Conv3D(filters=chn_base * 2, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv4_2_' + str(i))(h) # shape is [batch, 6, 12, 16, chn_base * 2] # Group 5, Attention h = SAAM(h, up_scale=up_scale, N_svd=N_svd) # Layer 6, Expanding h = Conv3D(filters=chn_base * 4, kernel_size=(1, 1, 1), activation='relu', padding='SAME', name='conv6')(h) # shape is [batch, 6, 12, 16, chn_base * 4] # Group 7 h = Conv3DTranspose(chn_base * 2, (1, 1, 4), (1, 1, 2), activation='relu', padding='SAME', name='conv7_1')(h) # shape is [batch, 16, 12, 32, chn_base * 2] h = tf.concat([h, h2], axis=-1) h = Conv3D(filters=chn_base * 2, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv7_2')(h) h = Conv3D(filters=chn_base * 2, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv7_3')(h) # shape is [batch, 16, 12, 32, chn_base * 2] # Group 8 h = Conv3DTranspose(chn_base, (1, 4, 4), (1, 2, 2), activation='relu', padding='SAME', name='conv8_1')(h) # shape is [batch, 16, 24, 64, chn_base] h = tf.concat([h, h1], axis=-1) h = Conv3D(filters=chn_base, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv8_2')(h) h = Conv3D(filters=chn_base, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv8_3')(h) # shape is [batch, 16, 24, 64, chn_base] # Group 9 h = Conv3D(filters=chn_in, kernel_size=(3, 3, 3), padding='SAME', name='conv9')(h) return h