def main(): """ Driver function for the script """ _parser = NmArgumentParser(dataclass_types=TrainingArguments) training_args, _ = _parser.parse_args_into_dataclasses() save_dir, loggers = helpers.get_save_dir_and_loggers(training_args, task=CURRENT_TASK) input_shape = ModelRegistry.input_shape(training_args.arch_key) image_size = input_shape[ 1] # assume shape [C, S, S] where S is the image size ( train_dataset, train_loader, val_dataset, val_loader, ) = helpers.get_train_and_validation_loaders(training_args, image_size, task=CURRENT_TASK) num_classes = helpers.infer_num_classes(training_args, train_dataset, val_dataset) # # model creation model = helpers.create_model(training_args, num_classes) train(training_args, model, train_loader, val_loader, input_shape, save_dir, loggers)
def main(): """ Driver function for the script """ _parser = NmArgumentParser( dataclass_types=LRAnalysisArguments, description="Utility script to Run a " "learning rate sensitivity analysis " "for a desired image classification architecture", ) args_, _ = _parser.parse_args_into_dataclasses() save_dir, loggers = helpers.get_save_dir_and_loggers( args_, task=CURRENT_TASK, ) input_shape = ModelRegistry.input_shape(args_.arch_key) # assume shape [C, S, S] where S is the image size image_size = input_shape[1] ( train_dataset, train_loader, val_dataset, val_loader, ) = helpers.get_train_and_validation_loaders( args_, image_size, task=CURRENT_TASK, ) num_classes = helpers.infer_num_classes(args_, train_dataset, val_dataset) model = helpers.create_model(args_, num_classes) lr_sensitivity(args_, model, train_loader, save_dir)
def test_with_optional(self): parser = NmArgumentParser(OptionalExample) expected = argparse.ArgumentParser() expected.add_argument("--foo", default=None, type=int) expected.add_argument("--bar", default=None, type=float, help="help message") expected.add_argument("--baz", default=None, type=str) expected.add_argument("--ces", nargs="+", default=[], type=str) expected.add_argument("--des", nargs="+", default=[], type=int) self.argparsersEqual(parser, expected) args = parser.parse_args([]) self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[])) args = parser.parse_args( "--foo 12 --bar 3.14 --baz 42 --ces parser expected_parser " "c --des 1 2 3".split() ) self.assertEqual( args, Namespace( foo=12, bar=3.14, baz="42", ces=["parser", "expected_parser", "c"], des=[1, 2, 3], ), )
def main(): """ Driver function """ _parser = NmArgumentParser( dataclass_types=ExportArgs, description="Utility script to export a model to onnx " "and also store sample inputs/outputs", ) (args_, ) = _parser.parse_args_into_dataclasses() model, save_dir, val_loader = export_setup(args_) export(args_, model, val_loader, save_dir)
def test_parse_dict(self): parser = NmArgumentParser(BasicExample) args_dict = { "foo": 12, "bar": 3.14, "baz": "42", "flag": True, } parsed_args = parser.parse_dict(args_dict)[0] args = BasicExample(**args_dict) self.assertEqual(parsed_args, args)
def test_basic(self): parser = NmArgumentParser(BasicExample) expected = argparse.ArgumentParser() expected.add_argument("--foo", type=int, required=True) expected.add_argument("--bar", type=float, required=True) expected.add_argument("--baz", type=str, required=True) expected.add_argument( "--flag", type=string_to_bool, default=False, const=True, nargs="?" ) self.argparsersEqual(parser, expected) args = ["--foo", "1", "--baz", "quux", "--bar", "0.5"] (example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False) self.assertFalse(example.flag)
def test_with_default(self): parser = NmArgumentParser(WithDefaultExample) expected = argparse.ArgumentParser() expected.add_argument("--foo", default=42, type=int) expected.add_argument("--baz", default="toto", type=str, help="help message") self.argparsersEqual(parser, expected)
def test_with_default_bool(self): parser = NmArgumentParser(WithDefaultBoolExample) expected = argparse.ArgumentParser() expected.add_argument( "--foo", type=string_to_bool, default=False, const=True, nargs="?" ) expected.add_argument("--no-baz", action="store_false", dest="baz") expected.add_argument( "--baz", type=string_to_bool, default=True, const=True, nargs="?" ) expected.add_argument("--opt", type=string_to_bool, default=None) self.argparsersEqual(parser, expected) args = parser.parse_args([]) self.assertEqual(args, Namespace(foo=False, baz=True, opt=None)) args = parser.parse_args(["--foo", "--no-baz"]) self.assertEqual(args, Namespace(foo=True, baz=False, opt=None)) args = parser.parse_args(["--foo", "--baz"]) self.assertEqual(args, Namespace(foo=True, baz=True, opt=None)) args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"]) self.assertEqual(args, Namespace(foo=True, baz=True, opt=True)) args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"]) self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_with_required(self): parser = NmArgumentParser(RequiredExample) expected = argparse.ArgumentParser() expected.add_argument("--required-list", nargs="+", type=int, required=True) expected.add_argument("--required-str", type=str, required=True) expected.add_argument( "--required-enum", type=str, choices=["titi", "toto"], required=True ) self.argparsersEqual(parser, expected)
def test_with_list(self): parser = NmArgumentParser(ListExample) expected = argparse.ArgumentParser() expected.add_argument("--foo-int", nargs="+", default=[], type=int) expected.add_argument("--bar-int", nargs="+", default=[1, 2, 3], type=int) expected.add_argument( "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str ) expected.add_argument( "--foo-float", nargs="+", default=[0.1, 0.2, 0.3], type=float ) self.argparsersEqual(parser, expected) args = parser.parse_args([]) self.assertEqual( args, Namespace( foo_int=[], bar_int=[1, 2, 3], foo_str=["Hallo", "Bonjour", "Hello"], foo_float=[0.1, 0.2, 0.3], ), ) args = parser.parse_args( "--foo-int 1 --bar-int 2 3 --foo-str parser expected_parser " "c --foo-float 0.1 0.7".split() ) self.assertEqual( args, Namespace( foo_int=[1], bar_int=[2, 3], foo_str=["parser", "expected_parser", "c"], foo_float=[0.1, 0.7], ), )
def test_with_enum(self): parser = NmArgumentParser(EnumExample) expected = argparse.ArgumentParser() expected.add_argument( "--foo", default="toto", choices=["titi", "toto"], type=str ) self.argparsersEqual(parser, expected) args = parser.parse_args([]) self.assertEqual(args.foo, "toto") enum_ex = parser.parse_args_into_dataclasses([])[0] self.assertEqual(enum_ex.foo, BasicEnum.toto) args = parser.parse_args(["--foo", "titi"]) self.assertEqual(args.foo, "titi") enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0] self.assertEqual(enum_ex.foo, BasicEnum.titi)
def test_non_hyphenated(self): parser = NmArgumentParser(dataclass_types=NonHyphenatedExample) args = parser.parse_args("--non_hyphenated 1".split()) self.assertEqual(args, Namespace(non_hyphenated=1))