def load(self, checkpoint_path: str): r""" Load a serialized checkpoint from a path. This method will try to find each of :attr:`checkpointables` in the file and load its state dict. Since our checkpointables are held as references, this method does not return them. Parameters ---------- checkpoint_path: str Path to a checkpoint serialized by :meth:`step`. Returns ------- int Iteration corresponding to the loaded checkpoint. Useful for resuming training. This will be -1 in case of best checkpoint, or if info does not exist. """ # Each process will log a message after loading checkpoint. rank = dist.get_rank() logger.info(f"Rank {rank}: Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") iteration = checkpoint.pop("iteration", -1) # Keep flags of all checkpointables to lo which ones were not loaded. is_loaded = {key: False for key in self.checkpointables} # Load each checkpointable from checkpoint. for key in checkpoint: if key in self.checkpointables: logger.info( f"Rank {rank}: Loading {key} from {checkpoint_path}") if isinstance(self.checkpointables[key], nn.parallel.DistributedDataParallel): self.checkpointables[key].module.load_state_dict( checkpoint[key]) else: self.checkpointables[key].load_state_dict(checkpoint[key]) is_loaded[key] = True else: logger.info( f"Rank {rank}: {key} not found in `checkpointables`.") not_loaded: List[str] = [ key for key in is_loaded if not is_loaded[key] ] if len(not_loaded) > 0: logger.info( f"Rank {rank}: Checkpointables not found in file: {not_loaded}" ) return iteration
def __init__(self, cfg, weights: Union[str, Dict[str, Any]]): self.start_iter = 0 self.max_iter = cfg.SOLVER.MAX_ITER self.cfg = cfg # We do not make any super call here and implement `__init__` from # `DefaultTrainer`: we need to initialize mixed precision model before # wrapping to DDP, so we need to do it this way. model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) data_loader = self.build_train_loader(cfg) scheduler = self.build_lr_scheduler(cfg, optimizer) # Load pre-trained weights before wrapping to DDP because `ApexDDP` has # some weird issue with `DetectionCheckpointer`. # fmt: off if isinstance(weights, str): # weights are ``str`` means ImageNet init or resume training. self.start_iter = (DetectionCheckpointer( model, optimizer=optimizer, scheduler=scheduler).resume_or_load(weights, resume=True).get( "iteration", -1) + 1) elif isinstance(weights, dict): # weights are a state dict means our pretrain init. DetectionCheckpointer(model)._load_model(weights) # fmt: on # Enable distributed training if we have multiple GPUs. Use Apex DDP for # non-FPN backbones because its `delay_allreduce` functionality helps with # gradient checkpointing. if dist.get_world_size() > 1: if global_cfg.get("GRADIENT_CHECKPOINT", False): model = ApexDDP(model, delay_allreduce=True) else: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_rank()], broadcast_buffers=False) # Call `__init__` from grandparent class: `SimpleTrainer`. SimpleTrainer.__init__(self, model, data_loader, optimizer) self.scheduler = scheduler self.checkpointer = DetectionCheckpointer(model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=self.scheduler) self.register_hooks(self.build_hooks())
def common_setup(_C: Config, _A: argparse.Namespace, job_type: str = "pretrain"): r""" Setup common stuff at the start of every pretraining or downstream evaluation job, all listed here to avoid code duplication. Basic steps: 1. Fix random seeds and other PyTorch flags. 2. Set up a serialization directory and loggers. 3. Log important stuff such as config, process info (useful during distributed training). 4. Save a copy of config to serialization directory. .. note:: It is assumed that multiple processes for distributed training have already been launched from outside. Functions from :mod:`virtex.utils.distributed` module ae used to get process info. Args: _C: Config object with all the parameters. _A: Argparse command line arguments. job_type: Type of job for which setup is to be done; one of ``{"pretrain", "downstream"}``. """ # Get process rank and world size (assuming distributed is initialized). RANK = dist.get_rank() WORLD_SIZE = dist.get_world_size() # For reproducibility - refer https://pytorch.org/docs/stable/notes/randomness.html torch.manual_seed(_C.RANDOM_SEED) torch.backends.cudnn.deterministic = _C.CUDNN_DETERMINISTIC torch.backends.cudnn.benchmark = _C.CUDNN_BENCHMARK random.seed(_C.RANDOM_SEED) np.random.seed(_C.RANDOM_SEED) # Create serialization directory and save config in it. os.makedirs(_A.serialization_dir, exist_ok=True) _C.dump(os.path.join(_A.serialization_dir, f"{job_type}_config.yaml")) # Remove default logger, create a logger for each process which writes to a # separate log-file. This makes changes in global scope. logger.remove(0) if dist.get_world_size() > 1: logger.add( os.path.join(_A.serialization_dir, f"log-rank{RANK}.txt"), format="{time} {level} {message}", ) # Add a logger for stdout only for the master process. if dist.is_master_process(): logger.add(sys.stdout, format="<g>{time}</g>: <lvl>{message}</lvl>", colorize=True) # Print process info, config and args. logger.info(f"Rank of current process: {RANK}. World size: {WORLD_SIZE}") logger.info(str(_C)) logger.info("Command line args:") for arg in vars(_A): logger.info("{:<20}: {}".format(arg, getattr(_A, arg)))
def main(_A: argparse.Namespace): if _A.num_gpus_per_machine == 0: # Set device as CPU if num_gpus_per_machine = 0. device = torch.device("cpu") else: # Get the current device as set for current distributed process. # Check `launch` function in `virtex.utils.distributed` module. device = torch.cuda.current_device() # Create a downstream config object (this will be immutable) and perform # common setup such as logging and setting up serialization directory. _DOWNC = Config(_A.down_config, _A.down_config_override) common_setup(_DOWNC, _A, job_type="downstream") # Create a (pretraining) config object and backup in serializaion directory. _C = Config(_A.config, _A.config_override) _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) # Get dataset name for tensorboard logging. DATASET = _DOWNC.DATA.ROOT.split("/")[-1] # Set number of output classes according to dataset: NUM_CLASSES_MAPPING = {"imagenet": 1000, "inaturalist": 8142} NUM_CLASSES = NUM_CLASSES_MAPPING[DATASET] # ------------------------------------------------------------------------- # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER # ------------------------------------------------------------------------- train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="train") train_dataloader = DataLoader( train_dataset, batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), num_workers=_A.cpu_workers, sampler=DistributedSampler( train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True, ), drop_last=False, pin_memory=True, collate_fn=train_dataset.collate_fn, ) val_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="val") val_dataloader = DataLoader( val_dataset, batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), num_workers=_A.cpu_workers, sampler=DistributedSampler( val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, ), pin_memory=True, drop_last=False, collate_fn=val_dataset.collate_fn, ) # Initialize model using pretraining config. pretrained_model = PretrainingModelFactory.from_config(_C) # Load weights according to the init method, do nothing for `random`, and # `imagenet` is already taken care of. if _A.weight_init == "virtex": CheckpointManager(model=pretrained_model).load(_A.checkpoint_path) elif _A.weight_init == "torchvision": # Keep strict=False because this state dict may have weights for # last fc layer. pretrained_model.visual.cnn.load_state_dict( torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], strict=False, ) # Pull out the CNN (torchvision-like) from our pretrained model and add # back the FC layer - this is exists in torchvision models, and is set to # `nn.Identity()` during pretraining. model = pretrained_model.visual.cnn # type: ignore model.fc = nn.Linear(_DOWNC.MODEL.VISUAL.FEATURE_SIZE, NUM_CLASSES).to(device) model = model.to(device) # Re-initialize the FC layer. torch.nn.init.normal_(model.fc.weight.data, mean=0.0, std=0.01) torch.nn.init.constant_(model.fc.bias.data, 0.0) # Freeze all layers except FC as per config param. if _DOWNC.MODEL.VISUAL.FROZEN: for name, param in model.named_parameters(): if "fc" not in name: param.requires_grad = False # Cross entropy loss and accuracy meter. criterion = nn.CrossEntropyLoss() top1 = TopkAccuracy(top_k=1) optimizer = OptimizerFactory.from_config(_DOWNC, model.named_parameters()) scheduler = LRSchedulerFactory.from_config(_DOWNC, optimizer) del pretrained_model # ------------------------------------------------------------------------- # BEFORE TRAINING STARTS # ------------------------------------------------------------------------- # Create an iterator from dataloader to sample batches perpetually. train_dataloader_iter = cycle(train_dataloader, device) # Wrap model and optimizer using NVIDIA Apex for mixed precision training. # NOTE: Always do this before wrapping model with DistributedDataParallel. if _DOWNC.FP16_OPT > 0: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level=f"O{_DOWNC.FP16_OPT}") if dist.get_world_size() > 1: dist.synchronize() model = nn.parallel.DistributedDataParallel( model, device_ids=[device], find_unused_parameters=True) if dist.is_master_process(): checkpoint_manager = CheckpointManager( _A.serialization_dir, model=model, optimizer=optimizer, scheduler=scheduler, ) tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) # Keep track of time per iteration and ETA. timer = Timer(start_from=1, total_iterations=_DOWNC.OPTIM.NUM_ITERATIONS) # ------------------------------------------------------------------------- # TRAINING LOOP # ------------------------------------------------------------------------- for iteration in range(1, _DOWNC.OPTIM.NUM_ITERATIONS + 1): timer.tic() optimizer.zero_grad() batch = next(train_dataloader_iter) logits = model(batch["image"]) loss = criterion(logits, batch["label"]) # Perform dynamic scaling of loss to adjust for mixed precision. if _DOWNC.FP16_OPT > 0: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() scheduler.step(iteration) timer.toc() if iteration % _A.log_every == 0 and dist.is_master_process(): logger.info( f"{timer.stats} | Loss: {loss:.3f} | GPU: {dist.gpu_mem_usage()} MB" ) tensorboard_writer.add_scalar(f"{DATASET}/train_loss", loss, iteration) tensorboard_writer.add_scalar( f"{DATASET}/learning_rate", optimizer.param_groups[0]["lr"], iteration, ) # --------------------------------------------------------------------- # VALIDATION # --------------------------------------------------------------------- if iteration % _A.checkpoint_every == 0: torch.set_grad_enabled(False) model.eval() total_val_loss = torch.tensor(0.0).to(device) for val_iteration, batch in enumerate(val_dataloader, start=1): for key in batch: batch[key] = batch[key].to(device) logits = model(batch["image"]) loss = criterion(logits, batch["label"]) top1(logits, batch["label"]) total_val_loss += loss # Divide each loss component by number of val batches per GPU. total_val_loss = total_val_loss / val_iteration dist.average_across_processes(total_val_loss) # Get accumulated Top-1 accuracy for logging across GPUs. acc = top1.get_metric(reset=True) dist.average_across_processes(acc) torch.set_grad_enabled(True) model.train() # Save recent checkpoint and best checkpoint based on accuracy. if dist.is_master_process(): checkpoint_manager.step(iteration) if iteration % _A.checkpoint_every == 0 and dist.is_master_process(): logger.info(f"Iter: {iteration} | Top-1 accuracy: {acc})") tensorboard_writer.add_scalar(f"{DATASET}/val_loss", total_val_loss, iteration) # This name scoping will result in Tensorboard displaying all metrics # (VOC07, caption, etc.) together. tensorboard_writer.add_scalars(f"metrics/{DATASET}", {"top1": acc}, iteration) # All processes will wait till master process is done logging. dist.synchronize()