def train(self): self.writer.write("===== Model =====") self.writer.write(self.model) print_model_parameters(self.model) if "train" not in self.run_type: self.inference() return should_break = False if self.max_epochs is None: self.max_epochs = math.inf else: self.max_updates = math.inf self.model.train() self.train_timer = Timer() self.snapshot_timer = Timer() self.profile("Setup Time") torch.autograd.set_detect_anomaly(True) self.writer.write("Starting training...") while self.num_updates < self.max_updates and not should_break: self.current_epoch += 1 registry.register("current_epoch", self.current_epoch) # Seed the sampler in case if it is distributed self.dataset_loader.seed_sampler("train", self.current_epoch) if self.current_epoch > self.max_epochs: break for batch in self.train_loader: self.profile("Batch load time") self.current_iteration += 1 self.writer.write(self.num_updates + 1, "debug") report = self._forward_pass(batch) loss = self._extract_loss(report) self._backward(loss) should_break = self._logistics(report) if self.num_updates > self.max_updates: should_break = True if should_break: break # In distributed, each worker will complete one epoch when we reach this # as each worker is an individual instance self.current_epoch += get_world_size() - 1 self.finalize()
def get_batch_size(): from mmf.utils.configuration import get_global_config batch_size = get_global_config("training.batch_size") world_size = get_world_size() if batch_size % world_size != 0: raise RuntimeError("Batch size {} must be divisible by number " "of GPUs {} used.".format(batch_size, world_size)) return batch_size // world_size
def run_training_epoch(self) -> None: should_break = False while self.num_updates < self.max_updates and not should_break: self.current_epoch += 1 registry.register("current_epoch", self.current_epoch) # Seed the sampler in case if it is distributed self.dataset_loader.seed_sampler("train", self.current_epoch) if self.current_epoch > self.max_epochs: break for batch in self.train_loader: self.profile("Batch load time") self.current_iteration += 1 self.writer.write(self.num_updates + 1, "debug") self.run_training_batch(batch) # Check if training should be stopped should_break = False if self.num_updates % self.training_config.evaluation_interval == 0: # Validation begin callbacks self.on_validation_start() self.writer.write( "Evaluation time. Running on full validation set...") # Validation and Early stopping # Create a new meter for this case report, meter = self.evaluation_loop(self.val_loader) # Validation end callbacks stop = self.early_stop_callback.on_validation_end( report=report, meter=meter) self.on_validation_end(report=report, meter=meter) gc.collect() if "cuda" in str(self.device): torch.cuda.empty_cache() if stop is True: self.writer.write("Early stopping activated") should_break = True if self.num_updates > self.max_updates: should_break = True if should_break: break # In distributed, each worker will complete one epoch when we reach this # as each worker is an individual instance self.current_epoch += get_world_size() - 1
def forward(self, outputs, targets): """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = { k: v for k, v in outputs.items() if k != "aux_outputs" } # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for # normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update( self.get_loss(loss, outputs, targets, indices, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each # intermediate layer. if "aux_outputs" in outputs: for i, aux_outputs in enumerate(outputs["aux_outputs"]): indices = self.matcher(aux_outputs, targets) for loss in self.losses: kwargs = {} if loss in ("labels", "labels_balanced"): # Logging is enabled only for the last layer kwargs = {"log": False} l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f"_{i}": v for k, v in l_dict.items()} losses.update(l_dict) return losses
def __len__(self) -> int: # Since, this is iterator, we need to return total length == number of batches # and as get_batch_size returns per GPU batch size, it needs to be multiplied # by world size batch_size = get_batch_size() * get_world_size() # Changed the length to accomadate drop_last == True # drop_last is required if the batch is split into multiple cores # some of the cores may not have enough examples. if is_xla(): logging.info( "drop_last is set to True to avoid uneven dimension shapes " "across cores.") return (self._total_length) // batch_size else: # This assumes drop_last=False for all loaders. See also # build_dataloader_and_sampler(). return (self._total_length + batch_size - 1) // batch_size
def get_batch_size(): from mmf.utils.configuration import get_global_config batch_size = get_global_config("training.batch_size") world_size = get_world_size() batch_size_per_device = get_global_config("training.batch_size_per_device") if batch_size_per_device is not None: logger.info( f"training.batch_size_per_device has been used as {batch_size_per_device} " + "This will override training.batch_size and set the global batch size to " + f"{batch_size_per_device} x {world_size} = " + f"{batch_size_per_device * world_size}") batch_size = batch_size_per_device * world_size if batch_size % world_size != 0: raise RuntimeError("Batch size {} must be divisible by number " "of GPUs {} used.".format(batch_size, world_size)) return batch_size // world_size
def parallelize_model(self) -> None: registry.register("data_parallel", False) registry.register("distributed", False) if ("cuda" in str(self.device) and torch.cuda.device_count() > 1 and not self.distributed): registry.register("data_parallel", True) self.model = torch.nn.DataParallel(self.model) if "cuda" in str(self.device) and self.distributed: registry.register("distributed", True) set_torch_ddp = True try: from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.optim.oss import OSS if isinstance(self.optimizer, OSS): self.model = ShardedDataParallel(self.model, self.optimizer) set_torch_ddp = False logger.info("Using FairScale ShardedDataParallel") except ImportError: logger.info("Using PyTorch DistributedDataParallel") warnings.warn( "You can enable ZeRO and Sharded DDP, by installing fairscale " + "and setting optimizer.enable_state_sharding=True.") if set_torch_ddp: self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=self.config.training. find_unused_parameters, ) if is_xla() and get_world_size() > 1: broadcast_xla_master_model_param(self.model)