Ejemplo n.º 1
0
def test_gpt2(model_name, ctx):
    # test from pretrained
    assert len(list_pretrained_gpt2()) > 0
    with tempfile.TemporaryDirectory() as root, ctx:
        cfg, tokenizer, params_path, lm_params_path =\
            get_pretrained_gpt2(model_name, load_backbone=True, load_lm=True, root=root)
        assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
        # test backbone
        gpt2_model = GPT2Model.from_cfg(cfg)
        gpt2_model.load_parameters(params_path)
        # test lm model
        gpt2_lm_model = GPT2ForLM(cfg)
        gpt2_lm_model.load_parameters(lm_params_path)

        # test forward
        batch_size = 3
        seq_length = 32
        vocab_size = len(tokenizer.vocab)
        input_ids = mx.np.array(np.random.randint(2, vocab_size,
                                                  (batch_size, seq_length)),
                                dtype=np.int32,
                                ctx=ctx)
        logits, _ = gpt2_lm_model(input_ids,
                                  gpt2_lm_model.init_states(batch_size, ctx))
        mx.npx.waitall()
        # test backward
        label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size)
        with mx.autograd.record():
            logits, _ = gpt2_lm_model(
                input_ids, gpt2_lm_model.init_states(batch_size, ctx))
            loss = label_smooth_loss(logits, input_ids)
            loss.backward()
        mx.npx.waitall()
Ejemplo n.º 2
0
def calculate_metrics(args):
    with open(args.file, encoding='utf-8') as of:
        samples = of.read()
    pattern = '=' * 40 + r' SAMPLE \d+ ' + '=' * 40 + '\n'
    samples = re.split(pattern, samples)[1:]

    num_samples = args.num_samples if args.num_samples else len(samples)
    assert num_samples <= len(samples), \
        f'The requested number of samples {num_samples} is greater than ' \
        f'total number of samples {len(samples)}'
    samples = samples[:num_samples]
    num_bleu_samples = args.num_bleu_samples if args.num_bleu_samples else num_samples
    assert num_bleu_samples <= num_samples, \
        f'The requested number of samples {num_bleu_samples} for ' \
        f'calculating self-BLEU is greater than number of samples {num_samples}.'
    seed = args.seed if args.seed is not None else 0
    random.seed(seed)

    _, tokenizer, _, _ = get_pretrained_gpt2(load_backbone=False,
                                             load_lm=False)
    sample_ids = tokenizer.encode(samples, output_type=int)
    if sample_ids[-1] == tokenizer.vocab.eos_id:
        sample_ids.pop()

    self_bleu4 = calculate_self_bleu4(sample_ids, num_bleu_samples)
    zipf_coefficient = calculate_zipf_coefficient(sample_ids, tokenizer)
    repetition = calculate_repetition(sample_ids)
    print('Self BLEU 4: {}\n'
          'Zipf coefficient: {}\n'
          'Repetition: {}\n'.format(self_bleu4, zipf_coefficient, repetition))
Ejemplo n.º 3
0
def sample_gpt2(args):
    ctx = mx.gpu(args.gpu) if args.gpu is not None else \
          mx.cpu()

    cfg, tokenizer, _, lm_params_path = get_pretrained_gpt2(
        model_name=args.model_name, load_backbone=False, load_lm=True)
    cfg.defrost()
    cfg.MODEL.layout = args.layout
    cfg.freeze()

    if args.length is None:
        args.length = cfg.MODEL.max_length
    assert args.length <= cfg.MODEL.max_length, \
           "Can't get samples longer than window size: {}".format(cfg.MODEL.max_length)

    model = GPT2ForLM(cfg)
    model.hybridize()
    model.load_parameters(lm_params_path, ctx=ctx)
    gpt2decoder = GPT2Decoder(model)

    sampler = BeamSearchSampler(beam_size=1,
                                decoder=gpt2decoder,
                                eos_id=None,
                                vocab_size=cfg.MODEL.vocab_size,
                                max_length_a=0,
                                max_length_b=args.length,
                                min_length=1,
                                temperature=args.temperature,
                                sampling=True,
                                sampling_topp=args.top_p,
                                sampling_topk=args.top_k,
                                early_return=False)
    start_states = gpt2decoder.init_states(args.batch_size, ctx)

    while True:
        raw_text = input('Model prompt >>> ')
        while not raw_text:
            print('Prompt should not be empty!')
            raw_text = input("Model prompt >>> ")
        context_tokens = tokenizer.encode(raw_text, output_type=int)
        batch_axis = 0 if args.layout == 'NT' else 1
        new_shape = (args.batch_size, len(context_tokens)) if args.layout == 'NT' else \
                    (len(context_tokens), args.batch_size)
        start_input = mx.np.broadcast_to(
            mx.np.expand_dims(mx.np.array(context_tokens, ctx=ctx),
                              batch_axis), new_shape)
        generated = 0
        while generated < args.nsamples:
            samples, _, _ = sampler(start_input, start_states)
            for i in range(args.batch_size):
                generated += 1
                ids = samples[i][0].asnumpy().tolist()
                ids = ids[1:ids.index(-1)] if -1 in ids else \
                      ids[1:]
                text = tokenizer.decode(ids)
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
        print("=" * 80)
def sample_gpt2(args):
    ctx = mx.gpu(args.gpu) if args.gpu is not None else \
          mx.cpu()

    cfg, tokenizer, _, lm_params_path = get_pretrained_gpt2(
        model_name=args.model_name, load_backbone=False, load_lm=True)
    cfg.defrost()
    cfg.MODEL.layout = args.layout
    cfg.freeze()

    if args.length is None:
        args.length = cfg.MODEL.max_length
    assert args.length <= cfg.MODEL.max_length, \
           "Can't get samples longer than window size: {}".format(cfg.MODEL.max_length)

    model = GPT2ForLM(cfg)
    model.hybridize()
    model.load_parameters(lm_params_path, ctx=ctx)
    gpt2decoder = GPT2Decoder(model)

    sampler = BeamSearchSampler(beam_size=1,
                                decoder=gpt2decoder,
                                eos_id=None,
                                vocab_size=cfg.MODEL.vocab_size,
                                max_length_a=0,
                                max_length_b=args.length,
                                min_length=1,
                                temperature=args.temperature,
                                sampling=True,
                                sampling_topp=args.top_p,
                                sampling_topk=args.top_k,
                                early_return=False)

    start_input = mx.np.full((args.batch_size, 1) if args.layout == 'NT' else
                             (1, args.batch_size),
                             tokenizer.vocab.eos_id,
                             ctx=ctx)
    start_states = gpt2decoder.init_states(args.batch_size, ctx)

    generated = 0
    while args.nsamples <= 0 or generated < args.nsamples:
        samples, _, _ = sampler(start_input, start_states)
        for i in range(args.batch_size):
            generated += 1
            ids = samples[i][0].asnumpy().tolist()
            ids = ids[1:ids.index(-1)] if -1 in ids else \
                  ids[1:]
            text = tokenizer.decode(ids)
            print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
            print(text)
Ejemplo n.º 5
0
def calculate_metrics(args):
    with open(args.file, encoding='utf-8') as of:
        samples = of.read()
    pattern = '=' * 40 + ' SAMPLE \d+ ' + '=' * 40 + '\n'
    samples = re.split(pattern, samples)[1:]
    samples = samples[:args.num_samples]
    assert len(samples) == args.num_samples

    _, tokenizer, _, _ = get_pretrained_gpt2(load_backbone=False,
                                             load_lm=False)
    sample_ids = tokenizer.encode(samples, output_type=int)
    if sample_ids[-1] == tokenizer.vocab.eos_id:
        sample_ids.pop()
    sample_strs = tokenizer.encode(samples, output_type=str)

    self_bleu4 = calculate_self_bleu4(sample_strs, args.num_bleu_samples)
    zipf_coefficient = calculate_zipf_coefficient(sample_ids, tokenizer)
    repetition = calculate_repetition(sample_ids)
    print('Self BLEU 4: {}\n'
          'Zipf coefficient: {}\n'
          'Repetition: {}\n'.format(self_bleu4, zipf_coefficient, repetition))