def testIsBroadcast(self): handler = grouping_op_handler.GroupingOpHandler() self.op_group_dict = {} # Size is not 1. self.assertFalse(handler._is_broadcast(self.batch_norm_op, self.mock_op_reg_manager)) # Size is 1 but op is not grouped. ungrouped_broadcast_input = tf.zeros([2, 4, 4, 1]) ungrouped_broadcast_input_slice = orm.OpSlice(ungrouped_broadcast_input, orm.Slice(0, 1)) self.op_slice_dict[ungrouped_broadcast_input.op] = [ ungrouped_broadcast_input_slice] self.assertFalse(handler._is_broadcast(ungrouped_broadcast_input.op, self.mock_op_reg_manager)) # Size is 1 and op is grouped. broadcast_input = tf.zeros([2, 4, 4, 1]) broadcast_input_slice = orm.OpSlice(broadcast_input.op, orm.Slice(0, 1)) self.op_slice_dict[broadcast_input.op] = [broadcast_input_slice] broadcast_input_group = orm.OpGroup(broadcast_input_slice) self.op_group_dict[broadcast_input_slice] = broadcast_input_group self.assertTrue(handler._is_broadcast(broadcast_input.op, self.mock_op_reg_manager))
def setUp(self): super(LogisticSigmoidSourceOpHandlerTest, self).setUp() tf.reset_default_graph() # This tests a Conv2D -> ReLU -> LogisticSigmoidGating chain of ops. inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') activation_gating.logistic_sigmoid_gating(c1, axis=3, is_training=True) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops that are created in the test network. self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5)) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5)) self.relu_op_group = orm.OpGroup(self.relu_op_slice) self.activation_gating_op = g.get_operation_by_name( 'logistic_sigmoid_gating/LogisticSigmoidGating') self.activation_gating_op_slice = orm.OpSlice( self.activation_gating_op, orm.Slice(0, 5)) self.activation_gating_op_group = orm.OpGroup( self.activation_gating_op_slice) self.mask_logits_op = g.get_operation_by_name( 'logistic_sigmoid_gating/mask_logits/Read/ReadVariableOp') self.mask_logits_op_slice = orm.OpSlice(self.mask_logits_op, orm.Slice(0, 5)) self.mask_logits_op_group = orm.OpGroup( self.mask_logits_op_slice, omit_source_op_slices=[self.mask_logits_op_slice]) self.multiply_op = g.get_operation_by_name( 'logistic_sigmoid_gating/Identity') self.multiply_op_slice = orm.OpSlice(self.multiply_op, orm.Slice(0, 5)) self.multiply_op_group = orm.OpGroup( self.multiply_op_slice, omit_source_op_slices=[self.multiply_op_slice]) # Create custom mapping of OpSlice and OpGroup in manager. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.conv_op, self.relu_op, self.activation_gating_op, self.multiply_op ]
def testGetOpsWithoutGroups(self): # For a list of ops, verify that ops without groups are returned. self.op_slice_dict = { self.batch_norm_op: [self.batch_norm_op_slice], self.conv_op: [self.conv_op_slice], self.gamma_op: [orm.OpSlice(self.gamma_op, None)], self.beta_op: [orm.OpSlice(self.beta_op, None)], self.decay_op: [orm.OpSlice(self.decay_op, None)], self.epsilon_op: [orm.OpSlice(self.epsilon_op, None)], } # Only batch norm and conv ops have groups. self.op_group_dict = { self.batch_norm_op_slice: self.batch_norm_op_group, self.conv_op_slice: self.conv_op_group } all_ops = [ self.batch_norm_op, self.conv_op, self.gamma_op, self.beta_op, self.decay_op, self.epsilon_op ] # Batch norm and conv ops have groups. The other ops do not have groups. expected_ops = [ self.gamma_op, self.beta_op, self.decay_op, self.epsilon_op ] self.assertEqual( expected_ops, op_handler_util.get_ops_without_groups(all_ops, self.mock_op_reg_manager))
def testGetOpSlices(self): # Generic ops are treated as a concatenation of their constituent OpSlice. batch_norm_op_slice_0_5 = orm.OpSlice(self.unfused_batch_norm_op, orm.Slice(0, 5)) batch_norm_op_slice_5_11 = orm.OpSlice(self.unfused_batch_norm_op, orm.Slice(5, 6)) batch_norm_op_slice_11_18 = orm.OpSlice(self.unfused_batch_norm_op, orm.Slice(11, 7)) # Map ops to slices. self.op_slice_dict = { self.unfused_batch_norm_op: [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_11, batch_norm_op_slice_11_18 ], } # A nested list composed of a list of OpSlice for each output op. In this # case, there is just one output op (i.e. batch norm). expected_output_op_slices = [[ batch_norm_op_slice_0_5, batch_norm_op_slice_5_11, batch_norm_op_slice_11_18 ]] output_ops = op_handler_util.get_output_ops(self.concat_op, self.mock_op_reg_manager) self.assertEqual( expected_output_op_slices, op_handler_util.get_op_slices(output_ops, self.mock_op_reg_manager))
def setUp(self): super(GroupingOpHandlerTest, self).setUp() tf.reset_default_graph() # This tests a Conv2D -> BatchNorm -> ReLU chain of ops. with framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.batch_norm_op = g.get_operation_by_name( 'conv1/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5)) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5)) self.relu_op_group = orm.OpGroup(self.relu_op_slice) self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read') self.gamma_op_slice = orm.OpSlice(self.gamma_op, orm.Slice(0, 5)) self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read') self.beta_op_slice = orm.OpSlice(self.beta_op, orm.Slice(0, 5)) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) self.op_slice_dict = { self.batch_norm_op: [self.batch_norm_op_slice], self.conv_op: [self.conv_op_slice], self.relu_op: [self.relu_op_slice], self.gamma_op: [self.gamma_op_slice], self.beta_op: [self.beta_op_slice], } def get_op_slices(op): return self.op_slice_dict.get(op) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op, self.beta_op ]
def testMatMul2D(self, size): inputs = tf.zeros((13, 2)) handler = matmul_source_op_handler.MatMulSourceOpHandler(0.1) kernel = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32) x = tf.matmul(inputs, kernel, transpose_b=False, name='MatMul') op_slice = orm.OpSlice(x.op, orm.Slice(0, size)) transpose_kernel = tf.constant([[1, 4], [2, 5], [3, 6]], dtype=tf.float32) x_other = tf.matmul( inputs, transpose_kernel, transpose_b=True, name='MatMulTransposedKernel') op_slice_other = orm.OpSlice(x_other.op, orm.Slice(0, size)) self.assertAllClose( handler.create_regularizer(op_slice).regularization_vector, handler.create_regularizer(op_slice_other).regularization_vector)
def testResliceConcatOps_NotAligned(self): relu3_op_slice_0_3 = orm.OpSlice(self.relu3_op, orm.Slice(0, 3)) relu3_op_slice_3_6 = orm.OpSlice(self.relu3_op, orm.Slice(3, 3)) # Map ops to slices. The op c3 is composed of multiple slices. self.op_slice_dict = { self.relu2_op: [self.relu2_op_slice], self.relu3_op: [relu3_op_slice_0_3, relu3_op_slice_3_6], self.relu4_op: [self.relu4_op_slice], } op_handler_util.reslice_concat_ops( [self.relu2_op, self.relu3_op, self.relu4_op], [5, 4, 2, 2, 5], self.mock_op_reg_manager) # Verify manager slices input ops. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.relu3_op, [4, 2]), mock.call(self.relu4_op, [2, 5]) ])
def testCreateRegularizer_Sliced(self): # Call handler to create regularizer. handler = conv2d_source_op_handler.Conv2DSourceOpHandler( _DEFAULT_THRESHOLD) conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 3)) regularizer = handler.create_regularizer(conv2_op_slice) # Verify regularizer produces correctly shaped tensors. # Most of the regularizer testing is in group_lasso_regularizer_test.py expected_norm_dim = 3 self.assertEqual(expected_norm_dim, regularizer.regularization_vector.shape.as_list()[0])
def setUp(self): tf.reset_default_graph() # Test a Identity -> DepthToSpace -> Identity chain of ops. inputs = tf.zeros([2, 4, 4, 4]) id1 = tf.identity(inputs) dts = tf.depth_to_space(id1, 2) tf.identity(dts) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.id1_op = g.get_operation_by_name('Identity') self.id1_op_slice = orm.OpSlice(self.id1_op, orm.Slice(0, 4)) self.id1_op_group = orm.OpGroup( self.id1_op_slice, omit_source_op_slices=[self.id1_op_slice]) self.id1_op_slice0 = orm.OpSlice(self.id1_op, orm.Slice(0, 1)) self.id1_op_slice1 = orm.OpSlice(self.id1_op, orm.Slice(1, 1)) self.id1_op_slice2 = orm.OpSlice(self.id1_op, orm.Slice(2, 1)) self.id1_op_slice3 = orm.OpSlice(self.id1_op, orm.Slice(3, 1)) self.dts_op = g.get_operation_by_name('DepthToSpace') self.dts_op_slice = orm.OpSlice(self.dts_op, orm.Slice(0, 1)) self.dts_op_group = orm.OpGroup( self.dts_op_slice, omit_source_op_slices=[self.dts_op_slice]) self.id2_op = g.get_operation_by_name('Identity_1') self.id2_op_slice = orm.OpSlice(self.id2_op, orm.Slice(0, 1)) self.id2_op_group = orm.OpGroup( self.id2_op_slice, omit_source_op_slices=[self.id2_op_slice]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) self.op_slice_dict = { self.id1_op: [self.id1_op_slice], self.dts_op: [self.dts_op_slice], self.id2_op: [self.id2_op_slice], } def get_op_slices(op): return self.op_slice_dict.get(op) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [self.id1_op, self.dts_op, self.id2_op]
def testCreateRegularizer_Sliced(self): # Call handler to create regularizer. handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( _GAMMA_THRESHOLD) batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 3)) regularizer = handler.create_regularizer(batch_norm_op_slice) # Verify regularizer is the gamma tensor. with self.cached_session(): # Initialize the gamma tensor to check value equality. with tf.variable_scope('', reuse=tf.AUTO_REUSE): gamma_tensor = tf.get_variable('conv1/BatchNorm/gamma') init = tf.variables_initializer([gamma_tensor]) init.run() # Verify regularizer is the sliced gamma tensor. self.assertAllEqual(gamma_tensor.eval()[0:3], regularizer._gamma.eval())
def testOpHandlerDecorator(self): image = tf.constant(0.0, shape=[1, 17, 19, 3]) kernel = tf.ones([5, 5, 3, 3]) output = tf.nn.conv2d(image, kernel, strides=[1, 1, 1, 1], padding='SAME') decorated_op_handler = op_handler_decorator.OpHandlerDecorator( conv_source_op_handler.ConvSourceOpHandler(1e-3, 0), DummyDecorator) op_slice = orm.OpSlice(output.op, orm.Slice(0, 3)) regularizer = decorated_op_handler.create_regularizer(op_slice) self.assertAllClose(0.5 * np.ones(3), regularizer.regularization_vector) self.assertAllClose(np.ones(3), regularizer.alive_vector)
def testCreateRegularizer_OnMask_Sliced(self): # Call handler to create regularizer. handler = ls_source_op_handler.LogisticSigmoidSourceOpHandler( regularize_on_mask=True) activation_gating_op_slice = orm.OpSlice(self.activation_gating_op, orm.Slice(0, 3)) regularizer = handler.create_regularizer(activation_gating_op_slice) g = tf.get_default_graph() mask_tensor = g.get_tensor_by_name( 'logistic_sigmoid_gating/LogisticSigmoidGating:0') with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) mask, reg_vec = sess.run( [mask_tensor, regularizer._regularization_vector]) # Verify regularizer is the sliced mask tensor. self.assertAllEqual(mask[0:3], reg_vec)
def testCreateRegularizer_OnLogits_Sliced(self): # Call handler to create regularizer. handler = ls_source_op_handler.LogisticSigmoidSourceOpHandler( regularize_on_mask=False) activation_gating_op_slice = orm.OpSlice(self.activation_gating_op, orm.Slice(0, 3)) regularizer = handler.create_regularizer(activation_gating_op_slice) g = tf.get_default_graph() logits_tensor = g.get_tensor_by_name( 'logistic_sigmoid_gating/mask_logits/Read/ReadVariableOp:0') with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) logits, reg_vec = sess.run( [logits_tensor, regularizer._regularization_vector]) # Verify regularizer is the sliced probability tensor. self.assertAllEqual(sp.special.expit(logits[0:3]), reg_vec)
def testGetOpSliceSizes(self): relu3_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3)) relu3_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3)) batch_norm_op_slice_0_5 = orm.OpSlice(self.unfused_batch_norm_op, orm.Slice(0, 5)) batch_norm_op_slice_5_8 = orm.OpSlice(self.unfused_batch_norm_op, orm.Slice(5, 3)) batch_norm_op_slice_8_11 = orm.OpSlice(self.unfused_batch_norm_op, orm.Slice(8, 3)) batch_norm_op_slice_11_18 = orm.OpSlice(self.unfused_batch_norm_op, orm.Slice(11, 7)) # Map ops to slices. self.op_slice_dict = { self.relu2_op: [self.relu2_op_slice], self.relu3_op: [relu3_op_slice_0_3, relu3_op_slice_3_6], self.relu4_op: [self.relu4_op_slice], self.unfused_batch_norm_op: [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_8, batch_norm_op_slice_8_11, batch_norm_op_slice_11_18 ], } expected_op_slice_sizes = [ [5], # c2 has size 5. [3, 3], # c3 has size 6, but in 2 slices of size 3. [7], # c4 has size 7. [5, 3, 3, 7] ] # batch norm has size 18, but slice sizes of c1, c2, c3. self.assertEqual( expected_op_slice_sizes, op_handler_util.get_op_slice_sizes( [[self.relu2_op_slice], [relu3_op_slice_0_3, relu3_op_slice_3_6], [self.relu4_op_slice], [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_8, batch_norm_op_slice_8_11, batch_norm_op_slice_11_18 ]]))
def setUp(self): tf.reset_default_graph() # This tests a Conv2D -> BatchNorm -> ReLU chain of ops. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops that are created in the test network. self.batch_norm_op = g.get_operation_by_name( 'conv1/BatchNorm/FusedBatchNorm') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5)) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5)) self.relu_op_group = orm.OpGroup(self.relu_op_slice) self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read') self.gamma_op_slice = orm.OpSlice(self.gamma_op, orm.Slice(0, 5)) self.gamma_op_group = orm.OpGroup( self.gamma_op_slice, omit_source_op_slices=[self.gamma_op_slice]) self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read') self.beta_op_slice = orm.OpSlice(self.beta_op, orm.Slice(0, 5)) self.beta_op_group = orm.OpGroup( self.beta_op_slice, omit_source_op_slices=[self.beta_op_slice]) self.mean_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg/sub_1') self.mean_op_slice = orm.OpSlice(self.mean_op, orm.Slice(0, 5)) self.mean_op_group = orm.OpGroup( self.mean_op_slice, omit_source_op_slices=[self.mean_op_slice]) self.std_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg_1/sub_1') self.std_op_slice = orm.OpSlice(self.std_op, orm.Slice(0, 5)) self.std_op_group = orm.OpGroup( self.std_op_slice, omit_source_op_slices=[self.std_op_slice]) # Create custom mapping of OpSlice and OpGroup in manager. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op, self.beta_op, self.mean_op, self.std_op ]
def setUp(self): super(OpHandlerUtilTest, self).setUp() tf.reset_default_graph() # This tests a Conv2D -> BatchNorm -> ReLU chain of ops. with framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') # This tests 3 Conv2D ops being concatenated before a batch normalization. c2 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv2') c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3') c4 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv4') net = tf.concat([c2, c3, c4], axis=3) layers.batch_norm(net) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops in the first test network. self.batch_norm_op = g.get_operation_by_name( 'conv1/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, None) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, None) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read') self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read') self.decay_op = g.get_operation_by_name('conv1/BatchNorm/Const') self.epsilon_op = g.get_operation_by_name('conv1/BatchNorm/Const_1') self.mean_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg/sub_1') self.std_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg_1/sub_1') self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, None) self.relu_op_group = orm.OpGroup( self.relu_op_slice, omit_source_op_slices=[self.relu_op_slice]) # Declare OpSlice and OpGroup for ops in the second test network. self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 5)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.relu3_op = g.get_operation_by_name('conv3/Relu') self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6)) self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) self.relu4_op = g.get_operation_by_name('conv4/Relu') self.relu4_op_slice = orm.OpSlice(self.relu4_op, orm.Slice(0, 7)) self.relu4_op_group = orm.OpGroup( self.relu4_op_slice, omit_source_op_slices=[self.relu4_op_slice]) self.unfused_batch_norm_op = g.get_operation_by_name( 'BatchNorm/FusedBatchNormV3') self.unfused_batch_norm_op_slice = orm.OpSlice( self.unfused_batch_norm_op, orm.Slice(0, 18)) self.concat_op = g.get_operation_by_name('concat') self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18)) self.concat_op_group = orm.OpGroup( self.concat_op_slice, omit_source_op_slices=[self.concat_op_slice]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) def is_passthrough(op): return op in self._passthrough_ops self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough self.mock_op_reg_manager.ops = [ self.batch_norm_op, self.gamma_op, self.beta_op, self.decay_op, self.epsilon_op, self.mean_op, self.std_op, self.conv_op, self.relu_op, self.relu2_op, self.relu3_op, self.relu4_op, self.unfused_batch_norm_op, self.concat_op ]
def testGroupOpWithInputsAndOutputs_MultipleSlices(self): # For the multiple slice case, verify that batch norm slices are grouped # with output slices (ReLU) and input slices (Conv2D). batch_norm_op_slice_0_2 = orm.OpSlice(self.batch_norm_op, orm.OpSlice(0, 2)) batch_norm_op_slice_2_5 = orm.OpSlice(self.batch_norm_op, orm.OpSlice(2, 3)) batch_norm_op_group1 = orm.OpGroup(batch_norm_op_slice_0_2) batch_norm_op_group2 = orm.OpGroup(batch_norm_op_slice_2_5) conv_op_slice_0_2 = orm.OpSlice(self.conv_op, orm.OpSlice(0, 2)) conv_op_slice_2_5 = orm.OpSlice(self.conv_op, orm.OpSlice(2, 3)) conv_op_group1 = orm.OpGroup(conv_op_slice_0_2, omit_source_op_slices=[conv_op_slice_0_2]) conv_op_group2 = orm.OpGroup(conv_op_slice_2_5, omit_source_op_slices=[conv_op_slice_2_5]) relu_op_slice_0_2 = orm.OpSlice(self.relu_op, orm.OpSlice(0, 2)) relu_op_slice_2_5 = orm.OpSlice(self.relu_op, orm.OpSlice(2, 3)) relu_op_group1 = orm.OpGroup(relu_op_slice_0_2) relu_op_group2 = orm.OpGroup(relu_op_slice_2_5) aligned_op_slice_sizes = [2, 3] self.op_slice_dict = { self.batch_norm_op: [batch_norm_op_slice_0_2, batch_norm_op_slice_2_5], self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_5], self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_5], } # All ops have groups. self.op_group_dict = { batch_norm_op_slice_0_2: batch_norm_op_group1, batch_norm_op_slice_2_5: batch_norm_op_group2, conv_op_slice_0_2: conv_op_group1, conv_op_slice_2_5: conv_op_group2, relu_op_slice_0_2: relu_op_group1, relu_op_slice_2_5: relu_op_group2, } ops_grouped = op_handler_util.group_op_with_inputs_and_outputs( self.batch_norm_op, [[conv_op_slice_0_2, conv_op_slice_2_5]], [[relu_op_slice_0_2, relu_op_slice_2_5]], aligned_op_slice_sizes, self.mock_op_reg_manager) # Verify manager looks up op slice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_any_call( self.batch_norm_op) # Verify manager groups batch norm with Conv2D and ReLU ops. self.assertTrue(ops_grouped) self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([batch_norm_op_slice_0_2, relu_op_slice_0_2]), mock.call([batch_norm_op_slice_0_2, conv_op_slice_0_2]), mock.call([batch_norm_op_slice_2_5, relu_op_slice_2_5]), mock.call([batch_norm_op_slice_2_5, conv_op_slice_2_5]) ])
def setUp(self): tf.reset_default_graph() # This tests 2 Conv2D ops with batch norm at the top. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1', normalizer_fn=None) layers.conv2d(c1, num_outputs=6, kernel_size=3, scope='conv2') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.conv1_op = g.get_operation_by_name('conv1/Conv2D') self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5)) self.conv1_op_group = orm.OpGroup( self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.conv2_op = g.get_operation_by_name('conv2/Conv2D') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6)) self.conv2_op_group = orm.OpGroup( self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.batch_norm_op = g.get_operation_by_name( 'conv2/BatchNorm/FusedBatchNorm') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) def is_passthrough(op): if op in [self.conv1_op, self.conv2_op]: h = output_non_passthrough_op_handler.OutputNonPassthroughOpHandler() return h.is_passthrough if op == self.batch_norm_op: return True else: return False self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough self.mock_op_reg_manager.ops = [ self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op, self.batch_norm_op]
def testAssignGrouping_AllNeighborsGrouped_InputSlicesNotAligned(self): # In this test, the op c2 has size 6 but is split into 2 slices of size 3. # The concat op (and its output, the batch norm) both have size 18. This # test verifies that the concat and batch norm are sliced according to the # sizes of c1, c2, and c3, and takes into account that c2 is also composed # of multiple slices. concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) concat_op_slice_5_8 = orm.OpSlice(self.concat_op, orm.Slice(5, 3)) concat_op_slice_8_11 = orm.OpSlice(self.concat_op, orm.Slice(8, 3)) concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7)) relu2_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3)) relu2_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3)) relu2_op_group1 = orm.OpGroup( relu2_op_slice_0_3, omit_source_op_slices=[relu2_op_slice_0_3]) relu2_op_group2 = orm.OpGroup( relu2_op_slice_3_6, omit_source_op_slices=[relu2_op_slice_3_6]) batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18)) batch_norm_op_group = orm.OpGroup( batch_norm_op_slice, omit_source_op_slices=[batch_norm_op_slice]) batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) batch_norm_op_group1 = orm.OpGroup( batch_norm_op_slice_0_5, omit_source_op_slices=[batch_norm_op_slice_0_5]) batch_norm_op_slice_5_8 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 3)) batch_norm_op_group2 = orm.OpGroup( batch_norm_op_slice_5_8, omit_source_op_slices=[batch_norm_op_slice_5_8]) batch_norm_op_slice_8_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(8, 3)) batch_norm_op_group3 = orm.OpGroup( batch_norm_op_slice_8_11, omit_source_op_slices=[batch_norm_op_slice_8_11]) batch_norm_op_slice_11_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 7)) batch_norm_op_group4 = orm.OpGroup( batch_norm_op_slice_11_18, omit_source_op_slices=[batch_norm_op_slice_11_18]) # Map ops to slices. The op c2 is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [relu2_op_slice_0_3, relu2_op_slice_3_6], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [batch_norm_op_slice], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, relu2_op_slice_0_3: relu2_op_group1, relu2_op_slice_3_6: relu2_op_group2, self.relu3_op_slice: self.relu3_op_group, batch_norm_op_slice: batch_norm_op_group, batch_norm_op_slice_0_5: batch_norm_op_group1, batch_norm_op_slice_5_8: batch_norm_op_group2, batch_norm_op_slice_8_11: batch_norm_op_group3, batch_norm_op_slice_11_18: batch_norm_op_group4, } # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_8, batch_norm_op_slice_8_11, batch_norm_op_slice_11_18 ] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ concat_op_slice_0_5, concat_op_slice_5_8, concat_op_slice_8_11, concat_op_slice_11_18 ] self.mock_op_reg_manager.slice_op.side_effect = slice_op # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op) ]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.batch_norm_op, [5, 3, 3, 7]), mock.call(self.concat_op, [5, 3, 3, 7]) ]) # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ concat_op_slice_0_5, self.relu1_op_slice, batch_norm_op_slice_0_5 ]), mock.call([ concat_op_slice_5_8, relu2_op_slice_0_3, batch_norm_op_slice_5_8 ]), mock.call([ concat_op_slice_8_11, relu2_op_slice_3_6, batch_norm_op_slice_8_11 ]), mock.call([ concat_op_slice_11_18, self.relu3_op_slice, batch_norm_op_slice_11_18 ]) ])
def setUp(self): tf.reset_default_graph() # This tests 2 Conv2D ops. inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1', normalizer_fn=None) layers.conv2d(c1, num_outputs=6, kernel_size=3, scope='conv2', normalizer_fn=None) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.conv1_op = g.get_operation_by_name('conv1/Conv2D') self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5)) self.conv1_op_group = orm.OpGroup( self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.conv2_op = g.get_operation_by_name('conv2/Conv2D') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6)) self.conv2_op_group = orm.OpGroup( self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice]) self.conv2_weights_op = g.get_operation_by_name('conv2/weights/read') self.conv2_weights_op_slice = orm.OpSlice(self.conv2_weights_op, orm.Slice(0, 6)) self.conv2_weights_op_group = orm.OpGroup( self.conv2_weights_op_slice, omit_source_op_slices=[self.conv2_weights_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) def is_passthrough(op): if op in [self.conv1_op, self.conv2_op]: h = conv2d_source_op_handler.Conv2DSourceOpHandler( _DEFAULT_THRESHOLD) return h.is_passthrough else: return False self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough self.mock_op_reg_manager.ops = [ self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op, self.conv2_weights_op ]
def testAssignGrouping_NoDepthMultiplier(self): # Repeat setUp, but with depth_multiplier=1. Unfortunately, this involves # rebuilding the graph from scratch. tf.reset_default_graph() # This tests a Conv2D -> SeparableConv2D -> Conv2D chain of ops. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') c2 = layers.separable_conv2d(c1, num_outputs=8, kernel_size=3, depth_multiplier=1, scope='conv2') layers.conv2d(c2, num_outputs=6, kernel_size=3, scope='conv3') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.dwise_conv2_op = g.get_operation_by_name( 'conv2/separable_conv2d/depthwise') self.dwise_conv2_op_slice = orm.OpSlice(self.dwise_conv2_op, orm.Slice(0, 5)) self.conv2_op = g.get_operation_by_name('conv2/separable_conv2d') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 8)) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup(self.relu1_op_slice) self.conv3_op = g.get_operation_by_name('conv3/Conv2D') self.conv3_op_slice = orm.OpSlice(self.conv3_op, orm.Slice(0, 6)) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) self.op_slice_dict = { self.dwise_conv2_op: [self.dwise_conv2_op_slice], self.conv2_op: [self.conv2_op_slice], self.relu1_op: [self.relu1_op_slice], self.conv3_op: [self.conv3_op_slice], } def get_op_slices(op): return self.op_slice_dict.get(op) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.relu1_op, self.dwise_conv2_op, self.conv2_op, self.conv3_op ] # All neighbor ops have groups. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, } # Call handler to assign grouping. handler = depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler( ) handler.assign_grouping(self.dwise_conv2_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.conv2_op), # Initial slice data. mock.call(self.dwise_conv2_op), mock.call(self.relu1_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.dwise_conv2_op), # Refreshing slice data. mock.call(self.relu1_op), # Group depthwise convolution. mock.call(self.dwise_conv2_op) ]) # Verify manager groups batch norm with inputs and outputs. self.mock_op_reg_manager.group_op_slices.assert_called_once_with( [self.dwise_conv2_op_slice, self.relu1_op_slice]) # Verify manager does not process any additional ops. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.conv2_op]) self.mock_op_reg_manager.process_ops_last.assert_not_called()
def setUp(self): tf.reset_default_graph() # This tests 3 Conv2D ops being concatenated. inputs = tf.zeros([2, 4, 4, 3]) with tf.contrib.framework.arg_scope(self._get_scope()): c1 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv1') c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2') c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3') net = tf.concat([c1, c2, c3], axis=2) layers.batch_norm(net) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.concat_op = g.get_operation_by_name('concat') self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 6)) self.concat_op_group = orm.OpGroup( self.concat_op_slice, omit_source_op_slices=[self.concat_op_slice]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 6)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.relu3_op = g.get_operation_by_name('conv3/Relu') self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6)) self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNorm') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6)) self.batch_norm_op_group = orm.OpGroup( self.batch_norm_op_slice, omit_source_op_slices=[self.batch_norm_op_slice]) self.concat_group = orm.OpGroup( op_slice=None, op_groups=[ self.batch_norm_op_group, self.concat_op_group, self.relu1_op_group, self.relu2_op_group, self.relu3_op_group ]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.is_passthrough.return_value = True self.mock_op_reg_manager.ops = [ self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op, self.batch_norm_op]
def setUp(self): tf.reset_default_graph() # This tests 3 Conv2D ops being concatenated. inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2') c3 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv3') net = tf.concat([c1, c2, c3], axis=3) layers.batch_norm(net) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.concat_op = g.get_operation_by_name('concat') self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18)) self.concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) self.concat_op_slice_5_11 = orm.OpSlice(self.concat_op, orm.Slice(5, 6)) self.concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7)) self.concat_op_group1 = orm.OpGroup( self.concat_op_slice_0_5, omit_source_op_slices=[self.concat_op_slice_0_5]) self.concat_op_group2 = orm.OpGroup( self.concat_op_slice_5_11, omit_source_op_slices=[self.concat_op_slice_5_11]) self.concat_op_group3 = orm.OpGroup( self.concat_op_slice_11_18, omit_source_op_slices=[self.concat_op_slice_11_18]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.relu3_op = g.get_operation_by_name('conv3/Relu') self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 7)) self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) self.axis_op = g.get_operation_by_name('concat/axis') self.batch_norm_op = g.get_operation_by_name( 'BatchNorm/FusedBatchNorm') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18)) self.batch_norm_op_group = orm.OpGroup( self.batch_norm_op_slice, omit_source_op_slices=[self.batch_norm_op_slice]) self.batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_slice_5_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 6)) self.batch_norm_op_slice_11_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 7)) self.batch_norm_op_group1 = orm.OpGroup( self.batch_norm_op_slice_0_5, omit_source_op_slices=[self.batch_norm_op_slice_0_5]) self.batch_norm_op_group2 = orm.OpGroup( self.batch_norm_op_slice_5_11, omit_source_op_slices=[self.batch_norm_op_slice_5_11]) self.batch_norm_op_group3 = orm.OpGroup( self.batch_norm_op_slice_11_18, omit_source_op_slices=[self.batch_norm_op_slice_11_18]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ self.batch_norm_op_slice_0_5, self.batch_norm_op_slice_5_11, self.batch_norm_op_slice_11_18 ] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ self.concat_op_slice_0_5, self.concat_op_slice_5_11, self.concat_op_slice_11_18 ] self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.slice_op.side_effect = slice_op self.mock_op_reg_manager.is_passthrough.return_value = True self.mock_op_reg_manager.ops = [ self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op, self.batch_norm_op ]
def testAssignGrouping_AllNeighborsGrouped_OutputSlicesNotAligned(self): # The output (batch norm) has sizes [9, 4, 5] which are not aligned. This # test verifies that the concat, batch norm, and Conv2D ops are sliced in # alignment. concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) concat_op_slice_5_9 = orm.OpSlice(self.concat_op, orm.Slice(5, 4)) concat_op_slice_9_11 = orm.OpSlice(self.concat_op, orm.Slice(9, 2)) concat_op_slice_11_13 = orm.OpSlice(self.concat_op, orm.Slice(11, 2)) concat_op_slice_13_18 = orm.OpSlice(self.concat_op, orm.Slice(13, 5)) relu2_op_slice_0_4 = orm.OpSlice(self.relu2_op, orm.Slice(0, 4)) relu2_op_slice_4_6 = orm.OpSlice(self.relu2_op, orm.Slice(4, 2)) relu3_op_slice_0_2 = orm.OpSlice(self.relu3_op, orm.Slice(0, 2)) relu3_op_slice_2_7 = orm.OpSlice(self.relu3_op, orm.Slice(2, 5)) batch_norm_op_slice_0_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 9)) batch_norm_op_group1 = orm.OpGroup( batch_norm_op_slice_0_9, omit_source_op_slices=[batch_norm_op_slice_0_9]) batch_norm_op_slice_9_13 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 4)) batch_norm_op_group2 = orm.OpGroup( batch_norm_op_slice_9_13, omit_source_op_slices=[batch_norm_op_slice_9_13]) batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(13, 5)) batch_norm_op_group3 = orm.OpGroup( batch_norm_op_slice_13_18, omit_source_op_slices=[batch_norm_op_slice_13_18]) batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) batch_norm_op_group4 = orm.OpGroup( batch_norm_op_slice_0_5, omit_source_op_slices=[batch_norm_op_slice_0_5]) batch_norm_op_slice_5_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 4)) batch_norm_op_group5 = orm.OpGroup( batch_norm_op_slice_5_9, omit_source_op_slices=[batch_norm_op_slice_5_9]) batch_norm_op_slice_9_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 2)) batch_norm_op_group6 = orm.OpGroup( batch_norm_op_slice_9_11, omit_source_op_slices=[batch_norm_op_slice_9_11]) batch_norm_op_slice_11_13 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 2)) batch_norm_op_group7 = orm.OpGroup( batch_norm_op_slice_11_13, omit_source_op_slices=[batch_norm_op_slice_11_13]) batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 5)) batch_norm_op_group8 = orm.OpGroup( batch_norm_op_slice_13_18, omit_source_op_slices=[batch_norm_op_slice_13_18]) # Map ops to slices. Batch norm op is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [ batch_norm_op_slice_0_9, batch_norm_op_slice_9_13, batch_norm_op_slice_13_18 ], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, self.relu2_op_slice: self.relu2_op_group, self.relu3_op_slice: self.relu3_op_group, batch_norm_op_slice_0_9: batch_norm_op_group1, batch_norm_op_slice_9_13: batch_norm_op_group2, batch_norm_op_slice_13_18: batch_norm_op_group3, batch_norm_op_slice_0_5: batch_norm_op_group4, batch_norm_op_slice_5_9: batch_norm_op_group5, batch_norm_op_slice_9_11: batch_norm_op_group6, batch_norm_op_slice_11_13: batch_norm_op_group7, batch_norm_op_slice_13_18: batch_norm_op_group8, } # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_9, batch_norm_op_slice_9_11, batch_norm_op_slice_11_13, batch_norm_op_slice_13_18 ] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ concat_op_slice_0_5, concat_op_slice_5_9, concat_op_slice_9_11, concat_op_slice_11_13, concat_op_slice_13_18 ] if op == self.relu2_op: self.op_slice_dict[self.relu2_op] = [ relu2_op_slice_0_4, relu2_op_slice_4_6 ] if op == self.relu3_op: self.op_slice_dict[self.relu3_op] = [ relu3_op_slice_0_2, relu3_op_slice_2_7 ] self.mock_op_reg_manager.slice_op.side_effect = slice_op # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op) ]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.relu2_op, [4, 2]), mock.call(self.relu3_op, [2, 5]), mock.call(self.batch_norm_op, [5, 4, 2, 2, 5]), mock.call(self.concat_op, [5, 4, 2, 2, 5]) ]) # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ concat_op_slice_0_5, self.relu1_op_slice, batch_norm_op_slice_0_5 ]), mock.call([ concat_op_slice_5_9, relu2_op_slice_0_4, batch_norm_op_slice_5_9 ]), mock.call([ concat_op_slice_9_11, relu2_op_slice_4_6, batch_norm_op_slice_9_11 ]), mock.call([ concat_op_slice_11_13, relu3_op_slice_0_2, batch_norm_op_slice_11_13 ]), mock.call([ concat_op_slice_13_18, relu3_op_slice_2_7, batch_norm_op_slice_13_18 ]) ])
def setUp(self): tf.reset_default_graph() # This tests a Conv2D -> SeparableConv2D -> Conv2D chain of ops. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') c2 = layers.separable_conv2d(c1, num_outputs=8, kernel_size=3, depth_multiplier=2, scope='conv2') layers.conv2d(c2, num_outputs=6, kernel_size=3, scope='conv3') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.dwise_conv2_op = g.get_operation_by_name( 'conv2/separable_conv2d/depthwise') self.dwise_conv2_op_slice = orm.OpSlice(self.dwise_conv2_op, orm.Slice(0, 10)) self.dwise_conv2_op_slice_0_1 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(0, 1)) self.dwise_conv2_op_slice_1_2 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(1, 1)) self.dwise_conv2_op_slice_2_3 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(2, 1)) self.dwise_conv2_op_slice_3_4 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(3, 1)) self.dwise_conv2_op_slice_4_5 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(4, 1)) self.dwise_conv2_op_slice_5_6 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(5, 1)) self.dwise_conv2_op_slice_6_7 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(6, 1)) self.dwise_conv2_op_slice_7_8 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(7, 1)) self.dwise_conv2_op_slice_8_9 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(8, 1)) self.dwise_conv2_op_slice_9_10 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(9, 1)) self.conv2_op = g.get_operation_by_name('conv2/separable_conv2d') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 8)) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_slice_0_1 = orm.OpSlice(self.relu1_op, orm.Slice(0, 1)) self.relu1_op_slice_1_2 = orm.OpSlice(self.relu1_op, orm.Slice(1, 1)) self.relu1_op_slice_2_3 = orm.OpSlice(self.relu1_op, orm.Slice(2, 1)) self.relu1_op_slice_3_4 = orm.OpSlice(self.relu1_op, orm.Slice(3, 1)) self.relu1_op_slice_4_5 = orm.OpSlice(self.relu1_op, orm.Slice(4, 1)) self.relu1_op_group = orm.OpGroup(self.relu1_op_slice) self.conv3_op = g.get_operation_by_name('conv3/Conv2D') self.conv3_op_slice = orm.OpSlice(self.conv3_op, orm.Slice(0, 6)) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) self.op_slice_dict = { self.dwise_conv2_op: [self.dwise_conv2_op_slice], self.conv2_op: [self.conv2_op_slice], self.relu1_op: [self.relu1_op_slice], self.conv3_op: [self.conv3_op_slice], } def get_op_slices(op): return self.op_slice_dict.get(op) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.dwise_conv2_op: self.op_slice_dict[self.dwise_conv2_op] = [ self.dwise_conv2_op_slice_0_1, self.dwise_conv2_op_slice_1_2, self.dwise_conv2_op_slice_2_3, self.dwise_conv2_op_slice_3_4, self.dwise_conv2_op_slice_4_5, self.dwise_conv2_op_slice_5_6, self.dwise_conv2_op_slice_6_7, self.dwise_conv2_op_slice_7_8, self.dwise_conv2_op_slice_8_9, self.dwise_conv2_op_slice_9_10 ] if op == self.relu1_op: self.op_slice_dict[self.relu1_op] = [ self.relu1_op_slice_0_1, self.relu1_op_slice_1_2, self.relu1_op_slice_2_3, self.relu1_op_slice_3_4, self.relu1_op_slice_4_5 ] self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.slice_op.side_effect = slice_op self.mock_op_reg_manager.ops = [ self.relu1_op, self.dwise_conv2_op, self.conv2_op, self.conv3_op ]
def testAssignGrouping_ProcessNeighborGroupsWithSlices(self): batch_norm_op_slice_0_2 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 2)) batch_norm_op_slice_2_3 = orm.OpSlice(self.batch_norm_op, orm.Slice(2, 1)) batch_norm_op_group1 = orm.OpGroup(batch_norm_op_slice_0_2) batch_norm_op_group2 = orm.OpGroup(batch_norm_op_slice_2_3) conv_op_slice_0_2 = orm.OpSlice(self.conv_op, orm.Slice(0, 2)) conv_op_slice_2_3 = orm.OpSlice(self.conv_op, orm.Slice(2, 1)) conv_op_group1 = orm.OpGroup(conv_op_slice_0_2, omit_source_op_slices=[conv_op_slice_0_2]) conv_op_group2 = orm.OpGroup(conv_op_slice_2_3, omit_source_op_slices=[conv_op_slice_2_3]) relu_op_slice_0_2 = orm.OpSlice(self.relu_op, orm.Slice(0, 2)) relu_op_slice_2_3 = orm.OpSlice(self.relu_op, orm.Slice(2, 1)) relu_op_group1 = orm.OpGroup(relu_op_slice_0_2) relu_op_group2 = orm.OpGroup(relu_op_slice_2_3) gamma_op_slice_0_2 = orm.OpSlice(self.gamma_op, orm.Slice(0, 2)) gamma_op_slice_2_3 = orm.OpSlice(self.gamma_op, orm.Slice(2, 1)) gamma_op_group1 = orm.OpGroup( gamma_op_slice_0_2, omit_source_op_slices=[gamma_op_slice_0_2]) gamma_op_group2 = orm.OpGroup( gamma_op_slice_2_3, omit_source_op_slices=[gamma_op_slice_2_3]) beta_op_slice_0_2 = orm.OpSlice(self.beta_op, orm.Slice(0, 2)) beta_op_slice_2_3 = orm.OpSlice(self.beta_op, orm.Slice(2, 1)) beta_op_group1 = orm.OpGroup(beta_op_slice_0_2, omit_source_op_slices=[beta_op_slice_0_2]) beta_op_group2 = orm.OpGroup(beta_op_slice_2_3, omit_source_op_slices=[beta_op_slice_2_3]) mean_op_slice_0_2 = orm.OpSlice(self.mean_op, orm.Slice(0, 2)) mean_op_slice_2_3 = orm.OpSlice(self.mean_op, orm.Slice(2, 1)) mean_op_group1 = orm.OpGroup(mean_op_slice_0_2, omit_source_op_slices=[mean_op_slice_0_2]) mean_op_group2 = orm.OpGroup(mean_op_slice_2_3, omit_source_op_slices=[mean_op_slice_2_3]) std_op_slice_0_2 = orm.OpSlice(self.std_op, orm.Slice(0, 2)) std_op_slice_2_3 = orm.OpSlice(self.std_op, orm.Slice(2, 1)) std_op_group1 = orm.OpGroup(std_op_slice_0_2, omit_source_op_slices=[std_op_slice_0_2]) std_op_group2 = orm.OpGroup(std_op_slice_2_3, omit_source_op_slices=[std_op_slice_2_3]) self.op_slice_dict = { self.batch_norm_op: [batch_norm_op_slice_0_2, batch_norm_op_slice_2_3], self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_3], self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_3], self.gamma_op: [gamma_op_slice_0_2, gamma_op_slice_2_3], self.beta_op: [beta_op_slice_0_2, beta_op_slice_2_3], self.mean_op: [mean_op_slice_0_2, mean_op_slice_2_3], self.std_op: [std_op_slice_0_2, std_op_slice_2_3], } # All OpSlice have groups. self.op_group_dict = { batch_norm_op_slice_0_2: batch_norm_op_group1, batch_norm_op_slice_2_3: batch_norm_op_group2, conv_op_slice_0_2: conv_op_group1, conv_op_slice_2_3: conv_op_group2, relu_op_slice_0_2: relu_op_group1, relu_op_slice_2_3: relu_op_group2, gamma_op_slice_0_2: gamma_op_group1, gamma_op_slice_2_3: gamma_op_group2, beta_op_slice_0_2: beta_op_group1, beta_op_slice_2_3: beta_op_group2, mean_op_slice_0_2: mean_op_group1, mean_op_slice_2_3: mean_op_group2, std_op_slice_0_2: std_op_group1, std_op_slice_2_3: std_op_group2, } # Call handler to assign grouping. handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( _GAMMA_THRESHOLD) handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up op slice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_any_call( self.batch_norm_op) self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op) self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op) # Verify manager groups batch norm with outputs and inputs by slice. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ batch_norm_op_slice_0_2, relu_op_slice_0_2, mean_op_slice_0_2, std_op_slice_0_2 ]), mock.call([ batch_norm_op_slice_0_2, conv_op_slice_0_2, gamma_op_slice_0_2, beta_op_slice_0_2 ], omit_source_op_slices=[]), mock.call([ batch_norm_op_slice_2_3, relu_op_slice_2_3, mean_op_slice_2_3, std_op_slice_2_3 ]), mock.call([ batch_norm_op_slice_2_3, conv_op_slice_2_3, gamma_op_slice_2_3, beta_op_slice_2_3 ], omit_source_op_slices=[]) ]) # Verify manager does not reprocess any ops. self.mock_op_reg_manager.process_ops.assert_not_called() self.mock_op_reg_manager.process_ops_last.assert_not_called()