def setUp(self): super(LogisticSigmoidSourceOpHandlerTest, self).setUp() tf.reset_default_graph() # This tests a Conv2D -> ReLU -> LogisticSigmoidGating chain of ops. inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') activation_gating.logistic_sigmoid_gating(c1, axis=3, is_training=True) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops that are created in the test network. self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5)) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5)) self.relu_op_group = orm.OpGroup(self.relu_op_slice) self.activation_gating_op = g.get_operation_by_name( 'logistic_sigmoid_gating/LogisticSigmoidGating') self.activation_gating_op_slice = orm.OpSlice( self.activation_gating_op, orm.Slice(0, 5)) self.activation_gating_op_group = orm.OpGroup( self.activation_gating_op_slice) self.mask_logits_op = g.get_operation_by_name( 'logistic_sigmoid_gating/mask_logits/Read/ReadVariableOp') self.mask_logits_op_slice = orm.OpSlice(self.mask_logits_op, orm.Slice(0, 5)) self.mask_logits_op_group = orm.OpGroup( self.mask_logits_op_slice, omit_source_op_slices=[self.mask_logits_op_slice]) self.multiply_op = g.get_operation_by_name( 'logistic_sigmoid_gating/Identity') self.multiply_op_slice = orm.OpSlice(self.multiply_op, orm.Slice(0, 5)) self.multiply_op_group = orm.OpGroup( self.multiply_op_slice, omit_source_op_slices=[self.multiply_op_slice]) # Create custom mapping of OpSlice and OpGroup in manager. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.conv_op, self.relu_op, self.activation_gating_op, self.multiply_op ]
def setUp(self): super(GroupingOpHandlerTest, self).setUp() tf.reset_default_graph() # This tests a Conv2D -> BatchNorm -> ReLU chain of ops. with framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.batch_norm_op = g.get_operation_by_name( 'conv1/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5)) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5)) self.relu_op_group = orm.OpGroup(self.relu_op_slice) self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read') self.gamma_op_slice = orm.OpSlice(self.gamma_op, orm.Slice(0, 5)) self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read') self.beta_op_slice = orm.OpSlice(self.beta_op, orm.Slice(0, 5)) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) self.op_slice_dict = { self.batch_norm_op: [self.batch_norm_op_slice], self.conv_op: [self.conv_op_slice], self.relu_op: [self.relu_op_slice], self.gamma_op: [self.gamma_op_slice], self.beta_op: [self.beta_op_slice], } def get_op_slices(op): return self.op_slice_dict.get(op) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op, self.beta_op ]
def setUp(self): tf.reset_default_graph() # Test a Identity -> DepthToSpace -> Identity chain of ops. inputs = tf.zeros([2, 4, 4, 4]) id1 = tf.identity(inputs) dts = tf.depth_to_space(id1, 2) tf.identity(dts) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.id1_op = g.get_operation_by_name('Identity') self.id1_op_slice = orm.OpSlice(self.id1_op, orm.Slice(0, 4)) self.id1_op_group = orm.OpGroup( self.id1_op_slice, omit_source_op_slices=[self.id1_op_slice]) self.id1_op_slice0 = orm.OpSlice(self.id1_op, orm.Slice(0, 1)) self.id1_op_slice1 = orm.OpSlice(self.id1_op, orm.Slice(1, 1)) self.id1_op_slice2 = orm.OpSlice(self.id1_op, orm.Slice(2, 1)) self.id1_op_slice3 = orm.OpSlice(self.id1_op, orm.Slice(3, 1)) self.dts_op = g.get_operation_by_name('DepthToSpace') self.dts_op_slice = orm.OpSlice(self.dts_op, orm.Slice(0, 1)) self.dts_op_group = orm.OpGroup( self.dts_op_slice, omit_source_op_slices=[self.dts_op_slice]) self.id2_op = g.get_operation_by_name('Identity_1') self.id2_op_slice = orm.OpSlice(self.id2_op, orm.Slice(0, 1)) self.id2_op_group = orm.OpGroup( self.id2_op_slice, omit_source_op_slices=[self.id2_op_slice]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) self.op_slice_dict = { self.id1_op: [self.id1_op_slice], self.dts_op: [self.dts_op_slice], self.id2_op: [self.id2_op_slice], } def get_op_slices(op): return self.op_slice_dict.get(op) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [self.id1_op, self.dts_op, self.id2_op]
def testIsBroadcast(self): handler = grouping_op_handler.GroupingOpHandler() self.op_group_dict = {} # Size is not 1. self.assertFalse(handler._is_broadcast(self.batch_norm_op, self.mock_op_reg_manager)) # Size is 1 but op is not grouped. ungrouped_broadcast_input = tf.zeros([2, 4, 4, 1]) ungrouped_broadcast_input_slice = orm.OpSlice(ungrouped_broadcast_input, orm.Slice(0, 1)) self.op_slice_dict[ungrouped_broadcast_input.op] = [ ungrouped_broadcast_input_slice] self.assertFalse(handler._is_broadcast(ungrouped_broadcast_input.op, self.mock_op_reg_manager)) # Size is 1 and op is grouped. broadcast_input = tf.zeros([2, 4, 4, 1]) broadcast_input_slice = orm.OpSlice(broadcast_input.op, orm.Slice(0, 1)) self.op_slice_dict[broadcast_input.op] = [broadcast_input_slice] broadcast_input_group = orm.OpGroup(broadcast_input_slice) self.op_group_dict[broadcast_input_slice] = broadcast_input_group self.assertTrue(handler._is_broadcast(broadcast_input.op, self.mock_op_reg_manager))
def testGetInputSourceOpsToOmit_IsSource(self): input_op_slices = [ self.relu2_op_slice, self.relu3_op_slice, self.relu4_op_slice ] # ReLU3 and concat are source ops now. relu3_op_group = orm.OpGroup(self.relu3_op_slice) concat_op_group = orm.OpGroup(self.concat_op_slice) self.op_group_dict = { self.relu2_op_slice: self.relu2_op_group, self.relu3_op_slice: relu3_op_group, self.relu4_op_slice: self.relu4_op_group, self.concat_op_slice: concat_op_group, } expected_ops_to_omit = [self.relu3_op_slice] self.assertEqual( expected_ops_to_omit, op_handler_util._get_input_source_ops_to_omit( input_op_slices, self.concat_op_slice, self.mock_op_reg_manager))
def testGroupOpWithInputsAndOutputs_MultipleSlices(self): # For the multiple slice case, verify that batch norm slices are grouped # with output slices (ReLU) and input slices (Conv2D). batch_norm_op_slice_0_2 = orm.OpSlice(self.batch_norm_op, orm.OpSlice(0, 2)) batch_norm_op_slice_2_5 = orm.OpSlice(self.batch_norm_op, orm.OpSlice(2, 3)) batch_norm_op_group1 = orm.OpGroup(batch_norm_op_slice_0_2) batch_norm_op_group2 = orm.OpGroup(batch_norm_op_slice_2_5) conv_op_slice_0_2 = orm.OpSlice(self.conv_op, orm.OpSlice(0, 2)) conv_op_slice_2_5 = orm.OpSlice(self.conv_op, orm.OpSlice(2, 3)) conv_op_group1 = orm.OpGroup(conv_op_slice_0_2, omit_source_op_slices=[conv_op_slice_0_2]) conv_op_group2 = orm.OpGroup(conv_op_slice_2_5, omit_source_op_slices=[conv_op_slice_2_5]) relu_op_slice_0_2 = orm.OpSlice(self.relu_op, orm.OpSlice(0, 2)) relu_op_slice_2_5 = orm.OpSlice(self.relu_op, orm.OpSlice(2, 3)) relu_op_group1 = orm.OpGroup(relu_op_slice_0_2) relu_op_group2 = orm.OpGroup(relu_op_slice_2_5) aligned_op_slice_sizes = [2, 3] self.op_slice_dict = { self.batch_norm_op: [batch_norm_op_slice_0_2, batch_norm_op_slice_2_5], self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_5], self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_5], } # All ops have groups. self.op_group_dict = { batch_norm_op_slice_0_2: batch_norm_op_group1, batch_norm_op_slice_2_5: batch_norm_op_group2, conv_op_slice_0_2: conv_op_group1, conv_op_slice_2_5: conv_op_group2, relu_op_slice_0_2: relu_op_group1, relu_op_slice_2_5: relu_op_group2, } ops_grouped = op_handler_util.group_op_with_inputs_and_outputs( self.batch_norm_op, [[conv_op_slice_0_2, conv_op_slice_2_5]], [[relu_op_slice_0_2, relu_op_slice_2_5]], aligned_op_slice_sizes, self.mock_op_reg_manager) # Verify manager looks up op slice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_any_call( self.batch_norm_op) # Verify manager groups batch norm with Conv2D and ReLU ops. self.assertTrue(ops_grouped) self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([batch_norm_op_slice_0_2, relu_op_slice_0_2]), mock.call([batch_norm_op_slice_0_2, conv_op_slice_0_2], omit_source_op_slices=[]), mock.call([batch_norm_op_slice_2_5, relu_op_slice_2_5]), mock.call([batch_norm_op_slice_2_5, conv_op_slice_2_5], omit_source_op_slices=[]) ])
def testAssignGrouping_NeighborsHaveSameGroup_ReprocessSources(self): source_conv_op_group = orm.OpGroup(self.conv_op_slice) self.op_slice_dict = { self.batch_norm_op: [self.batch_norm_op_slice], self.conv_op: [self.conv_op_slice], self.relu_op: [self.relu_op_slice], self.gamma_op: [self.gamma_op_slice], self.beta_op: [self.beta_op_slice], self.mean_op: [self.mean_op_slice], self.std_op: [self.std_op_slice], } self.op_group_dict = { self.batch_norm_op_slice: self.batch_norm_op_group, self.conv_op_slice: source_conv_op_group, self.relu_op_slice: self.batch_norm_op_group, self.gamma_op_slice: self.batch_norm_op_group, self.beta_op_slice: self.batch_norm_op_group, } source_ops = (self.conv_op, ) def is_source_op(op): return op in source_ops self.mock_op_reg_manager.is_source_op.side_effect = is_source_op # Call handler to assign grouping. handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( _GAMMA_THRESHOLD) handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up op slice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_any_call( self.batch_norm_op) self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op) self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op) # Verify manager groups batch norm with inputs and overrides source. self.mock_op_reg_manager.group_op_slices.assert_called_once_with( [ self.batch_norm_op_slice, self.conv_op_slice, self.gamma_op_slice, self.beta_op_slice ], omit_source_op_slices=[self.conv_op_slice]) # Verify manager adds ungrouped output ops to queue. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.mean_op, self.std_op]) self.mock_op_reg_manager.process_ops_last.assert_not_called()
def setUp(self): tf.reset_default_graph() # This tests a Conv2D -> SeparableConv2D -> Conv2D chain of ops. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') c2 = layers.separable_conv2d(c1, num_outputs=8, kernel_size=3, depth_multiplier=2, scope='conv2') layers.conv2d(c2, num_outputs=6, kernel_size=3, scope='conv3') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.dwise_conv2_op = g.get_operation_by_name( 'conv2/separable_conv2d/depthwise') self.dwise_conv2_op_slice = orm.OpSlice(self.dwise_conv2_op, orm.Slice(0, 10)) self.dwise_conv2_op_slice_0_1 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(0, 1)) self.dwise_conv2_op_slice_1_2 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(1, 1)) self.dwise_conv2_op_slice_2_3 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(2, 1)) self.dwise_conv2_op_slice_3_4 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(3, 1)) self.dwise_conv2_op_slice_4_5 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(4, 1)) self.dwise_conv2_op_slice_5_6 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(5, 1)) self.dwise_conv2_op_slice_6_7 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(6, 1)) self.dwise_conv2_op_slice_7_8 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(7, 1)) self.dwise_conv2_op_slice_8_9 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(8, 1)) self.dwise_conv2_op_slice_9_10 = orm.OpSlice(self.dwise_conv2_op, orm.Slice(9, 1)) self.conv2_op = g.get_operation_by_name('conv2/separable_conv2d') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 8)) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_slice_0_1 = orm.OpSlice(self.relu1_op, orm.Slice(0, 1)) self.relu1_op_slice_1_2 = orm.OpSlice(self.relu1_op, orm.Slice(1, 1)) self.relu1_op_slice_2_3 = orm.OpSlice(self.relu1_op, orm.Slice(2, 1)) self.relu1_op_slice_3_4 = orm.OpSlice(self.relu1_op, orm.Slice(3, 1)) self.relu1_op_slice_4_5 = orm.OpSlice(self.relu1_op, orm.Slice(4, 1)) self.relu1_op_group = orm.OpGroup(self.relu1_op_slice) self.conv3_op = g.get_operation_by_name('conv3/Conv2D') self.conv3_op_slice = orm.OpSlice(self.conv3_op, orm.Slice(0, 6)) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) self.op_slice_dict = { self.dwise_conv2_op: [self.dwise_conv2_op_slice], self.conv2_op: [self.conv2_op_slice], self.relu1_op: [self.relu1_op_slice], self.conv3_op: [self.conv3_op_slice], } def get_op_slices(op): return self.op_slice_dict.get(op) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.dwise_conv2_op: self.op_slice_dict[self.dwise_conv2_op] = [ self.dwise_conv2_op_slice_0_1, self.dwise_conv2_op_slice_1_2, self.dwise_conv2_op_slice_2_3, self.dwise_conv2_op_slice_3_4, self.dwise_conv2_op_slice_4_5, self.dwise_conv2_op_slice_5_6, self.dwise_conv2_op_slice_6_7, self.dwise_conv2_op_slice_7_8, self.dwise_conv2_op_slice_8_9, self.dwise_conv2_op_slice_9_10 ] if op == self.relu1_op: self.op_slice_dict[self.relu1_op] = [ self.relu1_op_slice_0_1, self.relu1_op_slice_1_2, self.relu1_op_slice_2_3, self.relu1_op_slice_3_4, self.relu1_op_slice_4_5 ] self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.slice_op.side_effect = slice_op self.mock_op_reg_manager.ops = [ self.relu1_op, self.dwise_conv2_op, self.conv2_op, self.conv3_op ]
def testAssignGrouping_NoDepthMultiplier(self): # Repeat setUp, but with depth_multiplier=1. Unfortunately, this involves # rebuilding the graph from scratch. tf.reset_default_graph() # This tests a Conv2D -> SeparableConv2D -> Conv2D chain of ops. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') c2 = layers.separable_conv2d(c1, num_outputs=8, kernel_size=3, depth_multiplier=1, scope='conv2') layers.conv2d(c2, num_outputs=6, kernel_size=3, scope='conv3') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.dwise_conv2_op = g.get_operation_by_name( 'conv2/separable_conv2d/depthwise') self.dwise_conv2_op_slice = orm.OpSlice(self.dwise_conv2_op, orm.Slice(0, 5)) self.conv2_op = g.get_operation_by_name('conv2/separable_conv2d') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 8)) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup(self.relu1_op_slice) self.conv3_op = g.get_operation_by_name('conv3/Conv2D') self.conv3_op_slice = orm.OpSlice(self.conv3_op, orm.Slice(0, 6)) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) self.op_slice_dict = { self.dwise_conv2_op: [self.dwise_conv2_op_slice], self.conv2_op: [self.conv2_op_slice], self.relu1_op: [self.relu1_op_slice], self.conv3_op: [self.conv3_op_slice], } def get_op_slices(op): return self.op_slice_dict.get(op) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.relu1_op, self.dwise_conv2_op, self.conv2_op, self.conv3_op ] # All neighbor ops have groups. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, } # Call handler to assign grouping. handler = depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler( ) handler.assign_grouping(self.dwise_conv2_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.conv2_op), # Initial slice data. mock.call(self.dwise_conv2_op), mock.call(self.relu1_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.dwise_conv2_op), # Refreshing slice data. mock.call(self.relu1_op), # Group depthwise convolution. mock.call(self.dwise_conv2_op) ]) # Verify manager groups batch norm with inputs and outputs. self.mock_op_reg_manager.group_op_slices.assert_called_once_with( [self.dwise_conv2_op_slice, self.relu1_op_slice]) # Verify manager does not process any additional ops. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.conv2_op]) self.mock_op_reg_manager.process_ops_last.assert_not_called()
def setUp(self): super(OpHandlerUtilTest, self).setUp() tf.reset_default_graph() # This tests a Conv2D -> BatchNorm -> ReLU chain of ops. with framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') # This tests 3 Conv2D ops being concatenated before a batch normalization. c2 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv2') c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3') c4 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv4') net = tf.concat([c2, c3, c4], axis=3) layers.batch_norm(net) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops in the first test network. self.batch_norm_op = g.get_operation_by_name( 'conv1/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, None) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, None) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read') self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read') self.decay_op = g.get_operation_by_name('conv1/BatchNorm/Const') self.epsilon_op = g.get_operation_by_name('conv1/BatchNorm/Const_1') self.mean_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg/sub_1') self.std_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg_1/sub_1') self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, None) self.relu_op_group = orm.OpGroup( self.relu_op_slice, omit_source_op_slices=[self.relu_op_slice]) # Declare OpSlice and OpGroup for ops in the second test network. self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 5)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.relu3_op = g.get_operation_by_name('conv3/Relu') self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6)) self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) self.relu4_op = g.get_operation_by_name('conv4/Relu') self.relu4_op_slice = orm.OpSlice(self.relu4_op, orm.Slice(0, 7)) self.relu4_op_group = orm.OpGroup( self.relu4_op_slice, omit_source_op_slices=[self.relu4_op_slice]) self.unfused_batch_norm_op = g.get_operation_by_name( 'BatchNorm/FusedBatchNormV3') self.unfused_batch_norm_op_slice = orm.OpSlice( self.unfused_batch_norm_op, orm.Slice(0, 18)) self.concat_op = g.get_operation_by_name('concat') self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18)) self.concat_op_group = orm.OpGroup( self.concat_op_slice, omit_source_op_slices=[self.concat_op_slice]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) def is_passthrough(op): return op in self._passthrough_ops self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough self.mock_op_reg_manager.ops = [ self.batch_norm_op, self.gamma_op, self.beta_op, self.decay_op, self.epsilon_op, self.mean_op, self.std_op, self.conv_op, self.relu_op, self.relu2_op, self.relu3_op, self.relu4_op, self.unfused_batch_norm_op, self.concat_op ]
def testAssignGrouping_ProcessNeighborGroupsWithSlices(self): batch_norm_op_slice_0_2 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 2)) batch_norm_op_slice_2_3 = orm.OpSlice(self.batch_norm_op, orm.Slice(2, 1)) batch_norm_op_group1 = orm.OpGroup(batch_norm_op_slice_0_2) batch_norm_op_group2 = orm.OpGroup(batch_norm_op_slice_2_3) conv_op_slice_0_2 = orm.OpSlice(self.conv_op, orm.Slice(0, 2)) conv_op_slice_2_3 = orm.OpSlice(self.conv_op, orm.Slice(2, 1)) conv_op_group1 = orm.OpGroup(conv_op_slice_0_2, omit_source_op_slices=[conv_op_slice_0_2]) conv_op_group2 = orm.OpGroup(conv_op_slice_2_3, omit_source_op_slices=[conv_op_slice_2_3]) relu_op_slice_0_2 = orm.OpSlice(self.relu_op, orm.Slice(0, 2)) relu_op_slice_2_3 = orm.OpSlice(self.relu_op, orm.Slice(2, 1)) relu_op_group1 = orm.OpGroup(relu_op_slice_0_2) relu_op_group2 = orm.OpGroup(relu_op_slice_2_3) gamma_op_slice_0_2 = orm.OpSlice(self.gamma_op, orm.Slice(0, 2)) gamma_op_slice_2_3 = orm.OpSlice(self.gamma_op, orm.Slice(2, 1)) gamma_op_group1 = orm.OpGroup( gamma_op_slice_0_2, omit_source_op_slices=[gamma_op_slice_0_2]) gamma_op_group2 = orm.OpGroup( gamma_op_slice_2_3, omit_source_op_slices=[gamma_op_slice_2_3]) beta_op_slice_0_2 = orm.OpSlice(self.beta_op, orm.Slice(0, 2)) beta_op_slice_2_3 = orm.OpSlice(self.beta_op, orm.Slice(2, 1)) beta_op_group1 = orm.OpGroup(beta_op_slice_0_2, omit_source_op_slices=[beta_op_slice_0_2]) beta_op_group2 = orm.OpGroup(beta_op_slice_2_3, omit_source_op_slices=[beta_op_slice_2_3]) mean_op_slice_0_2 = orm.OpSlice(self.mean_op, orm.Slice(0, 2)) mean_op_slice_2_3 = orm.OpSlice(self.mean_op, orm.Slice(2, 1)) mean_op_group1 = orm.OpGroup(mean_op_slice_0_2, omit_source_op_slices=[mean_op_slice_0_2]) mean_op_group2 = orm.OpGroup(mean_op_slice_2_3, omit_source_op_slices=[mean_op_slice_2_3]) std_op_slice_0_2 = orm.OpSlice(self.std_op, orm.Slice(0, 2)) std_op_slice_2_3 = orm.OpSlice(self.std_op, orm.Slice(2, 1)) std_op_group1 = orm.OpGroup(std_op_slice_0_2, omit_source_op_slices=[std_op_slice_0_2]) std_op_group2 = orm.OpGroup(std_op_slice_2_3, omit_source_op_slices=[std_op_slice_2_3]) self.op_slice_dict = { self.batch_norm_op: [batch_norm_op_slice_0_2, batch_norm_op_slice_2_3], self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_3], self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_3], self.gamma_op: [gamma_op_slice_0_2, gamma_op_slice_2_3], self.beta_op: [beta_op_slice_0_2, beta_op_slice_2_3], self.mean_op: [mean_op_slice_0_2, mean_op_slice_2_3], self.std_op: [std_op_slice_0_2, std_op_slice_2_3], } # All OpSlice have groups. self.op_group_dict = { batch_norm_op_slice_0_2: batch_norm_op_group1, batch_norm_op_slice_2_3: batch_norm_op_group2, conv_op_slice_0_2: conv_op_group1, conv_op_slice_2_3: conv_op_group2, relu_op_slice_0_2: relu_op_group1, relu_op_slice_2_3: relu_op_group2, gamma_op_slice_0_2: gamma_op_group1, gamma_op_slice_2_3: gamma_op_group2, beta_op_slice_0_2: beta_op_group1, beta_op_slice_2_3: beta_op_group2, mean_op_slice_0_2: mean_op_group1, mean_op_slice_2_3: mean_op_group2, std_op_slice_0_2: std_op_group1, std_op_slice_2_3: std_op_group2, } # Call handler to assign grouping. handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( _GAMMA_THRESHOLD) handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up op slice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_any_call( self.batch_norm_op) self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op) self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op) # Verify manager groups batch norm with outputs and inputs by slice. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ batch_norm_op_slice_0_2, relu_op_slice_0_2, mean_op_slice_0_2, std_op_slice_0_2 ]), mock.call([ batch_norm_op_slice_0_2, conv_op_slice_0_2, gamma_op_slice_0_2, beta_op_slice_0_2 ], omit_source_op_slices=[]), mock.call([ batch_norm_op_slice_2_3, relu_op_slice_2_3, mean_op_slice_2_3, std_op_slice_2_3 ]), mock.call([ batch_norm_op_slice_2_3, conv_op_slice_2_3, gamma_op_slice_2_3, beta_op_slice_2_3 ], omit_source_op_slices=[]) ]) # Verify manager does not reprocess any ops. self.mock_op_reg_manager.process_ops.assert_not_called() self.mock_op_reg_manager.process_ops_last.assert_not_called()
def setUp(self): tf.reset_default_graph() # This tests 2 Conv2D ops with batch norm at the top. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1', normalizer_fn=None) layers.conv2d(c1, num_outputs=6, kernel_size=3, scope='conv2') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.conv1_op = g.get_operation_by_name('conv1/Conv2D') self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5)) self.conv1_op_group = orm.OpGroup( self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.conv2_op = g.get_operation_by_name('conv2/Conv2D') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6)) self.conv2_op_group = orm.OpGroup( self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.batch_norm_op = g.get_operation_by_name( 'conv2/BatchNorm/FusedBatchNorm') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) def is_passthrough(op): if op in [self.conv1_op, self.conv2_op]: h = output_non_passthrough_op_handler.OutputNonPassthroughOpHandler() return h.is_passthrough if op == self.batch_norm_op: return True else: return False self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough self.mock_op_reg_manager.ops = [ self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op, self.batch_norm_op]
def testAssignGrouping_AllNeighborsGrouped_InputSlicesNotAligned(self): # In this test, the op c2 has size 6 but is split into 2 slices of size 3. # The concat op (and its output, the batch norm) both have size 18. This # test verifies that the concat and batch norm are sliced according to the # sizes of c1, c2, and c3, and takes into account that c2 is also composed # of multiple slices. concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) concat_op_slice_5_8 = orm.OpSlice(self.concat_op, orm.Slice(5, 3)) concat_op_slice_8_11 = orm.OpSlice(self.concat_op, orm.Slice(8, 3)) concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7)) relu2_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3)) relu2_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3)) relu2_op_group1 = orm.OpGroup( relu2_op_slice_0_3, omit_source_op_slices=[relu2_op_slice_0_3]) relu2_op_group2 = orm.OpGroup( relu2_op_slice_3_6, omit_source_op_slices=[relu2_op_slice_3_6]) batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18)) batch_norm_op_group = orm.OpGroup( batch_norm_op_slice, omit_source_op_slices=[batch_norm_op_slice]) batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) batch_norm_op_group1 = orm.OpGroup( batch_norm_op_slice_0_5, omit_source_op_slices=[batch_norm_op_slice_0_5]) batch_norm_op_slice_5_8 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 3)) batch_norm_op_group2 = orm.OpGroup( batch_norm_op_slice_5_8, omit_source_op_slices=[batch_norm_op_slice_5_8]) batch_norm_op_slice_8_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(8, 3)) batch_norm_op_group3 = orm.OpGroup( batch_norm_op_slice_8_11, omit_source_op_slices=[batch_norm_op_slice_8_11]) batch_norm_op_slice_11_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 7)) batch_norm_op_group4 = orm.OpGroup( batch_norm_op_slice_11_18, omit_source_op_slices=[batch_norm_op_slice_11_18]) # Map ops to slices. The op c2 is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [relu2_op_slice_0_3, relu2_op_slice_3_6], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [batch_norm_op_slice], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, relu2_op_slice_0_3: relu2_op_group1, relu2_op_slice_3_6: relu2_op_group2, self.relu3_op_slice: self.relu3_op_group, batch_norm_op_slice: batch_norm_op_group, batch_norm_op_slice_0_5: batch_norm_op_group1, batch_norm_op_slice_5_8: batch_norm_op_group2, batch_norm_op_slice_8_11: batch_norm_op_group3, batch_norm_op_slice_11_18: batch_norm_op_group4, } # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_8, batch_norm_op_slice_8_11, batch_norm_op_slice_11_18 ] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ concat_op_slice_0_5, concat_op_slice_5_8, concat_op_slice_8_11, concat_op_slice_11_18 ] self.mock_op_reg_manager.slice_op.side_effect = slice_op # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op) ]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.batch_norm_op, [5, 3, 3, 7]), mock.call(self.concat_op, [5, 3, 3, 7]) ]) # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ concat_op_slice_0_5, self.relu1_op_slice, batch_norm_op_slice_0_5 ]), mock.call([ concat_op_slice_5_8, relu2_op_slice_0_3, batch_norm_op_slice_5_8 ]), mock.call([ concat_op_slice_8_11, relu2_op_slice_3_6, batch_norm_op_slice_8_11 ]), mock.call([ concat_op_slice_11_18, self.relu3_op_slice, batch_norm_op_slice_11_18 ]) ])
def testAssignGrouping_AllNeighborsGrouped_OutputSlicesNotAligned(self): # The output (batch norm) has sizes [9, 4, 5] which are not aligned. This # test verifies that the concat, batch norm, and Conv2D ops are sliced in # alignment. concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) concat_op_slice_5_9 = orm.OpSlice(self.concat_op, orm.Slice(5, 4)) concat_op_slice_9_11 = orm.OpSlice(self.concat_op, orm.Slice(9, 2)) concat_op_slice_11_13 = orm.OpSlice(self.concat_op, orm.Slice(11, 2)) concat_op_slice_13_18 = orm.OpSlice(self.concat_op, orm.Slice(13, 5)) relu2_op_slice_0_4 = orm.OpSlice(self.relu2_op, orm.Slice(0, 4)) relu2_op_slice_4_6 = orm.OpSlice(self.relu2_op, orm.Slice(4, 2)) relu3_op_slice_0_2 = orm.OpSlice(self.relu3_op, orm.Slice(0, 2)) relu3_op_slice_2_7 = orm.OpSlice(self.relu3_op, orm.Slice(2, 5)) batch_norm_op_slice_0_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 9)) batch_norm_op_group1 = orm.OpGroup( batch_norm_op_slice_0_9, omit_source_op_slices=[batch_norm_op_slice_0_9]) batch_norm_op_slice_9_13 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 4)) batch_norm_op_group2 = orm.OpGroup( batch_norm_op_slice_9_13, omit_source_op_slices=[batch_norm_op_slice_9_13]) batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(13, 5)) batch_norm_op_group3 = orm.OpGroup( batch_norm_op_slice_13_18, omit_source_op_slices=[batch_norm_op_slice_13_18]) batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) batch_norm_op_group4 = orm.OpGroup( batch_norm_op_slice_0_5, omit_source_op_slices=[batch_norm_op_slice_0_5]) batch_norm_op_slice_5_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 4)) batch_norm_op_group5 = orm.OpGroup( batch_norm_op_slice_5_9, omit_source_op_slices=[batch_norm_op_slice_5_9]) batch_norm_op_slice_9_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 2)) batch_norm_op_group6 = orm.OpGroup( batch_norm_op_slice_9_11, omit_source_op_slices=[batch_norm_op_slice_9_11]) batch_norm_op_slice_11_13 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 2)) batch_norm_op_group7 = orm.OpGroup( batch_norm_op_slice_11_13, omit_source_op_slices=[batch_norm_op_slice_11_13]) batch_norm_op_slice_13_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 5)) batch_norm_op_group8 = orm.OpGroup( batch_norm_op_slice_13_18, omit_source_op_slices=[batch_norm_op_slice_13_18]) # Map ops to slices. Batch norm op is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [ batch_norm_op_slice_0_9, batch_norm_op_slice_9_13, batch_norm_op_slice_13_18 ], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, self.relu2_op_slice: self.relu2_op_group, self.relu3_op_slice: self.relu3_op_group, batch_norm_op_slice_0_9: batch_norm_op_group1, batch_norm_op_slice_9_13: batch_norm_op_group2, batch_norm_op_slice_13_18: batch_norm_op_group3, batch_norm_op_slice_0_5: batch_norm_op_group4, batch_norm_op_slice_5_9: batch_norm_op_group5, batch_norm_op_slice_9_11: batch_norm_op_group6, batch_norm_op_slice_11_13: batch_norm_op_group7, batch_norm_op_slice_13_18: batch_norm_op_group8, } # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_9, batch_norm_op_slice_9_11, batch_norm_op_slice_11_13, batch_norm_op_slice_13_18 ] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ concat_op_slice_0_5, concat_op_slice_5_9, concat_op_slice_9_11, concat_op_slice_11_13, concat_op_slice_13_18 ] if op == self.relu2_op: self.op_slice_dict[self.relu2_op] = [ relu2_op_slice_0_4, relu2_op_slice_4_6 ] if op == self.relu3_op: self.op_slice_dict[self.relu3_op] = [ relu3_op_slice_0_2, relu3_op_slice_2_7 ] self.mock_op_reg_manager.slice_op.side_effect = slice_op # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op) ]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls([ mock.call(self.relu2_op, [4, 2]), mock.call(self.relu3_op, [2, 5]), mock.call(self.batch_norm_op, [5, 4, 2, 2, 5]), mock.call(self.concat_op, [5, 4, 2, 2, 5]) ]) # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ concat_op_slice_0_5, self.relu1_op_slice, batch_norm_op_slice_0_5 ]), mock.call([ concat_op_slice_5_9, relu2_op_slice_0_4, batch_norm_op_slice_5_9 ]), mock.call([ concat_op_slice_9_11, relu2_op_slice_4_6, batch_norm_op_slice_9_11 ]), mock.call([ concat_op_slice_11_13, relu3_op_slice_0_2, batch_norm_op_slice_11_13 ]), mock.call([ concat_op_slice_13_18, relu3_op_slice_2_7, batch_norm_op_slice_13_18 ]) ])
def setUp(self): tf.reset_default_graph() # This tests 3 Conv2D ops being concatenated. inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2') c3 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv3') net = tf.concat([c1, c2, c3], axis=3) layers.batch_norm(net) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.concat_op = g.get_operation_by_name('concat') self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18)) self.concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) self.concat_op_slice_5_11 = orm.OpSlice(self.concat_op, orm.Slice(5, 6)) self.concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7)) self.concat_op_group1 = orm.OpGroup( self.concat_op_slice_0_5, omit_source_op_slices=[self.concat_op_slice_0_5]) self.concat_op_group2 = orm.OpGroup( self.concat_op_slice_5_11, omit_source_op_slices=[self.concat_op_slice_5_11]) self.concat_op_group3 = orm.OpGroup( self.concat_op_slice_11_18, omit_source_op_slices=[self.concat_op_slice_11_18]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.relu3_op = g.get_operation_by_name('conv3/Relu') self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 7)) self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) self.axis_op = g.get_operation_by_name('concat/axis') self.batch_norm_op = g.get_operation_by_name( 'BatchNorm/FusedBatchNorm') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18)) self.batch_norm_op_group = orm.OpGroup( self.batch_norm_op_slice, omit_source_op_slices=[self.batch_norm_op_slice]) self.batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_slice_5_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 6)) self.batch_norm_op_slice_11_18 = orm.OpSlice(self.batch_norm_op, orm.Slice(11, 7)) self.batch_norm_op_group1 = orm.OpGroup( self.batch_norm_op_slice_0_5, omit_source_op_slices=[self.batch_norm_op_slice_0_5]) self.batch_norm_op_group2 = orm.OpGroup( self.batch_norm_op_slice_5_11, omit_source_op_slices=[self.batch_norm_op_slice_5_11]) self.batch_norm_op_group3 = orm.OpGroup( self.batch_norm_op_slice_11_18, omit_source_op_slices=[self.batch_norm_op_slice_11_18]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ self.batch_norm_op_slice_0_5, self.batch_norm_op_slice_5_11, self.batch_norm_op_slice_11_18 ] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ self.concat_op_slice_0_5, self.concat_op_slice_5_11, self.concat_op_slice_11_18 ] self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.slice_op.side_effect = slice_op self.mock_op_reg_manager.is_passthrough.return_value = True self.mock_op_reg_manager.ops = [ self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op, self.batch_norm_op ]
def setUp(self): tf.reset_default_graph() # This tests a Conv2D -> BatchNorm -> ReLU chain of ops. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops that are created in the test network. self.batch_norm_op = g.get_operation_by_name( 'conv1/BatchNorm/FusedBatchNorm') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5)) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5)) self.relu_op_group = orm.OpGroup(self.relu_op_slice) self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read') self.gamma_op_slice = orm.OpSlice(self.gamma_op, orm.Slice(0, 5)) self.gamma_op_group = orm.OpGroup( self.gamma_op_slice, omit_source_op_slices=[self.gamma_op_slice]) self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read') self.beta_op_slice = orm.OpSlice(self.beta_op, orm.Slice(0, 5)) self.beta_op_group = orm.OpGroup( self.beta_op_slice, omit_source_op_slices=[self.beta_op_slice]) self.mean_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg/sub_1') self.mean_op_slice = orm.OpSlice(self.mean_op, orm.Slice(0, 5)) self.mean_op_group = orm.OpGroup( self.mean_op_slice, omit_source_op_slices=[self.mean_op_slice]) self.std_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg_1/sub_1') self.std_op_slice = orm.OpSlice(self.std_op, orm.Slice(0, 5)) self.std_op_group = orm.OpGroup( self.std_op_slice, omit_source_op_slices=[self.std_op_slice]) # Create custom mapping of OpSlice and OpGroup in manager. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op, self.beta_op, self.mean_op, self.std_op ]
def setUp(self): tf.reset_default_graph() # This tests 3 Conv2D ops being concatenated. inputs = tf.zeros([2, 4, 4, 3]) with tf.contrib.framework.arg_scope(self._get_scope()): c1 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv1') c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2') c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3') net = tf.concat([c1, c2, c3], axis=2) layers.batch_norm(net) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.concat_op = g.get_operation_by_name('concat') self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 6)) self.concat_op_group = orm.OpGroup( self.concat_op_slice, omit_source_op_slices=[self.concat_op_slice]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 6)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.relu3_op = g.get_operation_by_name('conv3/Relu') self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6)) self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNorm') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6)) self.batch_norm_op_group = orm.OpGroup( self.batch_norm_op_slice, omit_source_op_slices=[self.batch_norm_op_slice]) self.concat_group = orm.OpGroup( op_slice=None, op_groups=[ self.batch_norm_op_group, self.concat_op_group, self.relu1_op_group, self.relu2_op_group, self.relu3_op_group ]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.is_passthrough.return_value = True self.mock_op_reg_manager.ops = [ self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op, self.batch_norm_op]
def setUp(self): tf.reset_default_graph() # This tests 2 Conv2D ops. inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1', normalizer_fn=None) layers.conv2d(c1, num_outputs=6, kernel_size=3, scope='conv2', normalizer_fn=None) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.conv1_op = g.get_operation_by_name('conv1/Conv2D') self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5)) self.conv1_op_group = orm.OpGroup( self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.conv2_op = g.get_operation_by_name('conv2/Conv2D') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6)) self.conv2_op_group = orm.OpGroup( self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice]) self.conv2_weights_op = g.get_operation_by_name('conv2/weights/read') self.conv2_weights_op_slice = orm.OpSlice(self.conv2_weights_op, orm.Slice(0, 6)) self.conv2_weights_op_group = orm.OpGroup( self.conv2_weights_op_slice, omit_source_op_slices=[self.conv2_weights_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) def is_passthrough(op): if op in [self.conv1_op, self.conv2_op]: h = conv2d_source_op_handler.Conv2DSourceOpHandler( _DEFAULT_THRESHOLD) return h.is_passthrough else: return False self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough self.mock_op_reg_manager.ops = [ self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op, self.conv2_weights_op ]