Ejemplo n.º 1
0
def generating_poem(app,prefix,model,config,tokenizer,device,quick=False,num=5,batchGenerating=False,gpu='0',onlyMax=False,fast_pattern=False):
    torch.cuda.set_device(int(gpu))
    if len(prefix)>10:
        return []
    #print("start:", prefix)
    global a
    a = app
    n_ctx = model.config.n_ctx
    length = config['length']
    nsamples = num
    batch_size = config['batch_size']
    temperature = config['temperature']
    topk = config['topk']
    topp = config['topp']
    quick_pattern = quick
    repetition_penalty = config['repetition_penalty']
    if length == -1:
        length = model.config.n_ctx
    #print('generating-begin for %s'%prefix)
    raw_text = prefix[0]+prefix
    context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
    if batchGenerating:
        if onlyMax:
            outs = sample_sequence_batch_max(model, context_tokens, length, n_ctx, tokenizer, nsamples=2,
                                              temperature=temperature, top_k=topk,
                                              top_p=topp, repitition_penalty=repetition_penalty,
                                              device=device)
        else:
            if fast_pattern:
                outs = fast_sample_sequence_batch(model, context_tokens, length, nsamples=nsamples,
                                                  temperature=temperature, top_k=topk,
                                                  repitition_penalty=repetition_penalty, device=device)
            else:
                outs = sample_sequence_batch_opti(model, context_tokens, length, n_ctx, tokenizer, nsamples, temperature=temperature, top_k=topk,
                                  top_p=topp, repitition_penalty=repetition_penalty,
                                  device=device)
        S = []
        for out in outs:
            tmptext = untokenization_poem(out, tokenizer, config)
            poem = poemFilter1(tmptext[1:])
            if poem:
                S.append(poem)
    else:
        S = []
        for _ in range(nsamples):
            out = generate(
                n_ctx=n_ctx,
                model=model,
                context=context_tokens,
                length=length,
                is_fast_pattern=fast_pattern, tokenizer=tokenizer,
                temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device
            )
            tmptext = untokenization_poem(out, tokenizer, config)
            poem = poemFilter1(tmptext[1:])
            if poem:
                S.append(poem)
    S = dropDuplicateContent(S)
    return S
Ejemplo n.º 2
0
 def getpoem(len_sent, nb_sents):
     #len_sent = 7
     #nb_sents = 8
     if len(prefix) < nb_sents:
         prefix0 = list(prefix) + [''] * (nb_sents - len(prefix) + 1)
     else:
         prefix0 = list(prefix)
     raw_text = '[MASK]' + prefix0[0]
     #raw_text = prefix0[0] + prefix0[0]
     context = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
     contexts = [[c for c in context] for _ in range(nsamples)]
     inputs = [[c for c in context] for _ in range(nsamples)]
     num = nsamples
     for ii in range(1, nb_sents + 1):
         outs = fast_sample_sequence_batch_poemHead(
             model,
             contexts,
             inputs,
             length=len_sent + 1,
             nsamples=num,
             temperature=temperature,
             top_k=topk,
             repitition_penalty=repetition_penalty,
             device=device)
         S = [untokenization_poem(out, tokenizer, config) for out in outs]
         if ii == nb_sents:
             break
         S = [
             tmptext for tmptext in S
             if len(tmptext) > ii * (len_sent + 1) - 1
         ]
         if ii % 2 == 0:
             S1 = [
                 tt[:ii * (len_sent + 1)] for tt in S
                 if tt[ii * (len_sent + 1) - 1] in punc_end
             ]
         else:
             S1 = [
                 tt[:ii * (len_sent + 1)] for tt in S
                 if tt[ii * (len_sent + 1) - 1] in punc_mid
             ]
         raw_texts = ['[MASK]' + s + prefix0[ii] for s in S1]
         num = len(raw_texts)
         contexts = [
             tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw))
             for raw in raw_texts
         ]
         inputs = contexts
         if num == 0:
             break
     R = []
     for s in S:
         if s[-1] not in punc_end:
             t = s + '。'
         else:
             t = s
         poem = poemFilter1(t, prefix, config_predict.blackwords)
         if poem:
             R.append(poem)
     return R
Ejemplo n.º 3
0
def generating_poem(app,
                    prefix,
                    model,
                    config,
                    tokenizer,
                    device,
                    config_predict,
                    quick=False,
                    num=5,
                    continue_writing=False,
                    removeHighFreqWords=False,
                    batchGenerating=False,
                    gpu='0',
                    onlyMax=False,
                    maxNb=20):
    if len(prefix) == 0 or len(prefix) > model.config.n_ctx:
        return []
    if sum([_is_chinese_char(c) for c in prefix]) < len(prefix) * 0.75:
        return []
    if gpu:
        torch.cuda.set_device(int(gpu))
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = 'cpu'
    punc = '.,?!;\t 。,?!;'
    global a
    a = app
    fast_pattern = config_predict.fast_pattern
    n_ctx = model.config.n_ctx
    length = config['length']
    nsamples = num
    #maxNb = max(nsamples,maxNb)
    maxNb = nsamples
    temperature = config['temperature']
    topk = config['topk']
    topp = config['topp']
    quick_pattern = quick
    repetition_penalty = config['repetition_penalty']
    if length == -1:
        length = model.config.n_ctx
    #raw_text = prefix[0] + prefix
    raw_text = '[MASK]' + prefix
    context_tokens = tokenizer.convert_tokens_to_ids(
        tokenizer.tokenize(raw_text))
    t0 = time.time()
    outs = []
    if batchGenerating:
        S = []
        if onlyMax:
            outs = sample_sequence_batch_max(
                model,
                context_tokens,
                length,
                n_ctx,
                tokenizer,
                nsamples=2,
                temperature=temperature,
                top_k=topk,
                top_p=topp,
                repitition_penalty=repetition_penalty,
                device=device)
        else:
            if fast_pattern:
                outs = fast_sample_sequence_batch(
                    model,
                    context_tokens,
                    length,
                    nsamples=maxNb,
                    temperature=temperature,
                    top_k=topk,
                    repitition_penalty=repetition_penalty,
                    device=device)
            else:
                outs = sample_sequence_batch_opti(
                    model,
                    context_tokens,
                    length,
                    n_ctx,
                    tokenizer,
                    maxNb,
                    temperature=temperature,
                    top_k=topk,
                    top_p=topp,
                    repitition_penalty=repetition_penalty,
                    device=device)
    else:
        S = []
        for _ in range(maxNb):
            out = generate(n_ctx=n_ctx,
                           model=model,
                           context=context_tokens,
                           length=length,
                           is_fast_pattern=fast_pattern,
                           tokenizer=tokenizer,
                           is_quick=quick_pattern,
                           temperature=temperature,
                           top_k=topk,
                           top_p=topp,
                           repitition_penalty=repetition_penalty,
                           device=device)
            tmptext = untokenization(out, config, tokenizer, punc,
                                     continue_writing)
            S.append(tmptext)
    S = []
    for out in outs:
        tmptext = untokenization_poem(out, tokenizer, config)
        poem = poemFilter1(tmptext, prefix, config_predict.blackwords)
        if poem:
            S.append(poem)
    S = dropDuplicateContent(S)
    S = S[:nsamples]
    return S