Example #1
0
    def _build(self):
        """Builds block according to the arguments."""
        # pylint: disable=g-long-lambda
        bid = itertools.count(0)

        def get_norm_name(): return 'tpu_batch_normalization' + ('' if not next(
            bid) else '_' + str(next(bid) // 2))
        cid = itertools.count(0)
        def get_conv_name(): return 'conv2d' + ('' if not next(cid) else '_' + str(
            next(cid) // 2))
        # pylint: enable=g-long-lambda

        mconfig = self._mconfig
        block_args = self._block_args
        filters = block_args.input_filters * block_args.expand_ratio
        kernel_size = block_args.kernel_size
        if block_args.expand_ratio != 1:
            # Expansion phase:
            self._expand_conv = tf.keras.layers.Conv2D(
                filters,
                kernel_size=kernel_size,
                strides=block_args.strides,
                kernel_initializer=conv_kernel_initializer,
                padding='same',
                use_bias=False,
                name=get_conv_name())
            self._norm0 = utils.normalization(
                mconfig.bn_type,
                axis=self._channel_axis,
                momentum=mconfig.bn_momentum,
                epsilon=mconfig.bn_epsilon,
                groups=mconfig.gn_groups,
                name=get_norm_name())

        if self._has_se:
            num_reduced_filters = max(
                1, int(block_args.input_filters * block_args.se_ratio))
            self._se = SE(mconfig, num_reduced_filters, filters, name='se')
        else:
            self._se = None
        # Output phase:
        filters = block_args.output_filters
        self._project_conv = tf.keras.layers.Conv2D(
            filters,
            kernel_size=1 if block_args.expand_ratio != 1 else kernel_size,
            strides=1 if block_args.expand_ratio != 1 else block_args.strides,
            kernel_initializer=conv_kernel_initializer,
            padding='same',
            use_bias=False,
            name=get_conv_name())
        self._norm1 = utils.normalization(
            mconfig.bn_type,
            axis=self._channel_axis,
            momentum=mconfig.bn_momentum,
            epsilon=mconfig.bn_epsilon,
            groups=mconfig.gn_groups,
            name=get_norm_name())
Example #2
0
 def __init__(self, mconfig, stem_filters, name=None):
     super().__init__(name=name)
     self._conv_stem = tf.keras.layers.Conv2D(
         filters=round_filters(stem_filters, mconfig),
         kernel_size=3,
         strides=2,
         kernel_initializer=conv_kernel_initializer,
         padding='same',
         data_format=mconfig.data_format,
         use_bias=False,
         name='conv2d')
     self._norm = utils.normalization(
         mconfig.bn_type,
         axis=(1 if mconfig.data_format == 'channels_first' else -1),
         momentum=mconfig.bn_momentum,
         epsilon=mconfig.bn_epsilon,
         groups=mconfig.gn_groups)
     self._act = utils.get_act_fn(mconfig.act_fn)
Example #3
0
    def __init__(self, mconfig, name=None):
        super().__init__(name=name)

        self.endpoints = {}
        self._mconfig = mconfig

        self._conv_head = tf.keras.layers.Conv2D(
            filters=round_filters(mconfig.feature_size or 1280, mconfig),
            kernel_size=1,
            strides=1,
            kernel_initializer=conv_kernel_initializer,
            padding='same',
            data_format=mconfig.data_format,
            use_bias=False,
            name='conv2d')
        self._norm = utils.normalization(
            mconfig.bn_type,
            axis=(1 if mconfig.data_format == 'channels_first' else -1),
            momentum=mconfig.bn_momentum,
            epsilon=mconfig.bn_epsilon,
            groups=mconfig.gn_groups)
        self._act = utils.get_act_fn(mconfig.act_fn)

        self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
            data_format=mconfig.data_format)
        if mconfig.num_classes:
            self._fc = tf.keras.layers.Dense(
                mconfig.num_classes,
                kernel_initializer=dense_kernel_initializer,
                bias_initializer=tf.constant_initializer(mconfig.headbias
                                                         or 0))
        else:
            self._fc = None

        if mconfig.dropout_rate > 0:
            self._dropout = tf.keras.layers.Dropout(mconfig.dropout_rate)
        else:
            self._dropout = None

        self.h_axis, self.w_axis = ([2, 3] if mconfig.data_format
                                    == 'channels_first' else [1, 2])
Example #4
0
    def _build(self):
        """Builds block according to the arguments."""
        # pylint: disable=g-long-lambda
        bid = itertools.count(0)
        get_norm_name = lambda: 'tpu_batch_normalization' + ('' if not next(
            bid) else '_' + str(next(bid) // 2))
        cid = itertools.count(0)
        get_conv_name = lambda: 'conv2d' + ('' if not next(cid) else '_' + str(
            next(cid) // 2))
        # pylint: enable=g-long-lambda

        mconfig = self._mconfig
        filters = self._block_args.input_filters * self._block_args.expand_ratio
        kernel_size = self._block_args.kernel_size

        # Expansion phase. Called if not using fused convolutions and expansion
        # phase is necessary.
        if self._block_args.expand_ratio != 1:
            self._expand_conv = tf.keras.layers.Conv2D(
                filters=filters,
                kernel_size=1,
                strides=1,
                kernel_initializer=conv_kernel_initializer,
                padding='same',
                data_format=self._data_format,
                use_bias=False,
                name=get_conv_name())
            self._norm0 = utils.normalization(mconfig.bn_type,
                                              axis=self._channel_axis,
                                              momentum=mconfig.bn_momentum,
                                              epsilon=mconfig.bn_epsilon,
                                              groups=mconfig.gn_groups,
                                              name=get_norm_name())

        # Depth-wise convolution phase. Called if not using fused convolutions.
        self._depthwise_conv = tf.keras.layers.DepthwiseConv2D(
            kernel_size=kernel_size,
            strides=self._block_args.strides,
            depthwise_initializer=conv_kernel_initializer,
            padding='same',
            data_format=self._data_format,
            use_bias=False,
            name='depthwise_conv2d')

        self._norm1 = utils.normalization(mconfig.bn_type,
                                          axis=self._channel_axis,
                                          momentum=mconfig.bn_momentum,
                                          epsilon=mconfig.bn_epsilon,
                                          groups=mconfig.gn_groups,
                                          name=get_norm_name())

        if self._has_se:
            num_reduced_filters = max(
                1,
                int(self._block_args.input_filters *
                    self._block_args.se_ratio))
            self._se = SE(self._mconfig,
                          num_reduced_filters,
                          filters,
                          name='se')
        else:
            self._se = None

        # Output phase.
        filters = self._block_args.output_filters
        self._project_conv = tf.keras.layers.Conv2D(
            filters=filters,
            kernel_size=1,
            strides=1,
            kernel_initializer=conv_kernel_initializer,
            padding='same',
            data_format=self._data_format,
            use_bias=False,
            name=get_conv_name())
        self._norm2 = utils.normalization(mconfig.bn_type,
                                          axis=self._channel_axis,
                                          momentum=mconfig.bn_momentum,
                                          epsilon=mconfig.bn_epsilon,
                                          groups=mconfig.gn_groups,
                                          name=get_norm_name())