Esempio n. 1
0
 def test_pad_graph_inputs_to_match_type_raises_on_mismatched_graph_type_and_requested_type(
         self):
     comp = _create_compiled_computation(
         lambda x: x, computation_types.to_type([tf.float32]))
     with self.assertRaisesRegexp(TypeError, r'must match the beginning'):
         compiled_computation_transforms.pad_graph_inputs_to_match_type(
             comp, computation_types.to_type([tf.int32] * 2))
Esempio n. 2
0
 def test_pad_graph_inputs_to_match_type_raises_on_requested_type_too_short(
         self):
     comp = _create_compiled_computation(
         lambda x: x, computation_types.to_type([tf.int32] * 3))
     with self.assertRaisesRegexp(ValueError, r'must have more elements'):
         compiled_computation_transforms.pad_graph_inputs_to_match_type(
             comp, computation_types.to_type([tf.int32] * 2))
Esempio n. 3
0
 def test_pad_graph_inputs_to_match_type_raises_on_wrong_requested_type(
         self):
     comp = _create_compiled_computation(
         lambda x: x, computation_types.to_type([tf.int32]))
     tensor_type = computation_types.to_type(tf.int32)
     with self.assertRaisesRegexp(TypeError, r'Expected.*NamedTupleType'):
         compiled_computation_transforms.pad_graph_inputs_to_match_type(
             comp, tensor_type)
Esempio n. 4
0
 def test_pad_graph_inputs_to_match_type_raises_on_wrong_graph_parameter_type(
         self):
     comp = _create_compiled_computation(
         lambda x: x, computation_types.to_type(tf.int32))
     with self.assertRaisesRegexp(
             TypeError,
             r'Can only pad inputs of a CompiledComputation with parameter type tuple'
     ):
         compiled_computation_transforms.pad_graph_inputs_to_match_type(
             comp, computation_types.to_type([tf.int32]))
Esempio n. 5
0
    def test_pad_graph_inputs_to_match_type_preserves_unnamed_type_signature(
            self):
        computation_arg_type = computation_types.to_type([tf.int32])
        foo = _create_compiled_computation(lambda x: x, computation_arg_type)

        padded_inputs = compiled_computation_transforms.pad_graph_inputs_to_match_type(
            foo, computation_types.NamedTupleType([tf.int32, tf.float32]))
        expetected_type_signature = computation_types.FunctionType(
            [tf.int32, tf.float32], [tf.int32])

        self.assertEqual(padded_inputs.type_signature,
                         expetected_type_signature)
Esempio n. 6
0
    def test_pad_graph_inputs_to_match_type_add_single_int_executes_correctly(
            self):
        computation_arg_type = computation_types.to_type([tf.int32])
        foo = _create_compiled_computation(lambda x: x, computation_arg_type)

        padded_inputs = compiled_computation_transforms.pad_graph_inputs_to_match_type(
            foo, computation_types.NamedTupleType([tf.int32, tf.float32]))
        executable_padded_inputs = _to_computation_impl(padded_inputs)

        expected_result = anonymous_tuple.AnonymousTuple([(None, 1)])

        self.assertEqual(executable_padded_inputs([1, 0.]), expected_result)
        self.assertEqual(executable_padded_inputs([1, 10.]), expected_result)
Esempio n. 7
0
 def test_pad_graph_inputs_to_match_type_raises_on_none(self):
     with self.assertRaisesRegexp(TypeError,
                                  r'Expected.*CompiledComputation'):
         compiled_computation_transforms.pad_graph_inputs_to_match_type(
             None, computation_types.to_type([tf.int32]))