def make_train_data_loader(is_distributed=False, start_iter=0, image_size_divisibility=32, num_workers=1, batch_size=2, num_iters=90000, shuffle=True, single_block=False): num_gpus = get_world_size() assert batch_size % num_gpus == 0, "batch_size ({}) must be divisible by the number " "of GPUs ({}) used.".format(batch_size, num_gpus) images_per_gpu = batch_size // num_gpus if images_per_gpu > 1: logger = logging.getLogger(__name__) logger.warning( "When using more than one image per GPU you may encounter " "an out-of-memory (OOM) error if your GPU does not have " "sufficient memory. If this happens, you can reduce " "SOLVER.IMS_PER_BATCH (for training) or " "TEST.IMS_PER_BATCH (for inference). For training, you must " "also adjust the learning rate and schedule length according " "to the linear scaling rule. See for example: " "https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14" ) transforms = build_transforms(is_train=True) dataset = Scarlet300MaskDataset('train', transforms=transforms) if is_distributed: sampler = samplers.DistributedSampler(dataset, shuffle=shuffle) else: sampler = torch.utils.data.sampler.RandomSampler(dataset) batch_sampler = samplers.IterationBasedBatchSampler( torch.utils.data.sampler.BatchSampler(sampler, images_per_gpu, drop_last=False), num_iters, start_iter) return torch.utils.data.DataLoader( dataset, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=TTT.BatchCollator(), ) return data_loader
def make_batch_data_sampler( dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0 ): if aspect_grouping: if not isinstance(aspect_grouping, (list, tuple)): aspect_grouping = [aspect_grouping] aspect_ratios = _compute_aspect_ratios(dataset) group_ids = _quantize(aspect_ratios, aspect_grouping) batch_sampler = samplers.GroupedBatchSampler( sampler, group_ids, images_per_batch, drop_uneven=False ) else: batch_sampler = torch.utils.data.sampler.BatchSampler( sampler, images_per_batch, drop_last=False ) if num_iters is not None: batch_sampler = samplers.IterationBasedBatchSampler( batch_sampler, num_iters, start_iter ) return batch_sampler