Beispiel #1
0
def DEFINE_database(name: str,
                    database_class,
                    default: Optional[str],
                    help: str,
                    must_exist: bool = False,
                    validator: Callable[[Any], bool] = None):
    """Registers a flag whose value is a sqlutil.Database class.

  Unlike other DEFINE_* functions, the value produced by this flag is not an
  instance of the value, but a lambda that will instantiate a database of the
  requested type. This flag value must be called (with no arguments) in order to
  instantiate a database.

  Args:
    name: The name of the flag.
    database_class: The subclass of sqlutil.Database which is to be instantiated
      when this value is called, using the URL declared in 'default'.
    default: The default URL of the database. This is a required value.
    help: The help string.
    must_exist: If True, require that the database exists. Else, the database is
      created if it does not exist.
  """
    parser = flags_parsers.DatabaseParser(database_class,
                                          must_exist=must_exist)
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(parser,
                      name,
                      default,
                      help,
                      absl_flags.FLAGS,
                      serializer,
                      module_name=get_calling_module_name())
    if validator:
        RegisterFlagValidator(name, validator)
Beispiel #2
0
def DEFINE_input_path(name: str,
                      default: Union[None, str, pathlib.Path],
                      help: str,
                      required: bool = False,
                      is_dir: bool = False,
                      validator: Callable[[pathlib.Path], bool] = None):
    """Registers a flag whose value is an input path.

  An "input path" is a path to a file or directory that exists. The parsed value
  is a pathlib.Path instance. Flag parsing will fail if the value of this flag
  is not a path to an existing file or directory.

  Args:
    name: The name of the flag.
    default: The default value for the flag. While None is a legal value, it
      will fail during parsing - input paths are required flags.
    help: The help string.
    is_dir: If true, require the that the value be a directory. Else, require
      that the value be a file. Parsing will fail if this is not the case.
  """
    parser = flags_parsers.PathParser(must_exist=True, is_dir=is_dir)
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(parser,
                      name,
                      default,
                      help,
                      absl_flags.FLAGS,
                      serializer,
                      module_name=get_calling_module_name())
    if required:
        absl_flags.mark_flag_as_required(name)
    if validator:
        RegisterFlagValidator(name, validator)
Beispiel #3
0
def DEFINE_integerlist(name, default, help, on_nonincreasing=None,
                       flag_values=FLAGS, **kwargs):
  """Register a flag whose value must be an integer list."""

  parser = IntegerListParser(on_nonincreasing=on_nonincreasing)
  serializer = IntegerListSerializer()

  flags.DEFINE(parser, name, default, help, flag_values, serializer, **kwargs)
Beispiel #4
0
def DEFINE_sequence(  # pylint: disable=invalid-name,redefined-builtin
        name,
        default,
        help,
        flag_values=flags.FLAGS,
        **args):
    """Defines a flag for a list or tuple of simple types. See `Sequence` docs."""
    parser = _argument_parsers.SequenceParser()
    serializer = flags.ArgumentSerializer()
    # usage_logging: sequence
    flags.DEFINE(parser, name, default, help, flag_values, serializer, **args)
Beispiel #5
0
def DEFINE_multi_enum(  # pylint: disable=invalid-name,redefined-builtin
        name,
        default,
        enum_values,
        help,
        flag_values=flags.FLAGS,
        **args):
    """Defines flag for MultiEnum."""
    parser = _argument_parsers.MultiEnumParser(enum_values)
    serializer = flags.ArgumentSerializer()
    # usage_logging: multi_enum
    flags.DEFINE(parser, name, default, help, flag_values, serializer, **args)
Beispiel #6
0
def DEFINE_yaml(name, default, help, flag_values=flags.FLAGS, **kwargs):
  """Register a flag whose value is a YAML expression.

  Args:
    name: string. The name of the flag.
    default: object. The default value of the flag.
    help: string. A help message for the user.
    flag_values: the absl.flags.FlagValues object to define the flag in.
    kwargs: extra arguments to pass to absl.flags.DEFINE().
  """

  parser = YAMLParser()
  serializer = YAMLSerializer()

  flags.DEFINE(parser, name, default, help, flag_values, serializer, **kwargs)
Beispiel #7
0
def DEFINE_output_path(
    name: str,
    default: Union[None, str, pathlib.Path],
    help: str,
    required: bool = False,
    is_dir: bool = False,
    exist_ok: bool = True,
    must_exist: bool = False,
    validator: Callable[[pathlib.Path], bool] = None,
):
    """Registers a flag whose value is an output path.

  An "output path" is a path to a file or directory that may or may not already
  exist. The parsed value is a pathlib.Path instance. The idea is that this flag
  can be used to specify paths to files or directories that will be created
  during program execution. However, note that specifying an output path does
  not guarantee that the file will be produced.

  Args:
    name: The name of the flag.
    default: The default value for the flag. While None is a legal value, it
      will fail during parsing - output paths are required flags.
    help: The help string.
    is_dir: If true, require the that the value be a directory. Else, require
      that the value be a file. Parsing will fail if the path already exists and
      is of the incorrect type.
    exist_ok: If False, require that the path not exist, else parsing will fail.
    must_exist: If True, require that the path exists, else parsing will fail.
  """
    parser = flags_parsers.PathParser(
        must_exist=must_exist,
        exist_ok=exist_ok,
        is_dir=is_dir,
    )
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(
        parser,
        name,
        default,
        help,
        absl_flags.FLAGS,
        serializer,
        module_name=get_calling_module_name(),
    )
    if required:
        absl_flags.mark_flag_as_required(name)
    if validator:
        RegisterFlagValidator(name, validator)
Beispiel #8
0
def DEFINE_units(name, default, help, convertible_to,
                 flag_values=flags.FLAGS, **kwargs):
  """Register a flag whose value is a units expression.

  Args:
    name: string. The name of the flag.
    default: units.Quantity. The default value.
    help: string. A help message for the user.
    convertible_to: Either an individual unit specification or a series of unit
        specifications, where each unit specification is either a string (e.g.
        'byte') or a units.Unit. The flag value must be convertible to at least
        one of the specified Units to be considered valid.
    flag_values: the absl.flags.FlagValues object to define the flag in.
  """
  parser = UnitsParser(convertible_to=convertible_to)
  serializer = UnitsSerializer()
  flags.DEFINE(parser, name, default, help, flag_values, serializer, **kwargs)
Beispiel #9
0
def DEFINE_sequence(  # pylint: disable=invalid-name,redefined-builtin
    name: str,
    default: Optional[Iterable[_T]],
    help: str,
    flag_values=flags.FLAGS,
    **args,
) -> flags.FlagHolder[Iterable[_T]]:
    """Defines a flag for a list or tuple of simple types. See `Sequence` docs."""
    parser = _argument_parsers.SequenceParser()
    serializer = flags.ArgumentSerializer()
    # usage_logging: sequence
    return flags.DEFINE(
        parser,
        name,
        default,
        help,
        flag_values,
        serializer,
        **args,
    )
Beispiel #10
0
def DEFINE_multi_enum(  # pylint: disable=invalid-name,redefined-builtin
    name: str,
    default: Optional[Iterable[_T]],
    enum_values: Iterable[_T],
    help: str,
    flag_values=flags.FLAGS,
    **args,
) -> flags.FlagHolder[_T]:
    """Defines flag for MultiEnum."""
    parser = _argument_parsers.MultiEnumParser(enum_values)
    serializer = flags.ArgumentSerializer()
    # usage_logging: multi_enum
    return flags.DEFINE(
        parser,
        name,
        default,
        help,
        flag_values,
        serializer,
        **args,
    )
Beispiel #11
0
    def _parse(self, argument):
        # Parse config
        config = super(_ConfigFlag, self)._parse(argument)

        # Get list or overrides
        overrides = self._GetOverrides(sys.argv)
        # Attach types definitions
        overrides_types = GetTypes(overrides, config)

        # Iterate over overridden fields and create valid parsers
        self._override_values = {}
        for field_path, field_type in zip(overrides, overrides_types):
            field_help = 'An override of {}\'s field {}'.format(
                self.name, field_path)
            field_name = '{}.{}'.format(self.name, field_path)

            if field_type in _FIELD_TYPE_TO_PARSER:
                parser = _ConfigFieldParser(_FIELD_TYPE_TO_PARSER[field_type],
                                            field_path, config,
                                            self._override_values)
                flags.DEFINE(parser,
                             field_name,
                             GetValue(field_path, config),
                             field_help,
                             flag_values=self.flag_values,
                             serializer=_FIELD_TYPE_TO_SERIALIZER[field_type])
                flag = self.flag_values._flags().get(field_name)  # pylint: disable=protected-access
                flag.boolean = field_type is bool
            else:
                raise UnsupportedOperationError(
                    "Type {} of field {} is not supported for overriding. "
                    "Currently supported types are: {}. (Note that tuples should "
                    "be passed as a string on the command line: flag='(a, b, c)', "
                    "rather than flag=(a, b, c).)".format(
                        field_type, field_name, _FIELD_TYPE_TO_PARSER.keys()))

        self._config_filename = argument
        return config
Beispiel #12
0
def DEFINE_enum(
    name: str,
    enum_class,
    default,
    help: str,
    validator: Callable[[Any], bool] = None,
):
    """Registers a flag whose value is an enum.Enum class.

  Unlike other DEFINE_* functions, the value produced by this flag is not an
  instance of the value, but a lambda that will instantiate a database of the
  requested type. This flag value must be called (with no arguments) in order
  to instantiate an enum.

  Args:
    name: The name of the flag.
    enum_class: The subclass of enum.Enum which is to be instantiated when this
      value is called.
    default: The default value of the enum. Either the string name or an enum
      value.
    help: The help string.
    must_exist: If True, require that the database exists. Else, the database is
      created if it does not exist.
  """
    parser = flags_parsers.EnumParser(enum_class)
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(
        parser,
        name,
        default,
        help,
        absl_flags.FLAGS,
        serializer,
        module_name=get_calling_module_name(),
    )
    if validator:
        RegisterFlagValidator(name, validator)
Beispiel #13
0
def _DefineMemorySizeFlag(name, default, help, flag_values=FLAGS, **kwargs):
  flags.DEFINE(_MEMORY_SIZE_PARSER, name, default, help, flag_values,
               _UNITS_SERIALIZER, **kwargs)
Beispiel #14
0
def DEFINE_point(name, default, help):  # pylint: disable=invalid-name,redefined-builtin
  """Registers a flag whose value parses as a point."""
  flags.DEFINE(PointParser(), name, default, help)
Beispiel #15
0
tf.enable_v2_behavior()

flags.DEFINE_string("neutra_log_dir", "/tmp/neutra",
                    "Output directory for experiment artifacts.")
flags.DEFINE_string("checkpoint_log_dir", None,
                    "Output directory for checkpoints, if specified.")
flags.DEFINE_enum(
    "mode", "train", ["eval", "benchmark", "train", "objective"],
    "Mode for this run. Standard trains bijector, tunes the "
    "chain parameters and does the evals. Benchmark uses "
    "the tuned parameters and benchmarks the chain.")
flags.DEFINE_boolean(
    "restore_from_config", False,
    "Whether to restore the hyperparameters from the "
    "previous run.")
flags.DEFINE(utils.YAMLDictParser(), "hparams", "",
             "Hyperparameters to override.")
flags.DEFINE_string("tune_outputs_name", "tune_outputs",
                    "Name of the tune_outputs file.")
flags.DEFINE_string("eval_suffix", "", "Suffix for the eval outputs.")

FLAGS = flags.FLAGS


def Train(exp):
    log_dir = (FLAGS.checkpoint_log_dir
               if FLAGS.checkpoint_log_dir else FLAGS.neutra_log_dir)
    logging.info("Training")
    q_stats, secs_per_step = exp.Train()
    tf.io.gfile.makedirs(log_dir)
    utils.save_json(q_stats, os.path.join(log_dir, "q_stats"))
    utils.save_json(secs_per_step, os.path.join(log_dir,
Beispiel #16
0
                'Unable to parse {0}. Unrecognized stages were found: {1}'.
                format(repr(argument), ', '.join(sorted(invalid_items))))

        if _ALL in stage_list:
            if len(stage_list) > 1:
                raise ValueError(
                    "Unable to parse {0}. If 'all' stages are specified, individual "
                    "stages cannot also be specified.".format(repr(argument)))
            return list(STAGES)

        previous_stage = stage_list[0]
        for stage in itertools.islice(stage_list, 1, None):
            expected_stage = _NEXT_STAGE.get(previous_stage)
            if not expected_stage:
                raise ValueError(
                    "Unable to parse {0}. '{1}' should be the last "
                    "stage.".format(repr(argument), previous_stage))
            if stage != expected_stage:
                raise ValueError(
                    "Unable to parse {0}. The stage after '{1}' should be '{2}', not "
                    "'{3}'.".format(repr(argument), previous_stage,
                                    expected_stage, stage))
            previous_stage = stage

        return stage_list


flags.DEFINE(RunStageParser(), 'run_stage', STAGES,
             "The stage or stages of perfkitbenchmarker to run.", flags.FLAGS,
             flags.ListSerializer(','))
Beispiel #17
0
def DEFINE_point(name, default, help_string, flag_values=flags.FLAGS, **args):  # pylint: disable=invalid-name,redefined-builtin
    """Registers a flag whose value parses as a point."""
    flags.DEFINE(PointParser(), name, default, help_string, flag_values,
                 PointSerializer(), **args)