Ejemplo n.º 1
0
    def test_periodic_checkpointer_max_to_keep(self) -> None:
        """
        Test parameter: max_to_keep
        """
        _period = 10
        _max_iter = 100
        _max_to_keep = 3
        for trained_model in [
                self._create_model(),
                nn.DataParallel(self._create_model()),
        ]:
            with TemporaryDirectory() as f:
                checkpointer = Checkpointer(trained_model,
                                            save_dir=f,
                                            save_to_disk=True)
                periodic_checkpointer = PeriodicCheckpointer(
                    checkpointer, _period, 99, max_to_keep=_max_to_keep)
                for _ in range(2):
                    checkpoint_paths = []

                    for iteration in range(_max_iter):
                        periodic_checkpointer.step(iteration)
                        if (iteration + 1) % _period == 0:
                            path = os.path.join(
                                f, "model_{:07d}.pth".format(iteration))
                            checkpoint_paths.append(path)

                    for path in checkpoint_paths[:-_max_to_keep]:
                        self.assertFalse(os.path.exists(path))

                    for path in checkpoint_paths[-_max_to_keep:]:
                        self.assertTrue(os.path.exists(path))
Ejemplo n.º 2
0
    def _setup_checkpointers(self, resume_from="", search=True, period=1, **add_checkpointables):
        """
        Sets up a periodic chechkpointer which can be used to save checkpoints
        at every epoch. It will call optimizer's `get_checkpointables()` as objects
        to store.

        Args:
            resume_from (str): A checkpoint file to resume the search or evaluation from.
            search (bool): Whether search or evaluation phase is checkpointed. This is required
                because the files are in different folders to not be overridden
            add_checkpointables (object): Additional things to checkpoint together with the
                optimizer's checkpointables.
        """
        checkpointables = self.optimizer.get_checkpointables()
        checkpointables.update(add_checkpointables)

        checkpointer = utils.Checkpointer(
            model=checkpointables.pop('model'),
            save_dir=self.config.save + "/search" if search else self.config.save + "/eval",
            #**checkpointables #NOTE: this is throwing an Error
        )

        self.periodic_checkpointer = PeriodicCheckpointer(
            checkpointer,
            period=period,
            max_iter=self.config.search.epochs if search else self.config.evaluation.epochs
        )

        if resume_from:
            logger.info("loading model from file {}".format(resume_from))
            checkpoint = checkpointer.resume_or_load(resume_from, resume=True)
            if checkpointer.has_checkpoint():
                return checkpoint.get("iteration", -1) + 1
        return 0
Ejemplo n.º 3
0
 def test_periodic_checkpointer(self) -> None:
     """
     test that loading works even if they differ by a prefix.
     """
     _period = 10
     _max_iter = 100
     for trained_model in [
         self._create_model(),
         nn.DataParallel(self._create_model()),
     ]:
         with TemporaryDirectory() as f:
             checkpointer = Checkpointer(
                 trained_model, save_dir=f, save_to_disk=True
             )
             periodic_checkpointer = PeriodicCheckpointer(checkpointer, _period, 99)
             for iteration in range(_max_iter):
                 periodic_checkpointer.step(iteration)
                 path = os.path.join(f, "model_{:07d}.pth".format(iteration))
                 if (iteration + 1) % _period == 0:
                     self.assertTrue(os.path.exists(path))
                 else:
                     self.assertFalse(os.path.exists(path))
Ejemplo n.º 4
0
def do_train(cfg, model):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = Checkpointer(model, './', optimizer=optimizer, scheduler=scheduler)
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)

    writers = [CommonMetricPrinter(max_iter)] if d2_comm.is_main_process() else []

    train_mapper = get_dataset_mapper(cfg, is_train=True)
    dataloader, dataset_dicts = build_train_dataloader(cfg, mapper=train_mapper)
    LOG.info("Length of train dataset: {:d}".format(len(dataset_dicts)))
    LOG.info("Starting training")
    storage = get_event_storage()

    if cfg.EVAL_ON_START:
        do_test(cfg, model)
        comm.synchronize()

    # In mixed-precision training, gradients are scaled up to keep them from being vanished due to half-precision.
    # They're scaled down again before optimizers use them to compute updates.
    scaler = amp.GradScaler(enabled=cfg.SOLVER.MIXED_PRECISION_ENABLED)

    # Accumulate gradients for multiple batches (as returned by dataloader) before calling optimizer.step().
    accumulate_grad_batches = cfg.SOLVER.ACCUMULATE_GRAD_BATCHES

    num_images_seen = 0
    # For logging, this stores losses aggregated from all workers in distributed training.
    batch_loss_dict = defaultdict(float)
    optimizer.zero_grad()
    for data, iteration in zip(dataloader, range(max_iter * accumulate_grad_batches)):
        iteration += 1
        # this assumes drop_last=True, so all workers has the same size of batch.
        num_images_seen += len(data) * d2_comm.get_world_size()
        if iteration % accumulate_grad_batches == 0:
            storage.step()

        with amp.autocast(enabled=cfg.SOLVER.MIXED_PRECISION_ENABLED):
            loss_dict = model(data)
        # Account for accumulated gradients.
        loss_dict = {name: loss / accumulate_grad_batches for name, loss in loss_dict.items()}
        losses = sum(loss_dict.values())
        # FIXME: First few iterations might give Inf/NaN losses when using mixed precision. What should be done?
        if not torch.isfinite(losses):
            LOG.critical(f"The loss DIVERGED: {loss_dict}")

        # Track total loss for logging.
        loss_dict_reduced = {k: v.item() for k, v in d2_comm.reduce_dict(loss_dict).items()}
        assert torch.isfinite(torch.as_tensor(list(loss_dict_reduced.values()))).all(), loss_dict_reduced
        for k, v in loss_dict_reduced.items():
            batch_loss_dict[k] += v

        # No amp version: leaving this here for legacy:
        # losses.backward()
        scaler.scale(losses).backward()

        if iteration % accumulate_grad_batches > 0:
            # Just accumulate gradients and move on to next batch.
            continue

        # No amp version: leaving this here for legacy:
        # optimizer.step()
        # scheduler.step()
        # optimizer.zero_grad()

        scaler.step(optimizer)
        storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
        scheduler.step()
        scaler.update()

        losses_reduced = sum(loss for loss in batch_loss_dict.values())
        storage.put_scalars(total_loss=losses_reduced, **batch_loss_dict)

        # Reset states.
        batch_loss_dict = defaultdict(float)
        optimizer.zero_grad()

        batch_iter = iteration // accumulate_grad_batches

        # TODO: probably check if the gradients contain any inf or nan, and only proceed if not.
        if batch_iter > 5 and (batch_iter % 20 == 0 or batch_iter == max_iter):
            # if batch_iter > -1 and (batch_iter % 1 == 0 or batch_iter == max_iter):
            for writer in writers:
                writer.write()
            # log epoch / # images seen
            if d2_comm.is_main_process() and cfg.WANDB.ENABLED:
                wandb.log({"epoch": 1 + num_images_seen // len(dataset_dicts)}, step=batch_iter)
                wandb.log({"num_images_seen": num_images_seen}, step=batch_iter)

        if cfg.VIS.DATALOADER_ENABLED and batch_iter % cfg.VIS.DATALOADER_PERIOD == 0 and d2_comm.is_main_process():
            dataset_name = cfg.DATASETS.TRAIN.NAME
            visualizer_names = MetadataCatalog.get(dataset_name).loader_visualizers
            viz_images = defaultdict(dict)
            for viz_name in visualizer_names:
                viz = get_dataloader_visualizer(cfg, viz_name, dataset_name)
                for idx, x in enumerate(data):
                    viz_images[idx].update(viz.visualize(x))

            if cfg.WANDB.ENABLED:
                # per_image_vis = [coalece_viz_images(viz_images[idx])[0] for idx in range(len(data))]
                per_image_vis = [mosaic(list(viz_images[idx].values())) for idx in range(len(data))]
                wandb.log({
                    "dataloader": [wandb.Image(vis, caption=f"idx={idx}") for idx, vis in enumerate(per_image_vis)]
                },
                          step=batch_iter)
            save_vis(viz_images, os.path.join(os.getcwd(), "visualization"), "dataloader", step=batch_iter)

        if d2_comm.is_main_process():  # TODO (dennis.park): is this necessary?
            periodic_checkpointer.step(batch_iter - 1)  # (fvcore) model_0004999.pth checkpoints 5000-th iteration

        if batch_iter > 0 and batch_iter % cfg.SYNC_OUTPUT_DIR_S3.PERIOD == 0:
            sync_output_dir_s3(cfg)

        if (cfg.TEST.EVAL_PERIOD > 0 and batch_iter % cfg.TEST.EVAL_PERIOD == 0 and batch_iter != max_iter) or \
            batch_iter in cfg.TEST.ADDITIONAL_EVAL_STEPS:
            do_test(cfg, model)
            d2_comm.synchronize()