Example #1
0
    def train(self):
        self.writer.write("===== Model =====")
        self.writer.write(self.model)

        print_model_parameters(self.model)

        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_updates = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)
        self.writer.write("Starting training...")

        while self.num_updates < self.max_updates and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.dataset_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.num_updates + 1, "debug")

                report = self._forward_pass(batch)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if self.num_updates > self.max_updates:
                    should_break = True

                if should_break:
                    break

            # In distributed, each worker will complete one epoch when we reach this
            # as each worker is an individual instance
            self.current_epoch += get_world_size() - 1
        self.finalize()
Example #2
0
def get_batch_size():
    from mmf.utils.configuration import get_global_config

    batch_size = get_global_config("training.batch_size")

    world_size = get_world_size()

    if batch_size % world_size != 0:
        raise RuntimeError("Batch size {} must be divisible by number "
                           "of GPUs {} used.".format(batch_size, world_size))

    return batch_size // world_size
Example #3
0
    def run_training_epoch(self) -> None:
        should_break = False
        while self.num_updates < self.max_updates and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.dataset_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.num_updates + 1, "debug")

                self.run_training_batch(batch)

                # Check if training should be stopped
                should_break = False

                if self.num_updates % self.training_config.evaluation_interval == 0:
                    # Validation begin callbacks
                    self.on_validation_start()

                    self.writer.write(
                        "Evaluation time. Running on full validation set...")
                    # Validation and Early stopping
                    # Create a new meter for this case
                    report, meter = self.evaluation_loop(self.val_loader)

                    # Validation end callbacks
                    stop = self.early_stop_callback.on_validation_end(
                        report=report, meter=meter)
                    self.on_validation_end(report=report, meter=meter)

                    gc.collect()

                    if "cuda" in str(self.device):
                        torch.cuda.empty_cache()

                    if stop is True:
                        self.writer.write("Early stopping activated")
                        should_break = True
                if self.num_updates > self.max_updates:
                    should_break = True
                if should_break:
                    break

            # In distributed, each worker will complete one epoch when we reach this
            # as each worker is an individual instance
            self.current_epoch += get_world_size() - 1
Example #4
0
    def forward(self, outputs, targets):
        """This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for
                      the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see
                      each loss' doc
        """
        outputs_without_aux = {
            k: v
            for k, v in outputs.items() if k != "aux_outputs"
        }

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes accross all nodes, for
        # normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes],
                                    dtype=torch.float,
                                    device=next(iter(outputs.values())).device)
        if is_dist_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(
                self.get_loss(loss, outputs, targets, indices, num_boxes))

        # In case of auxiliary losses, we repeat this process with the output of each
        # intermediate layer.
        if "aux_outputs" in outputs:
            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    kwargs = {}
                    if loss in ("labels", "labels_balanced"):
                        # Logging is enabled only for the last layer
                        kwargs = {"log": False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices,
                                           num_boxes, **kwargs)
                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses
 def __len__(self) -> int:
     # Since, this is iterator, we need to return total length == number of batches
     # and as get_batch_size returns per GPU batch size, it needs to be multiplied
     # by world size
     batch_size = get_batch_size() * get_world_size()
     # Changed the length to accomadate drop_last == True
     # drop_last is required if the batch is split into multiple cores
     # some of the cores may not have enough examples.
     if is_xla():
         logging.info(
             "drop_last is set to True to avoid uneven dimension shapes "
             "across cores.")
         return (self._total_length) // batch_size
     else:
         # This assumes drop_last=False for all loaders. See also
         # build_dataloader_and_sampler().
         return (self._total_length + batch_size - 1) // batch_size
Example #6
0
def get_batch_size():
    from mmf.utils.configuration import get_global_config

    batch_size = get_global_config("training.batch_size")
    world_size = get_world_size()

    batch_size_per_device = get_global_config("training.batch_size_per_device")

    if batch_size_per_device is not None:
        logger.info(
            f"training.batch_size_per_device has been used as {batch_size_per_device} "
            +
            "This will override training.batch_size and set the global batch size to "
            + f"{batch_size_per_device} x {world_size} = " +
            f"{batch_size_per_device * world_size}")
        batch_size = batch_size_per_device * world_size

    if batch_size % world_size != 0:
        raise RuntimeError("Batch size {} must be divisible by number "
                           "of GPUs {} used.".format(batch_size, world_size))

    return batch_size // world_size
Example #7
0
    def parallelize_model(self) -> None:
        registry.register("data_parallel", False)
        registry.register("distributed", False)
        if ("cuda" in str(self.device) and torch.cuda.device_count() > 1
                and not self.distributed):
            registry.register("data_parallel", True)
            self.model = torch.nn.DataParallel(self.model)

        if "cuda" in str(self.device) and self.distributed:
            registry.register("distributed", True)
            set_torch_ddp = True
            try:
                from fairscale.nn.data_parallel import ShardedDataParallel
                from fairscale.optim.oss import OSS

                if isinstance(self.optimizer, OSS):
                    self.model = ShardedDataParallel(self.model,
                                                     self.optimizer)
                    set_torch_ddp = False
                    logger.info("Using FairScale ShardedDataParallel")
            except ImportError:
                logger.info("Using PyTorch DistributedDataParallel")
                warnings.warn(
                    "You can enable ZeRO and Sharded DDP, by installing fairscale "
                    + "and setting optimizer.enable_state_sharding=True.")

            if set_torch_ddp:
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model,
                    device_ids=[self.local_rank],
                    output_device=self.local_rank,
                    find_unused_parameters=self.config.training.
                    find_unused_parameters,
                )

        if is_xla() and get_world_size() > 1:
            broadcast_xla_master_model_param(self.model)