Ejemplo n.º 1
0
def load(model: str = 'xlnet', pool_mode: str = 'last', **kwargs):
    """
    Load xlnet model.

    Parameters
    ----------
    model : str, optional (default='base')
        Model architecture supported. Allowed values:

        * ``'xlnet'`` - XLNET architecture from google.
    pool_mode : str, optional (default='last')
        Model logits architecture supported. Allowed values:

        * ``'last'`` - last of the sequence.
        * ``'first'`` - first of the sequence.
        * ``'mean'`` - mean of the sequence.
        * ``'attn'`` - attention of the sequence.

    Returns
    -------
    result : malaya.transformers.xlnet.Model class
    """

    model = model.lower()
    pool_mode = pool_mode.lower()

    if pool_mode not in ['last', 'first', 'mean', 'attn']:
        raise Exception(
            "pool_mode not supported, only support ['last', 'first', 'mean', 'attn']"
        )

    path = check_file(PATH_XLNET[model]['model'], S3_PATH_XLNET[model],
                      **kwargs)

    if not os.path.exists(
            os.path.join(PATH_XLNET[model]['directory'], 'model.ckpt')):
        import tarfile

        with tarfile.open(path['model']) as tar:
            tar.extractall(path=PATH_XLNET[model]['path'])

    vocab_model = os.path.join(PATH_XLNET[model]['directory'],
                               'sp10m.cased.v9.model')
    vocab = os.path.join(PATH_XLNET[model]['directory'],
                         'sp10m.cased.v9.vocab')
    tokenizer = SentencePieceTokenizer(vocab_file=vocab,
                                       spm_model_file=vocab_model)
    xlnet_config = xlnet_lib.XLNetConfig(
        json_path=os.path.join(PATH_XLNET[model]['directory'], 'config.json'))
    xlnet_checkpoint = os.path.join(PATH_XLNET[model]['directory'],
                                    'model.ckpt')
    model = Model(xlnet_config,
                  tokenizer,
                  xlnet_checkpoint,
                  pool_mode=pool_mode,
                  **kwargs)
    model._saver.restore(model._sess, xlnet_checkpoint)
    return model
Ejemplo n.º 2
0
def load(model: str = 'xlnet', pool_mode: str = 'last', **kwargs):
    """
    Load xlnet model.

    Parameters
    ----------
    model : str, optional (default='base')
        Model architecture supported. Allowed values:

        * ``'xlnet'`` - XLNET architecture from google.
    pool_mode : str, optional (default='last')
        Model logits architecture supported. Allowed values:

        * ``'last'`` - last of the sequence.
        * ``'first'`` - first of the sequence.
        * ``'mean'`` - mean of the sequence.
        * ``'attn'`` - attention of the sequence.

    Returns
    -------
    result : malaya.transformers.xlnet.Model class
    """

    model = model.lower()
    pool_mode = pool_mode.lower()

    from malaya.path import PATH_XLNET, S3_PATH_XLNET
    from malaya.function import check_file

    if pool_mode not in ['last', 'first', 'mean', 'attn']:
        raise Exception(
            "pool_mode not supported, only support ['last', 'first', 'mean', 'attn']"
        )

    check_file(PATH_XLNET[model]['model'], S3_PATH_XLNET[model], **kwargs)

    if not os.path.exists(PATH_XLNET[model]['directory'] + 'model.ckpt'):
        import tarfile

        with tarfile.open(PATH_XLNET[model]['model']['model']) as tar:
            tar.extractall(path=PATH_XLNET[model]['path'])

    import sentencepiece as spm

    sp_model = spm.SentencePieceProcessor()
    sp_model.Load(PATH_XLNET[model]['directory'] + 'sp10m.cased.v9.model')
    xlnet_config = xlnet_lib.XLNetConfig(
        json_path=PATH_XLNET[model]['directory'] + 'config.json')
    xlnet_checkpoint = PATH_XLNET[model]['directory'] + 'model.ckpt'
    model = Model(xlnet_config,
                  sp_model,
                  xlnet_checkpoint,
                  pool_mode=pool_mode,
                  **kwargs)
    model._saver.restore(model._sess, xlnet_checkpoint)
    return model