コード例 #1
0
ファイル: fb_train.py プロジェクト: stas00/deep-shallow
def fb_main(device_id, args, start_rank, log_path=None):
    """[FB] entry point for each worker process."""

    args.distributed_rank = start_rank + device_id

    def add_handler(handler):
        for root in ["fairseq", "fairseq_cli"]:
            logger = logging.getLogger(root)
            logger.propagate = False  # don't propagate to parent loggers
            handler.setLevel(logging.INFO)
            handler.setFormatter(
                logging.Formatter(
                    fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
                    datefmt="%Y-%m-%d %H:%M:%S",
                )
            )
            logger.addHandler(handler)

    # write fairseq logs to stdout
    add_handler(logging.StreamHandler(sys.stdout))

    # support Manifold for checkpoints
    # For latte_training use case, we have separate NMTManifoldPathHandler registered in
    # https://fburl.com/wurd7t70. So if parameters need to be updated the right place
    # is ~/fbsource/fbcode/fblearner/flow/projects/fairseq/latte_training/manifold_file_io.py
    PathManager.register_handler(ManifoldPathHandler(max_parallel=16, timeout_sec=1800))

    def train_main():
        if args.distributed_world_size > 1:
            distributed_main(device_id, args)
        else:
            main(args)

    if log_path is not None and args.distributed_rank == 0:
        # write logs from worker 0 to train.log
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        Path(log_path).touch(0o777, exist_ok=True)
        add_handler(logging.FileHandler(log_path))
        train_main()
    else:
        train_main()
コード例 #2
0
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq.average_checkpoints import main
from fairseq.file_io import PathManager

# support fb specific path mananger
try:
    from fvcore.fb.manifold import ManifoldPathHandler
    PathManager.register_handler(
        ManifoldPathHandler(max_parallel=16, timeout_sec=1800))
except Exception:
    pass

if __name__ == '__main__':
    main()