def forward(self, inputs, drop_connect_rate=None): """ :param inputs: input tensor :param drop_connect_rate: drop connect rate (float, between 0 and 1) :return: output of block """ # Expansion and Depthwise Convolution x = inputs if self._block_args.expand_ratio != 1: x = relu_fn(self._bn0(self._expand_conv(inputs))) x = relu_fn(self._bn1(self._depthwise_conv(x))) # Squeeze and Excitation if self.has_se: x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._se_expand(relu_fn(self._se_reduce(x_squeezed))) x = torch.sigmoid(x_squeezed) * x x = self._bn2(self._project_conv(x)) # Skip connection and drop connect input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: if drop_connect_rate: x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection return x
def forward(self, inputs, drop_connect_rate=None): """MBConvBlock's forward function. Args: inputs (tensor): Input tensor. drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). Returns: Output of this block after processing. """ # Expansion and Depthwise Convolution x = inputs if self._block_args.expand_ratio != 1: x = self._expand_conv(inputs) x = self._bn0(x) x = self._swish(x) x = self._depthwise_conv(x) x = self._bn1(x) x = self._swish(x) # Squeeze and Excitation if self.has_se: x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._se_reduce(x_squeezed) x_squeezed = self._swish(x_squeezed) x_squeezed = self._se_expand(x_squeezed) x = torch.sigmoid(x_squeezed) * x # Pointwise Convolution x = self._project_conv(x) x = self._bn2(x) # Skip connection and drop connect input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: # The combination of skip connection and drop connect brings about stochastic depth. if drop_connect_rate: x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection return x
def call(self, inputs, training=True, survival_prob=None): """Implementation of call(). Args: inputs: the inputs tensor. training: boolean, whether the model is constructed for training. survival_prob: float, between 0 to 1, drop connect rate. Returns: A output tensor. """ logging.info('Block input: %s shape: %s', inputs.name, inputs.shape) if self._block_args.expand_ratio != 1: x = self._relu_fn( self._bn0(self._expand_conv(inputs), training=training)) else: x = inputs logging.info('Expand: %s shape: %s', x.name, x.shape) self.endpoints = {'expansion_output': x} x = self._bn1(self._project_conv(x), training=training) # Add identity so that quantization-aware training can insert quantization # ops correctly. x = tf.identity(x) if self._clip_projection_output: x = tf.clip_by_value(x, -6, 6) if self._block_args.id_skip: if all( s == 1 for s in self._block_args.strides ) and self._block_args.input_filters == self._block_args.output_filters: # Apply only if skip connection presents. if survival_prob: x = utils.drop_connect(x, training, survival_prob) x = tf.add(x, inputs) logging.info('Project: %s shape: %s', x.name, x.shape) return x
def call(self, inputs, training=True, survival_prob=None): """Implementation of call(). Args: inputs: the inputs tensor. training: boolean, whether the model is constructed for training. survival_prob: float, between 0 to 1, drop connect rate. Returns: A output tensor. """ logging.info('Block input: %s shape: %s', inputs.name, inputs.shape) logging.info('Block input depth: %s output depth: %s', self._block_args.input_filters, self._block_args.output_filters) x = inputs fused_conv_fn = self._fused_conv expand_conv_fn = self._expand_conv depthwise_conv_fn = self._depthwise_conv project_conv_fn = self._project_conv if self._block_args.condconv: pooled_inputs = self._avg_pooling(inputs) routing_weights = self._routing_fn(pooled_inputs) # Capture routing weights as additional input to CondConv layers fused_conv_fn = functools.partial(self._fused_conv, routing_weights=routing_weights) expand_conv_fn = functools.partial(self._expand_conv, routing_weights=routing_weights) depthwise_conv_fn = functools.partial( self._depthwise_conv, routing_weights=routing_weights) project_conv_fn = functools.partial( self._project_conv, routing_weights=routing_weights) # creates conv 2x2 kernel if self._block_args.space2depth == 1: with tf.variable_scope('space2depth'): x = self._relu_fn( self._bnsp(self._space2depth(x), training=training)) logging.info('Block start with space2depth: %s shape: %s', x.name, x.shape) if self._block_args.fused_conv: # If use fused mbconv, skip expansion and use regular conv. x = self._relu_fn(self._bn1(fused_conv_fn(x), training=training)) logging.info('Conv2D: %s shape: %s', x.name, x.shape) else: # Otherwise, first apply expansion and then apply depthwise conv. if self._block_args.expand_ratio != 1: x = self._relu_fn( self._bn0(expand_conv_fn(x), training=training)) logging.info('Expand: %s shape: %s', x.name, x.shape) x = self._relu_fn( self._bn1(depthwise_conv_fn(x), training=training)) logging.info('DWConv: %s shape: %s', x.name, x.shape) if self._has_se: with tf.variable_scope('se'): x = self._call_se(x) self.endpoints = {'expansion_output': x} x = self._bn2(project_conv_fn(x), training=training) # Add identity so that quantization-aware training can insert quantization # ops correctly. x = tf.identity(x) if self._clip_projection_output: x = tf.clip_by_value(x, -6, 6) if self._block_args.id_skip: if all(s == 1 for s in self._block_args.strides) and inputs.get_shape( ).as_list()[-1] == x.get_shape().as_list()[-1]: # Apply only if skip connection presents. if survival_prob: x = utils.drop_connect(x, training, survival_prob) x = tf.add(x, inputs) logging.info('Project: %s shape: %s', x.name, x.shape) return x