Ejemplo n.º 1
0
    def test_from_last_checkpoint_model(self) -> None:
        """
        test that loading works even if they differ by a prefix.
        """
        for trained_model, fresh_model in [
            (self._create_model(), self._create_model()),
            (nn.DataParallel(self._create_model()), self._create_model()),
            (self._create_model(), nn.DataParallel(self._create_model())),
            (
                nn.DataParallel(self._create_model()),
                nn.DataParallel(self._create_model()),
            ),
        ]:

            with TemporaryDirectory() as f:
                checkpointer = Checkpointer(trained_model, save_dir=f)
                checkpointer.save("checkpoint_file")

                # in the same folder
                fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
                self.assertTrue(fresh_checkpointer.has_checkpoint())
                self.assertEqual(
                    fresh_checkpointer.get_checkpoint_file(),
                    os.path.join(f, "checkpoint_file.pth"),
                )
                fresh_checkpointer.load(
                    fresh_checkpointer.get_checkpoint_file())

            for trained_p, loaded_p in zip(trained_model.parameters(),
                                           fresh_model.parameters()):
                # different tensor references
                self.assertFalse(id(trained_p) == id(loaded_p))
                # same content
                self.assertTrue(trained_p.cpu().equal(loaded_p.cpu()))
Ejemplo n.º 2
0
    def test_loading_objects_with_expected_shape_mismatches(self) -> None:
        def _get_model() -> torch.nn.Module:
            m = nn.Sequential(nn.Conv2d(2, 2, 1))
            m.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
            m = torch.quantization.prepare_qat(m)
            return m

        m1, m2 = _get_model(), _get_model()
        # Calibrate m1 with data to populate the observer stats
        m1(torch.randn(4, 2, 4, 4))
        # Load m1's checkpoint into m2. This should work without errors even
        # though the shapes of per-channel observer buffers do not match.
        with TemporaryDirectory() as f:
            checkpointer = Checkpointer(m1, save_dir=f)
            checkpointer.save("checkpoint_file")

            # in the same folder
            fresh_checkpointer = Checkpointer(m2, save_dir=f)
            self.assertTrue(fresh_checkpointer.has_checkpoint())
            self.assertEqual(
                fresh_checkpointer.get_checkpoint_file(),
                os.path.join(f, "checkpoint_file.pth"),
            )
            fresh_checkpointer.load(fresh_checkpointer.get_checkpoint_file())
            # Run the expected input through the network with observers
            # disabled and fake_quant enabled. If buffers were loaded correctly
            # into per-channel observers, this line will not crash.
            m2.apply(torch.quantization.disable_observer)
            m2.apply(torch.quantization.enable_fake_quant)
            m2(torch.randn(4, 2, 4, 4))
Ejemplo n.º 3
0
    def test_from_name_file_model(self):
        """
        test that loading works even if they differ by a prefix.
        """
        for trained_model, fresh_model in [
            (self._create_model(), self._create_model()),
            (nn.DataParallel(self._create_model()), self._create_model()),
            (self._create_model(), nn.DataParallel(self._create_model())),
            (
                nn.DataParallel(self._create_model()),
                nn.DataParallel(self._create_model()),
            ),
        ]:
            with TemporaryDirectory() as f:
                checkpointer = Checkpointer(trained_model,
                                            save_dir=f,
                                            save_to_disk=True)
                checkpointer.save("checkpoint_file")

                # on different folders.
                with TemporaryDirectory() as g:
                    fresh_checkpointer = Checkpointer(fresh_model, save_dir=g)
                    self.assertFalse(fresh_checkpointer.has_checkpoint())
                    self.assertEqual(fresh_checkpointer.get_checkpoint_file(),
                                     "")
                    fresh_checkpointer.load(
                        os.path.join(f, "checkpoint_file.pth"))

            for trained_p, loaded_p in zip(trained_model.parameters(),
                                           fresh_model.parameters()):
                # different tensor references.
                self.assertFalse(id(trained_p) == id(loaded_p))
                # same content.
                self.assertTrue(trained_p.equal(loaded_p))
Ejemplo n.º 4
0
    def test_load_lazy_module(self) -> None:
        def _get_model() -> nn.Sequential:  # pyre-fixme[11]
            return nn.Sequential(nn.LazyLinear(10))

        m1, m2 = _get_model(), _get_model()
        m1(torch.randn(4, 2, 4, 4))  # initialize m1, but not m2
        # Load m1's checkpoint into m2.
        with TemporaryDirectory() as f:
            checkpointer = Checkpointer(m1, save_dir=f)
            checkpointer.save("checkpoint_file")

            fresh_checkpointer = Checkpointer(m2, save_dir=f)
            self.assertTrue(fresh_checkpointer.has_checkpoint())
            self.assertEqual(
                fresh_checkpointer.get_checkpoint_file(),
                os.path.join(f, "checkpoint_file.pth"),
            )
            fresh_checkpointer.load(fresh_checkpointer.get_checkpoint_file())
            self.assertTrue(torch.equal(m1[0].weight, m2[0].weight))
Ejemplo n.º 5
0
    def test_checkpointables(self) -> None:
        """
        Test saving and loading checkpointables.
        """
        class CheckpointableObj:
            """
            A dummy checkpointableObj class with state_dict and load_state_dict
            methods.
            """
            def __init__(self):
                self.state = {
                    self.random_handle(): self.random_handle()
                    for i in range(10)
                }

            def random_handle(self, str_len=100) -> str:
                """
                Generate a random string of fixed length.
                Args:
                    str_len (str): length of the output string.
                Returns:
                    (str): random generated handle.
                """
                letters = string.ascii_uppercase
                return "".join(random.choice(letters) for i in range(str_len))

            def state_dict(self):
                """
                Return the state.
                Returns:
                    (dict): return the state.
                """
                return self.state

            def load_state_dict(self, state) -> None:
                """
                Load the state from a given state.
                Args:
                    state (dict): a key value dictionary.
                """
                self.state = copy.deepcopy(state)

        trained_model, fresh_model = self._create_model(), self._create_model()
        with TemporaryDirectory() as f:
            checkpointables = CheckpointableObj()
            checkpointer = Checkpointer(
                trained_model,
                save_dir=f,
                save_to_disk=True,
                checkpointables=checkpointables,
            )
            checkpointer.save("checkpoint_file")
            # in the same folder
            fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
            self.assertTrue(fresh_checkpointer.has_checkpoint())
            self.assertEqual(
                fresh_checkpointer.get_checkpoint_file(),
                os.path.join(f, "checkpoint_file.pth"),
            )
            checkpoint = fresh_checkpointer.load(
                fresh_checkpointer.get_checkpoint_file())
            state_dict = checkpointables.state_dict()
            for key, _ in state_dict.items():
                self.assertTrue(
                    checkpoint["checkpointables"].get(key) is not None)
                self.assertTrue(
                    checkpoint["checkpointables"][key] == state_dict[key])
Ejemplo n.º 6
0
def main(cfg: DictConfig) -> None:

    if "experiments" in cfg.keys():
        cfg = OmegaConf.merge(cfg, cfg.experiments)

    if "debug" in cfg.keys():
        logger.info(f"Run script in debug")
        cfg = OmegaConf.merge(cfg, cfg.debug)

    # A logger for this file
    logger = logging.getLogger(__name__)

    # NOTE: hydra causes the python file to run in hydra.run.dir by default
    logger.info(f"Run script in {HydraConfig.get().run.dir}")

    writer = SummaryWriter(log_dir=cfg.train.tensorboard_dir)

    checkpoints_dir = Path(cfg.train.checkpoints_dir)
    if not checkpoints_dir.exists():
        checkpoints_dir.mkdir(parents=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    image_shape = (cfg.train.channels, cfg.train.image_height,
                   cfg.train.image_width)

    # NOTE: With hydra, the python file runs in hydra.run.dir by default, so set the dataset path to a full path or an appropriate relative path
    dataset_path = Path(cfg.dataset.root) / cfg.dataset.frames
    split_path = Path(cfg.dataset.root) / cfg.dataset.split_file
    assert dataset_path.exists(), "Video image folder not found"
    assert (split_path.exists()
            ), "The file that describes the split of train/test not found."

    # Define training set
    train_dataset = Dataset(
        dataset_path=dataset_path,
        split_path=split_path,
        split_number=cfg.dataset.split_number,
        input_shape=image_shape,
        sequence_length=cfg.train.sequence_length,
        training=True,
    )

    # Define train dataloader
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=cfg.train.batch_size,
        shuffle=True,
        num_workers=cfg.train.num_workers,
    )

    # Define test set
    test_dataset = Dataset(
        dataset_path=dataset_path,
        split_path=split_path,
        split_number=cfg.dataset.split_number,
        input_shape=image_shape,
        sequence_length=cfg.train.sequence_length,
        training=False,
    )

    # Define test dataloader
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=cfg.train.batch_size,
        shuffle=False,
        num_workers=cfg.train.num_workers,
    )

    # Classification criterion
    criterion = nn.CrossEntropyLoss().to(device)

    # Define network
    model = CNNLSTM(
        num_classes=train_dataset.num_classes,
        latent_dim=cfg.train.latent_dim,
        lstm_layers=cfg.train.lstm_layers,
        hidden_dim=cfg.train.hidden_dim,
        bidirectional=cfg.train.bidirectional,
        attention=cfg.train.attention,
    )
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    checkpointer = Checkpointer(
        model,
        optimizer=optimizer,
        # scheduler=scheduler,
        save_dir=cfg.train.checkpoints_dir,
        save_to_disk=True,
    )

    if cfg.train.resume:
        if not checkpointer.has_checkpoint():
            start_epoch = 0
        else:
            ckpt = checkpointer.resume_or_load("", resume=True)
            start_epoch = ckpt["epoch"]
            model.to(device)
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)
    elif cfg.train.checkpoint_model != "":
        ckpt = torch.load(cfg.train.checkpoint_model, map_location="cpu")
        model.load_state_dict(ckpt["model"])
        model.to(device)
        start_epoch = 0
    else:
        start_epoch = 0

    for epoch in range(start_epoch, cfg.train.num_epochs):
        epoch += 1
        epoch_metrics = {"loss": [], "acc": []}
        timer = Timer()
        for batch_i, (X, y) in enumerate(train_dataloader):
            batch_i += 1
            if X.size(0) == 1:
                continue

            image_sequences = Variable(X.to(device), requires_grad=True)
            labels = Variable(y.to(device), requires_grad=False)

            optimizer.zero_grad()

            # Reset LSTM hidden state
            model.lstm.reset_hidden_state()

            # Get sequence predictions
            predictions = model(image_sequences)

            # Compute metrics
            loss = criterion(predictions, labels)
            acc = (
                predictions.detach().argmax(1) == labels).cpu().numpy().mean()

            loss.backward()
            optimizer.step()

            # Keep track of epoch metrics
            epoch_metrics["loss"].append(loss.item())
            epoch_metrics["acc"].append(acc)

            # Determine approximate time left
            batches_done = (epoch - 1) * len(train_dataloader) + (batch_i - 1)
            batches_left = cfg.train.num_epochs * len(
                train_dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           timer.seconds())
            time_iter = round(timer.seconds(), 3)
            timer.reset()

            logger.info(
                f'Training - [Epoch: {epoch}/{cfg.train.num_epochs}] [Batch: {batch_i}/{len(train_dataloader)}] [Loss: {np.mean(epoch_metrics["loss"]):.3f}] [Acc: {np.mean(epoch_metrics["acc"]):.3f}] [ETA: {time_left}] [Iter time: {time_iter}s/it]'
            )

            # Empty cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        writer.add_scalar("train/loss", np.mean(epoch_metrics["loss"]), epoch)
        writer.add_scalar("train/acc", np.mean(epoch_metrics["acc"]), epoch)

        def test_model(epoch):
            """ Evaluate the model on the test set """
            model.eval()
            test_metrics = {"loss": [], "acc": []}
            timer = Timer()
            for batch_i, (X, y) in enumerate(test_dataloader):
                batch_i += 1
                image_sequences = Variable(X.to(device), requires_grad=False)
                labels = Variable(y, requires_grad=False).to(device)

                with torch.no_grad():
                    # Reset LSTM hidden state
                    model.lstm.reset_hidden_state()
                    # Get sequence predictions
                    predictions = model(image_sequences)

                # Compute metrics
                loss = criterion(predictions, labels)
                acc = (predictions.detach().argmax(1) == labels
                       ).cpu().numpy().mean()

                # Keep track of loss and accuracy
                test_metrics["loss"].append(loss.item())
                test_metrics["acc"].append(acc)

                # Determine approximate time left
                batches_done = batch_i - 1
                batches_left = len(test_dataloader) - batches_done
                time_left = datetime.timedelta(seconds=batches_left *
                                               timer.seconds())
                time_iter = round(timer.seconds(), 3)
                timer.reset()

                # Log test performance
                logger.info(
                    f'Testing - [Epoch: {epoch}/{cfg.train.num_epochs}] [Batch: {batch_i}/{len(test_dataloader)}] [Loss: {np.mean(test_metrics["loss"]):.3f}] [Acc: {np.mean(test_metrics["acc"]):.3f}] [ETA: {time_left}] [Iter time: {time_iter}s/it]'
                )

            writer.add_scalar("test/loss", np.mean(test_metrics["loss"]),
                              epoch)
            writer.add_scalar("test/acc", np.mean(test_metrics["acc"]), epoch)

            model.train()

        # Evaluate the model on the test set
        test_model(epoch)

        # Save model checkpoint
        if epoch % cfg.train.checkpoint_interval == 0:
            checkpointer.save(f"checkpoint_{epoch:04}", epoch=epoch)

    writer.close()