def resnet(inputs, model, num_classes, model_type='vanilla', base_channels=16, scope='resnet_residual'): """Builds a CIFAR-10 resnet model.""" num_blocks = 3 num_units = model if len(num_units) == 1: num_units *= num_blocks assert len(num_units) == num_blocks b = resnet_utils.Block bc = base_channels blocks = [ b('block_1', residual, [(bc, 1, True)] + [(bc, 1, False)] * (num_units[0] - 1)), b('block_2', residual, [(2 * bc, 2, False)] + [(2 * bc, 1, False)] * (num_units[1] - 1)), b('block_3', residual, [(4 * bc, 2, False)] + [(4 * bc, 1, False)] * (num_units[2] - 1)) ] with tf.variable_scope(scope, [inputs]): end_points = {'inputs': inputs} end_points['flops'] = 0 net = inputs net, current_flops = flopsometer.conv2d(net, bc, 3, activation_fn=None, normalizer_fn=None) end_points['flops'] += current_flops net, end_points = resnet_act.stack_blocks(net, blocks, model_type=model_type, end_points=end_points) net = tf.reduce_mean(net, [1, 2], keep_dims=True) net = slim.batch_norm(net) net, current_flops = flopsometer.conv2d(net, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, scope='logits') end_points['flops'] += current_flops net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') return net, end_points
def testConv2d(self): inputs = tf.zeros([2, 16, 16, 4]) _, flops = flopsometer.conv2d(inputs, 8, [3, 3], stride=1, padding='SAME', output_mask=None) expected_flops = 2 * 16 * 16 * 3 * 3 * 8 * 4 with self.test_session() as sess: sess.run(tf.global_variables_initializer()) flops_out = sess.run(flops) self.assertAllEqual(flops_out, [expected_flops, expected_flops])
def resnet_v2(inputs, blocks, num_classes=None, global_pool=True, model_type='vanilla', scope=None, reuse=None, end_points=None): with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc: if end_points is None: end_points = {} end_points['inputs'] = inputs end_points['flops'] = end_points.get('flops', 0) net = inputs # We do not include batch normalization or activation functions in conv1 # because the first ResNet unit will perform these. Cf. Appendix of [2]. with slim.arg_scope([slim.conv2d], activation_fn=None, normalizer_fn=None): net, current_flops = flopsometer.conv2d_same(net, 64, 7, stride=2, scope='conv1') end_points['flops'] += current_flops net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') # Early stopping is broken in distributed training. net, end_points = resnet_act.stack_blocks(net, blocks, model_type=model_type, end_points=end_points) if global_pool or num_classes is not None: # This is needed because the pre-activation variant does not have batch # normalization or activation functions in the residual unit output. See # Appendix of [2]. net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm') if global_pool: # Global average pooling. net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) if num_classes is not None: net, current_flops = flopsometer.conv2d(net, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, scope='logits') end_points['flops'] += current_flops end_points['predictions'] = slim.softmax(net, scope='predictions') return net, end_points
def testConv2dUnknownSize(self): inputs = np.zeros([2, 16, 16, 4], dtype=np.float32) inputs_tf = tf.placeholder(tf.float32, shape=(2, None, None, 4)) _, flops = flopsometer.conv2d(inputs_tf, 8, [3, 3], stride=1, padding='SAME', output_mask=None) expected_flops = 2 * 16 * 16 * 3 * 3 * 8 * 4 with self.test_session() as sess: sess.run(tf.global_variables_initializer()) flops_out = sess.run(flops, feed_dict={inputs_tf: inputs}) self.assertAllEqual(flops_out, [expected_flops, expected_flops])
def get_halting_proba_conv(outputs, residual_mask=None): with tf.variable_scope('halting_proba'): flops = 0 x = outputs local_feature = slim.batch_norm(x, scope='local_bn') halting_logit, current_flops = flopsometer.conv2d( local_feature, 1, SACT_KERNEL_SIZE, activation_fn=None, normalizer_fn=None, biases_initializer=tf.constant_initializer(INIT_BIAS), output_mask=residual_mask, scope='local_conv') flops += current_flops # Add global halting logit. global_feature = tf.reduce_mean(x, [1, 2], keep_dims=True) global_feature = slim.batch_norm(global_feature, scope='global_bn') halting_logit_global, current_flops = flopsometer.conv2d( global_feature, 1, 1, activation_fn=None, normalizer_fn=None, biases_initializer= None, # biases are already present in local logits scope='global_conv') flops += current_flops # Addition with broadcasting over spatial dimensions. halting_logit += halting_logit_global halting_proba = tf.sigmoid(halting_logit) return halting_proba, flops
def get_halting_proba(outputs): with tf.variable_scope('halting_proba'): x = outputs x = tf.reduce_mean(x, [1, 2], keep_dims=True) x = slim.batch_norm(x, scope='global_bn') halting_proba, flops = flopsometer.conv2d( x, 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None, biases_initializer=tf.constant_initializer(INIT_BIAS), scope='global_conv') halting_proba = tf.squeeze(halting_proba, [1, 2]) return halting_proba, flops
def testConv2dOutputMask(self): inputs = tf.zeros([2, 16, 16, 4]) mask = np.random.random([2, 16, 16]) <= 0.6 mask_tf = tf.constant(np.float32(mask)) _, flops = flopsometer.conv2d(inputs, 8, [3, 3], stride=1, padding='SAME', output_mask=mask_tf) per_position_flops = 2 * 3 * 3 * 8 * 4 num_positions = np.sum(np.sum(np.int32(mask), 2), 1) expected_flops = [ per_position_flops * num_positions[0], per_position_flops * num_positions[1] ] with self.test_session() as sess: sess.run(tf.global_variables_initializer()) flops_out = sess.run(flops) self.assertAllEqual(flops_out, expected_flops)
def residual(inputs, depth, stride, activate_before_residual, residual_mask=None, scope=None): with tf.variable_scope(scope, 'residual', [inputs]): depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) preact = slim.batch_norm(inputs, scope='preact') if activate_before_residual: shortcut = preact else: shortcut = inputs if residual_mask is not None: # Max-pooling trick only works correctly when stride is 1. # We assume that stride=2 happens in the first layer where # residual_mask is None. assert stride == 1 diluted_residual_mask = slim.max_pool2d(residual_mask, [3, 3], stride=1, padding='SAME') else: diluted_residual_mask = None flops = 0 conv_output, current_flops = flopsometer.conv2d( preact, depth, 3, stride=stride, padding='SAME', output_mask=diluted_residual_mask, scope='conv1') flops += current_flops conv_output, current_flops = flopsometer.conv2d( conv_output, depth, 3, stride=1, padding='SAME', activation_fn=None, normalizer_fn=None, output_mask=residual_mask, scope='conv2') flops += current_flops if depth_in != depth: shortcut = slim.avg_pool2d(shortcut, stride, stride, padding='VALID') value = (depth - depth_in) // 2 shortcut = tf.pad(shortcut, [[0, 0], [0, 0], [0, 0], [value, value]]) if residual_mask is not None: conv_output *= residual_mask outputs = shortcut + conv_output return outputs, flops
def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, residual_mask=None, scope=None): with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc: flops = 0 depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact') if depth == depth_in: shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') else: shortcut, current_flops = flopsometer.conv2d(preact, depth, [1, 1], stride=stride, normalizer_fn=None, activation_fn=None, scope='shortcut') flops += current_flops if residual_mask is not None: # Max-pooling trick only works correctly when stride is 1. # We assume that stride=2 happens in the first layer where # residual_mask is None. assert stride == 1 diluted_residual_mask = slim.max_pool2d(residual_mask, [3, 3], stride=1, padding='SAME') else: diluted_residual_mask = None residual, current_flops = flopsometer.conv2d( preact, depth_bottleneck, [1, 1], stride=1, output_mask=diluted_residual_mask, scope='conv1') flops += current_flops residual, current_flops = flopsometer.conv2d_same( residual, depth_bottleneck, 3, stride, rate=rate, output_mask=residual_mask, scope='conv2') flops += current_flops residual, current_flops = flopsometer.conv2d(residual, depth, [1, 1], stride=1, normalizer_fn=None, activation_fn=None, output_mask=residual_mask, scope='conv3') flops += current_flops if residual_mask is not None: residual *= residual_mask outputs = shortcut + residual return outputs, flops