Example #1
0

class DataParallel(nn.DataParallel):
    def __init__(self, module):
        super(DataParallel, self).__init__(module)
        self.replicas = None

    def replicate(self, module, device_ids):
        if self.replicas is None:
            from torch.nn.parallel.replicate import replicate
            self.replicas = replicate(module, device_ids, not torch.is_grad_enabled())
        return self.replicas

dset_sbert = Dataset.load_from_disk("/home/ahemf/processed_datasets/dsets_448_sbert")
model_id = "distilgpt2"  # 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_id).eval()
max_length = 256
stride = max_length
for p in model.parameters():
    p.requires_grad = False
if torch.cuda.device_count() > 1:
    model = DataParallel(model)

model.to(device)
for p in model.parameters():
    p.requires_grad = False
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token


def perplexity(x, device):
Example #2
0
 def __init__(self, config):
     self.device = "cuda" if torch.cuda.is_available() else "cpu"
     print(f"using device: {self.device}")
     self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
     self.model = GPT2LMHeadModel.from_pretrained("gpt2").to(self.device)
Example #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('experiment', type=str)

    # Default parameters are set based on single GPU training
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument("--seed", type=int, default=0)

    parser.add_argument('--data_type', type=str, default='t1', choices=['t' + str(i) for i in range(9)], help="t: type")
    parser.add_argument('--model_type', type=str, default='cvae', choices=['cvae', 'ae_vae_fusion'])
    parser.add_argument('--iterations', type=int, default=101640 * 4)  # wp 850001  wi 300001 ax 300001 yp 800001
    parser.add_argument('--dataset', type=str, default='wi', choices=['ax', 'yp', 'wp', 'wi'], help="Dataset to use for training")
    parser.add_argument('--warmup', type=int, default=10000,
                        help="Amount of iterations to warmup, then decay. (-1 for no warmup and decay)")

    parser.add_argument('--batch-sizes', nargs='+', type=int, default=[1],
                        help='batch size per GPU. Lists the schedule.')
    parser.add_argument('--seq-lens', nargs='+', type=int, default=[1024],
                        help='seq length per sample. Lists the schedule.')
    parser.add_argument('--switch-time', type=float, default=0,
                        help="Percentage of iterations to spend on short sequence training.")
    parser.add_argument('--data-dir', type=str, default='data')
    parser.add_argument('--out-dir', type=str, default='out')
    parser.add_argument('--load', type=str, help='path to load model from') # , default='out/test/'
    parser.add_argument('--workers', default=1, type=int, metavar='N',
                        help='number of data loading workers')
    # use GPU
    parser.add_argument('--gpu', default=0, type=int)
    parser.add_argument('--no_gpu', action="store_true")

    parser.add_argument('--fp16', action='store_true', help="Train using FP16?")
    parser.add_argument('--fp16_opt_level', default='O0', type=str, required=False)

    # KL cost annealing, increase beta from beta_0 to 1 in beta_warmup steps
    parser.add_argument('--beta_0', default=1.00, type=float)
    parser.add_argument('--beta_warmup', type=int, default=50000)
    # cyc_vae parameters
    parser.add_argument('--cycle', type=int, default=101640)

    parser.add_argument('--add_input', action="store_true")
    parser.add_argument('--add_attn', action="store_true")
    parser.add_argument('--add_softmax', action="store_true")
    parser.add_argument('--attn_proj_vary', action="store_true")

    parser.add_argument('--learn_prior', action="store_true")

    args = parser.parse_args('test --batch-sizes 1 --seq-lens 1024 '
                             '--add_input --learn_prior --fp16'.split()) # wi.12.proj_vary_beta_cvae

    if args.model_type == 'cvae':
        args.learn_prior = True
    else:
        args.learn_prior = False

    # GPU
    if not torch.cuda.is_available(): args.no_gpu = True
    gpu = not args.no_gpu
    if gpu:
        print("There are ", torch.cuda.device_count(), " available GPUs!")
        # print('Setting GPUs {}'.format(args.device))
        print('Using GPU devices {}'.format(devices))
        torch.cuda.set_device(args.gpu)
        print('Current single GPU: {}'.format(torch.cuda.current_device()))
    device = torch.device(args.gpu if gpu else "cpu")

    # randomness
    np.random.seed(args.seed)
    prng = np.random.RandomState()
    torch.random.manual_seed(args.seed)
    if gpu: torch.cuda.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed)

    # logging
    save_folder = os.path.join(args.out_dir, args.experiment)
    os.makedirs(save_folder, exist_ok=True)
    t_writer = SummaryWriter(os.path.join(save_folder, 'train'), flush_secs=5)
    v_writer = SummaryWriter(os.path.join(save_folder, 'val'), flush_secs=5)
    importlib.reload(logging)
    logging.basicConfig(filename=os.path.join(save_folder, 'train.log'),
                        level=logging.INFO, format='%(asctime)s--- %(message)s')
    logging.info('\n*******************************************************************************\n')
    logging.info("the configuration:")
    logging.info(str(args).replace(',', '\n'))

    print('Loading models...')
    cache_dir = os.path.join(args.out_dir, 'model_cache')
    os.makedirs(cache_dir, exist_ok=True)
    # Load pre-trained teacher tokenizer (vocabulary)
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
    # Hack to allow tokenizing longer sequences.
    tokenizer.max_len = int(1e12)
    gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=cache_dir)
    print('gpt2_params:', num_params(gpt2_model))  # gpt2: 124439808
    config = GPT2Config()

    # add special tokens
    # special_tokens_dict = {
    #     'pad_token': '<|startoftext|>',
    #     'cls_token': '<|startofcond|>',
    #     'sep_token': '<|sepofcond|>',
    #     'mask_token': '<|endofcond|>'
    # }
    # num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    # print('We have added', num_added_toks, 'special tokens')
    # # Notice: resize_token_embeddings expect to receive the full size of the new vocab
    # gpt2_model.resize_token_embeddings(len(tokenizer))
    # assert tokenizer.pad_token == '<|startoftext|>'

    VAE = VAEModel(config, add_input=args.add_input, add_attn=args.add_attn, add_softmax=args.add_softmax,
                   attn_proj_vary=args.attn_proj_vary, learn_prior=args.learn_prior)
    init_para_frompretrained(VAE.transformer, gpt2_model.transformer, share_para=True)
    init_para_frompretrained(VAE.encoder, gpt2_model.transformer, share_para=False)
    if args.learn_prior:
        init_para_frompretrained(VAE.encoder_prior, VAE.encoder, share_para=True)
        VAE.encoder_prior.averageSelfAttention.attention_weights = VAE.encoder.averageSelfAttention.attention_weights
    VAE.lm_head.weight = gpt2_model.lm_head.weight
    if VAE.add_softmax:
        VAE.lm_head_rep = Conv1D(*gpt2_model.lm_head.weight.size())
        # VAE.lm_head_rep = LM_head_rep(*gpt2_model.lm_head.weight.size()[::-1])
    print('VAE_params:', num_params(VAE))  # 286694400
    if args.load:
        print('Loading model weights...')
        state = torch.load(os.path.join(args.load, 'model_latest.pt'))  # , map_location='cpu' model_latest.pt
        if 'module' in list(state.keys())[0]:  # model_path is data parallel model with attr 'module'
            state_copy = copy.copy(state)
            keys = state_copy.keys()
            for k in keys:
                state[k.replace('module.', '')] = state.pop(k)
        VAE.load_state_dict(state)
        gc.collect()
    print('Done.')

    # fix pre-trained parameters before certain iterations
    tuning_all_after_iters = 40000
    tuning_all = False
    for name, parameter in VAE.named_parameters():
        # print((name, parameter.requires_grad))
        new_pars = ['c_z', 'attention_weights', 'mean', 'logvar', 'input_proj', 'attn_proj', 'Nu_fc1', 'Nu_fc2', 'lm_head_rep']

        if not any([True if n in name else False for n in new_pars]):
           parameter.requires_grad = False

    print('Setup data...')
    # Batch and sequence length schedule
    assert len(args.batch_sizes) == len(args.seq_lens)
    batch_schedule = list(zip(map(int, args.batch_sizes), map(int, args.seq_lens)))
    assert len(batch_schedule) <= 2, 'Currently not supporting multiple schedule'
    cur_b_schedule = len(batch_schedule) - 1 if args.switch_time == 0 else 0
    print('Batch schedule', batch_schedule)
    train_loader, val_loader, test_loader = prepare_dataset(
        args.data_dir, args.dataset, tokenizer,
        batch_schedule[cur_b_schedule][0], batch_schedule[cur_b_schedule][1],
        batch_schedule[-1][0], batch_schedule[-1][1],
        batch_schedule[-1][0], batch_schedule[-1][1],
        make_test=True,
        num_workers=args.workers, data_type=args.data_type
    )
    print('Done.')

    ###
    val_loader = test_loader
    ###

    print('Wrapping models and optimizers...')
    # Apply linear scaling rule to increase batch size for short sequence training.
    lr_schedule = switch_schedule(linear_schedule(args), batch_schedule[cur_b_schedule][0] / batch_schedule[-1][0],
                                  int(args.iterations * args.switch_time))
    VAE = VAE.to(device)
    VAE.train()

    optimizer = AdamW(VAE.parameters(), lr=args.lr, correct_bias=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
    VAE, optimizer = amp.initialize(VAE, optimizer, opt_level=args.fp16_opt_level)

    loss_fn = nn.CrossEntropyLoss(reduction='none')
    print('Done.')

    print('Begin training iterations')
    logging.info("Begin training iterations")
    max_val_batches = 20000  # max num. of val batches
    logging.info("Total iteration: %d" % args.iterations)
    e = 0  # number of epoch
    num_iters = 0
    optimizer.zero_grad()
    beta = args.beta_0
    endoftext = tokenizer.convert_tokens_to_ids("<|endoftext|>")

    def val_step(val_loader):
        VAE.eval()

        n_words_bpe = 0
        n_words = 0
        logp_sum = 0.0
        kl_loss_sum = 0.0

        logging.info("Validation loop.         Batches: %d" % len(val_loader))
        logging.info("Validation loop. max_val_batches: %d" % max_val_batches)

        # val_iter = iter(val_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(val_iter)
        with tqdm(total=min(len(val_loader), max_val_batches)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask) in enumerate(val_loader):
                with torch.no_grad():
                    if args.model_type == 'cvae':
                        loss, ce_loss, kl_loss = compute_loss(device, VAE, x_mask, x_tokens, y_mask, y_tokens,
                                                              input_tokens, target_tokens, mask, loss_fn, 1.0)
                    else:
                        loss, ce_loss, kl_loss = compute_loss_ae(device, VAE, x_mask, x_tokens, y_mask, y_tokens,
                                                              input_tokens, target_tokens, mask, loss_fn, 1.0)

                if len(target_tokens.size()) == 1:
                    target_tokens = target_tokens.unsqueeze(0)
                n, l = target_tokens.size()

                text = target_tokens[0, :].tolist()
                logprob = ce_loss.tolist()
                assert len(text) == len(logprob)

                # only for story
                idx = text.index(endoftext)
                text = text[idx + 1:]
                logprob = logprob[idx + 1:]

                if endoftext in text:
                    idx = text.index(endoftext)
                    text = text[:idx]
                    logprob = logprob[:idx]

                logp_sum += sum(logprob)

                n_words_bpe += len(text)

                story = [tokenizer.decode(target_tokens[i, :]) for i in range(n)]
                story = [s[s.find("<|endoftext|>") + len("<|endoftext|>"):] for s in story]
                story = [s[:s.find("<|endoftext|>") + len("<|endoftext|>")] if "<|endoftext|>" in s else s for s in
                         story]
                words = sum([len(
                    [t for t in re.split('("|\'|!|\?|\.|,|:| |\n|’|“|”|;|\(|\)|`)', s) if t != ' ' and t != '']) for
                    s in story])
                n_words += words

                kl_loss_sum += kl_loss.item()

                if i > max_val_batches:
                    break
                pbar.update(1)

        loss_bpe = logp_sum / n_words_bpe
        ppl_bpe = round(math.exp(min(logp_sum / n_words_bpe, 100)), 3)
        ppl_word = round(math.exp(min(logp_sum / n_words, 100)), 3)
        kl = kl_loss_sum / len(val_loader)

        v_writer.add_scalar('loss', loss_bpe, num_iters)
        v_writer.add_scalar('ppl_bpe', ppl_bpe, num_iters)
        v_writer.add_scalar('ppl_word', ppl_word, num_iters)
        v_writer.add_scalar('kl', kl, num_iters)
        logging.info('val loss    : %.4f' % loss_bpe)
        logging.info('val ppl_bpe : %.4f' % ppl_bpe)
        logging.info('val ppl_word: %.4f' % ppl_word)
        logging.info('val   kl    : %.4f' % kl)

        VAE.train()

    def test_plot(test_loader, num_iters):
        VAE.eval()

        # get embedding
        X_emb = None
        y = None

        # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter)
        with tqdm(total=len(test_loader)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask) in enumerate(
                    test_loader):
                y_mask = y_mask.to(device)
                y_tokens = y_tokens.to(device)
                x_mask = x_mask.to(device)
                x_tokens = x_tokens.to(device)
                with torch.no_grad():
                    if args.model_type == 'cvae':
                        latent_mean, latent_logvar = VAE.encoder_prior(input_ids=x_tokens, attention_mask=x_mask)[:2]
                    else:
                        latent_mean, latent_logvar = VAE.encoder(input_ids=x_tokens, attention_mask=x_mask)[:2]

                if args.dataset == 'ax' or args.dataset == 'yp':
                    label = [tokenizer.decode(l)[:2] for l in x_tokens.tolist()]
                elif args.dataset == 'wp':
                    label = []
                    prompts = [tokenizer.decode(l)[:6].lower() for l in x_tokens.tolist()]
                    for prom in prompts:
                        if prom[0] in ['[', '('] and prom[5] in [']', ')']:
                            label.append(prom[2:4])
                        else:
                            label.append(None)
                elif args.dataset == 'wi':
                    # 0. TV, play, miniseries, telenovela; 1.film; 2. music; 3. manga, comic, 4. book, novel, story 5. game
                    label = []
                    prompts = [tokenizer.decode(l) for l in x_tokens.tolist()]
                    for prom in prompts:
                        if 'TV' in prom or 'play' in prom or 'miniseries' in prom or 'telenovela' in prom:
                            label.append(0)
                        elif 'film' in prom:
                            label.append(1)
                        elif 'music' in prom:
                            label.append(2)
                        elif 'manga' in prom or 'comic' in prom:
                            label.append(3)
                        elif 'book' in prom or 'novel' in prom or 'story' in prom:
                            label.append(4)
                        elif 'game' in prom:
                            label.append(5)
                        else:
                            label.append(None)
                else:
                    raise Exception

                if i == 0:
                    X_emb = latent_mean.data
                    y = label
                else:
                    X_emb = torch.cat((X_emb, latent_mean.data), dim=0)
                    y.extend(label)
                pbar.update(1)
        X_emb = X_emb.cpu().numpy()

        try:
            if args.dataset == 'yp':
                y = ['0' if l in ['0', '1'] else l for l in y]
                y = ['4' if l in ['3', '4'] else l for l in y]
                X_emb = X_emb[[l != '2' for l in y], :]
                y = [l for l in y if l != '2']

            if args.dataset == 'wp':
                topics = [['wp', 'sp', 'tt'], ['eu'], ['cw'], ['pm'], ['mp', 'ip'], ['pi', 'cc'], ['ot'], ['rf']]
                match = [[True if l in t else False for t in topics] for l in y]
                y = [m.index(True) if True in m else None for m in match]
                X_emb = X_emb[[l is not None for l in y], :]
                y = [l for l in y if l is not None]

            if args.dataset == 'wi':
                X_emb = X_emb[[l is not None for l in y], :]
                y = [l for l in y if l is not None]

            # to 2D
            # X_emb_2d = TSNE(n_components=2, init='pca', verbose=1).fit_transform(X_emb)
            X_emb_2d = TSNE(n_components=2, verbose=1, perplexity=40).fit_transform(X_emb)

            def remove_outliers(data, r=2.0):
                outliers_data = abs(data - np.mean(data, axis=0)) >= r * np.std(data, axis=0)
                outliers = np.any(outliers_data, axis=1)
                keep = np.logical_not(outliers)
                return outliers, keep

            outliers, keep = remove_outliers(X_emb_2d)
            X_emb_2d = X_emb_2d[keep, :]
            y = [l for l, k in zip(y, keep.tolist()) if k]

            # plot
            fig = plt.figure(figsize=(4, 4))
            ax = fig.add_axes([0, 0, 1, 1])
            cc = ['r', 'b', 'g', 'y', 'k', 'c', 'm', 'tab:blue']
            for i, l in enumerate(sorted(set(y))):
                idx = [yl == l for yl in y]
                plt.scatter(X_emb_2d[idx, 0], X_emb_2d[idx, 1], c=cc[i], s=10, edgecolor='none', alpha=0.5)
            ax.axis('off')  # adding it will get no axis
            plt.savefig(os.path.join(save_folder, 'tSNE_' + '{:07d}'.format(num_iters) + '.png'))
            plt.close(fig)
        except:
            pass

        VAE.train()

    def generate(test_loader, num_iters):
        VAE.eval()

        n_samples = 0
        bleu4_sum = 0.0
        rouge_scores_values_sum = [0.0] * 9

        args.nsamples = 1
        args.batch_size = 1
        args.temperature = 0.95
        args.top_k = 100
        args.top_p = 0.95
        model_type = args.model_type

        # write samples to file
        samples_file = open(os.path.join(save_folder, 'generate-' + '%07d' % num_iters + '.txt'), 'w', encoding='utf8')

        # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter)
        with tqdm(total=len(test_loader)) as pbar:
            for i_test, (x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask) in enumerate(
                    test_loader):

                if i_test >= 10: break

                length = -1
                if length == -1:
                    length = VAE.config.n_ctx - x_tokens.size(1) - 1
                elif length > VAE.config.n_ctx - x_tokens.size(1) - 1:
                    raise ValueError("Can't get samples longer than window size: %s" % VAE.config.n_ctx)

                eff_samples = []
                n, l = target_tokens.size()
                storys = [tokenizer.decode(target_tokens[i, :]) for i in range(n)]
                storys = [s[s.find("<|endoftext|>") + len("<|endoftext|>"):] for s in storys]
                storys_str = [s[:s.find("<|endoftext|>") + len("<|endoftext|>")] if "<|endoftext|>" in s else s for s in
                              storys]

                for _ in range(args.nsamples // args.batch_size):
                    # model, batch_size, temperature, top_k, top_p, eos_token, sample = VAE, args.batch_size, args.temperature, args.top_k, args.top_p, tokenizer.encoder['<|endoftext|>'], True
                    out, _ = sample_sequence(
                        model=VAE,
                        tokenizer=tokenizer,
                        length=length,
                        batch_size=args.batch_size,
                        x_mask=x_mask,
                        x_tokens=x_tokens,
                        y_mask=y_mask,
                        y_tokens=y_tokens,
                        temperature=args.temperature,
                        top_k=args.top_k,
                        top_p=args.top_p,
                        device=device,
                        eos_token=tokenizer.encoder['<|endoftext|>'],
                        model_type=model_type
                    )
                    out = out.tolist()

                    # extract story, check metrics
                    for i in range(len(out)):
                        text = out[i]
                        text = text[text.index(endoftext) + 1:]

                        if endoftext in text:
                            idx = text.index(endoftext)
                            text = text[:idx]

                        text = tokenizer.decode(text).strip()

                        # score for one long text, higher than 0.075 usually means repetition
                        # rep_score = repeat_score(text.split(), ngram=[3, 4, 5, 6, 7, 8])
                        # if rep_score > 0.075:
                        #     # print(rep_score)
                        #     continue

                        try:
                            # check bleu
                            bleu4 = sentence_bleu([storys_str[i].split()], text,
                                                  smoothing_function=SmoothingFunction().method7)

                            # check rouge
                            rouge = Rouge()
                            rouge_scores = rouge.get_scores(text, storys_str[i])
                            rouge_scores_values = [v for k in rouge_scores[0].keys() for v in
                                                   rouge_scores[0][k].values()]

                            bleu4_sum += bleu4
                            rouge_scores_values_sum = [v1 + v2 for v1, v2 in
                                                       zip(rouge_scores_values_sum, rouge_scores_values)]
                            n_samples += 1
                        except:
                            bleu4 = 0.0
                            rouge_scores = [{'rouge-1': {'f': 0.0, 'p': 0.0, 'r': 0.0},
                                             'rouge-2': {'f': 0.0, 'p': 0.0, 'r': 0.0},
                                             'rouge-l': {'f': 0.0, 'p': 0.0, 'r': 0.0}}]

                        eff_samples.append((text, bleu4, rouge_scores))

                    pbar.update(1)

                for i in range(len(eff_samples)):
                    samples_file.write("=" * 50 + " SAMPLE " + str(i_test) + " " + "=" * 50)
                    samples_file.write('\n' * 2)

                    samples_file.write("=" * 40 + " Outlines  " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(tokenizer.decode(x_tokens[i, :][x_mask[i, :] == 1].tolist()))
                    samples_file.write('\n' * 2)
                    samples_file.write("=" * 40 + " Story " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(storys_str[i])
                    samples_file.write('\n' * 2)

                    samples_file.write("=" * 40 + " Generated " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(eff_samples[i][0])
                    samples_file.write('\n' * 4)
                    samples_file.flush()

        print('Test complete with %05d samples.' % n_samples)
        logging.info("Test complete with %05d samples.", n_samples)
        logging.info("Iteration completed: %d" % num_iters)

        bleu4 = round(bleu4_sum / n_samples, 3)
        rouge_scores_values = [round(r / n_samples, 3) for r in rouge_scores_values_sum]
        print(' bleu-4:', bleu4)
        print(' rouge :', rouge_scores_values)
        logging.info(' bleu-4: %f', bleu4)
        logging.info(' rouge : %s', str(rouge_scores_values))

        VAE.train()

    test_plot(test_loader, num_iters)
    val_step(val_loader)
    generate(test_loader, num_iters)
    torch.save(VAE.state_dict(), os.path.join(save_folder, 'model_' + '{:07d}'.format(num_iters) + '.pt'))

    while num_iters < args.iterations:
        # Run epoch
        st = time.time()

        # Training
        print('Training loop. Batches:', len(train_loader))
        logging.info('\n----------------------------------------------------------------------')
        logging.info("Training loop.       Batches: %d" % len(train_loader))

        # train_iter = iter(train_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(train_iter)
        with tqdm(total=len(train_loader)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask) in enumerate(train_loader):

                # if num_iters % args.cycle >= args.cycle - args.beta_warmup:
                #     beta = min(1.0, beta + (1. - args.beta_0) / args.beta_warmup)

                if not tuning_all and num_iters >= tuning_all_after_iters:
                    for name, parameter in VAE.named_parameters():
                        # print((name, parameter.requires_grad))
                        parameter.requires_grad = True
                    tuning_all = True

                output = train_step(device, VAE, optimizer, x_mask, x_tokens, y_mask, y_tokens,
                                    input_tokens, target_tokens, mask, loss_fn, beta, args.model_type)
                loss, ce_loss, kl_loss = output[-1]

                lr = scheduler.get_last_lr()[0]
                # Log to Tensorboard
                t_writer.add_scalar('loss', loss, num_iters)
                t_writer.add_scalar('ppl', math.exp(min(ce_loss, 10)), num_iters)
                t_writer.add_scalar('lr', lr, num_iters)
                t_writer.add_scalar('iter_time', time.time() - st, num_iters)
                t_writer.add_scalar('kl', kl_loss, num_iters)
                t_writer.add_scalar('beta', beta, num_iters)

                if args.model_type == 'ae_vae_fusion':
                    loss, ce_loss, kl_loss = output[0]
                    # Log to Tensorboard
                    t_writer.add_scalar('ae_loss', loss, num_iters)
                    t_writer.add_scalar('ae_kl', kl_loss, num_iters)

                st = time.time()
                end = num_iters >= args.iterations

                if args.warmup != -1:
                    scheduler.step()

                if end: break
                num_iters += 1
                pbar.update(1)

                if num_iters % args.cycle == 0:
                    beta = args.beta_0
                    logging.info('KL annealing restart')

                if num_iters % 10000 == 0:
                    test_plot(test_loader, num_iters)
                    val_step(val_loader)
                    generate(test_loader, num_iters)

                if num_iters % 50000 == 0:
                    print('Saving model...')
                    logging.info("Iteration completed: %d, remained %d" % (num_iters, args.iterations - num_iters))
                    logging.info("Saving model...")
                    logging.info('\n------------------------------------------------------')
                    torch.save(VAE.state_dict(), os.path.join(save_folder, 'model_' + '{:07d}'.format(num_iters) + '.pt'))

                if args.switch_time > 0 and num_iters == int(args.iterations * args.switch_time):
                    print('Switch to long sequence training')
                    logging.info("Switch to long sequence training")
                    cur_b_schedule += 1
                    train_loader, val_loader, test_loader = prepare_dataset(
                        args.data_dir, args.dataset, tokenizer,
                        batch_schedule[cur_b_schedule][0], batch_schedule[cur_b_schedule][1],
                        batch_schedule[-1][0], batch_schedule[-1][1],
                        batch_schedule[-1][0], batch_schedule[-1][1],
                        make_test=True,
                        num_workers=args.workers, data_type=args.data_type
                    )
        if not end:
            e += 1
            logging.info("Training loop. The ith epoch completed: %d" % e)

    torch.save(VAE.state_dict(), os.path.join(save_folder, 'model_latest.pt'))
    print('Training complete.')
    logging.info("Training complete.")
Example #4
0
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math

prompt = 'Do you believe that certain materials, such as books, music, movies, magazines, etc., should be removed from the shelves if they are found offensive?'

device = torch.device("cuda")# if args.cuda else "cpu")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
config = GPT2Config()
config.output_hidden_states = True
model = GPT2LMHeadModel.from_pretrained('gpt2', config=config)

model.to(device)
model.cuda()

text_idx = tokenizer.encode(prompt)
word_vectors = model.transformer.wte.weight[text_idx,:]
avg_vec = word_vectors.mean(dim=0, keepdims=True)
distances = model.transformer.wte.weight @ avg_vec.T

x, topk = torch.topk(distances, k=100, dim=0)
torch.save(topk, 'target_words.pt')
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
    parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度')
    parser.add_argument('--batch_size', default=1, type=int, required=False, help='生成的batch size')
    parser.add_argument('--nsamples', default=1, type=int, required=False, help='生成几个样本')
    parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度')
    parser.add_argument('--topk', default=5, type=int, required=False, help='最高几选一')
    parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
    parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
                        help='模型参数')
    parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='词表路径')
    parser.add_argument('--model_path', default='model/model_epoch2', type=str, required=False, help='模型路径')
    parser.add_argument('--prefix', default='萧何', type=str, required=False, help='生成文章的开头')
    parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词')
    parser.add_argument('--segment', action='store_true', help='中文以词为单位')
    parser.add_argument('--fast_pattern', action='store_true', help='采用更加快的方式生成文本')
    parser.add_argument('--save_samples', action='store_true', help='保存产生的样本')
    parser.add_argument('--save_samples_path', default='.', type=str, required=False, help="保存样本的路径")
    parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False)

    args = parser.parse_args()
    print('args:\n' + args.__repr__())

    if args.segment:
        from tokenizations import tokenization_bert_word_level as tokenization_bert
    else:
        from tokenizations import tokenization_bert

    os.environ["CUDA_VISIBLE_DEVICES"] = args.device  # 此处设置程序使用哪些显卡
    length = args.length
    batch_size = args.batch_size
    nsamples = args.nsamples
    temperature = args.temperature
    topk = args.topk
    topp = args.topp
    repetition_penalty = args.repetition_penalty

    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_path)
    model.to(device)
    model.eval()

    n_ctx = model.config.n_ctx

    if length == -1:
        length = model.config.n_ctx
    if args.save_samples:
        if not os.path.exists(args.save_samples_path):
            os.makedirs(args.save_samples_path)
        samples_file = open(args.save_samples_path + '/samples.txt', 'w', encoding='utf8')
    while True:
        words = list(args.prefix)
        print(words)
        for word in words:
            raw_text = word
            context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
            generated = 0
            for _ in range(nsamples // batch_size):
                out = generate(
                    n_ctx=n_ctx,
                    model=model,
                    context=context_tokens,
                    length=length,
                    is_fast_pattern=args.fast_pattern, tokenizer=tokenizer,
                    temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device
                )
                for i in range(batch_size):
                    generated += 1
                    text = tokenizer.convert_ids_to_tokens(out)
                    for i, item in enumerate(text[:-1]):  # 确保英文前后有空格
                        if is_word(item) and is_word(text[i + 1]):
                            text[i] = item + ' '
                    for i, item in enumerate(text):
                        if item == '[MASK]':
                            text[i] = ''
                        elif item == '[CLS]':
                            text[i] = '\n\n'
                        elif item == '[SEP]':
                            text[i] = '\n'
                    info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n"
                    print(info)
                    text = ''.join(text).replace('##', '').strip()
                    print(text)
                    if args.save_samples:
                        samples_file.write(info)
                        samples_file.write(text)
                        samples_file.write('\n')
                        samples_file.write('=' * 90)
                        samples_file.write('\n' * 2)

        print("=" * 80)
        if generated == nsamples:
            # close file when finish writing.
            if args.save_samples:
                samples_file.close()
            break
Example #6
0
 def load_model(self):
     from transformers import GPT2Tokenizer, GPT2LMHeadModel
     self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
     self.model = GPT2LMHeadModel.from_pretrained(
         'gpt2')
     self.model.to(self.device)
Example #7
0
def main():
    wandb.init(project="gpt2_comet_atomic")

    config = wandb.config
    config.TRAIN_BATCH_SIZE = int(os.environ.get("TRAIN_BATCH_SIZE", 2))
    config.VALID_BATCH_SIZE = int(os.environ.get("VALID_BATCH_SIZE", 2))
    config.TRAIN_EPOCHS = int(os.environ.get("TRAIN_EPOCHS", 3))
    config.VAL_EPOCHS = int(os.environ.get("VAL_EPOCHS", 1))
    config.LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "1e-5"))
    config.SEED = int(os.environ.get("SEED", 42))
    config.IN_LEN = int(os.environ.get("IN_LEN", 16))
    config.OUT_LEN = int(os.environ.get("OUT_LEN", 34))
    config.SUMMARY_LEN = 0  # Used for t5
    config.OUT_DIR = os.environ.get("OUT_DIR", "/models")
    config.DO_TRAIN = os.environ.get("DO_TRAIN", "False") == "True"
    config.DO_PRED = os.environ.get("DO_PRED", "True") == "True"
    config.PRED_FILE = str(os.environ.get("PRED_FILE", ""))
    config.TOP_K = int(os.environ.get("TOP_K", 40))
    config.PRED_BATCH = 64
    config.TOKENIZER = os.environ.get('TOKENIZER', "gpt2-xl")

    torch.manual_seed(config.SEED)  # pytorch random seed
    np.random.seed(config.SEED)  # numpy random seed
    torch.backends.cudnn.deterministic = True

    model_name = "gpt2" if 'GPT2_MODEL' not in os.environ else os.environ[
        'GPT2_MODEL']

    try:
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    except:
        tokenizer = GPT2Tokenizer.from_pretrained(config.TOKENIZER)

    tokenizer.add_special_tokens({
        'eos_token':
        '[EOS]',
        'additional_special_tokens': [
            'LocationOfAction', 'HinderedBy', 'HasFirstSubevent',
            'NotHasProperty', 'NotHasA', 'HasA', 'AtLocation', 'NotCapableOf',
            'CausesDesire', 'HasPainCharacter', 'NotDesires', 'MadeUpOf',
            'InstanceOf', 'SymbolOf', 'xReason', 'isAfter', 'HasPrerequisite',
            'UsedFor', 'MadeOf', 'MotivatedByGoal', 'Causes', 'oEffect',
            'CreatedBy', 'ReceivesAction', 'NotMadeOf', 'xWant', 'PartOf',
            'DesireOf', 'HasPainIntensity', 'xAttr', 'DefinedAs', 'oReact',
            'xIntent', 'HasSubevent', 'oWant', 'HasProperty', 'IsA',
            'HasSubEvent', 'LocatedNear', 'Desires', 'isFilledBy', 'isBefore',
            'InheritsFrom', 'xNeed', 'xEffect', 'xReact', 'HasLastSubevent',
            'RelatedTo', 'CapableOf', 'NotIsA', 'ObjectUse', '[GEN]'
        ]
    })
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    train_dataset = pd.read_csv(os.environ.get(
        'TRAIN_DATA_PATH', "/tmp/gpt2data/atomic_train.tsv"),
                                encoding='latin-1',
                                sep="\t")
    if DEBUG:
        train_dataset = train_dataset.head(NUM_INST)
    # train_dataset = train_dataset[['head_event', 'tail_event', 'relation']]
    train_dataset.head_event = train_dataset.head_event + ' ' + train_dataset.relation \
                               + " [GEN]"
    train_dataset.tail_event = train_dataset.tail_event + ' [EOS]'
    logger.info(train_dataset.head())
    logger.info(train_dataset.tail_event)

    val_dataset = pd.read_csv(os.environ.get('DEV_DATA_PATH',
                                             "/tmp/gpt2data/atomic_dev.tsv"),
                              encoding='latin-1',
                              sep="\t")
    if DEBUG:
        val_dataset = val_dataset.head(NUM_INST)
    val_dataset = val_dataset[['head_event', 'tail_event', 'relation']]
    val_dataset.head_event = val_dataset.head_event + ' ' + val_dataset.relation + " [GEN]"
    val_dataset.tail_event = val_dataset.tail_event + ' [EOS]'
    logger.info(val_dataset.tail_event)
    logger.info(val_dataset.head())

    test_dataset = pd.read_csv(os.environ.get('TEST_DATA_PATH',
                                              "/tmp/gpt2data/atomic_test.tsv"),
                               encoding='latin-1',
                               sep="\t")
    if DEBUG:
        test_dataset = test_dataset.head(NUM_INST)
    test_dataset = test_dataset[['head_event', 'tail_event', 'relation']]
    test_dataset.head_event = test_dataset.head_event + ' ' + test_dataset.relation \
                              + " [GEN]"
    test_dataset.tail_event = test_dataset.tail_event + ' [EOS]'
    logger.info(test_dataset.tail_event)
    logger.info(test_dataset.head())

    val_dataset_mini = pd.read_csv(os.environ.get(
        'DEV_DATA_PATH', "/tmp/gpt2data/atomic_dev.tsv"),
                                   encoding='latin-1',
                                   sep="\t")
    if DEBUG:
        val_dataset_mini = val_dataset_mini.head(5)
    val_dataset_mini = val_dataset_mini.sample(n=min(
        int(val_dataset_mini.size / 3), 100),
                                               random_state=config.SEED)
    val_dataset_mini = val_dataset_mini[[
        'head_event', 'tail_event', 'relation'
    ]]
    val_dataset_mini.head_event = val_dataset_mini.head_event + ' ' + val_dataset_mini.relation + " [GEN]"
    val_dataset_mini.tail_event = val_dataset_mini.tail_event + ' [EOS]'
    logger.info(val_dataset_mini.tail_event)
    logger.info(val_dataset_mini.head())

    logger.info("TRAIN Dataset tuple count: {}".format(train_dataset.shape))
    logger.info("DEV Dataset tuple_count: {}".format(val_dataset.shape))
    logger.info("DEV MINI Dataset tuple_count: {}".format(
        val_dataset_mini.shape))

    training_set = KGDataset(train_dataset,
                             tokenizer,
                             config.OUT_LEN,
                             config.SUMMARY_LEN,
                             model="gpt2")
    val_set = KGDataset(val_dataset,
                        tokenizer,
                        config.IN_LEN,
                        config.OUT_LEN - config.IN_LEN,
                        model="gpt2",
                        is_eval=True)
    val_set_mini = KGDataset(val_dataset.head(2000),
                             tokenizer,
                             config.IN_LEN,
                             config.OUT_LEN - config.IN_LEN,
                             model="gpt2",
                             is_eval=True)
    test_set = KGDataset(test_dataset,
                         tokenizer,
                         config.IN_LEN,
                         config.OUT_LEN - config.IN_LEN,
                         model="gpt2",
                         is_eval=True)

    train_params = {
        'batch_size': config.TRAIN_BATCH_SIZE,
        'shuffle': True,
        'num_workers': 0
    }

    val_params = {'batch_size': 1, 'shuffle': False, 'num_workers': 0}

    training_loader = DataLoader(training_set, **train_params, drop_last=True)
    val_loader = DataLoader(val_set, **val_params, drop_last=True)
    test_loader = DataLoader(test_set, **val_params, drop_last=True)
    val_loader_mini = DataLoader(val_set_mini, **val_params, drop_last=True)

    logging.info("Loading model from {}".format(model_name))
    model = GPT2LMHeadModel.from_pretrained(model_name, use_cdn=False)
    logging.info("Move model to device {}".format(device))
    model = model.to(device)
    model.resize_token_embeddings(len(tokenizer))

    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=config.LEARNING_RATE)

    wandb.watch(model, log="all")

    if config.DO_TRAIN:
        logger.info('Initiating Fine-Tuning for the model on our dataset')

        for epoch in range(config.TRAIN_EPOCHS):
            train(epoch,
                  tokenizer,
                  model,
                  device,
                  training_loader,
                  optimizer,
                  val_loader_mini,
                  model_class="gpt2")
            model.save_pretrained('{}/checkpoint_{}'.format(
                config.OUT_DIR, epoch))
            tokenizer.save_pretrained('{}/checkpoint_{}'.format(
                config.OUT_DIR, epoch))
        model.save_pretrained('/models')

    if config.DO_PRED:

        if config.PRED_FILE.endswith("jsonl"):
            records = read_jsonl_lines(config.PRED_FILE)
            pred_dataset = pd.DataFrame.from_records(records)
            pred_dataset = pred_dataset.rename(columns={
                "head": "head_event",
                "tails": "tail_event"
            })
            pred_dataset = pred_dataset.explode('tail_event')
        else:
            pred_dataset = pd.read_csv(config.PRED_FILE,
                                       encoding='latin-1',
                                       sep="\t")

        if DEBUG:
            pred_dataset = pred_dataset.head(NUM_INST)

        pred_dataset = pred_dataset.drop_duplicates(['head_event', 'relation'],
                                                    ignore_index=True)

        pred_dataset.head_event = pred_dataset.head_event + ' ' + pred_dataset.relation + " [GEN]"
        pred_dataset.tail_event = pred_dataset.tail_event + ' [EOS]'
        logger.info(pred_dataset.tail_event)
        logger.info(pred_dataset.head())

        pred_set = KGDataset(pred_dataset,
                             tokenizer,
                             config.IN_LEN,
                             config.OUT_LEN - config.IN_LEN,
                             model="gpt2",
                             is_eval=True)
        pred_loader = DataLoader(pred_set, **val_params, drop_last=False)

        pred_generations = beam_generations(tokenizer,
                                            model,
                                            device,
                                            pred_loader,
                                            top_k=config.TOP_K)
        write_items(os.path.join(config.OUT_DIR, "pred_generations.jsonl"),
                    [json.dumps(r) for r in pred_generations])

        # Resave the model to keep generations and model associated
        model.save_pretrained('/models')
        tokenizer.save_pretrained('/models')
    # Create tokenizers
    model_name = pretrained_models[0]
    gpt2_tokenizer = None
    if pretrained_models == [model_name]:
        gpt2_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    else:
        raise NotImplementedError("Only the following tokenizers are supported: {}".format(model_name))

    num_keywords = len(TARG)
    num_possible_labels = int(1 + num_keywords)

    model = None
    tokenizer = None
    if 'gpt2' in model_name:
        model = GPT2LMHeadModel.from_pretrained(model_name)
        tokenizer = gpt2_tokenizer
    else:
        raise NotImplementedError("model_name == {} not supported".format(model_name))

    model.transformer.output_hidden_states = True  # necessary to pull activation tensors
    device = torch.device("cpu")
    if torch.cuda.is_available():
        model = model.cuda()
        device = torch.device("cuda")

    try:
        # BEGINNING ##############################################################################################################################

        print("and so it begins", flush=True)
        dataset = []
Example #9
0
        length = MAX_LENGTH  # avoid infinite loop
    return length


def main(
    max_length=10_000,
    temperature=1,
    top_k=0,
    top_p=0.9,
    repetition_penalty=1.0,
    num_return_sequences=1,
    prompt_text="Your flavors are good, but your texture's all wrong.",
):
    # Initialize the model and tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    # model.to(args.device)
    encoded_prompt = tokenizer.encode(prompt_text,
                                      add_special_tokens=False,
                                      return_tensors="pt")
    # encoded_prompt = encoded_prompt.to(args.device)
    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt
    max_length += len(encoded_prompt[0])
    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
Example #10
0
with open(os.path.join(MODEL_DIR, 'additional_ids_to_tokens.pkl'), 'rb') as f:
    additional_ids_to_tokens = pickle.load(f)
additional_tokens_to_ids = {v: k for k, v in additional_ids_to_tokens.items()}
try:
    ilm.tokenize_util.update_tokenizer(additional_ids_to_tokens, tokenizer)
except ValueError:
    print('Already updated')
print(additional_tokens_to_ids)

# Load model

import torch
from transformers import GPT2LMHeadModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
model.eval()
_ = model.to(device)

context = """
.  _ funny _ bad behavior .
""".strip()

context_ids = ilm.tokenize_util.encode(context, tokenizer)
print(context_ids)
# Replace blanks with appropriate tokens from left to right
_blank_id = ilm.tokenize_util.encode(' _', tokenizer)[0]
context_ids[context_ids.index(
    _blank_id)] = additional_tokens_to_ids['<|infill_one|>']
context_ids[context_ids.index(
    _blank_id)] = additional_tokens_to_ids['<|infill_one|>']
Example #11
0
    assert len(past) == len(cur_past)
    for i in range(len(past)):
        past[i] += WEIGHT * cond_past[i]
    return past


def top_k_filtering(logits, top_k=1, filter_value=-float("Inf"), min_tokens_to_keep=1):
    top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
    ids_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    ids_to_retain = torch.topk(logits, top_k)[1][0]
    logits[ids_to_remove] = filter_value
    return logits, ids_to_retain


tokenizer = GPT2Tokenizer.from_pretrained(MODEL)
model = GPT2LMHeadModel.from_pretrained(MODEL).to(DEV)
COND_IDS = tokenizer.encode(COND)
cond_ids = torch.tensor([COND_IDS]).to(DEV)


input_ids = torch.tensor([tokenizer.encode(PREFIX, add_special_tokens=True)]).to(DEV)
input_past = model(input_ids[:, :-1])[1]

for t in range(input_ids.shape[1]-1, LENGTH):  # +1 for the last time step of prefix

    with torch.no_grad():
        position_ids = torch.tensor([[t]]).to(DEV)
        past = input_past
        for i in range(cond_ids.shape[1]):
            cur_past = model(cond_ids[:, i:i+1], position_ids=position_ids)[1]
            past = cat_past(past, cur_past)
for book in ["meditations", "epictetus_discourses"]:
    with open(f"{book}.txt", 'r+') as f:
        text = f.read()
        text = re.sub("[A-Z][A-Z]+\.", " ", text)
        text = re.sub("[A-Z][A-Z]+", " ", text)
        text = re.sub("CHAPTER [0-9]+\.", " ", text)
        text = re.sub("CHAPTER [0-9]+", " ", text)
        text = re.sub("X\.", " ", text)
        text = re.sub("V\.", " ", text)
        text = re.sub("[0-9]+", "", text)
        text = re.sub("[0-9]+\.", "", text)
        text = ' '.join(text.split())
        corpus += text

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
datagen = DataGenerator(corpus, seq_size, tokenizer)
dataloader = DataLoader(datagen,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_list = []

timestamp = time.strftime("%y_%m_%d_%H_%M_%S")
if continue_training_path:
    # load a network for continued training
    print('continue training')
    model.load_state_dict(
        torch.load(continue_training_path +
Example #13
0
    def __init__(self) -> None:
        self.lists = {}

        # M-BERT
        from transformers import BertTokenizerFast, BertForMaskedLM
        self.bert_multilingual_tokenizer = BertTokenizerFast.from_pretrained(
            'bert-base-multilingual-cased')
        self.bert_multilingual_model = BertForMaskedLM.from_pretrained(
            'bert-base-multilingual-cased').eval()
        self.lists["M-BERT"] = {
            "Tokenizer": self.bert_multilingual_tokenizer,
            "Model": self.bert_multilingual_model
        }
        print("====================================")
        print("[BERT] Google Multilingual BERT 로드 완료")
        print("====================================")

        # KR-BERT
        from transformers import BertTokenizerFast, BertForMaskedLM
        self.krbert_tokenizer = BertTokenizerFast.from_pretrained(
            'snunlp/KR-Medium')
        self.krbert_model = BertForMaskedLM.from_pretrained(
            'snunlp/KR-Medium').eval()
        self.lists["KR-Medium"] = {
            "Tokenizer": self.krbert_tokenizer,
            "Model": self.krbert_model
        }
        print("====================================")
        print("[BERT] KR-BERT 로드 완료")
        print("====================================")

        # BERT
        from transformers import BertTokenizerFast, BertForMaskedLM
        self.bert_kor_tokenizer = BertTokenizerFast.from_pretrained(
            'kykim/bert-kor-base')
        self.bert_kor_model = BertForMaskedLM.from_pretrained(
            'kykim/bert-kor-base').eval()
        self.lists["bert-kor-base"] = {
            "Tokenizer": self.bert_kor_tokenizer,
            "Model": self.bert_kor_model
        }
        print("====================================")
        print("[BERT] BERT-kor-base 로드 완료")
        print("====================================")

        # ALBERT
        from transformers import AlbertForMaskedLM
        self.albert_tokenizer = BertTokenizerFast.from_pretrained(
            'kykim/albert-kor-base')
        self.albert_model = AlbertForMaskedLM.from_pretrained(
            'kykim/albert-kor-base').eval()
        self.lists["albert-kor-base"] = {
            "Tokenizer": self.albert_tokenizer,
            "Model": self.albert_model
        }
        print("====================================")
        print("[BERT] ALBERT-kor-base 로드 완료")
        print("====================================")

        # XLM-Roberta
        from transformers import XLMRobertaTokenizerFast, XLMRobertaForMaskedLM
        self.xlmroberta_tokenizer = XLMRobertaTokenizerFast.from_pretrained(
            'xlm-roberta-base')
        self.xlmroberta_model = XLMRobertaForMaskedLM.from_pretrained(
            'xlm-roberta-base').eval()
        self.lists["xlm-roberta-base"] = {
            "Tokenizer": self.xlmroberta_tokenizer,
            "Model": self.xlmroberta_model
        }
        print("====================================")
        print("[BERT] XLM-Roberta-kor 로드 완료")
        print("====================================")

        from transformers import BertTokenizerFast, EncoderDecoderModel
        self.tokenizer_bertshared = BertTokenizerFast.from_pretrained(
            "kykim/bertshared-kor-base")
        self.bertshared_model = EncoderDecoderModel.from_pretrained(
            "kykim/bertshared-kor-base")
        self.lists["bertshared-kor-base"] = {
            "Tokenizer": self.tokenizer_bertshared,
            "Model": self.bertshared_model
        }
        print("====================================")
        print("[Seq2seq + BERT] bertshared-kor-base 로드 완료")
        print("====================================")

        # gpt3-kor-small_based_on_gpt2
        from transformers import BertTokenizerFast, GPT2LMHeadModel
        self.tokenizer_gpt3 = BertTokenizerFast.from_pretrained(
            "kykim/gpt3-kor-small_based_on_gpt2")
        self.model_gpt3 = GPT2LMHeadModel.from_pretrained(
            "kykim/gpt3-kor-small_based_on_gpt2")
        self.lists["gpt3-kor-small_based_on_gpt2"] = {
            "Tokenizer": self.tokenizer_gpt3,
            "Model": self.model_gpt3
        }
        print("====================================")
        print("[GPT3] gpt3-small-based-on-gpt2 로드 완료")
        print("====================================")

        # electra-base-kor
        from transformers import ElectraTokenizerFast, ElectraModel
        self.tokenizer_electra = ElectraTokenizerFast.from_pretrained(
            "kykim/electra-kor-base")
        self.electra_model = ElectraModel.from_pretrained(
            "kykim/electra-kor-base")
        self.lists["electra-kor-base"] = {
            "Tokenizer": self.tokenizer_electra,
            "Model": self.electra_model
        }
        print("====================================")
        print("[ELECTRA] electra-kor-base 로드 완료")
        print("====================================")

        from transformers import ElectraTokenizerFast, ElectraForQuestionAnswering
        self.electra_tokenizer_QA = ElectraTokenizerFast.from_pretrained(
            "monologg/koelectra-base-v3-finetuned-korquad")
        self.electra_model_QA = ElectraForQuestionAnswering.from_pretrained(
            "monologg/koelectra-base-v3-finetuned-korquad")
        self.lists["electra-kor-QA"] = {
            "Tokenizer": self.electra_tokenizer_QA,
            "Model": self.electra_model_QA
        }
        print("====================================")
        print("[ELECTRA] koelectra-base-v3-finetuned-korquad 로드 완료")
        print("====================================")
Example #14
0
    collate_fn=data_collator,
    drop_last=DATALOADER_DROP_LAST,
)

# eval_dataset = PoetryDataset('data/whitman/input.txt', is_eval=True)
# eval_sampler = SequentialSampler(eval_dataset)
# eval_iterator = DataLoader(
#     eval_dataset,
#     sampler=eval_sampler,
#     batch_size=BATCH_SIZE,
#     collate_fn=data_collator,
#     drop_last=DATALOADER_DROP_LAST,
# )
# -

model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)
model.resize_token_embeddings(len(tokenizer))


def create_optimizer_and_scheduler(
    weight_decay: float = 0.0,
    learning_rate: float = 5e-5,
    beta1: float = 0.9,
    beta2: float = 0.999,
    epsilon: float = 1e-8,
    warmup_steps=0,
):
    """
    Setup the optimizer and the learning rate scheduler.
    We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
    Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
Example #15
0
#!/usr/bin/env python3
# coding: utf-8
########################################
# authors                              #
# marcalph https://github.com/marcalph #
########################################
""" composition baseline
"""
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
torch.manual_seed(42)
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
model = GPT2LMHeadModel.from_pretrained("distilgpt2", pad_token_id=tokenizer.eos_token_id)


def compose(string, possibilities=3):
    """ composition func
    Parameters
    ----------
    string : starting string
    Returns
    -------
    output : completed string
    """
    input = tokenizer.encode(string, return_tensors='pt')
    sampling_outputs = model.generate(
        input,
        do_sample=True,
        max_length=50,
        top_k=50,
        top_p=0.95,
from transformers import GPT2LMHeadModel, GPT2Tokenizer

CACHE_DIR = '/scratch/gpfs/hgazula/.cache/'
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-xl',
                                          add_prefix_space=True,
                                          cache_dir=CACHE_DIR)
tokenizer.pad_token = tokenizer.eos_token

lm_model = GPT2LMHeadModel.from_pretrained("gpt2-xl",
                                           output_hidden_states=True,
                                           cache_dir=CACHE_DIR)
lm_model.eval()

sentences = [
    "i'm asking because i wanna measure it you finish the iced tea and this"
]
tokens = tokenizer.tokenize(sentences[0])
ids = tokenizer.convert_tokens_to_ids(tokens)
tok_to_str = tokenizer.convert_tokens_to_string(tokens[18])
print(tokenizer.encode(sentences[0]))
print(tokenizer.decode(ids[18]))

# sentences = ['Hello',
#  'Hello world',
#  'Hello world there',
#  'Hello world there you',
#  'Hello world there you are',
#  'world there you are high'
#  ]

input_ids = tokenizer(sentences, padding=True, return_tensors='pt')
Example #17
0
async def analyze(request):

    # identify mentor and question that user selected
    input_form = await request.form()
    question_prompt = input_form[
        'question']  # used to identify the language model prompt category
    mentor = input_form['mentor']  # used for model generation signature

    # format linked text
    hyperlink_format = '<a href="{website}" style="color:blue; border-bottom: 1px solid" target="_blank" rel="noopener">{text}</a>'

    # generate answer prompts for model
    if question_prompt == "42":
        answer = answer_to_life()
        return JSONResponse({'result': answer})
    elif question_prompt == "empty":
        anwser_link = hyperlink_format.format(
            website='https://giphy.com/gifs/life-bid-dJv3R2vXEjPRm/fullscreen',
            text="I'll try to answer.")
        answer = "Choose your question and " + anwser_link
        return JSONResponse({'result': answer})
    else:
        answer = answer_prompt(
            question_prompt
        )  # use a language model prompt/seed from "question_prompts.py"

    # load gpt-2 model for selected mentor
    if mentor == "seneca":
        model = GPT2LMHeadModel.from_pretrained(path /
                                                f'{model_seneca_folder}')
        tokenizer = GPT2Tokenizer.from_pretrained(path /
                                                  f'{model_seneca_folder}')
    elif mentor == "paul-graham":
        model = GPT2LMHeadModel.from_pretrained(path /
                                                f'{model_paul_graham_folder}')
        tokenizer = GPT2Tokenizer.from_pretrained(
            path / f'{model_paul_graham_folder}')
    elif mentor == "david-goggins":
        model = GPT2LMHeadModel.from_pretrained(
            path / f'{model_david_goggins_folder}')
        tokenizer = GPT2Tokenizer.from_pretrained(
            path / f'{model_david_goggins_folder}')
    elif mentor == "brene-brown":
        model = GPT2LMHeadModel.from_pretrained(path /
                                                f'{model_brene_brown_folder}')
        tokenizer = GPT2Tokenizer.from_pretrained(
            path / f'{model_brene_brown_folder}')
    elif mentor == "tony-robbins":
        model = GPT2LMHeadModel.from_pretrained(path /
                                                f'{model_tony_robbins_folder}')
        tokenizer = GPT2Tokenizer.from_pretrained(
            path / f'{model_tony_robbins_folder}')

    # generate model prediction (string of text)
    model_prediction = predict(answer,
                               metadata,
                               model=model,
                               tokenizer=tokenizer)
    ram_mb = memory_usage_psutil()[1]
    print(f"After generating prediction, {ram_mb} MB of RAM is in use")

    # reset variables to save RAM
    model = None
    tokenizer = None

    # garbage collect to free up memory
    gc.collect()
    ram_mb = memory_usage_psutil()[1]
    print(f"After garbage collect, {ram_mb} MB of RAM is in use")

    # find position of final punctuation mark in model prediction (period, question, or exclamation)
    punct_period = model_prediction.rfind('.')
    punct_excl = model_prediction.rfind('!')
    punct_quest = model_prediction.rfind('?')

    punct_per_quote = model_prediction.rfind(
        '."') + 1  # if ends in quotation, include final quotation
    punct_exc_quote = model_prediction.rfind(
        '!"') + 1  # if ends in quotation, include final quotation
    punct_ques_quote = model_prediction.rfind(
        '?"') + 1  # if ends in quotation, include final quotation

    # find position to truncate model prediction
    min_chars = 100  # used as default in case generated language does not include ending punctuation and returns '-1'
    punct_max = max(punct_period, punct_excl, punct_quest, punct_per_quote,
                    punct_exc_quote, punct_ques_quote, min_chars)

    # truncate model prediction so it ends with punctuation mark
    trunc_model_prediction = model_prediction[:punct_max + 1]
    print(trunc_model_prediction
          )  # printed in logs to ensure it is working as expected

    # add a closing signature (not generated) for mentor
    if mentor == "seneca":
        sig_start = "</br></br>Get the rest of my free stoic letters "
        sig_link = hyperlink_format.format(
            website='https://tim.blog/2017/07/06/tao-of-seneca/', text='here')
        sig_end = ". Farewell."
        sig = sig_start + sig_link + sig_end
    elif mentor == "paul-graham":
        sig_start = "</br></br>Want to start a startup? Get funded by "
        sig_link = hyperlink_format.format(
            website='https://www.ycombinator.com/apply.html',
            text='Y Combinator')
        sig = sig_start + sig_link
    elif mentor == "brene-brown":
        sig_start = "</br></br>Be vulnerable. Be seen. "
        sig_link = hyperlink_format.format(
            website='https://youtu.be/-s6DQrqVHxM?t=37',
            text='Get in the arena.')
        sig = sig_start + sig_link
    elif mentor == "david-goggins":
        sig_start = "</br></br>"
        sig_link = hyperlink_format.format(
            website='https://www.youtube.com/watch?v=DS0ed93UQeY',
            text='#TakingSouls')
        sig = sig_start + sig_link
    elif mentor == "tony-robbins":
        sig_start = "</br></br>God Bless. Thank You.</br>#IamNotYourGuru"
        sig = sig_start

    return JSONResponse({'result': trunc_model_prediction + sig})
Example #18
0
	def __init__(self, pretrained='gpt2'):
		self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
		self.gpt2_model = GPT2LMHeadModel.from_pretrained(pretrained).to(device)
		self.pad_id = self.tokenizer._convert_token_to_id('<|endoftext|>')
 def __init__(self):
     self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
     self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
     self.model.eval()
Example #20
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device',
                        default='0,1,2,3',
                        type=str,
                        required=False,
                        help='设置使用哪些显卡')
    parser.add_argument('--length',
                        default=-1,
                        type=int,
                        required=False,
                        help='生成长度')
    parser.add_argument('--temperature',
                        default=1,
                        type=float,
                        required=False,
                        help='生成温度,越高越随机')
    parser.add_argument('--topk',
                        default=8,
                        type=int,
                        required=False,
                        help='生成的时候最高几选一')
    parser.add_argument('--topp',
                        default=0,
                        type=float,
                        required=False,
                        help='生成的时候积累概率最高多少')
    parser.add_argument('--model_config',
                        default='config/model_config_small.json',
                        type=str,
                        required=False,
                        help='模型参数路径')
    parser.add_argument('--tokenizer_path',
                        default='cache/vocab_seg_all.txt',
                        type=str,
                        required=False,
                        help='词表路径')
    parser.add_argument('--model_path',
                        default='model/final_model',
                        type=str,
                        required=False,
                        help='模型路径')
    parser.add_argument('--save_path',
                        default='generated/',
                        type=str,
                        required=False,
                        help='存放生成的文件的路径')
    parser.add_argument('--articles_per_title',
                        default=5,
                        type=int,
                        required=False,
                        help='每个标题生成多少篇文章')
    parser.add_argument('--titles',
                        default='萧炎',
                        type=str,
                        required=False,
                        help='标题列表,是一个字符串,用空格分开')
    parser.add_argument('--titles_file',
                        default='',
                        type=str,
                        required=False,
                        help='标题列表文件,文件中每行一个标题。如果这个选项有值则titles无效')
    parser.add_argument('--no_wordpiece',
                        action='store_true',
                        help='不做word piece切词')
    parser.add_argument('--segment', action='store_true', help='中文以词为单位')
    parser.add_argument('--repetition_penalty',
                        default=1.0,
                        type=float,
                        required=False)

    args = parser.parse_args()
    print('args:\n' + args.__repr__())

    if args.segment:
        from tokenizations import tokenization_bert_word_level as tokenization_bert
    else:
        from tokenizations import tokenization_bert

    os.environ["CUDA_VISIBLE_DEVICES"] = args.device  # 此处设置程序使用哪些显卡
    length = args.length
    temperature = args.temperature
    topk = args.topk
    topp = args.topp
    repetition_penalty = args.repetition_penalty

    titles = args.titles.split()  # 列表,里面每个元素是一个生成的标题
    if args.titles_file:
        with open(args.titles_file, 'r') as f:
            titles = [line.strip('\n') for line in f.readlines()]
    articles_per_title = args.articles_per_title  # 这里定义一个标题生成多少篇文章
    save_path = args.save_path  # 设置存到哪

    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_path)
    model.to(device)
    model.eval()

    n_ctx = model.config.n_ctx

    if not os.path.exists(save_path):
        os.mkdir(save_path)
    if length == -1:
        length = model.config.n_ctx

    for i, title in enumerate(titles):
        for j in range(articles_per_title):
            with open(save_path + str(i) + '-' + str(j) + '.txt', 'w') as f:
                context_tokens = tokenizer.convert_tokens_to_ids(
                    tokenizer.tokenize(title))
                generated = 0
                out = sample_sequence(n_ctx=n_ctx,
                                      model=model,
                                      length=length,
                                      context=context_tokens,
                                      tokenizer=tokenizer,
                                      temperature=temperature,
                                      top_k=topk,
                                      top_p=topp,
                                      repitition_penalty=repetition_penalty,
                                      device=device)
                out = out.tolist()[0]

                generated += 1
                text = tokenizer.convert_ids_to_tokens(out)

                for i, item in enumerate(text[:-1]):  # 确保英文前后有空格
                    if is_word(item) and is_word(text[i + 1]):
                        text[i] = item + ' '

                for i, item in enumerate(text):
                    if item == '[MASK]':
                        text[i] = ''
                    if item == '[CLS]' or item == '[SEP]':
                        text[i] = '\n'

                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                text = ''.join(text).replace('##', '').strip()
                # text = ''.join(text.split('\n')[:-1])
                print(text)
                f.write(text + '\n')
                print("=" * 80)
Example #21
0
def load_model(model) -> GPT2LMHeadModel:
    if model in SRC_MODELS:
        model = SRC_MODELS[model]
    return GPT2LMHeadModel.from_pretrained(str(model))
Example #22
0
import logging
logging.basicConfig(level=logging.INFO)

# Load pre-trained model tokenizer (vocabulary)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Encode a text inputs
text = "Who was Jim Henson ? Jim Henson was a"
indexed_tokens = tokenizer.encode(text)

# Convert indexed tokens in a PyTorch tensor
tokens_tensor = torch.tensor([indexed_tokens])
'''
Use GPT2LMHeadModel to generate the next token following our text
'''
# Load pre-trained model (weights)
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Set the model in evaluation mode to deactivate the DropOut modules
# This is IMPORTANT to have reproducible results during evaluation!
model.eval()

# Predict all tokens
with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0]

# get the predicted next sub-word (in our case, the word 'man')
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
assert predicted_text == 'Who was Jim Henson? Jim Henson was a man'

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

# adapted from - https://huggingface.co/blog/how-to-generate

seed_text = 'My favorite thing about a wedding at the Breakers Palm Beach is'

# how would your tokenizer be trained? It has to choose the right vocabulary (best guess)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

model = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)

input_ids = tokenizer.encode(seed_text, return_tensors='pt')

greedy_output = model.generate(input_ids, max_length=50)
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

Example #24
0
import os
import connexion
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

cwd = os.getcwd()
model_path = cwd + '/distilgpt2'
model = GPT2LMHeadModel.from_pretrained(model_path)
tokenizer = GPT2Tokenizer.from_pretrained(model_path)


def generate(seed):
    input_ids = torch.tensor(tokenizer.encode(seed)).unsqueeze(0)
    outputs = model.generate(input_ids=input_ids, max_length=200)
    return (tokenizer.decode(outputs[0], skip_special_tokens=True))


application = connexion.App(__name__, specification_dir='.')
application.add_api('api.yaml')

if __name__ == "__main__":
    application.run()
Example #25
0
                        help="For debugging purpose")
    args = parser.parse_args()

    args.device = torch.device("cuda")
    args.n_gpu = torch.cuda.device_count()

    if args.model == 'gpt2-medium':
        args.batch_size = 2
    else:
        args.batch_size = 5

    if args.do_rl:
        args.batch_size = 1

    tokenizer = GPT2Tokenizer.from_pretrained(args.model)
    model = GPT2LMHeadModel.from_pretrained(args.model)
    model = nn.DataParallel(model)
    model.to(args.device)

    if not os.path.exists(args.id):
        os.mkdir(args.id)

    criterion = nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
    if args.do_train:
        tb_writer = SummaryWriter(
            log_dir='tensorboard/GPT2-{}'.format(args.model))
        dataset = GPTTableDatabase('data/train_lm.json', None, None, tokenizer,
                                   args.batch_size, args.max_len)
        model.train()
        optimizer = optim.Adam(model.parameters(), args.learning_rate)
Example #26
0
 def build_model(self):
     """build model"""
     self.model = GPT2LMHeadModel.from_pretrained(os.path.join(self.dpath, 'gpt2'))
     self.loss = CrossEntropyLoss(ignore_index=self.pad_id)
Example #27
0
import torch

from transformers import (GPT2LMHeadModel, GPT2Tokenizer)

app = Flask(__name__)

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model.to(device)
length = 32


@app.route("/generate/", methods=['GET', 'POST'])
def root():
    prompt_text = str(request.values['Body'])
    print(prompt_text)
    encoded_prompt = tokenizer.encode(prompt_text,
                                      add_special_tokens=False,
                                      return_tensors="pt")

    output_sequences = model.generate(input_ids=encoded_prompt,
                                      max_length=length,
                                      temperature=1.0,
Example #28
0
args = get_train_args()

logger.info(json.dumps(args.__dict__, indent=2))
logger.info((f"Training in '{device}' "
             f"with {num_gpu} GPU{'s'*(num_gpu > 1)} "))

writer = SummaryWriter(log_dir=args.log_dir)

logger.info(f"Loading tokenizer from {args.model_name}...")
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name)

logger.info(f"Loading model from {args.model_name}...")
model = GPT2LMHeadModel.from_pretrained(
    args.model_name,
    pad_token_id=tokenizer.eos_token_id,
    gradient_checkpointing=True,
)
model = model.to(device)
model.train()

dataset = ScreenwriterData(
    tokenizer,
    block_size=args.block_size,
    recompute=args.recompute_data,
)

data_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=True,
Example #29
0
          outputs = model(input_ids, labels=target_ids)
          log_likelihood = outputs[0] * trg_len

      lls.append(log_likelihood)

  ppl = torch.exp(torch.stack(lls).sum() / end_loc)


train_fpaths = ["/home/data/train/hgg_train.txt", "/home/data/train/fish_train.txt", "/home/data/train/restaurant_train.txt", "/home/data/train/timetravel_train.txt", "/home/data/train/worldwar_train.txt", "/home/data/train/universe_train.txt"]
dataset = Books("<|sentence|>", truncate=False, gpt2_type="gpt2", train_fpaths=train_fpaths)
gpt2_type = "gpt2"


model = train(
    dataset,
    GPT2LMHeadModel.from_pretrained(gpt2_type),
    GPT2Tokenizer.from_pretrained(gpt2_type),
    batch_size=16,
    epochs=10,
    lr=3e-5,
    max_seq_len=140,
    warmup_steps=5000,
    gpt2_type=gpt2_type,
    device=device,
    output_dir="/home/output/models/",
    output_prefix="spinoff",
    save_model_on_epoch=True
)

story = generate(model.to("cpu"), GPT2Tokenizer.from_pretrained(gpt2_type),"<|sentence|>",entry_count=100)
for i, sentence in enumerate(story):
Example #30
0
def main():
    parser = ArgumentParser()
    parser.add_argument('lang')
    parser.add_argument('model')
    parser.add_argument('-n', type=int, default=None)
    args = parser.parse_args()

    with open(Path('data') / args.lang / 'config.json') as f:
        cfg = json.load(f)

    model_path = Path('data') / args.lang / 'models' / args.model

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)

    os.environ['TOKENIZERS_PARALLELISM'] = str(False)
    # tokenizer_tgt = Tokenizer.from_file('tgt.tokenizer.json')
    if args.lang == 'ita':
        tokenizer_tgt = GPT2Tokenizer.from_pretrained(
            'LorenzoDeMattei/GePpeTto')
    else:
        tokenizer_tgt = Tokenizer.from_file(
            str(
                Path('data') / args.lang / 'preparation' / 'vocabularies' /
                'tokenizer.json'))

    # model: GPT2LMHeadModel = EmbeddingTunerModel.load_from_checkpoint(model_path).m
    model = GPT2LMHeadModel.from_pretrained(str(model_path))
    model.to(device)

    if args.n is not None:
        tokenizer_eng = GPT2Tokenizer.from_pretrained('gpt2')

        dict_path = Path(
            'data') / args.lang / 'dictionaries' / f'{args.model}.tsv'
        with open(dict_path) as f_map:
            token_id_map = [
                tokenizer_eng.convert_tokens_to_ids(
                    line.strip().split('\t')[1]) for line in f_map
            ]

        print(f'generating {args.n:,} random texts (unconditioned)')

        out_dir = Path('data') / args.lang / 'results' / 'examples'
        os.makedirs(out_dir, exist_ok=True)
        name = str(int(time()))

        tgt_out_path = out_dir / f'{name}.{args.lang}.txt'
        src_out_path = out_dir / f'{name}.eng.txt'

        print(
            f'generating {args.n} {args.lang} examples to {tgt_out_path} [{src_out_path}]'
        )
        with open(tgt_out_path, 'w') as f_tgt, open(src_out_path,
                                                    'w') as f_eng:
            for i, (tgt, eng) in enumerate(
                    gen(tokenizer_tgt,
                        model,
                        device,
                        n=args.n,
                        tokenizer_eng=tokenizer_eng,
                        token_id_map=token_id_map,
                        cfg=cfg)):
                print(f'{i:,}/{args.n:,}')
                f_tgt.write(tgt + '\n\n')
                f_eng.write(eng + '\n\n')

        return

    while True:
        print('\n##########################################')
        prompt = input(' > ').strip()

        for txt in gen(tokenizer_tgt, model, device, prompt, cfg=cfg):
            print('\n' + txt)