コード例 #1
0
ファイル: uda_base.py プロジェクト: X-funbean/fast-reid
    def build_test_loader(cls, cfg, test_set):
        logger = logging.getLogger('fastreid')
        logger.info("Prepare testing loader")

        # test_loader = DataLoader(
        #     # Preprocessor(test_set),
        #     test_set,
        #     batch_size=cfg.TEST.IMS_PER_BATCH,
        #     num_workers=cfg.DATALOADER.NUM_WORKERS,
        #     shuffle=False,
        #     pin_memory=True,
        # )

        test_batch_size = cfg.TEST.IMS_PER_BATCH
        mini_batch_size = test_batch_size // comm.get_world_size()
        num_workers = cfg.DATALOADER.NUM_WORKERS
        data_sampler = samplers.InferenceSampler(len(test_set))
        batch_sampler = BatchSampler(data_sampler, mini_batch_size, False)
        test_loader = DataLoaderX(
            comm.get_local_rank(),
            dataset=test_set,
            batch_sampler=batch_sampler,
            num_workers=num_workers,  # save some memory
            collate_fn=fast_batch_collator,
            pin_memory=True,
        )

        return test_loader
コード例 #2
0
ファイル: build.py プロジェクト: sixiping/fast-reid
def build_reid_train_loader(
    train_set,
    *,
    sampler=None,
    total_batch_size,
    num_workers=0,
):
    """
    Build a dataloader for object re-identification with some default features.
    This interface is experimental.

    Returns:
        torch.utils.data.DataLoader: a dataloader.
    """

    mini_batch_size = total_batch_size // comm.get_world_size()

    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler, mini_batch_size, True)

    train_loader = DataLoaderX(
        comm.get_local_rank(),
        dataset=train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )

    return train_loader
コード例 #3
0
ファイル: defaults.py プロジェクト: xhuljl/fast-reid
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
        model = self.build_model(cfg)
        optimizer, param_wrapper = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(
                model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
            )

        self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else
                         SimpleTrainer)(model, data_loader, optimizer,
                                        param_wrapper)

        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
コード例 #4
0
    def __init__(self, embedding_size, num_classes, sample_rate, cls_type,
                 scale, margin):
        super().__init__()

        self.embedding_size = embedding_size
        self.num_classes = num_classes
        self.sample_rate = sample_rate

        self.world_size = comm.get_world_size()
        self.rank = comm.get_rank()
        self.local_rank = comm.get_local_rank()
        self.device = torch.device(f'cuda:{self.local_rank}')

        self.num_local: int = self.num_classes // self.world_size + int(
            self.rank < self.num_classes % self.world_size)
        self.class_start: int = self.num_classes // self.world_size * self.rank + \
                                min(self.rank, self.num_classes % self.world_size)
        self.num_sample: int = int(self.sample_rate * self.num_local)

        self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale,
                                                        margin)
        """ TODO: consider resume training
        if resume:
            try:
                self.weight: torch.Tensor = torch.load(self.weight_name)
                logging.info("softmax weight resume successfully!")
            except (FileNotFoundError, KeyError, IndexError):
                self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
                logging.info("softmax weight resume fail!")

            try:
                self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
                logging.info("softmax weight mom resume successfully!")
            except (FileNotFoundError, KeyError, IndexError):
                self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
                logging.info("softmax weight mom resume fail!")
        else:
        """
        self.weight = torch.normal(0,
                                   0.01, (self.num_local, self.embedding_size),
                                   device=self.device)
        self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
        logger.info("softmax weight init successfully!")
        logger.info("softmax weight mom init successfully!")
        self.stream: torch.cuda.Stream = torch.cuda.Stream(self.local_rank)

        self.index = None
        if int(self.sample_rate) == 1:
            self.update = lambda: 0
            self.sub_weight = nn.Parameter(self.weight)
            self.sub_weight_mom = self.weight_mom
        else:
            self.sub_weight = nn.Parameter(
                torch.empty((0, 0), device=self.device))
コード例 #5
0
ファイル: defaults.py プロジェクト: garyliu0816/fast-reid
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader)
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False)

        super().__init__(model, data_loader, optimizer, cfg.SOLVER.BASE_LR,
                         cfg.MODEL.LOSSES.CENTER.LR,
                         cfg.MODEL.LOSSES.CENTER.SCALE, cfg.SOLVER.AMP_ENABLED)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        if cfg.SOLVER.SWA.ENABLED:
            self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
        else:
            self.max_iter = cfg.SOLVER.MAX_ITER

        self.cfg = cfg

        self.register_hooks(self.build_hooks())
コード例 #6
0
def main(args):
    cfg = setup(args)

    model = build_model(cfg)
    logger.info("Model:\n{}".format(model))
    if args.eval_only:
        cfg.defrost()
        cfg.MODEL.BACKBONE.PRETRAIN = False

        Checkpointer(model).load(cfg.MODEL.WEIGHTS)  # load trained model

        return do_test(cfg, model)

    distributed = comm.get_world_size() > 1
    if distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[comm.get_local_rank()],
                                        broadcast_buffers=False)

    do_train(cfg, model, resume=args.resume)
    return do_test(cfg, model)
コード例 #7
0
ファイル: build.py プロジェクト: sixiping/fast-reid
def build_reid_test_loader(test_set,
                           test_batch_size,
                           num_query,
                           num_workers=4):
    """
    Similar to `build_reid_train_loader`. This sampler coordinates all workers to produce
    the exact set of all samples
    This interface is experimental.

    Args:
        test_set:
        test_batch_size:
        num_query:
        num_workers:

    Returns:
        DataLoader: a torch DataLoader, that loads the given reid dataset, with
        the test-time transformation.

    Examples:
    ::
        data_loader = build_reid_test_loader(test_set, test_batch_size, num_query)
        # or, instantiate with a CfgNode:
        data_loader = build_reid_test_loader(cfg, "my_test")
    """

    mini_batch_size = test_batch_size // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler,
                                                  mini_batch_size, False)
    test_loader = DataLoaderX(
        comm.get_local_rank(),
        dataset=test_set,
        batch_sampler=batch_sampler,
        num_workers=num_workers,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader, num_query
コード例 #8
0
    def __init__(self, cfg):
        TrainerBase.__init__(self)

        logger = logging.getLogger('fastreid.partial-fc.trainer')
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)

        if cfg.MODEL.HEADS.PFC.ENABLED:
            # fmt: off
            feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
            embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
            num_classes = cfg.MODEL.HEADS.NUM_CLASSES
            sample_rate = cfg.MODEL.HEADS.PFC.SAMPLE_RATE
            cls_type = cfg.MODEL.HEADS.CLS_LAYER
            scale = cfg.MODEL.HEADS.SCALE
            margin = cfg.MODEL.HEADS.MARGIN
            # fmt: on
            # Partial-FC module
            embedding_size = embedding_dim if embedding_dim > 0 else feat_dim
            self.pfc_module = PartialFC(embedding_size, num_classes,
                                        sample_rate, cls_type, scale, margin)
            self.pfc_optimizer = self.build_optimizer(cfg, self.pfc_module)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False,
                                            find_unused_parameters=True)

        self._trainer = PFCTrainer(model, data_loader, optimizer, self.pfc_module, self.pfc_optimizer) \
            if cfg.MODEL.HEADS.PFC.ENABLED else SimpleTrainer(model, data_loader, optimizer)

        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, optimizer,
                                                 self.iters_per_epoch)
        if cfg.MODEL.HEADS.PFC.ENABLED:
            self.pfc_scheduler = self.build_lr_scheduler(
                cfg, self.pfc_optimizer, self.iters_per_epoch)

        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            **self.scheduler,
        )

        if cfg.MODEL.HEADS.PFC.ENABLED:
            self.pfc_checkpointer = Checkpointer(
                self.pfc_module,
                cfg.OUTPUT_DIR,
                optimizer=self.pfc_optimizer,
                **self.pfc_scheduler,
            )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
コード例 #9
0
ファイル: uda_base.py プロジェクト: X-funbean/fast-reid
    def __init__(self, cfg):
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # if setup_logger is not called for fastreid
            setup_logger()

        logger.info("==> Load target-domain dataset")
        self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT)
        self.tgt_nums = len(tgt.train)

        cfg = self.auto_scale_hyperparams(cfg, self.tgt_nums)

        # Create model
        self.model = self.build_model(cfg,
                                      load_model=cfg.MODEL.PRETRAIN,
                                      show_model=True,
                                      use_dsbn=False)

        # Optimizer
        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
                find_unused_parameters=True)

        # Learning rate scheduler
        self.iters_per_epoch = cfg.SOLVER.ITERS
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = "AMPTrainer does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None
コード例 #10
0
    def __init__(self, cfg):
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # if setup_logger is not called for fastreid
            setup_logger()

        # Create datasets
        logger.info("==> Load source-domain dataset")
        self.src = src = self.load_dataset(cfg.DATASETS.SRC)
        self.src_pid_nums = src.get_num_pids(src.train)

        logger.info("==> Load target-domain dataset")
        self.tgt = tgt = self.load_dataset(cfg.DATASETS.TGT)
        self.tgt_nums = len(tgt.train)

        # Create model
        self.model = self.build_model(cfg,
                                      load_model=False,
                                      show_model=False,
                                      use_dsbn=True)

        # Create hybrid memorys
        self.hm = HybridMemory(num_features=cfg.MODEL.BACKBONE.FEAT_DIM,
                               num_samples=self.src_pid_nums + self.tgt_nums,
                               temp=cfg.MEMORY.TEMP,
                               momentum=cfg.MEMORY.MOMENTUM,
                               use_half=cfg.SOLVER.AMP.ENABLED).cuda()

        # Initialize source-domain class centroids
        logger.info(
            "==> Initialize source-domain class centroids in the hybrid memory"
        )
        with inference_context(self.model), torch.no_grad():
            src_train = self.build_dataset(cfg,
                                           src.train,
                                           is_train=False,
                                           relabel=False,
                                           with_mem_idx=False)
            src_init_feat_loader = self.build_test_loader(cfg, src_train)
            src_fname_feat_dict, _ = extract_features(self.model,
                                                      src_init_feat_loader)
            src_feat_dict = collections.defaultdict(list)
            for f, pid, _ in sorted(src.train):
                src_feat_dict[pid].append(src_fname_feat_dict[f].unsqueeze(0))
            src_centers = [
                torch.cat(src_feat_dict[pid], 0).mean(0)
                for pid in sorted(src_feat_dict.keys())
            ]
            src_centers = torch.stack(src_centers, 0)
            src_centers = F.normalize(src_centers, dim=1)

        # Initialize target-domain instance features
        logger.info(
            "==> Initialize target-domain instance features in the hybrid memory"
        )
        with inference_context(self.model), torch.no_grad():
            tgt_train = self.build_dataset(cfg,
                                           tgt.train,
                                           is_train=False,
                                           relabel=False,
                                           with_mem_idx=False)
            tgt_init_feat_loader = self.build_test_loader(cfg, tgt_train)
            tgt_fname_feat_dict, _ = extract_features(self.model,
                                                      tgt_init_feat_loader)
            tgt_features = torch.cat([
                tgt_fname_feat_dict[f].unsqueeze(0)
                for f, _, _ in sorted(self.tgt.train)
            ], 0)
            tgt_features = F.normalize(tgt_features, dim=1)

        self.hm.features = torch.cat((src_centers, tgt_features), dim=0).cuda()

        del (src_train, src_init_feat_loader, src_fname_feat_dict,
             src_feat_dict, src_centers, tgt_train, tgt_init_feat_loader,
             tgt_fname_feat_dict, tgt_features)

        # Optimizer
        self.optimizer, self.param_wrapper = self.build_optimizer(
            cfg, self.model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[comm.get_local_rank()],
                broadcast_buffers=False,
                find_unused_parameters=True)

        # Learning rate scheduler
        self.iters_per_epoch = cfg.SOLVER.ITERS
        self.scheduler = self.build_lr_scheduler(cfg, self.optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=self.optimizer,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

        if cfg.SOLVER.AMP.ENABLED:
            unsupported = "AMPTrainer does not support single-process multi-device training!"
            if isinstance(self.model, DistributedDataParallel):
                assert not (self.model.device_ids
                            and len(self.model.device_ids) > 1), unsupported

            from torch.cuda.amp.grad_scaler import GradScaler
            self.grad_scaler = GradScaler()
        else:
            self.grad_scaler = None