コード例 #1
0
ファイル: test_models_gpt2.py プロジェクト: zheyuye/gluon-nlp
def test_gpt2(model_name, ctx):
    # test from pretrained
    assert len(list_pretrained_gpt2()) > 0
    with tempfile.TemporaryDirectory() as root, ctx:
        cfg, tokenizer, params_path, lm_params_path =\
            get_pretrained_gpt2(model_name, load_backbone=True, load_lm=True, root=root)
        assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
        # test backbone
        gpt2_model = GPT2Model.from_cfg(cfg)
        gpt2_model.load_parameters(params_path)
        # test lm model
        gpt2_lm_model = GPT2ForLM(cfg)
        gpt2_lm_model.load_parameters(lm_params_path)

        # test forward
        batch_size = 3
        seq_length = 32
        vocab_size = len(tokenizer.vocab)
        input_ids = mx.np.array(np.random.randint(2, vocab_size,
                                                  (batch_size, seq_length)),
                                dtype=np.int32,
                                ctx=ctx)
        logits, _ = gpt2_lm_model(input_ids,
                                  gpt2_lm_model.init_states(batch_size, ctx))
        mx.npx.waitall()
        # test backward
        label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size)
        with mx.autograd.record():
            logits, _ = gpt2_lm_model(
                input_ids, gpt2_lm_model.init_states(batch_size, ctx))
            loss = label_smooth_loss(logits, input_ids)
            loss.backward()
        mx.npx.waitall()
コード例 #2
0
def parse_args():
    parser = argparse.ArgumentParser(
        description=
        'GPT-2 unconditional sampler. Load a GPT-2 model and sample.')
    parser.add_argument('--model_name',
                        type=str,
                        default='gpt2_124M',
                        choices=list_pretrained_gpt2(),
                        help='Model name')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='The random seed')
    parser.add_argument('--nsamples',
                        type=int,
                        default=0,
                        help='Number of samples to return')
    parser.add_argument('--batch_size',
                        type=int,
                        default=1,
                        help='Number of batches')
    parser.add_argument(
        '--length',
        type=int,
        default=None,
        help='Number of tokens in generated text, if None (default), is '
        'determined by model max_length')
    parser.add_argument('--temperature', type=float, default=1.0, help='')
    parser.add_argument('--top_k',
                        type=int,
                        default=-1,
                        help='Multinomial sampling with topk, '
                        'see [ACL2018] "Hierarchical Neural Story Generation"'
                        'https://www.aclweb.org/anthology/P18-1082.pdf')
    parser.add_argument(
        '--top_p',
        type=float,
        default=-1.0,
        help='Multinomial sampling with topp, '
        'see [ICLR2020] "The Curious Case of Neural Text Degeneration"'
        'https://arxiv.org/abs/1904.09751')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='Which gpu to use, set None to use cpu')
    parser.add_argument('--layout',
                        type=str,
                        choices=['NT', 'TN'],
                        default='NT',
                        help='Layout of the inference model')
    return parser.parse_args()
コード例 #3
0
def test_list_pretrained_gpt2():
    assert len(list_pretrained_gpt2()) > 0
コード例 #4
0
            hiddens, states = gpt2_model(
                inputs[:, i:i+1],
                states,
                mx.np.array(i, dtype=np.int32, ctx=ctx)
            )
            hiddens_l.append(hiddens)
        hiddens_concat = mx.np.concatenate(hiddens_l, axis=1)
        assert_allclose(one_time_states.asnumpy(),
                        states.asnumpy(), 1E-4, 1E-4)
        assert_allclose(one_time_hiddens.asnumpy(),
                        hiddens_concat.asnumpy(), 1E-4, 1E-4)


@pytest.mark.slow
@pytest.mark.remote_required
@pytest.mark.parametrize('model_name', list_pretrained_gpt2())
def test_gpt2(model_name, ctx):
    # test from pretrained
    assert len(list_pretrained_gpt2()) > 0
    with tempfile.TemporaryDirectory() as root, ctx:
        cfg, tokenizer, params_path, lm_params_path =\
            get_pretrained_gpt2(model_name, load_backbone=True, load_lm=True, root=root)
        assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
        # test backbone
        gpt2_model = GPT2Model.from_cfg(cfg)
        gpt2_model.load_parameters(params_path)
        # test lm model
        gpt2_lm_model = GPT2ForLM(cfg)
        gpt2_lm_model.load_parameters(lm_params_path)

        # test forward