def test_overriding_top_level_dict_flag_fails(self): ff.DEFINE_dict("top_level_dict", integer_field=ff.Integer(1, "integer field")) # The error type and message get converted in the process. with self.assertRaisesRegex(flags.IllegalFlagValueError, "Can't override a dict flag directly"): FLAGS(("./program", "--top_level_dict=3"))
def test_define_flat(self): flagholder = ff.DEFINE_dict( "flat_dict", integer_field=ff.Integer(1, "integer field"), string_field=ff.String(""), string_list_field=ff.StringList(["a", "b", "c"], "string list field")) # This should return a single dict with the default values specified above. expected = { "integer_field": 1, "string_field": "", "string_list_field": ["a", "b", "c"] } self.assertEqual(FLAGS.flat_dict, expected) self.assertEqual(flagholder.value, expected) # These flags should also exist, although we won't access them in practice. self.assertEqual(FLAGS["flat_dict.integer_field"].value, 1) self.assertEqual(FLAGS["flat_dict.string_field"].value, "") # Custom help string. self.assertEqual(FLAGS["flat_dict.integer_field"].help, "integer field") # Default help string. self.assertEqual(FLAGS["flat_dict.string_field"].help, "flat_dict.string_field")
def test_flag_values_error(self): with self.assertRaisesRegex(ValueError, "FlagValues instance"): ff.DEFINE_dict( "name", ff.String("stringflag", "string"), integer_field=ff.Integer(1, "integer field") )
def test_flag_name_error(self): with self.assertRaisesRegex(ValueError, "must be a string"): ff.DEFINE_dict( ff.String("name", "string flag"), ff.String("stringflag", "string"), integer_field=ff.Integer(1, "integer field") )
def test_too_many_positional_args_error(self): with self.assertRaisesRegex(ValueError, "at most two positional"): ff.DEFINE_dict( "name", ff.String("foo", "string"), ff.String("bar", "string"), integer_field=ff.Integer(1, "integer field") )
def test_valid_flat(self): result = ff._extract_defaults({ "integer_field": ff.Integer(10, "Integer field"), "string_field": ff.String("default", "String field"), }) expected = {"integer_field": 10, "string_field": "default"} self.assertEqual(result, expected)
def test_invalid_container(self): expected_message = ff._NOT_A_DICT_OR_ITEM.format("list") with self.assertRaisesWithLiteralMatch(TypeError, expected_message): ff._extract_defaults({ "integer_field": ff.Integer(10, "Integer field"), "string_field": ff.String("default", "String field"), "nested": [ff.Float(3.1, "Float field")], })
def test_define_nested(self): flagholder = ff.DEFINE_dict( "nested_dict", integer_field=ff.Integer(1, "integer field"), sub_dict=dict(string_field=ff.String("", "string field"))) # This should return a single dict with the default values specified above. expected = {"integer_field": 1, "sub_dict": {"string_field": ""}} self.assertEqual(FLAGS.nested_dict, expected) self.assertEqual(flagholder.value, expected) # These flags should also exist, although we won't access them in practice. self.assertEqual(FLAGS["nested_dict.integer_field"].value, 1) self.assertEqual(FLAGS["nested_dict.sub_dict.string_field"].value, "")
def test_basic_serialization(self): ff.DEFINE_dict( "to_serialize", integer_field=ff.Integer(1, "integer field"), boolean_field=ff.Boolean(False, "boolean field"), string_list_field=ff.StringList(["a", "b", "c"], "string list field"), enum_class_field=ff.EnumClass(MyEnum.A, MyEnum, "my enum field"), ) initial_dict_value = FLAGS["to_serialize"].value.copy() # Parse flags, then serialize. FLAGS([ "./program", "--to_serialize.boolean_field=True", "--to_serialize.integer_field", "1337", "--to_serialize.string_list_field=d,e,f", "--to_serialize.enum_class_field=B", ]) self.assertEqual(FLAGS["to_serialize"].serialize(), _flags._EMPTY) self.assertEqual(FLAGS["to_serialize.boolean_field"].serialize(), "--to_serialize.boolean_field=True") self.assertEqual(FLAGS["to_serialize.string_list_field"].serialize(), "--to_serialize.string_list_field=d,e,f") parsed_dict_value = FLAGS["to_serialize"].value.copy() self.assertDictEqual( parsed_dict_value, { "boolean_field": True, "integer_field": 1337, "string_list_field": ["d", "e", "f"], "enum_class_field": MyEnum.B, }) self.assertNotEqual(FLAGS["to_serialize"].value, initial_dict_value) # test a round trip serialized_args = [ FLAGS[name].serialize() for name in FLAGS if name.startswith("to_serialize.") ] FLAGS.unparse_flags() # Reset to defaults self.assertDictEqual(FLAGS["to_serialize"].value, initial_dict_value) FLAGS(["./program"] + serialized_args) self.assertDictEqual(FLAGS["to_serialize"].value, parsed_dict_value)
def test_valid_nested(self): result = ff._extract_defaults( { "integer_field": ff.Integer(10, "Integer field"), "string_field": ff.String("default", "String field"), "nested": { "float_field": ff.Float(3.1, "Float field"), }, } ) expected = { "integer_field": 10, "string_field": "default", "nested": {"float_field": 3.1}, } self.assertEqual(result, expected)
def test_no_name_error(self): with self.assertRaisesRegex(ValueError, "one positional argument"): ff.DEFINE_dict(integer_field=ff.Integer(1, "integer field"), )