Ejemplo n.º 1
0
def DEFINE_output_path(
    name: str,
    default: Union[None, str, pathlib.Path],
    help: str,
    required: bool = False,
    is_dir: bool = False,
    exist_ok: bool = True,
    must_exist: bool = False,
    validator: Callable[[pathlib.Path], bool] = None,
):
    """Registers a flag whose value is an output path.

  An "output path" is a path to a file or directory that may or may not already
  exist. The parsed value is a pathlib.Path instance. The idea is that this flag
  can be used to specify paths to files or directories that will be created
  during program execution. However, note that specifying an output path does
  not guarantee that the file will be produced.

  Args:
    name: The name of the flag.
    default: The default value for the flag. While None is a legal value, it
      will fail during parsing - output paths are required flags.
    help: The help string.
    is_dir: If true, require the that the value be a directory. Else, require
      that the value be a file. Parsing will fail if the path already exists and
      is of the incorrect type.
    exist_ok: If False, require that the path not exist, else parsing will fail.
    must_exist: If True, require that the path exists, else parsing will fail.
  """
    parser = flags_parsers.PathParser(
        must_exist=must_exist,
        exist_ok=exist_ok,
        is_dir=is_dir,
    )
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(
        parser,
        name,
        default,
        help,
        absl_flags.FLAGS,
        serializer,
        module_name=get_calling_module_name(),
    )
    if required:
        absl_flags.mark_flag_as_required(name)
    if validator:
        RegisterFlagValidator(name, validator)
Ejemplo n.º 2
0
def DEFINE_sequence(  # pylint: disable=invalid-name,redefined-builtin
    name: str,
    default: Optional[Iterable[_T]],
    help: str,
    flag_values=flags.FLAGS,
    **args,
) -> flags.FlagHolder[Iterable[_T]]:
    """Defines a flag for a list or tuple of simple types. See `Sequence` docs."""
    parser = _argument_parsers.SequenceParser()
    serializer = flags.ArgumentSerializer()
    # usage_logging: sequence
    return flags.DEFINE(
        parser,
        name,
        default,
        help,
        flag_values,
        serializer,
        **args,
    )
Ejemplo n.º 3
0
def DEFINE_config_dataclass(  # pylint: disable=invalid-name
    name: str,
    config: _T,
    help_string: str = 'Configuration object. Must be a dataclass.',
    flag_values: flags.FlagValues = FLAGS,
    parse_fn: Optional[Callable[[Any], _T]] = None,
    **kwargs,
) -> _TypedFlagHolder[_T]:
  """Defines a typed (dataclass) flag-overrideable configuration.

  Similar to `DEFINE_config_dict` except `config` should be a `dataclass`.

  Args:
    name: Flag name.
    config: A user-defined configuration object. Must be built via `dataclass`.
    help_string: Help string to display when --helpfull is called.
    flag_values: FlagValues instance used for parsing.
    parse_fn: Function that can parse provided flag value, when assigned
    via flag.value, or passed on command line. Default is to only allow
    to assign instances of this class.
    **kwargs: Optional keyword arguments passed to Flag constructor.
  Returns:
    A handle to the defined flag.
  """

  if not dataclasses.is_dataclass(config):
    raise ValueError('Configuration object must be a `dataclass`.')

  # Define the flag.
  parser = _DataclassParser(name=name, dataclass_type=type(config),
                            parse_fn=parse_fn)
  flag = _ConfigFlag(
      flag_values=flag_values,
      parser=parser,
      serializer=flags.ArgumentSerializer(),
      name=name,
      default=config,
      help_string=help_string,
      **kwargs)

  return _TypedFlagHolder(flag=flags.DEFINE_flag(flag, flag_values))
Ejemplo n.º 4
0
def DEFINE_multi_enum(  # pylint: disable=invalid-name,redefined-builtin
    name: str,
    default: Optional[Iterable[_T]],
    enum_values: Iterable[_T],
    help: str,
    flag_values=flags.FLAGS,
    **args,
) -> flags.FlagHolder[_T]:
    """Defines flag for MultiEnum."""
    parser = _argument_parsers.MultiEnumParser(enum_values)
    serializer = flags.ArgumentSerializer()
    # usage_logging: multi_enum
    return flags.DEFINE(
        parser,
        name,
        default,
        help,
        flag_values,
        serializer,
        **args,
    )
Ejemplo n.º 5
0
def DEFINE_database(
    name: str,
    database_class,
    default: Optional[str],
    help: str,
    must_exist: bool = False,
    validator: Callable[[Any], bool] = None,
):
    """Registers a flag whose value is a sqlutil.Database class.

  Unlike other DEFINE_* functions, the value produced by this flag is not an
  instance of the value, but a lambda that will instantiate a database of the
  requested type. This flag value must be called (with no arguments) in order to
  instantiate a database.

  Args:
    name: The name of the flag.
    database_class: The subclass of sqlutil.Database which is to be instantiated
      when this value is called, using the URL declared in 'default'.
    default: The default URL of the database. This is a required value.
    help: The help string.
    must_exist: If True, require that the database exists. Else, the database is
      created if it does not exist.
  """
    parser = flags_parsers.DatabaseParser(database_class,
                                          must_exist=must_exist)
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(
        parser,
        name,
        default,
        help,
        absl_flags.FLAGS,
        serializer,
        module_name=get_calling_module_name(),
    )
    if validator:
        RegisterFlagValidator(name, validator)
Ejemplo n.º 6
0
    def __init__(
        self,
        default: Optional[_T],
        help_string: str,
        parser: flags.ArgumentParser,
        serializer: Optional[flags.ArgumentSerializer] = None,
    ):
        """Initializes a new `Item`.

    Args:
      default: Default value of the flag that this instance will create.
      help_string: Help string for the flag that this instance will create. If
        `None`, then the dotted flag name will be used as the help string.
      parser: A `flags.ArgumentParser` used to parse command line input.
      serializer: An optional custom `flags.ArgumentSerializer`. By default, the
        flag defined by this class will use an instance of the base
        `flags.ArgumentSerializer`.
    """
        # Flags run the following lines of parsing code during initialization.
        # See Flag._set_default in absl/flags/_flag.py

        # It's useful to repeat it here so that users will see any errors when the
        # Item is initialized, rather than when define() is called later.

        # The only minor difference is that Flag._set_default calls Flag._parse,
        # which also catches and modifies the exception type.
        if default is None:
            self.default = default
        else:
            self.default = parser.parse(default)  # pytype: disable=wrong-arg-types

        self._help_string = help_string
        self._parser = parser

        if serializer is None:
            self._serializer = flags.ArgumentSerializer()
        else:
            self._serializer = serializer
Ejemplo n.º 7
0
    def __init__(self, default, help_string, parser, serializer=None):
        if default is None:
            self.default = default
        else:
            if (isinstance(default, collections.abc.Iterable)
                    and not isinstance(default, (str, bytes))):
                # Convert all non-string iterables to lists.
                default = list(default)

            if not isinstance(default, list):
                # Turn single items into single-value lists.
                default = [default]

            # Ensure each individual value is well-formed.
            self.default = [parser.parse(item) for item in default]

        self._help_string = help_string
        self._parser = parser

        if serializer is None:
            self._serializer = flags.ArgumentSerializer()
        else:
            self._serializer = serializer
Ejemplo n.º 8
0
    def test_update_shared_dict(self):
        # Tests that the shared dict is updated when the flag value is updated.
        shared_dict = {'a': {'b': 'value'}}
        namespace = ('a', 'b')
        flag_values = flags.FlagValues()

        flags.DEFINE_flag(_flags.ItemFlag(
            shared_dict,
            namespace,
            parser=flags.ArgumentParser(),
            serializer=flags.ArgumentSerializer(),
            name='a.b',
            default='bar',
            help_string='help string'),
                          flag_values=flag_values)

        flag_values['a.b'].value = 'new_value'
        with self.subTest(name='setter'):
            self.assertEqual(shared_dict, {'a': {'b': 'new_value'}})

        flag_values(('./program', '--a.b=override'))
        with self.subTest(name='override_parse'):
            self.assertEqual(shared_dict, {'a': {'b': 'override'}})
Ejemplo n.º 9
0
def DEFINE_enum(
    name: str,
    enum_class,
    default,
    help: str,
    validator: Callable[[Any], bool] = None,
):
    """Registers a flag whose value is an enum.Enum class.

  Unlike other DEFINE_* functions, the value produced by this flag is not an
  instance of the value, but a lambda that will instantiate a database of the
  requested type. This flag value must be called (with no arguments) in order
  to instantiate an enum.

  Args:
    name: The name of the flag.
    enum_class: The subclass of enum.Enum which is to be instantiated when this
      value is called.
    default: The default value of the enum. Either the string name or an enum
      value.
    help: The help string.
    must_exist: If True, require that the database exists. Else, the database is
      created if it does not exist.
  """
    parser = flags_parsers.EnumParser(enum_class)
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(
        parser,
        name,
        default,
        help,
        absl_flags.FLAGS,
        serializer,
        module_name=get_calling_module_name(),
    )
    if validator:
        RegisterFlagValidator(name, validator)
Ejemplo n.º 10
0
def DEFINE_input_path(
    name: str,
    default: Union[None, str, pathlib.Path],
    help: str,
    required: bool = False,
    is_dir: bool = False,
    validator: Callable[[pathlib.Path], bool] = None,
):
    """Registers a flag whose value is an input path.

  An "input path" is a path to a file or directory that exists. The parsed value
  is a pathlib.Path instance. Flag parsing will fail if the value of this flag
  is not a path to an existing file or directory.

  Args:
    name: The name of the flag.
    default: The default value for the flag. While None is a legal value, it
      will fail during parsing - input paths are required flags.
    help: The help string.
    is_dir: If true, require the that the value be a directory. Else, require
      that the value be a file. Parsing will fail if this is not the case.
  """
    parser = flags_parsers.PathParser(must_exist=True, is_dir=is_dir)
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(
        parser,
        name,
        default,
        help,
        absl_flags.FLAGS,
        serializer,
        module_name=get_calling_module_name(),
    )
    if required:
        absl_flags.mark_flag_as_required(name)
    if validator:
        RegisterFlagValidator(name, validator)
Ejemplo n.º 11
0
#   We then add support for long.
# - Only Python 2 str were supported (not unicode). Python 3 will behave the
#   same with the str semantic change.
_FIELD_TYPE_TO_PARSER = {
    float: flags.FloatParser(),
    bool: flags.BooleanParser(),
    # Implementing a custom parser to override `Tuple` arguments.
    tuple: tuple_parser.TupleParser(),
}
for t in six.integer_types:
    _FIELD_TYPE_TO_PARSER[t] = flags.IntegerParser()
for t in six.string_types:
    _FIELD_TYPE_TO_PARSER[t] = flags.ArgumentParser()
_FIELD_TYPE_TO_PARSER[str] = flags.ArgumentParser()
_FIELD_TYPE_TO_SERIALIZER = {
    t: flags.ArgumentSerializer()
    for t in _FIELD_TYPE_TO_PARSER
}


class UnsupportedOperationError(flags.Error):
    pass


class FlagOrderError(flags.Error):
    pass


class UnparsedFlagError(flags.Error):
    pass
 def __init__(self, *args, **kwargs):
     super(_StderrthresholdFlag, self).__init__(flags.ArgumentParser(),
                                                flags.ArgumentSerializer(),
                                                *args, **kwargs)
Ejemplo n.º 13
0
 def __init__(self, default, help_string):
     parser = flags.ArgumentParser()
     serializer = flags.ArgumentSerializer()
     super().__init__(default, help_string, parser, serializer)
Ejemplo n.º 14
0
 def __init__(self, default, enum_values, help_string):
     parser = _argument_parsers.MultiEnumParser(enum_values)
     serializer = flags.ArgumentSerializer()
     _ = parser.parse(enum_values)
     super().__init__(default, help_string, parser, serializer)
Ejemplo n.º 15
0
def DEFINE_config_dict(  # pylint: disable=g-bad-name
        name,
        config,
        help_string='ConfigDict instance.',
        flag_values=FLAGS,
        lock_config=True,
        **kwargs):
    """Defines flag for inline `ConfigDict's` compatible with absl flags.

  Similar to `DEFINE_config_file` except the flag's value should be a
  `ConfigDict` instead of a path to a file containing a `ConfigDict`. After the
  flag is parsed, `FLAGS.name` will contain a reference to the `ConfigDict`,
  optionally with some values overridden.

  Typical usage example:

  `script.py`::

    from absl import flags

    import ml_collections
    from ml_collections.config_flags import config_flags


    config = ml_collections.ConfigDict({
        'field1': 1,
        'field2': 'tom',
        'nested': {
            'field': 2.23,
        }
    })


    FLAGS = flags.FLAGS
    config_flags.DEFINE_config_dict('my_config', config)
    ...

    print(FLAGS.my_config)

  The following command::

    python script.py -- --my_config.field1 8
                        --my_config.nested.field=2.1

  will print::

    field1: 8
    field2: tom
    nested: {field: 2.1}

  Args:
    name: Flag name.
    config: `ConfigDict` object.
    help_string: Help string to display when --helpfull is called.
        (default: "ConfigDict instance.")
    flag_values: FlagValues instance used for parsing.
        (default: absl.flags.FLAGS)
    lock_config: If set to True, loaded config will be locked through calling
        .lock() method on its instance (if it exists). (default: True)
    **kwargs: Optional keyword arguments passed to Flag constructor.
  """
    if not isinstance(config, ml_collections.ConfigDict):
        raise TypeError('config should be a ConfigDict')
    parser = _InlineConfigParser(name=name, lock_config=lock_config)
    flag = _ConfigFlag(parser=parser,
                       serializer=flags.ArgumentSerializer(),
                       name=name,
                       default=config,
                       help_string=help_string,
                       flag_values=flag_values,
                       **kwargs)

    # Get the module name for the frame at depth 1 in the call stack.
    module_name = sys._getframe(1).f_globals.get('__name__', None)  # pylint: disable=protected-access
    module_name = sys.argv[0] if module_name == '__main__' else module_name
    flags.DEFINE_flag(flag, flag_values, module_name=module_name)
Ejemplo n.º 16
0
def DEFINE_config_file(  # pylint: disable=g-bad-name
        name,
        default=None,
        help_string='path to config file.',
        flag_values=FLAGS,
        lock_config=True,
        **kwargs):
    r"""Defines flag for `ConfigDict` files compatible with absl flags.

  The flag's value should be a path to a valid python file which contains a
  function called `get_config()` that returns a python object specifying
  a configuration. After the flag is parsed, `FLAGS.name` will contain
  a reference to this object, optionally with some values overridden.

  During flags parsing, every flag of form `--name.([a-zA-Z0-9]+\.?)+=value`
  and `-name.([a-zA-Z0-9]+\.?)+ value` will be treated as an override of a
  specific field in the config object returned by this flag. Field is
  essentially a dot delimited path inside the object where each path element
  has to be either an attribute or a key existing in the config object.
  For example `--my_config.field1.field2=val` means "assign value val
  to the attribute (or key) `field2` inside value of the attribute (or key)
  `field1` inside the value of `my_config` object". If there are both
  attribute and key-based access with the same name, attribute is preferred.

  Typical usage example:

  `script.py`::

    from absl import flags
    from ml_collections.config_flags import config_flags

    FLAGS = flags.FLAGS
    config_flags.DEFINE_config_file('my_config')

    print(FLAGS.my_config)

  `config.py`::

    def get_config():
      return {
          'field1': 1,
          'field2': 'tom',
          'nested': {
              'field': 2.23,
          },
      }

  The following command::

    python script.py -- --my_config=config.py
                        --my_config.field1 8
                        --my_config.nested.field=2.1

  will print::

    {'field1': 8, 'field2': 'tom', 'nested': {'field': 2.1}}

  It is possible to parameterise the get_config function, allowing it to
  return a differently structured result for different occasions. This is
  particularly useful when setting up hyperparameter sweeps across various
  network architectures.

  `parameterised_config.py`::

    def get_config(config_string):
      possible_configs = {
          'mlp': {
              'constructor': 'snt.nets.MLP',
              'config': {
                  'output_sizes': (128, 128, 1),
              }
          },
          'lstm': {
              'constructor': 'snt.LSTM',
              'config': {
                  'hidden_size': 128,
                  'forget_bias': 1.0,
              }
          }
      }
      return possible_configs[config_string]

  If a colon is present in the command line override for the config file,
  everything to the right of the colon is passed into the get_config function.
  The following command lines will both function correctly::

    python script.py -- --my_config=parameterised_config.py:mlp
                        --my_config.config.output_sizes="(256,256,1)"


    python script.py -- --my_config=parameterised_config.py:lstm
                        --my_config.config.hidden_size=256

  The following will produce an error, as the hidden_size flag does not
  exist when the "mlp" config_string is provided::

    python script.py -- --my_config=parameterised_config.py:mlp
                        --my_config.config.hidden_size=256

  Args:
    name: Flag name, optionally including extra config after a colon.
    default: Default value of the flag (default: None).
    help_string: Help string to display when --helpfull is called.
        (default: "path to config file.")
    flag_values: FlagValues instance used for parsing.
        (default: absl.flags.FLAGS)
    lock_config: If set to True, loaded config will be locked through calling
        .lock() method on its instance (if it exists). (default: True)
    **kwargs: Optional keyword arguments passed to Flag constructor.
  """
    parser = _ConfigFileParser(name=name, lock_config=lock_config)
    flag = _ConfigFlag(parser=parser,
                       serializer=flags.ArgumentSerializer(),
                       name=name,
                       default=default,
                       help_string=help_string,
                       flag_values=flag_values,
                       **kwargs)

    # Get the module name for the frame at depth 1 in the call stack.
    module_name = sys._getframe(1).f_globals.get('__name__', None)  # pylint: disable=protected-access
    module_name = sys.argv[0] if module_name == '__main__' else module_name
    flags.DEFINE_flag(flag, flag_values, module_name=module_name)
 def __init__(self, *args, **kwargs):
     super(_VerbosityFlag, self).__init__(flags.IntegerParser(),
                                          flags.ArgumentSerializer(), *args,
                                          **kwargs)