Ejemplo n.º 1
0
def infer_language_pair(path):
    """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
    src, dst = None, None
    for filename in PathManager.ls(path):
        parts = filename.split(".")
        if len(parts) >= 3 and len(parts[1].split("-")) == 2:
            return parts[1].split("-")
    return src, dst
Ejemplo n.º 2
0
def _find_extra_valid_paths(dataset_path: str) -> set:
    paths = utils.split_paths(dataset_path)
    all_valid_paths = set()
    for sub_dir in paths:
        contents = PathManager.ls(sub_dir)
        valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
        all_valid_paths |= {os.path.basename(p) for p in valid_paths}
    # Remove .bin, .idx etc
    roots = {os.path.splitext(p)[0] for p in all_valid_paths}
    return roots
Ejemplo n.º 3
0
 def _get_shard_num_dict(cls, split, paths):
     shards = defaultdict(int)
     for path in paths:
         files = PathManager.ls(path)
         for f in files:
             if f.startswith(split) and f.endswith('.idx'):
                 # idx files of the form "{split}.{src}-{tgt}.{lang}.idx"
                 direction = f.split('.')[-3]
                 shards[direction] += 1
     # each direction has two '.idx' files
     # one for source language and one for target language, so:
     return {k: v // 2 for k, v in shards.items()}
Ejemplo n.º 4
0
 def _get_shard_num_dict(cls, split, paths):
     shards = defaultdict(int)
     for path in paths:
         files = PathManager.ls(path)
         directions = set()
         for f in files:
             if f.startswith(split) and f.endswith(".idx"):
                 # idx files of the form "{split}.{src}-{tgt}.{lang}.idx"
                 direction = f.split(".")[-3]
                 directions.add(direction)
         for direction in directions:
             shards[direction] += 1
     return shards
Ejemplo n.º 5
0
def last_n_checkpoints(paths, n, update_based, upper_bound=None):
    assert len(paths) == 1
    path = paths[0]
    if update_based:
        pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
    else:
        pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
    files = PathManager.ls(path)

    entries = []
    for f in files:
        m = pt_regexp.fullmatch(f)
        if m is not None:
            sort_key = int(m.group(1))
            if upper_bound is None or sort_key <= upper_bound:
                entries.append((sort_key, m.group(0)))
    if len(entries) < n:
        raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
    return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
Ejemplo n.º 6
0
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
    """Retrieves all checkpoints found in `path` directory.

    Checkpoints are identified by matching filename to the specified pattern. If
    the pattern contains groups, the result will be sorted by the first group in
    descending order.
    """
    pt_regexp = re.compile(pattern)
    files = PathManager.ls(path)

    entries = []
    for i, f in enumerate(files):
        m = pt_regexp.fullmatch(f)
        if m is not None:
            idx = float(m.group(1)) if len(m.groups()) > 0 else i
            entries.append((idx, m.group(0)))
    if keep_match:
        return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
    else:
        return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]