Exemplo n.º 1
0
 def test_wrong_metric(self):
     with pytest.raises(ValueError):
         compute_metrics.evaluate(metric='wrong_metric',
                                  log_dir=self.log_dir,
                                  netG=self.netG,
                                  dataset_name=self.dataset_name,
                                  evaluate_step=self.evaluate_step,
                                  device=self.device)
Exemplo n.º 2
0
 def test_arguments(self):
     for metric in ['fid', 'kid', 'inception_score']:
         with pytest.raises(ValueError):
             compute_metrics.evaluate(metric=metric,
                                      log_dir=self.log_dir,
                                      netG=self.netG,
                                      dataset_name=self.dataset_name,
                                      evaluate_step=self.evaluate_step,
                                      device=self.device)
Exemplo n.º 3
0
    def test_evaluate_is(self):
        kwargs = {
            'metric': 'inception_score',
            'log_dir': self.log_dir,
            'netG': self.netG,
            'num_samples': 10,
            'num_runs': 3,
            'evaluate_step': self.evaluate_step,
            'device': self.device,
            'start_seed': self.start_seed,
        }

        scores = compute_metrics.evaluate(**kwargs)[self.evaluate_step]
        assert type(scores) == list
        assert all(map(lambda x: type(x) == float, scores))
Exemplo n.º 4
0
    def test_evaluate_kid(self):
        kwargs = {
            'metric': 'kid',
            'log_dir': self.log_dir,
            'evaluate_step': self.evaluate_step,
            'num_subsets': 10,
            'subset_size': 10,
            'netG': self.netG,
            'device': self.device,
            'start_seed': self.start_seed,
            'dataset_name': self.dataset_name
        }

        scores = compute_metrics.evaluate(**kwargs)[self.evaluate_step]
        assert type(scores) == list
        assert all(map(lambda x: type(x) == float, scores))
Exemplo n.º 5
0
    def test_evaluate_fid(self):
        kwargs = {
            'metric': 'fid',
            'log_dir': self.log_dir,
            'netG': self.netG,
            'dataset': self.dataset,
            'num_real_samples': 10,
            'num_fake_samples': 10,
            'evaluate_step': self.evaluate_step,
            'start_seed': self.start_seed,
            'device': self.device
        }

        scores = compute_metrics.evaluate(**kwargs)[self.evaluate_step]
        assert type(scores) == list
        assert all(map(lambda x: type(x) == float, scores))
Exemplo n.º 6
0
    def test_arguments(self):
        for metric in ['fid', 'kid', 'inception_score']:
            with pytest.raises(ValueError):
                compute_metrics.evaluate(metric=metric,
                                         log_dir=self.log_dir,
                                         netG=self.netG,
                                         dataset=self.dataset,
                                         evaluate_step=self.evaluate_step,
                                         device=self.device)

        # Both evaluate step and evaluate range defined
        with pytest.raises(ValueError):
            compute_metrics.evaluate(metric=metric,
                                     log_dir=self.log_dir,
                                     netG=self.netG,
                                     dataset=self.dataset,
                                     evaluate_range=(1000, 100000, 1000),
                                     evaluate_step=self.evaluate_step,
                                     device=self.device)

        # Faulty evaluate range
        with pytest.raises(ValueError):
            compute_metrics.evaluate(metric=metric,
                                     log_dir=self.log_dir,
                                     netG=self.netG,
                                     dataset=self.dataset,
                                     evaluate_range=(1000),
                                     device=self.device)

        with pytest.raises(ValueError):
            compute_metrics.evaluate(metric=metric,
                                     log_dir=self.log_dir,
                                     netG=self.netG,
                                     dataset=self.dataset,
                                     evaluate_range=('a', 'b', 'c'),
                                     device=self.device)

        with pytest.raises(ValueError):
            compute_metrics.evaluate(metric=metric,
                                     log_dir=self.log_dir,
                                     netG=self.netG,
                                     dataset=self.dataset,
                                     evaluate_range=(100, 100, 100, 100),
                                     device=self.device)

        with pytest.raises(ValueError):
            compute_metrics.evaluate(metric=metric,
                                     log_dir=self.log_dir,
                                     netG=self.netG,
                                     dataset=self.dataset,
                                     evaluate_range=None,
                                     device=self.device)

        # Invalid ckpt dir
        with pytest.raises(ValueError):
            compute_metrics.evaluate(metric=metric,
                                     log_dir='does_not_exist',
                                     netG=self.netG,
                                     dataset=self.dataset,
                                     evaluate_step=self.evaluate_step,
                                     device=self.device)