def load(path, filename, **kwargs): """Load network from file. Parameters ---------- path: str Path to directory where to load model from. filename: str Name of file to load model from. Returns ------- : dict[str, Union[keras.models.Sequential, function]] A dictionary of objects that constitute the input model. It must contain the following two keys: - 'model': keras.models.Sequential Keras model instance of the network. - 'val_fn': function Function that allows evaluating the original model. """ filepath = str(os.path.join(path, filename)) if os.path.exists(filepath + '.json'): model = models.model_from_json(open(filepath + '.json').read()) try: model.load_weights(filepath + '.h5') except OSError: # Allows h5 files without a .h5 extension to be loaded. model.load_weights(filepath) # With this loading method, optimizer and loss cannot be recovered. # Could be specified by user, but since they are not really needed # at inference time, set them to the most common choice. # TODO: Proper reinstantiation should be doable since Keras2 model.compile('sgd', 'categorical_crossentropy', ['accuracy', metrics.top_k_categorical_accuracy]) else: filepath_custom_objects = kwargs.get('filepath_custom_objects', None) if filepath_custom_objects is not None: filepath_custom_objects = str(filepath_custom_objects) # python 2 custom_dicts = assemble_custom_dict( get_custom_activations_dict(filepath_custom_objects), get_custom_layers_dict()) try: model = models.load_model(filepath + '.h5', custom_dicts) except OSError as e: print(e) print("Trying to load without '.h5' extension.") model = models.load_model(filepath, custom_dicts) model.compile(model.optimizer, model.loss, ['accuracy', metrics.top_k_categorical_accuracy]) model.summary() return {'model': model, 'val_fn': model.evaluate}
def load(path, filename): """Load network from file. Parameters ---------- path: str Path to directory where to load model from. filename: str Name of file to load model from. Returns ------- : dict[str, Union[keras.models.Sequential, function]] A dictionary of objects that constitute the input model. It must contain the following two keys: - 'model': keras.models.Sequential Keras model instance of the network. - 'val_fn': function Function that allows evaluating the original model. """ import os from keras import models, metrics filepath = os.path.join(path, filename) if os.path.exists(filepath + '.json'): model = models.model_from_json(open(filepath + '.json').read()) model.load_weights(filepath + '.h5') # With this loading method, optimizer and loss cannot be recovered. # Could be specified by user, but since they are not really needed # at inference time, set them to the most common choice. # TODO: Proper reinstantiation should be doable since Keras2 model.compile('sgd', 'categorical_crossentropy', ['accuracy', metrics.top_k_categorical_accuracy]) else: from snntoolbox.parsing.utils import get_custom_activations_dict model = models.load_model(filepath + '.h5', get_custom_activations_dict()) model.compile(model.optimizer, model.loss, ['accuracy', metrics.top_k_categorical_accuracy]) return {'model': model, 'val_fn': model.evaluate}
def update_setup(config_filepath): """Update default settings with user settings and check they are valid. Load settings from configuration file at ``config_filepath``, and check that parameter choices are valid. Non-specified settings are filled in with defaults. """ from textwrap import dedent # Load defaults. config = load_config( os.path.abspath( os.path.join(os.path.dirname(__file__), '..', 'config_defaults'))) # Overwrite with user settings. config.read(config_filepath) keras_backend = config.get('simulation', 'keras_backend') keras_backends = config_string_to_set_of_strings( config.get('restrictions', 'keras_backends')) assert keras_backend in keras_backends, \ "Keras backend {} not supported. Choose from {}.".format(keras_backend, keras_backends) os.environ['KERAS_BACKEND'] = keras_backend # The keras import has to happen after setting the backend environment # variable! import keras.backend as k assert k.backend() == keras_backend, \ "Keras backend set to {} in snntoolbox config file, but has already " \ "been set to {} by a previous keras import. Set backend " \ "appropriately in the keras config file.".format(keras_backend, k.backend()) if keras_backend == 'tensorflow': # Limit GPU usage of tensorflow. tf_config = k.tf.ConfigProto() tf_config.gpu_options.allow_growth = True k.tensorflow_backend.set_session(k.tf.Session(config=tf_config)) # Name of input file must be given. filename_ann = config.get('paths', 'filename_ann') assert filename_ann != '', "Filename of input model not specified." # Check that simulator choice is valid. simulator = config.get('simulation', 'simulator') simulators = config_string_to_set_of_strings( config.get('restrictions', 'simulators')) assert simulator in simulators, \ "Simulator '{}' not supported. Choose from {}".format(simulator, simulators) # Warn user that it is not possible to use Brian2 simulator by loading a # pre-converted network from disk. if simulator == 'brian2' and not config.getboolean('tools', 'convert'): print( dedent("""\ \n SNN toolbox Warning: When using Brian 2 simulator, you need to convert the network each time you start a new session. (No saving/reloading methods implemented.) Setting convert = True. \n""")) config.set('tools', 'convert', str(True)) # Set default path if user did not specify it. if config.get('paths', 'path_wd') == '': config.set('paths', 'path_wd', os.path.dirname(config_filepath)) # Check specified working directory exists. path_wd = config.get('paths', 'path_wd') assert os.path.exists(path_wd), \ "Working directory {} does not exist.".format(path_wd) # Check that choice of input model library is valid. model_lib = config.get('input', 'model_lib') model_libs = config_string_to_set_of_strings( config.get('restrictions', 'model_libs')) assert model_lib in model_libs, "ERROR: Input model library '{}' ".format( model_lib) + "not supported yet. Possible values: {}".format( model_libs) # Check input model is found and has the right format for the specified # model library. if model_lib == 'caffe': caffemodel_filepath = os.path.join(path_wd, filename_ann + '.caffemodel') caffemodel_h5_filepath = os.path.join(path_wd, filename_ann + '.caffemodel.h5') assert os.path.isfile(caffemodel_filepath) or os.path.isfile( caffemodel_h5_filepath), "File {} or {} not found.".format( caffemodel_filepath, caffemodel_h5_filepath) prototxt_filepath = os.path.join(path_wd, filename_ann + '.prototxt') assert os.path.isfile(prototxt_filepath), \ "File {} not found.".format(prototxt_filepath) elif model_lib == 'keras': h5_filepath = os.path.join(path_wd, filename_ann + '.h5') assert os.path.isfile(h5_filepath), \ "File {} not found.".format(h5_filepath) json_file = filename_ann + '.json' if not os.path.isfile(os.path.join(path_wd, json_file)): import keras import h5py from snntoolbox.parsing.utils import get_custom_activations_dict # Remove optimizer_weights here, because they may cause the # load_model method to fail if the network was trained on a # different platform or keras version # (see https://github.com/fchollet/keras/issues/4044). with h5py.File(h5_filepath, 'a') as f: if 'optimizer_weights' in f.keys(): del f['optimizer_weights'] # Try loading the model. keras.models.load_model(h5_filepath, get_custom_activations_dict()) elif model_lib == 'lasagne': h5_filepath = os.path.join(path_wd, filename_ann + '.h5') pkl_filepath = os.path.join(path_wd, filename_ann + '.pkl') assert os.path.isfile(h5_filepath) or os.path.isfile(pkl_filepath), \ "File {} not found.".format('.h5 or .pkl') py_filepath = os.path.join(path_wd, filename_ann + '.py') assert os.path.isfile(py_filepath), \ "File {} not found.".format(py_filepath) else: print("For the specified input model library {}, ".format(model_lib) + "no test is implemented to check if input model files exist in " "the specified working directory!") # Set default path if user did not specify it. if config.get('paths', 'dataset_path') == '': config.set('paths', 'dataset_path', os.path.dirname(__file__)) # Check that the data set path is valid. dataset_path = os.path.abspath(config.get('paths', 'dataset_path')) config.set('paths', 'dataset_path', dataset_path) assert os.path.exists(dataset_path), "Path to data set does not exist: " \ "{}".format(dataset_path) # Check that data set path contains the data in the specified format. assert os.listdir(dataset_path), "Data set directory is empty." normalize = config.getboolean('tools', 'normalize') dataset_format = config.get('input', 'dataset_format') if dataset_format == 'npz' and normalize and not os.path.exists( os.path.join(dataset_path, 'x_norm.npz')): raise RuntimeWarning( "No data set file 'x_norm.npz' found in specified data set path " + "{}. Add it, or disable normalization.".format(dataset_path)) if dataset_format == 'npz' and not ( os.path.exists(os.path.join(dataset_path, 'x_test.npz')) and os.path.exists(os.path.join(dataset_path, 'y_test.npz'))): raise RuntimeWarning( "Data set file 'x_test.npz' or 'y_test.npz' was not found in " "specified data set path {}.".format(dataset_path)) sample_idxs_to_test = eval(config.get('simulation', 'sample_idxs_to_test')) num_to_test = config.getint('simulation', 'num_to_test') if not sample_idxs_to_test == []: if len(sample_idxs_to_test) != num_to_test: print( dedent(""" SNN toolbox warning: Settings mismatch. Adjusting 'num_to_test' to equal the number of 'sample_idxs_to_test'.""")) config.set('simulation', 'num_to_test', str(len(sample_idxs_to_test))) # Create log directory if it does not exist. if config.get('paths', 'log_dir_of_current_run') == '': config.set( 'paths', 'log_dir_of_current_run', os.path.join(path_wd, 'log', 'gui', config.get('paths', 'runlabel'))) log_dir_of_current_run = config.get('paths', 'log_dir_of_current_run') if not os.path.isdir(log_dir_of_current_run): os.makedirs(log_dir_of_current_run) # Specify filenames for models at different stages of the conversion. if config.get('paths', 'filename_parsed_model') == '': config.set('paths', 'filename_parsed_model', filename_ann + '_parsed') if config.get('paths', 'filename_snn') == '': config.set('paths', 'filename_snn', '{}_{}'.format(filename_ann, simulator)) if simulator != 'INI' and not config.getboolean('input', 'poisson_input'): config.set('input', 'poisson_input', str(True)) print( dedent("""\ SNN toolbox Warning: Currently, turning off Poisson input is only possible in INI simulator. Falling back on Poisson input.""")) # Make sure the number of samples to test is not lower than the batch size. batch_size = config.getint('simulation', 'batch_size') if config.getint('simulation', 'num_to_test') < batch_size: print( dedent("""\ SNN toolbox Warning: 'num_to_test' set lower than 'batch_size'. In simulators that test samples batch-wise (e.g. INIsim), this can lead to undesired behavior. Setting 'num_to_test' equal to 'batch_size'.""")) config.set('simulation', 'num_to_test', str(batch_size)) plot_var = get_plot_keys(config) plot_vars = config_string_to_set_of_strings( config.get('restrictions', 'plot_vars')) assert all([v in plot_vars for v in plot_var]), \ "Plot variable(s) {} not understood.".format( [v for v in plot_var if v not in plot_vars]) if 'all' in plot_var: plot_vars_all = plot_vars.copy() plot_vars_all.remove('all') config.set('output', 'plot_vars', str(plot_vars_all)) log_var = get_log_keys(config) log_vars = config_string_to_set_of_strings( config.get('restrictions', 'log_vars')) assert all([v in log_vars for v in log_var]), \ "Log variable(s) {} not understood.".format( [v for v in log_var if v not in log_vars]) if 'all' in log_var: log_vars_all = log_vars.copy() log_vars_all.remove('all') config.set('output', 'log_vars', str(log_vars_all)) # Change matplotlib plot properties, e.g. label font size try: import matplotlib except ImportError: matplotlib = None if len(plot_vars) > 0: import warnings warnings.warn( "Package 'matplotlib' not installed; disabling " "plotting. Run 'pip install matplotlib' to enable " "plotting.", ImportWarning) config.set('output', 'plot_vars', str({})) if matplotlib is not None: matplotlib.rcParams.update(eval(config.get('output', 'plotproperties'))) # Check settings for parameter sweep param_name = config.get('parameter_sweep', 'param_name') try: config.get('cell', param_name) except KeyError: print("Unkown parameter name {} to sweep.".format(param_name)) raise RuntimeError if not eval(config.get('parameter_sweep', 'param_values')): config.set('parameter_sweep', 'param_values', str([eval(config.get('cell', param_name))])) spike_code = config.get('conversion', 'spike_code') spike_codes = config_string_to_set_of_strings( config.get('restrictions', 'spike_codes')) assert spike_code in spike_codes, \ "Unknown spike code {} selected. Choose from {}.".format(spike_code, spike_codes) if spike_code == 'temporal_pattern': num_bits = str(config.getint('conversion', 'num_bits')) config.set('simulation', 'duration', num_bits) config.set('simulation', 'batch_size', '1') elif 'ttfs' in spike_code: config.set('cell', 'tau_refrac', str(config.getint('simulation', 'duration'))) assert keras_backend != 'theano' or spike_code == 'temporal_mean_rate', \ "Keras backend 'theano' only works when the 'spike_code' parameter " \ "is set to 'temporal_mean_rate' in snntoolbox config." with open(os.path.join(log_dir_of_current_run, '.config'), str('w')) as f: config.write(f) return config
def load(path, filename, **kwargs): """Load network from file. Parameters ---------- path: str Path to directory where to load Pytorch model parameters from. filename: str Name of file to load Pytorch model from. path_model: str Path to directory where to load Pytorch model from Returns ------- : dict[str, Union[keras.models.Sequential, function]] A dictionary of objects that constitute the input model. It must contain the following two keys: - 'model': keras.models.Sequential Keras model instance of the network. - 'val_fn': function Function that allows evaluating the original model. """ import os from keras import models, metrics filepath = str(os.path.join(path, filename)) #Create dummy variable with correct shape dummy_input = np.random.uniform(0, 1, (1, 1, 28, 28)) dummy_input = Variable(torch.FloatTensor(dummy_input)) input_shapes = [(1, 28, 28)] #Use dummy-variable to trace the Pytorch model from snntoolbox.pytorch2keras.onnx2keras.converter import onnx_to_keras #load and trace the Pytorch model import sys path_model = os.path.join(path, "../") sys.path.append(path_model) import my_models from my_models import my_model model = my_model #Recommended save method model.load_state_dict(torch.load(filepath + '.pth')) dummy_output = model(dummy_input) if isinstance(dummy_output, torch.autograd.Variable): dummy_output = [dummy_output] if not isinstance(dummy_input, list): dummy_input = [dummy_input] dummy_input = tuple(dummy_input) #export as onnx model, and then reload input_names = ['input_{0}'.format(i) for i in range(len(dummy_input))] output_names = ['output_{0}'.format(i) for i in range(len(dummy_output))] print("output_names", output_names) torch.onnx.export(model, dummy_input, 'model.onnx', input_names=input_names, output_names=output_names) onnx_model = onnx.load('model.onnx') k_model = onnx_to_keras(onnx_model=onnx_model, input_names=input_names, input_shapes=input_shapes) # save the keras model keras.models.save_model(k_model, os.path.join(filepath + '.h5')) from snntoolbox.parsing.utils import get_custom_activations_dict, \ assemble_custom_dict, get_custom_layers_dict filepath_custom_objects = kwargs.get('filepath_custom_objects', None) if filepath_custom_objects is not None: filepath_custom_objects = str(filepath_custom_objects) # python 2 custom_dicts = assemble_custom_dict( get_custom_activations_dict(filepath_custom_objects), get_custom_layers_dict()) try: model = models.load_model(filepath + '.h5', custom_dicts) except OSError as e: print(e) print("Trying to load without '.h5' extension.") model = models.load_model(filepath, custom_dicts) #model.compile(model.optimizer, model.loss, # ['accuracy', metrics.top_k_categorical_accuracy]) model.compile('sgd', 'categorical_crossentropy', ['accuracy', metrics.top_k_categorical_accuracy]) model.summary() return {'model': model, 'val_fn': model.evaluate}