コード例 #1
0
ファイル: _flags_test.py プロジェクト: stjordanis/fancyflags
    def test_update_shared_dict_multi(self):
        # Tests that the shared dict is updated when the flag value is updated.
        shared_dict = {'a': {'b': ['value']}}
        namespace = ('a', 'b')
        flag_values = flags.FlagValues()

        flags.DEFINE_flag(_flags.MultiItemFlag(
            shared_dict,
            namespace,
            parser=flags.ArgumentParser(),
            serializer=flags.ArgumentSerializer(),
            name='a.b',
            default=['foo', 'bar'],
            help_string='help string'),
                          flag_values=flag_values)

        flag_values['a.b'].value = ['new', 'value']
        with self.subTest(name='setter'):
            self.assertEqual(shared_dict, {'a': {'b': ['new', 'value']}})

        flag_values(('./program', '--a.b=override1', '--a.b=override2'))
        with self.subTest(name='override_parse'):
            self.assertEqual(shared_dict,
                             {'a': {
                                 'b': ['override1', 'override2']
                             }})
コード例 #2
0
    def testFindConfigSpecified(self):
        """Tests whether config is specified on the command line."""

        config_flag = config_flags._ConfigFlag(parser=flags.ArgumentParser(),
                                               serializer=None,
                                               name='test_config',
                                               default='defaultconfig.py',
                                               help_string='')
        self.assertEqual(config_flag._FindConfigSpecified(['']), -1)

        argv_length = 20
        for i in range(argv_length):
            # Generate list of '--test_config.i=0' args.
            argv = [
                '--test_config.{}=0'.format(arg) for arg in range(argv_length)
            ]
            self.assertEqual(config_flag._FindConfigSpecified(argv), -1)

            # Override i-th arg with something specifying the value of 'test_config'.
            # After doing this, _FindConfigSpecified should return the value of i.
            argv[i] = '--test_config'
            self.assertEqual(config_flag._FindConfigSpecified(argv), i)
            argv[i] = '--test_config=config.py'
            self.assertEqual(config_flag._FindConfigSpecified(argv), i)
            argv[i] = '-test_config'
            self.assertEqual(config_flag._FindConfigSpecified(argv), i)
            argv[i] = '-test_config=config.py'
            self.assertEqual(config_flag._FindConfigSpecified(argv), i)
コード例 #3
0
    def testConfigSpecified(self, config_argument):
        """Tests whether config is specified on the command line."""

        config_flag = config_flags._ConfigFlag(parser=flags.ArgumentParser(),
                                               serializer=None,
                                               name='test_config',
                                               default='defaultconfig.py',
                                               help_string='')
        self.assertTrue(config_flag._IsConfigSpecified([config_argument]))
        self.assertFalse(config_flag._IsConfigSpecified(['']))
コード例 #4
0
ファイル: config_flags.py プロジェクト: NeoTim/ml_collections
# }
# The possible breaking changes are:
# - A Python 3 int could be a Python 2 long, which was not previously supported.
#   We then add support for long.
# - Only Python 2 str were supported (not unicode). Python 3 will behave the
#   same with the str semantic change.
_FIELD_TYPE_TO_PARSER = {
    float: flags.FloatParser(),
    bool: flags.BooleanParser(),
    # Implementing a custom parser to override `Tuple` arguments.
    tuple: tuple_parser.TupleParser(),
}
for t in six.integer_types:
    _FIELD_TYPE_TO_PARSER[t] = flags.IntegerParser()
for t in six.string_types:
    _FIELD_TYPE_TO_PARSER[t] = flags.ArgumentParser()
_FIELD_TYPE_TO_PARSER[str] = flags.ArgumentParser()
_FIELD_TYPE_TO_SERIALIZER = {
    t: flags.ArgumentSerializer()
    for t in _FIELD_TYPE_TO_PARSER
}


class UnsupportedOperationError(flags.Error):
    pass


class FlagOrderError(flags.Error):
    pass

コード例 #5
0
 def __init__(self, *args, **kwargs):
     super(_StderrthresholdFlag, self).__init__(flags.ArgumentParser(),
                                                flags.ArgumentSerializer(),
                                                *args, **kwargs)
コード例 #6
0
ファイル: _definitions.py プロジェクト: stjordanis/fancyflags
 def __init__(self, default, help_string):
     parser = flags.ArgumentParser()
     serializer = flags.ArgumentSerializer()
     super().__init__(default, help_string, parser, serializer)
コード例 #7
0
ファイル: _definitions.py プロジェクト: stjordanis/fancyflags
def DEFINE_dict(*args, **kwargs):  # pylint: disable=invalid-name
    """Defines a flat or nested dictionary flag.

  Usage example:

  ```python
  import fancyflags as ff

  ff.DEFINE_dict(
      "image_settings",
      mode=ff.String("pad", "Mode string field."),
      sizes=dict(
          width=ff.Integer(5, "Width."),
          height=ff.Integer(7, "Height."),
          scale=ff.Float(0.5, "Scale.")
      )
  )

  This creates a flag `FLAGS.image_settings`, with a default value of

  ```python
  {
      "mode": "pad",
      "sizes": {
          "width": 5,
          "height": 7,
          "scale": 0.5,
      }
  }
  ```

  Each item in the definition (e.g. ff.Integer(...)) corresponds to a flag that
  can be overridden from the command line using "dot" notation. For example, the
  following command overrides the `height` item in the nested dictionary defined
  above:

  ```
  python script_name.py -- --image_settings.sizes.height=10
  ```

  Args:
    *args: One or two positional arguments are expected:
        1. A string containing the root name for this flag. This must be set.
        2. Optionally, a `flags.FlagValues` object that will hold the Flags.
           If not set, the usual global `flags.FLAGS` object will be used.
    **kwargs: One or more keyword arguments, where the value is either an
      `ff.Item` such as `ff.String(...)` or `ff.Integer(...)` or a dict with the
      same constraints.

  Returns:
    A `FlagHolder` instance.
  """
    if not args:
        raise ValueError(
            "Please supply one positional argument containing the "
            "top-level flag name for the dict.")

    if not kwargs:
        raise ValueError(
            "Please supply at least one keyword argument defining a "
            "flag."
            "")
    if len(args) > 2:
        raise ValueError(
            "Please supply at most two positional arguments, the "
            "first containing the top-level flag name for the dict "
            "and, optionally and unusually, a second positional "
            "argument to override the flags.FlagValues instance to "
            "use.")

    if not isinstance(args[0], str):
        raise ValueError(
            "The first positional argument must be a string "
            "containing top-level flag name for the dict. Got a {}.".format(
                type(args[0]).__name__))

    if len(args) == 2:
        if not isinstance(args[1], flags.FlagValues):
            raise ValueError(
                "If supplying a second positional argument, this must "
                "be a flags.FlagValues instance. Got a {}. If you meant "
                "to define a flag, note these must be supplied as "
                "keyword arguments. ".format(type(args[1]).__name__))
        flag_values = args[1]
    else:
        flag_values = flags.FLAGS

    flag_name = args[0]

    shared_dict = define_flags(flag_name, kwargs, flag_values=flag_values)

    # usage_logging: dict

    # TODO(b/177672282): Can we persuade pytype to correctly infer the type of the
    #                    flagholder's .value attribute?
    # We register a dummy flag that returns `shared_dict` as a value.
    return flags.DEFINE_flag(_flags.DictFlag(
        shared_dict,
        name=flag_name,
        default=shared_dict,
        parser=flags.ArgumentParser(),
        serializer=None,
        help_string="Unused help string."),
                             flag_values=flag_values)
コード例 #8
0
ファイル: _definitions.py プロジェクト: stjordanis/fancyflags
 def __init__(self, default, help_string):
     super().__init__(default, help_string, flags.ArgumentParser())
コード例 #9
0
ファイル: _define_auto.py プロジェクト: deepmind/fancyflags
def DEFINE_auto(  # pylint: disable=invalid-name
    name: str,
    fn: _T,
    help_string: Optional[str] = None,
    flag_values: flags.FlagValues = flags.FLAGS,
) -> _flags.TypedFlagHolder[_T]:
    """Defines a flag for an `ff.auto`-compatible constructor or callable.

  Automatically defines a set of dotted `ff.Item` flags corresponding to the
  constructor arguments and their default values.

  Overriding the value of a dotted flag will update the arguments used to invoke
  `fn`. This flag's value returns a callable `fn` with these values as bound
  arguments,

  Example usage:

  ```python
  # Defined in, e.g., datasets library.

  @dataclasses.dataclass
  class DataSettings:
    dataset_name: str = 'mnist'
    split: str = 'train'
    batch_size: int = 128

  # In main script.
  # Exposes flags: --data.dataset_name --data.split and --data.batch_size.
  DATA_SETTINGS = ff.DEFINE_auto('data', datasets.DataSettings, 'Data config')

  def main(argv):
    # del argv  # Unused.
    dataset = datasets.load(DATA_SETTINGS.value())
    # ...
  ```

  Args:
    name: The name for the top-level flag.
    fn: An `ff.auto`-compatible `Callable`.
    help_string: Optional help string for this flag. If not provided, this will
      default to '{fn's module}.{fn's name}'.
    flag_values: An optional `flags.FlagValues` instance.

  Returns:
    A `flags.FlagHolder`.
  """
    arguments = _auto.auto(fn)
    # Define the individual flags.
    defaults = _definitions.define_flags(name,
                                         arguments,
                                         flag_values=flag_values)
    help_string = help_string or f'{fn.__module__}.{fn.__name__}'
    # Define a holder flag.
    holder = flags.DEFINE_flag(
        flag=_flags.AutoFlag(fn,
                             defaults,
                             name=name,
                             default=None,
                             parser=flags.ArgumentParser(),
                             serializer=None,
                             help_string=help_string),
        flag_values=flag_values,
    )

    return _flags.TypedFlagHolder(holder)
コード例 #10
0
ファイル: config_flags.py プロジェクト: google/ml_collections
FLAGS = flags.FLAGS

# Forward for backwards compatability.
GetValue = config_path.get_value
GetType = config_path.get_type
SetValue = config_path.set_value

# Prevent this module being considered for `FLAGS.find_module_defining_flag`.
flags._helpers.disclaim_module_ids.add(id(sys.modules[__name__]))  # pylint: disable=protected-access

_FIELD_TYPE_TO_PARSER = {
    float: flags.FloatParser(),
    bool: flags.BooleanParser(),
    tuple: tuple_parser.TupleParser(),
    int: flags.IntegerParser(),
    str: flags.ArgumentParser()
}

_FIELD_TYPE_TO_SERIALIZER = {
    t: flags.ArgumentSerializer()
    for t in _FIELD_TYPE_TO_PARSER
}


class UnsupportedOperationError(flags.Error):
  pass


class FlagOrderError(flags.Error):
  pass
コード例 #11
0
 def __init__(self,
              default: Optional[str],
              help_string: Optional[str] = None):
     super().__init__(default, help_string, flags.ArgumentParser())