Exemplo n.º 1
0
 def apply(self, arg_values, settings):
   """See base class."""
   try:
     return value.OperationValue((arg_values[0].value, arg_values[1].value),
                                 self, arg_values)
   except ValueError:
     return None
def extract_values_with_collapsed_subtrees(
    value: value_module.Value) -> List[value_module.Value]:
  """Collapses subtrees, see docstring of extract_examples_from_value."""
  if not isinstance(value, value_module.OperationValue):
    # Base case: the value is a leaf.
    return [value]

  value = typing.cast(value_module.OperationValue, value)

  results = []  # type: List[value_module.Value]

  # Each OperationApplication namedtuple represents one way of reaching this
  # Value by applying an operation to some arguments. Choose one at random.
  # If we recursed through all possible choices, we would heavily bias the
  # dataset toward Values with many possible expressions, e.g., those using many
  # commutative ops.
  operation_application = random.choice(value.operation_applications)

  children_possibilities = [extract_values_with_collapsed_subtrees(child)
                            for child in operation_application.arg_values]
  for children_choices in itertools.product(*children_possibilities):
    results.append(value_module.OperationValue(
        value=value.value,
        operation=operation_application.operation,
        arg_values=children_choices))

  # Include the case where the entire AST is replaced with a new input. Skip
  # tensor conversion, meaning don't convert a tuple of tensors into a single
  # tensor with one greater rank; keep the object as-is.
  results.append(value_module.InputValue(value=value.value,
                                         name=NEW_INPUT_NAME,
                                         skip_tensor_conversion=True))
  return results
Exemplo n.º 3
0
 def apply(self, arg_values, settings):
   """See base class."""
   try:
     return value.OperationValue(
         arg_values[0].value[:, int(arg_values[1].value)], self, arg_values)
   except Exception:  # pylint: disable=broad-except
     return None
Exemplo n.º 4
0
 def apply(self, arg_values, settings):
   """See base class."""
   try:
     result = self._evaluate_slice(
         [arg_values[0].value] +
         [int(arg_value.value) for arg_value in arg_values[1:]])
     return value.OperationValue(result, self, arg_values)
   except Exception:  # pylint: disable=broad-except
     return None
 def apply(self, arg_values, settings):
     """See base class."""
     value_objects = [arg_value.value for arg_value in arg_values]
     arg_dict = dict(zip(self.arg_names, value_objects))
     arg_dict.update(self.constant_kwargs)
     try:
         result_value = self._function_obj(**arg_dict)
     except Exception:  # pylint: disable=broad-except
         return None
     try:
         return value.OperationValue(result_value, self, arg_values)
     except ValueError:
         if settings.printing.tensor_size_warnings:
             self._print_warnings(arg_values, result_value)
         return None
Exemplo n.º 6
0
  def test_reconstruct_expression(self):
    tensor_value = value.InputValue([[1, 3, 2], [-3, 0, 4]], 'my_input')
    axis_value = value.ConstantValue(1)

    operation_value = value.OperationValue(tf.constant([3, 4]),
                                           self.operation,
                                           [tensor_value, axis_value])
    expected_expression = 'tf.reduce_max(my_input, axis=1)'
    self.assertEqual(operation_value.reconstruct_expression(),
                     expected_expression)
    self.assertEqual(operation_value.reconstruct_expression(),
                     expected_expression)  # Cached.

    operation_value_apply = self.operation.apply([tensor_value, axis_value],
                                                 self.settings)
    self.assertEqual(operation_value, operation_value_apply)
    self.assertEqual(operation_value_apply.reconstruct_expression(),
                     expected_expression)
Exemplo n.º 7
0
 def apply(self, arg_values, settings):
     """See base class."""
     return value.OperationValue(arg_values[0].value + arg_values[1].value,
                                 self, arg_values)