Ejemplo n.º 1
0
def load(path, filename, **kwargs):
    """Load network from file.

    Parameters
    ----------

    path: str
        Path to directory where to load model from.

    filename: str
        Name of file to load model from.

    Returns
    -------

    : dict[str, Union[keras.models.Sequential, function]]
        A dictionary of objects that constitute the input model. It must
        contain the following two keys:

        - 'model': keras.models.Sequential
            Keras model instance of the network.
        - 'val_fn': function
            Function that allows evaluating the original model.
    """

    filepath = str(os.path.join(path, filename))

    if os.path.exists(filepath + '.json'):
        model = models.model_from_json(open(filepath + '.json').read())
        try:
            model.load_weights(filepath + '.h5')
        except OSError:
            # Allows h5 files without a .h5 extension to be loaded.
            model.load_weights(filepath)
        # With this loading method, optimizer and loss cannot be recovered.
        # Could be specified by user, but since they are not really needed
        # at inference time, set them to the most common choice.
        # TODO: Proper reinstantiation should be doable since Keras2
        model.compile('sgd', 'categorical_crossentropy',
                      ['accuracy', metrics.top_k_categorical_accuracy])
    else:
        filepath_custom_objects = kwargs.get('filepath_custom_objects', None)
        if filepath_custom_objects is not None:
            filepath_custom_objects = str(filepath_custom_objects)  # python 2

        custom_dicts = assemble_custom_dict(
            get_custom_activations_dict(filepath_custom_objects),
            get_custom_layers_dict())
        try:
            model = models.load_model(filepath + '.h5', custom_dicts)
        except OSError as e:
            print(e)
            print("Trying to load without '.h5' extension.")
            model = models.load_model(filepath, custom_dicts)
        model.compile(model.optimizer, model.loss,
                      ['accuracy', metrics.top_k_categorical_accuracy])

    model.summary()
    return {'model': model, 'val_fn': model.evaluate}
Ejemplo n.º 2
0
def load(path, filename):
    """Load network from file.

    Parameters
    ----------

    path: str
        Path to directory where to load model from.

    filename: str
        Name of file to load model from.

    Returns
    -------

    : dict[str, Union[keras.models.Sequential, function]]
        A dictionary of objects that constitute the input model. It must
        contain the following two keys:

        - 'model': keras.models.Sequential
            Keras model instance of the network.
        - 'val_fn': function
            Function that allows evaluating the original model.
    """

    import os
    from keras import models, metrics

    filepath = os.path.join(path, filename)

    if os.path.exists(filepath + '.json'):
        model = models.model_from_json(open(filepath + '.json').read())
        model.load_weights(filepath + '.h5')
        # With this loading method, optimizer and loss cannot be recovered.
        # Could be specified by user, but since they are not really needed
        # at inference time, set them to the most common choice.
        # TODO: Proper reinstantiation should be doable since Keras2
        model.compile('sgd', 'categorical_crossentropy',
                      ['accuracy', metrics.top_k_categorical_accuracy])
    else:
        from snntoolbox.parsing.utils import get_custom_activations_dict
        model = models.load_model(filepath + '.h5',
                                  get_custom_activations_dict())
        model.compile(model.optimizer, model.loss,
                      ['accuracy', metrics.top_k_categorical_accuracy])

    return {'model': model, 'val_fn': model.evaluate}
Ejemplo n.º 3
0
def update_setup(config_filepath):
    """Update default settings with user settings and check they are valid.

    Load settings from configuration file at ``config_filepath``, and check that
    parameter choices are valid. Non-specified settings are filled in with
    defaults.
    """

    from textwrap import dedent

    # Load defaults.
    config = load_config(
        os.path.abspath(
            os.path.join(os.path.dirname(__file__), '..', 'config_defaults')))

    # Overwrite with user settings.
    config.read(config_filepath)

    keras_backend = config.get('simulation', 'keras_backend')
    keras_backends = config_string_to_set_of_strings(
        config.get('restrictions', 'keras_backends'))
    assert keras_backend in keras_backends, \
        "Keras backend {} not supported. Choose from {}.".format(keras_backend,
                                                                 keras_backends)
    os.environ['KERAS_BACKEND'] = keras_backend
    # The keras import has to happen after setting the backend environment
    # variable!
    import keras.backend as k
    assert k.backend() == keras_backend, \
        "Keras backend set to {} in snntoolbox config file, but has already " \
        "been set to {} by a previous keras import. Set backend " \
        "appropriately in the keras config file.".format(keras_backend,
                                                         k.backend())
    if keras_backend == 'tensorflow':
        # Limit GPU usage of tensorflow.
        tf_config = k.tf.ConfigProto()
        tf_config.gpu_options.allow_growth = True
        k.tensorflow_backend.set_session(k.tf.Session(config=tf_config))

    # Name of input file must be given.
    filename_ann = config.get('paths', 'filename_ann')
    assert filename_ann != '', "Filename of input model not specified."

    # Check that simulator choice is valid.
    simulator = config.get('simulation', 'simulator')
    simulators = config_string_to_set_of_strings(
        config.get('restrictions', 'simulators'))
    assert simulator in simulators, \
        "Simulator '{}' not supported. Choose from {}".format(simulator,
                                                              simulators)

    # Warn user that it is not possible to use Brian2 simulator by loading a
    # pre-converted network from disk.
    if simulator == 'brian2' and not config.getboolean('tools', 'convert'):
        print(
            dedent("""\ \n
            SNN toolbox Warning: When using Brian 2 simulator, you need to
            convert the network each time you start a new session. (No
            saving/reloading methods implemented.) Setting convert = True.
            \n"""))
        config.set('tools', 'convert', str(True))

    # Set default path if user did not specify it.
    if config.get('paths', 'path_wd') == '':
        config.set('paths', 'path_wd', os.path.dirname(config_filepath))

    # Check specified working directory exists.
    path_wd = config.get('paths', 'path_wd')
    assert os.path.exists(path_wd), \
        "Working directory {} does not exist.".format(path_wd)

    # Check that choice of input model library is valid.
    model_lib = config.get('input', 'model_lib')
    model_libs = config_string_to_set_of_strings(
        config.get('restrictions', 'model_libs'))
    assert model_lib in model_libs, "ERROR: Input model library '{}' ".format(
        model_lib) + "not supported yet. Possible values: {}".format(
            model_libs)

    # Check input model is found and has the right format for the specified
    # model library.
    if model_lib == 'caffe':
        caffemodel_filepath = os.path.join(path_wd,
                                           filename_ann + '.caffemodel')
        caffemodel_h5_filepath = os.path.join(path_wd,
                                              filename_ann + '.caffemodel.h5')
        assert os.path.isfile(caffemodel_filepath) or os.path.isfile(
            caffemodel_h5_filepath), "File {} or {} not found.".format(
                caffemodel_filepath, caffemodel_h5_filepath)
        prototxt_filepath = os.path.join(path_wd, filename_ann + '.prototxt')
        assert os.path.isfile(prototxt_filepath), \
            "File {} not found.".format(prototxt_filepath)
    elif model_lib == 'keras':
        h5_filepath = os.path.join(path_wd, filename_ann + '.h5')
        assert os.path.isfile(h5_filepath), \
            "File {} not found.".format(h5_filepath)
        json_file = filename_ann + '.json'
        if not os.path.isfile(os.path.join(path_wd, json_file)):
            import keras
            import h5py
            from snntoolbox.parsing.utils import get_custom_activations_dict
            # Remove optimizer_weights here, because they may cause the
            # load_model method to fail if the network was trained on a
            # different platform or keras version
            # (see https://github.com/fchollet/keras/issues/4044).
            with h5py.File(h5_filepath, 'a') as f:
                if 'optimizer_weights' in f.keys():
                    del f['optimizer_weights']
            # Try loading the model.
            keras.models.load_model(h5_filepath, get_custom_activations_dict())
    elif model_lib == 'lasagne':
        h5_filepath = os.path.join(path_wd, filename_ann + '.h5')
        pkl_filepath = os.path.join(path_wd, filename_ann + '.pkl')
        assert os.path.isfile(h5_filepath) or os.path.isfile(pkl_filepath), \
            "File {} not found.".format('.h5 or .pkl')
        py_filepath = os.path.join(path_wd, filename_ann + '.py')
        assert os.path.isfile(py_filepath), \
            "File {} not found.".format(py_filepath)
    else:
        print("For the specified input model library {}, ".format(model_lib) +
              "no test is implemented to check if input model files exist in "
              "the specified working directory!")

    # Set default path if user did not specify it.
    if config.get('paths', 'dataset_path') == '':
        config.set('paths', 'dataset_path', os.path.dirname(__file__))

    # Check that the data set path is valid.
    dataset_path = os.path.abspath(config.get('paths', 'dataset_path'))
    config.set('paths', 'dataset_path', dataset_path)
    assert os.path.exists(dataset_path), "Path to data set does not exist: " \
                                         "{}".format(dataset_path)

    # Check that data set path contains the data in the specified format.
    assert os.listdir(dataset_path), "Data set directory is empty."
    normalize = config.getboolean('tools', 'normalize')
    dataset_format = config.get('input', 'dataset_format')
    if dataset_format == 'npz' and normalize and not os.path.exists(
            os.path.join(dataset_path, 'x_norm.npz')):
        raise RuntimeWarning(
            "No data set file 'x_norm.npz' found in specified data set path " +
            "{}. Add it, or disable normalization.".format(dataset_path))
    if dataset_format == 'npz' and not (
            os.path.exists(os.path.join(dataset_path, 'x_test.npz'))
            and os.path.exists(os.path.join(dataset_path, 'y_test.npz'))):
        raise RuntimeWarning(
            "Data set file 'x_test.npz' or 'y_test.npz' was not found in "
            "specified data set path {}.".format(dataset_path))

    sample_idxs_to_test = eval(config.get('simulation', 'sample_idxs_to_test'))
    num_to_test = config.getint('simulation', 'num_to_test')
    if not sample_idxs_to_test == []:
        if len(sample_idxs_to_test) != num_to_test:
            print(
                dedent("""
            SNN toolbox warning: Settings mismatch. Adjusting 'num_to_test' to 
            equal the number of 'sample_idxs_to_test'."""))
            config.set('simulation', 'num_to_test',
                       str(len(sample_idxs_to_test)))

    # Create log directory if it does not exist.
    if config.get('paths', 'log_dir_of_current_run') == '':
        config.set(
            'paths', 'log_dir_of_current_run',
            os.path.join(path_wd, 'log', 'gui',
                         config.get('paths', 'runlabel')))
    log_dir_of_current_run = config.get('paths', 'log_dir_of_current_run')
    if not os.path.isdir(log_dir_of_current_run):
        os.makedirs(log_dir_of_current_run)

    # Specify filenames for models at different stages of the conversion.
    if config.get('paths', 'filename_parsed_model') == '':
        config.set('paths', 'filename_parsed_model', filename_ann + '_parsed')
    if config.get('paths', 'filename_snn') == '':
        config.set('paths', 'filename_snn',
                   '{}_{}'.format(filename_ann, simulator))

    if simulator != 'INI' and not config.getboolean('input', 'poisson_input'):
        config.set('input', 'poisson_input', str(True))
        print(
            dedent("""\
            SNN toolbox Warning: Currently, turning off Poisson input is
            only possible in INI simulator. Falling back on Poisson input."""))

    # Make sure the number of samples to test is not lower than the batch size.
    batch_size = config.getint('simulation', 'batch_size')
    if config.getint('simulation', 'num_to_test') < batch_size:
        print(
            dedent("""\
            SNN toolbox Warning: 'num_to_test' set lower than 'batch_size'.
            In simulators that test samples batch-wise (e.g. INIsim), this
            can lead to undesired behavior. Setting 'num_to_test' equal to
            'batch_size'."""))
        config.set('simulation', 'num_to_test', str(batch_size))

    plot_var = get_plot_keys(config)
    plot_vars = config_string_to_set_of_strings(
        config.get('restrictions', 'plot_vars'))
    assert all([v in plot_vars for v in plot_var]), \
        "Plot variable(s) {} not understood.".format(
            [v for v in plot_var if v not in plot_vars])
    if 'all' in plot_var:
        plot_vars_all = plot_vars.copy()
        plot_vars_all.remove('all')
        config.set('output', 'plot_vars', str(plot_vars_all))

    log_var = get_log_keys(config)
    log_vars = config_string_to_set_of_strings(
        config.get('restrictions', 'log_vars'))
    assert all([v in log_vars for v in log_var]), \
        "Log variable(s) {} not understood.".format(
            [v for v in log_var if v not in log_vars])
    if 'all' in log_var:
        log_vars_all = log_vars.copy()
        log_vars_all.remove('all')
        config.set('output', 'log_vars', str(log_vars_all))

    # Change matplotlib plot properties, e.g. label font size
    try:
        import matplotlib
    except ImportError:
        matplotlib = None
        if len(plot_vars) > 0:
            import warnings
            warnings.warn(
                "Package 'matplotlib' not installed; disabling "
                "plotting. Run 'pip install matplotlib' to enable "
                "plotting.", ImportWarning)
            config.set('output', 'plot_vars', str({}))
    if matplotlib is not None:
        matplotlib.rcParams.update(eval(config.get('output',
                                                   'plotproperties')))

    # Check settings for parameter sweep
    param_name = config.get('parameter_sweep', 'param_name')
    try:
        config.get('cell', param_name)
    except KeyError:
        print("Unkown parameter name {} to sweep.".format(param_name))
        raise RuntimeError
    if not eval(config.get('parameter_sweep', 'param_values')):
        config.set('parameter_sweep', 'param_values',
                   str([eval(config.get('cell', param_name))]))

    spike_code = config.get('conversion', 'spike_code')
    spike_codes = config_string_to_set_of_strings(
        config.get('restrictions', 'spike_codes'))
    assert spike_code in spike_codes, \
        "Unknown spike code {} selected. Choose from {}.".format(spike_code,
                                                                 spike_codes)
    if spike_code == 'temporal_pattern':
        num_bits = str(config.getint('conversion', 'num_bits'))
        config.set('simulation', 'duration', num_bits)
        config.set('simulation', 'batch_size', '1')
    elif 'ttfs' in spike_code:
        config.set('cell', 'tau_refrac',
                   str(config.getint('simulation', 'duration')))
    assert keras_backend != 'theano' or spike_code == 'temporal_mean_rate', \
        "Keras backend 'theano' only works when the 'spike_code' parameter " \
        "is set to 'temporal_mean_rate' in snntoolbox config."

    with open(os.path.join(log_dir_of_current_run, '.config'), str('w')) as f:
        config.write(f)

    return config
Ejemplo n.º 4
0
def load(path, filename, **kwargs):
    """Load network from file.

    Parameters
    ----------

    path: str
        Path to directory where to load Pytorch model parameters from.

    filename: str
        Name of file to load Pytorch model from.
        
    path_model: str
        Path to directory where to load Pytorch model from

    Returns
    -------

    : dict[str, Union[keras.models.Sequential, function]]
        A dictionary of objects that constitute the input model. It must
        contain the following two keys:

        - 'model': keras.models.Sequential
            Keras model instance of the network.
        - 'val_fn': function
            Function that allows evaluating the original model.
    """

    import os
    from keras import models, metrics

    filepath = str(os.path.join(path, filename))

    #Create dummy variable with correct shape
    dummy_input = np.random.uniform(0, 1, (1, 1, 28, 28))
    dummy_input = Variable(torch.FloatTensor(dummy_input))
    input_shapes = [(1, 28, 28)]

    #Use dummy-variable to trace the Pytorch model
    from snntoolbox.pytorch2keras.onnx2keras.converter import onnx_to_keras

    #load and trace the Pytorch model
    import sys
    path_model = os.path.join(path, "../")
    sys.path.append(path_model)
    import my_models
    from my_models import my_model
    model = my_model

    #Recommended save method
    model.load_state_dict(torch.load(filepath + '.pth'))
    dummy_output = model(dummy_input)

    if isinstance(dummy_output, torch.autograd.Variable):
        dummy_output = [dummy_output]
    if not isinstance(dummy_input, list):
        dummy_input = [dummy_input]
    dummy_input = tuple(dummy_input)
    #export as onnx model, and then reload
    input_names = ['input_{0}'.format(i) for i in range(len(dummy_input))]
    output_names = ['output_{0}'.format(i) for i in range(len(dummy_output))]
    print("output_names", output_names)
    torch.onnx.export(model,
                      dummy_input,
                      'model.onnx',
                      input_names=input_names,
                      output_names=output_names)
    onnx_model = onnx.load('model.onnx')
    k_model = onnx_to_keras(onnx_model=onnx_model,
                            input_names=input_names,
                            input_shapes=input_shapes)
    # save the keras model
    keras.models.save_model(k_model, os.path.join(filepath + '.h5'))


    from snntoolbox.parsing.utils import get_custom_activations_dict, \
            assemble_custom_dict, get_custom_layers_dict
    filepath_custom_objects = kwargs.get('filepath_custom_objects', None)
    if filepath_custom_objects is not None:
        filepath_custom_objects = str(filepath_custom_objects)  # python 2

    custom_dicts = assemble_custom_dict(
        get_custom_activations_dict(filepath_custom_objects),
        get_custom_layers_dict())
    try:
        model = models.load_model(filepath + '.h5', custom_dicts)
    except OSError as e:
        print(e)
        print("Trying to load without '.h5' extension.")
        model = models.load_model(filepath, custom_dicts)
        #model.compile(model.optimizer, model.loss,
        #              ['accuracy', metrics.top_k_categorical_accuracy])
    model.compile('sgd', 'categorical_crossentropy',
                  ['accuracy', metrics.top_k_categorical_accuracy])
    model.summary()
    return {'model': model, 'val_fn': model.evaluate}