Пример #1
0
    def _establish_initial_weights(self):
        location = self.desc.run_path(self.replicate, 0)
        if models.registry.exists(location, self.desc.train_start_step):
            if get_platform().is_primary_process:
                print('Initial weights loaded from {}'.format(
                    paths.model(location, self.desc.train_start_step)))
            return

        new_model = models.registry.get(self.desc.model_hparams,
                                        outputs=self.desc.train_outputs)

        # If there was a pretrained model, retrieve its final weights and adapt them for training.
        if self.desc.pretrain_training_hparams is not None:
            pretrain_loc = self.desc.run_path(self.replicate, 'pretrain')
            if get_platform().is_primary_process:
                print('Initial weights loaded from pretrained checkpoint {}'.
                      format(
                          paths.model(pretrain_loc,
                                      self.desc.pretrain_end_step)))
            old = models.registry.load(pretrain_loc,
                                       self.desc.pretrain_end_step,
                                       self.desc.model_hparams,
                                       self.desc.pretrain_outputs)
            state_dict = {k: v for k, v in old.state_dict().items()}

            # Select a new output layer if number of classes differs.
            if self.desc.train_outputs != self.desc.pretrain_outputs:
                state_dict.update({
                    k: new_model.state_dict()[k]
                    for k in new_model.output_layer_names
                })

            new_model.load_state_dict(state_dict)

        new_model.save(location, self.desc.train_start_step)
Пример #2
0
    def test_save(self):
        state1 = self.get_state(self.model)

        mask = Mask.ones_like(self.model)
        pruned_model = PrunedModel(self.model, mask)
        pruned_model.save(self.root, Step.zero(20))

        self.assertTrue(os.path.exists(paths.model(self.root, Step.zero(20))))

        self.model.load_state_dict(torch.load(paths.model(self.root, Step.zero(20))))
        self.assertStateEqual(state1, self.get_state(self.model))
Пример #3
0
    def test_level3_4it_pretrain2it(self):
        self.desc.pretrain_dataset_hparams = copy.deepcopy(
            self.desc.dataset_hparams)
        self.desc.pretrain_training_hparams = copy.deepcopy(
            self.desc.training_hparams)
        self.desc.pretrain_training_hparams.training_steps = '2it'
        self.desc.training_hparams.training_steps = '4it'
        LotteryRunner(replicate=2, levels=3, desc=self.desc,
                      verbose=False).run()

        # Check that the pretrain weights are present.
        pretrain_root = self.desc.run_path(2, 'pretrain')
        self.assertLevelFilesPresent(pretrain_root,
                                     self.to_step('0it'),
                                     self.to_step('2it'),
                                     masks=False)

        # Load the pretrain and level0 start weights to ensure they're the same.
        pretrain_end_weights = paths.model(self.desc.run_path(2, 'pretrain'),
                                           self.desc.pretrain_end_step)
        pretrain_end_weights = {
            k: v.numpy()
            for k, v in torch.load(pretrain_end_weights).items()
        }

        level0_weights = paths.model(self.desc.run_path(2, 0),
                                     self.desc.train_start_step)
        level0_weights = {
            k: v.numpy()
            for k, v in torch.load(level0_weights).items()
        }

        self.assertStateEqual(pretrain_end_weights, level0_weights)

        # Evaluate each of the pruning levels.
        for level in range(0, 2):
            level_root = self.desc.run_path(2, level)
            self.assertLevelFilesPresent(level_root, self.to_step('2it'),
                                         self.to_step('4it'))

            # Ensure that the initial weights are a masked version of the level 0 weights
            # (which are identical to the weights at the end of pretraining).
            mask = Mask.load(level_root).numpy()
            level_weights = paths.model(level_root, self.desc.train_start_step)
            level_weights = {
                k: v.numpy()
                for k, v in torch.load(level_weights).items()
            }
            self.assertStateEqual(
                level_weights,
                {k: v * mask.get(k, 1)
                 for k, v in level0_weights.items()})
Пример #4
0
 def assertLevelFilesPresent(self,
                             level_root,
                             start_step,
                             end_step,
                             masks=False):
     with self.subTest(level_root=level_root):
         self.assertTrue(os.path.exists(paths.model(level_root,
                                                    start_step)))
         self.assertTrue(os.path.exists(paths.model(level_root, end_step)))
         self.assertTrue(os.path.exists(paths.logger(level_root)))
         if masks:
             self.assertTrue(os.path.exists(paths.mask(level_root)))
             self.assertTrue(
                 os.path.exists(paths.sparsity_report(level_root)))
Пример #5
0
    def test_last_step(self):
        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    start_step=Step.from_epoch(2, 11, len(self.train_loader)),
                    end_step=Step.from_epoch(3, 0, len(self.train_loader)))

        end_state = TestStandardCallbacks.get_state(self.model)

        # Check that final state has been saved.
        end_loc = paths.model(self.root,
                              Step.from_epoch(3, 0, len(self.train_loader)))
        self.assertTrue(os.path.exists(end_loc))

        # Check that the final state that is saved matches the final state of the network.
        self.model.load_state_dict(torch.load(end_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(end_state, saved_state)

        # Check that the logger has the right number of entries.
        self.assertTrue(os.path.exists(paths.logger(self.root)))
        logger = MetricLogger.create_from_file(self.root)
        self.assertEqual(len(logger.get_data('train_loss')), 1)
        self.assertEqual(len(logger.get_data('test_loss')), 1)
        self.assertEqual(len(logger.get_data('train_accuracy')), 1)
        self.assertEqual(len(logger.get_data('test_accuracy')), 1)

        # Check that the checkpoint file exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))
Пример #6
0
    def test_first_step(self):
        init_state = TestStandardCallbacks.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    end_step=Step.from_epoch(0, 1, len(self.train_loader)))

        # Check that the initial state has been saved.
        model_state_loc = paths.model(self.root,
                                      Step.zero(len(self.train_loader)))
        self.assertTrue(os.path.exists(model_state_loc))

        # Check that the model state at init reflects the saved state.
        self.model.load_state_dict(torch.load(model_state_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(init_state, saved_state)

        # Check that the checkpoint file exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Check that the logger file doesn't exist.
        self.assertFalse(os.path.exists(paths.logger(self.root)))
Пример #7
0
 def save(self, save_location: str, save_step: Step, suffix: str = ""):
     if not get_platform().is_primary_process: return
     if not get_platform().exists(save_location):
         get_platform().makedirs(save_location)
     get_platform().save_model(
         self.state_dict(),
         paths.model(save_location, save_step, suffix=suffix))
Пример #8
0
    def test_level3_2it(self):
        self.desc.training_hparams.training_steps = '2it'
        LotteryRunner(replicate=2, levels=3, desc=self.desc,
                      verbose=False).run()

        level0_weights = paths.model(self.desc.run_path(2, 0),
                                     self.to_step('0it'))
        level0_weights = {
            k: v.numpy()
            for k, v in torch.load(level0_weights).items()
        }

        for level in range(0, 4):
            level_root = self.desc.run_path(2, level)
            self.assertLevelFilesPresent(level_root, self.to_step('0it'),
                                         self.to_step('2it'))

            # Check the mask.
            pct = 0.8**level
            mask = Mask.load(level_root).numpy()

            # Check the mask itself.
            total, total_present = 0.0, 0.0
            for v in mask.values():
                total += v.size
                total_present += np.sum(v)
            self.assertTrue(np.allclose(pct, total_present / total, atol=0.01))

            # Check the sparsity report.
            with open(paths.sparsity_report(level_root)) as fp:
                sparsity_report = json.loads(fp.read())
            self.assertTrue(
                np.allclose(pct,
                            sparsity_report['unpruned'] /
                            sparsity_report['total'],
                            atol=0.01))

            # Ensure that the initial weights are a masked version of the level 0 weights.
            level_weights = paths.model(level_root, self.to_step('0it'))
            level_weights = {
                k: v.numpy()
                for k, v in torch.load(level_weights).items()
            }
            self.assertStateEqual(
                level_weights,
                {k: v * mask.get(k, 1)
                 for k, v in level0_weights.items()})
Пример #9
0
def load(save_location: str,
         save_step: Step,
         model_hparams: ModelHparams,
         outputs=None):
    state_dict = get_platform().load_model(
        paths.model(save_location, save_step))
    model = get(model_hparams, outputs)
    model.load_state_dict(state_dict)
    return model
Пример #10
0
    def test_end_to_end(self):
        init_loc = paths.model(self.root, Step.zero(len(self.train_loader)))
        end_loc = paths.model(self.root,
                              Step.from_epoch(3, 0, len(self.train_loader)))

        init_state = TestStandardCallbacks.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    start_step=Step.from_epoch(0, 0, len(self.train_loader)),
                    end_step=Step.from_epoch(3, 0, len(self.train_loader)))

        end_state = TestStandardCallbacks.get_state(self.model)

        # Check that final state has been saved.
        self.assertTrue(os.path.exists(init_loc))
        self.assertTrue(os.path.exists(end_loc))

        # Check that the checkpoint file still exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Check that the initial and final states match those that were saved.
        self.model.load_state_dict(torch.load(init_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(init_state, saved_state)

        self.model.load_state_dict(torch.load(end_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(end_state, saved_state)

        # Check that the logger has the right number of entries.
        self.assertTrue(os.path.exists(paths.logger(self.root)))
        logger = MetricLogger.create_from_file(self.root)
        self.assertEqual(len(logger.get_data('train_loss')), 4)
        self.assertEqual(len(logger.get_data('test_loss')), 4)
        self.assertEqual(len(logger.get_data('train_accuracy')), 4)
        self.assertEqual(len(logger.get_data('test_accuracy')), 4)
Пример #11
0
 def run(self):
     if get_platform().exists(
             paths.model(self.desc.run_path(self.replicate),
                         self.desc.end_step)):
         return
     if self.verbose and get_platform().is_primary_process:
         print('=' * 82 +
               f'\nTraining a Model (Replicate {self.replicate})\n' +
               '-' * 82)
         print(self.desc.display)
         print(f'Output Location: {self.desc.run_path(self.replicate)}' +
               '\n' + '=' * 82 + '\n')
     self.desc.save(self.desc.run_path(self.replicate))
     train.standard_train(models.registry.get(self.desc.model_hparams),
                          self.desc.run_path(self.replicate),
                          self.desc.dataset_hparams,
                          self.desc.training_hparams,
                          evaluate_every_epoch=self.evaluate_every_epoch,
                          verbose=self.verbose,
                          weight_save_steps=self.weight_save_steps)
Пример #12
0
    def test_save_load_exists(self):
        hp = registry.get_default_hparams('cifar_resnet_20')
        model = registry.get(hp.model_hparams)
        step = Step.from_iteration(27, 17)
        model_location = paths.model(self.root, step)
        model_state = TestSaveLoadExists.get_state(model)

        self.assertFalse(registry.exists(self.root, step))
        self.assertFalse(os.path.exists(model_location))

        # Test saving.
        model.save(self.root, step)
        self.assertTrue(registry.exists(self.root, step))
        self.assertTrue(os.path.exists(model_location))

        # Test loading.
        model = registry.get(hp.model_hparams)
        model.load_state_dict(torch.load(model_location))
        self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model))

        model = registry.load(self.root, step, hp.model_hparams)
        self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model))
Пример #13
0
def exists(save_location, save_step, suffix=""):
    return get_platform().exists(paths.model(save_location, save_step, suffix))
Пример #14
0
def exists(save_location, save_step):
    return get_platform().exists(paths.model(save_location, save_step))