def pyrnn_to_pronn(pyrnn=None, output='en-default.pronn'): """ Converts a legacy python RNN to the new protobuf format. Benefits of the new format include independence from particular python versions and no arbitrary code execution issues inherent in pickle. Args: pyrnn (kraken.lib.lstm.SegRecognizer): pyrnn model output (unicode): path of the converted HDF5 model """ proto = pyrnn_pb2.pyrnn() proto.kind = 'pyrnn-bidi' proto.ninput = pyrnn.Ni proto.noutput = pyrnn.No proto.codec.extend(pyrnn.codec.code2char.values()) parallel, softmax = pyrnn.lstm.nets fwdnet, revnet = parallel.nets revnet = revnet.net for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'): fwd_weights = getattr(fwdnet, w) rev_weights = getattr(revnet, w) fwd_ar = getattr(proto.fwdnet, w.lower()) rev_ar = getattr(proto.revnet, w.lower()) fwd_ar.dim.extend(fwd_weights.shape) fwd_ar.value.extend(fwd_weights.reshape(-1).tolist()) rev_ar.dim.extend(rev_weights.shape) rev_ar.value.extend(rev_weights.reshape(-1).tolist()) proto.softmax.w2.dim.extend(softmax.W2.shape) proto.softmax.w2.value.extend(softmax.W2.reshape(-1).tolist()) with open(output, 'wb') as fp: fp.write(proto.SerializeToString())
def load_pronn_model(cls, path: str): """ Loads an pronn model to VGSL. """ with open(path, 'rb') as fp: net = pyrnn_pb2.pyrnn() try: net.ParseFromString(fp.read()) except Exception: raise KrakenInvalidModelException('File does not contain valid proto msg') if not net.IsInitialized(): raise KrakenInvalidModelException('Model incomplete') # extract codec codec = PytorchCodec(net.codec) input = net.ninput hidden = net.fwdnet.wgi.dim[0] # extract weights weightnames = ('wgi', 'wgf', 'wci', 'wgo', 'wip', 'wfp', 'wop') fwd_w = [] rev_w = [] for w in weightnames: fwd_ar = getattr(net.fwdnet, w) rev_ar = getattr(net.revnet, w) fwd_w.append(torch.Tensor(fwd_ar.value).view(list(fwd_ar.dim))) rev_w.append(torch.Tensor(rev_ar.value).view(list(rev_ar.dim))) t = torch.cat(fwd_w[:4]) weight_ih_l0 = t[:, :input+1] weight_hh_l0 = t[:, input+1:] t = torch.cat(rev_w[:4]) weight_ih_l0_rev = t[:, :input+1] weight_hh_l0_rev = t[:, input+1:] weight_lin = torch.Tensor(net.softmax.w2.value).view(list(net.softmax.w2.dim)) # build vgsl spec and set weights nn = cls('[1,1,0,{} Lbxo{} O1ca{}]'.format(input, hidden, len(net.codec))) nn.nn.L_0.layer.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0) nn.nn.L_0.layer.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0) nn.nn.L_0.layer.weight_ih_l0_reverse = torch.nn.Parameter(weight_ih_l0_rev) nn.nn.L_0.layer.weight_hh_l0_reverse = torch.nn.Parameter(weight_hh_l0_rev) nn.nn.L_0.layer.weight_ip_l0 = torch.nn.Parameter(fwd_w[4]) nn.nn.L_0.layer.weight_fp_l0 = torch.nn.Parameter(fwd_w[5]) nn.nn.L_0.layer.weight_op_l0 = torch.nn.Parameter(fwd_w[6]) nn.nn.L_0.layer.weight_ip_l0_reverse = torch.nn.Parameter(rev_w[4]) nn.nn.L_0.layer.weight_fp_l0_reverse = torch.nn.Parameter(rev_w[5]) nn.nn.L_0.layer.weight_op_l0_reverse = torch.nn.Parameter(rev_w[6]) nn.nn.O_1.lin.weight = torch.nn.Parameter(weight_lin) nn.add_codec(codec) return nn
def load_pronn(fname): """ Loads a legacy pyrnn model in protobuf format and instantiates a kraken.lib.lstm.SeqRecognizer object. Args: fname (str): Path to the protobuf file Returns: A kraken.lib.lstm.SeqRecognizer object """ logger.info(u'Trying to load prornn model from {}'.format(fname)) with open(fname, 'rb') as fp: logger.debug(u'Initializing protobuf message') proto = pyrnn_pb2.pyrnn() try: proto.ParseFromString(fp.read()) except: logger.debug(u'File does not contain valid proto msg') raise KrakenInvalidModelException( 'File does not contain valid proto msg') if not proto.IsInitialized(): logger.debug(u'Message in file incomplete') raise KrakenInvalidModelException('Model incomplete') # extract codec logger.debug(u'Extracting codec') codec = kraken.lib.lstm.Codec().init(proto.codec) hiddensize = proto.fwdnet.wgi.dim[0] # next build a line estimator logger.debug(u'Add line estimator') lnorm = kraken.lib.lineest.CenterNormalizer(proto.ninput) network = kraken.lib.lstm.SeqRecognizer( lnorm.target_height, hiddensize, codec=codec, normalize=kraken.lib.lstm.normalize_nfkc) logger.debug(u'Setting weights on BIDILSTM') parallel, softmax = network.lstm.nets fwdnet, revnet = parallel.nets revnet = revnet.net for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'): fwd_ar = getattr(proto.fwdnet, w.lower()) rev_ar = getattr(proto.revnet, w.lower()) setattr(fwdnet, w, numpy.array(fwd_ar.value).reshape(fwd_ar.dim)) setattr(revnet, w, numpy.array(rev_ar.value).reshape(rev_ar.dim)) softmax.W2 = numpy.array(proto.softmax.w2.value).reshape( proto.softmax.w2.dim) return network
def load_pronn(fname): """ Loads a legacy pyrnn model in protobuf format and instantiates a kraken.lib.lstm.SeqRecognizer object. Args: fname (unicode): Path to the HDF5 file Returns: A kraken.lib.lstm.SeqRecognizer object """ with open(fname, 'rb') as fp: proto = pyrnn_pb2.pyrnn() try: proto.ParseFromString(fp.read()) except: raise KrakenInvalidModelException( 'File does not contain valid proto msg') if not proto.IsInitialized(): raise KrakenInvalidModelException('Model incomplete') # extract codec codec = kraken.lib.lstm.Codec().init(proto.codec) hiddensize = proto.fwdnet.wgi.dim[0] # next build a line estimator lnorm = kraken.lib.lineest.CenterNormalizer(proto.ninput) network = kraken.lib.lstm.SeqRecognizer( lnorm.target_height, hiddensize, codec=codec, normalize=kraken.lib.lstm.normalize_nfkc) parallel, softmax = network.lstm.nets fwdnet, revnet = parallel.nets revnet = revnet.net for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'): fwd_ar = getattr(proto.fwdnet, w.lower()) rev_ar = getattr(proto.revnet, w.lower()) setattr(fwdnet, w, numpy.array(fwd_ar.value).reshape(fwd_ar.dim)) setattr(revnet, w, numpy.array(rev_ar.value).reshape(rev_ar.dim)) softmax.W2 = numpy.array(proto.softmax.w2.value).reshape( proto.softmax.w2.dim) return network
def load_pronn(fname): """ Loads a legacy pyrnn model in protobuf format and instantiates a kraken.lib.lstm.SeqRecognizer object. Args: fname (unicode): Path to the HDF5 file Returns: A kraken.lib.lstm.SeqRecognizer object """ with open(fname, 'rb') as fp: proto = pyrnn_pb2.pyrnn() try: proto.ParseFromString(fp.read()) except: raise KrakenInvalidModelException('File does not contain valid proto msg') if not proto.IsInitialized(): raise KrakenInvalidModelException('Model incomplete') # extract codec codec = kraken.lib.lstm.Codec().init(proto.codec) hiddensize = proto.fwdnet.wgi.dim[0] # next build a line estimator lnorm = kraken.lib.lineest.CenterNormalizer(proto.ninput) network = kraken.lib.lstm.SeqRecognizer(lnorm.target_height, hiddensize, codec=codec, normalize=kraken.lib.lstm.normalize_nfkc) parallel, softmax = network.lstm.nets fwdnet, revnet = parallel.nets revnet = revnet.net for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'): fwd_ar = getattr(proto.fwdnet, w.lower()) rev_ar = getattr(proto.revnet, w.lower()) setattr(fwdnet, w, numpy.array(fwd_ar.value).reshape(fwd_ar.dim)) setattr(revnet, w, numpy.array(rev_ar.value).reshape(rev_ar.dim)) softmax.W2 = numpy.array(proto.softmax.w2.value).reshape(proto.softmax.w2.dim) return network