예제 #1
0
    def test_get_dataset_from_png(self, _config):
        try:
            import matplotlib.pyplot as plt
        except ImportError:
            return

        datapath = _config.get('paths', 'dataset_path')
        classpath = os.path.join(datapath, 'class_0')
        os.mkdir(classpath)
        data = np.random.random_sample((10, 10, 3))
        plt.imsave(os.path.join(classpath, 'image_0.png'), data)

        _config.read_dict({
            'input': {
                'dataset_format':
                'png',
                'dataflow_kwargs':
                "{'target_size': (11, 12)}",
                'datagen_kwargs':
                "{'rescale': 0.003922,"
                " 'featurewise_center': True,"
                " 'featurewise_std_normalization':"
                " True}"
            }
        })

        normset, testset = get_dataset(_config)
        assert all([normset, testset])
예제 #2
0
    def test_normalizing(self, _model_2, _config):

        # Parsing removes BatchNorm layers, so we make a copy of the model.
        input_model = keras.models.clone_model(_model_2)
        input_model.set_weights(_model_2.get_weights())
        input_model.compile(_model_2.optimizer.__class__.__name__,
                            _model_2.loss, _model_2.metrics)

        num_to_test = 10000
        batch_size = 100
        _config.set('simulation', 'batch_size', str(batch_size))
        _config.set('simulation', 'num_to_test', str(num_to_test))

        normset, testset = get_dataset(_config)
        x_test = testset['x_test']
        y_test = testset['y_test']
        x_norm = normset['x_norm']

        model_lib = import_module('snntoolbox.parsing.model_libs.' +
                                  _config.get('input', 'model_lib') +
                                  '_input_lib')
        model_parser = model_lib.ModelParser(input_model, _config)
        model_parser.parse()
        parsed_model = model_parser.build_parsed_model()

        normalize_parameters(parsed_model, _config, x_norm=x_norm)

        _, acc, _ = model_parser.evaluate(batch_size, num_to_test, x_test,
                                          y_test)
        _, target_acc = _model_2.evaluate(x_test, y_test, batch_size)
        assert acc == target_acc
예제 #3
0
    def test_parsing(self, _model_2, _config):

        # Parsing removes BatchNorm layers, so we make a copy of the model.
        input_model = models.clone_model(_model_2)
        input_model.set_weights(_model_2.get_weights())
        input_model.compile(_model_2.optimizer.__class__.__name__,
                            _model_2.loss, _model_2.metrics)

        num_to_test = 10000
        batch_size = 100
        _config.set('simulation', 'batch_size', str(batch_size))
        _config.set('simulation', 'num_to_test', str(num_to_test))

        _, testset = get_dataset(_config)
        dataflow = testset['dataflow']

        model_lib = import_module('snntoolbox.parsing.model_libs.' +
                                  _config.get('input', 'model_lib') +
                                  '_input_lib')
        model_parser = model_lib.ModelParser(input_model, _config)
        model_parser.parse()
        model_parser.build_parsed_model()
        _, acc, _ = model_parser.evaluate(batch_size,
                                          num_to_test,
                                          dataflow=dataflow)
        _, target_acc = _model_2.evaluate(dataflow,
                                          steps=int(num_to_test / batch_size))
        assert acc == target_acc
예제 #4
0
 def test_get_dataset_from_npz(self, _datapath, _config):
     data = np.random.random_sample((1, 1, 1, 1))
     np.savez_compressed(str(_datapath.join('x_norm')), data)
     np.savez_compressed(str(_datapath.join('x_test')), data)
     np.savez_compressed(str(_datapath.join('y_test')), data)
     _config.set('paths', 'dataset_path', str(_datapath))
     normset, testset = get_dataset(_config)
     assert all([normset, testset])
예제 #5
0
    def test_loading(self, _model_4, _config):

        import keras
        assert keras.backend.image_data_format() == 'channels_first', \
            "Pytorch to Keras parser needs image_data_format == channel_first."

        self.prepare_model(_model_4, _config)

        updates = {
            'tools': {
                'evaluate_ann': True,
                'parse': False,
                'normalize': False,
                'convert': False,
                'simulate': False
            },
            'input': {
                'model_lib': 'pytorch'
            },
            'simulation': {
                'num_to_test': 100,
                'batch_size': 50
            }
        }

        _config.read_dict(updates)

        initialize_simulator(_config)

        normset, testset = get_dataset(_config)

        model_lib = import_module('snntoolbox.parsing.model_libs.' +
                                  _config.get('input', 'model_lib') +
                                  '_input_lib')
        input_model = model_lib.load(_config.get('paths', 'path_wd'),
                                     _config.get('paths', 'filename_ann'))

        # Evaluate input model.
        acc = model_lib.evaluate(input_model['val_fn'],
                                 _config.getint('simulation', 'batch_size'),
                                 _config.getint('simulation', 'num_to_test'),
                                 **testset)

        assert acc >= 0.8
예제 #6
0
 def test_get_dataset_from_npz(self, _config):
     normset, testset = get_dataset(_config)
     assert all([normset, testset])
예제 #7
0
def _dataset(_config):
    from snntoolbox.datasets.utils import get_dataset
    return get_dataset(_config)
예제 #8
0
def run_pipeline(config, queue=None):
    """Convert an analog network to a spiking network and simulate it.

    Complete pipeline of
        1. loading and testing a pretrained ANN,
        2. normalizing parameters
        3. converting it to SNN,
        4. running it on a simulator,
        5. given a specified hyperparameter range ``params``,
           repeat simulations with modified parameters.

    Parameters
    ----------

    config: configparser.ConfigParser
        ConfigParser containing the user settings.

    queue: Optional[Queue.Queue]
        Results are added to the queue to be displayed in the GUI.

    Returns
    -------

    results: list
        List of the accuracies obtained after simulating with each parameter
        value in config.get('parameter_sweep', 'param_values').
    """

    from snntoolbox.datasets.utils import get_dataset
    from snntoolbox.conversion.utils import normalize_parameters

    num_to_test = config.getint('simulation', 'num_to_test')

    # Instantiate an empty spiking network
    target_sim = import_target_sim(config)
    spiking_model = target_sim.SNN(config, queue)

    # ___________________________ LOAD DATASET ______________________________ #

    normset, testset = get_dataset(config)

    parsed_model = None
    if config.getboolean('tools', 'parse') and not is_stop(queue):

        # __________________________ LOAD MODEL _____________________________ #

        model_lib = import_module('snntoolbox.parsing.model_libs.' +
                                  config.get('input', 'model_lib') +
                                  '_input_lib')
        input_model = model_lib.load(config.get('paths', 'path_wd'),
                                     config.get('paths', 'filename_ann'))

        # Evaluate input model.
        if config.getboolean('tools', 'evaluate_ann') and not is_stop(queue):
            print(
                "Evaluating input model on {} samples...".format(num_to_test))
            model_lib.evaluate(input_model['val_fn'],
                               config.getint('simulation', 'batch_size'),
                               num_to_test, **testset)

        # ____________________________ PARSE ________________________________ #

        print("Parsing input model...")
        model_parser = model_lib.ModelParser(input_model['model'], config)
        model_parser.parse()
        parsed_model = model_parser.build_parsed_model()

        # ___________________________ NORMALIZE _____________________________ #

        if config.getboolean('tools', 'normalize') and not is_stop(queue):
            normalize_parameters(parsed_model, config, **normset)

        # Evaluate parsed model.
        if config.getboolean('tools', 'evaluate_ann') and not is_stop(queue):
            print(
                "Evaluating parsed model on {} samples...".format(num_to_test))
            model_parser.evaluate(config.getint('simulation', 'batch_size'),
                                  num_to_test, **testset)

        # Write parsed model to disk
        parsed_model.save(
            str(
                os.path.join(
                    config.get('paths', 'path_wd'),
                    config.get('paths', 'filename_parsed_model') + '.h5')))

    # _____________________________ CONVERT _________________________________ #

    if config.getboolean('tools', 'convert') and not is_stop(queue):
        if parsed_model is None:
            from snntoolbox.parsing.model_libs.keras_input_lib import load
            try:
                parsed_model = load(config.get('paths', 'path_wd'),
                                    config.get('paths',
                                               'filename_parsed_model'),
                                    filepath_custom_objects=config.get(
                                        'paths',
                                        'filepath_custom_objects'))['model']
            except FileNotFoundError:
                print("Could not find parsed model {} in path {}. Consider "
                      "setting `parse = True` in your config file.".format(
                          config.get('paths', 'path_wd'),
                          config.get('paths', 'filename_parsed_model')))

        spiking_model.build(parsed_model, **testset)

        # Export network in a format specific to the simulator with which it
        # will be tested later.
        spiking_model.save(config.get('paths', 'path_wd'),
                           config.get('paths', 'filename_snn'))

    # ______________________________ SIMULATE _______________________________ #

    if config.getboolean('tools', 'simulate') and not is_stop(queue):

        # Decorate the 'run' function of the spiking model with a parameter
        # sweep function.
        @run_parameter_sweep(config, queue)
        def run(snn, **test_set):
            return snn.run(**test_set)

        # Simulate network
        results = run(spiking_model, **testset)

        # Clean up
        spiking_model.end_sim()

        # Add results to queue to be displayed in GUI.
        if queue:
            queue.put(results)

        return results