def __init__(self, num_args, function_info): """Creates a DummyOperation with the given arity and FunctionInfo.""" metadata = operation_base.OperationMetadata( docstring='dummy operation') super(DummyOperation, self).__init__(num_args, weight=1, metadata=metadata) self.function_info = function_info
def __init__(self): num_args = self._num_args() min_rank = self._min_rank() docstring = self._get_docstring() metadata = operation_base.OperationMetadata(docstring=docstring) super(SlicingBaseOperation, self).__init__( num_args=num_args, weight=SLICING_WEIGHT, metadata=metadata) self.add_value_filters( [filtering.get_tensor_min_rank_filter(min_rank)] + [filtering.INT_OR_INT_TENSOR_FILTER] * (num_args - 1)) self.set_apply_filter(self._get_apply_filter())
def __init__(self): metadata = operation_base.OperationMetadata( docstring=SINGLETON_TUPLE_CREATION_DOCSTRING) super(SingletonTupleCreationOperation, self).__init__( num_args=1, weight=SINGLETON_TUPLE_CREATION_WEIGHT, metadata=metadata) def _primitives_or_sequences_filter(arg_value): """The result must be a possibly-nested list of primitives.""" return (arg_value.is_primitive or filtering.TENSOR_LIKE_SEQUENCE_FILTER(arg_value)) self.add_value_filters([_primitives_or_sequences_filter])
def __init__(self): metadata = operation_base.OperationMetadata( docstring=INDEXING_AXIS_1_DOCSTRING) super(IndexingAxis1Operation, self).__init__( num_args=2, weight=INDEXING_WEIGHT, metadata=metadata) self.add_value_filters([filtering.get_tensor_min_rank_filter(2), filtering.INT_OR_INT_TENSOR_FILTER]) def _check_index(arg_values): """Checks that the index is in range for the tensor.""" tensor, index = arg_values length = tensor.shape[1] return -length <= int(index.value) < length self.set_apply_filter(_check_index)
def __init__(self): metadata = operation_base.OperationMetadata(docstring='test docstring') super(StrangeAdditionOperation, self).__init__(2, weight=5, metadata=metadata) self.add_value_filters([ lambda arg_value: arg_value.value % 2 == 0, lambda arg_value: arg_value.value % 3 == 0 ]) self.add_value_filters([ lambda arg_value: arg_value.value % 4 == 0, lambda arg_value: arg_value.value % 5 == 0 ]) self.set_apply_filter( lambda arg_values: arg_values[0].value < arg_values[1].value)
def __init__(self, function_info): """Creates a FunctionOperation. Args: function_info: A tf_functions.FunctionInfo. """ function_name, arg_names, constant_kwargs = ( tf_functions.parse_function_info_name(function_info)) self._function_obj = tf_coder_utils.get_tf_function(function_name) docstring = self._function_obj.__doc__ if not docstring: print('Warning: could not get docstring for function {}'.format( function_name)) docstring = '' # Make sure the function and argument names appear in the docstring. (Args # should already appear in the docstring "Args" section though.) docstring += '\n' + function_info.name # If 'reduce_max' is the function name, make sure 'reduce' and 'max' also # appear as separate words. Ditto for argument names as well. docstring += '\n' + function_info.name.replace('_', ' ') # Upweight the function name (moreso than the argument names). function_name_without_tf = re.sub(r'^tf\.', '', function_name) docstring += ('\n' + function_name_without_tf) * 4 if '_' in function_name_without_tf: docstring += ('\n' + function_name_without_tf.replace('_', ' ')) * 2 metadata = operation_base.OperationMetadata(docstring=docstring) super(FunctionOperation, self).__init__(num_args=len(arg_names), weight=function_info.weight, metadata=metadata) self.function_info = function_info self.function_name = function_name self.arg_names = arg_names self.constant_kwargs = constant_kwargs self._has_default = {} parameters = funcsigs.signature(self._function_obj).parameters for arg_name in arg_names: param = parameters[arg_name] has_default = param.default is not param.empty self._has_default[arg_name] = has_default operation_filtering.add_filters_to_function_operation(self)
def __init__(self): metadata = operation_base.OperationMetadata( docstring=TRIPLE_CREATION_DOCSTRING) super(TripleCreationOperation, self).__init__(num_args=3, weight=TRIPLE_CREATION_WEIGHT, metadata=metadata) for primitive_type in tf_coder_utils.PRIMITIVE_TYPES: self.add_value_filters( [filtering.get_type_filter(primitive_type)] * 3) def _tensor_value_filter(arg_value): """Only keeps values that are "small" tensors.""" return (arg_value.is_tensor and arg_value.num_elements() * 3 <= limits.MAX_TENSOR_ELEMENTS) self.add_value_filters([_tensor_value_filter] * 3) self.add_value_filters([filtering.TENSOR_LIKE_SEQUENCE_FILTER] * 3) def _apply_filter(arg_values): """Ensures dtype and shape compatibility.""" first, second, third = arg_values if first.is_tensor: # Implies second and third are also tensors. if not first.dtype == second.dtype == third.dtype: return False shape_1 = first.shape shape_2 = second.shape shape_3 = third.shape if not len(shape_1) == len(shape_2) == len(shape_3): return False num_different = sum( not len_1 == len_2 == len_3 for len_1, len_2, len_3 in zip(shape_1, shape_2, shape_3)) return num_different <= 1 elif first.is_sequence: # Implies second and third is also sequences. return (first.sequence_dtype == second.sequence_dtype == third.sequence_dtype and first.sequence_shape == second.sequence_shape == third.sequence_shape) else: # Implies that all args are same-type primitives. return True self.set_apply_filter(_apply_filter)
def __init__(self): metadata = operation_base.OperationMetadata(docstring=INDEXING_DOCSTRING) super(IndexingOperation, self).__init__( num_args=2, weight=INDEXING_WEIGHT, metadata=metadata) def _sequence_or_tensor(arg_value): return (arg_value.is_sequence or (arg_value.is_tensor and len(arg_value.shape))) self.add_value_filters([_sequence_or_tensor, filtering.INT_OR_INT_TENSOR_FILTER]) def _check_index(arg_values): """Checks that the index is in range for the sequence or tensor.""" indexable, index = arg_values if indexable.is_sequence: length = len(indexable.value) else: # indexable is a tensor with at least 1 dimension. length = indexable.shape[0] return -length <= int(index.value) < length self.set_apply_filter(_check_index)
def __init__(self): metadata = operation_base.OperationMetadata( docstring=PAIR_CREATION_DOCSTRING) super(PairCreationOperation, self).__init__(num_args=2, weight=PAIR_CREATION_WEIGHT, metadata=metadata) for primitive_type in tf_coder_utils.PRIMITIVE_TYPES: self.add_value_filters( [filtering.get_type_filter(primitive_type)] * 2) def _tensor_value_filter(arg_value): """Only keeps values that are "small" tensors.""" return (arg_value.is_tensor and arg_value.num_elements() * 2 <= limits.MAX_TENSOR_ELEMENTS) self.add_value_filters([_tensor_value_filter] * 2) self.add_value_filters([filtering.TENSOR_LIKE_SEQUENCE_FILTER] * 2) def _apply_filter(arg_values): """Ensures dtype and shape compatibility.""" first, second = arg_values if first.is_tensor: # Implies second is also a tensor. if first.dtype != second.dtype: return False shape_1 = first.shape shape_2 = second.shape if shape_1 == shape_2: return True if len(shape_1) != len(shape_2): return False num_different = sum(len_1 != len_2 for len_1, len_2 in zip(shape_1, shape_2)) return num_different <= 1 elif first.is_sequence: # Implies second is also a sequence. return (first.sequence_dtype == second.sequence_dtype and first.sequence_shape == second.sequence_shape) else: # Implies first and second are same-type primitives. return True self.set_apply_filter(_apply_filter)