def test_custom_default_value(self): custom = {'default': 0.01} parser = argparse.ArgumentParser() net_args.add_training_settings_args(parser, learning_rate=custom, excluded_args=['num_epochs']) res = parser.parse_args([]) self.assertEqual(getattr(res, 'learning_rate'), 0.01)
def test_custom_short_alias(self): custom = {'short_alias': 'o'} parser = argparse.ArgumentParser() net_args.add_training_settings_args(parser, optimizer=custom, num_epochs={'required': False}) res = parser.parse_args(['-o', 'Adam']) self.assertEqual(getattr(res, 'optimizer'), 'Adam')
def test_excluded_args(self): excluded = set(net_args.TRAIN_SETTINGS_ARGS.keys()) excluded.remove('num_epochs') parser = argparse.ArgumentParser() net_args.add_training_settings_args(parser, excluded_args=excluded) with self.assertRaises(SystemExit) as cm: parser.parse_args(['--num_epochs', '10', '--optimizer', 'Adam']) self.fail('Failed to raise exception on unknown arg passed')
def test_custom_required_arg(self): custom = {'required': True} parser = argparse.ArgumentParser() net_args.add_training_settings_args(parser, loss_fn=custom, excluded_args=['num_epochs']) with self.assertRaises(SystemExit) as cm: args = parser.parse_args([]) self.fail('Failed to raise exception on required arg not passed')
def __init__(self): self.default_logdir = cutils.get_config_for_module( "model_trainer")['default']['logdir'] parser = argparse.ArgumentParser( description="Train network using provided dataset") # dataset input in_aliases = {'dataset name': 'name', 'dataset directory': 'srcdir'} dset_args = dargs.DatasetArgs(input_aliases=in_aliases) item_args = dargs.ItemTypeArgs() atype = dargs.arg_type.INPUT group = parser.add_argument_group(title="Input dataset") dset_args.add_dataset_arg_double(group, atype) item_args.add_item_type_args(group, atype) group.add_argument('--test_items_count', type=atypes.int_range(1), help='Number of dataset items to include in the ' 'test set. Overrides test_items_fraction.') group.add_argument('--test_items_fraction', type=float, default=0.1, help='Number of dataset items to include in the ' 'test set, expressed as a fraction.') modes = net_cons.DATASET_SPLIT_MODES group.add_argument('--split_mode', choices=modes, required=True, help='Method of splitting the test items subset ' 'from the input dataset.') # network to train group = parser.add_argument_group(title="Network configuration") net_args.add_network_arg(group, short_alias='n') net_args.add_model_file_arg(group, short_alias='m') group.add_argument('--tb_dir', default=self.default_logdir, help=('directory to store training logs for ' 'tensorboard.')) group.add_argument('--save', action='store_true', help=('save the model after training. Model files ' 'are saved under tb_dir as net.network_name/' 'net.network_name.tflearn.*')) # training settings group = parser.add_argument_group(title="Training parameters") net_args.add_training_settings_args(group, num_epochs={ 'required': False, 'default': 11, 'short_alias': 'e' }) self.parser = parser self.dset_args = dset_args self.item_args = item_args
def __init__(self): self.default_logdir = cutils.get_config_for_module( "model_xvalidator")['default']['logdir'] print("default logdir set to {}".format(self.default_logdir)) parser = argparse.ArgumentParser( description="Perform Kfold cross-validation on a given neural " "network with the given dataset.") # cross-validation settings group = parser.add_argument_group(title="Cross-validation parameters") group.add_argument('--num_crossvals', type=atypes.int_range(1), required=True, help='number of cross validations to perform') # network to train group = parser.add_argument_group(title="Network to use") net_args.add_network_arg(group, short_alias='n') net_args.add_model_file_arg(group, short_alias='m') # training_parameters group = parser.add_argument_group(title="Training parameters to use") net_args.add_training_settings_args(group, num_epochs={ 'required': False, 'default': 11, 'short_alias': 'e' }) group.add_argument('--tb_dir', default=self.default_logdir, help=('directory to store training logs for ' 'tensorboard.')) # dataset input in_aliases = {'dataset name': 'name', 'dataset directory': 'srcdir'} dset_args = dargs.DatasetArgs(input_aliases=in_aliases) item_args = dargs.ItemTypeArgs() atype = dargs.arg_type.INPUT group = parser.add_argument_group(title="Input dataset") dset_args.add_dataset_arg_double(group, atype) item_args.add_item_type_args(group, atype) group.add_argument('--test_items_count', type=atypes.int_range(1), help='number of dataset items to include in the ' 'test set. Overrides test_items_fraction.') group.add_argument('--test_items_fraction', type=float, default=0.1, help='number of dataset items to include in the ' 'test set, expressed as a fraction.') self.parser = parser self.dset_args = dset_args self.item_args = item_args
def test_custom_type(self): custom = {'type': float} parser = argparse.ArgumentParser() net_args.add_training_settings_args(parser, num_epochs=custom) res = parser.parse_args(['--num_epochs', '0.01']) self.assertEqual(getattr(res, 'num_epochs'), 0.01)