def apply(self, arg_values, settings): """See base class.""" try: return value.OperationValue((arg_values[0].value, arg_values[1].value), self, arg_values) except ValueError: return None
def extract_values_with_collapsed_subtrees( value: value_module.Value) -> List[value_module.Value]: """Collapses subtrees, see docstring of extract_examples_from_value.""" if not isinstance(value, value_module.OperationValue): # Base case: the value is a leaf. return [value] value = typing.cast(value_module.OperationValue, value) results = [] # type: List[value_module.Value] # Each OperationApplication namedtuple represents one way of reaching this # Value by applying an operation to some arguments. Choose one at random. # If we recursed through all possible choices, we would heavily bias the # dataset toward Values with many possible expressions, e.g., those using many # commutative ops. operation_application = random.choice(value.operation_applications) children_possibilities = [extract_values_with_collapsed_subtrees(child) for child in operation_application.arg_values] for children_choices in itertools.product(*children_possibilities): results.append(value_module.OperationValue( value=value.value, operation=operation_application.operation, arg_values=children_choices)) # Include the case where the entire AST is replaced with a new input. Skip # tensor conversion, meaning don't convert a tuple of tensors into a single # tensor with one greater rank; keep the object as-is. results.append(value_module.InputValue(value=value.value, name=NEW_INPUT_NAME, skip_tensor_conversion=True)) return results
def apply(self, arg_values, settings): """See base class.""" try: return value.OperationValue( arg_values[0].value[:, int(arg_values[1].value)], self, arg_values) except Exception: # pylint: disable=broad-except return None
def apply(self, arg_values, settings): """See base class.""" try: result = self._evaluate_slice( [arg_values[0].value] + [int(arg_value.value) for arg_value in arg_values[1:]]) return value.OperationValue(result, self, arg_values) except Exception: # pylint: disable=broad-except return None
def apply(self, arg_values, settings): """See base class.""" value_objects = [arg_value.value for arg_value in arg_values] arg_dict = dict(zip(self.arg_names, value_objects)) arg_dict.update(self.constant_kwargs) try: result_value = self._function_obj(**arg_dict) except Exception: # pylint: disable=broad-except return None try: return value.OperationValue(result_value, self, arg_values) except ValueError: if settings.printing.tensor_size_warnings: self._print_warnings(arg_values, result_value) return None
def test_reconstruct_expression(self): tensor_value = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input') axis_value = value.ConstantValue(1) operation_value = value.OperationValue(tf.constant([3, 4]), self.operation, [tensor_value, axis_value]) expected_expression = 'tf.reduce_max(my_input, axis=1)' self.assertEqual(operation_value.reconstruct_expression(), expected_expression) self.assertEqual(operation_value.reconstruct_expression(), expected_expression) # Cached. operation_value_apply = self.operation.apply([tensor_value, axis_value], self.settings) self.assertEqual(operation_value, operation_value_apply) self.assertEqual(operation_value_apply.reconstruct_expression(), expected_expression)
def apply(self, arg_values, settings): """See base class.""" return value.OperationValue(arg_values[0].value + arg_values[1].value, self, arg_values)