def infer_language_pair(path): """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx""" src, dst = None, None for filename in PathManager.ls(path): parts = filename.split(".") if len(parts) >= 3 and len(parts[1].split("-")) == 2: return parts[1].split("-") return src, dst
def _find_extra_valid_paths(dataset_path: str) -> set: paths = utils.split_paths(dataset_path) all_valid_paths = set() for sub_dir in paths: contents = PathManager.ls(sub_dir) valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None] all_valid_paths |= {os.path.basename(p) for p in valid_paths} # Remove .bin, .idx etc roots = {os.path.splitext(p)[0] for p in all_valid_paths} return roots
def _get_shard_num_dict(cls, split, paths): shards = defaultdict(int) for path in paths: files = PathManager.ls(path) for f in files: if f.startswith(split) and f.endswith('.idx'): # idx files of the form "{split}.{src}-{tgt}.{lang}.idx" direction = f.split('.')[-3] shards[direction] += 1 # each direction has two '.idx' files # one for source language and one for target language, so: return {k: v // 2 for k, v in shards.items()}
def _get_shard_num_dict(cls, split, paths): shards = defaultdict(int) for path in paths: files = PathManager.ls(path) directions = set() for f in files: if f.startswith(split) and f.endswith(".idx"): # idx files of the form "{split}.{src}-{tgt}.{lang}.idx" direction = f.split(".")[-3] directions.add(direction) for direction in directions: shards[direction] += 1 return shards
def last_n_checkpoints(paths, n, update_based, upper_bound=None): assert len(paths) == 1 path = paths[0] if update_based: pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') else: pt_regexp = re.compile(r'checkpoint(\d+)\.pt') files = PathManager.ls(path) entries = [] for f in files: m = pt_regexp.fullmatch(f) if m is not None: sort_key = int(m.group(1)) if upper_bound is None or sort_key <= upper_bound: entries.append((sort_key, m.group(0))) if len(entries) < n: raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False): """Retrieves all checkpoints found in `path` directory. Checkpoints are identified by matching filename to the specified pattern. If the pattern contains groups, the result will be sorted by the first group in descending order. """ pt_regexp = re.compile(pattern) files = PathManager.ls(path) entries = [] for i, f in enumerate(files): m = pt_regexp.fullmatch(f) if m is not None: idx = float(m.group(1)) if len(m.groups()) > 0 else i entries.append((idx, m.group(0))) if keep_match: return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)] else: return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]