#!/usr/bin/env python __author__ = "Pawel Swietojanski" import sys, types, tables, numpy, theano from pylearn2.scripts.pkl_to_pytables import ModelPyTables from pylearn2.utils import serial from theano import shared def create_model(yaml_path): from pylearn2.models.model import Model model = serial.load_train_file(yaml_path) assert isinstance(model, Model) return model if __name__=="__main__": assert len(sys.argv)==3 #hardcoded_layer_names = {"softmax":"y"} _, yaml_path, pytables_path_to = sys.argv model = create_model(yaml_path) params = ModelPyTables.pytables_to_params(pytables_path_to, name="Model") model.set_params(params)
def main(args=None): usage = "pool_adaptation.py [options] <si-model-dir> <sa-model-dir> <feats-scp> <targets-pdf> " parser = OptionParser() parser.add_option("--adapt-yaml", dest="adapt_yaml", default="", help="Provide the adaptation yaml template to start with") parser.add_option("--freeze-means", dest="freeze_means", default=False, help="Skip means while updating pools") parser.add_option("--freeze-betas", dest="freeze_betas", default=False, help="Skip precisions while updating pools") parser.add_option("--freeze-amp", dest="freeze_amp", default="true", help="Skip activation function amplitudes while updating pools") parser.add_option("--freeze-slopes", dest="freeze_slopes", default="true", help="Skip activation function slopes while updating pools") parser.add_option("--freeze-layer-ids", dest="freeze_layer_ids", default="", help="update params only in this layers, i.e. --layer-ids 012") parser.add_option("--job", dest="JOB", default=0, help="JOB ID used to store model in") parser.add_option("--debug", dest="debug", default=False, help="Prints activations and shapes in text format rather than binary Kaldi archives") (options,args) = parser.parse_args(args=args) print options.adapt_yaml print options.freeze_means print options.freeze_betas print options.freeze_amp print options.freeze_slopes print options.freeze_layer_ids print options.JOB print 'ARGS: ', args #if options.adapt_yaml!='': # NotImplementedError('Lodaing from pkl not yet supported due to GPU/CPU pickle issues.') if len(args) != 5: print usage exit(1) si_model_dir = args[1] sa_model_dir = args[2] feats_scp = args[3] targets_pdf = args[4] #print "si model dir is %s"%si_model_dir model_yaml = "%s/adapt_final%s.yaml"%(si_model_dir, options.JOB) model_params = "%s/cnn_best.h5"%si_model_dir #print 'Yaml path', model_yaml #print 'Params path', model_params if not os.path.isfile(options.adapt_yaml): raise Exception('File %s not found'%options.adapt_yaml) if not os.path.isfile(model_params): raise Exception('File %s not found'%model_params) vars={} vars['adapt_flist']=feats_scp vars['adapt_pdfs']=targets_pdf vars['adapt_lr']=0.05 vars['adapt_momentum']=0.5 vars['sa_dir'] = sa_model_dir vars['JOB'] = options.JOB #print vars #print 'Locals: ',locals() adapt_template = open(options.adapt_yaml, 'r').read() adapt_template_str = adapt_template % vars f = open(model_yaml, 'w') f.write(adapt_template_str) f.close() print 'Building model %s'%model_yaml train_obj = serial.load_train_file(model_yaml) print 'Loading params from %s'%model_params params = ModelPyTables.pytables_to_params(model_params, name='Model') train_obj.model.set_params(params) freeze_regex='softmax_[Wb]|h[0-9]_[Wb]|nlrf_[Wb]' if options.freeze_layer_ids!='': layers = options.freeze_layer_ids freeze_regex = "%s|g[%s]p_u|g[%s]p_beta"%(freeze_regex, layers, layers) if options.freeze_means == 'true': freeze_regex = "%s|g[0-9]p_u"%(freeze_regex) if options.freeze_betas == 'true': freeze_regex = "%s|g[0-9]p_beta"%(freeze_regex) if options.freeze_amp == 'true': freeze_regex = "%s|g[0-9]p_amp"%(freeze_regex) if options.freeze_slopes == 'true': freeze_regex = "%s|g[0-9]p_arg"%(freeze_regex) #print "Freeze regex is", freeze_regex model_params = train_obj.model.get_params() params_to_freeze = {} for param in model_params: if re.match(freeze_regex, str(param)) is not None: if param not in params_to_freeze: params_to_freeze[param] = param #print params_to_freeze if len(params_to_freeze)==len(model_params): print 'None of the parameters were set to be updated. Freeze list is', params_to_freeze exit(0) train_obj.model.freeze(params_to_freeze.values()) print 'Will update those params only: ', train_obj.model.get_params() train_obj.main_loop() train_obj.model.freeze_set = set([]) #unfreeze so get_params will return all model params print 'Unfreezed params are ', train_obj.model.get_params()