Beispiel #1
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)

        optimizer_ckpt = dict(optimizer=optimizer)
        if cfg.SOLVER.FP16_ENABLED:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
            optimizer_ckpt.update(dict(amp=amp))

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            # model = DistributedDataParallel(
            #     model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            # )
            model = DistributedDataParallel(model, delay_allreduce=True)

        self._trainer = (AMPTrainer if cfg.SOLVER.FP16_ENABLED else
                         SimpleTrainer)(model, data_loader, optimizer)

        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            **optimizer_ckpt,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
Beispiel #2
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader)
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False)

        super().__init__(model, data_loader, optimizer, cfg.SOLVER.BASE_LR,
                         cfg.MODEL.LOSSES.CENTER.LR,
                         cfg.MODEL.LOSSES.CENTER.SCALE, cfg.SOLVER.AMP_ENABLED)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        if cfg.SOLVER.SWA.ENABLED:
            self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
        else:
            self.max_iter = cfg.SOLVER.MAX_ITER

        self.cfg = cfg

        self.register_hooks(self.build_hooks())
Beispiel #3
0
def default_setup(cfg, args):
    """
    Perform some basic common setups at the beginning of a job, including:
    1. Set up the detectron2 logger
    2. Log basic information about environment, cmdline arguments, and config
    3. Backup the config to the output directory
    Args:
        cfg (CfgNode): the full config to be used
        args (argparse.NameSpace): the command line arguments to be logged
    """
    output_dir = cfg.OUTPUT_DIR
    if comm.is_main_process() and output_dir:
        PathManager.mkdirs(output_dir)

    rank = comm.get_rank()
    setup_logger(output_dir, distributed_rank=rank, name="fvcore")
    logger = setup_logger(output_dir, distributed_rank=rank)

    logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
    logger.info("Environment info:\n" + collect_env_info())

    logger.info("Command line arguments: " + str(args))
    if hasattr(args, "config_file") and args.config_file != "":
        logger.info(
            "Contents of args.config_file={}:\n{}".format(
                args.config_file, PathManager.open(args.config_file, "r").read()
            )
        )

    logger.info("Running with full config:\n{}".format(cfg))
    if comm.is_main_process() and output_dir:
        # Note: some of our scripts may expect the existence of
        # config.yaml in output directory
        path = os.path.join(output_dir, "config.yaml")
        with PathManager.open(path, "w") as f:
            f.write(cfg.dump())
        logger.info("Full config saved to {}".format(os.path.abspath(path)))

    # make sure each worker has a different, yet deterministic seed if specified
    seed_all_rng()

    # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
    # typical validation set.
    if not (hasattr(args, "eval_only") and args.eval_only):
        torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
Beispiel #4
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        self.cfg = cfg
        logger = logging.getLogger(__name__)
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()
        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        logger.info('Prepare training set')
        data_loader = self.build_train_loader(cfg)
        # For training, wrap with DP. But don't need this for inference.
        model = DataParallel(model)
        if cfg.MODEL.BACKBONE.NORM == "syncBN":
            # Monkey-patching with syncBN
            patch_replication_callback(model)
        model = model.cuda()
        super().__init__(model, data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            self.data_loader.loader.dataset,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        if cfg.SOLVER.SWA.ENABLED:
            self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
        else:
            self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
Beispiel #5
0
        self.predictors = []

        for cfg in cfgs:
            self.predictors.append(DefaultPredictor(cfg))

    def run_on_loader(self, data_loader):
        for batch in data_loader:
            predictions = []
            for predictor in self.predictors:
                predictions.append(predictor(batch["images"]))
            yield torch.cat(predictions, dim=-1), batch


if __name__ == "__main__":
    args = get_parser().parse_args()
    logger = setup_logger()
    cfgs = []
    for config_file in args.config_file:
        cfg = setup_cfg(config_file, args.opts)
        cfgs.append(cfg)
    results = OrderedDict()
    for dataset_name in cfgs[0].DATASETS.TESTS:
        test_loader, num_query = build_reid_test_loader(cfgs[0], dataset_name)
        evaluator = ReidEvaluator(cfgs[0], num_query)
        feat_extract = FeatureExtraction(cfgs)
        for (feat, batch) in tqdm.tqdm(feat_extract.run_on_loader(test_loader),
                                       total=len(test_loader)):
            evaluator.process(batch, feat)
        result = evaluator.evaluate()
        results[dataset_name] = result
    print_csv_format(results)
Beispiel #6
0
import torch

sys.path.append('../../')

import pytorch_to_caffe
from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import build_model
from fastreid.utils.file_io import PathManager
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger

# import some modules added in project like this below
# sys.path.append('../projects/FastCls')
# from fastcls import *

setup_logger(name='fastreid')
logger = logging.getLogger("fastreid.caffe_export")


def setup_cfg(args):
    cfg = get_cfg()
    # add_cls_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg


def get_parser():
    parser = argparse.ArgumentParser(
        description="Convert Pytorch to Caffe model")
Beispiel #7
0
"""

import argparse

import torch
import sys
sys.path.append('../../')

import pytorch_to_caffe
from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import build_model
from fastreid.utils.file_io import PathManager
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger

logger = setup_logger(name='caffe_export')


def setup_cfg(args):
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg


def get_parser():
    parser = argparse.ArgumentParser(description="Convert Pytorch to Caffe model")

    parser.add_argument(
        "--config-file",
Beispiel #8
0
    def __init__(self, cfg):
        TrainerBase.__init__(self)

        logger = logging.getLogger('fastreid.partial-fc.trainer')
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)

        if cfg.MODEL.HEADS.PFC.ENABLED:
            # fmt: off
            feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
            embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
            num_classes = cfg.MODEL.HEADS.NUM_CLASSES
            sample_rate = cfg.MODEL.HEADS.PFC.SAMPLE_RATE
            cls_type = cfg.MODEL.HEADS.CLS_LAYER
            scale = cfg.MODEL.HEADS.SCALE
            margin = cfg.MODEL.HEADS.MARGIN
            # fmt: on
            # Partial-FC module
            embedding_size = embedding_dim if embedding_dim > 0 else feat_dim
            self.pfc_module = PartialFC(embedding_size, num_classes,
                                        sample_rate, cls_type, scale, margin)
            self.pfc_optimizer = self.build_optimizer(cfg, self.pfc_module)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False,
                                            find_unused_parameters=True)

        self._trainer = PFCTrainer(model, data_loader, optimizer, self.pfc_module, self.pfc_optimizer) \
            if cfg.MODEL.HEADS.PFC.ENABLED else SimpleTrainer(model, data_loader, optimizer)

        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, optimizer,
                                                 self.iters_per_epoch)
        if cfg.MODEL.HEADS.PFC.ENABLED:
            self.pfc_scheduler = self.build_lr_scheduler(
                cfg, self.pfc_optimizer, self.iters_per_epoch)

        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            **self.scheduler,
        )

        if cfg.MODEL.HEADS.PFC.ENABLED:
            self.pfc_checkpointer = Checkpointer(
                self.pfc_module,
                cfg.OUTPUT_DIR,
                optimizer=self.pfc_optimizer,
                **self.pfc_scheduler,
            )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
"""

import argparse
import os
import sys

import tensorrt as trt

from trt_calibrator import FeatEntropyCalibrator

sys.path.append('../../')

from fastreid.utils.logger import setup_logger
from fastreid.utils.file_io import PathManager

logger = setup_logger(name='trt_export')


def get_parser():
    parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")

    parser.add_argument('--name',
                        default='baseline',
                        help="name for converted model")
    parser.add_argument('--output',
                        default='outputs/trt_model',
                        help="path to save converted trt model")
    parser.add_argument(
        '--mode',
        default='fp32',
        help=
from onnxsim import simplify
from torch.onnx import OperatorExportTypes

sys.path.append('../../')

from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import build_model
from fastreid.utils.file_io import PathManager
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger

# import some modules added in project like this below
# sys.path.append('../../projects/FastDistill')
# from fastdistill import *

logger = setup_logger(name='onnx_export')


def setup_cfg(args):
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg


def get_parser():
    parser = argparse.ArgumentParser(
        description="Convert Pytorch to ONNX model")

    parser.add_argument(
Beispiel #11
0
    def __init__(self, cfg):
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # if setup_logger is not called for fastreid
            setup_logger()

        logger.info("==> Load target-domain dataset")
        self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT)
        self.tgt_nums = len(tgt.train)

        cfg = self.auto_scale_hyperparams(cfg, self.tgt_nums)

        # Create model
        self.model = self.build_model(cfg,
                                      load_model=cfg.MODEL.PRETRAIN,
                                      show_model=True,
                                      use_dsbn=False)

        # Optimizer
        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
                find_unused_parameters=True)

        # Learning rate scheduler
        self.iters_per_epoch = cfg.SOLVER.ITERS
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = "AMPTrainer does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None
Beispiel #12
0
@contact: [email protected]
"""

import argparse
import os
import sys

import tensorrt as trt

from trt_calibrator import FeatEntropyCalibrator

sys.path.append('.')

from fastreid.utils.logger import setup_logger, PathManager

logger = setup_logger(name="trt_export")


def get_parser():
    parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")

    parser.add_argument(
        '--name',
        default='baseline',
        help="name for converted model"
    )
    parser.add_argument(
        '--output',
        default='outputs/trt_model',
        help="path to save converted trt model"
    )
sys.path.append('.')

from fastreid.evaluation import evaluate_rank
from fastreid.config import get_cfg
from fastreid.utils.logger import setup_logger
from fastreid.data import build_reid_test_loader
from predictor import FeatureExtractionDemo
from fastreid.utils.visualizer import Visualizer

# import some modules added in project
# for example, add partial reid like this below
# sys.path.append("projects/PartialReID")
# from partialreid import *

cudnn.benchmark = True
setup_logger(name="fastreid")

logger = logging.getLogger('fastreid.visualize_result')


def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    # add_partialreid_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg


def get_parser():
Beispiel #14
0
    def __init__(self, cfg):
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # if setup_logger is not called for fastreid
            setup_logger()

        # Create datasets
        logger.info("==> Load source-domain dataset")
        self.src = src = self.load_dataset(cfg.DATASETS.SRC)
        self.src_pid_nums = src.get_num_pids(src.train)

        logger.info("==> Load target-domain dataset")
        self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT)
        self.tgt_nums = len(tgt.train)

        # Create model
        self.model = self.build_model(cfg,
                                      load_model=False,
                                      show_model=False,
                                      use_dsbn=True)

        # Create hybrid memorys
        self.hm = HybridMemory(num_features=cfg.MODEL.BACKBONE.FEAT_DIM,
                               num_samples=self.src_pid_nums + self.tgt_nums,
                               temp=cfg.MEMORY.TEMP,
                               momentum=cfg.MEMORY.MOMENTUM,
                               use_half=cfg.SOLVER.AMP.ENABLED).cuda()

        # Initialize source-domain class centroids
        logger.info(
            "==> Initialize source-domain class centroids in the hybrid memory"
        )
        with inference_context(self.model), torch.no_grad():
            src_train = self.build_dataset(cfg,
                                           src.train,
                                           is_train=False,
                                           relabel=False,
                                           with_mem_idx=False)
            src_init_feat_loader = self.build_test_loader(cfg, src_train)
            src_fname_feat_dict, _ = extract_features(self.model,
                                                      src_init_feat_loader)
            src_feat_dict = collections.defaultdict(list)
            for f, pid, _ in sorted(src.train):
                src_feat_dict[pid].append(src_fname_feat_dict[f].unsqueeze(0))
            src_centers = [
                torch.cat(src_feat_dict[pid], 0).mean(0)
                for pid in sorted(src_feat_dict.keys())
            ]
            src_centers = torch.stack(src_centers, 0)
            src_centers = F.normalize(src_centers, dim=1)

        # Initialize target-domain instance features
        logger.info(
            "==> Initialize target-domain instance features in the hybrid memory"
        )
        with inference_context(self.model), torch.no_grad():
            tgt_train = self.build_dataset(cfg,
                                           tgt.train,
                                           is_train=False,
                                           relabel=False,
                                           with_mem_idx=False)
            tgt_init_feat_loader = self.build_test_loader(cfg, tgt_train)
            tgt_fname_feat_dict, _ = extract_features(self.model,
                                                      tgt_init_feat_loader)
            tgt_features = torch.cat([
                tgt_fname_feat_dict[f].unsqueeze(0)
                for f, _, _ in sorted(self.tgt.train)
            ], 0)
            tgt_features = F.normalize(tgt_features, dim=1)

        self.hm.features = torch.cat((src_centers, tgt_features), dim=0).cuda()

        del (src_train, src_init_feat_loader, src_fname_feat_dict,
             src_feat_dict, src_centers, tgt_train, tgt_init_feat_loader,
             tgt_fname_feat_dict, tgt_features)

        # Optimizer
        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
                find_unused_parameters=True)

        # Learning rate scheduler
        self.iters_per_epoch = cfg.SOLVER.ITERS
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = "AMPTrainer does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None
Beispiel #15
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)

        self.model = self.build_model(cfg)

        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
            )

        self._data_loader_iter = iter(data_loader)
        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = f"[{self.__class__.__name__}] does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None