Beispiel #1
0
def pyrnn_to_pronn(pyrnn=None, output='en-default.pronn'):
    """
    Converts a legacy python RNN to the new protobuf format. Benefits of the
    new format include independence from particular python versions and no
    arbitrary code execution issues inherent in pickle.

    Args:
        pyrnn (kraken.lib.lstm.SegRecognizer): pyrnn model
        output (unicode): path of the converted HDF5 model
    """
    proto = pyrnn_pb2.pyrnn()
    proto.kind = 'pyrnn-bidi'
    proto.ninput = pyrnn.Ni
    proto.noutput = pyrnn.No
    proto.codec.extend(pyrnn.codec.code2char.values())

    parallel, softmax = pyrnn.lstm.nets
    fwdnet, revnet = parallel.nets
    revnet = revnet.net
    for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'):
        fwd_weights = getattr(fwdnet, w)
        rev_weights = getattr(revnet, w)
        fwd_ar = getattr(proto.fwdnet, w.lower())
        rev_ar = getattr(proto.revnet, w.lower())
        fwd_ar.dim.extend(fwd_weights.shape)
        fwd_ar.value.extend(fwd_weights.reshape(-1).tolist())
        rev_ar.dim.extend(rev_weights.shape)
        rev_ar.value.extend(rev_weights.reshape(-1).tolist())
    proto.softmax.w2.dim.extend(softmax.W2.shape)
    proto.softmax.w2.value.extend(softmax.W2.reshape(-1).tolist())
    with open(output, 'wb') as fp:
        fp.write(proto.SerializeToString())
Beispiel #2
0
def pyrnn_to_pronn(pyrnn=None, output='en-default.pronn'):
    """
    Converts a legacy python RNN to the new protobuf format. Benefits of the
    new format include independence from particular python versions and no
    arbitrary code execution issues inherent in pickle.

    Args:
        pyrnn (kraken.lib.lstm.SegRecognizer): pyrnn model
        output (unicode): path of the converted HDF5 model
    """
    proto = pyrnn_pb2.pyrnn()
    proto.kind = 'pyrnn-bidi'
    proto.ninput = pyrnn.Ni
    proto.noutput = pyrnn.No
    proto.codec.extend(pyrnn.codec.code2char.values())

    parallel, softmax = pyrnn.lstm.nets
    fwdnet, revnet = parallel.nets
    revnet = revnet.net
    for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'):
            fwd_weights = getattr(fwdnet, w)
            rev_weights = getattr(revnet, w)
            fwd_ar = getattr(proto.fwdnet, w.lower())
            rev_ar = getattr(proto.revnet, w.lower())
            fwd_ar.dim.extend(fwd_weights.shape)
            fwd_ar.value.extend(fwd_weights.reshape(-1).tolist())
            rev_ar.dim.extend(rev_weights.shape)
            rev_ar.value.extend(rev_weights.reshape(-1).tolist())
    proto.softmax.w2.dim.extend(softmax.W2.shape)
    proto.softmax.w2.value.extend(softmax.W2.reshape(-1).tolist())
    with open(output, 'wb') as fp:
        fp.write(proto.SerializeToString())
Beispiel #3
0
    def load_pronn_model(cls, path: str):
        """
        Loads an pronn model to VGSL.
        """
        with open(path, 'rb') as fp:
            net = pyrnn_pb2.pyrnn()
            try:
                net.ParseFromString(fp.read())
            except Exception:
                raise KrakenInvalidModelException('File does not contain valid proto msg')
            if not net.IsInitialized():
                raise KrakenInvalidModelException('Model incomplete')

        # extract codec
        codec = PytorchCodec(net.codec)

        input = net.ninput
        hidden = net.fwdnet.wgi.dim[0]

        # extract weights
        weightnames = ('wgi', 'wgf', 'wci', 'wgo', 'wip', 'wfp', 'wop')

        fwd_w = []
        rev_w = []
        for w in weightnames:
            fwd_ar = getattr(net.fwdnet, w)
            rev_ar = getattr(net.revnet, w)
            fwd_w.append(torch.Tensor(fwd_ar.value).view(list(fwd_ar.dim)))
            rev_w.append(torch.Tensor(rev_ar.value).view(list(rev_ar.dim)))

        t = torch.cat(fwd_w[:4])
        weight_ih_l0 = t[:, :input+1]
        weight_hh_l0 = t[:, input+1:]

        t = torch.cat(rev_w[:4])
        weight_ih_l0_rev = t[:, :input+1]
        weight_hh_l0_rev = t[:, input+1:]

        weight_lin = torch.Tensor(net.softmax.w2.value).view(list(net.softmax.w2.dim))

        # build vgsl spec and set weights
        nn = cls('[1,1,0,{} Lbxo{} O1ca{}]'.format(input, hidden, len(net.codec)))

        nn.nn.L_0.layer.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0)
        nn.nn.L_0.layer.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0)
        nn.nn.L_0.layer.weight_ih_l0_reverse = torch.nn.Parameter(weight_ih_l0_rev)
        nn.nn.L_0.layer.weight_hh_l0_reverse = torch.nn.Parameter(weight_hh_l0_rev)
        nn.nn.L_0.layer.weight_ip_l0 = torch.nn.Parameter(fwd_w[4])
        nn.nn.L_0.layer.weight_fp_l0 = torch.nn.Parameter(fwd_w[5])
        nn.nn.L_0.layer.weight_op_l0 = torch.nn.Parameter(fwd_w[6])
        nn.nn.L_0.layer.weight_ip_l0_reverse = torch.nn.Parameter(rev_w[4])
        nn.nn.L_0.layer.weight_fp_l0_reverse = torch.nn.Parameter(rev_w[5])
        nn.nn.L_0.layer.weight_op_l0_reverse = torch.nn.Parameter(rev_w[6])

        nn.nn.O_1.lin.weight = torch.nn.Parameter(weight_lin)

        nn.add_codec(codec)

        return nn
Beispiel #4
0
    def load_pronn_model(cls, path: str):
        """
        Loads an pronn model to VGSL.
        """
        with open(path, 'rb') as fp:
            net = pyrnn_pb2.pyrnn()
            try:
                net.ParseFromString(fp.read())
            except Exception:
                raise KrakenInvalidModelException('File does not contain valid proto msg')
            if not net.IsInitialized():
                raise KrakenInvalidModelException('Model incomplete')

        # extract codec
        codec = PytorchCodec(net.codec)

        input = net.ninput
        hidden = net.fwdnet.wgi.dim[0]

        # extract weights
        weightnames = ('wgi', 'wgf', 'wci', 'wgo', 'wip', 'wfp', 'wop')

        fwd_w = []
        rev_w = []
        for w in weightnames:
            fwd_ar = getattr(net.fwdnet, w)
            rev_ar = getattr(net.revnet, w)
            fwd_w.append(torch.Tensor(fwd_ar.value).view(list(fwd_ar.dim)))
            rev_w.append(torch.Tensor(rev_ar.value).view(list(rev_ar.dim)))

        t = torch.cat(fwd_w[:4])
        weight_ih_l0 = t[:, :input+1]
        weight_hh_l0 = t[:, input+1:]

        t = torch.cat(rev_w[:4])
        weight_ih_l0_rev = t[:, :input+1]
        weight_hh_l0_rev = t[:, input+1:]

        weight_lin = torch.Tensor(net.softmax.w2.value).view(list(net.softmax.w2.dim))

        # build vgsl spec and set weights
        nn = cls('[1,1,0,{} Lbxo{} O1ca{}]'.format(input, hidden, len(net.codec)))

        nn.nn.L_0.layer.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0)
        nn.nn.L_0.layer.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0)
        nn.nn.L_0.layer.weight_ih_l0_reverse = torch.nn.Parameter(weight_ih_l0_rev)
        nn.nn.L_0.layer.weight_hh_l0_reverse = torch.nn.Parameter(weight_hh_l0_rev)
        nn.nn.L_0.layer.weight_ip_l0 = torch.nn.Parameter(fwd_w[4])
        nn.nn.L_0.layer.weight_fp_l0 = torch.nn.Parameter(fwd_w[5])
        nn.nn.L_0.layer.weight_op_l0 = torch.nn.Parameter(fwd_w[6])
        nn.nn.L_0.layer.weight_ip_l0_reverse = torch.nn.Parameter(rev_w[4])
        nn.nn.L_0.layer.weight_fp_l0_reverse = torch.nn.Parameter(rev_w[5])
        nn.nn.L_0.layer.weight_op_l0_reverse = torch.nn.Parameter(rev_w[6])

        nn.nn.O_1.lin.weight = torch.nn.Parameter(weight_lin)

        nn.add_codec(codec)

        return nn
Beispiel #5
0
def load_pronn(fname):
    """
    Loads a legacy pyrnn model in protobuf format and instantiates a
    kraken.lib.lstm.SeqRecognizer object.

    Args:
        fname (str): Path to the protobuf file

    Returns:
        A kraken.lib.lstm.SeqRecognizer object
    """
    logger.info(u'Trying to load prornn model from {}'.format(fname))
    with open(fname, 'rb') as fp:
        logger.debug(u'Initializing protobuf message')
        proto = pyrnn_pb2.pyrnn()
        try:
            proto.ParseFromString(fp.read())
        except:
            logger.debug(u'File does not contain valid proto msg')
            raise KrakenInvalidModelException(
                'File does not contain valid proto msg')
        if not proto.IsInitialized():
            logger.debug(u'Message in file incomplete')
            raise KrakenInvalidModelException('Model incomplete')
        # extract codec
        logger.debug(u'Extracting codec')
        codec = kraken.lib.lstm.Codec().init(proto.codec)
        hiddensize = proto.fwdnet.wgi.dim[0]
        # next build a line estimator
        logger.debug(u'Add line estimator')
        lnorm = kraken.lib.lineest.CenterNormalizer(proto.ninput)
        network = kraken.lib.lstm.SeqRecognizer(
            lnorm.target_height,
            hiddensize,
            codec=codec,
            normalize=kraken.lib.lstm.normalize_nfkc)
        logger.debug(u'Setting weights on BIDILSTM')
        parallel, softmax = network.lstm.nets
        fwdnet, revnet = parallel.nets
        revnet = revnet.net
        for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'):
            fwd_ar = getattr(proto.fwdnet, w.lower())
            rev_ar = getattr(proto.revnet, w.lower())
            setattr(fwdnet, w, numpy.array(fwd_ar.value).reshape(fwd_ar.dim))
            setattr(revnet, w, numpy.array(rev_ar.value).reshape(rev_ar.dim))
        softmax.W2 = numpy.array(proto.softmax.w2.value).reshape(
            proto.softmax.w2.dim)
        return network
Beispiel #6
0
def load_pronn(fname):
    """
    Loads a legacy pyrnn model in protobuf format and instantiates a
    kraken.lib.lstm.SeqRecognizer object.

    Args:
        fname (unicode): Path to the HDF5 file

    Returns:
        A kraken.lib.lstm.SeqRecognizer object
    """
    with open(fname, 'rb') as fp:
        proto = pyrnn_pb2.pyrnn()
        try:
            proto.ParseFromString(fp.read())
        except:
            raise KrakenInvalidModelException(
                'File does not contain valid proto msg')
        if not proto.IsInitialized():
            raise KrakenInvalidModelException('Model incomplete')
        # extract codec
        codec = kraken.lib.lstm.Codec().init(proto.codec)
        hiddensize = proto.fwdnet.wgi.dim[0]
        # next build a line estimator
        lnorm = kraken.lib.lineest.CenterNormalizer(proto.ninput)
        network = kraken.lib.lstm.SeqRecognizer(
            lnorm.target_height,
            hiddensize,
            codec=codec,
            normalize=kraken.lib.lstm.normalize_nfkc)
        parallel, softmax = network.lstm.nets
        fwdnet, revnet = parallel.nets
        revnet = revnet.net
        for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'):
            fwd_ar = getattr(proto.fwdnet, w.lower())
            rev_ar = getattr(proto.revnet, w.lower())
            setattr(fwdnet, w, numpy.array(fwd_ar.value).reshape(fwd_ar.dim))
            setattr(revnet, w, numpy.array(rev_ar.value).reshape(rev_ar.dim))
        softmax.W2 = numpy.array(proto.softmax.w2.value).reshape(
            proto.softmax.w2.dim)
        return network
Beispiel #7
0
def load_pronn(fname):
    """
    Loads a legacy pyrnn model in protobuf format and instantiates a
    kraken.lib.lstm.SeqRecognizer object.

    Args:
        fname (unicode): Path to the HDF5 file

    Returns:
        A kraken.lib.lstm.SeqRecognizer object
    """
    with open(fname, 'rb') as fp:
        proto = pyrnn_pb2.pyrnn()
        try:
            proto.ParseFromString(fp.read())
        except:
            raise KrakenInvalidModelException('File does not contain valid proto msg')
        if not proto.IsInitialized():
            raise KrakenInvalidModelException('Model incomplete')
        # extract codec
        codec = kraken.lib.lstm.Codec().init(proto.codec)
        hiddensize = proto.fwdnet.wgi.dim[0]
        # next build a line estimator
        lnorm = kraken.lib.lineest.CenterNormalizer(proto.ninput)
        network = kraken.lib.lstm.SeqRecognizer(lnorm.target_height,
                                                hiddensize,
                                                codec=codec,
                                                normalize=kraken.lib.lstm.normalize_nfkc)
        parallel, softmax = network.lstm.nets
        fwdnet, revnet = parallel.nets
        revnet = revnet.net
        for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'):
            fwd_ar = getattr(proto.fwdnet, w.lower())
            rev_ar = getattr(proto.revnet, w.lower())
            setattr(fwdnet, w, numpy.array(fwd_ar.value).reshape(fwd_ar.dim))
            setattr(revnet, w, numpy.array(rev_ar.value).reshape(rev_ar.dim))
        softmax.W2 = numpy.array(proto.softmax.w2.value).reshape(proto.softmax.w2.dim)
        return network