def Dense3(inputs, growth_rate, training):
    #	[b, w, h, d, c] = inputs.get_shape().as_list()
    bn_relu1 = BN_ReLU(inputs, training)
    conv1 = Conv3D(bn_relu1, growth_rate, 3, 1)
    concat1 = tf.concat((inputs, conv1), axis=4)
    bn_relu2 = BN_ReLU(concat1, training)
    conv2 = Conv3D(bn_relu2, growth_rate, kernel_size=3, strides=1)
    concat2 = tf.concat((concat1, conv2), axis=4)
    bn_relu3 = BN_ReLU(concat2, training)
    conv3 = Conv3D(bn_relu3, growth_rate, kernel_size=3, strides=1)
    concat3 = tf.concat((concat2, conv3), axis=4)
    bn_relu4 = BN_ReLU(concat3, training)
    conv4 = Conv3D(bn_relu4, c[3] + 3 * growth_rate, kernel_size=1, strides=1)
    return conv4
Ejemplo n.º 2
0
	def _attention_block(self, inputs, filters, training,
							projection_shortcut, strides):
		"""Attentional building block for residual networks with BN before convolutions.

		Args:
			inputs: A tensor of size [batch, depth_in, height_in, width_in, channels].
			filters: The number of filters for the convolutions.
			training: A Boolean for whether the model is in training or inference
				mode. Needed for batch normalization.
			projection_shortcut: The function to use for projection shortcuts
				(typically a 1x1 convolution when downsampling the input).
			strides: The block's stride. If greater than 1, this block will ultimately
				downsample the input.

		Returns:
			The output tensor of the block.
		"""

		shortcut = inputs
		inputs = BN_ReLU(inputs, training)

		# The projection shortcut should come after the first batch norm and ReLU
		# since it performs a 1x1 convolution.
		if projection_shortcut is not None:
			shortcut = projection_shortcut(inputs)

		if strides != 1:
			layer_type = 'UP'
		else:
			layer_type = 'SAME'

		inputs = multihead_attention_3d(
					inputs, filters, filters, filters, 1, training, layer_type)

		return inputs + shortcut
Ejemplo n.º 3
0
	def _build_network(self, inputs, training):
		"""Build the network.
		"""

		inputs = Conv3D(
					inputs=inputs,
					filters=self.num_filters,
					kernel_size=3,
					strides=1)
		inputs = tf.identity(inputs, 'initial_conv')

		skip_inputs = []
		for i, num_blocks in enumerate(self.block_sizes):
			# print(i, num_blocks)
			num_filters = self.num_filters * (2**i)
			inputs = self._encoding_block_layer(
						inputs=inputs, filters=num_filters,
						block_fn=self._residual_block, blocks=num_blocks,
						strides=self.block_strides[i], training=training,
						name='encode_block_layer{}'.format(i+1))
			skip_inputs.append(inputs)
			# print(inputs.shape)
		# print(len(skip_inputs))
		
		inputs = BN_ReLU(inputs, training)
		num_filters = self.num_filters * (2**(len(self.block_sizes)-1))
		# print(num_filters)
		inputs = multihead_attention_3d(
					inputs, num_filters, num_filters, num_filters, 2, training, layer_type='SAME')
		inputs += skip_inputs[-1]

		for i, num_blocks in reversed(list(enumerate(self.block_sizes[1:]))):
			# print(i, num_blocks)
			num_filters = self.num_filters * (2**i)
			if i == len(self.block_sizes) - 2:
				inputs = self._att_decoding_block_layer(
						inputs=inputs, skip_inputs=skip_inputs[i],
						filters=num_filters, block_fn=self._residual_block,
						blocks=1, strides=self.block_strides[i+1],
						training=training,
						name='decode_block_layer{}'.format(len(self.block_sizes)-i-1))
			else:
				inputs = self._decoding_block_layer(
						inputs=inputs, skip_inputs=skip_inputs[i],
						filters=num_filters, block_fn=self._residual_block,
						blocks=1, strides=self.block_strides[i+1],
						training=training,
						name='decode_block_layer{}'.format(len(self.block_sizes)-i-1))
			# print(inputs.shape)

		inputs = self._output_block_layer(inputs=inputs, training=training)
		# print(inputs.shape)

		return inputs
Ejemplo n.º 4
0
    def _output_block_layer(self, inputs, training):

        inputs = BN_ReLU(inputs, training)

        inputs = tf.layers.dropout(inputs, rate=0.5, training=training)

        inputs = Conv3D(inputs=inputs,
                        filters=self.num_classes,
                        kernel_size=1,
                        strides=1,
                        use_bias=True)

        return tf.identity(inputs, 'output')
def res_inc_deconv(inputs, training):
    #	[b, w, h, d, c]= inputs.get_shape().as_list()
    deconv1_1_1 = Deconv3D(inputs,
                           32,
                           kernel_size=3,
                           strides=1,
                           use_bias=False)  #44
    deconv1_1 = BN_ReLU(deconv1_1_1, training)
    deconv2_1_1 = Deconv3D(inputs,
                           32,
                           kernel_size=3,
                           strides=1,
                           use_bias=False)  #44
    deconv2_1 = BN_ReLU(deconv2_1_1, training)
    deconv2_3 = Deconv3D(deconv2_1,
                         64,
                         kernel_size=3,
                         strides=1,
                         use_bias=False)  #88
    deconv2_2 = BN_ReLU(deconv2_3, training)
    deconv3_1_1 = Deconv3D(inputs,
                           32,
                           kernel_size=3,
                           strides=1,
                           use_bias=False)  #44
    deconv3_1 = BN_ReLU(deconv3_1_1, training)
    deconv3_2_1 = Dilated_Conv3D(deconv3_1,
                                 32,
                                 kernel_size=3,
                                 dilation_rate=2,
                                 use_bias=False)  #44
    deconv3_2 = BN_ReLU(deconv3_2_1, training)
    concat = tf.concat((deconv1_1, deconv2_2, deconv3_2), axis=4)
    deconv1 = Deconv3D(concat, 128, kernel_size=3, strides=1,
                       use_bias=False)  #176
    deconv = BN_ReLU(deconv1, training)
    fuse = tf.add(inputs, deconv)
    return fuse
def unpool(inputs, training):
    #	[b, w, h, d, c] = inputs.get_shape().as_list()
    conv31 = Conv3D(inputs, 176, kernel_size=3, strides=1)
    deconv31 = BN_ReLU(conv31, training)
    deconv1_1 = Deconv3D(deconv31,
                         176,
                         kernel_size=3,
                         strides=1,
                         use_bias=False)
    deconv1 = BN_ReLU(deconv1_1, training)
    deconv1_2 = Deconv3D(deconv1, 88, kernel_size=3, strides=2, use_bias=False)
    deconv2 = BN_ReLU(deconv1_2, training)
    deconv2_1 = Deconv3D(inputs, 176, kernel_size=3, strides=1, use_bias=False)
    deconv3 = BN_ReLU(deconv2_1, training)
    deconv2_2 = Dilated_Conv3D(deconv3,
                               176,
                               kernel_size=3,
                               dilation_rate=2,
                               use_bias=False)
    deconv4 = BN_ReLU(deconv2_2, training)
    deconv2_3 = Deconv3D(deconv4, 88, kernel_size=3, strides=2, use_bias=False)
    deconv5 = BN_ReLU(deconv2_3, training)
    concat = tf.concat((deconv2, deconv5), axis=4)
    return concat