def test_custom_default_value(self):
     custom = {'default': 0.01}
     parser = argparse.ArgumentParser()
     net_args.add_training_settings_args(parser, learning_rate=custom,
                                         excluded_args=['num_epochs'])
     res = parser.parse_args([])
     self.assertEqual(getattr(res, 'learning_rate'), 0.01)
 def test_custom_short_alias(self):
     custom = {'short_alias': 'o'}
     parser = argparse.ArgumentParser()
     net_args.add_training_settings_args(parser, optimizer=custom,
                                         num_epochs={'required': False})
     res = parser.parse_args(['-o', 'Adam'])
     self.assertEqual(getattr(res, 'optimizer'), 'Adam')
 def test_excluded_args(self):
     excluded = set(net_args.TRAIN_SETTINGS_ARGS.keys())
     excluded.remove('num_epochs')
     parser = argparse.ArgumentParser()
     net_args.add_training_settings_args(parser, excluded_args=excluded)
     with self.assertRaises(SystemExit) as cm:
         parser.parse_args(['--num_epochs', '10', '--optimizer', 'Adam'])
         self.fail('Failed to raise exception on unknown arg passed')
 def test_custom_required_arg(self):
     custom = {'required': True}
     parser = argparse.ArgumentParser()
     net_args.add_training_settings_args(parser, loss_fn=custom,
                                         excluded_args=['num_epochs'])
     with self.assertRaises(SystemExit) as cm:
         args = parser.parse_args([])
         self.fail('Failed to raise exception on required arg not passed')
    def __init__(self):
        self.default_logdir = cutils.get_config_for_module(
            "model_trainer")['default']['logdir']
        parser = argparse.ArgumentParser(
            description="Train network using provided dataset")

        # dataset input
        in_aliases = {'dataset name': 'name', 'dataset directory': 'srcdir'}
        dset_args = dargs.DatasetArgs(input_aliases=in_aliases)
        item_args = dargs.ItemTypeArgs()
        atype = dargs.arg_type.INPUT
        group = parser.add_argument_group(title="Input dataset")
        dset_args.add_dataset_arg_double(group, atype)
        item_args.add_item_type_args(group, atype)
        group.add_argument('--test_items_count',
                           type=atypes.int_range(1),
                           help='Number of dataset items to include in the '
                           'test set. Overrides test_items_fraction.')
        group.add_argument('--test_items_fraction',
                           type=float,
                           default=0.1,
                           help='Number of dataset items to include in the '
                           'test set, expressed as a fraction.')
        modes = net_cons.DATASET_SPLIT_MODES
        group.add_argument('--split_mode',
                           choices=modes,
                           required=True,
                           help='Method of splitting the test items subset '
                           'from the input dataset.')

        # network to train
        group = parser.add_argument_group(title="Network configuration")
        net_args.add_network_arg(group, short_alias='n')
        net_args.add_model_file_arg(group, short_alias='m')
        group.add_argument('--tb_dir',
                           default=self.default_logdir,
                           help=('directory to store training logs for '
                                 'tensorboard.'))
        group.add_argument('--save',
                           action='store_true',
                           help=('save the model after training. Model files '
                                 'are saved under tb_dir as net.network_name/'
                                 'net.network_name.tflearn.*'))

        # training settings
        group = parser.add_argument_group(title="Training parameters")
        net_args.add_training_settings_args(group,
                                            num_epochs={
                                                'required': False,
                                                'default': 11,
                                                'short_alias': 'e'
                                            })

        self.parser = parser
        self.dset_args = dset_args
        self.item_args = item_args
예제 #6
0
    def __init__(self):
        self.default_logdir = cutils.get_config_for_module(
            "model_xvalidator")['default']['logdir']
        print("default logdir set to {}".format(self.default_logdir))
        parser = argparse.ArgumentParser(
            description="Perform Kfold cross-validation on a given neural "
            "network with the given dataset.")

        # cross-validation settings
        group = parser.add_argument_group(title="Cross-validation parameters")
        group.add_argument('--num_crossvals',
                           type=atypes.int_range(1),
                           required=True,
                           help='number of cross validations to perform')

        # network to train
        group = parser.add_argument_group(title="Network to use")
        net_args.add_network_arg(group, short_alias='n')
        net_args.add_model_file_arg(group, short_alias='m')

        # training_parameters
        group = parser.add_argument_group(title="Training parameters to use")
        net_args.add_training_settings_args(group,
                                            num_epochs={
                                                'required': False,
                                                'default': 11,
                                                'short_alias': 'e'
                                            })
        group.add_argument('--tb_dir',
                           default=self.default_logdir,
                           help=('directory to store training logs for '
                                 'tensorboard.'))

        # dataset input
        in_aliases = {'dataset name': 'name', 'dataset directory': 'srcdir'}
        dset_args = dargs.DatasetArgs(input_aliases=in_aliases)
        item_args = dargs.ItemTypeArgs()
        atype = dargs.arg_type.INPUT
        group = parser.add_argument_group(title="Input dataset")
        dset_args.add_dataset_arg_double(group, atype)
        item_args.add_item_type_args(group, atype)
        group.add_argument('--test_items_count',
                           type=atypes.int_range(1),
                           help='number of dataset items to include in the '
                           'test set. Overrides test_items_fraction.')
        group.add_argument('--test_items_fraction',
                           type=float,
                           default=0.1,
                           help='number of dataset items to include in the '
                           'test set, expressed as a fraction.')

        self.parser = parser
        self.dset_args = dset_args
        self.item_args = item_args
 def test_custom_type(self):
     custom = {'type': float}
     parser = argparse.ArgumentParser()
     net_args.add_training_settings_args(parser, num_epochs=custom)
     res = parser.parse_args(['--num_epochs', '0.01'])
     self.assertEqual(getattr(res, 'num_epochs'), 0.01)