Пример #1
0
def start(args, cfg):
    # suppress Python warnings from sub-processes to prevent duplicate warnings being printed to console
    if dist_utils.get_rank() > 0:
        warnings.filterwarnings("ignore")

    logger = create_logger(args)
    model_save_dir = os.path.join(ModelPaths.checkpoint_base_dir(), cfg.MODE,
                                  args.model_dir)

    if dist_utils.is_main_process():
        os.makedirs(model_save_dir, exist_ok=True)

    # check if a checkpoint already exists in the model save directory. If it does, and the 'no_resume' flag is not set,
    # training should resume from the last pre-existing checkpoint.
    existing_ckpts = sorted(glob(os.path.join(model_save_dir, "*.pth")))
    if existing_ckpts and not args.no_resume:
        args.restore_session = existing_ckpts[-1]
        args.initial_ckpt = None  # when jobs auto-restart on the cluster, this might be set,
        # however we want to use the latest checkpoint instead

    # backup config to model directory
    if dist_utils.is_main_process():
        with open(os.path.join(model_save_dir, 'config.yaml'),
                  'w') as writefile:
            yaml.dump(global_cfg.d(), writefile)

    trainer = Trainer(cfg, model_save_dir, args, logger)

    try:
        trainer.start(args)
    except InterruptException as _:
        if dist_utils.is_main_process():
            print("Interrupt signal received. Saving checkpoint...")
            trainer.backup_session()
            dist_utils.synchronize()
        exit(1)
    except Exception as err:
        if dist_utils.is_main_process():
            print("Exception occurred. Saving checkpoint...")
            print(err)
            trainer.backup_session()
            if dist_utils.is_distributed():
                dist.destroy_process_group()
        raise err
Пример #2
0
    def __init__(self, cfg, model_save_dir, args, logger):
        self.num_gpus = dist_utils.get_world_size()
        self.local_rank = dist_utils.get_rank()
        self.local_device = dist_utils.get_device()
        self.is_main_process = dist_utils.is_main_process()

        self.console_logger = logger

        self.model_save_dir = model_save_dir
        self.log_dir = os.path.join(self.model_save_dir, 'logs')

        if self.is_main_process:
            os.makedirs(self.log_dir, exist_ok=True)

        self.model = build_model(restore_pretrained_backbone_wts=True, logger=self.console_logger).to(self.local_device)

        # create optimizer
        self.optimizer = create_optimizer(self.model, cfg, self.console_logger.info)

        # wrap model and optimizer around apex if mixed precision training is enabled
        if cfg.MIXED_PRECISION:
            assert APEX_IMPORTED
            self.console_logger.info("Mixed precision training is enabled.")
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=cfg.MIXED_PRECISION_OPT_LEVEL)

        if dist_utils.is_distributed():
            self.model = nn.parallel.DistributedDataParallel(
                self.model, device_ids=[self.local_rank], output_device=self.local_rank,
                find_unused_parameters=cfg.FREEZE_BACKBONE
            )

        self.total_iterations = cfg.MAX_ITERATIONS

        # create LR scheduler
        self.lr_scheduler = create_lr_scheduler(self.optimizer, cfg, self.console_logger.info)

        # create parameter logger
        self.logger = None
        if self.is_main_process:
            self.logger = TrainingLogger(self.log_dir)

        self.interrupt_detector = InterruptDetector()
        self.cfg = cfg

        self.elapsed_iterations = 0

        assert not (args.restore_session and args.initial_ckpt)

        if args.restore_session:
            self.console_logger.info("Restoring session from {}".format(args.restore_session))
            self.restore_session(torch.load(args.restore_session, map_location=self.local_device))
        elif args.initial_ckpt:
            self.console_logger.info("Loading model weights from checkpoint at: {}".format(args.initial_ckpt))
            self._model.load_state_dict(torch.load(args.initial_ckpt, map_location=self.local_device)['model'])
Пример #3
0
def create_logger(args):
    logger = logging.getLogger("MaskTCNNTrainLogger")
    if dist_utils.is_main_process():
        logger.setLevel(args.log_level)
    else:
        logger.setLevel(args.subprocess_log_level)

    ch = logging.StreamHandler()
    formatter = logging.Formatter("[%(proc_id)d] %(asctime)s - %(levelname)s - %(message)s", "%H:%M:%S")
    extra = {"proc_id": dist_utils.get_rank()}
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    logger.propagate = False

    logger = logging.LoggerAdapter(logger, extra)
    logger.propagate = False

    return logger