Beispiel #1
0
  def testIsBroadcast(self):
    handler = grouping_op_handler.GroupingOpHandler()
    self.op_group_dict = {}

    # Size is not 1.
    self.assertFalse(handler._is_broadcast(self.batch_norm_op,
                                           self.mock_op_reg_manager))

    # Size is 1 but op is not grouped.
    ungrouped_broadcast_input = tf.zeros([2, 4, 4, 1])
    ungrouped_broadcast_input_slice = orm.OpSlice(ungrouped_broadcast_input,
                                                  orm.Slice(0, 1))
    self.op_slice_dict[ungrouped_broadcast_input.op] = [
        ungrouped_broadcast_input_slice]
    self.assertFalse(handler._is_broadcast(ungrouped_broadcast_input.op,
                                           self.mock_op_reg_manager))

    # Size is 1 and op is grouped.
    broadcast_input = tf.zeros([2, 4, 4, 1])
    broadcast_input_slice = orm.OpSlice(broadcast_input.op, orm.Slice(0, 1))
    self.op_slice_dict[broadcast_input.op] = [broadcast_input_slice]
    broadcast_input_group = orm.OpGroup(broadcast_input_slice)
    self.op_group_dict[broadcast_input_slice] = broadcast_input_group
    self.assertTrue(handler._is_broadcast(broadcast_input.op,
                                          self.mock_op_reg_manager))
Beispiel #2
0
  def testAssignGrouping_AllInputsGrouped(self):
    # All inputs have groups.  Some output ops (mean_op and std_op) do not have
    # groups.
    self.op_group_dict = {
        self.batch_norm_op_slice: self.batch_norm_op_group,
        self.conv_op_slice: self.conv_op_group,
        self.relu_op_slice: self.relu_op_group,
        self.gamma_op_slice: self.conv_op_group,
        self.beta_op_slice: self.conv_op_group,
    }

    # Call handler to assign grouping.
    handler = grouping_op_handler.GroupingOpHandler()
    handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        # Checking for ops to process.
        [mock.call(self.conv_op),
         mock.call(self.gamma_op),
         mock.call(self.beta_op),
         mock.call(self.relu_op),
         mock.call(self.mean_op),
         mock.call(self.std_op),
         # Initial slice data.
         mock.call(self.batch_norm_op),
         mock.call(self.conv_op),
         mock.call(self.gamma_op),
         mock.call(self.beta_op),
         mock.call(self.relu_op),
         mock.call(self.mean_op),
         mock.call(self.std_op),
         # Reslicing.
         mock.call(self.conv_op),
         mock.call(self.gamma_op),
         mock.call(self.beta_op),
         mock.call(self.batch_norm_op),
         mock.call(self.relu_op),
         mock.call(self.mean_op),
         mock.call(self.std_op),
         # Refreshing slice data.
         mock.call(self.conv_op),
         mock.call(self.gamma_op),
         mock.call(self.beta_op),
         mock.call(self.relu_op),
         mock.call(self.mean_op),
         mock.call(self.std_op),
         # Group batch norm op.
         mock.call(self.batch_norm_op)])

    # Verify manager groups batch norm with input ops.
    self.mock_op_reg_manager.group_op_slices.assert_called_once_with(
        [self.batch_norm_op_slice, self.conv_op_slice, self.gamma_op_slice,
         self.beta_op_slice])

    # Verify manager processes grouping for mean_op and std_op which do not have
    # groups.
    self.mock_op_reg_manager.process_ops.assert_called_once_with(
        [self.mean_op, self.std_op])
    self.mock_op_reg_manager.process_ops_last.assert_not_called()
Beispiel #3
0
def _get_base_op_hander_dicts():
    """Returns the base op_hander_dict for all regularizers."""
    base_dict = collections.defaultdict(
        grouping_op_handler.GroupingOpHandler, {
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'ExpandDims':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'RandomUniform':
            leaf_op_handler.LeafOpHandler(),
            'Reshape':
            leaf_op_handler.LeafOpHandler(),
            'Shape':
            leaf_op_handler.LeafOpHandler(),
            'SpaceToDepth':
            leaf_op_handler.LeafOpHandler(),
            'StridedSlice':
            leaf_op_handler.LeafOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        })
    for resize_method in RESIZE_OP_NAMES:
        # Resize* ops, second input might be a tensor which will result in an error.
        base_dict[resize_method] = grouping_op_handler.GroupingOpHandler([0])
    return base_dict
    def testAssignGrouping_AllNeighborsGroupedSameGroup(self):
        # All neighbor ops have same group as batch norm.
        self.op_group_dict = {
            self.batch_norm_op_slice: self.batch_norm_op_group,
            self.conv_op_slice: self.batch_norm_op_group,
            self.relu_op_slice: self.batch_norm_op_group,
            self.gamma_op_slice: self.batch_norm_op_group,
            self.beta_op_slice: self.batch_norm_op_group,
            self.mean_op_slice: self.batch_norm_op_group,
            self.std_op_slice: self.batch_norm_op_group,
        }

        # Call handler to assign grouping.
        handler = grouping_op_handler.GroupingOpHandler()
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.relu_op),
                mock.call(self.mean_op),
                mock.call(self.std_op),
                # Initial slice data.
                mock.call(self.batch_norm_op),
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.relu_op),
                mock.call(self.mean_op),
                mock.call(self.std_op),
                # Reslicing.
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.batch_norm_op),
                mock.call(self.relu_op),
                mock.call(self.mean_op),
                mock.call(self.std_op),
                # Refreshing slice data.
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.relu_op),
                mock.call(self.mean_op),
                mock.call(self.std_op),
                # Group batch norm op.
                mock.call(self.batch_norm_op)
            ])

        # Verify manager doesn't perform any additional grouping.
        self.mock_op_reg_manager.group_op_slices.assert_not_called()

        # Verify manager does not process any additional ops.
        self.mock_op_reg_manager.process_ops.assert_not_called()
        self.mock_op_reg_manager.process_ops_last.assert_not_called()
    def testAssignGrouping_AllOutputsGrouped(self):
        # All outputs have groups.  Input beta_op does not have a group.
        self.op_group_dict = {
            self.batch_norm_op_slice: self.batch_norm_op_group,
            self.conv_op_slice: self.conv_op_group,
            self.relu_op_slice: self.relu_op_group,
            self.gamma_op_slice: self.conv_op_group,
            self.mean_op_slice: self.relu_op_group,
            self.std_op_slice: self.relu_op_group,
        }

        # Call handler to assign grouping.
        handler = grouping_op_handler.GroupingOpHandler()
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.relu_op),
                mock.call(self.mean_op),
                mock.call(self.std_op),
                # Initial slice data.
                mock.call(self.batch_norm_op),
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.relu_op),
                mock.call(self.mean_op),
                mock.call(self.std_op),
                # Reslicing.
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.batch_norm_op),
                mock.call(self.relu_op),
                mock.call(self.mean_op),
                mock.call(self.std_op),
                # Refreshing slice data.
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.relu_op),
                mock.call(self.mean_op),
                mock.call(self.std_op)
            ])

        # Verify manager does not group.
        self.mock_op_reg_manager.group_op_slices.assert_not_called()

        # Verify manager processes all neighbors.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.beta_op])
        self.mock_op_reg_manager.process_ops_last.assert_called_once_with(
            [self.batch_norm_op])
Beispiel #6
0
  def testGetInputOutputOpSlices(self):
    input_ops = [self.conv_op, self.gamma_op, self.beta_op]
    output_ops = [self.mean_op, self.std_op, self.relu_op]

    expected_input_op_slices = [
        [self.conv_op_slice], [self.gamma_op_slice], [self.beta_op_slice]]
    expected_output_op_slices = [
        [self.mean_op_slice], [self.std_op_slice], [self.relu_op_slice]]

    # Instantiate handler.
    handler = grouping_op_handler.GroupingOpHandler()

    self.assertEqual(
        (expected_input_op_slices, expected_output_op_slices),
        handler._get_input_output_op_slices(input_ops, output_ops,
                                            self.mock_op_reg_manager))
    def testAssignGrouping_NoNeighborGroups(self):
        # No ops have groups.
        self.op_group_dict = {}

        # Call handler to assign grouping.
        handler = grouping_op_handler.GroupingOpHandler()
        handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.relu_op),
                # Initial slice data.
                mock.call(self.batch_norm_op),
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.relu_op),
                # Reslicing.
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.batch_norm_op),
                mock.call(self.relu_op),
                # Refreshing slice data.
                mock.call(self.conv_op),
                mock.call(self.gamma_op),
                mock.call(self.beta_op),
                mock.call(self.relu_op)
            ])

        # Verify manager does not group.
        self.mock_op_reg_manager.group_op_slices.assert_not_called()

        # Verify manager processes grouping for Conv2D, ReLU, and batch norm ops.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.relu_op, self.conv_op, self.gamma_op, self.beta_op])
        self.mock_op_reg_manager.process_ops_last.assert_called_once_with(
            [self.batch_norm_op])
    def assign_grouping(self, op, op_reg_manager):
        """Assign grouping to the given op and updates the manager.

    Args:
      op: tf.Operation to assign grouping to.
      op_reg_manager: OpRegularizerManager to keep track of the grouping.
    """
        concat_axis = _get_concat_op_axis(op)
        # Need to figure out the rank to know if axis is last.
        rank = len(op.inputs[0].shape)  # Rank of the first input.

        if concat_axis != -1 and concat_axis != rank - 1:
            # Concat is actually grouping inputs!
            handler = grouping_op_handler.GroupingOpHandler()
            handler.assign_grouping(op, op_reg_manager)
            return

        # If concat is of the last dimension, this is a `standard` concat.
        # TODO(a1): Consider refactoring this duplicated logic.
        # Check if all input ops have groups, or tell the manager to process them.
        input_ops = op_handler_util.get_input_ops(op, op_reg_manager)
        input_ops_without_group = op_handler_util.get_ops_without_groups(
            input_ops, op_reg_manager)

        # Check if all output ops have groups, or tell the manager to process them.
        output_ops = op_handler_util.get_output_ops(op, op_reg_manager)
        output_ops_without_group = op_handler_util.get_ops_without_groups(
            output_ops, op_reg_manager)

        # Remove non-passthrough ops from outputs ops to group with.
        output_ops = op_handler_util.remove_non_passthrough_ops(
            output_ops, op_reg_manager)

        # Only group with output ops that have the same size.  Process the ops that
        # have mismatched size.
        input_ops_to_group = input_ops
        input_ops_to_process = input_ops_without_group
        output_ops_to_group, output_ops_to_process = (
            op_handler_util.separate_same_size_ops(op, output_ops))

        # Also process ungrouped ops.
        output_ops_to_process.extend(output_ops_without_group)

        # Populate OpSlice data for all relevant ops.
        concat_op_slices = op_reg_manager.get_op_slices(op)
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Align op slices sizes if needed.
        aligned_op_slice_sizes = op_handler_util.get_aligned_op_slice_sizes(
            concat_op_slices, input_op_slices, output_op_slices)
        op_handler_util.reslice_concat_ops(input_ops_to_group,
                                           aligned_op_slice_sizes,
                                           op_reg_manager)
        op_handler_util.reslice_ops(output_ops_to_group + [op],
                                    aligned_op_slice_sizes, op_reg_manager)

        # Repopulate OpSlice data, as ops may have been resliced.
        input_op_slices, output_op_slices = self._get_input_output_op_slices(
            input_ops_to_group, output_ops_to_group, op_reg_manager)

        # Group aligned OpSlice.
        op_handler_util.group_aligned_input_output_slices(
            op, input_ops_to_process, output_ops_to_process, input_op_slices,
            output_op_slices, aligned_op_slice_sizes, op_reg_manager)
Beispiel #9
0
  def testAssignGrouping_NonPassthroughOutputsSkipped(self):
    # Designate ReLU as non-passthrough for this test to demonstrate that batch
    # norm op does not group with ReLU.
    def is_passthrough(op):
      if op == self.relu_op:
        return False
      return True

    self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough

    # All neighbor ops have groups.
    self.op_group_dict = {
        self.batch_norm_op_slice: self.batch_norm_op_group,
        self.conv_op_slice: self.conv_op_group,
        self.relu_op_slice: self.relu_op_group,
        self.gamma_op_slice: self.conv_op_group,
        self.beta_op_slice: self.conv_op_group,
        self.mean_op_slice: self.relu_op_group,
        self.std_op_slice: self.relu_op_group,
    }

    # Call handler to assign grouping.
    handler = grouping_op_handler.GroupingOpHandler()
    handler.assign_grouping(self.batch_norm_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        # Checking for ops to process.
        [mock.call(self.conv_op),
         mock.call(self.gamma_op),
         mock.call(self.beta_op),
         mock.call(self.relu_op),
         mock.call(self.mean_op),
         mock.call(self.std_op),
         # Initial slice data.
         mock.call(self.batch_norm_op),
         mock.call(self.conv_op),
         mock.call(self.gamma_op),
         mock.call(self.beta_op),
         mock.call(self.mean_op),
         mock.call(self.std_op),
         # Reslicing.
         mock.call(self.conv_op),
         mock.call(self.gamma_op),
         mock.call(self.beta_op),
         mock.call(self.batch_norm_op),
         mock.call(self.mean_op),
         mock.call(self.std_op),
         # Refreshing slice data.
         mock.call(self.conv_op),
         mock.call(self.gamma_op),
         mock.call(self.beta_op),
         mock.call(self.mean_op),
         mock.call(self.std_op),
         # Group batch norm op.
         mock.call(self.batch_norm_op)])

    # Verify manager groups batch norm with inputs and outputs.  ReLU is not
    # part of the grouping.
    self.mock_op_reg_manager.group_op_slices.assert_has_calls(
        [mock.call([self.batch_norm_op_slice, self.mean_op_slice,
                    self.std_op_slice]),
         mock.call([self.batch_norm_op_slice, self.conv_op_slice,
                    self.gamma_op_slice, self.beta_op_slice])])

    # Verify manager does not process any additional ops.
    self.mock_op_reg_manager.process_ops.assert_not_called()
    self.mock_op_reg_manager.process_ops_last.assert_not_called()