def testAssignGrouping_NoDepthMultiplier(self): # Repeat setUp, but with depth_multiplier=1. Unfortunately, this involves # rebuilding the graph from scratch. tf.reset_default_graph() # This tests a Conv2D -> SeparableConv2D -> Conv2D chain of ops. with framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') c2 = layers.separable_conv2d(c1, num_outputs=8, kernel_size=3, depth_multiplier=1, scope='conv2') layers.conv2d(c2, num_outputs=6, kernel_size=3, scope='conv3') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.dwise_conv2_op = g.get_operation_by_name( 'conv2/separable_conv2d/depthwise') self.dwise_conv2_op_slice = orm.OpSlice(self.dwise_conv2_op, orm.Slice(0, 5)) self.conv2_op = g.get_operation_by_name('conv2/separable_conv2d') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 8)) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup(self.relu1_op_slice) self.conv3_op = g.get_operation_by_name('conv3/Conv2D') self.conv3_op_slice = orm.OpSlice(self.conv3_op, orm.Slice(0, 6)) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) self.op_slice_dict = { self.dwise_conv2_op: [self.dwise_conv2_op_slice], self.conv2_op: [self.conv2_op_slice], self.relu1_op: [self.relu1_op_slice], self.conv3_op: [self.conv3_op_slice], } def get_op_slices(op): return self.op_slice_dict.get(op) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.relu1_op, self.dwise_conv2_op, self.conv2_op, self.conv3_op ] # All neighbor ops have groups. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, } # Call handler to assign grouping. handler = depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler( ) handler.assign_grouping(self.dwise_conv2_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.relu1_op), mock.call(self.conv2_op), # Initial slice data. mock.call(self.dwise_conv2_op), mock.call(self.relu1_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.dwise_conv2_op), # Refreshing slice data. mock.call(self.relu1_op), # Group depthwise convolution. mock.call(self.dwise_conv2_op) ]) # Verify manager groups batch norm with inputs and outputs. self.mock_op_reg_manager.group_op_slices.assert_called_once_with( [self.dwise_conv2_op_slice, self.relu1_op_slice]) # Verify manager does not process any additional ops. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.conv2_op]) self.mock_op_reg_manager.process_ops_last.assert_not_called()
def setUp(self): tf.reset_default_graph() # This tests a Conv2D -> BatchNorm -> ReLU chain of ops. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops that are created in the test network. self.batch_norm_op = g.get_operation_by_name( 'conv1/BatchNorm/FusedBatchNorm') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, orm.Slice(0, 5)) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, orm.Slice(0, 5)) self.relu_op_group = orm.OpGroup(self.relu_op_slice) self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read') self.gamma_op_slice = orm.OpSlice(self.gamma_op, orm.Slice(0, 5)) self.gamma_op_group = orm.OpGroup( self.gamma_op_slice, omit_source_op_slices=[self.gamma_op_slice]) self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read') self.beta_op_slice = orm.OpSlice(self.beta_op, orm.Slice(0, 5)) self.beta_op_group = orm.OpGroup( self.beta_op_slice, omit_source_op_slices=[self.beta_op_slice]) self.mean_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg/sub_1') self.mean_op_slice = orm.OpSlice(self.mean_op, orm.Slice(0, 5)) self.mean_op_group = orm.OpGroup( self.mean_op_slice, omit_source_op_slices=[self.mean_op_slice]) self.std_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg_1/sub_1') self.std_op_slice = orm.OpSlice(self.std_op, orm.Slice(0, 5)) self.std_op_group = orm.OpGroup( self.std_op_slice, omit_source_op_slices=[self.std_op_slice]) # Create custom mapping of OpSlice and OpGroup in manager. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.ops = [ self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op, self.beta_op, self.mean_op, self.std_op ]
def setUp(self): tf.reset_default_graph() # This tests 2 Conv2D ops with batch norm at the top. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1', normalizer_fn=None) layers.conv2d(c1, num_outputs=6, kernel_size=3, scope='conv2') g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.conv1_op = g.get_operation_by_name('conv1/Conv2D') self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5)) self.conv1_op_group = orm.OpGroup( self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.conv2_op = g.get_operation_by_name('conv2/Conv2D') self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6)) self.conv2_op_group = orm.OpGroup( self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.batch_norm_op = g.get_operation_by_name( 'conv2/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6)) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) def is_passthrough(op): if op in [self.conv1_op, self.conv2_op]: h = output_non_passthrough_op_handler.OutputNonPassthroughOpHandler( ) return h.is_passthrough if op == self.batch_norm_op: return True else: return False self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough self.mock_op_reg_manager.ops = [ self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op, self.batch_norm_op ]
def _build(self, conv_type): assert conv_type in ['Conv2D', 'Conv3D'] if conv_type == 'Conv2D': inputs = tf.zeros([2, 4, 4, 3]) conv_fn = layers.conv2d else: inputs = tf.zeros([2, 4, 4, 4, 3]) conv_fn = layers.conv3d c1 = conv_fn( inputs, num_outputs=5, kernel_size=3, scope='conv1', normalizer_fn=None) conv_fn(c1, num_outputs=6, kernel_size=3, scope='conv2', normalizer_fn=None) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.conv1_op = g.get_operation_by_name('conv1/' + conv_type) self.conv1_op_slice = orm.OpSlice(self.conv1_op, orm.Slice(0, 5)) self.conv1_op_group = orm.OpGroup( self.conv1_op_slice, omit_source_op_slices=[self.conv1_op_slice]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.conv2_op = g.get_operation_by_name('conv2/' + conv_type) self.conv2_op_slice = orm.OpSlice(self.conv2_op, orm.Slice(0, 6)) self.conv2_op_group = orm.OpGroup( self.conv2_op_slice, omit_source_op_slices=[self.conv2_op_slice]) self.conv2_weights_op = g.get_operation_by_name('conv2/weights/read') self.conv2_weights_op_slice = orm.OpSlice( self.conv2_weights_op, orm.Slice(0, 6)) self.conv2_weights_op_group = orm.OpGroup( self.conv2_weights_op_slice, omit_source_op_slices=[self.conv2_weights_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) def is_passthrough(op): if op in [self.conv1_op, self.conv2_op]: h = conv_source_op_handler.ConvSourceOpHandler(_DEFAULT_THRESHOLD) return h.is_passthrough else: return False self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough self.mock_op_reg_manager.ops = [ self.conv1_op, self.relu1_op, self.conv2_op, self.relu2_op, self.conv2_weights_op]
def testAssignGrouping_ProcessNeighborGroupsWithSlices(self): batch_norm_op_slice_0_2 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 2)) batch_norm_op_slice_2_3 = orm.OpSlice(self.batch_norm_op, orm.Slice(2, 1)) batch_norm_op_group1 = orm.OpGroup(batch_norm_op_slice_0_2) batch_norm_op_group2 = orm.OpGroup(batch_norm_op_slice_2_3) conv_op_slice_0_2 = orm.OpSlice(self.conv_op, orm.Slice(0, 2)) conv_op_slice_2_3 = orm.OpSlice(self.conv_op, orm.Slice(2, 1)) conv_op_group1 = orm.OpGroup(conv_op_slice_0_2, omit_source_op_slices=[conv_op_slice_0_2]) conv_op_group2 = orm.OpGroup(conv_op_slice_2_3, omit_source_op_slices=[conv_op_slice_2_3]) relu_op_slice_0_2 = orm.OpSlice(self.relu_op, orm.Slice(0, 2)) relu_op_slice_2_3 = orm.OpSlice(self.relu_op, orm.Slice(2, 1)) relu_op_group1 = orm.OpGroup(relu_op_slice_0_2) relu_op_group2 = orm.OpGroup(relu_op_slice_2_3) gamma_op_slice_0_2 = orm.OpSlice(self.gamma_op, orm.Slice(0, 2)) gamma_op_slice_2_3 = orm.OpSlice(self.gamma_op, orm.Slice(2, 1)) gamma_op_group1 = orm.OpGroup( gamma_op_slice_0_2, omit_source_op_slices=[gamma_op_slice_0_2]) gamma_op_group2 = orm.OpGroup( gamma_op_slice_2_3, omit_source_op_slices=[gamma_op_slice_2_3]) beta_op_slice_0_2 = orm.OpSlice(self.beta_op, orm.Slice(0, 2)) beta_op_slice_2_3 = orm.OpSlice(self.beta_op, orm.Slice(2, 1)) beta_op_group1 = orm.OpGroup(beta_op_slice_0_2, omit_source_op_slices=[beta_op_slice_0_2]) beta_op_group2 = orm.OpGroup(beta_op_slice_2_3, omit_source_op_slices=[beta_op_slice_2_3]) mean_op_slice_0_2 = orm.OpSlice(self.mean_op, orm.Slice(0, 2)) mean_op_slice_2_3 = orm.OpSlice(self.mean_op, orm.Slice(2, 1)) mean_op_group1 = orm.OpGroup(mean_op_slice_0_2, omit_source_op_slices=[mean_op_slice_0_2]) mean_op_group2 = orm.OpGroup(mean_op_slice_2_3, omit_source_op_slices=[mean_op_slice_2_3]) std_op_slice_0_2 = orm.OpSlice(self.std_op, orm.Slice(0, 2)) std_op_slice_2_3 = orm.OpSlice(self.std_op, orm.Slice(2, 1)) std_op_group1 = orm.OpGroup(std_op_slice_0_2, omit_source_op_slices=[std_op_slice_0_2]) std_op_group2 = orm.OpGroup(std_op_slice_2_3, omit_source_op_slices=[std_op_slice_2_3]) self.op_slice_dict = { self.batch_norm_op: [batch_norm_op_slice_0_2, batch_norm_op_slice_2_3], self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_3], self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_3], self.gamma_op: [gamma_op_slice_0_2, gamma_op_slice_2_3], self.beta_op: [beta_op_slice_0_2, beta_op_slice_2_3], self.mean_op: [mean_op_slice_0_2, mean_op_slice_2_3], self.std_op: [std_op_slice_0_2, std_op_slice_2_3], } # All OpSlice have groups. self.op_group_dict = { batch_norm_op_slice_0_2: batch_norm_op_group1, batch_norm_op_slice_2_3: batch_norm_op_group2, conv_op_slice_0_2: conv_op_group1, conv_op_slice_2_3: conv_op_group2, relu_op_slice_0_2: relu_op_group1, relu_op_slice_2_3: relu_op_group2, gamma_op_slice_0_2: gamma_op_group1, gamma_op_slice_2_3: gamma_op_group2, beta_op_slice_0_2: beta_op_group1, beta_op_slice_2_3: beta_op_group2, mean_op_slice_0_2: mean_op_group1, mean_op_slice_2_3: mean_op_group2, std_op_slice_0_2: std_op_group1, std_op_slice_2_3: std_op_group2, } # Call handler to assign grouping. handler = batch_norm_source_op_handler.BatchNormSourceOpHandler( _GAMMA_THRESHOLD) handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager) # Verify manager looks up op slice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_any_call( self.batch_norm_op) self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv_op) self.mock_op_reg_manager.get_op_slices.assert_any_call(self.relu_op) # Verify manager groups batch norm with outputs and inputs by slice. self.mock_op_reg_manager.group_op_slices.assert_has_calls([ mock.call([ batch_norm_op_slice_0_2, relu_op_slice_0_2, mean_op_slice_0_2, std_op_slice_0_2 ]), mock.call([ batch_norm_op_slice_0_2, conv_op_slice_0_2, gamma_op_slice_0_2, beta_op_slice_0_2 ], omit_source_op_slices=[]), mock.call([ batch_norm_op_slice_2_3, relu_op_slice_2_3, mean_op_slice_2_3, std_op_slice_2_3 ]), mock.call([ batch_norm_op_slice_2_3, conv_op_slice_2_3, gamma_op_slice_2_3, beta_op_slice_2_3 ], omit_source_op_slices=[]) ]) # Verify manager does not reprocess any ops. self.mock_op_reg_manager.process_ops.assert_not_called() self.mock_op_reg_manager.process_ops_last.assert_not_called()
def setUp(self): tf.reset_default_graph() # This tests 3 Conv2D ops being concatenated. inputs = tf.zeros([2, 4, 4, 3]) with tf.contrib.framework.arg_scope(self._get_scope()): c1 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv1') c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2') c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3') net = tf.concat([c1, c2, c3], axis=2) layers.batch_norm(net) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.concat_op = g.get_operation_by_name('concat') self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 6)) self.concat_op_group = orm.OpGroup( self.concat_op_slice, omit_source_op_slices=[self.concat_op_slice]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 6)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.relu3_op = g.get_operation_by_name('conv3/Relu') self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6)) self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6)) self.batch_norm_op_group = orm.OpGroup( self.batch_norm_op_slice, omit_source_op_slices=[self.batch_norm_op_slice]) self.concat_group = orm.OpGroup( op_slice=None, op_groups=[ self.batch_norm_op_group, self.concat_op_group, self.relu1_op_group, self.relu2_op_group, self.relu3_op_group ]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.is_passthrough.return_value = True self.mock_op_reg_manager.ops = [ self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op, self.batch_norm_op]
def testAssignGrouping_AllNeighborsGrouped_InputSlicesNotAligned(self): # In this test, the op c2 has size 6 but is split into 2 slices of size 3. # The concat op (and its output, the batch norm) both have size 18. This # test verifies that the concat and batch norm are sliced according to the # sizes of c1, c2, and c3, and takes into account that c2 is also composed # of multiple slices. concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) concat_op_slice_5_8 = orm.OpSlice(self.concat_op, orm.Slice(5, 3)) concat_op_slice_8_11 = orm.OpSlice(self.concat_op, orm.Slice(8, 3)) concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7)) relu2_op_slice_0_3 = orm.OpSlice(self.relu2_op, orm.Slice(0, 3)) relu2_op_slice_3_6 = orm.OpSlice(self.relu2_op, orm.Slice(3, 3)) relu2_op_group1 = orm.OpGroup( relu2_op_slice_0_3, omit_source_op_slices=[relu2_op_slice_0_3]) relu2_op_group2 = orm.OpGroup( relu2_op_slice_3_6, omit_source_op_slices=[relu2_op_slice_3_6]) batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18)) batch_norm_op_group = orm.OpGroup( batch_norm_op_slice, omit_source_op_slices=[batch_norm_op_slice]) batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) batch_norm_op_group1 = orm.OpGroup( batch_norm_op_slice_0_5, omit_source_op_slices=[batch_norm_op_slice_0_5]) batch_norm_op_slice_5_8 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 3)) batch_norm_op_group2 = orm.OpGroup( batch_norm_op_slice_5_8, omit_source_op_slices=[batch_norm_op_slice_5_8]) batch_norm_op_slice_8_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(8, 3)) batch_norm_op_group3 = orm.OpGroup( batch_norm_op_slice_8_11, omit_source_op_slices=[batch_norm_op_slice_8_11]) batch_norm_op_slice_11_18 = orm.OpSlice( self.batch_norm_op, orm.Slice(11, 7)) batch_norm_op_group4 = orm.OpGroup( batch_norm_op_slice_11_18, omit_source_op_slices=[batch_norm_op_slice_11_18]) # Map ops to slices. The op c2 is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [relu2_op_slice_0_3, relu2_op_slice_3_6], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [batch_norm_op_slice], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, relu2_op_slice_0_3: relu2_op_group1, relu2_op_slice_3_6: relu2_op_group2, self.relu3_op_slice: self.relu3_op_group, batch_norm_op_slice: batch_norm_op_group, batch_norm_op_slice_0_5: batch_norm_op_group1, batch_norm_op_slice_5_8: batch_norm_op_group2, batch_norm_op_slice_8_11: batch_norm_op_group3, batch_norm_op_slice_11_18: batch_norm_op_group4, } # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_8, batch_norm_op_slice_8_11, batch_norm_op_slice_11_18] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ concat_op_slice_0_5, concat_op_slice_5_8, concat_op_slice_8_11, concat_op_slice_11_18] self.mock_op_reg_manager.slice_op.side_effect = slice_op # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op)]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls( [mock.call(self.batch_norm_op, [5, 3, 3, 7]), mock.call(self.concat_op, [5, 3, 3, 7])]) # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_has_calls( [mock.call([concat_op_slice_0_5, self.relu1_op_slice, batch_norm_op_slice_0_5]), mock.call([concat_op_slice_5_8, relu2_op_slice_0_3, batch_norm_op_slice_5_8]), mock.call([concat_op_slice_8_11, relu2_op_slice_3_6, batch_norm_op_slice_8_11]), mock.call([concat_op_slice_11_18, self.relu3_op_slice, batch_norm_op_slice_11_18])])
def testAssignGrouping_AllNeighborsGrouped_OutputSlicesNotAligned(self): # The output (batch norm) has sizes [9, 4, 5] which are not aligned. This # test verifies that the concat, batch norm, and Conv2D ops are sliced in # alignment. concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) concat_op_slice_5_9 = orm.OpSlice(self.concat_op, orm.Slice(5, 4)) concat_op_slice_9_11 = orm.OpSlice(self.concat_op, orm.Slice(9, 2)) concat_op_slice_11_13 = orm.OpSlice(self.concat_op, orm.Slice(11, 2)) concat_op_slice_13_18 = orm.OpSlice(self.concat_op, orm.Slice(13, 5)) relu2_op_slice_0_4 = orm.OpSlice(self.relu2_op, orm.Slice(0, 4)) relu2_op_slice_4_6 = orm.OpSlice(self.relu2_op, orm.Slice(4, 2)) relu3_op_slice_0_2 = orm.OpSlice(self.relu3_op, orm.Slice(0, 2)) relu3_op_slice_2_7 = orm.OpSlice(self.relu3_op, orm.Slice(2, 5)) batch_norm_op_slice_0_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 9)) batch_norm_op_group1 = orm.OpGroup( batch_norm_op_slice_0_9, omit_source_op_slices=[batch_norm_op_slice_0_9]) batch_norm_op_slice_9_13 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 4)) batch_norm_op_group2 = orm.OpGroup( batch_norm_op_slice_9_13, omit_source_op_slices=[batch_norm_op_slice_9_13]) batch_norm_op_slice_13_18 = orm.OpSlice( self.batch_norm_op, orm.Slice(13, 5)) batch_norm_op_group3 = orm.OpGroup( batch_norm_op_slice_13_18, omit_source_op_slices=[batch_norm_op_slice_13_18]) batch_norm_op_slice_0_5 = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5)) batch_norm_op_group4 = orm.OpGroup( batch_norm_op_slice_0_5, omit_source_op_slices=[batch_norm_op_slice_0_5]) batch_norm_op_slice_5_9 = orm.OpSlice(self.batch_norm_op, orm.Slice(5, 4)) batch_norm_op_group5 = orm.OpGroup( batch_norm_op_slice_5_9, omit_source_op_slices=[batch_norm_op_slice_5_9]) batch_norm_op_slice_9_11 = orm.OpSlice(self.batch_norm_op, orm.Slice(9, 2)) batch_norm_op_group6 = orm.OpGroup( batch_norm_op_slice_9_11, omit_source_op_slices=[batch_norm_op_slice_9_11]) batch_norm_op_slice_11_13 = orm.OpSlice( self.batch_norm_op, orm.Slice(11, 2)) batch_norm_op_group7 = orm.OpGroup( batch_norm_op_slice_11_13, omit_source_op_slices=[batch_norm_op_slice_11_13]) batch_norm_op_slice_13_18 = orm.OpSlice( self.batch_norm_op, orm.Slice(11, 5)) batch_norm_op_group8 = orm.OpGroup( batch_norm_op_slice_13_18, omit_source_op_slices=[batch_norm_op_slice_13_18]) # Map ops to slices. Batch norm op is composed of multiple slices. self.op_slice_dict = { self.relu1_op: [self.relu1_op_slice], self.relu2_op: [self.relu2_op_slice], self.relu3_op: [self.relu3_op_slice], self.concat_op: [self.concat_op_slice], self.batch_norm_op: [batch_norm_op_slice_0_9, batch_norm_op_slice_9_13, batch_norm_op_slice_13_18], } # Map each slice to a group. self.op_group_dict = { self.relu1_op_slice: self.relu1_op_group, self.relu2_op_slice: self.relu2_op_group, self.relu3_op_slice: self.relu3_op_group, batch_norm_op_slice_0_9: batch_norm_op_group1, batch_norm_op_slice_9_13: batch_norm_op_group2, batch_norm_op_slice_13_18: batch_norm_op_group3, batch_norm_op_slice_0_5: batch_norm_op_group4, batch_norm_op_slice_5_9: batch_norm_op_group5, batch_norm_op_slice_9_11: batch_norm_op_group6, batch_norm_op_slice_11_13: batch_norm_op_group7, batch_norm_op_slice_13_18: batch_norm_op_group8, } # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ batch_norm_op_slice_0_5, batch_norm_op_slice_5_9, batch_norm_op_slice_9_11, batch_norm_op_slice_11_13, batch_norm_op_slice_13_18] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ concat_op_slice_0_5, concat_op_slice_5_9, concat_op_slice_9_11, concat_op_slice_11_13, concat_op_slice_13_18] if op == self.relu2_op: self.op_slice_dict[self.relu2_op] = [ relu2_op_slice_0_4, relu2_op_slice_4_6] if op == self.relu3_op: self.op_slice_dict[self.relu3_op] = [ relu3_op_slice_0_2, relu3_op_slice_2_7] self.mock_op_reg_manager.slice_op.side_effect = slice_op # Call handler to assign grouping. handler = concat_op_handler.ConcatOpHandler() handler.assign_grouping(self.concat_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Initial slice data. mock.call(self.concat_op), mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Reslicing. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), mock.call(self.concat_op), # Refreshing slice data. mock.call(self.relu1_op), mock.call(self.relu2_op), mock.call(self.relu3_op), mock.call(self.batch_norm_op), # Group concat op. mock.call(self.concat_op)]) # Verify manager slices ops that do not have aligned OpSlice sizes. self.mock_op_reg_manager.slice_op.assert_has_calls( [mock.call(self.relu2_op, [4, 2]), mock.call(self.relu3_op, [2, 5]), mock.call(self.batch_norm_op, [5, 4, 2, 2, 5]), mock.call(self.concat_op, [5, 4, 2, 2, 5])]) # Verify manager groups the new slices. self.mock_op_reg_manager.group_op_slices.assert_has_calls( [mock.call([concat_op_slice_0_5, self.relu1_op_slice, batch_norm_op_slice_0_5]), mock.call([concat_op_slice_5_9, relu2_op_slice_0_4, batch_norm_op_slice_5_9]), mock.call([concat_op_slice_9_11, relu2_op_slice_4_6, batch_norm_op_slice_9_11]), mock.call([concat_op_slice_11_13, relu3_op_slice_0_2, batch_norm_op_slice_11_13]), mock.call([concat_op_slice_13_18, relu3_op_slice_2_7, batch_norm_op_slice_13_18])])
def setUp(self): tf.reset_default_graph() # This tests 3 Conv2D ops being concatenated. inputs = tf.zeros([2, 4, 4, 3]) c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2') c3 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv3') net = tf.concat([c1, c2, c3], axis=3) layers.batch_norm(net) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops of interest. self.concat_op = g.get_operation_by_name('concat') self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18)) self.concat_op_slice_0_5 = orm.OpSlice(self.concat_op, orm.Slice(0, 5)) self.concat_op_slice_5_11 = orm.OpSlice(self.concat_op, orm.Slice(5, 6)) self.concat_op_slice_11_18 = orm.OpSlice(self.concat_op, orm.Slice(11, 7)) self.concat_op_group1 = orm.OpGroup( self.concat_op_slice_0_5, omit_source_op_slices=[self.concat_op_slice_0_5]) self.concat_op_group2 = orm.OpGroup( self.concat_op_slice_5_11, omit_source_op_slices=[self.concat_op_slice_5_11]) self.concat_op_group3 = orm.OpGroup( self.concat_op_slice_11_18, omit_source_op_slices=[self.concat_op_slice_11_18]) self.relu1_op = g.get_operation_by_name('conv1/Relu') self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 5)) self.relu1_op_group = orm.OpGroup( self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice]) self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.relu3_op = g.get_operation_by_name('conv3/Relu') self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 7)) self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) self.axis_op = g.get_operation_by_name('concat/axis') self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18)) self.batch_norm_op_group = orm.OpGroup( self.batch_norm_op_slice, omit_source_op_slices=[self.batch_norm_op_slice]) self.batch_norm_op_slice_0_5 = orm.OpSlice( self.batch_norm_op, orm.Slice(0, 5)) self.batch_norm_op_slice_5_11 = orm.OpSlice( self.batch_norm_op, orm.Slice(5, 6)) self.batch_norm_op_slice_11_18 = orm.OpSlice( self.batch_norm_op, orm.Slice(11, 7)) self.batch_norm_op_group1 = orm.OpGroup( self.batch_norm_op_slice_0_5, omit_source_op_slices=[self.batch_norm_op_slice_0_5]) self.batch_norm_op_group2 = orm.OpGroup( self.batch_norm_op_slice_5_11, omit_source_op_slices=[self.batch_norm_op_slice_5_11]) self.batch_norm_op_group3 = orm.OpGroup( self.batch_norm_op_slice_11_18, omit_source_op_slices=[self.batch_norm_op_slice_11_18]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) # Update op_slice_dict when an op is sliced. def slice_op(op, _): if op == self.batch_norm_op: self.op_slice_dict[self.batch_norm_op] = [ self.batch_norm_op_slice_0_5, self.batch_norm_op_slice_5_11, self.batch_norm_op_slice_11_18] if op == self.concat_op: self.op_slice_dict[self.concat_op] = [ self.concat_op_slice_0_5, self.concat_op_slice_5_11, self.concat_op_slice_11_18] self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_source_op.return_value = False self.mock_op_reg_manager.slice_op.side_effect = slice_op self.mock_op_reg_manager.is_passthrough.return_value = True self.mock_op_reg_manager.ops = [ self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op, self.batch_norm_op]
def setUp(self): tf.reset_default_graph() # This tests a Conv2D -> BatchNorm -> ReLU chain of ops. with tf.contrib.framework.arg_scope(self._batch_norm_scope()): inputs = tf.zeros([2, 4, 4, 3]) layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1') # This tests 3 Conv2D ops being concatenated before a batch normalization. c2 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv2') c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3') c4 = layers.conv2d(inputs, num_outputs=7, kernel_size=3, scope='conv4') net = tf.concat([c2, c3, c4], axis=3) layers.batch_norm(net) g = tf.get_default_graph() # Declare OpSlice and OpGroup for ops in the first test network. self.batch_norm_op = g.get_operation_by_name( 'conv1/BatchNorm/FusedBatchNormV3') self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, None) self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice) self.conv_op = g.get_operation_by_name('conv1/Conv2D') self.conv_op_slice = orm.OpSlice(self.conv_op, None) self.conv_op_group = orm.OpGroup( self.conv_op_slice, omit_source_op_slices=[self.conv_op_slice]) self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read') self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read') self.decay_op = g.get_operation_by_name('conv1/BatchNorm/Const') self.epsilon_op = g.get_operation_by_name('conv1/BatchNorm/Const_1') self.mean_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg/sub_1') self.std_op = g.get_operation_by_name( 'conv1/BatchNorm/AssignMovingAvg_1/sub_1') self.relu_op = g.get_operation_by_name('conv1/Relu') self.relu_op_slice = orm.OpSlice(self.relu_op, None) self.relu_op_group = orm.OpGroup( self.relu_op_slice, omit_source_op_slices=[self.relu_op_slice]) # Declare OpSlice and OpGroup for ops in the second test network. self.relu2_op = g.get_operation_by_name('conv2/Relu') self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 5)) self.relu2_op_group = orm.OpGroup( self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice]) self.relu3_op = g.get_operation_by_name('conv3/Relu') self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6)) self.relu3_op_group = orm.OpGroup( self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice]) self.relu4_op = g.get_operation_by_name('conv4/Relu') self.relu4_op_slice = orm.OpSlice(self.relu4_op, orm.Slice(0, 7)) self.relu4_op_group = orm.OpGroup( self.relu4_op_slice, omit_source_op_slices=[self.relu4_op_slice]) self.unfused_batch_norm_op = g.get_operation_by_name( 'BatchNorm/FusedBatchNormV3') self.unfused_batch_norm_op_slice = orm.OpSlice( self.unfused_batch_norm_op, orm.Slice(0, 18)) self.concat_op = g.get_operation_by_name('concat') self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 18)) self.concat_op_group = orm.OpGroup( self.concat_op_slice, omit_source_op_slices=[self.concat_op_slice]) # Create mock OpRegularizerManager with custom mapping of OpSlice and # OpGroup. self.mock_op_reg_manager = mock.create_autospec( orm.OpRegularizerManager) def get_op_slices(op): return self.op_slice_dict.get(op, []) def get_op_group(op_slice): return self.op_group_dict.get(op_slice) def is_passthrough(op): return op in self._passthrough_ops self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices self.mock_op_reg_manager.get_op_group.side_effect = get_op_group self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough self.mock_op_reg_manager.ops = [ self.batch_norm_op, self.gamma_op, self.beta_op, self.decay_op, self.epsilon_op, self.mean_op, self.std_op, self.conv_op, self.relu_op, self.relu2_op, self.relu3_op, self.relu4_op, self.unfused_batch_norm_op, self.concat_op ]