示例#1
0
    def inception_block(
            x: TensorType,
            specs: Tuple,
            channel_axis: int,
            name: str,
            weight_suffix: Optional[str] = "weights",
            conv_suffix: Optional[str] = "",
            norm_suffix: Optional[str] = "/BatchNorm") -> TensorType:
        """Inception block.

        Args:
            x: input_tensor
            specs: Number of filters per branch.
            channel_axis: channel dimension
            name: Prefix for ths block.
            weight_suffix: Name of learn-able parameters in conv.
            conv_suffix: Suffix for conv layer.
            norm_suffix: Suffix for batch norm.

        Returns: Concatenated output of inception block.

        """
        (br0, br1, br2, br3) = specs  # ((64,), (96,128), (16,32), (32,))

        branch_0 = conv_norm_relu(x,
                                  br0[0],
                                  1,
                                  1,
                                  name=name + "/Branch_0/Conv2d_0a_1x1",
                                  weight_suffix=weight_suffix,
                                  conv_suffix=conv_suffix,
                                  norm_suffix=norm_suffix)

        branch_1 = conv_norm_relu(x,
                                  br1[0],
                                  1,
                                  1,
                                  name=name + "/Branch_1/Conv2d_0a_1x1",
                                  weight_suffix=weight_suffix,
                                  conv_suffix=conv_suffix,
                                  norm_suffix=norm_suffix)
        branch_1 = conv_norm_relu(branch_1,
                                  br1[1],
                                  3,
                                  3,
                                  name=name + "/Branch_1/Conv2d_0b_3x3",
                                  weight_suffix=weight_suffix,
                                  conv_suffix=conv_suffix,
                                  norm_suffix=norm_suffix)

        branch_2 = conv_norm_relu(x,
                                  br2[0],
                                  1,
                                  1,
                                  name=name + "/Branch_2/Conv2d_0a_1x1",
                                  weight_suffix=weight_suffix,
                                  conv_suffix=conv_suffix,
                                  norm_suffix=norm_suffix)
        if name == "InceptionV1/Mixed_5b":
            branch_2b_name = name + "/Branch_2/Conv2d_0a_3x3"
        else:
            branch_2b_name = name + "/Branch_2/Conv2d_0b_3x3"
        branch_2 = conv_norm_relu(branch_2,
                                  br2[1],
                                  3,
                                  3,
                                  name=branch_2b_name,
                                  weight_suffix=weight_suffix,
                                  conv_suffix=conv_suffix,
                                  norm_suffix=norm_suffix)

        branch_3 = max_pool(x,
                            3,
                            strides=1,
                            padding='same',
                            name=name + "/Branch_3/Conv2d_0a_max")
        branch_3 = conv_norm_relu(branch_3,
                                  br3[0],
                                  1,
                                  1,
                                  name=name + "/Branch_3/Conv2d_0b_1x1",
                                  weight_suffix=weight_suffix,
                                  conv_suffix=conv_suffix,
                                  norm_suffix=norm_suffix)

        x = concat([branch_0, branch_1, branch_2, branch_3],
                   axis=channel_axis,
                   name=name + "_Concatenated")
        return x
示例#2
0
    def build_model(self, img_input: TensorType) -> TensorType:
        """Build graph using img_input as input.

        Args:
            img_input: 4D Image input tensor of shape (batch, height, width, channels)

        Returns:
            `Tensor` holding output probabilities per class, shape (batch, num_classes)
        """

        x = conv_norm_relu(img_input, 32, 3, strides=2, padding='VALID')
        x = conv_norm_relu(x, 32, 3, padding='VALID')
        x = conv_norm_relu(
            x,
            64,
            3,
        )
        x = max_pool(x, 3, strides=2)

        x = conv_norm_relu(x, 80, 1, padding='VALID')
        x = conv_norm_relu(x, 192, 3, padding='VALID')
        x = max_pool(x, 3, strides=2)

        # mixed 0: 35 x 35 x 256
        branch1x1 = conv_norm_relu(x, 64, 1)

        branch5x5 = conv_norm_relu(x, 48, 1)
        branch5x5 = conv_norm_relu(branch5x5, 64, 5)

        branch3x3dbl = conv_norm_relu(x, 64, 1)
        branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3)
        branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3)

        branch_pool = avg_pool(x, 3, strides=1, padding='SAME')
        branch_pool = conv_norm_relu(branch_pool, 32, 1)
        x = concat([branch1x1, branch5x5, branch3x3dbl, branch_pool],
                   axis=-1,
                   name='mixed0')

        # mixed 1: 35 x 35 x 288
        branch1x1 = conv_norm_relu(x, 64, 1)

        branch5x5 = conv_norm_relu(x, 48, 1)
        branch5x5 = conv_norm_relu(branch5x5, 64, 5)

        branch3x3dbl = conv_norm_relu(x, 64, 1)
        branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3)
        branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3)

        branch_pool = avg_pool(x, 3, strides=1, padding='SAME')
        branch_pool = conv_norm_relu(branch_pool, 64, 1)
        x = concat([branch1x1, branch5x5, branch3x3dbl, branch_pool],
                   axis=-1,
                   name='mixed1')

        # mixed 2: 35 x 35 x 288
        branch1x1 = conv_norm_relu(x, 64, 1)

        branch5x5 = conv_norm_relu(x, 48, 1)
        branch5x5 = conv_norm_relu(branch5x5, 64, 5)

        branch3x3dbl = conv_norm_relu(x, 64, 1)
        branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3)
        branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3)

        branch_pool = avg_pool(x, 3, strides=1, padding='SAME')
        branch_pool = conv_norm_relu(branch_pool, 64, 1)
        x = concat([branch1x1, branch5x5, branch3x3dbl, branch_pool],
                   axis=-1,
                   name='mixed2')

        # mixed 3: 17 x 17 x 768
        branch3x3 = conv_norm_relu(x, 384, 3, strides=2, padding='VALID')

        branch3x3dbl = conv_norm_relu(x, 64, 1)
        branch3x3dbl = conv_norm_relu(branch3x3dbl, 96, 3)
        branch3x3dbl = conv_norm_relu(branch3x3dbl,
                                      96,
                                      3,
                                      strides=2,
                                      padding='VALID')

        branch_pool = max_pool(x, 3, 2)
        x = concat([branch3x3, branch3x3dbl, branch_pool],
                   axis=-1,
                   name='mixed3')

        # mixed 4: 17 x 17 x 768
        branch1x1 = conv_norm_relu(x, 192, 1)

        branch7x7 = conv_norm_relu(x, 128, 1)
        branch7x7 = conv_norm_relu(branch7x7, 128, 1, 7)
        branch7x7 = conv_norm_relu(branch7x7, 192, 7, 1)

        branch7x7dbl = conv_norm_relu(x, 128, 1, 1)
        branch7x7dbl = conv_norm_relu(branch7x7dbl, 128, 7, 1)
        branch7x7dbl = conv_norm_relu(branch7x7dbl, 128, 1, 7)
        branch7x7dbl = conv_norm_relu(branch7x7dbl, 128, 7, 1)
        branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 1, 7)

        branch_pool = avg_pool(x, 3, strides=1, padding='SAME')
        branch_pool = conv_norm_relu(branch_pool, 192, 1, 1)
        x = concat([branch1x1, branch7x7, branch7x7dbl, branch_pool],
                   axis=-1,
                   name='mixed4')

        # mixed 5, 6: 17 x 17 x 768
        for i in range(2):
            branch1x1 = conv_norm_relu(x, 192, 1, 1)

            branch7x7 = conv_norm_relu(x, 160, 1, 1)
            branch7x7 = conv_norm_relu(branch7x7, 160, 1, 7)
            branch7x7 = conv_norm_relu(branch7x7, 192, 7, 1)

            branch7x7dbl = conv_norm_relu(x, 160, 1, 1)
            branch7x7dbl = conv_norm_relu(branch7x7dbl, 160, 7, 1)
            branch7x7dbl = conv_norm_relu(branch7x7dbl, 160, 1, 7)
            branch7x7dbl = conv_norm_relu(branch7x7dbl, 160, 7, 1)
            branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 1, 7)

            branch_pool = avg_pool(x, 3, strides=1, padding='SAME')
            branch_pool = conv_norm_relu(branch_pool, 192, 1, 1)
            x = concat([branch1x1, branch7x7, branch7x7dbl, branch_pool],
                       axis=-1,
                       name='mixed' + str(5 + i))

        # mixed 7: 17 x 17 x 768
        branch1x1 = conv_norm_relu(x, 192, 1, 1)

        branch7x7 = conv_norm_relu(x, 192, 1, 1)
        branch7x7 = conv_norm_relu(branch7x7, 192, 1, 7)
        branch7x7 = conv_norm_relu(branch7x7, 192, 7, 1)

        branch7x7dbl = conv_norm_relu(x, 192, 1, 1)
        branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 7, 1)
        branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 1, 7)
        branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 7, 1)
        branch7x7dbl = conv_norm_relu(branch7x7dbl, 192, 1, 7)

        branch_pool = avg_pool(x, 3, strides=1, padding='SAME')
        branch_pool = conv_norm_relu(branch_pool, 192, 1, 1)
        x = concat([branch1x1, branch7x7, branch7x7dbl, branch_pool],
                   axis=-1,
                   name='mixed7')

        # mixed 8: 8 x 8 x 1280
        branch3x3 = conv_norm_relu(x, 192, 1)
        branch3x3 = conv_norm_relu(branch3x3,
                                   320,
                                   3,
                                   strides=2,
                                   padding='VALID')

        branch7x7x3 = conv_norm_relu(x, 192, 1, 1)
        branch7x7x3 = conv_norm_relu(branch7x7x3, 192, 1, 7)
        branch7x7x3 = conv_norm_relu(branch7x7x3, 192, 7, 1)
        branch7x7x3 = conv_norm_relu(branch7x7x3,
                                     192,
                                     3,
                                     3,
                                     strides=2,
                                     padding='VALID')

        branch_pool = max_pool(x, 3, strides=2)
        x = concat([branch3x3, branch7x7x3, branch_pool],
                   axis=-1,
                   name='mixed8')

        # mixed 9: 8 x 8 x 2048
        for i in range(2):
            branch1x1 = conv_norm_relu(x, 320, 1, 1)

            branch3x3 = conv_norm_relu(x, 384, 1, 1)
            branch3x3_1 = conv_norm_relu(branch3x3, 384, 1, 3)
            branch3x3_2 = conv_norm_relu(branch3x3, 384, 3, 1)
            branch3x3 = concat([branch3x3_1, branch3x3_2],
                               axis=-1,
                               name='mixed9_' + str(i))

            branch3x3dbl = conv_norm_relu(x, 448, 1, 1)
            branch3x3dbl = conv_norm_relu(branch3x3dbl, 384, 3, 3)
            branch3x3dbl_1 = conv_norm_relu(branch3x3dbl, 384, 1, 3)
            branch3x3dbl_2 = conv_norm_relu(branch3x3dbl, 384, 3, 1)
            branch3x3dbl = concat([branch3x3dbl_1, branch3x3dbl_2], axis=-1)

            branch_pool = avg_pool(x, 3, strides=1, padding='SAME')
            branch_pool = conv_norm_relu(branch_pool, 192, 1, 1)
            x = concat([branch1x1, branch3x3, branch3x3dbl, branch_pool],
                       axis=-1,
                       name='mixed' + str(9 + i))

        # Classification block
        x = avg_pool(x, kernel_size=8, strides=1, name='avg_pool')
        x = squeeze(x, axis=[1, 2], name='squeeze')
        x = fully_connected(x, self.num_classes, name='predictions')
        x = softmax(x, name='output-prob')

        return x
    def normal_a_cell(ip, p, filters, block_id=None):
        """Adds a Normal cell for NASNet-A (Fig. 4 in the paper).

        Args:
            ip: Input tensor `x`
            p: Input tensor `p`
            filters: Number of output filters
            block_id: String block_id

        Returns:
            A tensorflow tensor
        """
        channel_dim = -1

        with tf.name_scope('normal_A_block_%s' % block_id):
            p = NASNetMobile.adjust_block(p, ip, filters, block_id)

            h = layers.relu(ip)
            h = layers.conv(h,
                            filters_out=filters,
                            kernel_size=(1, 1),
                            stride=1,
                            padding='same',
                            name='normal_conv_1_%s' % block_id,
                            add_bias=False)
            h = layers.norm(h,
                            axis=channel_dim,
                            momentum=0.9997,
                            epsilon=1e-3,
                            name='normal_bn_1_%s' % block_id)

            with tf.name_scope('block_1'):
                x1_1 = NASNetMobile.separable_conv_block(
                    h,
                    filters,
                    kernel_size=5,
                    block_id='normal_left1_%s' % block_id)

                x1_2 = NASNetMobile.separable_conv_block(
                    p, filters, block_id='normal_right1_%s' % block_id)
                x1 = x1_1 + x1_2

            with tf.name_scope('block_2'):
                x2_1 = NASNetMobile.separable_conv_block(
                    p, filters, 5, block_id='normal_left2_%s' % block_id)
                x2_2 = NASNetMobile.separable_conv_block(
                    p, filters, 3, block_id='normal_right2_%s' % block_id)
                x2 = x2_1 + x2_2

            with tf.name_scope('block_3'):
                x3 = layers.avg_pool(h,
                                     3,
                                     strides=1,
                                     padding='same',
                                     name='normal_left3_%s' % block_id)
                x3 = x3 + p

            with tf.name_scope('block_4'):
                x4_1 = layers.avg_pool(p,
                                       3,
                                       strides=1,
                                       padding='same',
                                       name='normal_left4_%s' % block_id)
                x4_2 = layers.avg_pool(p,
                                       3,
                                       strides=1,
                                       padding='same',
                                       name='normal_right4_%s' % block_id)
                x4 = x4_1 + x4_2

            with tf.name_scope('block_5'):
                x5 = NASNetMobile.separable_conv_block(
                    h, filters, block_id='normal_left5_%s' % block_id)
                x5 = x5 + h

            x = layers.concat([p, x1, x2, x3, x4, x5],
                              axis=channel_dim,
                              name='normal_concat_%s' % block_id)

        return x, ip
    def reduction_a_cell(ip, p, filters, block_id=None):
        """Adds a Reduction cell for NASNet-A (Fig. 4 in the paper).

         Args:
             ip: Input tensor `x`
             p: Input tensor `p`
             filters: Number of output filters
             block_id: String block_id

         Returns:
             A tf tensor
         """
        channel_dim = -1

        with tf.name_scope('reduction_A_block_%s' % block_id):
            p = NASNetMobile.adjust_block(p, ip, filters, block_id)

            h = layers.relu(ip)
            h = layers.conv(h,
                            filters_out=filters,
                            kernel_size=(1, 1),
                            stride=1,
                            padding='same',
                            name='reduction_conv_1_%s' % block_id,
                            add_bias=False)
            h = layers.norm(h,
                            axis=channel_dim,
                            momentum=0.9997,
                            epsilon=1e-3,
                            name='reduction_bn_1_%s' % block_id)

            h3 = layers.zero_padding(h,
                                     padding=NASNetMobile.correct_pad(
                                         h, (3, 3)),
                                     name='reduction_pad_1_%s' % block_id)

            with tf.name_scope('block_1'):
                x1_1 = NASNetMobile.separable_conv_block(
                    h,
                    filters=filters,
                    kernel_size=5,
                    strides=2,
                    block_id='reduction_left1_%s' % block_id)
                x1_2 = NASNetMobile.separable_conv_block(
                    p,
                    filters=filters,
                    kernel_size=7,
                    strides=2,
                    block_id='reduction_right1_%s' % block_id)
                x1 = x1_1 + x1_2

            with tf.name_scope('block_2'):
                x2_1 = layers.max_pool(h3,
                                       3,
                                       strides=2,
                                       padding='valid',
                                       name='reduction_left2_%s' % block_id)
                x2_2 = NASNetMobile.separable_conv_block(
                    p,
                    filters=filters,
                    kernel_size=7,
                    strides=2,
                    block_id='reduction_right2_%s' % block_id)
                x2 = x2_1 + x2_2

            with tf.name_scope('block_3'):
                x3_1 = layers.avg_pool(h3,
                                       3,
                                       strides=2,
                                       padding='valid',
                                       name='reduction_left3_%s' % block_id)
                x3_2 = NASNetMobile.separable_conv_block(
                    p,
                    filters,
                    5,
                    strides=2,
                    block_id='reduction_right3_%s' % block_id)
                x3 = x3_1 + x3_2

            with tf.name_scope('block_4'):
                x4 = layers.avg_pool(x1,
                                     3,
                                     strides=1,
                                     padding='same',
                                     name='reduction_left4_%s' % block_id)
                x4 += x2

            with tf.name_scope('block_5'):
                x5_1 = NASNetMobile.separable_conv_block(
                    x1, filters, 3, block_id='reduction_left4_%s' % block_id)
                x5_2 = layers.max_pool(h3,
                                       3,
                                       strides=2,
                                       padding='valid',
                                       name='reduction_right5_%s' % block_id)
                x5 = x5_1 + x5_2

            x = layers.concat([x2, x3, x4, x5],
                              axis=channel_dim,
                              name='reduction_concat_%s' % block_id)
            return x, ip
    def adjust_block(p, ip, filters, block_id=None):
        """Adjusts the input `previous path` to match the shape of the `input`.

        Used in situations where the output number of filters needs to be changed.

        Args:
            p: Input tensor which needs to be modified
            ip: Input tensor whose shape needs to be matched
            filters: Number of output filters to be matched
            block_id: String block_id

        Returns:
            Adjusted tf tensor.
        """
        channel_dim = -1
        img_dim = -2

        ip_shape = ip.get_shape().as_list()

        if p is not None:
            p_shape = p.get_shape().as_list()
        else:
            p_shape = ip_shape

        with tf.name_scope('adjust_block'):
            if p is None:
                p = ip

            elif p_shape[img_dim] != ip_shape[img_dim]:
                with tf.name_scope('adjust_reduction_block_%s' % block_id):
                    p = layers.relu(p, name='adjust_relu_1_%s' % block_id)
                    p1 = layers.avg_pool(p,
                                         1,
                                         strides=2,
                                         padding='valid',
                                         name='adjust_avg_pool_1_%s' %
                                         block_id)
                    p1 = layers.conv(p1,
                                     filters_out=filters // 2,
                                     kernel_size=(1, 1),
                                     padding='same',
                                     add_bias=False,
                                     name='adjust_conv_1_%s' % block_id)

                    p2 = layers.zero_padding(p, padding=((0, 1), (0, 1)))
                    p2 = layers.crop(p2, cropping=((1, 0), (1, 0)))
                    p2 = layers.avg_pool(p2,
                                         1,
                                         strides=2,
                                         padding='valid',
                                         name='adjust_avg_pool_2_%s' %
                                         block_id)
                    p2 = layers.conv(p2,
                                     filters_out=filters // 2,
                                     kernel_size=(1, 1),
                                     padding='same',
                                     add_bias=False,
                                     name='adjust_conv_2_%s' % block_id)

                    p = layers.concat([p1, p2], axis=channel_dim)
                    p = layers.norm(p,
                                    axis=channel_dim,
                                    momentum=0.9997,
                                    epsilon=1e-3,
                                    name='adjust_bn_%s' % block_id)

            elif p_shape[channel_dim] != filters:
                with tf.name_scope('adjust_projection_block_%s' % block_id):
                    p = layers.relu(p)
                    p = layers.conv(p,
                                    filters_out=filters,
                                    kernel_size=(1, 1),
                                    stride=1,
                                    padding='same',
                                    name='adjust_conv_projection_%s' %
                                    block_id,
                                    add_bias=False)
                    p = layers.norm(p,
                                    axis=channel_dim,
                                    momentum=0.9997,
                                    epsilon=1e-3,
                                    name='adjust_bn_%s' % block_id)
        return p