def __init__(self, num_args, function_info):
     """Creates a DummyOperation with the given arity and FunctionInfo."""
     metadata = operation_base.OperationMetadata(
         docstring='dummy operation')
     super(DummyOperation, self).__init__(num_args,
                                          weight=1,
                                          metadata=metadata)
     self.function_info = function_info
示例#2
0
 def __init__(self):
   num_args = self._num_args()
   min_rank = self._min_rank()
   docstring = self._get_docstring()
   metadata = operation_base.OperationMetadata(docstring=docstring)
   super(SlicingBaseOperation, self).__init__(
       num_args=num_args, weight=SLICING_WEIGHT, metadata=metadata)
   self.add_value_filters(
       [filtering.get_tensor_min_rank_filter(min_rank)] +
       [filtering.INT_OR_INT_TENSOR_FILTER] * (num_args - 1))
   self.set_apply_filter(self._get_apply_filter())
示例#3
0
  def __init__(self):
    metadata = operation_base.OperationMetadata(
        docstring=SINGLETON_TUPLE_CREATION_DOCSTRING)
    super(SingletonTupleCreationOperation, self).__init__(
        num_args=1, weight=SINGLETON_TUPLE_CREATION_WEIGHT, metadata=metadata)

    def _primitives_or_sequences_filter(arg_value):
      """The result must be a possibly-nested list of primitives."""
      return (arg_value.is_primitive or
              filtering.TENSOR_LIKE_SEQUENCE_FILTER(arg_value))
    self.add_value_filters([_primitives_or_sequences_filter])
示例#4
0
  def __init__(self):
    metadata = operation_base.OperationMetadata(
        docstring=INDEXING_AXIS_1_DOCSTRING)
    super(IndexingAxis1Operation, self).__init__(
        num_args=2, weight=INDEXING_WEIGHT, metadata=metadata)

    self.add_value_filters([filtering.get_tensor_min_rank_filter(2),
                            filtering.INT_OR_INT_TENSOR_FILTER])
    def _check_index(arg_values):
      """Checks that the index is in range for the tensor."""
      tensor, index = arg_values
      length = tensor.shape[1]
      return -length <= int(index.value) < length
    self.set_apply_filter(_check_index)
示例#5
0
 def __init__(self):
     metadata = operation_base.OperationMetadata(docstring='test docstring')
     super(StrangeAdditionOperation, self).__init__(2,
                                                    weight=5,
                                                    metadata=metadata)
     self.add_value_filters([
         lambda arg_value: arg_value.value % 2 == 0,
         lambda arg_value: arg_value.value % 3 == 0
     ])
     self.add_value_filters([
         lambda arg_value: arg_value.value % 4 == 0,
         lambda arg_value: arg_value.value % 5 == 0
     ])
     self.set_apply_filter(
         lambda arg_values: arg_values[0].value < arg_values[1].value)
    def __init__(self, function_info):
        """Creates a FunctionOperation.

    Args:
      function_info: A tf_functions.FunctionInfo.
    """
        function_name, arg_names, constant_kwargs = (
            tf_functions.parse_function_info_name(function_info))
        self._function_obj = tf_coder_utils.get_tf_function(function_name)
        docstring = self._function_obj.__doc__
        if not docstring:
            print('Warning: could not get docstring for function {}'.format(
                function_name))
            docstring = ''

        # Make sure the function and argument names appear in the docstring. (Args
        # should already appear in the docstring "Args" section though.)
        docstring += '\n' + function_info.name
        # If 'reduce_max' is the function name, make sure 'reduce' and 'max' also
        # appear as separate words. Ditto for argument names as well.
        docstring += '\n' + function_info.name.replace('_', ' ')
        # Upweight the function name (moreso than the argument names).
        function_name_without_tf = re.sub(r'^tf\.', '', function_name)
        docstring += ('\n' + function_name_without_tf) * 4
        if '_' in function_name_without_tf:
            docstring += ('\n' +
                          function_name_without_tf.replace('_', ' ')) * 2

        metadata = operation_base.OperationMetadata(docstring=docstring)
        super(FunctionOperation, self).__init__(num_args=len(arg_names),
                                                weight=function_info.weight,
                                                metadata=metadata)

        self.function_info = function_info
        self.function_name = function_name
        self.arg_names = arg_names
        self.constant_kwargs = constant_kwargs

        self._has_default = {}

        parameters = funcsigs.signature(self._function_obj).parameters
        for arg_name in arg_names:
            param = parameters[arg_name]
            has_default = param.default is not param.empty
            self._has_default[arg_name] = has_default

        operation_filtering.add_filters_to_function_operation(self)
示例#7
0
    def __init__(self):
        metadata = operation_base.OperationMetadata(
            docstring=TRIPLE_CREATION_DOCSTRING)
        super(TripleCreationOperation,
              self).__init__(num_args=3,
                             weight=TRIPLE_CREATION_WEIGHT,
                             metadata=metadata)

        for primitive_type in tf_coder_utils.PRIMITIVE_TYPES:
            self.add_value_filters(
                [filtering.get_type_filter(primitive_type)] * 3)

        def _tensor_value_filter(arg_value):
            """Only keeps values that are "small" tensors."""
            return (arg_value.is_tensor and
                    arg_value.num_elements() * 3 <= limits.MAX_TENSOR_ELEMENTS)

        self.add_value_filters([_tensor_value_filter] * 3)

        self.add_value_filters([filtering.TENSOR_LIKE_SEQUENCE_FILTER] * 3)

        def _apply_filter(arg_values):
            """Ensures dtype and shape compatibility."""
            first, second, third = arg_values
            if first.is_tensor:  # Implies second and third are also tensors.
                if not first.dtype == second.dtype == third.dtype:
                    return False
                shape_1 = first.shape
                shape_2 = second.shape
                shape_3 = third.shape
                if not len(shape_1) == len(shape_2) == len(shape_3):
                    return False
                num_different = sum(
                    not len_1 == len_2 == len_3
                    for len_1, len_2, len_3 in zip(shape_1, shape_2, shape_3))
                return num_different <= 1
            elif first.is_sequence:  # Implies second and third is also sequences.
                return (first.sequence_dtype == second.sequence_dtype ==
                        third.sequence_dtype and first.sequence_shape ==
                        second.sequence_shape == third.sequence_shape)
            else:  # Implies that all args are same-type primitives.
                return True

        self.set_apply_filter(_apply_filter)
示例#8
0
  def __init__(self):
    metadata = operation_base.OperationMetadata(docstring=INDEXING_DOCSTRING)
    super(IndexingOperation, self).__init__(
        num_args=2, weight=INDEXING_WEIGHT, metadata=metadata)

    def _sequence_or_tensor(arg_value):
      return (arg_value.is_sequence or
              (arg_value.is_tensor and len(arg_value.shape)))
    self.add_value_filters([_sequence_or_tensor,
                            filtering.INT_OR_INT_TENSOR_FILTER])
    def _check_index(arg_values):
      """Checks that the index is in range for the sequence or tensor."""
      indexable, index = arg_values
      if indexable.is_sequence:
        length = len(indexable.value)
      else:  # indexable is a tensor with at least 1 dimension.
        length = indexable.shape[0]
      return -length <= int(index.value) < length
    self.set_apply_filter(_check_index)
示例#9
0
    def __init__(self):
        metadata = operation_base.OperationMetadata(
            docstring=PAIR_CREATION_DOCSTRING)
        super(PairCreationOperation,
              self).__init__(num_args=2,
                             weight=PAIR_CREATION_WEIGHT,
                             metadata=metadata)

        for primitive_type in tf_coder_utils.PRIMITIVE_TYPES:
            self.add_value_filters(
                [filtering.get_type_filter(primitive_type)] * 2)

        def _tensor_value_filter(arg_value):
            """Only keeps values that are "small" tensors."""
            return (arg_value.is_tensor and
                    arg_value.num_elements() * 2 <= limits.MAX_TENSOR_ELEMENTS)

        self.add_value_filters([_tensor_value_filter] * 2)

        self.add_value_filters([filtering.TENSOR_LIKE_SEQUENCE_FILTER] * 2)

        def _apply_filter(arg_values):
            """Ensures dtype and shape compatibility."""
            first, second = arg_values
            if first.is_tensor:  # Implies second is also a tensor.
                if first.dtype != second.dtype:
                    return False
                shape_1 = first.shape
                shape_2 = second.shape
                if shape_1 == shape_2:
                    return True
                if len(shape_1) != len(shape_2):
                    return False
                num_different = sum(len_1 != len_2
                                    for len_1, len_2 in zip(shape_1, shape_2))
                return num_different <= 1
            elif first.is_sequence:  # Implies second is also a sequence.
                return (first.sequence_dtype == second.sequence_dtype
                        and first.sequence_shape == second.sequence_shape)
            else:  # Implies first and second are same-type primitives.
                return True

        self.set_apply_filter(_apply_filter)