Esempio n. 1
0
    def __init__(self, mconfig, se_filters, output_filters, name=None):
        super().__init__(name=name)

        self._local_pooling = mconfig.local_pooling
        self._data_format = mconfig.data_format
        self._act = utils.get_act_fn(mconfig.act_fn)

        # Squeeze and Excitation layer.
        self._se_reduce = tf.keras.layers.Conv2D(
            se_filters,
            kernel_size=1,
            strides=1,
            kernel_initializer=conv_kernel_initializer,
            padding='same',
            data_format=self._data_format,
            use_bias=True,
            name='conv2d')
        self._se_expand = tf.keras.layers.Conv2D(
            output_filters,
            kernel_size=1,
            strides=1,
            kernel_initializer=conv_kernel_initializer,
            padding='same',
            data_format=self._data_format,
            use_bias=True,
            name='conv2d_1')
Esempio n. 2
0
    def __init__(self, block_args, mconfig, name=None):
        """Initializes a MBConv block.

        Args:
          block_args: BlockArgs, arguments to create a Block.
          mconfig: GlobalParams, a set of global parameters.
          name: layer name.
        """
        super().__init__(name=name)

        self._block_args = copy.deepcopy(block_args)
        self._mconfig = copy.deepcopy(mconfig)
        self._local_pooling = mconfig.local_pooling
        self._data_format = mconfig.data_format
        self._channel_axis = 1 if self._data_format == 'channels_first' else -1

        self._act = utils.get_act_fn(mconfig.act_fn)
        self._has_se = (
            self._block_args.se_ratio is not None and
            0 < self._block_args.se_ratio <= 1)

        self.endpoints = None

        # Builds the block accordings to arguments.
        self._build()
Esempio n. 3
0
 def __init__(self, mconfig, stem_filters, name=None):
     super().__init__(name=name)
     self._conv_stem = tf.keras.layers.Conv2D(
         filters=round_filters(stem_filters, mconfig),
         kernel_size=3,
         strides=2,
         kernel_initializer=conv_kernel_initializer,
         padding='same',
         data_format=mconfig.data_format,
         use_bias=False,
         name='conv2d')
     self._norm = utils.normalization(
         mconfig.bn_type,
         axis=(1 if mconfig.data_format == 'channels_first' else -1),
         momentum=mconfig.bn_momentum,
         epsilon=mconfig.bn_epsilon,
         groups=mconfig.gn_groups)
     self._act = utils.get_act_fn(mconfig.act_fn)
Esempio n. 4
0
    def __init__(self, mconfig, name=None):
        super().__init__(name=name)

        self.endpoints = {}
        self._mconfig = mconfig

        self._conv_head = tf.keras.layers.Conv2D(
            filters=round_filters(mconfig.feature_size or 1280, mconfig),
            kernel_size=1,
            strides=1,
            kernel_initializer=conv_kernel_initializer,
            padding='same',
            data_format=mconfig.data_format,
            use_bias=False,
            name='conv2d')
        self._norm = utils.normalization(
            mconfig.bn_type,
            axis=(1 if mconfig.data_format == 'channels_first' else -1),
            momentum=mconfig.bn_momentum,
            epsilon=mconfig.bn_epsilon,
            groups=mconfig.gn_groups)
        self._act = utils.get_act_fn(mconfig.act_fn)

        self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
            data_format=mconfig.data_format)
        if mconfig.num_classes:
            self._fc = tf.keras.layers.Dense(
                mconfig.num_classes,
                kernel_initializer=dense_kernel_initializer,
                bias_initializer=tf.constant_initializer(mconfig.headbias
                                                         or 0))
        else:
            self._fc = None

        if mconfig.dropout_rate > 0:
            self._dropout = tf.keras.layers.Dropout(mconfig.dropout_rate)
        else:
            self._dropout = None

        self.h_axis, self.w_axis = ([2, 3] if mconfig.data_format
                                    == 'channels_first' else [1, 2])