コード例 #1
0
def define_help_flags():
    """Registers help flags. Idempotent."""
    # Use a global to ensure idempotence.
    global _define_help_flags_called

    if not _define_help_flags_called:
        flags.DEFINE_flag(HelpFlag())
        flags.DEFINE_flag(HelpshortFlag())  # alias for --help
        flags.DEFINE_flag(HelpfullFlag())
        flags.DEFINE_flag(HelpXMLFlag())
        _define_help_flags_called = True
コード例 #2
0
ファイル: _flags_test.py プロジェクト: stjordanis/fancyflags
    def test_update_shared_dict_multi(self):
        # Tests that the shared dict is updated when the flag value is updated.
        shared_dict = {'a': {'b': ['value']}}
        namespace = ('a', 'b')
        flag_values = flags.FlagValues()

        flags.DEFINE_flag(_flags.MultiItemFlag(
            shared_dict,
            namespace,
            parser=flags.ArgumentParser(),
            serializer=flags.ArgumentSerializer(),
            name='a.b',
            default=['foo', 'bar'],
            help_string='help string'),
                          flag_values=flag_values)

        flag_values['a.b'].value = ['new', 'value']
        with self.subTest(name='setter'):
            self.assertEqual(shared_dict, {'a': {'b': ['new', 'value']}})

        flag_values(('./program', '--a.b=override1', '--a.b=override2'))
        with self.subTest(name='override_parse'):
            self.assertEqual(shared_dict,
                             {'a': {
                                 'b': ['override1', 'override2']
                             }})
コード例 #3
0
    def define(
        self,
        namespace: str,
        shared_dict,
        flag_values: flags.FlagValues,
    ) -> flags.FlagHolder[_T]:
        """Defines a flag that when parsed will update a shared dictionary.

    Args:
      namespace: A sequence of strings that define the name of this flag. For
        example, `("foo", "bar")` will correspond to a flag named `foo.bar`.
      shared_dict: A dictionary that is shared by the top level dict flag. When
        the individual flag created by this method is parsed, it will also
        write the parsed value into `shared_dict`. The `namespace` determines
        the flat or nested key when storing the parsed value.
      flag_values: The `flags.FlagValues` instance to use.

    Returns:
      A new flags.FlagHolder instance.
    """
        name = SEPARATOR.join(namespace)
        help_string = name if self._help_string is None else self._help_string
        return flags.DEFINE_flag(_flags.ItemFlag(shared_dict,
                                                 namespace,
                                                 parser=self._parser,
                                                 serializer=self._serializer,
                                                 name=name,
                                                 default=self.default,
                                                 help_string=help_string),
                                 flag_values=flag_values)
コード例 #4
0
ファイル: _definitions.py プロジェクト: stjordanis/fancyflags
 def define(self, namespace, shared_dict, flag_values):
     flags.DEFINE_flag(_flags.MultiItemFlag(shared_dict,
                                            namespace,
                                            parser=self._parser,
                                            serializer=self._serializer,
                                            name=SEPARATOR.join(namespace),
                                            default=self.default,
                                            help_string=self._help_string),
                       flag_values=flag_values)
コード例 #5
0
 def define(
     self,
     namespace: str,
     shared_dict,
     flag_values,
 ) -> flags.FlagHolder[Iterable[_T]]:
     name = SEPARATOR.join(namespace)
     help_string = name if self._help_string is None else self._help_string
     return flags.DEFINE_flag(_flags.MultiItemFlag(
         shared_dict,
         namespace,
         parser=self._parser,
         serializer=self._serializer,
         name=name,
         default=self.default,
         help_string=help_string),
                              flag_values=flag_values)
コード例 #6
0
ファイル: config_flags.py プロジェクト: google/ml_collections
def DEFINE_config_dataclass(  # pylint: disable=invalid-name
    name: str,
    config: _T,
    help_string: str = 'Configuration object. Must be a dataclass.',
    flag_values: flags.FlagValues = FLAGS,
    parse_fn: Optional[Callable[[Any], _T]] = None,
    **kwargs,
) -> _TypedFlagHolder[_T]:
  """Defines a typed (dataclass) flag-overrideable configuration.

  Similar to `DEFINE_config_dict` except `config` should be a `dataclass`.

  Args:
    name: Flag name.
    config: A user-defined configuration object. Must be built via `dataclass`.
    help_string: Help string to display when --helpfull is called.
    flag_values: FlagValues instance used for parsing.
    parse_fn: Function that can parse provided flag value, when assigned
    via flag.value, or passed on command line. Default is to only allow
    to assign instances of this class.
    **kwargs: Optional keyword arguments passed to Flag constructor.
  Returns:
    A handle to the defined flag.
  """

  if not dataclasses.is_dataclass(config):
    raise ValueError('Configuration object must be a `dataclass`.')

  # Define the flag.
  parser = _DataclassParser(name=name, dataclass_type=type(config),
                            parse_fn=parse_fn)
  flag = _ConfigFlag(
      flag_values=flag_values,
      parser=parser,
      serializer=flags.ArgumentSerializer(),
      name=name,
      default=config,
      help_string=help_string,
      **kwargs)

  return _TypedFlagHolder(flag=flags.DEFINE_flag(flag, flag_values))
コード例 #7
0
ファイル: config_flags.py プロジェクト: NeoTim/ml_collections
def DEFINE_config_file(  # pylint: disable=g-bad-name
        name,
        default=None,
        help_string='path to config file.',
        flag_values=FLAGS,
        lock_config=True,
        **kwargs):
    r"""Defines flag for `ConfigDict` files compatible with absl flags.

  The flag's value should be a path to a valid python file which contains a
  function called `get_config()` that returns a python object specifying
  a configuration. After the flag is parsed, `FLAGS.name` will contain
  a reference to this object, optionally with some values overridden.

  During flags parsing, every flag of form `--name.([a-zA-Z0-9]+\.?)+=value`
  and `-name.([a-zA-Z0-9]+\.?)+ value` will be treated as an override of a
  specific field in the config object returned by this flag. Field is
  essentially a dot delimited path inside the object where each path element
  has to be either an attribute or a key existing in the config object.
  For example `--my_config.field1.field2=val` means "assign value val
  to the attribute (or key) `field2` inside value of the attribute (or key)
  `field1` inside the value of `my_config` object". If there are both
  attribute and key-based access with the same name, attribute is preferred.

  Typical usage example:

  `script.py`::

    from absl import flags
    from ml_collections.config_flags import config_flags

    FLAGS = flags.FLAGS
    config_flags.DEFINE_config_file('my_config')

    print(FLAGS.my_config)

  `config.py`::

    def get_config():
      return {
          'field1': 1,
          'field2': 'tom',
          'nested': {
              'field': 2.23,
          },
      }

  The following command::

    python script.py -- --my_config=config.py
                        --my_config.field1 8
                        --my_config.nested.field=2.1

  will print::

    {'field1': 8, 'field2': 'tom', 'nested': {'field': 2.1}}

  It is possible to parameterise the get_config function, allowing it to
  return a differently structured result for different occasions. This is
  particularly useful when setting up hyperparameter sweeps across various
  network architectures.

  `parameterised_config.py`::

    def get_config(config_string):
      possible_configs = {
          'mlp': {
              'constructor': 'snt.nets.MLP',
              'config': {
                  'output_sizes': (128, 128, 1),
              }
          },
          'lstm': {
              'constructor': 'snt.LSTM',
              'config': {
                  'hidden_size': 128,
                  'forget_bias': 1.0,
              }
          }
      }
      return possible_configs[config_string]

  If a colon is present in the command line override for the config file,
  everything to the right of the colon is passed into the get_config function.
  The following command lines will both function correctly::

    python script.py -- --my_config=parameterised_config.py:mlp
                        --my_config.config.output_sizes="(256,256,1)"


    python script.py -- --my_config=parameterised_config.py:lstm
                        --my_config.config.hidden_size=256

  The following will produce an error, as the hidden_size flag does not
  exist when the "mlp" config_string is provided::

    python script.py -- --my_config=parameterised_config.py:mlp
                        --my_config.config.hidden_size=256

  Args:
    name: Flag name, optionally including extra config after a colon.
    default: Default value of the flag (default: None).
    help_string: Help string to display when --helpfull is called.
        (default: "path to config file.")
    flag_values: FlagValues instance used for parsing.
        (default: absl.flags.FLAGS)
    lock_config: If set to True, loaded config will be locked through calling
        .lock() method on its instance (if it exists). (default: True)
    **kwargs: Optional keyword arguments passed to Flag constructor.
  """
    parser = _ConfigFileParser(name=name, lock_config=lock_config)
    flag = _ConfigFlag(parser=parser,
                       serializer=flags.ArgumentSerializer(),
                       name=name,
                       default=default,
                       help_string=help_string,
                       flag_values=flag_values,
                       **kwargs)

    # Get the module name for the frame at depth 1 in the call stack.
    module_name = sys._getframe(1).f_globals.get('__name__', None)  # pylint: disable=protected-access
    module_name = sys.argv[0] if module_name == '__main__' else module_name
    flags.DEFINE_flag(flag, flag_values, module_name=module_name)
コード例 #8
0
ファイル: config_flags.py プロジェクト: NeoTim/ml_collections
def DEFINE_config_dict(  # pylint: disable=g-bad-name
        name,
        config,
        help_string='ConfigDict instance.',
        flag_values=FLAGS,
        lock_config=True,
        **kwargs):
    """Defines flag for inline `ConfigDict's` compatible with absl flags.

  Similar to `DEFINE_config_file` except the flag's value should be a
  `ConfigDict` instead of a path to a file containing a `ConfigDict`. After the
  flag is parsed, `FLAGS.name` will contain a reference to the `ConfigDict`,
  optionally with some values overridden.

  Typical usage example:

  `script.py`::

    from absl import flags

    import ml_collections
    from ml_collections.config_flags import config_flags


    config = ml_collections.ConfigDict({
        'field1': 1,
        'field2': 'tom',
        'nested': {
            'field': 2.23,
        }
    })


    FLAGS = flags.FLAGS
    config_flags.DEFINE_config_dict('my_config', config)
    ...

    print(FLAGS.my_config)

  The following command::

    python script.py -- --my_config.field1 8
                        --my_config.nested.field=2.1

  will print::

    field1: 8
    field2: tom
    nested: {field: 2.1}

  Args:
    name: Flag name.
    config: `ConfigDict` object.
    help_string: Help string to display when --helpfull is called.
        (default: "ConfigDict instance.")
    flag_values: FlagValues instance used for parsing.
        (default: absl.flags.FLAGS)
    lock_config: If set to True, loaded config will be locked through calling
        .lock() method on its instance (if it exists). (default: True)
    **kwargs: Optional keyword arguments passed to Flag constructor.
  """
    if not isinstance(config, ml_collections.ConfigDict):
        raise TypeError('config should be a ConfigDict')
    parser = _InlineConfigParser(name=name, lock_config=lock_config)
    flag = _ConfigFlag(parser=parser,
                       serializer=flags.ArgumentSerializer(),
                       name=name,
                       default=config,
                       help_string=help_string,
                       flag_values=flag_values,
                       **kwargs)

    # Get the module name for the frame at depth 1 in the call stack.
    module_name = sys._getframe(1).f_globals.get('__name__', None)  # pylint: disable=protected-access
    module_name = sys.argv[0] if module_name == '__main__' else module_name
    flags.DEFINE_flag(flag, flag_values, module_name=module_name)
コード例 #9
0
                     False,
                     'Should only log to stderr?',
                     allow_override_cpp=True)
flags.DEFINE_boolean('alsologtostderr',
                     False,
                     'also log to stderr?',
                     allow_override_cpp=True)
flags.DEFINE_string('log_dir',
                    os.getenv('TEST_TMPDIR', ''),
                    'directory to write logfiles into',
                    allow_override_cpp=True)
flags.DEFINE_flag(
    _VerbosityFlag(
        'verbosity',
        -1,
        'Logging verbosity level. Messages logged at this level or lower will '
        'be included. Set to 1 for debug logging. If the flag was not set or '
        'supplied, the value will be changed from the default of -1 (warning) to '
        '0 (info) after flags are parsed.',
        short_name='v',
        allow_hide_cpp=True))
flags.DEFINE_flag(
    _LoggerLevelsFlag(
        'logger_levels', {},
        'Specify log level of loggers. The format is a CSV list of '
        '`name:level`. Where `name` is the logger name used with '
        '`logging.getLogger()`, and `level` is a level name  (INFO, DEBUG, '
        'etc). e.g. `myapp.foo:INFO,other.logger:DEBUG`'))
flags.DEFINE_flag(
    _StderrthresholdFlag(
        'stderrthreshold',
        'fatal', 'log messages at this level, or more severe, to stderr in '
コード例 #10
0
                     False,
                     'Should only log to stderr?',
                     allow_override_cpp=True)
flags.DEFINE_boolean('alsologtostderr',
                     False,
                     'also log to stderr?',
                     allow_override_cpp=True)
flags.DEFINE_string('log_dir',
                    os.getenv('TEST_TMPDIR', ''),
                    'directory to write logfiles into',
                    allow_override_cpp=True)
flags.DEFINE_flag(
    _VerbosityFlag(
        'verbosity',
        -1,
        'Logging verbosity level. Messages logged at this level or lower will '
        'be included. Set to 1 for debug logging. If the flag was not set or '
        'supplied, the value will be changed from the default of -1 (warning) to '
        '0 (info) after flags are parsed.',
        short_name='v',
        allow_hide_cpp=True))
flags.DEFINE_flag(
    _StderrthresholdFlag(
        'stderrthreshold',
        'fatal', 'log messages at this level, or more severe, to stderr in '
        'addition to the logfile.  Possible values are '
        "'debug', 'info', 'warning', 'error', and 'fatal'.  "
        'Obsoletes --alsologtostderr. Using --alsologtostderr '
        'cancels the effect of this flag. Please also note that '
        'this flag is subject to --verbosity and requires logfile '
        'not be stderr.',
        allow_hide_cpp=True))
コード例 #11
0
ファイル: _definitions.py プロジェクト: stjordanis/fancyflags
def DEFINE_dict(*args, **kwargs):  # pylint: disable=invalid-name
    """Defines a flat or nested dictionary flag.

  Usage example:

  ```python
  import fancyflags as ff

  ff.DEFINE_dict(
      "image_settings",
      mode=ff.String("pad", "Mode string field."),
      sizes=dict(
          width=ff.Integer(5, "Width."),
          height=ff.Integer(7, "Height."),
          scale=ff.Float(0.5, "Scale.")
      )
  )

  This creates a flag `FLAGS.image_settings`, with a default value of

  ```python
  {
      "mode": "pad",
      "sizes": {
          "width": 5,
          "height": 7,
          "scale": 0.5,
      }
  }
  ```

  Each item in the definition (e.g. ff.Integer(...)) corresponds to a flag that
  can be overridden from the command line using "dot" notation. For example, the
  following command overrides the `height` item in the nested dictionary defined
  above:

  ```
  python script_name.py -- --image_settings.sizes.height=10
  ```

  Args:
    *args: One or two positional arguments are expected:
        1. A string containing the root name for this flag. This must be set.
        2. Optionally, a `flags.FlagValues` object that will hold the Flags.
           If not set, the usual global `flags.FLAGS` object will be used.
    **kwargs: One or more keyword arguments, where the value is either an
      `ff.Item` such as `ff.String(...)` or `ff.Integer(...)` or a dict with the
      same constraints.

  Returns:
    A `FlagHolder` instance.
  """
    if not args:
        raise ValueError(
            "Please supply one positional argument containing the "
            "top-level flag name for the dict.")

    if not kwargs:
        raise ValueError(
            "Please supply at least one keyword argument defining a "
            "flag."
            "")
    if len(args) > 2:
        raise ValueError(
            "Please supply at most two positional arguments, the "
            "first containing the top-level flag name for the dict "
            "and, optionally and unusually, a second positional "
            "argument to override the flags.FlagValues instance to "
            "use.")

    if not isinstance(args[0], str):
        raise ValueError(
            "The first positional argument must be a string "
            "containing top-level flag name for the dict. Got a {}.".format(
                type(args[0]).__name__))

    if len(args) == 2:
        if not isinstance(args[1], flags.FlagValues):
            raise ValueError(
                "If supplying a second positional argument, this must "
                "be a flags.FlagValues instance. Got a {}. If you meant "
                "to define a flag, note these must be supplied as "
                "keyword arguments. ".format(type(args[1]).__name__))
        flag_values = args[1]
    else:
        flag_values = flags.FLAGS

    flag_name = args[0]

    shared_dict = define_flags(flag_name, kwargs, flag_values=flag_values)

    # usage_logging: dict

    # TODO(b/177672282): Can we persuade pytype to correctly infer the type of the
    #                    flagholder's .value attribute?
    # We register a dummy flag that returns `shared_dict` as a value.
    return flags.DEFINE_flag(_flags.DictFlag(
        shared_dict,
        name=flag_name,
        default=shared_dict,
        parser=flags.ArgumentParser(),
        serializer=None,
        help_string="Unused help string."),
                             flag_values=flag_values)
コード例 #12
0
    def __init__(self):
        super().__init__('helpfull', False, 'show full help', allow_hide_cpp=True)

    def parse(self, arg):
        if arg:
            usage(writeto_stdout=True)
        sys.exit(1)


flags.DEFINE_string('save_dir', 'scripts/templogs', 'Output direcotry')
flags.DEFINE_boolean('clear_save', False, 'Remove anything previously in the output directory')
flags.DEFINE_multi_string('extra_wl', [], 'Path to the CSV containing extra workload info')
flags.DEFINE_multi_string('extra_mem', [], 'Path to the CSV containing extra mem info')
flags.DEFINE_string('force_preset', None, 'Force to use specific server config preset')
flags.DEFINE_flag(HelpFlag())
flags.DEFINE_flag(HelpfullFlag())


def usage(shorthelp=False, writeto_stdout=False, detailed_error=None,
          exitcode=None):
    """Writes __main__'s docstring to stderr with some help text.
    Args:
      shorthelp: bool, if True, prints only flags from the main module,
          rather than all flags.
      writeto_stdout: bool, if True, writes help message to stdout,
          rather than to stderr.
      detailed_error: str, additional detail about why usage info was presented.
      exitcode: optional integer, if set, exits with this status code after
          writing help.
    """
コード例 #13
0
ファイル: _define_auto.py プロジェクト: deepmind/fancyflags
def DEFINE_auto(  # pylint: disable=invalid-name
    name: str,
    fn: _T,
    help_string: Optional[str] = None,
    flag_values: flags.FlagValues = flags.FLAGS,
) -> _flags.TypedFlagHolder[_T]:
    """Defines a flag for an `ff.auto`-compatible constructor or callable.

  Automatically defines a set of dotted `ff.Item` flags corresponding to the
  constructor arguments and their default values.

  Overriding the value of a dotted flag will update the arguments used to invoke
  `fn`. This flag's value returns a callable `fn` with these values as bound
  arguments,

  Example usage:

  ```python
  # Defined in, e.g., datasets library.

  @dataclasses.dataclass
  class DataSettings:
    dataset_name: str = 'mnist'
    split: str = 'train'
    batch_size: int = 128

  # In main script.
  # Exposes flags: --data.dataset_name --data.split and --data.batch_size.
  DATA_SETTINGS = ff.DEFINE_auto('data', datasets.DataSettings, 'Data config')

  def main(argv):
    # del argv  # Unused.
    dataset = datasets.load(DATA_SETTINGS.value())
    # ...
  ```

  Args:
    name: The name for the top-level flag.
    fn: An `ff.auto`-compatible `Callable`.
    help_string: Optional help string for this flag. If not provided, this will
      default to '{fn's module}.{fn's name}'.
    flag_values: An optional `flags.FlagValues` instance.

  Returns:
    A `flags.FlagHolder`.
  """
    arguments = _auto.auto(fn)
    # Define the individual flags.
    defaults = _definitions.define_flags(name,
                                         arguments,
                                         flag_values=flag_values)
    help_string = help_string or f'{fn.__module__}.{fn.__name__}'
    # Define a holder flag.
    holder = flags.DEFINE_flag(
        flag=_flags.AutoFlag(fn,
                             defaults,
                             name=name,
                             default=None,
                             parser=flags.ArgumentParser(),
                             serializer=None,
                             help_string=help_string),
        flag_values=flag_values,
    )

    return _flags.TypedFlagHolder(holder)