def test_periodic_checkpointer_max_to_keep(self) -> None: """ Test parameter: max_to_keep """ _period = 10 _max_iter = 100 _max_to_keep = 3 for trained_model in [ self._create_model(), nn.DataParallel(self._create_model()), ]: with TemporaryDirectory() as f: checkpointer = Checkpointer(trained_model, save_dir=f, save_to_disk=True) periodic_checkpointer = PeriodicCheckpointer( checkpointer, _period, 99, max_to_keep=_max_to_keep) for _ in range(2): checkpoint_paths = [] for iteration in range(_max_iter): periodic_checkpointer.step(iteration) if (iteration + 1) % _period == 0: path = os.path.join( f, "model_{:07d}.pth".format(iteration)) checkpoint_paths.append(path) for path in checkpoint_paths[:-_max_to_keep]: self.assertFalse(os.path.exists(path)) for path in checkpoint_paths[-_max_to_keep:]: self.assertTrue(os.path.exists(path))
def _setup_checkpointers(self, resume_from="", search=True, period=1, **add_checkpointables): """ Sets up a periodic chechkpointer which can be used to save checkpoints at every epoch. It will call optimizer's `get_checkpointables()` as objects to store. Args: resume_from (str): A checkpoint file to resume the search or evaluation from. search (bool): Whether search or evaluation phase is checkpointed. This is required because the files are in different folders to not be overridden add_checkpointables (object): Additional things to checkpoint together with the optimizer's checkpointables. """ checkpointables = self.optimizer.get_checkpointables() checkpointables.update(add_checkpointables) checkpointer = utils.Checkpointer( model=checkpointables.pop('model'), save_dir=self.config.save + "/search" if search else self.config.save + "/eval", #**checkpointables #NOTE: this is throwing an Error ) self.periodic_checkpointer = PeriodicCheckpointer( checkpointer, period=period, max_iter=self.config.search.epochs if search else self.config.evaluation.epochs ) if resume_from: logger.info("loading model from file {}".format(resume_from)) checkpoint = checkpointer.resume_or_load(resume_from, resume=True) if checkpointer.has_checkpoint(): return checkpoint.get("iteration", -1) + 1 return 0
def test_periodic_checkpointer(self) -> None: """ test that loading works even if they differ by a prefix. """ _period = 10 _max_iter = 100 for trained_model in [ self._create_model(), nn.DataParallel(self._create_model()), ]: with TemporaryDirectory() as f: checkpointer = Checkpointer( trained_model, save_dir=f, save_to_disk=True ) periodic_checkpointer = PeriodicCheckpointer(checkpointer, _period, 99) for iteration in range(_max_iter): periodic_checkpointer.step(iteration) path = os.path.join(f, "model_{:07d}.pth".format(iteration)) if (iteration + 1) % _period == 0: self.assertTrue(os.path.exists(path)) else: self.assertFalse(os.path.exists(path))
def do_train(cfg, model): model.train() optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) checkpointer = Checkpointer(model, './', optimizer=optimizer, scheduler=scheduler) max_iter = cfg.SOLVER.MAX_ITER periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter) writers = [CommonMetricPrinter(max_iter)] if d2_comm.is_main_process() else [] train_mapper = get_dataset_mapper(cfg, is_train=True) dataloader, dataset_dicts = build_train_dataloader(cfg, mapper=train_mapper) LOG.info("Length of train dataset: {:d}".format(len(dataset_dicts))) LOG.info("Starting training") storage = get_event_storage() if cfg.EVAL_ON_START: do_test(cfg, model) comm.synchronize() # In mixed-precision training, gradients are scaled up to keep them from being vanished due to half-precision. # They're scaled down again before optimizers use them to compute updates. scaler = amp.GradScaler(enabled=cfg.SOLVER.MIXED_PRECISION_ENABLED) # Accumulate gradients for multiple batches (as returned by dataloader) before calling optimizer.step(). accumulate_grad_batches = cfg.SOLVER.ACCUMULATE_GRAD_BATCHES num_images_seen = 0 # For logging, this stores losses aggregated from all workers in distributed training. batch_loss_dict = defaultdict(float) optimizer.zero_grad() for data, iteration in zip(dataloader, range(max_iter * accumulate_grad_batches)): iteration += 1 # this assumes drop_last=True, so all workers has the same size of batch. num_images_seen += len(data) * d2_comm.get_world_size() if iteration % accumulate_grad_batches == 0: storage.step() with amp.autocast(enabled=cfg.SOLVER.MIXED_PRECISION_ENABLED): loss_dict = model(data) # Account for accumulated gradients. loss_dict = {name: loss / accumulate_grad_batches for name, loss in loss_dict.items()} losses = sum(loss_dict.values()) # FIXME: First few iterations might give Inf/NaN losses when using mixed precision. What should be done? if not torch.isfinite(losses): LOG.critical(f"The loss DIVERGED: {loss_dict}") # Track total loss for logging. loss_dict_reduced = {k: v.item() for k, v in d2_comm.reduce_dict(loss_dict).items()} assert torch.isfinite(torch.as_tensor(list(loss_dict_reduced.values()))).all(), loss_dict_reduced for k, v in loss_dict_reduced.items(): batch_loss_dict[k] += v # No amp version: leaving this here for legacy: # losses.backward() scaler.scale(losses).backward() if iteration % accumulate_grad_batches > 0: # Just accumulate gradients and move on to next batch. continue # No amp version: leaving this here for legacy: # optimizer.step() # scheduler.step() # optimizer.zero_grad() scaler.step(optimizer) storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) scheduler.step() scaler.update() losses_reduced = sum(loss for loss in batch_loss_dict.values()) storage.put_scalars(total_loss=losses_reduced, **batch_loss_dict) # Reset states. batch_loss_dict = defaultdict(float) optimizer.zero_grad() batch_iter = iteration // accumulate_grad_batches # TODO: probably check if the gradients contain any inf or nan, and only proceed if not. if batch_iter > 5 and (batch_iter % 20 == 0 or batch_iter == max_iter): # if batch_iter > -1 and (batch_iter % 1 == 0 or batch_iter == max_iter): for writer in writers: writer.write() # log epoch / # images seen if d2_comm.is_main_process() and cfg.WANDB.ENABLED: wandb.log({"epoch": 1 + num_images_seen // len(dataset_dicts)}, step=batch_iter) wandb.log({"num_images_seen": num_images_seen}, step=batch_iter) if cfg.VIS.DATALOADER_ENABLED and batch_iter % cfg.VIS.DATALOADER_PERIOD == 0 and d2_comm.is_main_process(): dataset_name = cfg.DATASETS.TRAIN.NAME visualizer_names = MetadataCatalog.get(dataset_name).loader_visualizers viz_images = defaultdict(dict) for viz_name in visualizer_names: viz = get_dataloader_visualizer(cfg, viz_name, dataset_name) for idx, x in enumerate(data): viz_images[idx].update(viz.visualize(x)) if cfg.WANDB.ENABLED: # per_image_vis = [coalece_viz_images(viz_images[idx])[0] for idx in range(len(data))] per_image_vis = [mosaic(list(viz_images[idx].values())) for idx in range(len(data))] wandb.log({ "dataloader": [wandb.Image(vis, caption=f"idx={idx}") for idx, vis in enumerate(per_image_vis)] }, step=batch_iter) save_vis(viz_images, os.path.join(os.getcwd(), "visualization"), "dataloader", step=batch_iter) if d2_comm.is_main_process(): # TODO (dennis.park): is this necessary? periodic_checkpointer.step(batch_iter - 1) # (fvcore) model_0004999.pth checkpoints 5000-th iteration if batch_iter > 0 and batch_iter % cfg.SYNC_OUTPUT_DIR_S3.PERIOD == 0: sync_output_dir_s3(cfg) if (cfg.TEST.EVAL_PERIOD > 0 and batch_iter % cfg.TEST.EVAL_PERIOD == 0 and batch_iter != max_iter) or \ batch_iter in cfg.TEST.ADDITIONAL_EVAL_STEPS: do_test(cfg, model) d2_comm.synchronize()