Ejemplo n.º 1
0
    def test_from_name_file_model(self) -> None:
        """
        test that loading works even if they differ by a prefix.
        """
        for trained_model, fresh_model in [
            (self._create_model(), self._create_model()),
            (nn.DataParallel(self._create_model()), self._create_model()),
            (self._create_model(), nn.DataParallel(self._create_model())),
            (
                nn.DataParallel(self._create_model()),
                nn.DataParallel(self._create_model()),
            ),
        ]:
            with TemporaryDirectory() as f:
                checkpointer = Checkpointer(trained_model,
                                            save_dir=f,
                                            save_to_disk=True)
                checkpointer.save("checkpoint_file")

                # on different folders.
                with TemporaryDirectory() as g:
                    fresh_checkpointer = Checkpointer(fresh_model, save_dir=g)
                    self.assertFalse(fresh_checkpointer.has_checkpoint())
                    self.assertEqual(fresh_checkpointer.get_checkpoint_file(),
                                     "")
                    fresh_checkpointer.load(
                        os.path.join(f, "checkpoint_file.pth"))

            for trained_p, loaded_p in zip(trained_model.parameters(),
                                           fresh_model.parameters()):
                # different tensor references.
                self.assertFalse(id(trained_p) == id(loaded_p))
                # same content.
                self.assertTrue(trained_p.cpu().equal(loaded_p.cpu()))
Ejemplo n.º 2
0
def main():
    config = load_config()

    if config.test.output_dir is None:
        output_dir = pathlib.Path(config.test.checkpoint).parent
    else:
        output_dir = pathlib.Path(config.test.output_dir)
        output_dir.mkdir(exist_ok=True, parents=True)

    logger = create_logger(name=__name__, distributed_rank=get_rank())

    model = create_model(config)
    model = apply_data_parallel_wrapper(config, model)
    checkpointer = Checkpointer(model,
                                checkpoint_dir=output_dir,
                                logger=logger,
                                distributed_rank=get_rank())
    checkpointer.load(config.test.checkpoint)

    test_loader = create_dataloader(config, is_train=False)
    _, test_loss = create_loss(config)

    preds, probs, labels, loss, acc = evaluate(config, model, test_loader,
                                               test_loss, logger)

    output_path = output_dir / f'predictions.npz'
    np.savez(output_path,
             preds=preds,
             probs=probs,
             labels=labels,
             loss=loss,
             acc=acc)
Ejemplo n.º 3
0
    def __init__(self, cfg, dataset_name, distributed, output_dir=None):
        """
        Args:
            dataset_name (str): name of the dataset to be evaluated.
            distributed (True): if True, will collect results from all ranks for evaluation.
                Otherwise, will evaluate the results in the current process.
            output_dir (str): an output directory to dump results.
        """
        self._cpu_device = torch.device("cpu")
        self._logger = logging.getLogger(__name__)
        self._dataset_name = dataset_name
        self._distributed = distributed
        self._output_dir = output_dir

        vq_cfg = get_cfg()
        vq_cfg.merge_from_file(cfg.TEST.VT_SAMPLER.VQ_VAE.CFG)
        self.vqvae = build_model(vq_cfg)
        Checkpointer(self.vqvae.encoder).resume_or_load(
            cfg.TEST.VT_SAMPLER.VQ_VAE.ENCODER_WEIGHTS, resume=False)
        Checkpointer(self.vqvae.generator).resume_or_load(
            cfg.TEST.VT_SAMPLER.VQ_VAE.GENERATOR_WEIGHTS, resume=False)
        Checkpointer(self.vqvae.codebook).resume_or_load(
            cfg.TEST.VT_SAMPLER.VQ_VAE.CODEBOOK_WEIGHTS, resume=False)
        self.vqvae.set_generator_requires_grad(False)
        self.vqvae.eval()
        self.scale_to_zeroone = vq_cfg.INPUT.SCALE_TO_ZEROONE
Ejemplo n.º 4
0
def sample_videos(args):
    # load config
    cfg = setup(args)
    cfg.TEST.EVALUATORS = "VTSampler"
    cfg.TEST.NUM_SAMPLES = 1

    # load videotransformer
    vt = build_model(cfg)
    Checkpointer(vt.model).resume_or_load(cfg.MODEL.GENERATOR.WEIGHTS,
                                          resume=False)
    vt.eval()

    # load vqvae
    vq_cfg = get_cfg()
    vq_cfg.merge_from_file(cfg.TEST.VT_SAMPLER.VQ_VAE.CFG)
    vqvae = build_model(vq_cfg)
    Checkpointer(vqvae.encoder).resume_or_load(
        cfg.TEST.VT_SAMPLER.VQ_VAE.ENCODER_WEIGHTS, resume=False)
    Checkpointer(vqvae.generator).resume_or_load(
        cfg.TEST.VT_SAMPLER.VQ_VAE.GENERATOR_WEIGHTS, resume=False)
    Checkpointer(vqvae.codebook).resume_or_load(
        cfg.TEST.VT_SAMPLER.VQ_VAE.CODEBOOK_WEIGHTS, resume=False)
    vqvae.eval()

    # load data
    scale_to_zeroone = vq_cfg.INPUT.SCALE_TO_ZEROONE
    n_prime = cfg.TEST.VT_SAMPLER.N_PRIME
    images = load_video(
        args.video_dir)[:n_prime]  # (T, C, H, W) = (n_prime, 3, 64, 64)
    assert images.shape == (n_prime, 3, 64, 64)
    print(f"Loaded {n_prime} priming frames")

    # sample
    latent_sequence = vqvae([{
        'image_sequence': images
    }])[0]['latent']  # (n_prime, nc, h, w) = (n_prime, nc, 16, 16)
    print(f"Transferred to latent codes.")
    _, nc, h, w = latent_sequence.shape
    new_sequence = latent_sequence.new_zeros(16, nc, h, w)
    new_sequence[:n_prime] = latent_sequence
    samples = vt([{
        'image_sequence': new_sequence
    }])[0]['samples']  # list of samples
    print(f"Sampled new video.")
    sample = samples[0].squeeze(0)  # T, h, w if nc == 1 or nc, T, h, w
    if sample.dim() == 4:
        sample = sample.transpose(0, 1)  # T, nc, h, w
    sample = vqvae.decode(sample)  # T, 3, H, W
    sample = vqvae.back_normalizer(sample)  # T, 3, H, W
    if scale_to_zeroone:
        sample = sample * 255
    sample.clamp_(0.0, 255.0)
    sample = sample.permute(0, 2, 3, 1).contiguous()  # T, H, W, 3
    sample = sample.detach().cpu().numpy().astype(np.uint8)
    save_video(sample, cfg.OUTPUT_DIR)
    print(f"Saved new video.")
Ejemplo n.º 5
0
def load_checkpoint_from_http(
    model,
    filename,
    map_location=None,
):
    checkpointer = Checkpointer(model)
    checkpoint = load_from_http(filename, map_location=map_location)
    
    checkpointer.logger.info("[Checkpointer] Loading from {} ...".format(filename))
    incompatible = checkpointer._load_model(checkpoint={"model": checkpoint})
    
    # handle some existing subclasses that returns None
    if incompatible is not None:
        checkpointer._log_incompatible_keys(incompatible)
Ejemplo n.º 6
0
 def test_best_checkpointer(self):
     model = _SimpleModel()
     dataloader = self._data_loader("cpu")
     opt = torch.optim.SGD(model.parameters(), 0.1)
     metric_name = "metric"
     total_iter = 40
     test_period = 10
     test_cases = [
         ("max", iter([0.3, 0.4, 0.35, 0.5]), 3),
         ("min", iter([1.0, 0.8, 0.9, 0.9]), 2),
         ("min", iter([math.nan, 0.8, 0.9, 0.9]), 1),
     ]
     for mode, metrics, call_count in test_cases:
         trainer = SimpleTrainer(model, dataloader, opt)
         with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
             checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)
             trainer.register_hooks([
                 hooks.EvalHook(test_period,
                                lambda: {metric_name: next(metrics)}),
                 hooks.BestCheckpointer(test_period,
                                        checkpointer,
                                        metric_name,
                                        mode=mode),
             ])
             with mock.patch.object(checkpointer,
                                    "save") as mock_save_method:
                 trainer.train(0, total_iter)
                 self.assertEqual(mock_save_method.call_count, call_count)
Ejemplo n.º 7
0
    def test_periodic_checkpointer_max_to_keep(self) -> None:
        """
        Test parameter: max_to_keep
        """
        _period = 10
        _max_iter = 100
        _max_to_keep = 3
        for trained_model in [
                self._create_model(),
                nn.DataParallel(self._create_model()),
        ]:
            with TemporaryDirectory() as f:
                checkpointer = Checkpointer(trained_model,
                                            save_dir=f,
                                            save_to_disk=True)
                periodic_checkpointer = PeriodicCheckpointer(
                    checkpointer, _period, 99, max_to_keep=_max_to_keep)
                for _ in range(2):
                    checkpoint_paths = []

                    for iteration in range(_max_iter):
                        periodic_checkpointer.step(iteration)
                        if (iteration + 1) % _period == 0:
                            path = os.path.join(
                                f, "model_{:07d}.pth".format(iteration))
                            checkpoint_paths.append(path)

                    for path in checkpoint_paths[:-_max_to_keep]:
                        self.assertFalse(os.path.exists(path))

                    for path in checkpoint_paths[-_max_to_keep:]:
                        self.assertTrue(os.path.exists(path))
Ejemplo n.º 8
0
def main():
    config = load_config()

    set_seeds(config.train.seed)
    setup_cudnn(config)

    output_dir = create_train_output_dir(config)
    save_config(config, output_dir)
    logger = create_logger(name=__name__,
                           output_dir=output_dir,
                           filename='log.txt')
    logger.info(config)

    train_loader, val_loader = create_dataloader(config, is_train=True)
    model = create_model(config)
    loss_function = create_loss(config)
    optimizer = create_optimizer(config, model)
    scheduler = create_scheduler(config, optimizer)
    checkpointer = Checkpointer(model,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                save_dir=output_dir.as_posix(),
                                save_to_disk=True)
    tensorboard_writer = create_tensorboard_writer(config, output_dir)

    if config.train.val_first:
        validate(0, model, loss_function, val_loader, config,
                 tensorboard_writer, logger)

    for epoch in range(1, config.scheduler.epochs + 1):
        train(epoch, model, optimizer, scheduler, loss_function, train_loader,
              config, tensorboard_writer, logger)
        scheduler.step()

        if epoch % config.train.val_period == 0:
            validate(epoch, model, loss_function, val_loader, config,
                     tensorboard_writer, logger)

        if (epoch % config.train.checkpoint_period == 0
                or epoch == config.scheduler.epochs):
            checkpoint_config = {'epoch': epoch, 'config': config.as_dict()}
            checkpointer.save(f'checkpoint_{epoch:04d}', **checkpoint_config)

    tensorboard_writer.close()
Ejemplo n.º 9
0
    def test_load_reused_params(self) -> None:
        class Model(nn.Module):
            def __init__(self, has_y: bool) -> None:
                super().__init__()
                self.x = nn.Linear(10, 10)
                if has_y:
                    self.y = self.x

        model = Model(has_y=False)
        model.x.bias.data.fill_(5.0)  # pyre-ignore
        data = {"model": model.state_dict()}
        new_model = Model(has_y=True)
        chkpt = Checkpointer(new_model)
        chkpt.logger = logger = MagicMock()
        chkpt._load_model(data)
        self.assertTrue(
            torch.allclose(new_model.y.bias - 5.0,
                           torch.zeros_like(new_model.y.bias)))
        logger.info.assert_not_called()
Ejemplo n.º 10
0
    def test_checkpoint_resume(self):
        model = _SimpleModel()
        dataloader = self._data_loader("cpu")
        opt = torch.optim.SGD(model.parameters(), 0.1)
        scheduler = torch.optim.lr_scheduler.StepLR(opt, 3)

        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
            trainer = SimpleTrainer(model, dataloader, opt)
            checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)

            trainer.register_hooks([
                hooks.LRScheduler(scheduler=scheduler),
                # checkpoint after scheduler to properly save the state of scheduler
                hooks.PeriodicCheckpointer(checkpointer, 10),
            ])

            trainer.train(0, 12)
            self.assertAlmostEqual(opt.param_groups[0]["lr"], 1e-5)
            self.assertEqual(scheduler.last_epoch, 12)
            del trainer

            opt = torch.optim.SGD(model.parameters(), 999)  # lr will be loaded
            trainer = SimpleTrainer(model, dataloader, opt)
            scheduler = torch.optim.lr_scheduler.StepLR(opt, 3)
            trainer.register_hooks([
                hooks.LRScheduler(scheduler=scheduler),
            ])
            checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)
            checkpointer.resume_or_load("non_exist.pth")
            self.assertEqual(
                trainer.iter,
                11)  # last finished iter number (0-based in Trainer)
            # number of times `scheduler.step()` was called (1-based)
            self.assertEqual(scheduler.last_epoch, 12)
            self.assertAlmostEqual(opt.param_groups[0]["lr"], 1e-5)
Ejemplo n.º 11
0
    def test_checkpoint_resume(self):
        model = _SimpleModel()
        dataloader = self._data_loader("cpu")
        opt = torch.optim.SGD(model.parameters(), 0.1)
        scheduler = torch.optim.lr_scheduler.StepLR(opt, 3)

        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
            trainer = SimpleTrainer(model, dataloader, opt)
            checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)

            trainer.register_hooks(
                [
                    hooks.PeriodicCheckpointer(checkpointer, 10),
                    hooks.LRScheduler(scheduler=scheduler),
                ]
            )

            trainer.train(0, 12)
            del trainer

            trainer = SimpleTrainer(model, dataloader, opt)
            scheduler = torch.optim.lr_scheduler.StepLR(opt, 3)
            trainer.register_hooks(
                [
                    hooks.LRScheduler(scheduler=scheduler),
                ]
            )
            checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)
            checkpointer.resume_or_load("non_exist.pth")
            self.assertEqual(trainer.iter, 11)  # last finished iter
            self.assertEqual(scheduler.last_epoch, 11)
Ejemplo n.º 12
0
    def configure_optimizers_and_checkpointers(self):
        optimizer_g = build_optimizer(self.model, self.cfg, suffix="_G")
        scheduler_g = build_lr_scheduler(self.cfg, optimizer_g)

        PathManager.mkdirs(os.path.join(self.cfg.OUTPUT_DIR, 'netG'))
        c = [
            {"checkpointer": Checkpointer(self.model, os.path.join(self.cfg.OUTPUT_DIR, 'netG')),
             "pretrained": self.cfg.MODEL.GENERATOR.WEIGHTS, },
        ]
        o = [
            {"optimizer": optimizer_g, "scheduler": scheduler_g, "type": "generator"},
        ]
        return o, c
Ejemplo n.º 13
0
def evaluate_on_dataset(
        config_file="../../configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
        override_cfg=(),
        test_datasets=(),
):
    if override_cfg is None:
        override_cfg = []
    cfg = get_cfg()
    cfg.merge_from_file(config_file)
    cfg.merge_from_list(override_cfg)
    cfg.DATASETS.TEST = test_datasets
    model = build_model(cfg)

    checkpointer = Checkpointer(model)
    checkpointer.load(cfg.MODEL.WEIGHTS)

    evaluator = [
        COCOEvaluator(test_set, cfg, False) for test_set in test_datasets
    ]

    metrics = DefaultTrainer.test(cfg, model, evaluator)

    return metrics
Ejemplo n.º 14
0
    def configure_optimizers_and_checkpointers(self):
        o, c = super().configure_optimizers_and_checkpointers()

        if not self.use_codebook_ema:
            optimizer_c = build_optimizer(self.codebook, self.cfg, suffix="_G")
            scheduler_c = build_lr_scheduler(self.cfg, optimizer_c)
            o += [
                {"optimizer": optimizer_c, "scheduler": scheduler_c, "type": "generator"},
            ]

        PathManager.mkdirs(os.path.join(self.cfg.OUTPUT_DIR, 'netC'))
        c += [
            {"checkpointer": Checkpointer(self.codebook, os.path.join(self.cfg.OUTPUT_DIR, 'netC')),
             "pretrained": self.cfg.MODEL.CODEBOOK.WEIGHTS, },
        ]

        return o, c
Ejemplo n.º 15
0
 def test_periodic_checkpointer(self) -> None:
     """
     test that loading works even if they differ by a prefix.
     """
     _period = 10
     _max_iter = 100
     for trained_model in [
         self._create_model(),
         nn.DataParallel(self._create_model()),
     ]:
         with TemporaryDirectory() as f:
             checkpointer = Checkpointer(
                 trained_model, save_dir=f, save_to_disk=True
             )
             periodic_checkpointer = PeriodicCheckpointer(checkpointer, _period, 99)
             for iteration in range(_max_iter):
                 periodic_checkpointer.step(iteration)
                 path = os.path.join(f, "model_{:07d}.pth".format(iteration))
                 if (iteration + 1) % _period == 0:
                     self.assertTrue(os.path.exists(path))
                 else:
                     self.assertFalse(os.path.exists(path))
Ejemplo n.º 16
0
    def test_load_lazy_module(self) -> None:
        def _get_model() -> nn.Sequential:  # pyre-fixme[11]
            return nn.Sequential(nn.LazyLinear(10))

        m1, m2 = _get_model(), _get_model()
        m1(torch.randn(4, 2, 4, 4))  # initialize m1, but not m2
        # Load m1's checkpoint into m2.
        with TemporaryDirectory() as f:
            checkpointer = Checkpointer(m1, save_dir=f)
            checkpointer.save("checkpoint_file")

            fresh_checkpointer = Checkpointer(m2, save_dir=f)
            self.assertTrue(fresh_checkpointer.has_checkpoint())
            self.assertEqual(
                fresh_checkpointer.get_checkpoint_file(),
                os.path.join(f, "checkpoint_file.pth"),
            )
            fresh_checkpointer.load(fresh_checkpointer.get_checkpoint_file())
            self.assertTrue(torch.equal(m1[0].weight, m2[0].weight))
Ejemplo n.º 17
0
    def test_checkpointables(self) -> None:
        """
        Test saving and loading checkpointables.
        """
        class CheckpointableObj:
            """
            A dummy checkpointableObj class with state_dict and load_state_dict
            methods.
            """
            def __init__(self):
                self.state = {
                    self.random_handle(): self.random_handle()
                    for i in range(10)
                }

            def random_handle(self, str_len=100) -> str:
                """
                Generate a random string of fixed length.
                Args:
                    str_len (str): length of the output string.
                Returns:
                    (str): random generated handle.
                """
                letters = string.ascii_uppercase
                return "".join(random.choice(letters) for i in range(str_len))

            def state_dict(self):
                """
                Return the state.
                Returns:
                    (dict): return the state.
                """
                return self.state

            def load_state_dict(self, state) -> None:
                """
                Load the state from a given state.
                Args:
                    state (dict): a key value dictionary.
                """
                self.state = copy.deepcopy(state)

        trained_model, fresh_model = self._create_model(), self._create_model()
        with TemporaryDirectory() as f:
            checkpointables = CheckpointableObj()
            checkpointer = Checkpointer(
                trained_model,
                save_dir=f,
                save_to_disk=True,
                checkpointables=checkpointables,
            )
            checkpointer.save("checkpoint_file")
            # in the same folder
            fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
            self.assertTrue(fresh_checkpointer.has_checkpoint())
            self.assertEqual(
                fresh_checkpointer.get_checkpoint_file(),
                os.path.join(f, "checkpoint_file.pth"),
            )
            checkpoint = fresh_checkpointer.load(
                fresh_checkpointer.get_checkpoint_file())
            state_dict = checkpointables.state_dict()
            for key, _ in state_dict.items():
                self.assertTrue(
                    checkpoint["checkpointables"].get(key) is not None)
                self.assertTrue(
                    checkpoint["checkpointables"][key] == state_dict[key])
Ejemplo n.º 18
0
}

# search_space = SimpleCellSearchSpace()
search_space = NasBench201SeachSpace()
# search_space = HierarchicalSearchSpace()
# search_space = DartsSearchSpace()

assert search_space.QUERYABLE

optimizer = supported_optimizers[config.optimizer]

optimizer.adapt_search_space(search_space)

checkpoint_dir = '/home/moa/dev/python_projects/NASLib/naslib/benchmarks/nasbench201/run/cifar10/{}/4/search/'.format(
    config.optimizer)
checkpointables = optimizer.get_checkpointables()

checkpointer = Checkpointer(model=checkpointables.pop('model'),
                            save_dir="/tmp/",
                            **checkpointables)

for checkpoint in sorted(
        glob.glob(os.path.join(checkpoint_dir, 'model_0*.pth'))):

    checkpoint = checkpointer.resume_or_load(checkpoint, resume=False)
    epoch = checkpoint.get("iteration", -1)

    print(optimizer.test_statistics())

trainer.evaluate(resume_from=checkpoint)
Ejemplo n.º 19
0
def main(cfg):
    setup(cfg)
    dataset_names = register_datasets(cfg)
    if cfg.ONLY_REGISTER_DATASETS:
        return {}, cfg
    LOG.info(f"Registered {len(dataset_names)} datasets:" + '\n\t' + '\n\t'.join(dataset_names))

    model = build_model(cfg)

    checkpoint_file = cfg.MODEL.CKPT
    if checkpoint_file:
        if cfg.MODEL.CKPT_REMAPPER:
            if cfg.EVAL_ONLY:
                LOG.warning("Running with 'EVAL_ONLY', but the checkpoint is remapped.")
            checkpoint_file = CHECKPOINT_REMAPPERS[cfg.MODEL.CKPT_REMAPPER](checkpoint_file, model)

        # Batchnorm2D submodules to convert to FrozenBatchnNorm2D.
        modules_to_convert_frozenbb = [
            (name, module) for name, module in model.named_modules() if name in cfg.MODEL.CONVERT_TO_FROZEN_BN_MODULES
        ]
        if len(modules_to_convert_frozenbb) > 0:
            module_names, modules = list(zip(*modules_to_convert_frozenbb))
            LOG.info(
                f"Converting BatchNorm2d -> FrozenBatchNorm2d {len(modules)} submodule(s):" + '\n\t' +
                '\n\t'.join(module_names)
            )
            for module in modules:
                FrozenBatchNorm2d.convert_frozen_batchnorm(module)

        # Some checkpoints contain batchnorm layer with negative value for 'running_var'.
        model = HotFixFrozenBatchNorm2d.convert_frozenbn_to_hotfix_ver(model)
        Checkpointer(model).load(checkpoint_file)

    if cfg.EVAL_ONLY:
        assert cfg.TEST.ENABLED, "'eval-only' mode is not compatible with 'cfg.TEST.ENABLED = False'."
        test_results = do_test(cfg, model, is_last=True)
        if cfg.TEST.AUG.ENABLED:
            test_results.update(do_test(cfg, model, is_last=True, use_tta=True))
        return test_results, cfg

    modules_to_freeze = cfg.MODEL.FREEZE_MODULES
    if modules_to_freeze:
        LOG.info(f"Freezing {len(modules_to_freeze)} submodule(s):" + '\n\t' + '\n\t'.join(modules_to_freeze))
        # `requires_grad=False` must be set *before* wrapping the model with `DistributedDataParallel`
        # modules_to_freeze = [x.strip() for x in cfg.MODEL.FREEZE_MODULES.split(',')]
        # for module_name in cfg.MODEL.FREEZE_MODULES:
        for module_name in modules_to_freeze:
            freeze_submodule(model, module_name)

    if comm.is_distributed():
        assert d2_comm._LOCAL_PROCESS_GROUP is not None
        # Convert all Batchnorm*D to nn.SyncBatchNorm.
        # For faster training, the batch stats are computed over only the GPUs of the same machines (usually 8).
        sync_bn_pg = d2_comm._LOCAL_PROCESS_GROUP if cfg.SOLVER.SYNCBN_USE_LOCAL_WORKERS else None
        model = SyncBatchNorm.convert_sync_batchnorm(model, process_group=sync_bn_pg)
        model = DistributedDataParallel(
            model,
            device_ids=[d2_comm.get_local_rank()],
            broadcast_buffers=False,
            find_unused_parameters=cfg.SOLVER.DDP_FIND_UNUSED_PARAMETERS
        )

    do_train(cfg, model)
    test_results = do_test(cfg, model, is_last=True)
    if cfg.TEST.AUG.ENABLED:
        test_results.update(do_test(cfg, model, is_last=True, use_tta=True))
    return test_results, cfg
Ejemplo n.º 20
0
def do_train(cfg, model):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = Checkpointer(model, './', optimizer=optimizer, scheduler=scheduler)
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)

    writers = [CommonMetricPrinter(max_iter)] if d2_comm.is_main_process() else []

    train_mapper = get_dataset_mapper(cfg, is_train=True)
    dataloader, dataset_dicts = build_train_dataloader(cfg, mapper=train_mapper)
    LOG.info("Length of train dataset: {:d}".format(len(dataset_dicts)))
    LOG.info("Starting training")
    storage = get_event_storage()

    if cfg.EVAL_ON_START:
        do_test(cfg, model)
        comm.synchronize()

    # In mixed-precision training, gradients are scaled up to keep them from being vanished due to half-precision.
    # They're scaled down again before optimizers use them to compute updates.
    scaler = amp.GradScaler(enabled=cfg.SOLVER.MIXED_PRECISION_ENABLED)

    # Accumulate gradients for multiple batches (as returned by dataloader) before calling optimizer.step().
    accumulate_grad_batches = cfg.SOLVER.ACCUMULATE_GRAD_BATCHES

    num_images_seen = 0
    # For logging, this stores losses aggregated from all workers in distributed training.
    batch_loss_dict = defaultdict(float)
    optimizer.zero_grad()
    for data, iteration in zip(dataloader, range(max_iter * accumulate_grad_batches)):
        iteration += 1
        # this assumes drop_last=True, so all workers has the same size of batch.
        num_images_seen += len(data) * d2_comm.get_world_size()
        if iteration % accumulate_grad_batches == 0:
            storage.step()

        with amp.autocast(enabled=cfg.SOLVER.MIXED_PRECISION_ENABLED):
            loss_dict = model(data)
        # Account for accumulated gradients.
        loss_dict = {name: loss / accumulate_grad_batches for name, loss in loss_dict.items()}
        losses = sum(loss_dict.values())
        # FIXME: First few iterations might give Inf/NaN losses when using mixed precision. What should be done?
        if not torch.isfinite(losses):
            LOG.critical(f"The loss DIVERGED: {loss_dict}")

        # Track total loss for logging.
        loss_dict_reduced = {k: v.item() for k, v in d2_comm.reduce_dict(loss_dict).items()}
        assert torch.isfinite(torch.as_tensor(list(loss_dict_reduced.values()))).all(), loss_dict_reduced
        for k, v in loss_dict_reduced.items():
            batch_loss_dict[k] += v

        # No amp version: leaving this here for legacy:
        # losses.backward()
        scaler.scale(losses).backward()

        if iteration % accumulate_grad_batches > 0:
            # Just accumulate gradients and move on to next batch.
            continue

        # No amp version: leaving this here for legacy:
        # optimizer.step()
        # scheduler.step()
        # optimizer.zero_grad()

        scaler.step(optimizer)
        storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
        scheduler.step()
        scaler.update()

        losses_reduced = sum(loss for loss in batch_loss_dict.values())
        storage.put_scalars(total_loss=losses_reduced, **batch_loss_dict)

        # Reset states.
        batch_loss_dict = defaultdict(float)
        optimizer.zero_grad()

        batch_iter = iteration // accumulate_grad_batches

        # TODO: probably check if the gradients contain any inf or nan, and only proceed if not.
        if batch_iter > 5 and (batch_iter % 20 == 0 or batch_iter == max_iter):
            # if batch_iter > -1 and (batch_iter % 1 == 0 or batch_iter == max_iter):
            for writer in writers:
                writer.write()
            # log epoch / # images seen
            if d2_comm.is_main_process() and cfg.WANDB.ENABLED:
                wandb.log({"epoch": 1 + num_images_seen // len(dataset_dicts)}, step=batch_iter)
                wandb.log({"num_images_seen": num_images_seen}, step=batch_iter)

        if cfg.VIS.DATALOADER_ENABLED and batch_iter % cfg.VIS.DATALOADER_PERIOD == 0 and d2_comm.is_main_process():
            dataset_name = cfg.DATASETS.TRAIN.NAME
            visualizer_names = MetadataCatalog.get(dataset_name).loader_visualizers
            viz_images = defaultdict(dict)
            for viz_name in visualizer_names:
                viz = get_dataloader_visualizer(cfg, viz_name, dataset_name)
                for idx, x in enumerate(data):
                    viz_images[idx].update(viz.visualize(x))

            if cfg.WANDB.ENABLED:
                # per_image_vis = [coalece_viz_images(viz_images[idx])[0] for idx in range(len(data))]
                per_image_vis = [mosaic(list(viz_images[idx].values())) for idx in range(len(data))]
                wandb.log({
                    "dataloader": [wandb.Image(vis, caption=f"idx={idx}") for idx, vis in enumerate(per_image_vis)]
                },
                          step=batch_iter)
            save_vis(viz_images, os.path.join(os.getcwd(), "visualization"), "dataloader", step=batch_iter)

        if d2_comm.is_main_process():  # TODO (dennis.park): is this necessary?
            periodic_checkpointer.step(batch_iter - 1)  # (fvcore) model_0004999.pth checkpoints 5000-th iteration

        if batch_iter > 0 and batch_iter % cfg.SYNC_OUTPUT_DIR_S3.PERIOD == 0:
            sync_output_dir_s3(cfg)

        if (cfg.TEST.EVAL_PERIOD > 0 and batch_iter % cfg.TEST.EVAL_PERIOD == 0 and batch_iter != max_iter) or \
            batch_iter in cfg.TEST.ADDITIONAL_EVAL_STEPS:
            do_test(cfg, model)
            d2_comm.synchronize()
Ejemplo n.º 21
0
    def test_from_last_checkpoint_model(self):
        """
        test that loading works even if they differ by a prefix.
        """
        for trained_model, fresh_model in [
            (self._create_model(), self._create_model()),
            (nn.DataParallel(self._create_model()), self._create_model()),
            (self._create_model(), nn.DataParallel(self._create_model())),
            (
                nn.DataParallel(self._create_model()),
                nn.DataParallel(self._create_model()),
            ),
        ]:

            with TemporaryDirectory() as f:
                checkpointer = Checkpointer(trained_model, save_dir=f)
                checkpointer.save("checkpoint_file")

                # in the same folder
                fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
                self.assertTrue(fresh_checkpointer.has_checkpoint())
                self.assertEqual(
                    fresh_checkpointer.get_checkpoint_file(),
                    os.path.join(f, "checkpoint_file.pth"),
                )
                fresh_checkpointer.load(
                    fresh_checkpointer.get_checkpoint_file())

            for trained_p, loaded_p in zip(trained_model.parameters(),
                                           fresh_model.parameters()):
                # different tensor references
                self.assertFalse(id(trained_p) == id(loaded_p))
                # same content
                self.assertTrue(trained_p.equal(loaded_p))
Ejemplo n.º 22
0
def main(cfg: DictConfig) -> None:

    if "experiments" in cfg.keys():
        cfg = OmegaConf.merge(cfg, cfg.experiments)

    if "debug" in cfg.keys():
        logger.info(f"Run script in debug")
        cfg = OmegaConf.merge(cfg, cfg.debug)

    # A logger for this file
    logger = logging.getLogger(__name__)

    # NOTE: hydra causes the python file to run in hydra.run.dir by default
    logger.info(f"Run script in {HydraConfig.get().run.dir}")

    writer = SummaryWriter(log_dir=cfg.train.tensorboard_dir)

    checkpoints_dir = Path(cfg.train.checkpoints_dir)
    if not checkpoints_dir.exists():
        checkpoints_dir.mkdir(parents=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    image_shape = (cfg.train.channels, cfg.train.image_height,
                   cfg.train.image_width)

    # NOTE: With hydra, the python file runs in hydra.run.dir by default, so set the dataset path to a full path or an appropriate relative path
    dataset_path = Path(cfg.dataset.root) / cfg.dataset.frames
    split_path = Path(cfg.dataset.root) / cfg.dataset.split_file
    assert dataset_path.exists(), "Video image folder not found"
    assert (split_path.exists()
            ), "The file that describes the split of train/test not found."

    # Define training set
    train_dataset = Dataset(
        dataset_path=dataset_path,
        split_path=split_path,
        split_number=cfg.dataset.split_number,
        input_shape=image_shape,
        sequence_length=cfg.train.sequence_length,
        training=True,
    )

    # Define train dataloader
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=cfg.train.batch_size,
        shuffle=True,
        num_workers=cfg.train.num_workers,
    )

    # Define test set
    test_dataset = Dataset(
        dataset_path=dataset_path,
        split_path=split_path,
        split_number=cfg.dataset.split_number,
        input_shape=image_shape,
        sequence_length=cfg.train.sequence_length,
        training=False,
    )

    # Define test dataloader
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=cfg.train.batch_size,
        shuffle=False,
        num_workers=cfg.train.num_workers,
    )

    # Classification criterion
    criterion = nn.CrossEntropyLoss().to(device)

    # Define network
    model = CNNLSTM(
        num_classes=train_dataset.num_classes,
        latent_dim=cfg.train.latent_dim,
        lstm_layers=cfg.train.lstm_layers,
        hidden_dim=cfg.train.hidden_dim,
        bidirectional=cfg.train.bidirectional,
        attention=cfg.train.attention,
    )
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    checkpointer = Checkpointer(
        model,
        optimizer=optimizer,
        # scheduler=scheduler,
        save_dir=cfg.train.checkpoints_dir,
        save_to_disk=True,
    )

    if cfg.train.resume:
        if not checkpointer.has_checkpoint():
            start_epoch = 0
        else:
            ckpt = checkpointer.resume_or_load("", resume=True)
            start_epoch = ckpt["epoch"]
            model.to(device)
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)
    elif cfg.train.checkpoint_model != "":
        ckpt = torch.load(cfg.train.checkpoint_model, map_location="cpu")
        model.load_state_dict(ckpt["model"])
        model.to(device)
        start_epoch = 0
    else:
        start_epoch = 0

    for epoch in range(start_epoch, cfg.train.num_epochs):
        epoch += 1
        epoch_metrics = {"loss": [], "acc": []}
        timer = Timer()
        for batch_i, (X, y) in enumerate(train_dataloader):
            batch_i += 1
            if X.size(0) == 1:
                continue

            image_sequences = Variable(X.to(device), requires_grad=True)
            labels = Variable(y.to(device), requires_grad=False)

            optimizer.zero_grad()

            # Reset LSTM hidden state
            model.lstm.reset_hidden_state()

            # Get sequence predictions
            predictions = model(image_sequences)

            # Compute metrics
            loss = criterion(predictions, labels)
            acc = (
                predictions.detach().argmax(1) == labels).cpu().numpy().mean()

            loss.backward()
            optimizer.step()

            # Keep track of epoch metrics
            epoch_metrics["loss"].append(loss.item())
            epoch_metrics["acc"].append(acc)

            # Determine approximate time left
            batches_done = (epoch - 1) * len(train_dataloader) + (batch_i - 1)
            batches_left = cfg.train.num_epochs * len(
                train_dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           timer.seconds())
            time_iter = round(timer.seconds(), 3)
            timer.reset()

            logger.info(
                f'Training - [Epoch: {epoch}/{cfg.train.num_epochs}] [Batch: {batch_i}/{len(train_dataloader)}] [Loss: {np.mean(epoch_metrics["loss"]):.3f}] [Acc: {np.mean(epoch_metrics["acc"]):.3f}] [ETA: {time_left}] [Iter time: {time_iter}s/it]'
            )

            # Empty cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        writer.add_scalar("train/loss", np.mean(epoch_metrics["loss"]), epoch)
        writer.add_scalar("train/acc", np.mean(epoch_metrics["acc"]), epoch)

        def test_model(epoch):
            """ Evaluate the model on the test set """
            model.eval()
            test_metrics = {"loss": [], "acc": []}
            timer = Timer()
            for batch_i, (X, y) in enumerate(test_dataloader):
                batch_i += 1
                image_sequences = Variable(X.to(device), requires_grad=False)
                labels = Variable(y, requires_grad=False).to(device)

                with torch.no_grad():
                    # Reset LSTM hidden state
                    model.lstm.reset_hidden_state()
                    # Get sequence predictions
                    predictions = model(image_sequences)

                # Compute metrics
                loss = criterion(predictions, labels)
                acc = (predictions.detach().argmax(1) == labels
                       ).cpu().numpy().mean()

                # Keep track of loss and accuracy
                test_metrics["loss"].append(loss.item())
                test_metrics["acc"].append(acc)

                # Determine approximate time left
                batches_done = batch_i - 1
                batches_left = len(test_dataloader) - batches_done
                time_left = datetime.timedelta(seconds=batches_left *
                                               timer.seconds())
                time_iter = round(timer.seconds(), 3)
                timer.reset()

                # Log test performance
                logger.info(
                    f'Testing - [Epoch: {epoch}/{cfg.train.num_epochs}] [Batch: {batch_i}/{len(test_dataloader)}] [Loss: {np.mean(test_metrics["loss"]):.3f}] [Acc: {np.mean(test_metrics["acc"]):.3f}] [ETA: {time_left}] [Iter time: {time_iter}s/it]'
                )

            writer.add_scalar("test/loss", np.mean(test_metrics["loss"]),
                              epoch)
            writer.add_scalar("test/acc", np.mean(test_metrics["acc"]), epoch)

            model.train()

        # Evaluate the model on the test set
        test_model(epoch)

        # Save model checkpoint
        if epoch % cfg.train.checkpoint_interval == 0:
            checkpointer.save(f"checkpoint_{epoch:04}", epoch=epoch)

    writer.close()
def train(cfg):
    torch.manual_seed(cfg.exp.seed)
    np.random.seed(cfg.exp.seed)

    # Init model, optimizer, loss, video stream
    class_groups = lvs_dataset.sequence_to_class_groups_stable[cfg.dataset.sequence]
    class_groups = [ [lvs_dataset.detectron_classes.index(c) for c in g] \
                     for g in class_groups]
    num_classes = len(class_groups) + 1
    log.info(f'Number of class {num_classes}')

    dataset = lvs_dataset.LVSDataset(cfg.dataset.data_dir, cfg.dataset.sequence,
                         str(cfg.dataset.sequence_id).zfill(3),
                         start_frame=cfg.dataset.start_frame,
                         max_frames=cfg.dataset.max_frames,
                         stride=cfg.online_train.training_stride)

    device = torch.device('cuda')
    model, _ = load_model(cfg.model, num_classes)
    model = model[0]
    model.to(device)
    optimizer = configure_optimizer(cfg.online_train.optimizer, model)
    scheduler = None
    if cfg.online_train.scheduler.name == 'multi_step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         cfg.online_train.scheduler.milestones,
                                                         cfg.online_train.scheduler.gamma)
    elif cfg.online_train.scheduler.name == 'poly':
        scheduler = Poly(optimizer, cfg.online_train.epoch,
                         len(dataset) // cfg.online_train.batch_size)
    cls_weight = None
    if cfg.online_train.cls_weight:
        #assert len(cfg.online_train.cls_weight) == num_classes
        cls_weight = cfg.online_train.cls_weight[:num_classes]
        cls_weight = torch.tensor(cls_weight).float()
    criterion = torch.nn.CrossEntropyLoss(weight=cls_weight, reduction='none')
    criterion.to(device)

    start_epoch = 0
    checkpointer = Checkpointer(model, save_dir='./', optimizer=optimizer)
    #states = checkpointer.resume_or_load(None, resume=True)
    #if 'model' in states:
    #    model.load_state_dict(states['model'])
    #if 'optimizer' in states:
    #    optimizer.load_state_dict(states['optimizer'])
    #if 'epoch' in states:
    #    start_epoch = states['epoch'] + 1

    train_cfg = cfg.online_train
    sampler = None
    if cfg.model.perf_stats:
        sampler = get_sampler(cfg.model.perf_stats, num_classes, cfg.dataset.max_frames)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=train_cfg.batch_size,
                                             shuffle=sampler is None,
                                             num_workers=4,
                                             sampler=sampler)
    writer = SummaryWriter(log_dir='./', flush_secs=30)

    for epoch in range(start_epoch, train_cfg.epoch):
        ds_total = len(dataset) if not sampler else len(sampler)
        pbar = tqdm(total=ds_total // train_cfg.batch_size + 1)
        for batch_idx, (frames, labels, label_weights) in enumerate(dataloader):
            optimizer.zero_grad()

            frames = frames.to(device)
            labels = labels.to(device)
            label_weights = label_weights.to(device)

            logits = model(frames)
            logpt = criterion(logits, labels)
            fg_weights = torch.ones_like(label_weights) * train_cfg.fg_weight
            bg_mask = label_weights == 0
            fg_weights.masked_fill_(bg_mask, train_cfg.bg_weight)
            if train_cfg.focal_gamma > 0:
                pt = torch.exp(-logpt)
                loss = (((1. - pt) ** train_cfg.focal_gamma) * logpt * fg_weights).mean()
            else:
                loss = (logpt * fg_weights).mean()

            loss.backward()
            optimizer.step()

            with torch.no_grad():
                _, preds = torch.max(logits, dim=1)
                tp, fp, fn, cls_scores = \
                    calculate_class_iou(preds, labels, num_classes)

            step = epoch * len(dataset) + batch_idx * train_cfg.batch_size
            if batch_idx % 10 == 0:
                writer.add_scalar('train/loss', loss, step)
                writer.add_scalar('train/bg_iou', cls_scores[0], step)
                writer.add_scalar('train/fg_iou', cls_scores[1:].mean(), step)

            pbar.update(1)
            pbar.set_description(f'loss: {loss:.3f} mIoU: {cls_scores[1:].mean():.3f}')

        checkpointer.save(f'epoch_{epoch}.pth', epoch=epoch)

        if scheduler:
            scheduler.step()
Ejemplo n.º 24
0
def main():
    global global_step

    config = load_config()

    set_seed(config)
    setup_cudnn(config)

    # np.iinfo(np_type).max: machine limit (upper bound) of the this type
    # every epoch will have a specific epoch seed
    epoch_seeds = np.random.randint(np.iinfo(np.int32).max // 2,
                                    size=config.scheduler.epochs)

    if config.train.distributed:
        dist.init_process_group(backend=config.train.dist.backend,
                                init_method=config.train.dist.init_method,
                                rank=config.train.dist.node_rank,
                                world_size=config.train.dist.world_size)
        torch.cuda.set_device(config.train.dist.local_rank)

    output_dir = pathlib.Path(config.train.output_dir)
    if get_rank() == 0:
        if not config.train.resume and output_dir.exists():
            raise RuntimeError(
                f'Output directory `{output_dir.as_posix()}` already exists')
        output_dir.mkdir(exist_ok=True, parents=True)
        if not config.train.resume:
            # if we need to resume training, current config, environment info and the difference between
            # the current and default config will be saved.
            save_config(config, output_dir / 'config.yaml')
            save_config(get_env_info(config), output_dir / 'env.yaml')
            diff = find_config_diff(config)
            if diff is not None:
                save_config(diff, output_dir / 'config_min.yaml')

    logger = create_logger(name=__name__,
                           distributed_rank=get_rank(),
                           output_dir=output_dir,
                           filename='log.txt')
    logger.info(config)
    logger.info(get_env_info(config))

    train_loader, val_loader = create_dataloader(config, is_train=True)

    model = create_model(config)
    # Multiply-and-ACcumulate(MAC): ops
    macs, n_params = count_op(config, model)
    logger.info(f'MACs   : {macs}')
    logger.info(f'#params: {n_params}')
    # creating optimizer: SGD with nesterov momentum, adam, amsgrad, adabound, adaboundw or lars.
    optimizer = create_optimizer(config, model)
    # some AMP(Automatic mixed precision) settings
    if config.device != 'cpu':
        model, optimizer = apex.amp.initialize(
            model, optimizer, opt_level=config.train.precision)
    # create data parallel model or distributed data
    model = apply_data_parallel_wrapper(config, model)

    # set up scheduler and warm up scheduler
    # steps per epoch: how many batches in an epoch
    scheduler = create_scheduler(config,
                                 optimizer,
                                 steps_per_epoch=len(train_loader))
    # create checkponit, do ot use torch's default checkpoint saver because it can't save scheduler
    checkpointer = Checkpointer(model,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                save_dir=output_dir,
                                save_to_disk=get_rank() == 0)

    start_epoch = config.train.start_epoch
    # last_epoch is used to resume training, here normally we should start from config.train.start_epoch
    scheduler.last_epoch = start_epoch
    # The resume training supports multiple modes:
    # 1. resume = True, loading model from the last training checkpoint and following the global step and config
    # 2. resume = False, training checkpoint is specified, load checkpoint to cpu
    if config.train.resume:
        checkpoint_config = checkpointer.resume_or_load('', resume=True)
        global_step = checkpoint_config['global_step']
        start_epoch = checkpoint_config['epoch']
        config.defrost()
        config.merge_from_other_cfg(ConfigNode(checkpoint_config['config']))
        config.freeze()
    elif config.train.checkpoint != '':
        checkpoint = torch.load(config.train.checkpoint, map_location='cpu')
        if isinstance(model,
                      (nn.DataParallel, nn.parallel.DistributedDataParallel)):
            model.module.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'])
    # Two TensorBoard writer:
    # First writer for this run of training(maybe it's resuming training)
    # Second writer follows the global steps and records the global run.
    if get_rank() == 0 and config.train.use_tensorboard:
        tensorboard_writer = create_tensorboard_writer(
            config, output_dir, purge_step=config.train.start_epoch + 1)
        tensorboard_writer2 = create_tensorboard_writer(
            config, output_dir / 'running', purge_step=global_step + 1)
    else:
        tensorboard_writer = DummyWriter()
        tensorboard_writer2 = DummyWriter()

    train_loss, val_loss = create_loss(config)

    if (config.train.val_period > 0 and start_epoch == 0
            and config.train.val_first):
        # validate the model from epoch 0
        validate(0, config, model, val_loss, val_loader, logger,
                 tensorboard_writer)

    for epoch, seed in enumerate(epoch_seeds[start_epoch:], start_epoch):
        epoch += 1

        np.random.seed(seed)
        train(epoch, config, model, optimizer, scheduler, train_loss,
              train_loader, logger, tensorboard_writer, tensorboard_writer2)

        if config.train.val_period > 0 and (epoch % config.train.val_period
                                            == 0):
            validate(epoch, config, model, val_loss, val_loader, logger,
                     tensorboard_writer)

        tensorboard_writer.flush()
        tensorboard_writer2.flush()

        if (epoch % config.train.checkpoint_period
                == 0) or (epoch == config.scheduler.epochs):
            checkpoint_config = {
                'epoch': epoch,
                'global_step': global_step,
                'config': config.as_dict(),
            }
            checkpointer.save(f'checkpoint_{epoch:05d}', **checkpoint_config)

    tensorboard_writer.close()
    tensorboard_writer2.close()
Ejemplo n.º 25
0
    def test_loading_objects_with_expected_shape_mismatches(self) -> None:
        def _get_model() -> torch.nn.Module:
            m = nn.Sequential(nn.Conv2d(2, 2, 1))
            m.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
            m = torch.quantization.prepare_qat(m)
            return m

        m1, m2 = _get_model(), _get_model()
        # Calibrate m1 with data to populate the observer stats
        m1(torch.randn(4, 2, 4, 4))
        # Load m1's checkpoint into m2. This should work without errors even
        # though the shapes of per-channel observer buffers do not match.
        with TemporaryDirectory() as f:
            checkpointer = Checkpointer(m1, save_dir=f)
            checkpointer.save("checkpoint_file")

            # in the same folder
            fresh_checkpointer = Checkpointer(m2, save_dir=f)
            self.assertTrue(fresh_checkpointer.has_checkpoint())
            self.assertEqual(
                fresh_checkpointer.get_checkpoint_file(),
                os.path.join(f, "checkpoint_file.pth"),
            )
            fresh_checkpointer.load(fresh_checkpointer.get_checkpoint_file())
            # Run the expected input through the network with observers
            # disabled and fake_quant enabled. If buffers were loaded correctly
            # into per-channel observers, this line will not crash.
            m2.apply(torch.quantization.disable_observer)
            m2.apply(torch.quantization.enable_fake_quant)
            m2(torch.randn(4, 2, 4, 4))
Ejemplo n.º 26
0
def main():
    global global_step

    config = load_config()

    set_seed(config)
    setup_cudnn(config)

    epoch_seeds = np.random.randint(np.iinfo(np.int32).max // 2,
                                    size=config.scheduler.epochs)

    if config.train.distributed:
        dist.init_process_group(backend=config.train.dist.backend,
                                init_method=config.train.dist.init_method,
                                rank=config.train.dist.node_rank,
                                world_size=config.train.dist.world_size)
        torch.cuda.set_device(config.train.dist.local_rank)

    output_dir = pathlib.Path(config.train.output_dir)
    if get_rank() == 0:
        if not config.train.resume and output_dir.exists():
            raise RuntimeError(
                f'Output directory `{output_dir.as_posix()}` already exists')
        output_dir.mkdir(exist_ok=True, parents=True)
        if not config.train.resume:
            save_config(config, output_dir / 'config.yaml')
            save_config(get_env_info(config), output_dir / 'env.yaml')
            diff = find_config_diff(config)
            if diff is not None:
                save_config(diff, output_dir / 'config_min.yaml')

    logger = create_logger(name=__name__,
                           distributed_rank=get_rank(),
                           output_dir=output_dir,
                           filename='log.txt')
    logger.info(config)
    logger.info(get_env_info(config))

    train_loader, val_loader = create_dataloader(config, is_train=True)

    model = create_model(config)
    macs, n_params = count_op(config, model)
    logger.info(f'MACs  : {macs}')
    logger.info(f'#params: {n_params}')

    optimizer = create_optimizer(config, model)
    model, optimizer = apex.amp.initialize(model,
                                           optimizer,
                                           opt_level=config.train.precision)
    model = apply_data_parallel_wrapper(config, model)

    scheduler = create_scheduler(config,
                                 optimizer,
                                 steps_per_epoch=len(train_loader))
    checkpointer = Checkpointer(model,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                save_dir=output_dir,
                                save_to_disk=get_rank() == 0)

    start_epoch = config.train.start_epoch
    scheduler.last_epoch = start_epoch
    if config.train.resume:
        checkpoint_config = checkpointer.resume_or_load('', resume=True)
        global_step = checkpoint_config['global_step']
        start_epoch = checkpoint_config['epoch']
        config.defrost()
        config.merge_from_other_cfg(ConfigNode(checkpoint_config['config']))
        config.freeze()
    elif config.train.checkpoint != '':
        checkpoint = torch.load(config.train.checkpoint, map_location='cpu')
        if isinstance(model,
                      (nn.DataParallel, nn.parallel.DistributedDataParallel)):
            model.module.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'])

    if get_rank() == 0 and config.train.use_tensorboard:
        tensorboard_writer = create_tensorboard_writer(
            config, output_dir, purge_step=config.train.start_epoch + 1)
        tensorboard_writer2 = create_tensorboard_writer(
            config, output_dir / 'running', purge_step=global_step + 1)
    else:
        tensorboard_writer = DummyWriter()
        tensorboard_writer2 = DummyWriter()

    train_loss, val_loss = create_loss(config)

    if (config.train.val_period > 0 and start_epoch == 0
            and config.train.val_first):
        validate(0, config, model, val_loss, val_loader, logger,
                 tensorboard_writer)

    for epoch, seed in enumerate(epoch_seeds[start_epoch:], start_epoch):
        epoch += 1

        np.random.seed(seed)
        train(epoch, config, model, optimizer, scheduler, train_loss,
              train_loader, logger, tensorboard_writer, tensorboard_writer2)

        if config.train.val_period > 0 and (epoch %
                                            config.train.val_period == 0):
            validate(epoch, config, model, val_loss, val_loader, logger,
                     tensorboard_writer)

        tensorboard_writer.flush()
        tensorboard_writer2.flush()

        if (epoch % config.train.checkpoint_period == 0) or (
                epoch == config.scheduler.epochs):
            checkpoint_config = {
                'epoch': epoch,
                'global_step': global_step,
                'config': config.as_dict(),
            }
            checkpointer.save(f'checkpoint_{epoch:05d}', **checkpoint_config)

    tensorboard_writer.close()
    tensorboard_writer2.close()