예제 #1
0
    def test_maxpool(self, _model_maxpool2D_1_last, _config_last):
        """Test maxpooling."""
        config = _config_last
        path_wd = config.get('paths', 'path_wd')
        model_name = config.get('paths', 'filename_ann')
        models.save_model(_model_maxpool2D_1_last,
                          os.path.join(path_wd, model_name + '.h5'))

        updates = {
            'tools': {
                'evaluate_ann': False,
                'normalize': False
            },
            'simulation': {
                'duration': 100,
                'num_to_test': 100,
                'batch_size': 50
            },
            'output': {
                'log_vars': {'activations_n_b_l', 'spiketrains_n_b_l_t'}
            }
        }

        config.read_dict(updates)

        acc = run_pipeline(config)

        acc_ann = get_ann_acc(config)
        assert acc[0] >= 0.9 * acc_ann

        corr = get_correlations(config)
        assert np.all(corr[:-1] > 0.99)
        assert corr[-1] > 0.90
예제 #2
0
    def test_maxpool_fallback(self, _model_maxpool2D_1, _config):
        """Test that maxpooling falls back on average pooling."""
        path_wd = _config.get('paths', 'path_wd')
        model_name = _config.get('paths', 'filename_ann')
        models.save_model(_model_maxpool2D_1,
                          os.path.join(path_wd, model_name + '.h5'))

        updates = {
            'tools': {
                'evaluate_ann': False,
                'normalize': False
            },
            'conversion': {
                'max2avg_pool': True
            },
            'simulation': {
                'duration': 100,
                'num_to_test': 100,
                'batch_size': 50
            },
            'output': {
                'log_vars': {'activations_n_b_l', 'spiketrains_n_b_l_t'}
            }
        }

        _config.read_dict(updates)

        acc = run_pipeline(_config)

        assert acc[0] >= 0.8

        corr = get_correlations(_config)
        assert np.all(corr[:-1] > 0.99)
        assert corr[-1] > 0.90
예제 #3
0
    def test_pipeline(self, _model_4_first):

        self.prepare_model(_model_4_first)

        updates = {
            'tools': {
                'evaluate_ann': False
            },
            'input': {
                'model_lib': 'pytorch'
            },
            'simulation': {
                'duration': 100,
                'num_to_test': 100,
                'batch_size': 50
            },
            'output': {
                'log_vars': {'activations_n_b_l', 'spiketrains_n_b_l_t'}
            }
        }

        self.config.read_dict(updates)

        initialize_simulator(self.config)

        acc = run_pipeline(self.config)

        assert acc[0] >= 0.8

        corr = get_correlations(self.config)
        assert np.all(corr[:-1] > 0.97)
        assert corr[-1] > 0.5