def test_returns_computation(self, py_fn, type_signature, arg,
                               expected_result):
    proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
        py_fn, type_signature)

    self.assertIsInstance(proto, pb.Computation)
    actual_result = tensorflow_computation_test_utils.run_tensorflow(proto, arg)
    self.assertEqual(actual_result, expected_result)
 def test_transform_compiled_computation_semantic_equivalence(self):
     tuple_type = computation_types.TensorType(tf.int32)
     compiled_computation = building_block_factory.create_compiled_identity(
         tuple_type)
     config = tf.compat.v1.ConfigProto()
     tf_optimizer = compiled_computation_transformations.TensorFlowOptimizer(
         config)
     transformed_comp, mutated = tf_optimizer.transform(
         compiled_computation)
     self.assertTrue(mutated)
     self.assertIsInstance(transformed_comp,
                           building_blocks.CompiledComputation)
     zero_before_transform = tensorflow_computation_test_utils.run_tensorflow(
         compiled_computation.proto, 0)
     zero_after_transform = tensorflow_computation_test_utils.run_tensorflow(
         transformed_comp.proto, 0)
     self.assertEqual(zero_before_transform, zero_after_transform)
  def test_returns_computation(self, type_signature, value):
    proto, _ = tensorflow_computation_factory.create_identity(type_signature)

    self.assertIsInstance(proto, pb.Computation)
    actual_type = type_serialization.deserialize_type(proto.type)
    expected_type = type_factory.unary_op(type_signature)
    self.assertEqual(actual_type, expected_type)
    actual_result = tensorflow_computation_test_utils.run_tensorflow(
        proto, value)
    self.assertEqual(actual_result, value)
  def test_returns_computation(self):
    proto, _ = tensorflow_computation_factory.create_empty_tuple()

    self.assertIsInstance(proto, pb.Computation)
    actual_type = type_serialization.deserialize_type(proto.type)
    expected_type = computation_types.FunctionType(None, [])
    expected_type.check_assignable_from(actual_type)
    actual_result = tensorflow_computation_test_utils.run_tensorflow(proto)
    expected_result = structure.Struct([])
    self.assertEqual(actual_result, expected_result)
  def test_returns_computation(self, value, type_signature, expected_result):
    proto, _ = tensorflow_computation_factory.create_constant(
        value, type_signature)

    self.assertIsInstance(proto, pb.Computation)
    actual_type = type_serialization.deserialize_type(proto.type)
    expected_type = computation_types.FunctionType(None, type_signature)
    expected_type.check_assignable_from(actual_type)
    actual_result = tensorflow_computation_test_utils.run_tensorflow(proto)
    if isinstance(expected_result, list):
      self.assertCountEqual(actual_result, expected_result)
    else:
      self.assertEqual(actual_result, expected_result)
  def test_returns_computation(self, type_signature, count, value):
    proto, _ = tensorflow_computation_factory.create_replicate_input(
        type_signature, count)

    self.assertIsInstance(proto, pb.Computation)
    actual_type = type_serialization.deserialize_type(proto.type)
    expected_type = computation_types.FunctionType(type_signature,
                                                   [type_signature] * count)
    expected_type.check_assignable_from(actual_type)
    actual_result = tensorflow_computation_test_utils.run_tensorflow(
        proto, value)
    expected_result = structure.Struct([(None, value)] * count)
    self.assertEqual(actual_result, expected_result)
  def test_returns_computation(self, operator, type_signature, operands,
                               expected_result):
    proto, _ = tensorflow_computation_factory.create_binary_operator_with_upcast(
        type_signature, operator)

    self.assertIsInstance(proto, pb.Computation)
    actual_type = type_serialization.deserialize_type(proto.type)
    self.assertIsInstance(actual_type, computation_types.FunctionType)
    # Note: It is only useful to test the parameter type; the result type
    # depends on the `operator` used, not the implemenation
    # `create_binary_operator_with_upcast`.
    expected_parameter_type = _StructType(type_signature)
    self.assertEqual(actual_type.parameter, expected_parameter_type)
    actual_result = tensorflow_computation_test_utils.run_tensorflow(
        proto, operands)
    self.assertEqual(actual_result, expected_result)