Ejemplo n.º 1
0
    def test_batch_size_per_device(self):
        # Need to patch the mmf.utils.general's world size not mmf.utils.distributed
        # as the first one is what will be used
        with patch("mmf.utils.general.get_world_size", return_value=2):
            trainer = TrainerTrainingLoopMock(100, 2, None, batch_size=4)
            registry.register("config", trainer.config)
            batch_size = get_batch_size()
            trainer.config.training.batch_size = batch_size
            trainer.load_datasets()
            # Train loader has batch size per device, for global batch size 4
            # with world size 2, batch size per device should 4 // 2 = 2
            self.assertEqual(trainer.train_loader.current_loader.batch_size, 2)
            # This is per device, so should stay same
            trainer = TrainerTrainingLoopMock(100,
                                              2,
                                              None,
                                              batch_size_per_device=4)
            registry.register("config", trainer.config)
            batch_size = get_batch_size()
            trainer.config.training.batch_size = batch_size
            trainer.load_datasets()
            self.assertEqual(trainer.train_loader.current_loader.batch_size, 4)

        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 2)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 2, 1, 2)
Ejemplo n.º 2
0
def _add_extra_args_for_dataloader(
    dataset_instance: mmf_typings.DatasetType,
    other_args: mmf_typings.DataLoaderArgsType = None,
) -> mmf_typings.DataLoaderArgsType:
    from mmf.utils.general import get_batch_size, get_class_weight

    if other_args is None:
        other_args = {}
    dataset_type = dataset_instance.dataset_type

    if dataset_type != "test":
        other_args["shuffle"] = True
    else:
        other_args["shuffle"] = False
        other_args["sampler"] = WeightedRandomSampler(
            torch.from_numpy(np.array(get_class_weight())), get_batch_size())
        other_args.pop("shuffle")

    # In distributed mode, we use DistributedSampler from PyTorch
    if is_dist_initialized():
        other_args["sampler"] = torch.utils.data.DistributedSampler(
            dataset_instance, shuffle=other_args["shuffle"])
        # Shuffle is mutually exclusive with sampler, let DistributedSampler
        # take care of shuffle and pop from main args
        other_args.pop("shuffle")

    other_args["batch_size"] = get_batch_size()
    return other_args
Ejemplo n.º 3
0
def _add_extra_args_for_dataloader(
    dataset_instance: torch.utils.data.Dataset, other_args: Dict[str, Any] = None
) -> Dict[str, Any]:
    from mmf.utils.general import get_batch_size

    dataset_type = dataset_instance.dataset_type

    if other_args["shuffle"] is None:
        other_args["shuffle"] = False
        if dataset_type != "test":
            other_args["shuffle"] = True

    # In distributed mode, we use DistributedSampler from PyTorch
    if is_dist_initialized():
        other_args["sampler"] = torch.utils.data.DistributedSampler(
            dataset_instance, shuffle=other_args["shuffle"]
        )
        # Shuffle is mutually exclusive with sampler, let DistributedSampler
        # take care of shuffle and pop from main args
        other_args.pop("shuffle")

    if is_xla():
        other_args["sampler"] = torch.utils.data.DistributedSampler(
            dataset_instance,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=other_args["shuffle"],
            drop_last=True,
        )
        other_args.pop("shuffle")

    if other_args["batch_size"] is None:
        other_args["batch_size"] = get_batch_size()

    return other_args
Ejemplo n.º 4
0
def _add_extra_args_for_dataloader(
    dataset_instance: torch.utils.data.Dataset,
    other_args: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
    from mmf.utils.general import get_batch_size

    if other_args is None:
        other_args = {}
    dataset_type = dataset_instance.dataset_type

    other_args["shuffle"] = False
    if dataset_type != "test":
        other_args["shuffle"] = True

    # In distributed mode, we use DistributedSampler from PyTorch
    if is_dist_initialized():
        other_args["sampler"] = torch.utils.data.DistributedSampler(
            dataset_instance, shuffle=other_args["shuffle"])
        # Shuffle is mutually exclusive with sampler, let DistributedSampler
        # take care of shuffle and pop from main args
        other_args.pop("shuffle")

    other_args["batch_size"] = get_batch_size()

    return other_args
Ejemplo n.º 5
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = get_batch_size()

        self.train_loader = MultiDatasetLoader("train")
        self.val_loader = MultiDatasetLoader("val")
        self.test_loader = MultiDatasetLoader("test")

        self.train_loader.load(self.config)
        self.val_loader.load(self.config)
        self.test_loader.load(self.config)
Ejemplo n.º 6
0
    def __init__(self, config: DictConfig):
        super().__init__()
        self.config = config
        self.batch_size = get_batch_size()

        self.dataset_list: List[str] = dataset_list_from_config(self.config)
        self.datamodules: List[pl.LightningDataModule] = build_multiple_datamodules(
            self.dataset_list, self.config.dataset_config
        )
        self.train_loader: Optional[MultiDataLoader] = None
        self.val_loader: Optional[MultiDataLoader] = None
        self.test_loader: Optional[MultiDataLoader] = None
Ejemplo n.º 7
0
    def _add_extra_args_for_dataloader(self, other_args=None):
        if other_args is None:
            other_args = {}

        if is_dist_initialized():
            other_args["sampler"] = DistributedSampler(self.current_dataset,
                                                       shuffle=False)
        else:
            other_args["shuffle"] = False

        other_args["batch_size"] = get_batch_size()

        return other_args
Ejemplo n.º 8
0
    def __init__(self, config, dataset_type, imdb_file_index, *args, **kwargs):
        super().__init__("airstore", config, dataset_type)

        self.pathmanager = create_path_manager()
        self.config = config
        self.batch_size = get_batch_size()
        self.airstore_uri = config.annotations.get(
            dataset_type)[imdb_file_index]
        self.split = dataset_type
        self.epoch = 0
        self.start_iter = 0
        self.global_rank = torch.distributed.get_rank()
        self.global_world_size = torch.distributed.get_world_size()
        self._iterator = None
Ejemplo n.º 9
0
 def __len__(self):
     # Since, this is iterator, we need to return total length == number of batches
     batch_size = get_batch_size()
     # Changed the length to accomadate drop_last == True
     # drop_last is required if the batch is split into multiple cores
     # some of the cores may not have enough examples.
     if is_xla():
         logging.info(
             "drop_last is set to True to avoid uneven dimension shapes "
             "across cores.")
         return (self._total_length) // batch_size
     else:
         # This assumes drop_last=False for all loaders. See also
         # build_dataloader_and_sampler().
         return (self._total_length + batch_size - 1) // batch_size
Ejemplo n.º 10
0
    def test_exit_on_nan_losses(self, a):
        config = self._get_config(max_updates=2, max_epochs=None, batch_size=4)
        trainer = TrainerTrainingLoopMock(config=config)
        add_model(trainer, SimpleNaNLossModel({"in_dim": 1}))
        add_optimizer(trainer, config)
        registry.register("config", trainer.config)
        batch_size = get_batch_size()
        trainer.config.training.batch_size = batch_size
        trainer.load_datasets()

        exception_raised = False
        try:
            trainer.training_loop()
        except RuntimeError:
            exception_raised = True
        self.assertTrue(exception_raised)
Ejemplo n.º 11
0
    def _add_extra_args_for_dataloader(self, dataset, opts, other_args=None):
        if other_args is None:
            other_args = {}
        dataset_type = self._dataset_type

        other_args["shuffle"] = False
        if dataset_type != "test":
            other_args["shuffle"] = True

        # In distributed mode, we use DistributedSampler from PyTorch
        if torch.distributed.is_initialized():
            other_args["sampler"] = DistributedSampler(
                dataset, shuffle=other_args["shuffle"])
            # Shuffle is mutually exclusive with sampler, let DistributedSampler
            # take care of shuffle and pop from main args
            other_args.pop("shuffle")

        other_args["batch_size"] = get_batch_size()

        return other_args
Ejemplo n.º 12
0
    def __init__(
        self,
        num_train_data,
        max_updates,
        max_epochs,
        config=None,
        optimizer=None,
        update_frequency=1,
        batch_size=1,
        batch_size_per_device=None,
        fp16=False,
        on_update_end_fn=None,
        scheduler_config=None,
        grad_clipping_config=None,
    ):
        if config is None:
            self.config = OmegaConf.create(
                {
                    "training": {
                        "detect_anomaly": False,
                        "evaluation_interval": 10000,
                        "update_frequency": update_frequency,
                        "fp16": fp16,
                        "batch_size": batch_size,
                        "batch_size_per_device": batch_size_per_device,
                    }
                }
            )
            self.training_config = self.config.training
        else:
            self.training_config = config.training
            self.config = config

        # Load batch size with custom config and cleanup
        original_config = registry.get("config")
        registry.register("config", self.config)
        batch_size = get_batch_size()
        registry.register("config", original_config)

        if max_updates is not None:
            self.training_config["max_updates"] = max_updates
        if max_epochs is not None:
            self.training_config["max_epochs"] = max_epochs

        self.model = SimpleModel({"in_dim": 1})
        self.model.build()
        if torch.cuda.is_available():
            self.model = self.model.cuda()
            self.device = "cuda"
        else:
            self.device = "cpu"
        self.distributed = False

        self.dataset_loader = MagicMock()
        self.dataset_loader.seed_sampler = MagicMock(return_value=None)
        self.dataset_loader.prepare_batch = lambda x: SampleList(x)
        if optimizer is None:
            self.optimizer = MagicMock()
            self.optimizer.step = MagicMock(return_value=None)
            self.optimizer.zero_grad = MagicMock(return_value=None)
        else:
            self.optimizer = optimizer

        if scheduler_config:
            config.training.lr_scheduler = True
            config.scheduler = scheduler_config
            self.lr_scheduler_callback = LRSchedulerCallback(config, self)
            self.callbacks.append(self.lr_scheduler_callback)
            on_update_end_fn = (
                on_update_end_fn
                if on_update_end_fn
                else self.lr_scheduler_callback.on_update_end
            )

        if grad_clipping_config:
            self.training_config.clip_gradients = True
            self.training_config.max_grad_l2_norm = grad_clipping_config[
                "max_grad_l2_norm"
            ]
            self.training_config.clip_norm_mode = grad_clipping_config["clip_norm_mode"]

        dataset = NumbersDataset(num_train_data)
        self.train_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=1,
            drop_last=False,
        )
        self.train_loader.current_dataset = dataset
        self.on_batch_start = MagicMock(return_value=None)
        self.on_update_start = MagicMock(return_value=None)
        self.logistics_callback = MagicMock(return_value=None)
        self.logistics_callback.log_interval = MagicMock(return_value=None)
        self.on_batch_end = MagicMock(return_value=None)
        self.on_update_end = (
            on_update_end_fn if on_update_end_fn else MagicMock(return_value=None)
        )
        self.meter = Meter()
        self.after_training_loop = MagicMock(return_value=None)
        self.on_validation_start = MagicMock(return_value=None)
        self.evaluation_loop = MagicMock(return_value=(None, None))
        self.scaler = torch.cuda.amp.GradScaler(enabled=False)
        self.val_loader = MagicMock(return_value=None)
        self.early_stop_callback = MagicMock(return_value=None)
        self.on_validation_end = MagicMock(return_value=None)
        self.metrics = MagicMock(return_value=None)
Ejemplo n.º 13
0
 def __len__(self):
     # Since, this is iterator, we need to return total length == number of batches
     return self._total_length // get_batch_size()
Ejemplo n.º 14
0
 def __len__(self):
     # Since, this is iterator, we need to return total length == number of batches
     batch_size = get_batch_size()
     # This assumes drop_last=False for all loaders. See also
     # build_dataloader_and_sampler().
     return (self._total_length + batch_size - 1) // batch_size