Example #1
0
    def __init__(self, cfg):
        self._logger = setup_logger(__name__, all_rank=True)
        
        if dist.is_main_process():
            self._logger.debug(f'Config File : \n{cfg}')
            if cfg.VISUALIZE_DIR and not os.path.isdir(cfg.VISUALIZE_DIR) : os.makedirs(cfg.VISUALIZE_DIR)
            self.visualize_dir = cfg.VISUALIZE_DIR
        dist.synchronize()
        
        self.test_loader = build_test_loader(cfg)

        self.model = build_model(cfg)
        self.model.eval()
        if dist.is_main_process():
            self._logger.debug(f"Model Structure\n{self.model}")
                
        if dist.get_world_size() > 1:
            self.model = DistributedDataParallel(self.model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)

        self.checkpointer = Checkpointer(
            self.model,
            cfg.OUTPUT_DIR,
        )
        self.checkpointer.load(cfg.WEIGHTS)

        self.meta_data = MetadataCatalog.get(cfg.LOADER.TEST_DATASET)
        self.class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
Example #2
0
def build_train_loader(cfg):
    images_per_batch = cfg.SOLVER.IMG_PER_BATCH
    assert images_per_batch >= dist.get_world_size()
    assert images_per_batch % dist.get_world_size() == 0

    data = [
        DatasetCatalog.get(train_dataset)
        for train_dataset in cfg.LOADER.TRAIN_DATASET
    ]
    data = list(itertools.chain.from_iterable(data))

    dataset = ListDataset(cfg, data)
    mapper = DatasetMapper(cfg, is_train=True)
    dataset = MapDataset(dataset, mapper)

    sampler = IterSampler(cfg, dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler, images_per_batch // dist.get_world_size(),
        drop_last=True)  # drop last so the batch always have the same size

    if cfg.LOADER.ASPECT_GROUPING:
        data_loader = torch.utils.data.DataLoader(
            dataset,
            sampler=sampler,
            batch_sampler=None,
            num_workers=cfg.LOADER.NUM_WORKERS,
            collate_fn=operator.itemgetter(0),
            worker_init_fn=worker_init_reset_seed,
        )
        return AspectRatioGroupedDataset(
            data_loader, images_per_batch // dist.get_world_size())

    else:
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler, images_per_batch // dist.get_world_size(),
            drop_last=True)  # drop last so the batch always have the same size

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            num_workers=cfg.LOADER.NUM_WORKERS,
            collate_fn=trivial_batch_collator,
            worker_init_fn=worker_init_reset_seed,
        )
        return data_loader
Example #3
0
    def __init__(self, cfg, dataset):
        self._size = len(dataset)
        assert self._size > 0
        self._shuffle = cfg.LOADER.TRAIN_SHUFFLE
        if cfg.SEED < 0: self._seed = int(dist.shared_random_seed())
        else: self._seed = int(cfg.SEED)

        self._rank = dist.get_rank()
        self._world_size = dist.get_world_size()
Example #4
0
    def __init__(self, size: int):
        self._size = size
        assert size > 0
        self._rank = dist.get_rank()
        self._world_size = dist.get_world_size()

        shard_size = (self._size - 1) // self._world_size + 1
        begin = shard_size * self._rank
        end = min(shard_size * (self._rank + 1), self._size)
        self._local_indices = range(begin, end)
Example #5
0
def evaluator(model, data_loader, evaluators):
    num_devices = dist.get_world_size()
    _logger = setup_logger(__name__, all_rank=True)

    total = len(data_loader)  # inference data loader must have a fixed length
    _logger.info(f"Start inference on {total} images")

    if evaluators is None: evaluators = Evaluators([])
    evaluators.reset()

    timer = Timer(warmup=5, pause=True)
    total_compute_time = 0
    total_time = 0

    with inference_context(model), torch.no_grad():
        for idx, inputs in enumerate(data_loader):
            timer.resume()
            outputs = model(inputs)
            if torch.cuda.is_available(): torch.cuda.synchronize()
            timer.pause()
            evaluators.process(inputs, outputs)

            if timer.total_seconds() > 10:
                total_compute_time += timer.seconds()
                total_time += timer.total_seconds()
                timer.reset(pause=True)

                total_seconds_per_img = total_time / (idx + 1)
                seconds_per_img = total_compute_time / (idx + 1)
                eta = datetime.timedelta(seconds=int(total_seconds_per_img *
                                                     (total - idx - 1)))
                _logger.info(
                    f"Inference done {idx + 1}/{total}. {seconds_per_img:.4f} s / img. ETA={eta}"
                )

    total_compute_time += timer.seconds()
    total_time += timer.total_seconds()

    total_time_str = str(datetime.timedelta(seconds=total_time))
    _logger.info(
        f"Total inference time: {total_time_str} ({total_time / total:.6f} s / img per device, on {num_devices} devices)"
    )

    total_compute_time_str = str(
        datetime.timedelta(seconds=int(total_compute_time)))
    _logger.info(
        f"Total inference pure compute time: {total_compute_time_str} ({total_compute_time / total:.6f} s / img per device, on {num_devices} devices)"
    )

    results = evaluators.evaluate()
    if results is None: results = {}
    return results
Example #6
0
    def __init__(self, cfg):
        super().__init__(cfg)

        if cfg.SEED < 0 : cfg.SEED = dist.shared_random_seed()
        self._seed = cfg.SEED
        seed_all_rng(self._seed)
        
        self._logger.debug(f'Config File : \n{cfg}')
        if dist.is_main_process():
            if cfg.OUTPUT_DIR and not os.path.isdir(cfg.OUTPUT_DIR) : os.makedirs(cfg.OUTPUT_DIR)
            with open(os.path.join(cfg.OUTPUT_DIR, 'config'), 'w') as f:
                f.write(cfg.dump())
        dist.synchronize()
        
        self.train_loader = build_train_loader(cfg)
        self.test_loader = build_test_loader(cfg)
        self.train_iter = iter(self.train_loader)

        self.model = build_model(cfg)
        self.model.train()
        if dist.is_main_process():
            self._logger.debug(f"Model Structure\n{self.model}")
        
        self.optimizer = build_optimizer(cfg, self.model)
        self.optimizer.zero_grad()
        self.scheduler = build_lr_scheduler(cfg, self.optimizer)
        self.accumulate = cfg.SOLVER.ACCUMULATE
        
        if dist.get_world_size() > 1:
            self.model = DistributedDataParallel(self.model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)

        self.weight_path = cfg.WEIGHTS
        self.checkpointer = Checkpointer(
            self.model,
            cfg.OUTPUT_DIR,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
        )

        self.evaluator = build_evaluator(cfg)

        hooks = build_hooks(cfg, self.model, self.optimizer, self.scheduler, self.checkpointer)
        self.register_hooks(hooks)
Example #7
0
    def __call__(self):
        num_devices = dist.get_world_size()

        total = len(self.test_loader)  # inference data loader must have a fixed length
        self._logger.info(f"Start visualize on {total} images")

        timer = Timer(warmup = 5, pause=True)
        total_compute_time = 0
        total_time = 0

        with inference_context(self.model), torch.no_grad():
            for idx, inputs in enumerate(self.test_loader):
                timer.resume()
                outputs = self.model(inputs)
                if torch.cuda.is_available() : torch.cuda.synchronize()
                timer.pause()

                self.save_visualize(inputs, outputs)

                if timer.total_seconds() > 10:
                    total_compute_time += timer.seconds()
                    total_time += timer.total_seconds()
                    timer.reset(pause=True)

                    total_seconds_per_img = total_time / (idx + 1)
                    seconds_per_img = total_compute_time / (idx + 1)
                    eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
                    self._logger.info(f"Visualize done {idx + 1}/{total}. {seconds_per_img:.4f} s / img. ETA={eta}")

        total_compute_time += timer.seconds()
        total_time += timer.total_seconds()

        total_time_str = str(datetime.timedelta(seconds=total_time))
        self._logger.info(
            f"Total Visualize time: {total_time_str} ({total_time / total:.6f} s / img per device, on {num_devices} devices)"
        )

        total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
        self._logger.info(
            f"Total visualize pure compute time: {total_compute_time_str} ({total_compute_time / total:.6f} s / img per device, on {num_devices} devices)"
        )