示例#1
0
    def test_update_shared_dict_multi(self):
        # Tests that the shared dict is updated when the flag value is updated.
        shared_dict = {'a': {'b': ['value']}}
        namespace = ('a', 'b')
        flag_values = flags.FlagValues()

        flags.DEFINE_flag(_flags.MultiItemFlag(
            shared_dict,
            namespace,
            parser=flags.ArgumentParser(),
            serializer=flags.ArgumentSerializer(),
            name='a.b',
            default=['foo', 'bar'],
            help_string='help string'),
                          flag_values=flag_values)

        flag_values['a.b'].value = ['new', 'value']
        with self.subTest(name='setter'):
            self.assertEqual(shared_dict, {'a': {'b': ['new', 'value']}})

        flag_values(('./program', '--a.b=override1', '--a.b=override2'))
        with self.subTest(name='override_parse'):
            self.assertEqual(shared_dict,
                             {'a': {
                                 'b': ['override1', 'override2']
                             }})
示例#2
0
def AddCmdFunc(command_name,
               cmd_func,
               command_aliases=None,
               all_commands_help=None,
               hidden=False):
    """Add a new command to the list of registered commands.

  Args:
    command_name:      name of the command which will be used in argument
                       parsing
    cmd_func:          command function, this function received the remaining
                       arguments as its only parameter. It is supposed to do the
                       command work and then return with the command result that
                       is being used as the shell exit code.
    command_aliases:   A list of command aliases that the command can be run as.
    all_commands_help: Help message to be displayed in place of func.__doc__
                       when all commands are displayed.
    hidden:            Whether to hide this command from the help output.
  """
    _AddCmdInstance(command_name,
                    _FunctionalCmd(command_name,
                                   flags.FlagValues(),
                                   cmd_func,
                                   command_aliases=command_aliases,
                                   all_commands_help=all_commands_help,
                                   hidden=hidden),
                    command_aliases=command_aliases)
示例#3
0
def prepare_subprocess_cmd(subprocess_cmd):
    '''Prepares a subprocess command by running --helpfull and masking flags.

    Args:
        subprocess_cmd: List[str], what would be passed into subprocess.call()
            i.e. ['python', 'train.py', '--flagfile=flags']

    Returns:
        List[str], ['python', 'train.py', '--train_flag=blah', '--more_flags']
    '''
    help_cmd = subprocess_cmd + ['--helpfull']
    help_output = subprocess.run(help_cmd, stdout=subprocess.PIPE).stdout
    help_output = help_output.decode('ascii')
    if 'python' in subprocess_cmd[0]:
        valid_flags = parse_helpfull_output(help_output)
    else:
        valid_flags = parse_helpfull_output(help_output, regex=FLAG_HELP_RE_CC)
    parsed_flags = flags.FlagValues().read_flags_from_files(subprocess_cmd[1:])

    def valid_argv(argv):
        ''' Figures out if a flag parsed from the flagfile matches a flag in
        the command about to be run.'''
        flagname_match = FLAG_RE.match(argv)
        if not flagname_match:
            return True
        flagname = flagname_match.group()
        return flagname in valid_flags

    filtered_flags = list(filter(valid_argv, parsed_flags))
    return [subprocess_cmd[0]] + filtered_flags
示例#4
0
 def setUp(self):
     self._absl_flags = flags.FlagValues()
     flags.DEFINE_bool('absl_bool',
                       None,
                       'help for --absl_bool.',
                       short_name='b',
                       flag_values=self._absl_flags)
     # Add a boolean flag that starts with "no", to verify it can correctly
     # handle the "no" prefixes in boolean flags.
     flags.DEFINE_bool('notice',
                       None,
                       'help for --notice.',
                       flag_values=self._absl_flags)
     flags.DEFINE_string('absl_string',
                         'default',
                         'help for --absl_string=%.',
                         short_name='s',
                         flag_values=self._absl_flags)
     flags.DEFINE_integer('absl_integer',
                          1,
                          'help for --absl_integer.',
                          flag_values=self._absl_flags)
     flags.DEFINE_float('absl_float',
                        1,
                        'help for --absl_integer.',
                        flag_values=self._absl_flags)
     flags.DEFINE_enum('absl_enum',
                       'apple', ['apple', 'orange'],
                       'help for --absl_enum.',
                       flag_values=self._absl_flags)
示例#5
0
文件: utils.py 项目: lionffen/minigo
async def expand_cmd_str(cmd):
    n = 2 if is_python_cmd(cmd) else 1
    cmd = list(cmd)
    args = cmd[n:]
    process = cmd[:n]
    key = ' '.join(process)

    async with flag_cache_lock:
        valid_flags = flag_cache.get(key)
        if valid_flags is None:
            valid_flags = mask_flags.extract_valid_flags(cmd)
            flag_cache[key] = valid_flags

    parsed_args = flags.FlagValues().read_flags_from_files(args)
    flag_args = {}
    position_args = []
    for arg in parsed_args:
        if arg.startswith('--'):
            if '=' not in arg:
                flag_args[arg] = None
            else:
                flag, value = arg.split('=', 1)
                flag_args[flag] = value
        else:
            position_args.append(arg)

    flag_list = []
    for flag, value in flag_args.items():
        if value is None:
            flag_list.append(flag)
        else:
            flag_list.append('%s=%s' % (flag, value))

    flag_list = sorted(mask_flags.filter_flags(flag_list, valid_flags))
    return '  '.join(process + flag_list + position_args)
示例#6
0
    def test_basic_serialization(self):
        flag_values = flags.FlagValues()
        _define_auto.DEFINE_auto('point', Point, flag_values=flag_values)

        # Accessing flag_holder.value would raise an error here, since flags haven't
        # been parsed yet. For consistency we access the value via flag_values
        # throughout the test, rather than through a returned `FlagHolder`.
        initial_point_value = copy.deepcopy(flag_values['point'].value())

        # Parse flags, then serialize.
        flag_values(
            ('./program', '--point.x=1.2', '--point.y=3.5', '--point.label=p'))

        self.assertEqual(flag_values['point'].serialize(), _flags._EMPTY)
        self.assertEqual(flag_values['point.x'].serialize(), '--point.x=1.2')
        self.assertEqual(flag_values['point.label'].serialize(),
                         '--point.label=p')

        parsed_point_value = copy.deepcopy(flag_values['point'].value())

        self.assertEqual(parsed_point_value, Point(x=1.2, y=3.5, label='p'))
        self.assertNotEqual(parsed_point_value, initial_point_value)

        # Test a round trip.
        serialized_args = [
            flag_values[name].serialize() for name in flag_values
            if name.startswith('point.')
        ]

        flag_values.unparse_flags()  # Reset to defaults
        self.assertEqual(flag_values['point'].value(), initial_point_value)

        flag_values(['./program'] + serialized_args)
        self.assertEqual(flag_values['point'].value(), parsed_point_value)
示例#7
0
 def setUp(self):
     # Save the value of the instance of FLAGS local to this module.
     global FLAGS  # pylint: disable=global-statement
     self.flags = FLAGS
     # pylint: disable=g-bad-name
     FLAGS = flags.FlagValues()
     FLAGS.append_flag_values(self.flags)
     FLAGS.mark_as_parsed()
示例#8
0
 def test_dataclass(self):
     flag_values = flags.FlagValues()
     flag_holder = _define_auto.DEFINE_auto('point',
                                            Point,
                                            flag_values=flag_values)
     flag_values(('./program', '--point.x=2.0', '--point.y=-1.5',
                  '--point.label=p'))
     expected = Point(2.0, -1.5, 'p')
     self.assertEqual(expected, flag_holder.value())
示例#9
0
 def test_override_kwargs(self):
     flag_values = flags.FlagValues()
     flag_holder = _define_auto.DEFINE_auto('point',
                                            Point,
                                            flag_values=flag_values)
     flag_values(('./program', '--point.x=2.0', '--point.y=-1.5',
                  '--point.label=p'))
     expected = Point(3.0, -1.5, 'p')
     # Here we override one of the arguments.
     self.assertEqual(expected, flag_holder.value(x=3.0))
示例#10
0
    def test_disclaimed_module(self):
        flag_values = flags.FlagValues()
        _ = _define_auto.DEFINE_auto('greet',
                                     greet,
                                     'help string',
                                     flag_values=flag_values)
        defining_module = flag_values.find_module_defining_flag('greet')

        # The defining module should be the calling module, not the module where
        # the flag is defined. Otherwise the help for a module's flags will not be
        # printed unless the user uses --helpfull.
        self.assertIn('_define_auto_test', defining_module)
示例#11
0
 def test_function(self):
     flag_values = flags.FlagValues()
     flag_holder = _define_auto.DEFINE_auto('greet',
                                            greet,
                                            flag_values=flag_values)
     flag_values((
         './program',
         '--greet.greeting=Hi there',
         '--greet.targets=(\'Alice\', \'Bob\')',
     ))
     expected = 'Hi there Alice, Bob'
     self.assertEqual(expected, flag_holder.value())
    def test_parse_assign_dataclass(self):
        flag_values = flags.FlagValues()

        def always_fail(v):
            raise ValueError()

        result = config_flags.DEFINE_config_dataclass('test_config',
                                                      _CONFIG,
                                                      flag_values=flag_values,
                                                      parse_fn=always_fail)
        flag_values(['program'])
        flag_values['test_config'].value = parse_config_flag('12')
        self.assertEqual(result.value.my_model.foo, 12)
示例#13
0
    def test_help_strings(self):
        flag_values = flags.FlagValues()

        # Should default to module.name, since the `greet` docstring is empty.
        _define_auto.DEFINE_auto('greet', greet, flag_values=flag_values)
        # Should use the custom help string.
        _define_auto.DEFINE_auto('point',
                                 Point,
                                 help_string='custom',
                                 flag_values=flag_values)

        self.assertEqual(flag_values['greet'].help,
                         f'{greet.__module__}.greet')
        self.assertEqual(flag_values['point'].help, 'custom')
示例#14
0
def expand_flags(cmd, *args):
  """Expand & dedup any flagfile command line arguments."""

  # Read any flagfile arguments and expand them into a new list.
  expanded = flags.FlagValues().read_flags_from_files(args)

  # When one flagfile includes & overrides a base one, the expanded list may
  # contain multiple instances of the same flag with different values.
  # Deduplicate, always taking the last occurance of the flag.
  deduped = OrderedDict()
  for arg in expanded:
    flag = arg.split('=', 1)[0]
    deduped[flag] = arg
  return deduped.values()
示例#15
0
def prepare_subprocess_cmd(subprocess_cmd):
    """Prepares a subprocess command by running --helpfull and masking flags.

    Args:
        subprocess_cmd: List[str], what would be passed into subprocess.call()
          i.e. ['python', 'train.py', '--flagfile=flags']

    Returns:
        ['python', 'train.py', '--train_flag=blah', '--more_flags']
  """
    valid_flags = extract_valid_flags(subprocess_cmd)
    parsed_flags = flags.FlagValues().read_flags_from_files(subprocess_cmd[1:])
    filtered_flags = filter_flags(parsed_flags, valid_flags)
    return [subprocess_cmd[0]] + filtered_flags
示例#16
0
    def test_serialize_roundtrip(self):
        # Use the global 'FLAGS' as the source, to ensure all the framework defined
        # flags will go through the round trip process.
        flags.DEFINE_string('testflag', 'testval', 'help', flag_values=FLAGS)

        new_flag_values = flags.FlagValues()
        new_flag_values.append_flag_values(FLAGS)

        FLAGS.testflag = 'roundtrip_me'
        argv = ['binary_name'] + FLAGS.flags_into_string().splitlines()

        self.assertNotEqual(new_flag_values['testflag'], FLAGS.testflag)
        new_flag_values(argv)
        self.assertEqual(new_flag_values.testflag, FLAGS.testflag)
        del FLAGS.testflag
def checked_run(name, *cmd):
    # Read & expand any flagfiles specified on the commandline so we can know
    # exactly what's going on.
    expanded = flags.FlagValues().read_flags_from_files(cmd)
    logging.info('Running %s:\n  %s', name, '  '.join(expanded))

    with utils.logged_timer('%s finished' % name.capitalize()):
        completed_process = subprocess.run(cmd,
                                           stdout=subprocess.PIPE,
                                           stderr=subprocess.STDOUT)
        if completed_process.returncode:
            logging.error('Error running %s: %s', name,
                          completed_process.stdout.decode())
            raise RuntimeError('Non-zero return code executing %s' %
                               ' '.join(cmd))
    return completed_process
示例#18
0
 def testFlagChangesAreNotReflectedInConfigDict(self):
     flag_values = flags.FlagValues()
     flags.DEFINE_integer('test_flag',
                          0,
                          'Test flag.',
                          flag_values=flag_values)
     flag_values([sys.argv[0]])
     flag_values_overrides = {}
     flag_values_overrides['test_flag'] = 1
     self.assertFlagState(flag_values, 0, False)
     self.assertEqual(flag_values_overrides['test_flag'], 1)
     with flag_util.OverrideFlags(flag_values, flag_values_overrides):
         self.assertFlagState(flag_values, 1, True)
         flag_values.test_flag = 2
         self.assertFlagState(flag_values, 2, True)
         self.assertEqual(flag_values_overrides['test_flag'], 1)
示例#19
0
def duplicate_flags(flagnames=None):
  """Returns a new FlagValues object with the requested flagnames.

  Used to test DuplicateFlagError detection.

  Args:
    flagnames: str, A list of flag names to create.

  Returns:
    A FlagValues object with one boolean flag for each name in flagnames.
  """
  flag_values = flags.FlagValues()
  for name in flagnames:
    flags.DEFINE_boolean(name, False, 'Flag named %s' % (name,),
                         flag_values=flag_values)
  return flag_values
def test_flags(default, *flag_args, parse_fn=None):
    flag_values = flags.FlagValues()
    # DEFINE_config_dataclass accesses sys.argv to build flag list!
    old_args = list(sys.argv)
    sys.argv[:] = ['', *['--test_config' + f for f in flag_args]]
    try:
        result = config_flags.DEFINE_config_dataclass('test_config',
                                                      default,
                                                      flag_values=flag_values,
                                                      parse_fn=parse_fn)
        _, *remaining = flag_values(sys.argv)
        if remaining:
            raise ValueError(f'{remaining}')
        # assert not remaining
        return result.value
    finally:
        sys.argv[:] = old_args
    def test_flag_overrides(self):

        # Set up some flag overrides.
        old_argv = list(sys.argv)
        sys.argv = shlex.split(
            './program foo.py --test_config.baseline_model.foo=99')
        flag_values = flags.FlagValues()

        # Define a config dataclass flag.
        test_config = config_flags.DEFINE_config_dataclass(
            'test_config', _CONFIG, flag_values=flag_values)

        # Inject the flag overrides.
        flag_values(sys.argv)
        sys.argv = old_argv

        # Did the value get overridden?
        self.assertEqual(test_config.value.baseline_model.foo, 99)
示例#22
0
async def expand_cmd_str(cmd):
    """Expands a string to run as command line argument."""
    real_args_position = 2 if is_python_cmd(cmd) else 1
    cmd = list(cmd)
    args = cmd[real_args_position:]
    process = cmd[:real_args_position]
    key = ' '.join(process)

    async with flag_cache_lock:
        valid_flags = flag_cache.get(key)
        if valid_flags is None:
            valid_flags = mask_flags.extract_valid_flags(cmd)
            flag_cache[key] = valid_flags

    parsed_args = flags.FlagValues().read_flags_from_files(args)
    flag_args = {}
    position_args = []
    for arg in parsed_args:
        if arg.startswith('--'):
            if '=' not in arg:
                flag_args[arg] = None
            else:
                flag, value = arg.split('=', 1)
                if flag in MULTI_VALUE_FLAGS:
                    if flag not in flag_args:
                        flag_args[flag] = []
                    flag_args[flag].append(value)
                else:
                    flag_args[flag] = value
        else:
            position_args.append(arg)

    flag_list = []
    for flag, value in flag_args.items():
        if value is None:
            flag_list.append(flag)
        elif isinstance(value, list):
            for v in value:
                flag_list.append('%s=%s' % (flag, v))
        else:
            flag_list.append('%s=%s' % (flag, value))

    flag_list = sorted(mask_flags.filter_flags(flag_list, valid_flags))
    return '  '.join(process + flag_list + position_args)
示例#23
0
    def test_serialize_roundtrip(self):
        # Use the global 'FLAGS' as the source, to ensure all the framework defined
        # flags will go through the round trip process.
        flags.DEFINE_string('testflag', 'testval', 'help', flag_values=FLAGS)

        flags.DEFINE_multi_enum('test_multi_enum_flag', ['x', 'y'],
                                ['x', 'y', 'z'],
                                'Multi enum help.',
                                flag_values=FLAGS)

        class Fruit(enum.Enum):
            APPLE = 1
            ORANGE = 2
            TOMATO = 3

        flags.DEFINE_multi_enum_class('test_multi_enum_class_flag',
                                      ['APPLE', 'TOMATO'],
                                      Fruit,
                                      'Fruit help.',
                                      flag_values=FLAGS)

        new_flag_values = flags.FlagValues()
        new_flag_values.append_flag_values(FLAGS)

        FLAGS.testflag = 'roundtrip_me'
        FLAGS.test_multi_enum_flag = ['y', 'z']
        FLAGS.test_multi_enum_class_flag = [Fruit.ORANGE, Fruit.APPLE]
        argv = ['binary_name'] + FLAGS.flags_into_string().splitlines()

        self.assertNotEqual(new_flag_values['testflag'], FLAGS.testflag)
        self.assertNotEqual(new_flag_values['test_multi_enum_flag'],
                            FLAGS.test_multi_enum_flag)
        self.assertNotEqual(new_flag_values['test_multi_enum_class_flag'],
                            FLAGS.test_multi_enum_class_flag)
        new_flag_values(argv)
        self.assertEqual(new_flag_values.testflag, FLAGS.testflag)
        self.assertEqual(new_flag_values.test_multi_enum_flag,
                         FLAGS.test_multi_enum_flag)
        self.assertEqual(new_flag_values.test_multi_enum_class_flag,
                         FLAGS.test_multi_enum_class_flag)
        del FLAGS.testflag
        del FLAGS.test_multi_enum_flag
        del FLAGS.test_multi_enum_class_flag
示例#24
0
 def testReadAndWrite(self):
     flag_values = flags.FlagValues()
     flags.DEFINE_integer('test_flag',
                          0,
                          'Test flag.',
                          flag_values=flag_values)
     flag_values([sys.argv[0]])
     flag_values_overrides = {}
     flag_values_overrides['test_flag'] = 1
     self.assertFlagState(flag_values, 0, False)
     self.assertEqual(flag_values_overrides['test_flag'], 1)
     with flag_util.OverrideFlags(flag_values, flag_values_overrides):
         self.assertFlagState(flag_values, 1, True)
         self.assertEqual(flag_values_overrides['test_flag'], 1)
     self.assertFlagState(flag_values, 0, False)
     self.assertEqual(flag_values_overrides['test_flag'], 1)
     flag_values.test_flag = 3
     self.assertFlagState(flag_values, 3, False)
     self.assertEqual(flag_values_overrides['test_flag'], 1)
示例#25
0
def AddCmd(command_name, cmd_factory, **kwargs):
    """Add a command from a Cmd subclass or factory.

  Args:
    command_name:    name of the command which will be used in argument parsing
    cmd_factory:     A callable whose arguments match those of Cmd.__init__ and
                     returns a Cmd. In the simplest case this is just a subclass
                     of Cmd.
    **kwargs:        Additional keyword arguments to be passed to the
                     cmd_factory at initialization. Also passed to
                     _AddCmdInstance to catch command_aliases.

  Raises:
    AppCommandsError: if calling cmd_factory does not return an instance of Cmd.
  """
    cmd = cmd_factory(command_name, flags.FlagValues(), **kwargs)

    if not isinstance(cmd, Cmd):
        raise AppCommandsError('Command must be an instance of commands.Cmd')

    _AddCmdInstance(command_name, cmd, **kwargs)
示例#26
0
def prepare_subprocess_cmd(subprocess_cmd):
    '''Prepares a subprocess command by running --helpfull and masking flags.

    Args:
        subprocess_cmd: List[str], what would be passed into subprocess.call()
            i.e. ['python', 'train.py', '--flagfile=flags']

    Returns:
        List[str], ['python', 'train.py', '--train_flag=blah', '--more_flags']
    '''
    help_cmd = subprocess_cmd + ['--helpfull']
    help_output = subprocess.run(help_cmd, stdout=subprocess.PIPE).stdout
    help_output = help_output.decode('ascii')
    if 'python' in subprocess_cmd[0]:
        valid_flags = parse_helpfull_output(help_output)
    else:
        valid_flags = parse_helpfull_output(help_output, regex=FLAG_HELP_RE_CC)
    parsed_flags = flags.FlagValues().read_flags_from_files(subprocess_cmd[1:])

    filtered_flags = filter_flags(parsed_flags, valid_flags)
    return [subprocess_cmd[0]] + filtered_flags
示例#27
0
def extract_multi_instance(cmd):
    cmd_list = flags.FlagValues().read_flags_from_files(cmd)
    new_cmd_list = []
    multi_instance = False
    num_instance = 0
    num_games = 0
    parallel_games = 0

    for arg in cmd_list:
        argsplit = arg.split('=', 1)
        flag = argsplit[0]
        if flag == '--multi_instance':
            if argsplit[1] == 'True':
                multi_instance = True
            else:
                multi_instance = False
        elif flag == '--num_games':
            num_games = int(argsplit[1])
        elif flag == '--parallel_games':
            parallel_games = int(argsplit[1])

    if multi_instance:
        if num_games % parallel_games != 0:
            logging.error('Error num_games must be multiply of %d',
                          parallel_games)
            raise RuntimeError(
                'incompatible num_games/parallel_games combination')
        num_instance = num_games // parallel_games

    for arg in cmd_list:
        argsplit = arg.split('=', 1)
        flag = argsplit[0]
        if flag == '--multi_instance':
            pass
        elif multi_instance and flag == '--num_games':
            pass
        else:
            new_cmd_list.append(arg)

    return multi_instance, num_instance, new_cmd_list
def _parse_flags(command,
                 default=None,
                 config=None,
                 lock_config=True,
                 required=False):
    """Parses arguments simulating sys.argv."""

    if config is not None and default is not None:
        raise ValueError('If config is supplied a default should not be.')

    # Storing copy of the old sys.argv.
    old_argv = list(sys.argv)

    # Overwriting sys.argv, as sys has a global state it gets propagated.
    # The module shlex is useful here because it splits the input similar to
    # sys.argv. For instance, string arguments are not split by space.
    sys.argv = shlex.split(command)

    # Actual parsing.
    values = flags.FlagValues()
    if config is None:
        config_flags.DEFINE_config_file('test_config',
                                        default=default,
                                        flag_values=values,
                                        lock_config=lock_config)
    else:
        config_flags.DEFINE_config_dict('test_config',
                                        config=config,
                                        flag_values=values,
                                        lock_config=lock_config)

    if required:
        flags.mark_flag_as_required('test_config', flag_values=values)
    values(sys.argv)

    # Going back to original values.
    sys.argv = old_argv

    return values
示例#29
0
def expand_cmd_str(cmd):
    result = ' '.join(flags.FlagValues().read_flags_from_files(cmd))
    if cmd[0] == 'mpiexec' or cmd[0] == 'mpirun':
        result = ' \\\n-host '.join(result.split(' -host '))
    # avoid buffer too big to block I/O
    return result[:8192]
示例#30
0
文件: utils.py 项目: mtyka/minigo-1
def expand_cmd_str(cmd):
    return '  '.join(flags.FlagValues().read_flags_from_files(cmd))