Ejemplo n.º 1
0
 def test_raises_arg_does_not_match_param(self):
   good_param = building_blocks.Reference('x', tf.int32)
   good_body = building_blocks.Tuple(
       [building_blocks.Reference('x', tf.int32)])
   bad_arg_type = building_blocks.Data('y', tf.float32)
   with self.assertRaises(TypeError):
     compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg(
         good_param, good_body, bad_arg_type)
Ejemplo n.º 2
0
 def test_constructs_called_tf_block_of_correct_type_signature(self):
   param = building_blocks.Reference('x', tf.int32)
   body = building_blocks.Tuple([building_blocks.Reference('x', tf.int32)])
   arg = building_blocks.Reference('y', tf.int32)
   tf_block = compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg(
       param, body, arg)
   self.assertIsInstance(tf_block, building_blocks.Call)
   self.assertIsInstance(tf_block.function,
                         building_blocks.CompiledComputation)
   self.assertEqual(tf_block.type_signature, body.type_signature)
Ejemplo n.º 3
0
 def test_generated_tensorflow_executes_correctly_sequence_parameter(self):
   param = building_blocks.Reference('x',
                                     computation_types.SequenceType(tf.int32))
   body = building_blocks.Tuple([param])
   sequence_ref = building_blocks.Reference(
       'y', computation_types.SequenceType(tf.int32))
   tf_block = compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg(
       param, body, sequence_ref)
   result = test_utils.run_tensorflow(tf_block.function.proto, list(range(5)))
   self.assertLen(result, 1)
   self.assertAllEqual(result[0], list(range(5)))
Ejemplo n.º 4
0
 def test_generated_tensorflow_executes_correctly_tuple_parameter(self):
   param = building_blocks.Reference('x', [tf.int32, tf.float32])
   body = building_blocks.Tuple([
       building_blocks.Selection(param, index=1),
       building_blocks.Selection(param, index=0)
   ])
   int_constant = building_block_factory.create_tensorflow_constant(
       [tf.int32, tf.float32], 1)
   tf_block = compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg(
       param, body, int_constant)
   result = test_utils.run_tensorflow(tf_block.function.proto)
   self.assertLen(result, 2)
   self.assertEqual(result[0], 1.)
   self.assertEqual(result[1], 1)
 def test_generated_tensorflow_executes_correctly_int_parameter(self):
     param = building_blocks.Reference('x', tf.int32)
     body = building_blocks.Tuple([
         building_blocks.Reference('x', tf.int32),
         building_blocks.Reference('x', tf.int32)
     ])
     int_constant = building_block_factory.create_tensorflow_constant(
         tf.int32, 0)
     tf_block = transformations.construct_tensorflow_calling_lambda_on_concrete_arg(
         param, body, int_constant)
     result = test_utils.run_tensorflow(tf_block.function.proto)
     self.assertLen(result, 2)
     self.assertEqual(result[0], 0)
     self.assertEqual(result[1], 0)
Ejemplo n.º 6
0
 def test_raises_wrong_arguments(self):
   good_param = building_blocks.Reference('x', tf.int32)
   good_body = building_blocks.Tuple(
       [building_blocks.Reference('x', tf.int32)])
   good_arg = building_blocks.Data('y', tf.int32)
   with self.assertRaises(TypeError):
     compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg(
         good_body, good_body, good_arg)
   with self.assertRaises(TypeError):
     compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg(
         good_param, [good_param], good_arg)
   with self.assertRaises(TypeError):
     compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg(
         good_param, good_body, [good_arg])