コード例 #1
0
 def build_optimizer(cls, cfg, model):
     """
     Returns:
         torch.optim.Optimizer:
     It now calls :func:`detectron2.solver.build_optimizer`.
     Overwrite it if you'd like a different optimizer.
     """
     return build_optimizer(cfg, model)
コード例 #2
0
ファイル: plain_train_net.py プロジェクト: zl930216/fast-reid
def do_train(cfg, model, resume=False):
    data_loader = build_reid_train_loader(cfg)

    model.train()
    optimizer = build_optimizer(cfg, model)

    iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
    scheduler = build_lr_scheduler(cfg, optimizer, iters_per_epoch)

    checkpointer = Checkpointer(model,
                                cfg.OUTPUT_DIR,
                                save_to_disk=comm.is_main_process(),
                                optimizer=optimizer**scheduler)

    start_epoch = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get("epoch", -1) + 1)
    iteration = start_iter = start_epoch * iters_per_epoch

    max_epoch = cfg.SOLVER.MAX_EPOCH
    max_iter = max_epoch * iters_per_epoch
    warmup_iters = cfg.SOLVER.WARMUP_ITERS
    delay_epochs = cfg.SOLVER.DELAY_EPOCHS

    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 cfg.SOLVER.CHECKPOINT_PERIOD,
                                                 max_epoch)

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR)
    ] if comm.is_main_process() else [])

    # compared to "train_net.py", we do not support some hooks, such as
    # accurate timing, FP16 training and precise BN here,
    # because they are not trivial to implement in a small training loop
    logger.info("Start training from epoch {}".format(start_epoch))
    with EventStorage(start_iter) as storage:
        for epoch in range(start_epoch, max_epoch):
            storage.epoch = epoch
            for data, _ in zip(data_loader, range(iters_per_epoch)):
                storage.iter = iteration

                loss_dict = model(data)
                losses = sum(loss_dict.values())
                assert torch.isfinite(losses).all(), loss_dict

                loss_dict_reduced = {
                    k: v.item()
                    for k, v in comm.reduce_dict(loss_dict).items()
                }
                losses_reduced = sum(loss
                                     for loss in loss_dict_reduced.values())
                if comm.is_main_process():
                    storage.put_scalars(total_loss=losses_reduced,
                                        **loss_dict_reduced)

                optimizer.zero_grad()
                losses.backward()
                optimizer.step()
                storage.put_scalar("lr",
                                   optimizer.param_groups[0]["lr"],
                                   smoothing_hint=False)

                if iteration - start_iter > 5 and (
                    (iteration + 1) % 200 == 0 or iteration == max_iter - 1):
                    for writer in writers:
                        writer.write()

                iteration += 1

                if iteration <= warmup_iters:
                    scheduler["warmup_sched"].step()

            # Write metrics after each epoch
            for writer in writers:
                writer.write()

            if iteration > warmup_iters and (epoch + 1) >= delay_epochs:
                scheduler["lr_sched"].step()

            if (cfg.TEST.EVAL_PERIOD > 0
                    and (epoch + 1) % cfg.TEST.EVAL_PERIOD == 0
                    and epoch != max_iter - 1):
                do_test(cfg, model)
                # Compared to "train_net.py", the test results are not dumped to EventStorage

            periodic_checkpointer.step(epoch)
コード例 #3
0
    def __init__(self, cfg):
        TrainerBase.__init__(self)

        logger = logging.getLogger('fastreid.partial-fc.trainer')
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
        model = self.build_model(cfg)
        optimizer, param_wrapper = self.build_optimizer(cfg, model)

        if cfg.MODEL.HEADS.PFC.ENABLED:
            # fmt: off
            feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
            embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
            num_classes = cfg.MODEL.HEADS.NUM_CLASSES
            sample_rate = cfg.MODEL.HEADS.PFC.SAMPLE_RATE
            cls_type = cfg.MODEL.HEADS.CLS_LAYER
            scale = cfg.MODEL.HEADS.SCALE
            margin = cfg.MODEL.HEADS.MARGIN
            # fmt: on
            # Partial-FC module
            embedding_size = embedding_dim if embedding_dim > 0 else feat_dim
            self.pfc_module = PartialFC(embedding_size, num_classes,
                                        sample_rate, cls_type, scale, margin)
            self.pfc_optimizer, _ = build_optimizer(cfg, self.pfc_module,
                                                    False)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(
                model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
            )

        if cfg.MODEL.HEADS.PFC.ENABLED:
            mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
            grad_scaler = MaxClipGradScaler(mini_batch_size,
                                            128 * mini_batch_size,
                                            growth_interval=100)
            self._trainer = PFCTrainer(model, data_loader, optimizer,
                                       param_wrapper, self.pfc_module,
                                       self.pfc_optimizer,
                                       cfg.SOLVER.AMP.ENABLED, grad_scaler)
        else:
            self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else
                             SimpleTrainer)(model, data_loader, optimizer,
                                            param_wrapper)

        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, optimizer,
                                                 self.iters_per_epoch)
        if cfg.MODEL.HEADS.PFC.ENABLED:
            self.pfc_scheduler = self.build_lr_scheduler(
                cfg, self.pfc_optimizer, self.iters_per_epoch)

        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            **self.scheduler,
        )

        if cfg.MODEL.HEADS.PFC.ENABLED:
            self.pfc_checkpointer = PfcCheckpointer(
                self.pfc_module,
                cfg.OUTPUT_DIR,
                optimizer=self.pfc_optimizer,
                **self.pfc_scheduler,
            )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())