Esempio n. 1
0
  def _build(self):
    """Builds a model."""
    self._blocks = []
    batch_norm_momentum = self._global_params.batch_norm_momentum
    batch_norm_epsilon = self._global_params.batch_norm_epsilon
    if self._global_params.data_format == 'channels_first':
      channel_axis = 1
      self._spatial_dims = [2, 3]
    else:
      channel_axis = -1
      self._spatial_dims = [1, 2]

    # Stem part.
    self._conv_stem = utils.Conv2D(
        filters=round_filters(32, self._global_params, self._fix_head_stem),
        kernel_size=[3, 3],
        strides=[2, 2],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._global_params.data_format,
        use_bias=False)
    self._bn0 = self._batch_norm(
        axis=channel_axis,
        momentum=batch_norm_momentum,
        epsilon=batch_norm_epsilon)

    # Builds blocks.
    for i, block_args in enumerate(self._blocks_args):
      assert block_args.num_repeat > 0
      assert block_args.super_pixel in [0, 1, 2]
      # Update block input and output filters based on depth multiplier.
      input_filters = round_filters(block_args.input_filters,
                                    self._global_params)

      output_filters = round_filters(block_args.output_filters,
                                     self._global_params)
      kernel_size = block_args.kernel_size
      if self._fix_head_stem and (i == 0 or i == len(self._blocks_args) - 1):
        repeats = block_args.num_repeat
      else:
        repeats = round_repeats(block_args.num_repeat, self._global_params)
      block_args = block_args._replace(
          input_filters=input_filters,
          output_filters=output_filters,
          num_repeat=repeats)

      # The first block needs to take care of stride and filter size increase.
      conv_block = self._get_conv_block(block_args.conv_type)
      if not block_args.super_pixel:  #  no super_pixel at all
        self._blocks.append(conv_block(block_args, self._global_params))
      else:
        # if superpixel, adjust filters, kernels, and strides.
        depth_factor = int(4 / block_args.strides[0] / block_args.strides[1])
        block_args = block_args._replace(
            input_filters=block_args.input_filters * depth_factor,
            output_filters=block_args.output_filters * depth_factor,
            kernel_size=((block_args.kernel_size + 1) // 2 if depth_factor > 1
                         else block_args.kernel_size))
        # if the first block has stride-2 and super_pixel trandformation
        if (block_args.strides[0] == 2 and block_args.strides[1] == 2):
          block_args = block_args._replace(strides=[1, 1])
          self._blocks.append(conv_block(block_args, self._global_params))
          block_args = block_args._replace(  # sp stops at stride-2
              super_pixel=0,
              input_filters=input_filters,
              output_filters=output_filters,
              kernel_size=kernel_size)
        elif block_args.super_pixel == 1:
          self._blocks.append(conv_block(block_args, self._global_params))
          block_args = block_args._replace(super_pixel=2)
        else:
          self._blocks.append(conv_block(block_args, self._global_params))
      if block_args.num_repeat > 1:  # rest of blocks with the same block_arg
        # pylint: disable=protected-access
        block_args = block_args._replace(
            input_filters=block_args.output_filters, strides=[1, 1])
        # pylint: enable=protected-access
      for _ in xrange(block_args.num_repeat - 1):
        self._blocks.append(conv_block(block_args, self._global_params))

    # Head part.
    self._conv_head = utils.Conv2D(
        filters=round_filters(1280, self._global_params, self._fix_head_stem),
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._global_params.data_format,
        use_bias=False)
    self._bn1 = self._batch_norm(
        axis=channel_axis,
        momentum=batch_norm_momentum,
        epsilon=batch_norm_epsilon)

    self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
        data_format=self._global_params.data_format)
    if self._global_params.num_classes:
      self._fc = tf.layers.Dense(
          self._global_params.num_classes,
          kernel_initializer=dense_kernel_initializer)
    else:
      self._fc = None

    if self._global_params.dropout_rate > 0:
      self._dropout = tf.keras.layers.Dropout(self._global_params.dropout_rate)
    else:
      self._dropout = None
Esempio n. 2
0
  def _build(self):
    """Builds block according to the arguments."""
    if self._block_args.super_pixel == 1:
      self._superpixel = tf.layers.Conv2D(
          self._block_args.input_filters,
          kernel_size=[2, 2],
          strides=[2, 2],
          kernel_initializer=conv_kernel_initializer,
          padding='same',
          data_format=self._data_format,
          use_bias=False)
      self._bnsp = self._batch_norm(
          axis=self._channel_axis,
          momentum=self._batch_norm_momentum,
          epsilon=self._batch_norm_epsilon)

    if self._block_args.condconv:
      # Add the example-dependent routing function
      self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
          data_format=self._data_format)
      self._routing_fn = tf.layers.Dense(
          self._condconv_num_experts, activation=tf.nn.sigmoid)

    filters = self._block_args.input_filters * self._block_args.expand_ratio
    kernel_size = self._block_args.kernel_size

    # Fused expansion phase. Called if using fused convolutions.
    self._fused_conv = self.conv_cls(
        filters=filters,
        kernel_size=[kernel_size, kernel_size],
        strides=self._block_args.strides,
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False)

    # Expansion phase. Called if not using fused convolutions and expansion
    # phase is necessary.
    self._expand_conv = self.conv_cls(
        filters=filters,
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False)
    self._bn0 = self._batch_norm(
        axis=self._channel_axis,
        momentum=self._batch_norm_momentum,
        epsilon=self._batch_norm_epsilon)

    # Depth-wise convolution phase. Called if not using fused convolutions.
    self._depthwise_conv = self.depthwise_conv_cls(
        kernel_size=[kernel_size, kernel_size],
        strides=self._block_args.strides,
        depthwise_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False)

    self._bn1 = self._batch_norm(
        axis=self._channel_axis,
        momentum=self._batch_norm_momentum,
        epsilon=self._batch_norm_epsilon)

    if self._has_se:
      num_reduced_filters = max(
          1, int(
              self._block_args.input_filters * (self._block_args.se_ratio * (
                  self._se_coefficient if self._se_coefficient else 1))))
      # Squeeze and Excitation layer.
      self._se_reduce = utils.Conv2D(
          num_reduced_filters,
          kernel_size=[1, 1],
          strides=[1, 1],
          kernel_initializer=conv_kernel_initializer,
          padding='same',
          data_format=self._data_format,
          use_bias=True)
      self._se_expand = utils.Conv2D(
          filters,
          kernel_size=[1, 1],
          strides=[1, 1],
          kernel_initializer=conv_kernel_initializer,
          padding='same',
          data_format=self._data_format,
          use_bias=True)

    # Output phase.
    filters = self._block_args.output_filters
    self._project_conv = self.conv_cls(
        filters=filters,
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False)
    self._bn2 = self._batch_norm(
        axis=self._channel_axis,
        momentum=self._batch_norm_momentum,
        epsilon=self._batch_norm_epsilon)