def _build(self): """Builds a model.""" self._blocks = [] batch_norm_momentum = self._global_params.batch_norm_momentum batch_norm_epsilon = self._global_params.batch_norm_epsilon if self._global_params.data_format == 'channels_first': channel_axis = 1 self._spatial_dims = [2, 3] else: channel_axis = -1 self._spatial_dims = [1, 2] # Stem part. self._conv_stem = utils.Conv2D( filters=round_filters(32, self._global_params, self._fix_head_stem), kernel_size=[3, 3], strides=[2, 2], kernel_initializer=conv_kernel_initializer, padding='same', data_format=self._global_params.data_format, use_bias=False) self._bn0 = self._batch_norm( axis=channel_axis, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon) # Builds blocks. for i, block_args in enumerate(self._blocks_args): assert block_args.num_repeat > 0 assert block_args.super_pixel in [0, 1, 2] # Update block input and output filters based on depth multiplier. input_filters = round_filters(block_args.input_filters, self._global_params) output_filters = round_filters(block_args.output_filters, self._global_params) kernel_size = block_args.kernel_size if self._fix_head_stem and (i == 0 or i == len(self._blocks_args) - 1): repeats = block_args.num_repeat else: repeats = round_repeats(block_args.num_repeat, self._global_params) block_args = block_args._replace( input_filters=input_filters, output_filters=output_filters, num_repeat=repeats) # The first block needs to take care of stride and filter size increase. conv_block = self._get_conv_block(block_args.conv_type) if not block_args.super_pixel: # no super_pixel at all self._blocks.append(conv_block(block_args, self._global_params)) else: # if superpixel, adjust filters, kernels, and strides. depth_factor = int(4 / block_args.strides[0] / block_args.strides[1]) block_args = block_args._replace( input_filters=block_args.input_filters * depth_factor, output_filters=block_args.output_filters * depth_factor, kernel_size=((block_args.kernel_size + 1) // 2 if depth_factor > 1 else block_args.kernel_size)) # if the first block has stride-2 and super_pixel trandformation if (block_args.strides[0] == 2 and block_args.strides[1] == 2): block_args = block_args._replace(strides=[1, 1]) self._blocks.append(conv_block(block_args, self._global_params)) block_args = block_args._replace( # sp stops at stride-2 super_pixel=0, input_filters=input_filters, output_filters=output_filters, kernel_size=kernel_size) elif block_args.super_pixel == 1: self._blocks.append(conv_block(block_args, self._global_params)) block_args = block_args._replace(super_pixel=2) else: self._blocks.append(conv_block(block_args, self._global_params)) if block_args.num_repeat > 1: # rest of blocks with the same block_arg # pylint: disable=protected-access block_args = block_args._replace( input_filters=block_args.output_filters, strides=[1, 1]) # pylint: enable=protected-access for _ in xrange(block_args.num_repeat - 1): self._blocks.append(conv_block(block_args, self._global_params)) # Head part. self._conv_head = utils.Conv2D( filters=round_filters(1280, self._global_params, self._fix_head_stem), kernel_size=[1, 1], strides=[1, 1], kernel_initializer=conv_kernel_initializer, padding='same', data_format=self._global_params.data_format, use_bias=False) self._bn1 = self._batch_norm( axis=channel_axis, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon) self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D( data_format=self._global_params.data_format) if self._global_params.num_classes: self._fc = tf.layers.Dense( self._global_params.num_classes, kernel_initializer=dense_kernel_initializer) else: self._fc = None if self._global_params.dropout_rate > 0: self._dropout = tf.keras.layers.Dropout(self._global_params.dropout_rate) else: self._dropout = None
def _build(self): """Builds block according to the arguments.""" if self._block_args.super_pixel == 1: self._superpixel = tf.layers.Conv2D( self._block_args.input_filters, kernel_size=[2, 2], strides=[2, 2], kernel_initializer=conv_kernel_initializer, padding='same', data_format=self._data_format, use_bias=False) self._bnsp = self._batch_norm( axis=self._channel_axis, momentum=self._batch_norm_momentum, epsilon=self._batch_norm_epsilon) if self._block_args.condconv: # Add the example-dependent routing function self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D( data_format=self._data_format) self._routing_fn = tf.layers.Dense( self._condconv_num_experts, activation=tf.nn.sigmoid) filters = self._block_args.input_filters * self._block_args.expand_ratio kernel_size = self._block_args.kernel_size # Fused expansion phase. Called if using fused convolutions. self._fused_conv = self.conv_cls( filters=filters, kernel_size=[kernel_size, kernel_size], strides=self._block_args.strides, kernel_initializer=conv_kernel_initializer, padding='same', data_format=self._data_format, use_bias=False) # Expansion phase. Called if not using fused convolutions and expansion # phase is necessary. self._expand_conv = self.conv_cls( filters=filters, kernel_size=[1, 1], strides=[1, 1], kernel_initializer=conv_kernel_initializer, padding='same', data_format=self._data_format, use_bias=False) self._bn0 = self._batch_norm( axis=self._channel_axis, momentum=self._batch_norm_momentum, epsilon=self._batch_norm_epsilon) # Depth-wise convolution phase. Called if not using fused convolutions. self._depthwise_conv = self.depthwise_conv_cls( kernel_size=[kernel_size, kernel_size], strides=self._block_args.strides, depthwise_initializer=conv_kernel_initializer, padding='same', data_format=self._data_format, use_bias=False) self._bn1 = self._batch_norm( axis=self._channel_axis, momentum=self._batch_norm_momentum, epsilon=self._batch_norm_epsilon) if self._has_se: num_reduced_filters = max( 1, int( self._block_args.input_filters * (self._block_args.se_ratio * ( self._se_coefficient if self._se_coefficient else 1)))) # Squeeze and Excitation layer. self._se_reduce = utils.Conv2D( num_reduced_filters, kernel_size=[1, 1], strides=[1, 1], kernel_initializer=conv_kernel_initializer, padding='same', data_format=self._data_format, use_bias=True) self._se_expand = utils.Conv2D( filters, kernel_size=[1, 1], strides=[1, 1], kernel_initializer=conv_kernel_initializer, padding='same', data_format=self._data_format, use_bias=True) # Output phase. filters = self._block_args.output_filters self._project_conv = self.conv_cls( filters=filters, kernel_size=[1, 1], strides=[1, 1], kernel_initializer=conv_kernel_initializer, padding='same', data_format=self._data_format, use_bias=False) self._bn2 = self._batch_norm( axis=self._channel_axis, momentum=self._batch_norm_momentum, epsilon=self._batch_norm_epsilon)