def mbv1_block_(inputs, filters, is_training, stride, width=1., block_id=0, pruning_method='baseline', data_format='channels_first', weight_decay=0.): """Standard building block for mobilenetv1 networks. Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. filters: Int specifying number of filters for the first two convolutions. is_training: Boolean specifying whether the model is training. stride: Int specifying the stride. If stride >1, the input is downsampled. width: multiplier for channel dimensions block_id: which block this is pruning_method: String that specifies the pruning method used to identify which weights to remove. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. weight_decay: Weight for the l2 regularization loss. Returns: The output activation tensor. """ # separable_conv_2d followed by contracting 1x1 conv. end_point = 'depthwise_nxn_%s' % block_id # Depthwise depthwise_out = depthwise_conv2d_fixed_padding(inputs=inputs, kernel_size=3, stride=stride, data_format=data_format, name=end_point) depthwise_out = resnet_model.batch_norm_relu(depthwise_out, is_training, relu=True, data_format=data_format) # Contraction end_point = 'contraction_1x1_%s' % block_id divisible_by = 8 if block_id == 0: divisible_by = 1 out_filters = _make_divisible(int(width * filters), divisor=divisible_by) contraction_out = conv2d_fixed_padding(inputs=depthwise_out, filters=out_filters, kernel_size=1, strides=1, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay, name=end_point) contraction_out = resnet_model.batch_norm_relu(contraction_out, is_training, relu=True, data_format=data_format) output = contraction_out return output
def model(inputs, is_training): """Creation of the model graph.""" with tf.variable_scope(name, 'resnet_model'): inputs = resnet_model.fixed_padding(inputs, kernel_size=3, data_format=data_format) padding = 'VALID' kernel_initializer = tf.variance_scaling_initializer() kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) inputs = tf.layers.conv2d(inputs=inputs, filters=_make_divisible(32 * width), kernel_size=3, strides=2, padding=padding, use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, data_format=data_format, name='initial_conv') inputs = tf.identity(inputs, 'initial_conv') inputs = resnet_model.batch_norm_relu(inputs, is_training, data_format=data_format) mb_block = functools.partial(mbv1_block_, is_training=is_training, width=width, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay) inputs = mb_block(inputs, filters=64, stride=1, block_id=0) inputs = mb_block(inputs, filters=128, stride=2, block_id=1) inputs = mb_block(inputs, filters=128, stride=1, block_id=2) inputs = mb_block(inputs, filters=256, stride=2, block_id=3) inputs = mb_block(inputs, filters=256, stride=1, block_id=4) inputs = mb_block(inputs, filters=512, stride=2, block_id=5) inputs = mb_block(inputs, filters=512, stride=1, block_id=6) inputs = mb_block(inputs, filters=512, stride=1, block_id=7) inputs = mb_block(inputs, filters=512, stride=1, block_id=8) inputs = mb_block(inputs, filters=512, stride=1, block_id=9) inputs = mb_block(inputs, filters=512, stride=1, block_id=10) inputs = mb_block(inputs, filters=1024, stride=2, block_id=11) inputs = mb_block(inputs, filters=1024, stride=1, block_id=12) last_block_filters = _make_divisible(int(1024 * width), 8) if data_format == 'channels_last': pool_size = (inputs.shape[1], inputs.shape[2]) elif data_format == 'channels_first': pool_size = (inputs.shape[2], inputs.shape[3]) inputs = tf.layers.average_pooling2d(inputs=inputs, pool_size=pool_size, strides=1, padding='VALID', data_format=data_format, name='final_avg_pool') inputs = tf.identity(inputs, 'final_avg_pool') inputs = tf.reshape(inputs, [-1, last_block_filters]) kernel_initializer = tf.variance_scaling_initializer() kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) if prune_last_layer: inputs = sparse_fully_connected( x=inputs, units=num_classes, sparsity_technique=pruning_method if prune_last_layer else 'baseline', kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='final_dense') else: inputs = tf.layers.dense(inputs=inputs, units=num_classes, activation=None, use_bias=True, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='final_dense') inputs = tf.identity(inputs, 'final_dense') return inputs
def inverted_res_block_( inputs, filters, is_training, stride, width=1., expansion_factor=6., block_id=0, pruning_method='baseline', data_format='channels_first', weight_decay=0., ): """Standard building block for mobilenetv2 networks. Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. filters: Int specifying number of filters for the first two convolutions. is_training: Boolean specifying whether the model is training. stride: Int specifying the stride. If stride >1, the input is downsampled. width: multiplier for channel dimensions expansion_factor: How much to increase the filters before the depthwise conv. block_id: which block this is pruning_method: String that specifies the pruning method used to identify which weights to remove. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. weight_decay: Weight for the l2 regularization loss. Returns: The output activation tensor. """ # 1x1 expanded conv, followed by separable_conv_2d followed by # contracting 1x1 conv. shortcut = inputs if data_format == 'channels_first': prev_depth = inputs.get_shape().as_list()[1] elif data_format == 'channels_last': prev_depth = inputs.get_shape().as_list()[3] else: raise ValueError('Unknown data_format ' + data_format) # Expand multiplier = expansion_factor if block_id > 0 else 1 # skip the expansion if this is the first block if block_id: end_point = 'expand_1x1_%s' % block_id inputs = conv2d_fixed_padding(inputs=inputs, filters=int(multiplier * prev_depth), kernel_size=1, strides=1, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay, name=end_point) inputs = resnet_model.batch_norm_relu(inputs, is_training, relu=True, data_format=data_format) end_point = 'depthwise_nxn_%s' % block_id # Depthwise depthwise_out = depthwise_conv2d_fixed_padding(inputs=inputs, kernel_size=3, stride=stride, data_format=data_format, name=end_point) depthwise_out = resnet_model.batch_norm_relu(depthwise_out, is_training, relu=True, data_format=data_format) # Contraction end_point = 'contraction_1x1_%s' % block_id divisible_by = 8 if block_id == 0: divisible_by = 1 out_filters = _make_divisible(int(width * filters), divisor=divisible_by) contraction_out = conv2d_fixed_padding(inputs=depthwise_out, filters=out_filters, kernel_size=1, strides=1, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay, name=end_point) contraction_out = resnet_model.batch_norm_relu(contraction_out, is_training, relu=False, data_format=data_format) output = contraction_out if prev_depth == out_filters and stride == 1: output += shortcut return output