Exemplo n.º 1
0
    def test_create_restore_delete(self):
        # Create the hyperparameters and objects to save.
        hp = models.registry.get_default_hparams('cifar_resnet_20')
        model = models.registry.get(hp.model_hparams)
        optimizer = optimizers.get_optimizer(hp.training_hparams, model)
        dataloader = datasets.registry.get(hp.dataset_hparams)
        step = Step.from_epoch(13, 27, 400)

        # Run one step of SGD.
        examples, labels = next(iter(dataloader))
        optimizer.zero_grad()
        model.train()
        model.loss_criterion(model(examples), labels).backward()
        optimizer.step()

        # Create a fake logger.
        logger = MetricLogger()
        logger.add('test_accuracy', Step.from_epoch(0, 0, 400), 0.1)
        logger.add('test_accuracy', Step.from_epoch(10, 0, 400), 0.5)
        logger.add('test_accuracy', Step.from_epoch(100, 0, 400), 0.8)

        # Save a checkpoint.
        checkpointing.save_checkpoint_callback(self.root, step, model, optimizer, logger)
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Create new models.
        model2 = models.registry.get(hp.model_hparams)
        optimizer2 = optimizers.get_optimizer(hp.training_hparams, model)

        # Ensure the new model has different weights.
        sd1, sd2 = model.state_dict(), model2.state_dict()
        for k in model.prunable_layer_names:
            self.assertFalse(np.array_equal(sd1[k].numpy(), sd2[k].numpy()))

        self.assertIn('momentum_buffer', optimizer.state[optimizer.param_groups[0]['params'][0]])
        self.assertNotIn('momentum_buffer', optimizer2.state[optimizer.param_groups[0]['params'][0]])

        # Restore the checkpointt.
        step2, logger2 = checkpointing.restore_checkpoint(self.root, model2, optimizer2, 400)

        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))
        self.assertEqual(step, step2)
        self.assertEqual(str(logger), str(logger2))

        # Ensure the new model is now the same.
        sd1, sd2 = model.state_dict(), model2.state_dict()
        self.assertEqual(set(sd1.keys()), set(sd2.keys()))
        for k in sd1:
            self.assertTrue(np.array_equal(sd1[k].numpy(), sd2[k].numpy()))

        # Ensure the new optimizer is now the same.
        mom1 = optimizer.state[optimizer.param_groups[0]['params'][0]]['momentum_buffer']
        mom2 = optimizer2.state[optimizer.param_groups[0]['params'][0]]['momentum_buffer']
        self.assertTrue(np.array_equal(mom1.numpy(), mom2.numpy()))
Exemplo n.º 2
0
def restore_checkpoint(output_location, model, optimizer,
                       iterations_per_epoch):
    checkpoint_location = paths.checkpoint(output_location)
    if not get_platform().exists(checkpoint_location):
        return None, None
    checkpoint = get_platform().load_model(checkpoint_location,
                                           map_location=torch.device('cpu'))

    # Handle DataParallel.
    module_in_name = get_platform().is_parallel
    if module_in_name and not all(
            k.startswith('module.') for k in checkpoint['model_state_dict']):
        checkpoint['model_state_dict'] = {
            'module.' + k: v
            for k, v in checkpoint['model_state_dict'].items()
        }
    elif all(k.startswith('module.')
             for k in checkpoint['model_state_dict']) and not module_in_name:
        checkpoint['model_state_dict'] = {
            k[len('module.'):]: v
            for k, v in checkpoint['model_state_dict'].items()
        }

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    step = Step.from_epoch(checkpoint['ep'], checkpoint['it'],
                           iterations_per_epoch)
    logger = MetricLogger.create_from_string(checkpoint['logger'])

    return step, logger
Exemplo n.º 3
0
    def test_last_step(self):
        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    start_step=Step.from_epoch(2, 11, len(self.train_loader)),
                    end_step=Step.from_epoch(3, 0, len(self.train_loader)))

        end_state = TestStandardCallbacks.get_state(self.model)

        # Check that final state has been saved.
        end_loc = paths.model(self.root,
                              Step.from_epoch(3, 0, len(self.train_loader)))
        self.assertTrue(os.path.exists(end_loc))

        # Check that the final state that is saved matches the final state of the network.
        self.model.load_state_dict(torch.load(end_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(end_state, saved_state)

        # Check that the logger has the right number of entries.
        self.assertTrue(os.path.exists(paths.logger(self.root)))
        logger = MetricLogger.create_from_file(self.root)
        self.assertEqual(len(logger.get_data('train_loss')), 1)
        self.assertEqual(len(logger.get_data('test_loss')), 1)
        self.assertEqual(len(logger.get_data('train_accuracy')), 1)
        self.assertEqual(len(logger.get_data('test_accuracy')), 1)

        # Check that the checkpoint file exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))
Exemplo n.º 4
0
    def test_first_step(self):
        init_state = TestStandardCallbacks.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    end_step=Step.from_epoch(0, 1, len(self.train_loader)))

        # Check that the initial state has been saved.
        model_state_loc = paths.model(self.root,
                                      Step.zero(len(self.train_loader)))
        self.assertTrue(os.path.exists(model_state_loc))

        # Check that the model state at init reflects the saved state.
        self.model.load_state_dict(torch.load(model_state_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(init_state, saved_state)

        # Check that the checkpoint file exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Check that the logger file doesn't exist.
        self.assertFalse(os.path.exists(paths.logger(self.root)))
Exemplo n.º 5
0
def save_checkpoint_callback(output_location, step, model, optimizer, logger):
    if get_platform().is_primary_process:
        get_platform().save_model({
            'ep': step.ep,
            'it': step.it,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'logger': str(logger),
        }, paths.checkpoint(output_location))
    get_platform().barrier()
Exemplo n.º 6
0
    def test_end_to_end(self):
        init_loc = paths.model(self.root, Step.zero(len(self.train_loader)))
        end_loc = paths.model(self.root,
                              Step.from_epoch(3, 0, len(self.train_loader)))

        init_state = TestStandardCallbacks.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    start_step=Step.from_epoch(0, 0, len(self.train_loader)),
                    end_step=Step.from_epoch(3, 0, len(self.train_loader)))

        end_state = TestStandardCallbacks.get_state(self.model)

        # Check that final state has been saved.
        self.assertTrue(os.path.exists(init_loc))
        self.assertTrue(os.path.exists(end_loc))

        # Check that the checkpoint file still exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Check that the initial and final states match those that were saved.
        self.model.load_state_dict(torch.load(init_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(init_state, saved_state)

        self.model.load_state_dict(torch.load(end_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(end_state, saved_state)

        # Check that the logger has the right number of entries.
        self.assertTrue(os.path.exists(paths.logger(self.root)))
        logger = MetricLogger.create_from_file(self.root)
        self.assertEqual(len(logger.get_data('train_loss')), 4)
        self.assertEqual(len(logger.get_data('test_loss')), 4)
        self.assertEqual(len(logger.get_data('train_accuracy')), 4)
        self.assertEqual(len(logger.get_data('test_accuracy')), 4)