Beispiel #1
0
def build_dataloader_and_sampler(
    dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
            dataloader has to be created
        datamodule_config (omegaconf.DictConfig): Datamodule configuration; required
            for infering params for dataloader

    Returns:
        Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
            Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    training_config = get_global_config("training")
    # Support params coming in from dataloader params
    other_args = {
        "num_workers": datamodule_config.get(
            "num_workers", training_config.get("num_workers", 4)
        ),
        "pin_memory": datamodule_config.get(
            "pin_memory", training_config.get("pin_memory", False)
        ),
        "shuffle": datamodule_config.get("shuffle", None),
        "batch_size": datamodule_config.get("batch_size", None),
    }

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance, other_args)
    else:
        other_args.pop("shuffle")

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        collate_fn=BatchCollator(
            dataset_instance.dataset_name, dataset_instance.dataset_type
        ),
        drop_last=is_xla(),  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = xla_pl.MpDeviceLoader(loader, device)

    if other_args["num_workers"] >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
Beispiel #2
0
def summarize_report(
    current_iteration,
    num_updates,
    max_updates,
    meter,
    should_print=True,
    extra=None,
    tb_writer=None,
):
    if extra is None:
        extra = {}
    if not is_master() and not is_xla():
        return

    if tb_writer:
        scalar_dict = meter.get_scalar_dict()
        tb_writer.add_scalars(scalar_dict, current_iteration)

    if not should_print:
        return
    log_dict = {}
    if num_updates is not None and max_updates is not None:
        log_dict.update({"progress": f"{num_updates}/{max_updates}"})

    log_dict.update(meter.get_log_dict())
    log_dict.update(extra)

    log_progress(log_dict)
Beispiel #3
0
def _add_extra_args_for_dataloader(
    dataset_instance: torch.utils.data.Dataset, other_args: Dict[str, Any] = None
) -> Dict[str, Any]:
    from mmf.utils.general import get_batch_size

    dataset_type = dataset_instance.dataset_type

    if other_args["shuffle"] is None:
        other_args["shuffle"] = False
        if dataset_type != "test":
            other_args["shuffle"] = True

    # In distributed mode, we use DistributedSampler from PyTorch
    if is_dist_initialized():
        other_args["sampler"] = torch.utils.data.DistributedSampler(
            dataset_instance, shuffle=other_args["shuffle"]
        )
        # Shuffle is mutually exclusive with sampler, let DistributedSampler
        # take care of shuffle and pop from main args
        other_args.pop("shuffle")

    if is_xla():
        other_args["sampler"] = torch.utils.data.DistributedSampler(
            dataset_instance,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=other_args["shuffle"],
            drop_last=True,
        )
        other_args.pop("shuffle")

    if other_args["batch_size"] is None:
        other_args["batch_size"] = get_batch_size()

    return other_args
Beispiel #4
0
def get_current_device():
    if is_xla():
        import torch_xla.core.xla_model as xm

        return xm.xla_device()
    if torch.cuda.is_available():
        return f"cuda:{torch.cuda.current_device()}"
    else:
        return torch.device("cpu")
 def __len__(self):
     # Since, this is iterator, we need to return total length == number of batches
     batch_size = get_batch_size()
     # Changed the length to accomadate drop_last == True
     # drop_last is required if the batch is split into multiple cores
     # some of the cores may not have enough examples.
     if is_xla():
         logging.info(
             "drop_last is set to True to avoid uneven dimension shapes "
             "across cores.")
         return (self._total_length) // batch_size
     else:
         # This assumes drop_last=False for all loaders. See also
         # build_dataloader_and_sampler().
         return (self._total_length + batch_size - 1) // batch_size
Beispiel #6
0
    def __call__(self, update, iteration, meter):
        """
        Method to be called everytime you need to check whether to
        early stop or not
        Arguments:
            update {number}: Current update number
            iteration {number}: Current iteration number
        Returns:
            bool -- Tells whether early stopping occurred or not
        """
        # There are operations involving synchronization downstream
        # For XLA those calls must be executed from all cores
        # Therefore we do return here in case of XLA
        if not is_master() and not is_xla():
            return False

        value = meter.meters.get(self.early_stop_criteria, None)
        if value is None:
            raise ValueError("Criteria used for early stopping ({}) is not "
                             "present in meter.".format(
                                 self.early_stop_criteria))

        value = value.global_avg

        if isinstance(value, torch.Tensor):
            value = value.item()

        if (self.minimize and value < self.best_monitored_value) or (
                not self.minimize and value > self.best_monitored_value):
            self.best_monitored_value = value
            self.best_monitored_iteration = iteration
            self.best_monitored_update = update
            self.checkpoint.save(update, iteration, update_best=True)

        elif self.best_monitored_update + self.patience < update:
            self.activated = True
            if self.should_stop is True:
                self.checkpoint.restore()
                self.checkpoint.finalize()
                return True
            else:
                return False
        else:
            self.checkpoint.save(update, iteration, update_best=False)

        return False
Beispiel #7
0
 def _check_nan_losses(self, report):
     # skip this check in XLA mode as calling .item() in forward pass
     # greatly slows down the training
     if not is_xla():
         # check whether NaN has occurred in the losses, and exit the training
         # when NaN happens
         loss_dict = report.losses
         nan_loss_keys = []
         for key, value in loss_dict.items():
             if torch.any(torch.isnan(value)).item():
                 nan_loss_keys.append(key)
         if len(nan_loss_keys) > 0:
             keys_str = ", ".join(nan_loss_keys)
             error_msg = (
                 f"NaN occurred in the following loss(es): {keys_str}; "
                 f"exiting the training")
             logger.info(error_msg)
             raise RuntimeError(error_msg)
Beispiel #8
0
    def _finish_update(self):
        if self.training_config.clip_gradients:
            clip_gradients(
                self.model,
                self.num_updates,
                self.logistics_callback.tb_writer,
                self.config,
                scale=self.scaler.get_scale(),
            )
        if is_xla():
            import torch_xla.core.xla_model as xm

            # Assumes no model parallel
            xm.reduce_gradients(self.optimizer)

        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.num_updates += 1
        self.profile("Finished update")
def validate_batch_sizes(my_batch_size: int) -> bool:
    """
    Validates all workers got the same batch size.
    """

    # skip batch size validation on XLA (as there's too much overhead
    # and data loader automatically drops the last batch in XLA mode)
    if is_xla():
        return True

    batch_size_tensor = torch.IntTensor([my_batch_size])
    if torch.cuda.is_available():
        batch_size_tensor = batch_size_tensor.cuda()
    all_batch_sizes = gather_tensor(batch_size_tensor)
    for j, oth_batch_size in enumerate(all_batch_sizes.data):
        if oth_batch_size != my_batch_size:
            logger.error(f"Node {j} batch {oth_batch_size} != {my_batch_size}")
            return False
    return True
Beispiel #10
0
def summarize_report(
    current_iteration,
    num_updates,
    max_updates,
    meter,
    should_print=True,
    extra=None,
    tb_writer=None,
    wandb_logger=None,
):
    if extra is None:
        extra = {}
    if not is_main() and not is_xla():
        return

    # Log the learning rate if available
    if wandb_logger and "lr" in extra:
        wandb_logger.log_metrics({"train/learning_rate": float(extra["lr"])},
                                 commit=False)

    if tb_writer:
        scalar_dict = meter.get_scalar_dict()
        tb_writer.add_scalars(scalar_dict, current_iteration)

    if wandb_logger:
        metrics = meter.get_scalar_dict()
        wandb_logger.log_metrics({
            **metrics, "trainer/global_step":
            current_iteration
        })

    if not should_print:
        return
    log_dict = {}
    if num_updates is not None and max_updates is not None:
        log_dict.update({"progress": f"{num_updates}/{max_updates}"})

    log_dict.update(meter.get_log_dict())
    log_dict.update(extra)

    log_progress(log_dict)
Beispiel #11
0
    def _summarize_report(self, meter, should_print=True, extra=None):
        if extra is None:
            extra = {}
        if not is_master() and not is_xla():
            return

        if self.training_config.tensorboard:
            scalar_dict = meter.get_scalar_dict()
            self.tb_writer.add_scalars(scalar_dict, self.trainer.current_iteration)

        if not should_print:
            return
        log_dict = {}
        if hasattr(self.trainer, "num_updates") and hasattr(
            self.trainer, "max_updates"
        ):
            log_dict.update(
                {"progress": f"{self.trainer.num_updates}/{self.trainer.max_updates}"}
            )
        log_dict.update(meter.get_log_dict())
        log_dict.update(extra)

        log_progress(log_dict)
Beispiel #12
0
    def parallelize_model(self) -> None:
        registry.register("data_parallel", False)
        registry.register("distributed", False)
        if ("cuda" in str(self.device) and torch.cuda.device_count() > 1
                and not self.distributed):
            registry.register("data_parallel", True)
            self.model = torch.nn.DataParallel(self.model)

        if "cuda" in str(self.device) and self.distributed:
            registry.register("distributed", True)
            set_torch_ddp = True
            try:
                from fairscale.nn.data_parallel import ShardedDataParallel
                from fairscale.optim.oss import OSS

                if isinstance(self.optimizer, OSS):
                    self.model = ShardedDataParallel(self.model,
                                                     self.optimizer)
                    set_torch_ddp = False
                    logger.info("Using FairScale ShardedDataParallel")
            except ImportError:
                logger.info("Using PyTorch DistributedDataParallel")
                warnings.warn(
                    "You can enable ZeRO and Sharded DDP, by installing fairscale "
                    + "and setting optimizer.enable_state_sharding=True.")

            if set_torch_ddp:
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model,
                    device_ids=[self.local_rank],
                    output_device=self.local_rank,
                    find_unused_parameters=self.config.training.
                    find_unused_parameters,
                )

        if is_xla() and get_world_size() > 1:
            broadcast_xla_master_model_param(self.model)
Beispiel #13
0
def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False):
    """Run starts a job based on the command passed from the command line.
    You can optionally run the mmf job programmatically by passing an optlist as opts.

    Args:
        opts (typing.Optional[typing.List[str]], optional): Optlist which can be used.
            to override opts programmatically. For e.g. if you pass
            opts = ["training.batch_size=64", "checkpoint.resume=True"], this will
            set the batch size to 64 and resume from the checkpoint if present.
            Defaults to None.
        predict (bool, optional): If predict is passed True, then the program runs in
            prediction mode. Defaults to False.
    """
    setup_imports()

    if opts is None:
        parser = flags.get_parser()
        args = parser.parse_args()
    else:
        args = argparse.Namespace(config_override=None)
        args.opts = opts

    configuration = Configuration(args)
    # Do set runtime args which can be changed by MMF
    configuration.args = args
    config = configuration.get_config()
    config.start_rank = 0
    if config.distributed.init_method is None:
        infer_init_method(config)

    if config.distributed.init_method is not None:
        if torch.cuda.device_count() > 1 and not config.distributed.no_spawn:
            config.start_rank = config.distributed.rank
            config.distributed.rank = None
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(configuration, predict),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(0, configuration, predict)
    elif config.distributed.world_size > 1:
        if is_xla():
            import torch_xla.distributed.xla_multiprocessing as xmp

            torch.multiprocessing.set_sharing_strategy("file_system")
            xmp.spawn(
                fn=distributed_main,
                args=(configuration, predict),
                nprocs=8,  # use all 8 TPU cores
                start_method="fork",
            )
        else:
            assert config.distributed.world_size <= torch.cuda.device_count()
            port = random.randint(10000, 20000)
            config.distributed.init_method = f"tcp://localhost:{port}"
            config.distributed.rank = None
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(configuration, predict),
                nprocs=config.distributed.world_size,
            )
    else:
        config.device_id = 0
        main(configuration, predict=predict)
Beispiel #14
0
def build_dataloader_and_sampler(
    dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
            dataloader has to be created
        datamodule_config (omegaconf.DictConfig): Datamodule configuration; required
            for infering params for dataloader

    Returns:
        Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
            Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    training_config = get_global_config("training")
    # Support params coming in from dataloader params
    other_args = {
        "num_workers": datamodule_config.get(
            "num_workers", training_config.get("num_workers", 4)
        ),
        "pin_memory": datamodule_config.get(
            "pin_memory", training_config.get("pin_memory", False)
        ),
        "shuffle": datamodule_config.get("shuffle", None),
        "batch_size": datamodule_config.get("batch_size", None),
    }
    if version.parse(torch.__version__) >= version.parse("1.8"):
        # only use persistent workers in PyTorch 1.8 or higher
        # (PyTorch 1.7 also has this option but doesn't support it correctly due to
        # https://github.com/pytorch/pytorch/issues/48370)
        other_args["persistent_workers"] = (
            datamodule_config.get(
                "persistent_workers", training_config.get("persistent_workers", True)
            ),
        )
        if other_args["persistent_workers"] and other_args["num_workers"] == 0:
            logger.warning(
                "persistent_workers cannot be used together with num_workers == 0; "
                "setting persistent_workers to False"
            )
            other_args["persistent_workers"] = False

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance, other_args)
    else:
        other_args.pop("shuffle")

    # Set drop_last=True when using XLA to have constant batch size.
    # In this case we also need to set drop_last=True in DistributedSampler.
    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        collate_fn=BatchCollator(
            dataset_instance.dataset_name, dataset_instance.dataset_type
        ),
        drop_last=is_xla(),  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = xla_pl.MpDeviceLoader(loader, device)

    if other_args["num_workers"] >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
Beispiel #15
0
 def save_func(self, *args):
     return save_xla_ckpt(*args) if is_xla() else torch.save(*args)
Beispiel #16
0
 def finalize(self):
     if is_master() or is_xla():
         with PathManager.open(self.pth_filepath, "wb") as f:
             self.save_func(self.trainer.model.state_dict(), f)
Beispiel #17
0
    def save(self, update, iteration=None, update_best=False):
        # Only save in main process
        # For xla we use xm.save method
        # Which ensures that actual checkpoint saving happens
        # only for the master node.
        # The method also takes care of all the necessary synchronization
        if not is_master() and not is_xla():
            return

        logger.info("Checkpoint save operation started!")
        if not iteration:
            iteration = update

        ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update)
        best_ckpt_filepath = os.path.join(
            self.ckpt_foldername, self.ckpt_prefix + "best.ckpt"
        )
        current_ckpt_filepath = os.path.join(
            self.ckpt_foldername, self.ckpt_prefix + "current.ckpt"
        )

        best_iteration = (
            self.trainer.early_stop_callback.early_stopping.best_monitored_iteration
        )
        best_update = (
            self.trainer.early_stop_callback.early_stopping.best_monitored_update
        )
        best_metric = (
            self.trainer.early_stop_callback.early_stopping.best_monitored_value
        )
        model = self.trainer.model
        data_parallel = registry.get("data_parallel") or registry.get("distributed")
        fp16_scaler = getattr(self.trainer, "scaler", None)
        fp16_scaler_dict = None

        if fp16_scaler is not None:
            fp16_scaler_dict = fp16_scaler.state_dict()

        if data_parallel is True:
            model = model.module

        ckpt = {
            "model": model.state_dict(),
            "optimizer": self.trainer.optimizer.state_dict(),
            "best_iteration": best_iteration,
            "current_iteration": iteration,
            "current_epoch": self.trainer.current_epoch,
            "num_updates": update,
            "best_update": best_update,
            "best_metric_value": best_metric,
            "fp16_scaler": fp16_scaler_dict,
            # Convert to container to avoid any dependencies
            "config": OmegaConf.to_container(self.config, resolve=True),
        }

        lr_scheduler = self.trainer.lr_scheduler_callback._scheduler
        if lr_scheduler is not None:
            ckpt["lr_scheduler"] = lr_scheduler.state_dict()

        if self.git_repo:
            git_metadata_dict = self._get_vcs_fields()
            ckpt.update(git_metadata_dict)

        with PathManager.open(ckpt_filepath, "wb") as f:
            self.save_func(ckpt, f)

        if update_best:
            logger.info("Saving best checkpoint")
            with PathManager.open(best_ckpt_filepath, "wb") as f:
                self.save_func(ckpt, f)

        # Save current always

        logger.info("Saving current checkpoint")
        with PathManager.open(current_ckpt_filepath, "wb") as f:
            self.save_func(ckpt, f)

        # Remove old checkpoints if max_to_keep is set
        if self.max_to_keep > 0:
            if len(self.saved_iterations) == self.max_to_keep:
                self.remove(self.saved_iterations.pop(0))
            self.saved_iterations.append(update)

        logger.info("Checkpoint save operation finished!")
Beispiel #18
0
 def save_func(self, *args):
     return xm.save(*args) if is_xla() else torch.save(*args)
Beispiel #19
0
 def finalize(self):
     if is_main() or is_xla():
         with open_if_main(self.pth_filepath, "wb") as f:
             self.save_func(self.trainer.model.state_dict(), f)
Beispiel #20
0
def build_dataloader_and_sampler(
    dataset_instance: mmf_typings.DatasetType,
    training_config: mmf_typings.DictConfig
) -> mmf_typings.DataLoaderAndSampler:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (mmf_typings.DatasetType): Instance of dataset for which
            dataloader has to be created
        training_config (mmf_typings.DictConfig): Training configuration; required
            for infering params for dataloader

    Returns:
        mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    num_workers = training_config.num_workers
    pin_memory = training_config.pin_memory

    other_args = {}

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance,
                                                    other_args)

    if is_xla():
        other_args["sampler"] = torch.utils.data.DistributedSampler(
            dataset_instance,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=other_args["shuffle"],
        )
        other_args.pop("shuffle")

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        pin_memory=pin_memory,
        collate_fn=BatchCollator(dataset_instance.dataset_name,
                                 dataset_instance.dataset_type),
        num_workers=num_workers,
        drop_last=False,  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = pl.MpDeviceLoader(loader, device)

    if num_workers >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)