def test_pack_flat_sequence_to_spec_structure(self): subset_placeholders = utils.make_placeholders(mock_nested_subset_spec) flattened_subset_placeholders = utils.flatten_spec_structure( subset_placeholders) packed_subset_placeholders = utils.pack_flat_sequence_to_spec_structure( mock_nested_subset_spec, flattened_subset_placeholders) utils.assert_equal(subset_placeholders, packed_subset_placeholders) utils.assert_equal(mock_nested_subset_spec, packed_subset_placeholders, ignore_batch=True) placeholders = utils.make_placeholders(mock_nested_spec) flattened_placeholders = utils.flatten_spec_structure(placeholders) packed_placeholders = utils.pack_flat_sequence_to_spec_structure( mock_nested_subset_spec, flattened_placeholders) # We only subselect what we need in pack_flat_sequence_to_spec_structure, # hence, we should recover what we wanted. utils.assert_equal(mock_nested_subset_spec, packed_placeholders, ignore_batch=True) utils.assert_equal(subset_placeholders, packed_placeholders) packed_optional_placeholders = utils.pack_flat_sequence_to_spec_structure( mock_nested_optional_spec, flattened_placeholders) # Although mock_nested_optional_spec would like more tensors # flattened_placeholders cannot provide them, fortunately they are optional. utils.assert_required(packed_optional_placeholders, placeholders) utils.assert_required(mock_nested_spec, packed_optional_placeholders, ignore_batch=True)
def test_pack_flat_sequence_to_spec_structure_ensure_order(self): test_spec = utils.TensorSpecStruct() test_spec.b = utils.ExtendedTensorSpec( shape=(1,), dtype=tf.float32, name='b') test_spec.a = utils.ExtendedTensorSpec( shape=(1,), dtype=tf.float32, name='a') test_spec.c = utils.ExtendedTensorSpec( shape=(1,), dtype=tf.float32, name='c') placeholders = utils.make_placeholders(test_spec) packed_placeholders = utils.pack_flat_sequence_to_spec_structure( test_spec, placeholders) for pos, order_name in enumerate(['a', 'b', 'c']): self.assertEqual(list(packed_placeholders.keys())[pos], order_name) self.assertEqual( list(packed_placeholders.values())[pos].op.name, order_name)