def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    summary_writer = SummaryWriter(log_dir=output_dir)
    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)

    if cfg.MODEL.WEIGHT.upper() == 'CONTINUE':
        model_weight = last_checkpoint(output_dir)
    else:
        model_weight = cfg.MODEL.WEIGHT
    extra_checkpoint_data = checkpointer.load(model_weight)

    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    data_loader_val = make_data_loader(cfg,
                                       is_train=False,
                                       is_distributed=distributed)[0]

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(model=model,
             data_loader=data_loader,
             data_loader_val=data_loader_val,
             optimizer=optimizer,
             scheduler=scheduler,
             checkpointer=checkpointer,
             device=device,
             checkpoint_period=checkpoint_period,
             arguments=arguments,
             summary_writer=summary_writer)

    return model
Exemple #2
0
    def load(self, f=None):
        if self.has_checkpoint():
            # override argument with existing checkpoint
            f = self.get_checkpoint_file()
        if not f:
            # no checkpoint could be found
            self.logger.info("No checkpoint found. Initializing model from scratch")
            log_optimizer_scheduler_info(self.logger, self.optimizer, self.scheduler)

            return {}

        self.logger.info("Loading checkpoint from {}".format(f))
        checkpoint = self._load_file(f)
        self._load_model(checkpoint)

        if self.cfg.PRIORITY_CONFIG:
            temp_optimizer = make_optimizer(self.cfg, self.model)
            self.optimizer.load_state_dict(temp_optimizer.state_dict())

            for group in self.optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])

            iteration = checkpoint['iteration'] if 'iteration' in checkpoint else 0
            last_epoch = iteration - 1
            temp_scheduler = make_lr_scheduler(self.cfg, self.optimizer, last_epoch=last_epoch)
            self.scheduler.load_state_dict(temp_scheduler.state_dict())

            # remove processed stat data
            for stat_name in ["optimizer", "scheduler"]:
                if stat_name in checkpoint:
                    checkpoint.pop(stat_name)
        else:
            if "optimizer" in checkpoint and self.optimizer:
                self.logger.info("Loading optimizer from {}".format(f))
                self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
            if "scheduler" in checkpoint and self.scheduler:
                self.logger.info("Loading scheduler from {}".format(f))
                self.scheduler.load_state_dict(checkpoint.pop("scheduler"))

        if self.optimizer is not None and self.scheduler is not None:
            log_optimizer_scheduler_info(self.logger, self.optimizer, self.scheduler)

        # return any further checkpoint data
        return checkpoint