Example #1
0
def gpt2(
    model: str = '345M',
    generate_length: int = 256,
    temperature: float = 1.0,
    top_k: int = 40,
    **kwargs,
):

    """
    Load GPT2 model to generate a string given a prefix string.

    Parameters
    ----------
    model : str, optional (default='345M')
        Model architecture supported. Allowed values:

        * ``'117M'`` - GPT2 117M parameters.
        * ``'345M'`` - GPT2 345M parameters.

    generate_length : int, optional (default=256)
        length of sentence to generate.
    temperature : float, optional (default=1.0)
        temperature value, value should between 0 and 1.
    top_k : int, optional (default=40)
        top-k in nucleus sampling selection.

    Returns
    -------
    result: malaya.transformers.gpt2.Model class
    """

    model = model.upper()
    if model not in _gpt2_availability:
        raise ValueError(
            'model not supported, please check supported models from `malaya.generator.available_gpt2()`.'
        )

    if generate_length < 10:
        raise ValueError('generate_length must bigger than 10')
    if not 0 < temperature <= 1.0:
        raise ValueError('temperature must, 0 < temperature <= 1.0')
    if top_k < 5:
        raise ValueError('top_k must bigger than 5')
    from malaya.transformers.gpt2 import load

    if tf.executing_eagerly():
        logging.warning(
            'Load pretrained GPT2 model will disable eager execution.'
        )
        tf.compat.v1.disable_eager_execution()

    return load(
        model = model,
        generate_length = generate_length,
        temperature = temperature,
        top_k = top_k,
        **kwargs,
    )
Example #2
0
def gpt2(
    model: str = '345M',
    generate_length: int = 256,
    temperature: float = 1.0,
    top_k: int = 40,
    **kwargs
):

    """
    Load transformer model.

    Parameters
    ----------
    model : str, optional (default='345M')
        Model architecture supported. Allowed values:

        * ``'117M'`` - GPT2 117M parameters.
        * ``'345M'`` - GPT2 345M parameters.

    generate_length : int, optional (default=256)
        length of sentence to generate.
    
    temperature : float, optional (default=1.0)
        temperature value, value should between 0 and 1.

    top_k : int, optional (default=40)
        top-k in nucleus sampling selection.

    Returns
    -------
    result: malaya.transformers.gpt2.Model class
    """

    model = model.upper()
    if model not in ['117M', '345M']:
        raise Exception(
            "model not supported, for now only supported ['117M', '345M']"
        )

    if generate_length < 10:
        raise Exception('generate_length must bigger than 10')
    if not 0 < temperature <= 1.0:
        raise Exception('temperature must, 0 < temperature <= 1.0')
    if top_k < 5:
        raise Exception('top_k must bigger than 5')
    from malaya.transformers.gpt2 import load

    return load(
        model = model,
        generate_length = generate_length,
        temperature = temperature,
        top_k = top_k,
        **kwargs
    )