Ejemplo n.º 1
0
def benchmark_train(args):
    cfg = setup(args)
    model = build_model(cfg)
    logger.info("Model:\n{}".format(model))
    if comm.get_world_size() > 1:
        model = DistributedDataParallel(model,
                                        device_ids=[comm.get_local_rank()],
                                        broadcast_buffers=False)
    optimizer = build_optimizer(cfg, model)
    checkpointer = DetectionCheckpointer(model, optimizer=optimizer)
    checkpointer.load(cfg.MODEL.WEIGHTS)

    cfg.defrost()
    cfg.DATALOADER.NUM_WORKERS = 0
    data_loader = build_detection_train_loader(cfg)
    dummy_data = list(itertools.islice(data_loader, 100))

    def f():
        while True:
            yield from DatasetFromList(dummy_data, copy=False)

    max_iter = 400
    trainer = SimpleTrainer(model, f(), optimizer)
    trainer.register_hooks([
        hooks.IterationTimer(),
        hooks.PeriodicWriter([CommonMetricPrinter(max_iter)])
    ])
    trainer.train(1, max_iter)
Ejemplo n.º 2
0
    def __init__(self, cfg, model_build_func):
        """
        Args:
            cfg (BaseConfig):
        """
        logger = logging.getLogger("cvpods")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for d2
            setup_logger()

        self.start_iter = 0

        data_loader = self.build_train_loader(cfg)
        epoch_iters = adjust_epoch_and_iter(cfg, data_loader)
        self.max_iter = cfg.SOLVER.LR_SCHEDULER.MAX_ITER
        self.max_epoch = cfg.SOLVER.LR_SCHEDULER.MAX_EPOCH

        model = model_build_func(cfg)
        model = maybe_convert_module(model)
        logger.info(f"Model structure: {model}")

        # Assume these objects must be constructed in this order.
        optimizer = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False,
                                            find_unused_parameters=True)
        # TODO: @wangfeng02, `batch_subdivisions`
        super().__init__(model, data_loader, optimizer,
                         cfg.SOLVER.BATCH_SUBDIVISIONS)

        if not cfg.SOLVER.LR_SCHEDULER.get("EPOCH_WISE", False):
            epoch_iters = -1

        self.scheduler = self.build_lr_scheduler(cfg,
                                                 optimizer,
                                                 epoch_iters=epoch_iters)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )

        self.cfg = cfg
        self.register_hooks(self.build_hooks())
Ejemplo n.º 3
0
def main(args):
    config.merge_from_list(args.opts)
    cfg = setup(args)
    model = build_model(cfg)

    logger.info("Model:\n{}".format(model))
    if args.eval_only:
        DefaultCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        return do_test(cfg, model)

    distributed = comm.get_world_size() > 1
    if distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[comm.get_local_rank()],
                                        broadcast_buffers=False)

    do_train(cfg, model)
    return do_test(cfg, model)
Ejemplo n.º 4
0
    def __init__(self, cfg, model_build_func):
        """
        Args:
            cfg (BaseConfig):
        """
        logger = logging.getLogger("cvpods")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for d2
            setup_logger()

        # For simulate large batch training
        images_per_batch = cfg.SOLVER.IMS_PER_BATCH
        batch_subdivisions = cfg.SOLVER.BATCH_SUBDIVISIONS
        assert (
            batch_subdivisions > 0
        ), "cfg.SOLVER.BATCH_SUBDIVISIONS ({}) must be greater than or equal to 1.".format(
            batch_subdivisions
        )
        if batch_subdivisions > 1:
            # if batch_subdivisions is equal to 1, the following check is redundant
            assert (
                images_per_batch % batch_subdivisions == 0
            ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the " \
                "cfg.SOLVER.BATCH_SUBDIVISIONS ({}).".format(images_per_batch, batch_subdivisions)
            images_per_mini_batch = images_per_batch // batch_subdivisions

            num_workers = comm.get_world_size()
            assert (
                images_per_mini_batch % num_workers == 0
            ), "images per mini batch ({}, is calculated from cfg.SOLVER.IMS_PER_BATCH // " \
                "cfg.SOLVER.BATCH_SUBDIVISIONS) must be divisible by the number of workers " \
                "({}).".format(images_per_mini_batch, num_workers)

            assert (
                images_per_mini_batch >= num_workers
            ), "images per mini batch ({}, is calculated from cfg.SOLVER.IMS_PER_BATCH // " \
                "cfg.SOLVER.BATCH_SUBDIVISIONS) must be larger than the number of workers " \
                "({}).".format(images_per_mini_batch, num_workers)
        self.batch_subdivisions = batch_subdivisions

        data_loader = self.build_train_loader(cfg)

        self.start_iter = 0
        self.max_iter = cfg.SOLVER.LR_SCHEDULER.MAX_ITER
        self.max_epoch = cfg.SOLVER.LR_SCHEDULER.MAX_EPOCH
        self.cfg = cfg

        if self.max_epoch is not None:
            epoch_iter = len(data_loader.dataset) // (
                comm.get_world_size() * data_loader.batch_size * cfg.SOLVER.BATCH_SUBDIVISIONS
            ) + 1
            cfg.SOLVER.LR_SCHEDULER.MAX_ITER = self.max_iter = self.max_epoch * epoch_iter
            cfg.SOLVER.LR_SCHEDULER.STEPS = list(
                (x * epoch_iter for x in cfg.SOLVER.LR_SCHEDULER.STEPS))
            cfg.SOLVER.LR_SCHEDULER.WARMUP_ITERS = int(
                cfg.SOLVER.LR_SCHEDULER.WARMUP_ITERS * epoch_iter)
            cfg.SOLVER.CHECKPOINT_PERIOD = epoch_iter * cfg.SOLVER.CHECKPOINT_PERIOD
            cfg.TEST.EVAL_PERIOD = epoch_iter * cfg.TEST.EVAL_PERIOD
        else:
            epoch_iter = -1

        model = model_build_func(cfg)
        logger.info(f"Model structure: {model}")

        # Assume these objects must be constructed in this order.
        optimizer = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False,
                                            find_unused_parameters=True)
        super().__init__(model, data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer, epoch_iters=epoch_iter)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )

        self.register_hooks(self.build_hooks())
Ejemplo n.º 5
0
    def __init__(self, cfg, build_model):
        """
        Args:
            cfg (config dict):
        """

        self.data_loader = self.build_train_loader(cfg)
        # Assume these objects must be constructed in this order.
        model = build_model(cfg)
        self.model = maybe_convert_module(model)
        logger.info(f"Model: \n{self.model}")

        # Assume these objects must be constructed in this order.
        self.optimizer = self.build_optimizer(cfg, self.model)

        if cfg.TRAINER.FP16.ENABLED:
            self.mixed_precision = True
            if cfg.TRAINER.FP16.TYPE == "APEX":
                from apex import amp
                self.model, self.optimizer = amp.initialize(
                    self.model,
                    self.optimizer,
                    opt_level=cfg.TRAINER.FP16.OPTS.OPT_LEVEL)
        else:
            self.mixed_precision = False

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            torch.cuda.set_device(comm.get_local_rank())
            if cfg.MODEL.DDP_BACKEND == "torch":
                self.model = DistributedDataParallel(
                    self.model,
                    device_ids=[comm.get_local_rank()],
                    broadcast_buffers=False,
                    find_unused_parameters=True)
            elif cfg.MODEL.DDP_BACKEND == "apex":
                from apex.parallel import DistributedDataParallel as ApexDistributedDataParallel
                self.model = ApexDistributedDataParallel(self.model)
            else:
                raise ValueError("non-supported DDP backend: {}".format(
                    cfg.MODEL.DDP_BACKEND))

        super().__init__(
            self.model,
            self.data_loader,
            self.optimizer,
        )

        if not cfg.SOLVER.LR_SCHEDULER.get("EPOCH_WISE", False):
            epoch_iters = -1
        else:
            epoch_iters = cfg.SOLVER.LR_SCHEDULER.get("EPOCH_ITERS")
            logger.warning(f"Setup LR Scheduler in EPOCH mode: {epoch_iters}")

        auto_scale_config(cfg, self.data_loader)
        self.scheduler = self.build_lr_scheduler(cfg,
                                                 self.optimizer,
                                                 epoch_iters=epoch_iters)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DefaultCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
        )

        self.start_iter = 0
        self.start_epoch = 0
        self.max_iter = cfg.SOLVER.LR_SCHEDULER.MAX_ITER
        self.max_epoch = cfg.SOLVER.LR_SCHEDULER.MAX_EPOCH
        self.window_size = cfg.TRAINER.WINDOW_SIZE

        self.cfg = cfg

        self.register_hooks(self.build_hooks())