示例#1
0
    def forward(self, inputs, drop_connect_rate=None):
        """
        :param inputs: input tensor
        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
        :return: output of block
        """

        # Expansion and Depthwise Convolution
        x = inputs
        if self._block_args.expand_ratio != 1:
            x = self._swish(self._bn0(self._expand_conv(inputs)))
        x = self._swish(self._bn1(self._depthwise_conv(x)))

        # Squeeze and Excitation
        if self.has_se:
            x_squeezed = F.adaptive_avg_pool2d(x, 1)
            x_squeezed = self._se_expand(
                self._swish(self._se_reduce(x_squeezed)))
            x = torch.sigmoid(x_squeezed) * x

        x = self._bn2(self._project_conv(x))

        # Skip connection and drop connect
        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
            if drop_connect_rate:
                x = drop_connect(x,
                                 p=drop_connect_rate,
                                 training=self.training)
            x = x + inputs  # skip connection
        return x
示例#2
0
    def forward(self, inputs, drop_connect_rate=None):
        """
        :param inputs: input tensor
        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
        :return: output of block
        """

        # Expansion and Depthwise Convolution
        x = inputs
        if self._block_args.expand_ratio != 1:
            x = self._bn0(self._expand_conv(inputs))
        x = self._bn1(self._depthwise_conv(x))

        # Squeeze and Excitation
        # Delete SE layer from here

        x = self._bn2(self._project_conv(x))

        # Skip connection and drop connect
        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
            if drop_connect_rate:
                x = drop_connect(x,
                                 p=drop_connect_rate,
                                 training=self.training)
            x = x + inputs  # skip connection
        return x
示例#3
0
  def call(self, inputs, training=True, drop_connect_rate=None):
    """Implementation of call().

    Args:
      inputs: the inputs tensor.
      training: boolean, whether the model is constructed for training.
      drop_connect_rate: float, between 0 to 1, drop connect rate.

    Returns:
      A output tensor.
    """
    tf.logging.info('Block input: %s shape: %s' % (inputs.name, inputs.shape))
    if self._block_args.expand_ratio != 1:
      x = relu_fn(self._bn0(self._expand_conv(inputs), training=training))
    else:
      x = inputs
    tf.logging.info('Expand: %s shape: %s' % (x.name, x.shape))

    x = relu_fn(self._bn1(self._depthwise_conv(x), training=training))
    tf.logging.info('DWConv: %s shape: %s' % (x.name, x.shape))

    if self.has_se:
      with tf.variable_scope('se'):
        x = self._call_se(x)

    self.endpoints = {'expansion_output': x}

    x = self._bn2(self._project_conv(x), training=training)
    if self._block_args.id_skip:
      if all(
          s == 1 for s in self._block_args.strides
      ) and self._block_args.input_filters == self._block_args.output_filters:
        # only apply drop_connect if skip presents.
        if drop_connect_rate:
          x = utils.drop_connect(x, training, drop_connect_rate)
        x = tf.add(x, inputs)
    tf.logging.info('Project: %s shape: %s' % (x.name, x.shape))
    return x
示例#4
0
import torch
示例#5
0
    def forward(self, inputs, drop_connect_rate=None):
        """
        Forward run of the block, see comments to EfficientNet.forward for clarification.
        """

        ops, total_ops = 0., 0.

        x = inputs
        if self._block_args.expand_ratio != 1:
            x, delta_ops, delta_ops_total = self._expand_conv(
                inputs, is_not_quantized=False)
            ops, total_ops = ops + delta_ops, total_ops + delta_ops_total

            delta_ops, delta_ops_total = ops_bn(x, is_not_quantized=False)
            ops, total_ops = ops + delta_ops, total_ops + delta_ops_total
            x = self._bn0(x)

            delta_ops, delta_ops_total = ops_non_linearity(
                x, is_not_quantized=False)
            ops, total_ops = ops + delta_ops, total_ops + delta_ops_total
            x = relu_fn(x)

        x, delta_ops, delta_ops_total = self._depthwise_conv(
            x, is_not_quantized=False)
        ops, total_ops = ops + delta_ops, total_ops + delta_ops_total

        delta_ops, delta_ops_total = ops_bn(x, is_not_quantized=False)
        ops, total_ops = ops + delta_ops, total_ops + delta_ops_total
        x = self._bn1(x)

        delta_ops, delta_ops_total = ops_non_linearity(x,
                                                       is_not_quantized=False)
        ops, total_ops = ops + delta_ops, total_ops + delta_ops_total
        x = relu_fn(x)

        if self.has_se:
            delta_ops, delta_ops_total = ops_adaptive_avg_pool(
                x, is_not_quantized=False)
            ops, total_ops = ops + delta_ops, total_ops + delta_ops_total
            x_squeezed = F.adaptive_avg_pool2d(x, 1)

            x_squeezed, delta_ops, delta_ops_total = self._se_reduce(
                x_squeezed, is_not_quantized=False)
            ops, total_ops = ops + delta_ops, total_ops + delta_ops_total

            delta_ops, delta_ops_total = ops_non_linearity(
                x_squeezed, is_not_quantized=False)
            ops, total_ops = ops + delta_ops, total_ops + delta_ops_total
            x_squeezed = relu_fn(x_squeezed)

            x_squeezed, delta_ops, delta_ops_total = self._se_expand(
                x_squeezed, is_not_quantized=False)
            ops, total_ops = ops + delta_ops, total_ops + delta_ops_total

            delta_ops, delta_ops_total = ops_non_linearity(
                x, is_not_quantized=False)
            ops, total_ops = ops + delta_ops, total_ops + delta_ops_total
            x = torch.sigmoid(x_squeezed) * x

        x, delta_ops, delta_ops_total = self._project_conv(
            x, is_not_quantized=False)
        ops, total_ops = ops + delta_ops, total_ops + delta_ops_total

        delta_ops, delta_ops_total = ops_bn(x, is_not_quantized=False)
        ops, total_ops = ops + delta_ops, total_ops + delta_ops_total
        x = self._bn2(x)

        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
            if drop_connect_rate:
                x = drop_connect(x,
                                 p=drop_connect_rate,
                                 training=self.training)
            x = x + inputs  # skip connection

        delta_ops, delta_ops_total = ops_non_linearity(x,
                                                       is_not_quantized=False)
        ops, total_ops = ops + delta_ops, total_ops + delta_ops_total
        return x, ops, total_ops