Ejemplo n.º 1
0
    def test_pack_flat_sequence_to_spec_structure(self):
        subset_placeholders = utils.make_placeholders(mock_nested_subset_spec)
        flattened_subset_placeholders = utils.flatten_spec_structure(
            subset_placeholders)
        packed_subset_placeholders = utils.pack_flat_sequence_to_spec_structure(
            mock_nested_subset_spec, flattened_subset_placeholders)
        utils.assert_equal(subset_placeholders, packed_subset_placeholders)
        utils.assert_equal(mock_nested_subset_spec,
                           packed_subset_placeholders,
                           ignore_batch=True)

        placeholders = utils.make_placeholders(mock_nested_spec)
        flattened_placeholders = utils.flatten_spec_structure(placeholders)
        packed_placeholders = utils.pack_flat_sequence_to_spec_structure(
            mock_nested_subset_spec, flattened_placeholders)
        # We only subselect what we need in pack_flat_sequence_to_spec_structure,
        # hence, we should recover what we wanted.
        utils.assert_equal(mock_nested_subset_spec,
                           packed_placeholders,
                           ignore_batch=True)
        utils.assert_equal(subset_placeholders, packed_placeholders)

        packed_optional_placeholders = utils.pack_flat_sequence_to_spec_structure(
            mock_nested_optional_spec, flattened_placeholders)
        # Although mock_nested_optional_spec would like more tensors
        # flattened_placeholders cannot provide them, fortunately they are optional.
        utils.assert_required(packed_optional_placeholders, placeholders)
        utils.assert_required(mock_nested_spec,
                              packed_optional_placeholders,
                              ignore_batch=True)
Ejemplo n.º 2
0
 def test_pack_flat_sequence_to_spec_structure_ensure_order(self):
   test_spec = utils.TensorSpecStruct()
   test_spec.b = utils.ExtendedTensorSpec(
       shape=(1,), dtype=tf.float32, name='b')
   test_spec.a = utils.ExtendedTensorSpec(
       shape=(1,), dtype=tf.float32, name='a')
   test_spec.c = utils.ExtendedTensorSpec(
       shape=(1,), dtype=tf.float32, name='c')
   placeholders = utils.make_placeholders(test_spec)
   packed_placeholders = utils.pack_flat_sequence_to_spec_structure(
       test_spec, placeholders)
   for pos, order_name in enumerate(['a', 'b', 'c']):
     self.assertEqual(list(packed_placeholders.keys())[pos], order_name)
     self.assertEqual(
         list(packed_placeholders.values())[pos].op.name, order_name)