Beispiel #1
0
    def forward_network(self, input_img, reuse=False):
        with tf.compat.v1.variable_scope(self.name, reuse=reuse):
            tf_utils.print_activations(input_img, logger=None)
            inputs = self.conv2d_fixed_padding(inputs=input_img, filters=64, kernel_size=7, strides=2, name='conv1')
            inputs = tf_utils.max_pool(inputs, name='3x3_maxpool', ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
                                       logger=None)

            inputs = self.block_layer(inputs=inputs, filters=64, block_fn=self.bottleneck_block, blocks=self.layers[0],
                                      strides=1, train_mode=False, name='block_layer1')
            inputs = self.block_layer(inputs=inputs, filters=128, block_fn=self.bottleneck_block, blocks=self.layers[1],
                                      strides=2, train_mode=False, name='block_layer2')
            inputs = self.block_layer(inputs=inputs, filters=256, block_fn=self.bottleneck_block, blocks=self.layers[2],
                                      strides=2, train_mode=False, name='block_layer3')
            inputs = self.block_layer(inputs=inputs, filters=512, block_fn=self.bottleneck_block, blocks=self.layers[3],
                                      strides=2, train_mode=False, name='block_layer4')

            inputs = tf_utils.relu(inputs, name='before_flatten_relu', logger=None)

            # _, h, w, _ = inputs.get_shape().as_list()
            # inputs = tf_utils.avg_pool(inputs, name='gap', ksize=[1, h, w, 1], strides=[1, 1, 1, 1], logger=self.logger)

            # Flatten & FC1
            inputs = tf_utils.flatten(inputs, name='flatten', logger=None)
            inputs = tf_utils.linear(inputs, 512, name='FC1')
            inputs = tf_utils.relu(inputs, name='FC1_relu', logger=None)

            inputs = tf_utils.linear(inputs, 256, name='FC2')
            inputs = tf_utils.relu(inputs, name='FC2_relu', logger=None)

            logits = tf_utils.linear(inputs, self.num_attribute, name='Out')

            return logits
Beispiel #2
0
    def forward_network(self, img, padding='SAME', reuse=False):
        with tf.compat.v1.variable_scope(self.name, reuse=reuse):
            # Stage 0
            s0_conv1 = tf_utils.conv2d(x=img,
                                       output_dim=self.conv_dims[0],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s0_conv1')
            s0_conv1 = tf_utils.relu(s0_conv1, name='relu_s0_conv1')

            s0_conv2 = tf_utils.conv2d(x=s0_conv1,
                                       output_dim=self.conv_dims[0],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s0_conv2')
            s0_conv2 = tf_utils.norm(s0_conv2,
                                     name='s0_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s0_conv2 = tf_utils.relu(s0_conv2, name='relu_s0_conv2')

            # Stage 1
            s1_maxpool = tf_utils.max_pool(x=s0_conv2, name='s1_maxpool2d')

            s1_conv1 = tf_utils.conv2d(x=s1_maxpool,
                                       output_dim=self.conv_dims[0],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s1_conv1')
            s1_conv1 = tf_utils.norm(s1_conv1,
                                     name='s1_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1')

            s1_conv2 = tf_utils.conv2d(x=s1_conv1,
                                       output_dim=self.conv_dims[1],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s1_conv2')
            s1_conv2 = tf_utils.norm(s1_conv2,
                                     name='s1_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2')

            # Stage 2
            s2_maxpool = tf_utils.max_pool(x=s1_conv2, name='s2_maxpool2d')
            s2_conv1 = tf_utils.conv2d(x=s2_maxpool,
                                       output_dim=self.conv_dims[2],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s2_conv1')
            s2_conv1 = tf_utils.norm(s2_conv1,
                                     name='s2_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s2_conv1 = tf_utils.relu(s2_conv1, name='relu_s2_conv1')

            s2_conv2 = tf_utils.conv2d(x=s2_conv1,
                                       output_dim=self.conv_dims[3],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s2_conv2')
            s2_conv2 = tf_utils.norm(s2_conv2,
                                     name='s2_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s2_conv2 = tf_utils.relu(s2_conv2, name='relu_s2_conv2')

            # Stage 3
            s3_maxpool = tf_utils.max_pool(x=s2_conv2, name='s3_maxpool2d')
            s3_conv1 = tf_utils.conv2d(x=s3_maxpool,
                                       output_dim=self.conv_dims[4],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s3_conv1')
            s3_conv1 = tf_utils.norm(s3_conv1,
                                     name='s3_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s3_conv1 = tf_utils.relu(s3_conv1, name='relu_s3_conv1')

            s3_conv2 = tf_utils.conv2d(x=s3_conv1,
                                       output_dim=self.conv_dims[5],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s3_conv2')
            s3_conv2 = tf_utils.norm(s3_conv2,
                                     name='s3_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s3_conv2 = tf_utils.relu(s3_conv2, name='relu_s3_conv2')

            # Stage 4
            s4_maxpool = tf_utils.max_pool(x=s3_conv2, name='s4_maxpool2d')
            s4_conv1 = tf_utils.conv2d(x=s4_maxpool,
                                       output_dim=self.conv_dims[6],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s4_conv1')
            s4_conv1 = tf_utils.norm(s4_conv1,
                                     name='s4_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s4_conv1 = tf_utils.relu(s4_conv1, name='relu_s4_conv1')

            s4_conv2 = tf_utils.conv2d(x=s4_conv1,
                                       output_dim=self.conv_dims[7],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s4_conv2')
            s4_conv2 = tf_utils.norm(s4_conv2,
                                     name='s4_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s4_conv2 = tf_utils.relu(s4_conv2, name='relu_s4_conv2')
            s4_conv2_drop = tf_utils.dropout(x=s4_conv2,
                                             keep_prob=0.,
                                             name='s4_dropout')

            # Stage 5
            s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop,
                                           name='s5_maxpool2d')
            s5_conv1 = tf_utils.conv2d(x=s5_maxpool,
                                       output_dim=self.conv_dims[8],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s5_conv1')
            s5_conv1 = tf_utils.norm(s5_conv1,
                                     name='s5_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s5_conv1 = tf_utils.relu(s5_conv1, name='relu_s5_conv1')

            s5_conv2 = tf_utils.conv2d(x=s5_conv1,
                                       output_dim=self.conv_dims[9],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s5_conv2')
            s5_conv2 = tf_utils.norm(s5_conv2,
                                     name='s5_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s5_conv2 = tf_utils.relu(s5_conv2, name='relu_s5_conv2')
            s5_conv2_drop = tf_utils.dropout(x=s5_conv2,
                                             keep_prob=0.,
                                             name='s5_dropout')

            # Stage 6
            s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop,
                                           output_dim=self.conv_dims[10],
                                           k_h=2,
                                           k_w=2,
                                           initializer='He',
                                           name='s6_deconv1')
            s6_deconv1 = tf_utils.norm(s6_deconv1,
                                       name='s6_norm0',
                                       _type='batch',
                                       _ops=self._ops,
                                       is_train=False)
            s6_deconv1 = tf_utils.relu(s6_deconv1, name='relu_s6_deconv1')
            # Cropping
            w1 = s4_conv2_drop.get_shape().as_list()[2]
            w2 = s6_deconv1.get_shape().as_list()[2] - s4_conv2_drop.get_shape(
            ).as_list()[2]
            s6_deconv1_split, _ = tf.split(s6_deconv1,
                                           num_or_size_splits=[w1, w2],
                                           axis=2,
                                           name='axis2_split')
            tf_utils.print_activations(s6_deconv1_split)
            # Concat
            s6_concat = tf_utils.concat(
                values=[s6_deconv1_split, s4_conv2_drop],
                axis=3,
                name='s6_axis3_concat')

            s6_conv2 = tf_utils.conv2d(x=s6_concat,
                                       output_dim=self.conv_dims[11],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s6_conv2')
            s6_conv2 = tf_utils.norm(s6_conv2,
                                     name='s6_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s6_conv2 = tf_utils.relu(s6_conv2, name='relu_s6_conv2')

            s6_conv3 = tf_utils.conv2d(x=s6_conv2,
                                       output_dim=self.conv_dims[12],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s6_conv3')
            s6_conv3 = tf_utils.norm(s6_conv3,
                                     name='s6_norm2',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s6_conv3 = tf_utils.relu(s6_conv3, name='relu_s6_conv3')

            # Stage 7
            s7_deconv1 = tf_utils.deconv2d(x=s6_conv3,
                                           output_dim=self.conv_dims[13],
                                           k_h=2,
                                           k_w=2,
                                           initializer='He',
                                           name='s7_deconv1')
            s7_deconv1 = tf_utils.norm(s7_deconv1,
                                       name='s7_norm0',
                                       _type='batch',
                                       _ops=self._ops,
                                       is_train=False)
            s7_deconv1 = tf_utils.relu(s7_deconv1, name='relu_s7_deconv1')
            # Concat
            s7_concat = tf_utils.concat(values=[s7_deconv1, s3_conv2],
                                        axis=3,
                                        name='s7_axis3_concat')

            s7_conv2 = tf_utils.conv2d(x=s7_concat,
                                       output_dim=self.conv_dims[14],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s7_conv2')
            s7_conv2 = tf_utils.norm(s7_conv2,
                                     name='s7_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s7_conv2 = tf_utils.relu(s7_conv2, name='relu_s7_conv2')

            s7_conv3 = tf_utils.conv2d(x=s7_conv2,
                                       output_dim=self.conv_dims[15],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s7_conv3')
            s7_conv3 = tf_utils.norm(s7_conv3,
                                     name='s7_norm2',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s7_conv3 = tf_utils.relu(s7_conv3, name='relu_s7_conv3')

            # Stage 8
            s8_deconv1 = tf_utils.deconv2d(x=s7_conv3,
                                           output_dim=self.conv_dims[16],
                                           k_h=2,
                                           k_w=2,
                                           initializer='He',
                                           name='s8_deconv1')
            s8_deconv1 = tf_utils.norm(s8_deconv1,
                                       name='s8_norm0',
                                       _type='batch',
                                       _ops=self._ops,
                                       is_train=False)
            s8_deconv1 = tf_utils.relu(s8_deconv1, name='relu_s8_deconv1')
            # Concat
            s8_concat = tf_utils.concat(values=[s8_deconv1, s2_conv2],
                                        axis=3,
                                        name='s8_axis3_concat')

            s8_conv2 = tf_utils.conv2d(x=s8_concat,
                                       output_dim=self.conv_dims[17],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s8_conv2')
            s8_conv2 = tf_utils.norm(s8_conv2,
                                     name='s8_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s8_conv2 = tf_utils.relu(s8_conv2, name='relu_s8_conv2')

            s8_conv3 = tf_utils.conv2d(x=s8_conv2,
                                       output_dim=self.conv_dims[18],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s8_conv3')
            s8_conv3 = tf_utils.norm(s8_conv3,
                                     name='s8_norm2',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s8_conv3 = tf_utils.relu(s8_conv3, name='relu_conv3')

            # Stage 9
            s9_deconv1 = tf_utils.deconv2d(x=s8_conv3,
                                           output_dim=self.conv_dims[19],
                                           k_h=2,
                                           k_w=2,
                                           initializer='He',
                                           name='s9_deconv1')
            s9_deconv1 = tf_utils.norm(s9_deconv1,
                                       name='s9_norm0',
                                       _type='batch',
                                       _ops=self._ops,
                                       is_train=False)
            s9_deconv1 = tf_utils.relu(s9_deconv1, name='relu_s9_deconv1')
            # Concat
            s9_concat = tf_utils.concat(values=[s9_deconv1, s1_conv2],
                                        axis=3,
                                        name='s9_axis3_concat')

            s9_conv2 = tf_utils.conv2d(x=s9_concat,
                                       output_dim=self.conv_dims[20],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s9_conv2')
            s9_conv2 = tf_utils.norm(s9_conv2,
                                     name='s9_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s9_conv2 = tf_utils.relu(s9_conv2, name='relu_s9_conv2')

            s9_conv3 = tf_utils.conv2d(x=s9_conv2,
                                       output_dim=self.conv_dims[21],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s9_conv3')
            s9_conv3 = tf_utils.norm(s9_conv3,
                                     name='s9_norm2',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s9_conv3 = tf_utils.relu(s9_conv3, name='relu_s9_conv3')

            s10_deconv1 = tf_utils.deconv2d(x=s9_conv3,
                                            output_dim=self.conv_dims[-1],
                                            k_h=2,
                                            k_w=2,
                                            initializer='He',
                                            name='s10_deconv1')
            s10_deconv1 = tf_utils.norm(s10_deconv1,
                                        name='s10_norm0',
                                        _type='batch',
                                        _ops=self._ops,
                                        is_train=False)
            s10_deconv1 = tf_utils.relu(s10_deconv1, name='relu_s10_deconv1')
            # Concat
            s10_concat = tf_utils.concat(values=[s10_deconv1, s0_conv2],
                                         axis=3,
                                         name='s10_axis3_concat')

            s10_conv2 = tf_utils.conv2d(s10_concat,
                                        output_dim=self.conv_dims[-1],
                                        k_h=3,
                                        k_w=3,
                                        d_h=1,
                                        d_w=1,
                                        padding=padding,
                                        initializer='He',
                                        name='s10_conv2')
            s10_conv2 = tf_utils.norm(s10_conv2,
                                      name='s10_norm1',
                                      _type='batch',
                                      _ops=self._ops,
                                      is_train=False)
            s10_conv2 = tf_utils.relu(s10_conv2, name='relu_s10_conv2')

            s10_conv3 = tf_utils.conv2d(x=s10_conv2,
                                        output_dim=self.conv_dims[-1],
                                        k_h=3,
                                        k_w=3,
                                        d_h=1,
                                        d_w=1,
                                        padding=padding,
                                        initializer='He',
                                        name='s10_conv3')
            s10_conv3 = tf_utils.norm(s10_conv3,
                                      name='s10_norm2',
                                      _type='batch',
                                      _ops=self._ops,
                                      is_train=False)
            s10_conv3 = tf_utils.relu(s10_conv3, name='relu_s10_conv3')

            output = tf_utils.conv2d(s10_conv3,
                                     output_dim=self.num_classes,
                                     k_h=1,
                                     k_w=1,
                                     d_h=1,
                                     d_w=1,
                                     padding=padding,
                                     initializer='He',
                                     name='output')

            return output
    def forward_network(self, inputImg, padding='SAME', reuse=False):
        with tf.compat.v1.variable_scope(self.name, reuse=reuse):
            # This part is for compatible between input size [640, 400] and [320, 200]
            if self.resize_factor == 1.0:
                # Stage 0
                tf_utils.print_activations(inputImg, logger=self.logger)
                s0_conv1 = tf_utils.conv2d(x=inputImg, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s0_conv1', logger=self.logger)
                s0_conv1 = tf_utils.relu(s0_conv1, name='relu_s0_conv1', logger=self.logger)

                s0_conv2 = tf_utils.conv2d(x=s0_conv1, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s0_conv2', logger=self.logger)
                if self.use_batch_norm:
                    s0_conv2 = tf_utils.norm(s0_conv2, name='s0_norm1', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s0_conv2 = tf_utils.relu(s0_conv2, name='relu_s0_conv2', logger=self.logger)

                # Stage 1
                s1_maxpool = tf_utils.max_pool(x=s0_conv2, name='s1_maxpool2d', logger=self.logger)

                s1_conv1 = tf_utils.conv2d(x=s1_maxpool, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s1_conv1', logger=self.logger)
                if self.use_batch_norm:
                    s1_conv1 = tf_utils.norm(s1_conv1, name='s1_norm0', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1', logger=self.logger)

                s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s1_conv2', logger=self.logger)
                if self.use_batch_norm:
                    s1_conv2 = tf_utils.norm(s1_conv2, name='s1_norm1', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2', logger=self.logger)
            else:
                # Stage 1
                tf_utils.print_activations(inputImg, logger=self.logger)
                s1_conv1 = tf_utils.conv2d(x=inputImg, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s1_conv1', logger=self.logger)
                s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1', logger=self.logger)

                s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s1_conv2', logger=self.logger)
                if self.use_batch_norm:
                    s1_conv2 = tf_utils.norm(s1_conv2, name='s1_norm1', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2', logger=self.logger)

            # Stage 2
            s2_maxpool = tf_utils.max_pool(x=s1_conv2, name='s2_maxpool2d', logger=self.logger)
            s2_conv1 = tf_utils.conv2d(x=s2_maxpool, output_dim=self.conv_dims[2], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s2_conv1', logger=self.logger)
            if self.use_batch_norm:
                s2_conv1 = tf_utils.norm(s2_conv1, name='s2_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s2_conv1 = tf_utils.relu(s2_conv1, name='relu_s2_conv1', logger=self.logger)

            s2_conv2 = tf_utils.conv2d(x=s2_conv1, output_dim=self.conv_dims[3], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s2_conv2', logger=self.logger)
            if self.use_batch_norm:
                s2_conv2 = tf_utils.norm(s2_conv2, name='s2_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s2_conv2 = tf_utils.relu(s2_conv2, name='relu_s2_conv2', logger=self.logger)

            # Stage 3
            s3_maxpool = tf_utils.max_pool(x=s2_conv2, name='s3_maxpool2d', logger=self.logger)
            s3_conv1 = tf_utils.conv2d(x=s3_maxpool, output_dim=self.conv_dims[4], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s3_conv1', logger=self.logger)
            if self.use_batch_norm:
                s3_conv1 = tf_utils.norm(s3_conv1, name='s3_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s3_conv1 = tf_utils.relu(s3_conv1, name='relu_s3_conv1', logger=self.logger)

            s3_conv2 = tf_utils.conv2d(x=s3_conv1, output_dim=self.conv_dims[5], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s3_conv2', logger=self.logger)
            if self.use_batch_norm:
                s3_conv2 = tf_utils.norm(s3_conv2, name='s3_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s3_conv2 = tf_utils.relu(s3_conv2, name='relu_s3_conv2', logger=self.logger)

            # Stage 4
            s4_maxpool = tf_utils.max_pool(x=s3_conv2, name='s4_maxpool2d', logger=self.logger)
            s4_conv1 = tf_utils.conv2d(x=s4_maxpool, output_dim=self.conv_dims[6], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s4_conv1', logger=self.logger)
            if self.use_batch_norm:
                s4_conv1 = tf_utils.norm(s4_conv1, name='s4_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s4_conv1 = tf_utils.relu(s4_conv1, name='relu_s4_conv1', logger=self.logger)

            s4_conv2 = tf_utils.conv2d(x=s4_conv1, output_dim=self.conv_dims[7], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s4_conv2', logger=self.logger)
            if self.use_batch_norm:
                s4_conv2 = tf_utils.norm(s4_conv2, name='s4_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s4_conv2 = tf_utils.relu(s4_conv2, name='relu_s4_conv2', logger=self.logger)
            s4_conv2_drop = tf_utils.dropout(x=s4_conv2, keep_prob=self.ratePh, name='s4_dropout',
                                             logger=self.logger)

            # Stage 5
            s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop, name='s5_maxpool2d', logger=self.logger)
            s5_conv1 = tf_utils.conv2d(x=s5_maxpool, output_dim=self.conv_dims[8], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s5_conv1', logger=self.logger)
            if self.use_batch_norm:
                s5_conv1 = tf_utils.norm(s5_conv1, name='s5_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s5_conv1 = tf_utils.relu(s5_conv1, name='relu_s5_conv1', logger=self.logger)

            s5_conv2 = tf_utils.conv2d(x=s5_conv1, output_dim=self.conv_dims[9], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s5_conv2', logger=self.logger)
            if self.use_batch_norm:
                s5_conv2 = tf_utils.norm(s5_conv2, name='s5_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s5_conv2 = tf_utils.relu(s5_conv2, name='relu_s5_conv2', logger=self.logger)
            s5_conv2_drop = tf_utils.dropout(x=s5_conv2, keep_prob=self.ratePh, name='s5_dropout',
                                             logger=self.logger)

            # Stage 6
            s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop, output_dim=self.conv_dims[10], k_h=2, k_w=2,
                                           initializer='He', name='s6_deconv1', logger=self.logger)
            if self.use_batch_norm:
                s6_deconv1 = tf_utils.norm(s6_deconv1, name='s6_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s6_deconv1 = tf_utils.relu(s6_deconv1, name='relu_s6_deconv1', logger=self.logger)
            # Cropping
            w1 = s4_conv2_drop.get_shape().as_list()[2]
            w2 = s6_deconv1.get_shape().as_list()[2] - s4_conv2_drop.get_shape().as_list()[2]
            s6_deconv1_split, _ = tf.split(s6_deconv1, num_or_size_splits=[w1, w2], axis=2, name='axis2_split')
            tf_utils.print_activations(s6_deconv1_split, logger=self.logger)
            # Concat
            s6_concat = tf_utils.concat(values=[s6_deconv1_split, s4_conv2_drop], axis=3, name='s6_axis3_concat',
                                        logger=self.logger)

            s6_conv2 = tf_utils.conv2d(x=s6_concat, output_dim=self.conv_dims[11], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s6_conv2', logger=self.logger)
            if self.use_batch_norm:
                s6_conv2 = tf_utils.norm(s6_conv2, name='s6_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s6_conv2 = tf_utils.relu(s6_conv2, name='relu_s6_conv2', logger=self.logger)

            s6_conv3 = tf_utils.conv2d(x=s6_conv2, output_dim=self.conv_dims[12], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s6_conv3', logger=self.logger)
            if self.use_batch_norm:
                s6_conv3 = tf_utils.norm(s6_conv3, name='s6_norm2', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s6_conv3 = tf_utils.relu(s6_conv3, name='relu_s6_conv3', logger=self.logger)

            # Stage 7
            s7_deconv1 = tf_utils.deconv2d(x=s6_conv3, output_dim=self.conv_dims[13], k_h=2, k_w=2, initializer='He',
                                           name='s7_deconv1', logger=self.logger)
            if self.use_batch_norm:
                s7_deconv1 = tf_utils.norm(s7_deconv1, name='s7_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s7_deconv1 = tf_utils.relu(s7_deconv1, name='relu_s7_deconv1', logger=self.logger)
            # Concat
            s7_concat = tf_utils.concat(values=[s7_deconv1, s3_conv2], axis=3, name='s7_axis3_concat',
                                        logger=self.logger)

            s7_conv2 = tf_utils.conv2d(x=s7_concat, output_dim=self.conv_dims[14], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s7_conv2', logger=self.logger)
            if self.use_batch_norm:
                s7_conv2 = tf_utils.norm(s7_conv2, name='s7_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s7_conv2 = tf_utils.relu(s7_conv2, name='relu_s7_conv2', logger=self.logger)

            s7_conv3 = tf_utils.conv2d(x=s7_conv2, output_dim=self.conv_dims[15], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s7_conv3', logger=self.logger)
            if self.use_batch_norm:
                s7_conv3 = tf_utils.norm(s7_conv3, name='s7_norm2', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s7_conv3 = tf_utils.relu(s7_conv3, name='relu_s7_conv3', logger=self.logger)

            # Stage 8
            s8_deconv1 = tf_utils.deconv2d(x=s7_conv3, output_dim=self.conv_dims[16], k_h=2, k_w=2, initializer='He',
                                           name='s8_deconv1', logger=self.logger)
            if self.use_batch_norm:
                s8_deconv1 = tf_utils.norm(s8_deconv1, name='s8_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s8_deconv1 = tf_utils.relu(s8_deconv1, name='relu_s8_deconv1', logger=self.logger)
            # Concat
            s8_concat = tf_utils.concat(values=[s8_deconv1,s2_conv2], axis=3, name='s8_axis3_concat',
                                        logger=self.logger)

            s8_conv2 = tf_utils.conv2d(x=s8_concat, output_dim=self.conv_dims[17], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s8_conv2', logger=self.logger)
            if self.use_batch_norm:
                s8_conv2 = tf_utils.norm(s8_conv2, name='s8_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s8_conv2 = tf_utils.relu(s8_conv2, name='relu_s8_conv2', logger=self.logger)

            s8_conv3 = tf_utils.conv2d(x=s8_conv2, output_dim=self.conv_dims[18], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s8_conv3', logger=self.logger)
            if self.use_batch_norm:
                s8_conv3 = tf_utils.norm(s8_conv3, name='s8_norm2', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s8_conv3 = tf_utils.relu(s8_conv3, name='relu_conv3', logger=self.logger)

            # Stage 9
            s9_deconv1 = tf_utils.deconv2d(x=s8_conv3, output_dim=self.conv_dims[19], k_h=2, k_w=2,
                                           initializer='He', name='s9_deconv1', logger=self.logger)
            if self.use_batch_norm:
                s9_deconv1 = tf_utils.norm(s9_deconv1, name='s9_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s9_deconv1 = tf_utils.relu(s9_deconv1, name='relu_s9_deconv1', logger=self.logger)
            # Concat
            s9_concat = tf_utils.concat(values=[s9_deconv1, s1_conv2], axis=3, name='s9_axis3_concat',
                                        logger=self.logger)

            s9_conv2 = tf_utils.conv2d(x=s9_concat, output_dim=self.conv_dims[20], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s9_conv2', logger=self.logger)
            if self.use_batch_norm:
                s9_conv2 = tf_utils.norm(s9_conv2, name='s9_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s9_conv2 = tf_utils.relu(s9_conv2, name='relu_s9_conv2', logger=self.logger)

            s9_conv3 = tf_utils.conv2d(x=s9_conv2, output_dim=self.conv_dims[21], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s9_conv3', logger=self.logger)
            if self.use_batch_norm:
                s9_conv3 = tf_utils.norm(s9_conv3, name='s9_norm2', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s9_conv3 = tf_utils.relu(s9_conv3, name='relu_s9_conv3', logger=self.logger)

            if self.resize_factor == 1.0:
                s10_deconv1 = tf_utils.deconv2d(x=s9_conv3, output_dim=self.conv_dims[-1], k_h=2, k_w=2,
                                                initializer='He', name='s10_deconv1', logger=self.logger)
                if self.use_batch_norm:
                    s10_deconv1 = tf_utils.norm(s10_deconv1, name='s10_norm0', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s10_deconv1 = tf_utils.relu(s10_deconv1, name='relu_s10_deconv1', logger=self.logger)
                # Concat
                s10_concat = tf_utils.concat(values=[s10_deconv1, s0_conv2], axis=3, name='s10_axis3_concat',
                                             logger=self.logger)

                s10_conv2 = tf_utils.conv2d(s10_concat, output_dim=self.conv_dims[-1], k_h=3, k_w=3, d_h=1, d_w=1,
                                            padding=padding, initializer='He', name='s10_conv2', logger=self.logger)
                if self.use_batch_norm:
                    s10_conv2 = tf_utils.norm(s10_conv2, name='s10_norm1', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s10_conv2 = tf_utils.relu(s10_conv2, name='relu_s10_conv2', logger=self.logger)

                s10_conv3 = tf_utils.conv2d(x=s10_conv2, output_dim=self.conv_dims[-1], k_h=3, k_w=3, d_h=1, d_w=1,
                                            padding=padding, initializer='He', name='s10_conv3', logger=self.logger)
                if self.use_batch_norm:
                    s10_conv3 = tf_utils.norm(s10_conv3, name='s10_norm2', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s10_conv3 = tf_utils.relu(s10_conv3, name='relu_s10_conv3', logger=self.logger)

                output = tf_utils.conv2d(s10_conv3, output_dim=self.numClasses, k_h=1, k_w=1, d_h=1, d_w=1,
                                         padding=padding, initializer='He', name='output', logger=self.logger)
            else:
                output = tf_utils.conv2d(s9_conv3, output_dim=self.numClasses, k_h=1, k_w=1, d_h=1, d_w=1,
                                         padding=padding, initializer='He', name='output', logger=self.logger)

            return output
Beispiel #4
0
    def u_net(self):
        # Stage 1
        tf_utils.print_activations(self.inp_img, logger=self.logger)
        s1_conv1 = tf_utils.conv2d(x=self.inp_img,
                                   output_dim=self.conv_dims[0],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s1_conv1',
                                   logger=self.logger)
        s1_conv1 = tf_utils.relu(s1_conv1,
                                 name='relu_s1_conv1',
                                 logger=self.logger)
        s1_conv2 = tf_utils.conv2d(x=s1_conv1,
                                   output_dim=self.conv_dims[1],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s1_conv2',
                                   logger=self.logger)
        s1_conv2 = tf_utils.relu(s1_conv2,
                                 name='relu_s1_conv2',
                                 logger=self.logger)

        # Stage 2
        s2_maxpool = tf_utils.max_pool(x=s1_conv2,
                                       name='s2_maxpool',
                                       logger=self.logger)
        s2_conv1 = tf_utils.conv2d(x=s2_maxpool,
                                   output_dim=self.conv_dims[2],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s2_conv1',
                                   logger=self.logger)
        s2_conv1 = tf_utils.relu(s2_conv1,
                                 name='relu_s2_conv1',
                                 logger=self.logger)
        s2_conv2 = tf_utils.conv2d(x=s2_conv1,
                                   output_dim=self.conv_dims[3],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s2_conv2',
                                   logger=self.logger)
        s2_conv2 = tf_utils.relu(s2_conv2,
                                 name='relu_s2_conv2',
                                 logger=self.logger)

        # Stage 3
        s3_maxpool = tf_utils.max_pool(x=s2_conv2,
                                       name='s3_maxpool',
                                       logger=self.logger)
        s3_conv1 = tf_utils.conv2d(x=s3_maxpool,
                                   output_dim=self.conv_dims[4],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s3_conv1',
                                   logger=self.logger)
        s3_conv1 = tf_utils.relu(s3_conv1,
                                 name='relu_s3_conv1',
                                 logger=self.logger)
        s3_conv2 = tf_utils.conv2d(x=s3_conv1,
                                   output_dim=self.conv_dims[5],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s3_conv2',
                                   logger=self.logger)
        s3_conv2 = tf_utils.relu(s3_conv2,
                                 name='relu_s3_conv2',
                                 logger=self.logger)

        # Stage 4
        s4_maxpool = tf_utils.max_pool(x=s3_conv2,
                                       name='s4_maxpool',
                                       logger=self.logger)
        s4_conv1 = tf_utils.conv2d(x=s4_maxpool,
                                   output_dim=self.conv_dims[6],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s4_conv1',
                                   logger=self.logger)
        s4_conv1 = tf_utils.relu(s4_conv1,
                                 name='relu_s4_conv1',
                                 logger=self.logger)
        s4_conv2 = tf_utils.conv2d(x=s4_conv1,
                                   output_dim=self.conv_dims[7],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s4_conv2',
                                   logger=self.logger)
        s4_conv2 = tf_utils.relu(s4_conv2,
                                 name='relu_s4_conv2',
                                 logger=self.logger)
        s4_conv2_drop = tf_utils.dropout(x=s4_conv2,
                                         keep_prob=self.keep_prob,
                                         name='s4_conv2_dropout',
                                         logger=self.logger)

        # Stage 5
        s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop,
                                       name='s5_maxpool',
                                       logger=self.logger)
        s5_conv1 = tf_utils.conv2d(x=s5_maxpool,
                                   output_dim=self.conv_dims[8],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s5_conv1',
                                   logger=self.logger)
        s5_conv1 = tf_utils.relu(s5_conv1,
                                 name='relu_s5_conv1',
                                 logger=self.logger)
        s5_conv2 = tf_utils.conv2d(x=s5_conv1,
                                   output_dim=self.conv_dims[9],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s5_conv2',
                                   logger=self.logger)
        s5_conv2 = tf_utils.relu(s5_conv2,
                                 name='relu_s5_conv2',
                                 logger=self.logger)
        s5_conv2_drop = tf_utils.dropout(x=s5_conv2,
                                         keep_prob=self.keep_prob,
                                         name='s5_conv2_dropout',
                                         logger=self.logger)

        # Stage 6
        s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop,
                                       output_dim=self.conv_dims[10],
                                       k_h=2,
                                       k_w=2,
                                       initializer='He',
                                       name='s6_deconv1',
                                       logger=self.logger)
        s6_deconv1 = tf_utils.relu(s6_deconv1,
                                   name='relu_s6_deconv1',
                                   logger=self.logger)

        # Cropping
        h1, w1 = s4_conv2_drop.get_shape().as_list()[1:3]
        h2, w2 = s6_deconv1.get_shape().as_list()[1:3]
        s4_conv2_crop = tf.image.crop_to_bounding_box(
            image=s4_conv2_drop,
            offset_height=int(0.5 * (h1 - h2)),
            offset_width=int(0.5 * (w1 - w2)),
            target_height=h2,
            target_width=w2)
        tf_utils.print_activations(s4_conv2_crop, logger=self.logger)

        s6_concat = tf_utils.concat(values=[s4_conv2_crop, s6_deconv1],
                                    axis=3,
                                    name='s6_concat',
                                    logger=self.logger)
        s6_conv2 = tf_utils.conv2d(x=s6_concat,
                                   output_dim=self.conv_dims[11],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s6_conv2',
                                   logger=self.logger)
        s6_conv2 = tf_utils.relu(s6_conv2,
                                 name='relu_s6_conv2',
                                 logger=self.logger)
        s6_conv3 = tf_utils.conv2d(x=s6_conv2,
                                   output_dim=self.conv_dims[12],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s6_conv3',
                                   logger=self.logger)
        s6_conv3 = tf_utils.relu(s6_conv3,
                                 name='relu_s6_conv3',
                                 logger=self.logger)

        # Stage 7
        s7_deconv1 = tf_utils.deconv2d(x=s6_conv3,
                                       output_dim=self.conv_dims[13],
                                       k_h=2,
                                       k_w=2,
                                       initializer='He',
                                       name='s7_deconv1',
                                       logger=self.logger)
        s7_deconv1 = tf_utils.relu(s7_deconv1,
                                   name='relu_s7_deconv1',
                                   logger=self.logger)
        # Cropping
        h1, w1 = s3_conv2.get_shape().as_list()[1:3]
        h2, w2 = s7_deconv1.get_shape().as_list()[1:3]
        s3_conv2_crop = tf.image.crop_to_bounding_box(
            image=s3_conv2,
            offset_height=int(0.5 * (h1 - h2)),
            offset_width=int(0.5 * (w1 - w2)),
            target_height=h2,
            target_width=w2)
        tf_utils.print_activations(s3_conv2_crop, logger=self.logger)

        s7_concat = tf_utils.concat(values=[s3_conv2_crop, s7_deconv1],
                                    axis=3,
                                    name='s7_concat',
                                    logger=self.logger)
        s7_conv2 = tf_utils.conv2d(x=s7_concat,
                                   output_dim=self.conv_dims[14],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s7_conv2',
                                   logger=self.logger)
        s7_conv2 = tf_utils.relu(s7_conv2,
                                 name='relu_s7_conv2',
                                 logger=self.logger)
        s7_conv3 = tf_utils.conv2d(x=s7_conv2,
                                   output_dim=self.conv_dims[15],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s7_conv3',
                                   logger=self.logger)
        s7_conv3 = tf_utils.relu(s7_conv3,
                                 name='relu_s7_conv3',
                                 logger=self.logger)

        # Stage 8
        s8_deconv1 = tf_utils.deconv2d(x=s7_conv3,
                                       output_dim=self.conv_dims[16],
                                       k_h=2,
                                       k_w=2,
                                       initializer='He',
                                       name='s8_deconv1',
                                       logger=self.logger)
        s8_deconv1 = tf_utils.relu(s8_deconv1,
                                   name='relu_s8_deconv1',
                                   logger=self.logger)
        # Cropping
        h1, w1 = s2_conv2.get_shape().as_list()[1:3]
        h2, w2 = s8_deconv1.get_shape().as_list()[1:3]
        s2_conv2_crop = tf.image.crop_to_bounding_box(
            image=s2_conv2,
            offset_height=int(0.5 * (h1 - h2)),
            offset_width=int(0.5 * (w1 - w2)),
            target_height=h2,
            target_width=w2)
        tf_utils.print_activations(s2_conv2_crop, logger=self.logger)

        s8_concat = tf_utils.concat(values=[s2_conv2_crop, s8_deconv1],
                                    axis=3,
                                    name='s8_concat',
                                    logger=self.logger)
        s8_conv2 = tf_utils.conv2d(x=s8_concat,
                                   output_dim=self.conv_dims[17],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s8_conv2',
                                   logger=self.logger)
        s8_conv2 = tf_utils.relu(s8_conv2,
                                 name='relu_s8_conv2',
                                 logger=self.logger)
        s8_conv3 = tf_utils.conv2d(x=s8_conv2,
                                   output_dim=self.conv_dims[18],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s8_conv3',
                                   logger=self.logger)
        s8_conv3 = tf_utils.relu(s8_conv3,
                                 name='relu_conv3',
                                 logger=self.logger)

        # Stage 9
        s9_deconv1 = tf_utils.deconv2d(x=s8_conv3,
                                       output_dim=self.conv_dims[19],
                                       k_h=2,
                                       k_w=2,
                                       initializer='He',
                                       name='s9_deconv1',
                                       logger=self.logger)
        s9_deconv1 = tf_utils.relu(s9_deconv1,
                                   name='relu_s9_deconv1',
                                   logger=self.logger)
        # Cropping
        h1, w1 = s1_conv2.get_shape().as_list()[1:3]
        h2, w2 = s9_deconv1.get_shape().as_list()[1:3]
        s1_conv2_crop = tf.image.crop_to_bounding_box(
            image=s1_conv2,
            offset_height=int(0.5 * (h1 - h2)),
            offset_width=int(0.5 * (w1 - w2)),
            target_height=h2,
            target_width=w2)
        tf_utils.print_activations(s1_conv2_crop, logger=self.logger)

        s9_concat = tf_utils.concat(values=[s1_conv2_crop, s9_deconv1],
                                    axis=3,
                                    name='s9_concat',
                                    logger=self.logger)
        s9_conv2 = tf_utils.conv2d(x=s9_concat,
                                   output_dim=self.conv_dims[20],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s9_conv2',
                                   logger=self.logger)
        s9_conv2 = tf_utils.relu(s9_conv2,
                                 name='relu_s9_conv2',
                                 logger=self.logger)
        s9_conv3 = tf_utils.conv2d(x=s9_conv2,
                                   output_dim=self.conv_dims[21],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s9_conv3',
                                   logger=self.logger)
        s9_conv3 = tf_utils.relu(s9_conv3,
                                 name='relu_s9_conv3',
                                 logger=self.logger)
        self.pred = tf_utils.conv2d(x=s9_conv3,
                                    output_dim=self.conv_dims[22],
                                    k_h=1,
                                    k_w=1,
                                    d_h=1,
                                    d_w=1,
                                    padding='SAME',
                                    initializer='He',
                                    name='output',
                                    logger=self.logger)
    def __init__(self, input_dim=[32, 32, 3], output_dim=[128, 256, 512, 1000, 10], optimizer=None, use_dropout=True,
                 lr=0.001, weight_decay=1e-4, random_seed=123, is_train=True, log_dir=None, name=None):
        self.name = name
        self.is_train = is_train
        self.log_dir = log_dir
        self.cur_lr = None
        self.logger, self.file_handler, self.stream_handler = utils.init_logger(log_dir=self.log_dir,
                                                                                name=self.name,
                                                                                is_train=self.is_train)

        with tf.variable_scope(self.name):
            # Placeholders for inputs
            self.X = tf.placeholder(dtype=tf.float32, shape=[None, *input_dim], name='X')
            self.y = tf.placeholder(dtype=tf.float32, shape=[None, output_dim[-1]], name='y')
            self.y_cls = tf.math.argmax(input=self.y, axis=1)
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
            tf_utils.print_activations(self.X, logger=self.logger if self.is_train else None)

            # Placeholders for TensorBoard
            self.train_acc = tf.placeholder(tf.float32, name='train_acc')
            self.val_acc = tf.placeholder(tf.float32, name='val_acc')

            # Convolutional layers
            net = self.X
            if use_dropout:
                net = tf_utils.dropout(x=net,
                                       keep_prob=self.keep_prob,
                                       seed=random_seed,
                                       name='dropout_input',
                                       logger=self.logger if self.is_train else None)

            for idx in range(3):
                net = tf_utils.conv2d(x=net,
                                      output_dim=output_dim[idx],
                                      k_h=5,
                                      k_w=5,
                                      d_h=1,
                                      d_w=1,
                                      name='conv2d'+str(idx),
                                      logger=self.logger if self.is_train else None)
                net = tf_utils.max_pool(x=net,
                                        ksize=[1, 3, 3, 1],
                                        strides=[1, 2, 2, 1],
                                        name='maxpool'+str(idx),
                                        logger=self.logger if self.is_train else None)
                net = tf_utils.relu(x=net,
                                    name='relu'+str(idx),
                                    is_print=True,
                                    logger=self.logger if self.is_train else None)

            # Fully conneted layers, flatten first
            net = tf_utils.flatten(x=net,
                                   name='fc2_flatten',
                                   logger=self.logger if self.is_train else None)

            net = tf_utils.linear(x=net,
                                  output_size=output_dim[-2],
                                  name='fc3',
                                  logger=self.logger if self.is_train else None)

            if use_dropout:
                net = tf_utils.dropout(x=net,
                                       keep_prob=self.keep_prob,
                                       seed=random_seed,
                                       name='dropout3',
                                       logger=self.logger if self.is_train else None)

            net = tf_utils.relu(x=net,
                                name='relu3',
                                logger=self.logger if self.is_train else None)

            # Last predict layer
            self.y_pred = tf_utils.linear(x=net,
                                          output_size=output_dim[-1],
                                          name='last_fc',
                                          logger=self.logger if self.is_train else None)

            # Loss = data loss + regularization term
            self.data_loss = tf.math.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=self.y_pred, labels=self.y))
            self.reg_term = weight_decay * tf.reduce_sum(
                [tf.nn.l2_loss(weight) for weight in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)])
            self.loss = self.data_loss + self.reg_term

            # Optimizer
            self.train_op, self.cur_lr = optimizer_fn(optimizer, lr=lr, loss=self.loss, name=self.name)

            # Accuracy etc
            self.y_pred_cls = tf.math.argmax(input=self.y_pred, axis=1)
            correct_prediction = tf.math.equal(self.y_pred_cls, self.y_cls)
            self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype=tf.float32)) * 100.

        self._tensorboard()
        tf_utils.show_all_variables(logger=self.logger if self.is_train else None)
    def forward_network(self, input_img, padding='SAME', reuse=False):
        with tf.compat.v1.variable_scope(self.name, reuse=reuse):
            tf_utils.print_activations(input_img, logger=self.logger)
            inputs = self.conv2d_fixed_padding(inputs=input_img,
                                               filters=64,
                                               kernel_size=7,
                                               strides=2,
                                               name='conv1')
            inputs = tf_utils.max_pool(inputs,
                                       name='3x3_maxpool',
                                       ksize=[1, 3, 3, 1],
                                       strides=[1, 2, 2, 1],
                                       logger=self.logger)

            inputs = self.block_layer(inputs=inputs,
                                      filters=64,
                                      block_fn=self.bottleneck_block,
                                      blocks=self.layers[0],
                                      strides=1,
                                      train_mode=self.train_mode,
                                      name='block_layer1')
            inputs = self.block_layer(inputs=inputs,
                                      filters=128,
                                      block_fn=self.bottleneck_block,
                                      blocks=self.layers[1],
                                      strides=2,
                                      train_mode=self.train_mode,
                                      name='block_layer2')
            inputs = self.block_layer(inputs=inputs,
                                      filters=256,
                                      block_fn=self.bottleneck_block,
                                      blocks=self.layers[2],
                                      strides=2,
                                      train_mode=self.train_mode,
                                      name='block_layer3')
            inputs = self.block_layer(inputs=inputs,
                                      filters=512,
                                      block_fn=self.bottleneck_block,
                                      blocks=self.layers[3],
                                      strides=2,
                                      train_mode=self.train_mode,
                                      name='block_layer4')

            inputs = tf_utils.norm(inputs,
                                   name='before_gap_batch_norm',
                                   _type='batch',
                                   _ops=self._ops,
                                   is_train=self.train_mode,
                                   logger=self.logger)
            inputs = tf_utils.relu(inputs,
                                   name='before_gap_relu',
                                   logger=self.logger)
            _, h, w, _ = inputs.get_shape().as_list()
            inputs = tf_utils.avg_pool(inputs,
                                       name='gap',
                                       ksize=[1, h, w, 1],
                                       strides=[1, 1, 1, 1],
                                       logger=self.logger)

            inputs = tf_utils.flatten(inputs,
                                      name='flatten',
                                      logger=self.logger)
            logits = tf_utils.linear(inputs, self.num_classes, name='logits')

            return logits