def testAssignGrouping_GroupWithOutputOnly(self):
        # Map ops to slices.
        self.op_slice_dict = {
            self.conv1_op: [self.conv1_op_slice],
            self.relu1_op: [self.relu1_op_slice],
            self.conv2_op: [self.conv2_op_slice],
            self.relu2_op: [self.relu2_op_slice],
        }

        # Map each slice to a group. Corresponding op slices have the same group.
        self.op_group_dict = {
            self.conv2_op_slice: self.conv2_op_group,
        }

        # Call handler to assign grouping.
        handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
            _DEFAULT_THRESHOLD)
        handler.assign_grouping(self.conv2_op, self.mock_op_reg_manager)

        # Verify manager looks up op slice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_any_call(self.conv2_op)

        # Verify manager does not slice any ops.
        self.mock_op_reg_manager.slice_op.assert_not_called()

        # Verify manager adds inputs to process queue.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.relu1_op])
 def is_passthrough(op):
     if op in [self.conv1_op, self.conv2_op]:
         h = conv2d_source_op_handler.Conv2DSourceOpHandler(
             _DEFAULT_THRESHOLD)
         return h.is_passthrough
     else:
         return False
  def __init__(
      self,
      ops,
      threshold,
      l1_fraction=0,
      regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
      decorator_parameters=None,
      force_group=None,
      regularizer_blacklist=None):
    """Creates a GroupLassoActivationRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all the
        ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      threshold: A float scalar, will be used as a 'threshold' for all
        regularizer instances created by this class.
      l1_fraction: Relative weight of L1 in L1 + L2 regularization.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
    conv2d_handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
        threshold, l1_fraction)
    conv2d_transpose_handler = (
        conv2d_transpose_source_op_handler.Conv2DTransposeSourceOpHandler(
            threshold, l1_fraction))
    matmul_handler = matmul_source_op_handler.MatMulSourceOpHandler(
        threshold, l1_fraction)
    if regularizer_decorator:
      conv2d_handler = op_handler_decorator.OpHandlerDecorator(
          conv2d_handler, regularizer_decorator, decorator_parameters)
      conv2d_transpose_handler = op_handler_decorator.OpHandlerDecorator(
          conv2d_transpose_handler, regularizer_decorator, decorator_parameters)
      matmul_handler = op_handler_decorator.OpHandlerDecorator(
          matmul_handler, regularizer_decorator, decorator_parameters)

    op_handler_dict = op_handlers.get_group_lasso_op_handler_dict()
    op_handler_dict.update({
        'Conv2D': conv2d_handler,
        'Conv2DBackpropInput': conv2d_transpose_handler,
        'MatMul': matmul_handler,
    })

    self._manager = orm.OpRegularizerManager(
        ops,
        op_handler_dict,
        force_group=force_group,
        regularizer_blacklist=regularizer_blacklist)
    self._calculator = cost_calculator.CostCalculator(
        self._manager, resource_function.activation_count_function)
    def testCreateRegularizer(self):
        # Call handler to create regularizer.
        handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
            _DEFAULT_THRESHOLD)
        regularizer = handler.create_regularizer(self.conv2_op_slice)

        # Verify regularizer produces correctly shaped tensors.
        # Most of the regularizer testing is in group_lasso_regularizer_test.py
        expected_norm_dim = self.conv2_op.inputs[1].shape.as_list()[-1]
        self.assertEqual(expected_norm_dim,
                         regularizer.regularization_vector.shape.as_list()[0])
Exemple #5
0
    def testOpHandlerDecorator(self):
        image = tf.constant(0.0, shape=[1, 17, 19, 3])
        kernel = tf.ones([5, 5, 3, 3])

        output = tf.nn.conv2d(image,
                              kernel,
                              strides=[1, 1, 1, 1],
                              padding='SAME')

        decorated_op_handler = op_handler_decorator.OpHandlerDecorator(
            conv2d_source_op_handler.Conv2DSourceOpHandler(1e-3, 0),
            DummyDecorator)
        op_slice = orm.OpSlice(output.op, orm.Slice(0, 3))
        regularizer = decorated_op_handler.create_regularizer(op_slice)

        self.assertAllClose(0.5 * np.ones(3),
                            regularizer.regularization_vector)
        self.assertAllClose(np.ones(3), regularizer.alive_vector)
    def __init__(self,
                 output_boundary,
                 threshold,
                 l1_fraction=0,
                 regularizer_decorator=None,
                 decorator_parameters=None,
                 input_boundary=None,
                 force_group=None,
                 regularizer_blacklist=None):
        """Creates a GroupLassoFlopsRegularizer object.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      threshold: A float scalar, will be used as a 'threshold' for all
        regularizer instances created by this class.
      l1_fraction: Relative weight of L1 in L1 + L2 regularization.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
    """
        conv2d_handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
            threshold, l1_fraction)
        conv2d_transpose_handler = (
            conv2d_transpose_source_op_handler.Conv2DTransposeSourceOpHandler(
                threshold, l1_fraction))
        matmul_handler = matmul_source_op_handler.MatMulSourceOpHandler(
            threshold, l1_fraction)
        if regularizer_decorator:
            conv2d_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_handler, regularizer_decorator, decorator_parameters)
            conv2d_transpose_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_transpose_handler, regularizer_decorator,
                decorator_parameters)
            matmul_handler = op_handler_decorator.OpHandlerDecorator(
                matmul_handler, regularizer_decorator, decorator_parameters)

        op_handler_dict = op_handlers.get_group_lasso_op_handler_dict()
        op_handler_dict.update({
            'Conv2D': conv2d_handler,
            'Conv2DBackpropInput': conv2d_transpose_handler,
            'MatMul': matmul_handler,
        })

        self._manager = orm.OpRegularizerManager(
            output_boundary,
            op_handler_dict,
            input_boundary=input_boundary,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.flop_function)
Exemple #7
0
    def __init__(self,
                 output_boundary,
                 threshold,
                 hardware,
                 batch_size=1,
                 l1_fraction=0,
                 regularizer_decorator=None,
                 decorator_parameters=None,
                 input_boundary=None,
                 force_group=None,
                 regularizer_blacklist=None,
                 convert_to_variable=True):
        """Creates a GroupLassoFlopsRegularizer object.

    Args:
      output_boundary: An OpRegularizer will be created for all these
        operations, and recursively for all ops they depend on via data
        dependency that does not involve ops from input_boundary.
      threshold: A float scalar, will be used as a 'threshold' for all
        regularizer instances created by this class.
      hardware: String name of hardware platform to target. Must be a key from
        resource_function.PEAK_COMPUTE.
      batch_size: Integer batch size to calculate cost/loss for.
      l1_fraction: Relative weight of L1 in L1 + L2 regularization.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      input_boundary: A list of ops that represent the input boundary of the
        subgraph being regularized (input boundary is not regularized).
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
      convert_to_variable: If `True` convert to variable in the
        `GroupLassoBaseOpHandler`. If your graph creates variables outside of
        `tf.get_variable`, set to `False`.
    """
        conv2d_handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
            threshold, l1_fraction, convert_to_variable)
        conv2d_transpose_handler = (
            conv2d_transpose_source_op_handler.Conv2DTransposeSourceOpHandler(
                threshold, l1_fraction, convert_to_variable))
        matmul_handler = matmul_source_op_handler.MatMulSourceOpHandler(
            threshold, l1_fraction, convert_to_variable)
        if regularizer_decorator:
            conv2d_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_handler, regularizer_decorator, decorator_parameters)
            conv2d_transpose_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_transpose_handler, regularizer_decorator,
                decorator_parameters)
            matmul_handler = op_handler_decorator.OpHandlerDecorator(
                matmul_handler, regularizer_decorator, decorator_parameters)

        op_handler_dict = op_handlers.get_group_lasso_op_handler_dict()
        op_handler_dict.update({
            'Conv2D': conv2d_handler,
            'Conv2DBackpropInput': conv2d_transpose_handler,
            'MatMul': matmul_handler,
        })

        self._manager = orm.OpRegularizerManager(
            output_boundary,
            op_handler_dict,
            input_boundary=input_boundary,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager,
            resource_function.latency_function_factory(hardware, batch_size))
        self._hardware = hardware
Exemple #8
0
    def __init__(self,
                 ops,
                 threshold,
                 l1_fraction=0,
                 regularizer_decorator: Type[
                     generic_regularizers.OpRegularizer] = None,
                 decorator_parameters=None,
                 force_group=None,
                 regularizer_blacklist=None,
                 convert_to_variable=True):
        """Creates a GroupLassoFlopsRegularizer object.

    Args:
      ops: A list of tf.Operation. An OpRegularizer will be created for all the
        ops in `ops`, and recursively for all ops they depend on via data
        dependency. Typically `ops` would contain a single tf.Operation, which
        is the output of the network.
      threshold: A float scalar, will be used as a 'threshold' for all
        regularizer instances created by this class.
      l1_fraction: Relative weight of L1 in L1 + L2 regularization.
      regularizer_decorator: A class of OpRegularizer decorator to use.
      decorator_parameters: A dictionary of parameters to pass to the decorator
        factory. To be used only with decorators that requires parameters,
        otherwise use None.
      force_group: List of regex for ops that should be force-grouped.  Each
        regex corresponds to a separate group.  Use '|' operator to specify
        multiple patterns in a single regex. See op_regularizer_manager for more
        detail.
      regularizer_blacklist: List of regex for ops that should not be
        regularized. See op_regularizer_manager for more detail.
      convert_to_variable: If `True` convert to variable in the
        `GroupLassoBaseOpHandler`. If you're graph creates variables outside of
        `tf.get_variable`, set to `False`.
    """
        conv2d_handler = conv2d_source_op_handler.Conv2DSourceOpHandler(
            threshold, l1_fraction, convert_to_variable)
        conv2d_transpose_handler = (
            conv2d_transpose_source_op_handler.Conv2DTransposeSourceOpHandler(
                threshold, l1_fraction, convert_to_variable))
        matmul_handler = matmul_source_op_handler.MatMulSourceOpHandler(
            threshold, l1_fraction, convert_to_variable)
        if regularizer_decorator:
            conv2d_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_handler, regularizer_decorator, decorator_parameters)
            conv2d_transpose_handler = op_handler_decorator.OpHandlerDecorator(
                conv2d_transpose_handler, regularizer_decorator,
                decorator_parameters)
            matmul_handler = op_handler_decorator.OpHandlerDecorator(
                matmul_handler, regularizer_decorator, decorator_parameters)

        op_handler_dict = collections.defaultdict(
            grouping_op_handler.GroupingOpHandler)
        op_handler_dict.update({
            'Conv2D':
            conv2d_handler,
            'Conv2DBackpropInput':
            conv2d_transpose_handler,
            'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
            'DepthToSpace':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
            'MatMul':
            matmul_handler,
            'RandomUniform':
            leaf_op_handler.LeafOpHandler(),
            'Reshape':
            leaf_op_handler.LeafOpHandler(),
            'Shape':
            leaf_op_handler.LeafOpHandler(),
            'TensorArrayGatherV3':
            leaf_op_handler.LeafOpHandler(),
            'Transpose':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
            'StridedSlice':
            leaf_op_handler.LeafOpHandler(),
        })

        self._manager = orm.OpRegularizerManager(
            ops,
            op_handler_dict,
            force_group=force_group,
            regularizer_blacklist=regularizer_blacklist)
        self._calculator = cost_calculator.CostCalculator(
            self._manager, resource_function.flop_function)