def Dense3(inputs, growth_rate, training): # [b, w, h, d, c] = inputs.get_shape().as_list() bn_relu1 = BN_ReLU(inputs, training) conv1 = Conv3D(bn_relu1, growth_rate, 3, 1) concat1 = tf.concat((inputs, conv1), axis=4) bn_relu2 = BN_ReLU(concat1, training) conv2 = Conv3D(bn_relu2, growth_rate, kernel_size=3, strides=1) concat2 = tf.concat((concat1, conv2), axis=4) bn_relu3 = BN_ReLU(concat2, training) conv3 = Conv3D(bn_relu3, growth_rate, kernel_size=3, strides=1) concat3 = tf.concat((concat2, conv3), axis=4) bn_relu4 = BN_ReLU(concat3, training) conv4 = Conv3D(bn_relu4, c[3] + 3 * growth_rate, kernel_size=1, strides=1) return conv4
def _attention_block(self, inputs, filters, training, projection_shortcut, strides): """Attentional building block for residual networks with BN before convolutions. Args: inputs: A tensor of size [batch, depth_in, height_in, width_in, channels]. filters: The number of filters for the convolutions. training: A Boolean for whether the model is in training or inference mode. Needed for batch normalization. projection_shortcut: The function to use for projection shortcuts (typically a 1x1 convolution when downsampling the input). strides: The block's stride. If greater than 1, this block will ultimately downsample the input. Returns: The output tensor of the block. """ shortcut = inputs inputs = BN_ReLU(inputs, training) # The projection shortcut should come after the first batch norm and ReLU # since it performs a 1x1 convolution. if projection_shortcut is not None: shortcut = projection_shortcut(inputs) if strides != 1: layer_type = 'UP' else: layer_type = 'SAME' inputs = multihead_attention_3d( inputs, filters, filters, filters, 1, training, layer_type) return inputs + shortcut
def _build_network(self, inputs, training): """Build the network. """ inputs = Conv3D( inputs=inputs, filters=self.num_filters, kernel_size=3, strides=1) inputs = tf.identity(inputs, 'initial_conv') skip_inputs = [] for i, num_blocks in enumerate(self.block_sizes): # print(i, num_blocks) num_filters = self.num_filters * (2**i) inputs = self._encoding_block_layer( inputs=inputs, filters=num_filters, block_fn=self._residual_block, blocks=num_blocks, strides=self.block_strides[i], training=training, name='encode_block_layer{}'.format(i+1)) skip_inputs.append(inputs) # print(inputs.shape) # print(len(skip_inputs)) inputs = BN_ReLU(inputs, training) num_filters = self.num_filters * (2**(len(self.block_sizes)-1)) # print(num_filters) inputs = multihead_attention_3d( inputs, num_filters, num_filters, num_filters, 2, training, layer_type='SAME') inputs += skip_inputs[-1] for i, num_blocks in reversed(list(enumerate(self.block_sizes[1:]))): # print(i, num_blocks) num_filters = self.num_filters * (2**i) if i == len(self.block_sizes) - 2: inputs = self._att_decoding_block_layer( inputs=inputs, skip_inputs=skip_inputs[i], filters=num_filters, block_fn=self._residual_block, blocks=1, strides=self.block_strides[i+1], training=training, name='decode_block_layer{}'.format(len(self.block_sizes)-i-1)) else: inputs = self._decoding_block_layer( inputs=inputs, skip_inputs=skip_inputs[i], filters=num_filters, block_fn=self._residual_block, blocks=1, strides=self.block_strides[i+1], training=training, name='decode_block_layer{}'.format(len(self.block_sizes)-i-1)) # print(inputs.shape) inputs = self._output_block_layer(inputs=inputs, training=training) # print(inputs.shape) return inputs
def _output_block_layer(self, inputs, training): inputs = BN_ReLU(inputs, training) inputs = tf.layers.dropout(inputs, rate=0.5, training=training) inputs = Conv3D(inputs=inputs, filters=self.num_classes, kernel_size=1, strides=1, use_bias=True) return tf.identity(inputs, 'output')
def res_inc_deconv(inputs, training): # [b, w, h, d, c]= inputs.get_shape().as_list() deconv1_1_1 = Deconv3D(inputs, 32, kernel_size=3, strides=1, use_bias=False) #44 deconv1_1 = BN_ReLU(deconv1_1_1, training) deconv2_1_1 = Deconv3D(inputs, 32, kernel_size=3, strides=1, use_bias=False) #44 deconv2_1 = BN_ReLU(deconv2_1_1, training) deconv2_3 = Deconv3D(deconv2_1, 64, kernel_size=3, strides=1, use_bias=False) #88 deconv2_2 = BN_ReLU(deconv2_3, training) deconv3_1_1 = Deconv3D(inputs, 32, kernel_size=3, strides=1, use_bias=False) #44 deconv3_1 = BN_ReLU(deconv3_1_1, training) deconv3_2_1 = Dilated_Conv3D(deconv3_1, 32, kernel_size=3, dilation_rate=2, use_bias=False) #44 deconv3_2 = BN_ReLU(deconv3_2_1, training) concat = tf.concat((deconv1_1, deconv2_2, deconv3_2), axis=4) deconv1 = Deconv3D(concat, 128, kernel_size=3, strides=1, use_bias=False) #176 deconv = BN_ReLU(deconv1, training) fuse = tf.add(inputs, deconv) return fuse
def unpool(inputs, training): # [b, w, h, d, c] = inputs.get_shape().as_list() conv31 = Conv3D(inputs, 176, kernel_size=3, strides=1) deconv31 = BN_ReLU(conv31, training) deconv1_1 = Deconv3D(deconv31, 176, kernel_size=3, strides=1, use_bias=False) deconv1 = BN_ReLU(deconv1_1, training) deconv1_2 = Deconv3D(deconv1, 88, kernel_size=3, strides=2, use_bias=False) deconv2 = BN_ReLU(deconv1_2, training) deconv2_1 = Deconv3D(inputs, 176, kernel_size=3, strides=1, use_bias=False) deconv3 = BN_ReLU(deconv2_1, training) deconv2_2 = Dilated_Conv3D(deconv3, 176, kernel_size=3, dilation_rate=2, use_bias=False) deconv4 = BN_ReLU(deconv2_2, training) deconv2_3 = Deconv3D(deconv4, 88, kernel_size=3, strides=2, use_bias=False) deconv5 = BN_ReLU(deconv2_3, training) concat = tf.concat((deconv2, deconv5), axis=4) return concat