コード例 #1
0
ファイル: app.py プロジェクト: zhangheyu518/clgen
def DEFINE_database(name: str,
                    database_class,
                    default: Optional[str],
                    help: str,
                    must_exist: bool = False,
                    validator: Callable[[Any], bool] = None):
    """Registers a flag whose value is a sqlutil.Database class.

  Unlike other DEFINE_* functions, the value produced by this flag is not an
  instance of the value, but a lambda that will instantiate a database of the
  requested type. This flag value must be called (with no arguments) in order to
  instantiate a database.

  Args:
    name: The name of the flag.
    database_class: The subclass of sqlutil.Database which is to be instantiated
      when this value is called, using the URL declared in 'default'.
    default: The default URL of the database. This is a required value.
    help: The help string.
    must_exist: If True, require that the database exists. Else, the database is
      created if it does not exist.
  """
    parser = flags_parsers.DatabaseParser(database_class,
                                          must_exist=must_exist)
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(parser,
                      name,
                      default,
                      help,
                      absl_flags.FLAGS,
                      serializer,
                      module_name=get_calling_module_name())
    if validator:
        RegisterFlagValidator(name, validator)
コード例 #2
0
ファイル: app.py プロジェクト: zhangheyu518/clgen
def DEFINE_input_path(name: str,
                      default: Union[None, str, pathlib.Path],
                      help: str,
                      required: bool = False,
                      is_dir: bool = False,
                      validator: Callable[[pathlib.Path], bool] = None):
    """Registers a flag whose value is an input path.

  An "input path" is a path to a file or directory that exists. The parsed value
  is a pathlib.Path instance. Flag parsing will fail if the value of this flag
  is not a path to an existing file or directory.

  Args:
    name: The name of the flag.
    default: The default value for the flag. While None is a legal value, it
      will fail during parsing - input paths are required flags.
    help: The help string.
    is_dir: If true, require the that the value be a directory. Else, require
      that the value be a file. Parsing will fail if this is not the case.
  """
    parser = flags_parsers.PathParser(must_exist=True, is_dir=is_dir)
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(parser,
                      name,
                      default,
                      help,
                      absl_flags.FLAGS,
                      serializer,
                      module_name=get_calling_module_name())
    if required:
        absl_flags.mark_flag_as_required(name)
    if validator:
        RegisterFlagValidator(name, validator)
コード例 #3
0
def DEFINE_integerlist(name, default, help, on_nonincreasing=None,
                       flag_values=FLAGS, **kwargs):
  """Register a flag whose value must be an integer list."""

  parser = IntegerListParser(on_nonincreasing=on_nonincreasing)
  serializer = IntegerListSerializer()

  flags.DEFINE(parser, name, default, help, flag_values, serializer, **kwargs)
コード例 #4
0
ファイル: _definitions.py プロジェクト: stjordanis/fancyflags
def DEFINE_sequence(  # pylint: disable=invalid-name,redefined-builtin
        name,
        default,
        help,
        flag_values=flags.FLAGS,
        **args):
    """Defines a flag for a list or tuple of simple types. See `Sequence` docs."""
    parser = _argument_parsers.SequenceParser()
    serializer = flags.ArgumentSerializer()
    # usage_logging: sequence
    flags.DEFINE(parser, name, default, help, flag_values, serializer, **args)
コード例 #5
0
ファイル: _definitions.py プロジェクト: stjordanis/fancyflags
def DEFINE_multi_enum(  # pylint: disable=invalid-name,redefined-builtin
        name,
        default,
        enum_values,
        help,
        flag_values=flags.FLAGS,
        **args):
    """Defines flag for MultiEnum."""
    parser = _argument_parsers.MultiEnumParser(enum_values)
    serializer = flags.ArgumentSerializer()
    # usage_logging: multi_enum
    flags.DEFINE(parser, name, default, help, flag_values, serializer, **args)
コード例 #6
0
def DEFINE_yaml(name, default, help, flag_values=flags.FLAGS, **kwargs):
  """Register a flag whose value is a YAML expression.

  Args:
    name: string. The name of the flag.
    default: object. The default value of the flag.
    help: string. A help message for the user.
    flag_values: the absl.flags.FlagValues object to define the flag in.
    kwargs: extra arguments to pass to absl.flags.DEFINE().
  """

  parser = YAMLParser()
  serializer = YAMLSerializer()

  flags.DEFINE(parser, name, default, help, flag_values, serializer, **kwargs)
コード例 #7
0
def DEFINE_output_path(
    name: str,
    default: Union[None, str, pathlib.Path],
    help: str,
    required: bool = False,
    is_dir: bool = False,
    exist_ok: bool = True,
    must_exist: bool = False,
    validator: Callable[[pathlib.Path], bool] = None,
):
    """Registers a flag whose value is an output path.

  An "output path" is a path to a file or directory that may or may not already
  exist. The parsed value is a pathlib.Path instance. The idea is that this flag
  can be used to specify paths to files or directories that will be created
  during program execution. However, note that specifying an output path does
  not guarantee that the file will be produced.

  Args:
    name: The name of the flag.
    default: The default value for the flag. While None is a legal value, it
      will fail during parsing - output paths are required flags.
    help: The help string.
    is_dir: If true, require the that the value be a directory. Else, require
      that the value be a file. Parsing will fail if the path already exists and
      is of the incorrect type.
    exist_ok: If False, require that the path not exist, else parsing will fail.
    must_exist: If True, require that the path exists, else parsing will fail.
  """
    parser = flags_parsers.PathParser(
        must_exist=must_exist,
        exist_ok=exist_ok,
        is_dir=is_dir,
    )
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(
        parser,
        name,
        default,
        help,
        absl_flags.FLAGS,
        serializer,
        module_name=get_calling_module_name(),
    )
    if required:
        absl_flags.mark_flag_as_required(name)
    if validator:
        RegisterFlagValidator(name, validator)
コード例 #8
0
def DEFINE_units(name, default, help, convertible_to,
                 flag_values=flags.FLAGS, **kwargs):
  """Register a flag whose value is a units expression.

  Args:
    name: string. The name of the flag.
    default: units.Quantity. The default value.
    help: string. A help message for the user.
    convertible_to: Either an individual unit specification or a series of unit
        specifications, where each unit specification is either a string (e.g.
        'byte') or a units.Unit. The flag value must be convertible to at least
        one of the specified Units to be considered valid.
    flag_values: the absl.flags.FlagValues object to define the flag in.
  """
  parser = UnitsParser(convertible_to=convertible_to)
  serializer = UnitsSerializer()
  flags.DEFINE(parser, name, default, help, flag_values, serializer, **kwargs)
コード例 #9
0
def DEFINE_sequence(  # pylint: disable=invalid-name,redefined-builtin
    name: str,
    default: Optional[Iterable[_T]],
    help: str,
    flag_values=flags.FLAGS,
    **args,
) -> flags.FlagHolder[Iterable[_T]]:
    """Defines a flag for a list or tuple of simple types. See `Sequence` docs."""
    parser = _argument_parsers.SequenceParser()
    serializer = flags.ArgumentSerializer()
    # usage_logging: sequence
    return flags.DEFINE(
        parser,
        name,
        default,
        help,
        flag_values,
        serializer,
        **args,
    )
コード例 #10
0
def DEFINE_multi_enum(  # pylint: disable=invalid-name,redefined-builtin
    name: str,
    default: Optional[Iterable[_T]],
    enum_values: Iterable[_T],
    help: str,
    flag_values=flags.FLAGS,
    **args,
) -> flags.FlagHolder[_T]:
    """Defines flag for MultiEnum."""
    parser = _argument_parsers.MultiEnumParser(enum_values)
    serializer = flags.ArgumentSerializer()
    # usage_logging: multi_enum
    return flags.DEFINE(
        parser,
        name,
        default,
        help,
        flag_values,
        serializer,
        **args,
    )
コード例 #11
0
ファイル: config_flags.py プロジェクト: NeoTim/ml_collections
    def _parse(self, argument):
        # Parse config
        config = super(_ConfigFlag, self)._parse(argument)

        # Get list or overrides
        overrides = self._GetOverrides(sys.argv)
        # Attach types definitions
        overrides_types = GetTypes(overrides, config)

        # Iterate over overridden fields and create valid parsers
        self._override_values = {}
        for field_path, field_type in zip(overrides, overrides_types):
            field_help = 'An override of {}\'s field {}'.format(
                self.name, field_path)
            field_name = '{}.{}'.format(self.name, field_path)

            if field_type in _FIELD_TYPE_TO_PARSER:
                parser = _ConfigFieldParser(_FIELD_TYPE_TO_PARSER[field_type],
                                            field_path, config,
                                            self._override_values)
                flags.DEFINE(parser,
                             field_name,
                             GetValue(field_path, config),
                             field_help,
                             flag_values=self.flag_values,
                             serializer=_FIELD_TYPE_TO_SERIALIZER[field_type])
                flag = self.flag_values._flags().get(field_name)  # pylint: disable=protected-access
                flag.boolean = field_type is bool
            else:
                raise UnsupportedOperationError(
                    "Type {} of field {} is not supported for overriding. "
                    "Currently supported types are: {}. (Note that tuples should "
                    "be passed as a string on the command line: flag='(a, b, c)', "
                    "rather than flag=(a, b, c).)".format(
                        field_type, field_name, _FIELD_TYPE_TO_PARSER.keys()))

        self._config_filename = argument
        return config
コード例 #12
0
def DEFINE_enum(
    name: str,
    enum_class,
    default,
    help: str,
    validator: Callable[[Any], bool] = None,
):
    """Registers a flag whose value is an enum.Enum class.

  Unlike other DEFINE_* functions, the value produced by this flag is not an
  instance of the value, but a lambda that will instantiate a database of the
  requested type. This flag value must be called (with no arguments) in order
  to instantiate an enum.

  Args:
    name: The name of the flag.
    enum_class: The subclass of enum.Enum which is to be instantiated when this
      value is called.
    default: The default value of the enum. Either the string name or an enum
      value.
    help: The help string.
    must_exist: If True, require that the database exists. Else, the database is
      created if it does not exist.
  """
    parser = flags_parsers.EnumParser(enum_class)
    serializer = absl_flags.ArgumentSerializer()
    absl_flags.DEFINE(
        parser,
        name,
        default,
        help,
        absl_flags.FLAGS,
        serializer,
        module_name=get_calling_module_name(),
    )
    if validator:
        RegisterFlagValidator(name, validator)
コード例 #13
0
def _DefineMemorySizeFlag(name, default, help, flag_values=FLAGS, **kwargs):
  flags.DEFINE(_MEMORY_SIZE_PARSER, name, default, help, flag_values,
               _UNITS_SERIALIZER, **kwargs)
コード例 #14
0
ファイル: point_flag.py プロジェクト: mattffc/pysc2
def DEFINE_point(name, default, help):  # pylint: disable=invalid-name,redefined-builtin
  """Registers a flag whose value parses as a point."""
  flags.DEFINE(PointParser(), name, default, help)
コード例 #15
0
tf.enable_v2_behavior()

flags.DEFINE_string("neutra_log_dir", "/tmp/neutra",
                    "Output directory for experiment artifacts.")
flags.DEFINE_string("checkpoint_log_dir", None,
                    "Output directory for checkpoints, if specified.")
flags.DEFINE_enum(
    "mode", "train", ["eval", "benchmark", "train", "objective"],
    "Mode for this run. Standard trains bijector, tunes the "
    "chain parameters and does the evals. Benchmark uses "
    "the tuned parameters and benchmarks the chain.")
flags.DEFINE_boolean(
    "restore_from_config", False,
    "Whether to restore the hyperparameters from the "
    "previous run.")
flags.DEFINE(utils.YAMLDictParser(), "hparams", "",
             "Hyperparameters to override.")
flags.DEFINE_string("tune_outputs_name", "tune_outputs",
                    "Name of the tune_outputs file.")
flags.DEFINE_string("eval_suffix", "", "Suffix for the eval outputs.")

FLAGS = flags.FLAGS


def Train(exp):
    log_dir = (FLAGS.checkpoint_log_dir
               if FLAGS.checkpoint_log_dir else FLAGS.neutra_log_dir)
    logging.info("Training")
    q_stats, secs_per_step = exp.Train()
    tf.io.gfile.makedirs(log_dir)
    utils.save_json(q_stats, os.path.join(log_dir, "q_stats"))
    utils.save_json(secs_per_step, os.path.join(log_dir,
コード例 #16
0
                'Unable to parse {0}. Unrecognized stages were found: {1}'.
                format(repr(argument), ', '.join(sorted(invalid_items))))

        if _ALL in stage_list:
            if len(stage_list) > 1:
                raise ValueError(
                    "Unable to parse {0}. If 'all' stages are specified, individual "
                    "stages cannot also be specified.".format(repr(argument)))
            return list(STAGES)

        previous_stage = stage_list[0]
        for stage in itertools.islice(stage_list, 1, None):
            expected_stage = _NEXT_STAGE.get(previous_stage)
            if not expected_stage:
                raise ValueError(
                    "Unable to parse {0}. '{1}' should be the last "
                    "stage.".format(repr(argument), previous_stage))
            if stage != expected_stage:
                raise ValueError(
                    "Unable to parse {0}. The stage after '{1}' should be '{2}', not "
                    "'{3}'.".format(repr(argument), previous_stage,
                                    expected_stage, stage))
            previous_stage = stage

        return stage_list


flags.DEFINE(RunStageParser(), 'run_stage', STAGES,
             "The stage or stages of perfkitbenchmarker to run.", flags.FLAGS,
             flags.ListSerializer(','))
コード例 #17
0
ファイル: point_flag.py プロジェクト: yeclairer/soowa
def DEFINE_point(name, default, help_string, flag_values=flags.FLAGS, **args):  # pylint: disable=invalid-name,redefined-builtin
    """Registers a flag whose value parses as a point."""
    flags.DEFINE(PointParser(), name, default, help_string, flag_values,
                 PointSerializer(), **args)