def test_raises_arg_does_not_match_param(self): good_param = building_blocks.Reference('x', tf.int32) good_body = building_blocks.Tuple( [building_blocks.Reference('x', tf.int32)]) bad_arg_type = building_blocks.Data('y', tf.float32) with self.assertRaises(TypeError): compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg( good_param, good_body, bad_arg_type)
def test_constructs_called_tf_block_of_correct_type_signature(self): param = building_blocks.Reference('x', tf.int32) body = building_blocks.Tuple([building_blocks.Reference('x', tf.int32)]) arg = building_blocks.Reference('y', tf.int32) tf_block = compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg( param, body, arg) self.assertIsInstance(tf_block, building_blocks.Call) self.assertIsInstance(tf_block.function, building_blocks.CompiledComputation) self.assertEqual(tf_block.type_signature, body.type_signature)
def test_generated_tensorflow_executes_correctly_sequence_parameter(self): param = building_blocks.Reference('x', computation_types.SequenceType(tf.int32)) body = building_blocks.Tuple([param]) sequence_ref = building_blocks.Reference( 'y', computation_types.SequenceType(tf.int32)) tf_block = compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg( param, body, sequence_ref) result = test_utils.run_tensorflow(tf_block.function.proto, list(range(5))) self.assertLen(result, 1) self.assertAllEqual(result[0], list(range(5)))
def test_generated_tensorflow_executes_correctly_tuple_parameter(self): param = building_blocks.Reference('x', [tf.int32, tf.float32]) body = building_blocks.Tuple([ building_blocks.Selection(param, index=1), building_blocks.Selection(param, index=0) ]) int_constant = building_block_factory.create_tensorflow_constant( [tf.int32, tf.float32], 1) tf_block = compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg( param, body, int_constant) result = test_utils.run_tensorflow(tf_block.function.proto) self.assertLen(result, 2) self.assertEqual(result[0], 1.) self.assertEqual(result[1], 1)
def test_generated_tensorflow_executes_correctly_int_parameter(self): param = building_blocks.Reference('x', tf.int32) body = building_blocks.Tuple([ building_blocks.Reference('x', tf.int32), building_blocks.Reference('x', tf.int32) ]) int_constant = building_block_factory.create_tensorflow_constant( tf.int32, 0) tf_block = transformations.construct_tensorflow_calling_lambda_on_concrete_arg( param, body, int_constant) result = test_utils.run_tensorflow(tf_block.function.proto) self.assertLen(result, 2) self.assertEqual(result[0], 0) self.assertEqual(result[1], 0)
def test_raises_wrong_arguments(self): good_param = building_blocks.Reference('x', tf.int32) good_body = building_blocks.Tuple( [building_blocks.Reference('x', tf.int32)]) good_arg = building_blocks.Data('y', tf.int32) with self.assertRaises(TypeError): compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg( good_body, good_body, good_arg) with self.assertRaises(TypeError): compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg( good_param, [good_param], good_arg) with self.assertRaises(TypeError): compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg( good_param, good_body, [good_arg])