示例#1
0
def _create_gcs_file(path, mode, content=None, size=None, cleanup=None):
  content = _create_content(mode, content=content, size=size)
  with gcs.open(path, mode=mode) as fd:
    fd.write(content)
  if cleanup is not None:
    cleanup.append(lambda: gcs.remove(path))
  return content
示例#2
0
def load_checkpoint_to_cpu(path, arg_overrides=None):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
    try:
        from fairseq.fb_pathmgr import fb_pathmgr
        with fb_pathmgr.open(path, "rb") as f:
            state = torch.load(
                f, map_location=lambda s, l: default_restore_location(s, 'cpu'),
            )
    except (ModuleNotFoundError, ImportError):
        # if path manager not found, continue with local file.
        if path.startswith(gcsfs.CLOUD_STORAGE_PREFIX):
            with gcsfs.open(path, 'rb') as fd:
                state = torch.load(fd, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        else:
            state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = state['args']
    if arg_overrides is not None:
        for arg_name, arg_val in arg_overrides.items():
            setattr(args, arg_name, arg_val)
    state = _upgrade_state_dict(state)
    return state
示例#3
0
  metrics, sp_return_code = _run_subprocess(FLAGS.positional)
  if sp_return_code:
    raise AssertionError(
      'Child process had non-zero exit code: {}'.format(sp_return_code))

  # Retrieve any config files that affect this test.
  # NOTE: these are ordered in increasing specificity. For example, if there
  # was a base config that affects all tests and a specific config for a
  # particular test, then the base config will be the first element in the
  # list and the most specific config will be the last element.
  ordered_config_dicts = []

  path_to_search = FLAGS.test_folder_name
  while True:
    try:
      f = gcsfs.open(os.path.join(FLAGS.root, path_to_search, 'config.json'))
      ordered_config_dicts.insert(0, json.load(f))
    except google.api_core.exceptions.NotFound:
      pass
    if not path_to_search:
      break
    path_to_search = os.path.split(path_to_search)[0]
  if not ordered_config_dicts:
    raise ValueError('No config files found in {} or parent directories. '
        'See example usage at top of metrics_test_wrapper.py'.format(
            os.path.join(FLAGS.root, FLAGS.test_folder_name)))

  # Consolidate configs into 1 dict by overwriting the least-specific configs
  # with the increasingly more-specific configs.
  config = ordered_config_dicts[0]
  for c in ordered_config_dicts:
示例#4
0
 def thread_fn():
     ts = time.time()
     for n in range(0, args.test_count):
         with gs.open(args.gsfile) as fd:
             assert len(fd.read()) == gblob.size
     return time.time() - ts