def ValidateMeasurementsFlag(options_list):
    """Verifies correct usage of the measurements configuration flag.

  The user of the flag must provide at least one option. All provided options
  must be valid. The NONE option cannot be combined with other options.

  Args:
    options_list: A list of strings parsed from the provided value for the
      flag.

  Returns:
    True if the list of options provided as the value for the flag meets all
    the documented requirements.

  Raises:
    flags.ValidationError: If the list of options provided as the value for
      the flag does not meet the documented requirements.
  """
    for option in options_list:
        if option not in MEASUREMENTS_ALL:
            raise flags.ValidationError('%s: Invalid value for --%s' %
                                        (option, MEASUREMENTS_FLAG_NAME))
        if option == MEASUREMENTS_NONE and len(options_list) != 1:
            raise flags.ValidationError(
                '%s: Cannot combine with other --%s options' %
                (option, MEASUREMENTS_FLAG_NAME))
    return True
示例#2
0
def validate_multiparam_str(param_str: str):
    """Validates parameter settings for linspace and logspace parameters.

  Checks that the value is of the form "NAME,START,STOP,NUM" where
  START and STOP are floats with START <= STOP and NUM is an integer.

  Args:
      param_str: The `str` value of the parameter to test.
  Returns:
      True if the parameter setting is valid.
  Raises:
      flags.ValidationError: The parameter setting is invalid.
  """
    vals = param_str.split(",")
    if len(vals) != 4:
        raise flags.ValidationError(
            'Array param value must be "NAME,START,STOP,NUM"')
    try:
        start = float(vals[1])
        stop = float(vals[2])
        num = int(vals[3])
    except ValueError as e:
        raise flags.ValidationError(
            'Array param "NAME,START,STOP,NUM" must have float'
            " START and STOP and int NUM") from e
    if start > stop:
        raise flags.ValidationError(
            'Array param "NAME,START,STOP,NUM" must have START <= STOP')
    if num < 1:
        raise flags.ValidationError(
            'Array param "NAME,START,STOP,NUM" must have NUM > 0')
    return True
示例#3
0
 def _check_fp16_implementation(flags_dict):
     """Validator to check fp16_implementation flag is valid."""
     if (flags_dict["fp16_implementation"] == "graph_rewrite" and
             flags_dict["dtype"] != "fp16"):
         raise flags.ValidationError("--fp16_implementation should not be "
                                     "specified unless --dtype=fp16")
     if (flags_dict["fp16_implementation"] != "graph_rewrite" and
             flags_dict["loss_scale"] == "dynamic"):
         raise flags.ValidationError("--loss_scale=dynamic is only supported "
                                     "when "
                                     "--fp16_implementation=graph_rewrite")
     return True
示例#4
0
 def _check_fp16_implementation(flags_dict):
   """Validator to check fp16_implementation flag is valid."""
   if (flags_dict['fp16_implementation'] == 'graph_rewrite' and
       flags_dict['dtype'] != 'fp16'):
     raise flags.ValidationError('--fp16_implementation should not be '
                                 'specified unless --dtype=fp16')
   if (flags_dict['fp16_implementation'] != 'graph_rewrite' and
       flags_dict['loss_scale'] == 'dynamic'):
     raise flags.ValidationError('--loss_scale=dynamic is only supported '
                                 'when '
                                 '--fp16_implementation=graph_rewrite')
   return True
示例#5
0
 def _check_fp16_implementation(flags_dict):
   """Validator to check fp16_implementation flag is valid."""
   if (flags_dict["fp16_implementation"] == "graph_rewrite" and
       flags_dict["dtype"] != "fp16"):
     raise flags.ValidationError("--fp16_implementation should not be "
                                 "specified unless --dtype=fp16")
   return True
示例#6
0
def _check_block_size(flag_value):
    try:
        _parse_block_size_flag(flag_value)
        return True
    except:
        raise flags.ValidationError('Invalid block size value "%s".' %
                                    flag_value)
def _config_file_validator(config_file_path):
    """Validate the config yaml file path.

  Args:
    config_file_path: str, the name or the full path of the config file.

  Returns:
    True when the config_file_path is considered valid.

  Raises:
    flags.ValidationError: if the config file does not end in yaml or is not
        found on the path.
  """
    if not config_file_path.endswith('.yaml'):
        raise flags.ValidationError('the config file must end in .yaml')
    config_file_path = _get_config_file_path(config_file_path)
    if not os.path.isfile(config_file_path):
        raise flags.ValidationError('the config file specified is not found')
    return True
示例#8
0
 def validate_mutual_exclusion(flags_dict):
     valid_1 = (flags_dict['root_dir'] is not None
                and flags_dict['experiment_name'] is not None
                and flags_dict['train_eval_dir'] is None)
     valid_2 = (flags_dict['root_dir'] is None
                and flags_dict['experiment_name'] is None
                and flags_dict['train_eval_dir'] is not None)
     if valid_1 or valid_2:
         return True
     message = ('Exactly both root_dir and experiment_name or only '
                'train_eval_dir must be specified.')
     raise flags.ValidationError(message)
  def _validate_schedule(flag_values):
    """Validates the --schedule flag and the flags it interacts with."""
    schedule = flag_values["schedule"]
    save_checkpoints_steps = flag_values["save_checkpoints_steps"]
    save_checkpoints_secs = flag_values["save_checkpoints_secs"]

    if schedule in ["train", "train_and_eval"]:
      if not (save_checkpoints_steps or save_checkpoints_secs):
        raise flags.ValidationError(
            "--schedule='%s' requires --save_checkpoints_steps or "
            "--save_checkpoints_secs." % schedule)

    return True
示例#10
0
def validate_param_list_str(param_str: str):
    """Validates parameter settings for explicit parameter lists.

  Checks that the value is of the form "NAME,VAL1,VAL2[,VAL3,...]".

  Args:
      param_str: The `str` value of the parameter to test.
  Returns:
      True if the parameter setting is valid.
  Raises:
      flags.ValidationError: The parameter setting is invalid.
  """
    vals = param_str.split(",")
    if len(vals) < 3:
        raise flags.ValidationError(
            'param_list value must be "NAME,VAL1,VAL2[,VAL3,...]"')
    return True
示例#11
0
def _AppServerValidator(app_servers):
  """Validate the App Engine servers are of the right format.

  Args:
    app_servers: list|str|, a list of strings defining the Google Cloud Project
        IDs by friendly name.

  Returns:
    True if the app_servers are in the correct format.

  Raises:
    flags.ValidationError, if the app_servers are in the incorrect format.
  """
  try:
    _ParseAppServers(app_servers)
  except ValueError:
    raise flags.ValidationError(_APP_SERVER_ERROR.format(app_servers))
  return True
def ValidateVmMetadataFlag(options_list):
  """Verifies correct usage of the vm metadata flag.

  All provided options must be in the form of key:value.

  Args:
    options_list: A list of strings parsed from the provided value for the
      flag.

  Returns:
    True if the list of options provided as the value for the flag meets
    requirements.

  Raises:
    flags.ValidationError: If the list of options provided as the value for
      the flag does not meet requirements.
  """
  for option in options_list:
    if ':' not in option[1:-1]:
      raise flags.ValidationError(
          '"%s" not in expected key:value format' % option)
  return True
示例#13
0
def _CheckWindowScale(window_scale):
    if window_scale < 10 or window_scale > 300:
        raise flags.ValidationError('Scale factor outside range: 10-300')
    return True
示例#14
0
def main(argv):
    del argv  # Unused.
    # TODO: Add more flag validations.
    if FLAGS.max_failures is not None and FLAGS.max_failures > 0:
        raise NotImplementedError(
            'Does not yet handle image retrieval/conversion '
            'failures')

    if FLAGS.atlas_width is not None or FLAGS.atlas_height is not None:
        print(FLAGS.atlas_width, FLAGS.atlas_height)
    # raise NotImplementedError(
    #     'Does not yet support specifying an atlas size.')

    if FLAGS.sourcelist is None:
        raise flags.ValidationError(
            'You must specify a list of image sources.')

    bg_color_rgb = _determine_bg_rgb()

    outputdir = FLAGS.output_dir
    if outputdir is None:
        outputdir = os.path.join(os.getcwd())

    image_source_list = atlasmaker_io.read_src_list_csvfile(
        FLAGS.sourcelist, FLAGS.sourcelist_dups_handling)

    # Provide some useful confirmation info about settings to user.
    logging.info(
        'Desired output size in pixels width, height for each image is: '
        '(%d, %d)' % (FLAGS.image_width, FLAGS.image_height))
    logging.info('Image format for Atlas is: %s' % FLAGS.image_format)
    logging.info('Background RGB is set to %s' % str(bg_color_rgb))
    logging.info('Background opacity is set to %d' % FLAGS.image_opacity)
    logging.info(
        'Should we preserve image aspect ratio during conversion? %s' %
        FLAGS.keep_aspect_ratio)

    image_convert_settings = convert.ImageConvertSettings(
        img_format=FLAGS.image_format,
        width=FLAGS.image_width,
        height=FLAGS.image_height,
        bg_color_rgb=bg_color_rgb,
        opacity=FLAGS.image_opacity,
        preserve_aspect_ratio=FLAGS.keep_aspect_ratio,
        resize_if_larger=FLAGS.resize_if_larger)

    # Ensure we can write to the output dir or fail fast.
    atlasmaker_io.create_output_dir_if_not_exist(FLAGS.output_dir)

    # Create default image to be used for images that we can't get or convert.
    if FLAGS.default_image_path is not None:
        logging.info('Using image %s as default image when a specified image '
                     'can\'t be fetched or converted' %
                     FLAGS.default_image_path)
        default_img = parallelize.convert_default_image(
            FLAGS.default_image_path, image_convert_settings)
    else:
        logging.info(
            'No default image for failures specified by user, so just '
            'using the background as the default image.')
        default_img = convert.create_default_image(image_convert_settings)

    # Verify we can write the specified output format, or fail fast.
    try:
        testimage_file_name = '{}.{}'.format('testimage',
                                             str(FLAGS.image_format).lower())
        atlasmaker_io.save_image(default_img,
                                 os.path.join(FLAGS.output_dir,
                                              testimage_file_name),
                                 delete_after_write=True)
        logging.info('Confirmed we can output images in %s format' %
                     FLAGS.image_format)
    except:
        logging.error('Unable to write test image in desired output format. '
                      'Please confirm that \'%s\' is a supported PIL output '
                      'format.' % FLAGS.image_format)
        raise

    # Convert images in parallel.
    logging.info('Scheduling %d tasks.' % len(image_source_list))
    converted_images_with_statuses = parallelize.get_and_convert_images_parallel(
        image_source_list,
        image_convert_settings,
        n_jobs=FLAGS.num_parallel_jobs,
        verbose=FLAGS.parallelization_verbosity,
        allow_truncated_images=FLAGS.use_truncated_images,
        request_timeout=FLAGS.http_request_timeout,
        http_max_retries=FLAGS.http_max_retries)

    sprite_atlas_settings = montage.SpriteAtlasSettings(
        img_format=FLAGS.image_format,
        height=FLAGS.atlas_height,
        width=FLAGS.atlas_width,
        filename=FLAGS.filename)
    # width=FLAGS.atlas_width)

    # Generate the atlas from converted images.
    sprite_atlas_generator = montage.SpriteAtlasGenerator(
        images_with_statuses=converted_images_with_statuses,
        img_src_paths=image_source_list,
        atlas_settings=sprite_atlas_settings,
        default_img=default_img)

    atlases, manifests = sprite_atlas_generator.create_atlas()

    atlasmaker_io.save_atlas_and_manifests(
        outdir=outputdir,
        atlases=atlases,
        manifests=manifests,
        sprite_atlas_settings=sprite_atlas_settings)