def test_gpt2(model_name, ctx): # test from pretrained assert len(list_pretrained_gpt2()) > 0 with tempfile.TemporaryDirectory() as root, ctx: cfg, tokenizer, params_path, lm_params_path =\ get_pretrained_gpt2(model_name, load_backbone=True, load_lm=True, root=root) assert cfg.MODEL.vocab_size == len(tokenizer.vocab) # test backbone gpt2_model = GPT2Model.from_cfg(cfg) gpt2_model.load_parameters(params_path) # test lm model gpt2_lm_model = GPT2ForLM(cfg) gpt2_lm_model.load_parameters(lm_params_path) # test forward batch_size = 3 seq_length = 32 vocab_size = len(tokenizer.vocab) input_ids = mx.np.array(np.random.randint(2, vocab_size, (batch_size, seq_length)), dtype=np.int32, ctx=ctx) logits, _ = gpt2_lm_model(input_ids, gpt2_lm_model.init_states(batch_size, ctx)) mx.npx.waitall() # test backward label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size) with mx.autograd.record(): logits, _ = gpt2_lm_model( input_ids, gpt2_lm_model.init_states(batch_size, ctx)) loss = label_smooth_loss(logits, input_ids) loss.backward() mx.npx.waitall()
def calculate_metrics(args): with open(args.file, encoding='utf-8') as of: samples = of.read() pattern = '=' * 40 + r' SAMPLE \d+ ' + '=' * 40 + '\n' samples = re.split(pattern, samples)[1:] num_samples = args.num_samples if args.num_samples else len(samples) assert num_samples <= len(samples), \ f'The requested number of samples {num_samples} is greater than ' \ f'total number of samples {len(samples)}' samples = samples[:num_samples] num_bleu_samples = args.num_bleu_samples if args.num_bleu_samples else num_samples assert num_bleu_samples <= num_samples, \ f'The requested number of samples {num_bleu_samples} for ' \ f'calculating self-BLEU is greater than number of samples {num_samples}.' seed = args.seed if args.seed is not None else 0 random.seed(seed) _, tokenizer, _, _ = get_pretrained_gpt2(load_backbone=False, load_lm=False) sample_ids = tokenizer.encode(samples, output_type=int) if sample_ids[-1] == tokenizer.vocab.eos_id: sample_ids.pop() self_bleu4 = calculate_self_bleu4(sample_ids, num_bleu_samples) zipf_coefficient = calculate_zipf_coefficient(sample_ids, tokenizer) repetition = calculate_repetition(sample_ids) print('Self BLEU 4: {}\n' 'Zipf coefficient: {}\n' 'Repetition: {}\n'.format(self_bleu4, zipf_coefficient, repetition))
def sample_gpt2(args): ctx = mx.gpu(args.gpu) if args.gpu is not None else \ mx.cpu() cfg, tokenizer, _, lm_params_path = get_pretrained_gpt2( model_name=args.model_name, load_backbone=False, load_lm=True) cfg.defrost() cfg.MODEL.layout = args.layout cfg.freeze() if args.length is None: args.length = cfg.MODEL.max_length assert args.length <= cfg.MODEL.max_length, \ "Can't get samples longer than window size: {}".format(cfg.MODEL.max_length) model = GPT2ForLM(cfg) model.hybridize() model.load_parameters(lm_params_path, ctx=ctx) gpt2decoder = GPT2Decoder(model) sampler = BeamSearchSampler(beam_size=1, decoder=gpt2decoder, eos_id=None, vocab_size=cfg.MODEL.vocab_size, max_length_a=0, max_length_b=args.length, min_length=1, temperature=args.temperature, sampling=True, sampling_topp=args.top_p, sampling_topk=args.top_k, early_return=False) start_states = gpt2decoder.init_states(args.batch_size, ctx) while True: raw_text = input('Model prompt >>> ') while not raw_text: print('Prompt should not be empty!') raw_text = input("Model prompt >>> ") context_tokens = tokenizer.encode(raw_text, output_type=int) batch_axis = 0 if args.layout == 'NT' else 1 new_shape = (args.batch_size, len(context_tokens)) if args.layout == 'NT' else \ (len(context_tokens), args.batch_size) start_input = mx.np.broadcast_to( mx.np.expand_dims(mx.np.array(context_tokens, ctx=ctx), batch_axis), new_shape) generated = 0 while generated < args.nsamples: samples, _, _ = sampler(start_input, start_states) for i in range(args.batch_size): generated += 1 ids = samples[i][0].asnumpy().tolist() ids = ids[1:ids.index(-1)] if -1 in ids else \ ids[1:] text = tokenizer.decode(ids) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text) print("=" * 80)
def sample_gpt2(args): ctx = mx.gpu(args.gpu) if args.gpu is not None else \ mx.cpu() cfg, tokenizer, _, lm_params_path = get_pretrained_gpt2( model_name=args.model_name, load_backbone=False, load_lm=True) cfg.defrost() cfg.MODEL.layout = args.layout cfg.freeze() if args.length is None: args.length = cfg.MODEL.max_length assert args.length <= cfg.MODEL.max_length, \ "Can't get samples longer than window size: {}".format(cfg.MODEL.max_length) model = GPT2ForLM(cfg) model.hybridize() model.load_parameters(lm_params_path, ctx=ctx) gpt2decoder = GPT2Decoder(model) sampler = BeamSearchSampler(beam_size=1, decoder=gpt2decoder, eos_id=None, vocab_size=cfg.MODEL.vocab_size, max_length_a=0, max_length_b=args.length, min_length=1, temperature=args.temperature, sampling=True, sampling_topp=args.top_p, sampling_topk=args.top_k, early_return=False) start_input = mx.np.full((args.batch_size, 1) if args.layout == 'NT' else (1, args.batch_size), tokenizer.vocab.eos_id, ctx=ctx) start_states = gpt2decoder.init_states(args.batch_size, ctx) generated = 0 while args.nsamples <= 0 or generated < args.nsamples: samples, _, _ = sampler(start_input, start_states) for i in range(args.batch_size): generated += 1 ids = samples[i][0].asnumpy().tolist() ids = ids[1:ids.index(-1)] if -1 in ids else \ ids[1:] text = tokenizer.decode(ids) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text)
def calculate_metrics(args): with open(args.file, encoding='utf-8') as of: samples = of.read() pattern = '=' * 40 + ' SAMPLE \d+ ' + '=' * 40 + '\n' samples = re.split(pattern, samples)[1:] samples = samples[:args.num_samples] assert len(samples) == args.num_samples _, tokenizer, _, _ = get_pretrained_gpt2(load_backbone=False, load_lm=False) sample_ids = tokenizer.encode(samples, output_type=int) if sample_ids[-1] == tokenizer.vocab.eos_id: sample_ids.pop() sample_strs = tokenizer.encode(samples, output_type=str) self_bleu4 = calculate_self_bleu4(sample_strs, args.num_bleu_samples) zipf_coefficient = calculate_zipf_coefficient(sample_ids, tokenizer) repetition = calculate_repetition(sample_ids) print('Self BLEU 4: {}\n' 'Zipf coefficient: {}\n' 'Repetition: {}\n'.format(self_bleu4, zipf_coefficient, repetition))