def sample(model,
           enc,
           length,
           context,
           temperature=1.0,
           device='cpu',
           topk=-1):
    assert length > 0

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)

    prev = context
    output = context
    past = None

    total_log_probs = 0
    total_entropy_ptau = 0
    total_num = 0
    total_kl = 0  # in bits

    with torch.no_grad():
        while total_num < length:
            if past and past[0].shape[3] >= 1023:
                raise RuntimeError

            logits, past = model(prev.unsqueeze(0), past=past)
            past = limit_past(past)
            logits[0, -1, -1] = -1e10  # endoftext can't happen
            logits[0, -1, 628] = -1e10  # 2 newlines can't happen
            logits, indices = logits[0, -1, :].sort(descending=True)
            base_log_probs = F.log_softmax(logits, dim=-1)

            if topk > 0:
                logits = logits[:topk]

            logits = logits / temperature
            log_probs = F.log_softmax(logits, dim=-1)
            probs = torch.exp(log_probs)

            total_kl += kl(probs, log_probs, base_log_probs[:topk])

            selection = torch.multinomial(probs, num_samples=1).item()
            log_prob_chosen = base_log_probs[selection]
            total_log_probs += log_prob_chosen.item()

            total_entropy_ptau += entropy(probs, log_probs)

            prev = indices[selection].view(1)
            output = torch.cat((output, prev))
            total_num += 1

    avg_NLL = -total_log_probs / total_num
    avg_KL = total_kl / total_num
    avg_Hq = total_entropy_ptau / total_num

    return output[len(context):].tolist(), avg_NLL, avg_KL, avg_Hq
Example #2
0
def encode_arithmetic(model, enc, message, context, finish_sent=False, device='cuda', temp=1.0, precision=16, topk=50000):

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)

    max_val = 2**precision
    threshold = 2**(-precision)
    cur_interval = [0, max_val] # bottom inclusive, top exclusive

    prev = context
    output = context
    past = None

    total_num = 0
    total_num_for_stats = 0
    total_log_probs = 0
    total_kl = 0 # in bits
    total_entropy_ptau = 0
    total_num_sents = 0

    with torch.no_grad():
        i = 0
        sent_finish = False
        while i < len(message) or (finish_sent and not sent_finish):
            logits, past = model(prev.unsqueeze(0), past=past)
            past = limit_past(past)
            logits[0, -1, -1] = -1e20 # endoftext token can't happen
            logits[0, -1, 628] = -1e20 # 2 newlines token can't happen
            logits, indices = logits[0, -1, :].sort(descending=True)
            logits = logits.double()
            logits_temp = logits / temp
            probs_temp = F.softmax(logits_temp, dim=0)
            log_probs_temp = F.log_softmax(logits_temp, dim=0)
            log_probs = F.log_softmax(logits, dim=0)
            
            # conditions for having reached the end of the message
            if i >= len(message):
                selection = 0
                sent_finish = is_sent_finish(indices[selection].item(), enc)
            else:
                # Cutoff low probabilities that would be rounded to 0
                cur_int_range = cur_interval[1]-cur_interval[0]
                cur_threshold = 1/cur_int_range
                k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
                probs_temp_int = probs_temp[:k] # Cutoff all but top k

                # Rescale to correct range
                probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range

                # Round probabilities to integers given precision
                probs_temp_int = probs_temp_int.round().long()
                cum_probs = probs_temp_int.cumsum(0)

                # Remove any elements from the bottom if rounding caused the total prob to be too large
                overfill_index = (cum_probs > cur_int_range).nonzero()
                if len(overfill_index) > 0:
                    cum_probs = cum_probs[:overfill_index[0]]

                # Add any mass to the top if removing/rounding causes the total prob to be too small
                cum_probs += cur_int_range-cum_probs[-1] # add

                # Get out resulting probabilities
                probs_final = cum_probs.clone()
                probs_final[1:] = cum_probs[1:] - cum_probs[:-1]

                # Convert to position in range
                cum_probs += cur_interval[0]

                # Get selected index based on binary fraction from message bits
                message_bits = message[i:i+precision]
                if i+precision > len(message):
                    message_bits = message_bits + [0]*(i+precision-len(message))
                message_idx = bits2int(reversed(message_bits))
                selection = (cum_probs > message_idx).nonzero()[0].item()

                # Calculate new range as ints
                new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
                new_int_top = cum_probs[selection]

                # Convert range to bits
                new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
                new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive

                # Consume most significant bits which are now fixed and update interval
                num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
                i += num_bits_encoded

                new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded
                new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded

                cur_interval[0] = bits2int(reversed(new_int_bottom_bits))
                cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive

                # Gather statistics
                total_log_probs += log_probs[selection].item()

                q = probs_final.double()/probs_final.sum()
                logq = q.log()
                total_kl += kl(q, logq, log_probs[:len(q)])
                total_entropy_ptau += entropy(probs_temp, log_probs_temp)
                total_num_for_stats += 1
            
            # Update history with new token
            prev = indices[selection].view(1)
            output = torch.cat((output, prev))
            total_num += 1
            #print(enc.decode(prev.tolist()), message_bits[:num_bits_encoded])
            
            # For text->bits->text
            partial = enc.decode(output[len(context):].tolist())
            if '<eos>' in partial:
                break
            
    avg_NLL = -total_log_probs/total_num_for_stats
    avg_KL = total_kl/total_num_for_stats
    avg_Hq = total_entropy_ptau/total_num_for_stats
    words_per_bit = total_num_for_stats/i

    return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit, avg_Hq
Example #3
0
def decode_arithmetic(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000):
    # inp is a list of token indices
    # context is a list of token indices
    inp = enc.encode(text)
    # common BPE error case: 128, 128 (2 newlines) is interpretted as 628 (2 newlines)
    i = 0
    while i < len(inp):
        if inp[i] == 628:
            inp[i] = 198
            inp[i+1:i+1] = [198]
            i += 2
        else:
            i += 1

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)

    max_val = 2**precision
    threshold = 2**(-precision)
    cur_interval = [0, max_val] # bottom inclusive, top exclusive

    prev = context
    past = None
    message = []
    with torch.no_grad():
        i = 0
        while i < len(inp):
            logits, past = model(prev.unsqueeze(0), past=past)
            past = limit_past(past)
            logits[0, -1, -1] = -1e10 # endoftext can't happen
            logits[0, -1, 628] = -1e10 # 2 newlines can't happen
            logits, indices = logits[0, -1, :].sort(descending=True)
            logits = logits.double()
            logits_temp = logits / temp
            probs_temp = F.softmax(logits_temp, dim=0)
            
            # Cutoff low probabilities that would be rounded to 0
            cur_int_range = cur_interval[1]-cur_interval[0]
            cur_threshold = 1/cur_int_range
            k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)
            probs_temp_int = probs_temp[:k] # Cutoff all but top k

            # Rescale to correct range
            probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range

            # Round probabilities to integers given precision
            probs_temp_int = probs_temp_int.round().long()
            cum_probs = probs_temp_int.cumsum(0)

            # Remove any elements from the bottom if rounding caused the total prob to be too large
            overfill_index = (cum_probs > cur_int_range).nonzero()
            if len(overfill_index) > 0:
                cum_probs = cum_probs[:overfill_index[0]]
                k = overfill_index[0].item()

            # Add any mass to the top if removing/rounding causes the total prob to be too small
            cum_probs += cur_int_range-cum_probs[-1] # add

            # Covnert to position in range
            cum_probs += cur_interval[0]

            rank = (indices == inp[i]).nonzero().item()

            # Handle most errors that could happen because of BPE with heuristic
            if rank >= k:
                true_token_text = enc.decoder[inp[i]]
                for rank_idx in range(k):
                    prop_token_text = enc.decoder[indices[rank_idx].item()]
                    # common case that is not caught
                    if inp[i] == 128 and indices[rank_idx] == 198:
                        rank = rank_idx
                        inp[i] = indices[rank_idx].item()
                        break
                    
                    # Is there a more likely prefix token that could be the actual token generated?
                    if len(prop_token_text) <= len(true_token_text) and \
                            prop_token_text == true_token_text[:len(prop_token_text)]:
                        rank = rank_idx
                        suffix = true_token_text[len(prop_token_text):]
                        suffix_tokens = enc.encode(suffix) # a list
                        inp[i] = indices[rank_idx].item()
                        inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
                        break

                    # Is there a more likely longer token that could be the actual token generated?
                    elif len(prop_token_text) > len(true_token_text) and \
                              true_token_text == prop_token_text[:len(true_token_text)]:
                        whole_text = true_token_text
                        num_extra = 1
                        while len(whole_text) < len(prop_token_text):
                            whole_text += enc.decoder[inp[i+num_extra]]
                            num_extra += 1
                        if prop_token_text == whole_text[:len(prop_token_text)]:
                            rank = rank_idx
                            inp[i] = indices[rank_idx].item()
                            for j in range(1, num_extra):
                                del inp[i+j]

                            if len(whole_text) > len(prop_token_text):
                                suffix = whole_text[len(prop_token_text):]
                                suffix_tokens = enc.encode(suffix) # a list
                                inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list
                            break
                else:
                    print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))
                    rank = 0
            
            selection = rank
            
            # Calculate new range as ints
            new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]
            new_int_top = cum_probs[selection]

            # Convert range to bits
            new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))
            new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive
            
            # Emit most significant bits which are now fixed and update interval
            num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)
            if i == len(inp)-1:
                new_bits = new_int_bottom_bits_inc
            else:
                new_bits = new_int_top_bits_inc[:num_bits_encoded]
            message += new_bits

            new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded
            new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded

            cur_interval[0] = bits2int(reversed(new_int_bottom_bits))
            cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive
            
            # Update history with new token
            prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
            #print(enc.decode([inp[i]]), new_bits)
            i += 1
    
    return message
Example #4
0
def decode_huffman(model, enc, text, context, bits_per_word, device='cpu'):
    # inp is a list of token indices
    # context is a list of token indices
    inp = enc.encode(text)
    i = 0
    while i < len(inp):
        if inp[i] == 628:
            inp[i] = 198
            inp[i + 1:i + 1] = [198]
            i += 2
        else:
            i += 1

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
    prev = context
    past = None

    message = []
    with torch.no_grad():
        i = 0
        while i < len(inp):
            if past and past[0].shape[3] >= 1023:
                raise RuntimeError

            logits, past = model(prev.unsqueeze(0), past=past)
            past = limit_past(past)
            logits[0, -1, -1] = -1e10  # endoftext can't happen
            logits[0, -1, 628] = -1e10  # 2 newlines can't happen
            logits, indices = logits[0, -1, :].sort(descending=True)

            # Get the top 2**bits options
            indices = indices[:2**bits_per_word]
            log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
            probs = torch.exp(log_probs)

            if inp[i] not in indices:
                true_token_text = enc.decoder[inp[i]]
                for rank_idx in range(2**bits_per_word):
                    prop_token_text = enc.decoder[indices[rank_idx].item()]
                    # common case that is not caught
                    if inp[i] == 128 and indices[rank_idx] == 198:
                        rank = rank_idx
                        inp[i] = indices[rank_idx].item()
                        break

                    # Is there a more likely prefix token that could be the actual token generated?
                    if len(prop_token_text) <= len(true_token_text) and \
                            prop_token_text == true_token_text[:len(prop_token_text)]:
                        rank = rank_idx
                        suffix = true_token_text[len(prop_token_text):]
                        suffix_tokens = enc.encode(suffix)  # a list
                        inp[i] = indices[rank_idx].item()
                        inp[i + 1:i +
                            1] = suffix_tokens  # insert suffix tokens into list
                        break

                    # Is there a more likely longer token that could be the actual token generated?
                    elif len(prop_token_text) > len(true_token_text) and \
                              true_token_text == prop_token_text[:len(true_token_text)]:
                        whole_text = true_token_text
                        num_extra = 1
                        while len(whole_text) < len(prop_token_text):
                            whole_text += enc.decoder[inp[i + num_extra]]
                            num_extra += 1
                        if prop_token_text == whole_text[:len(prop_token_text
                                                              )]:
                            rank = rank_idx
                            inp[i] = indices[rank_idx].item()
                            for j in range(1, num_extra):
                                del inp[i + j]

                            if len(whole_text) > len(prop_token_text):
                                suffix = whole_text[len(prop_token_text):]
                                suffix_tokens = enc.encode(suffix)  # a list
                                inp[i + 1:i +
                                    1] = suffix_tokens  # insert suffix tokens into list
                            break
                else:
                    print(
                        'Unable to fix BPE error: token received: %s=%d, text: %s'
                        % (true_token_text, inp[i], text))
                    rank = 0
            else:
                rank = (indices == inp[i]).nonzero().item()

            probs_array = probs.cpu().numpy()
            coding = HuffmanCoding()
            coding.make_heap_from_array(probs_array)
            coding.merge_nodes()
            coding.make_codes()

            tokens_t = map(int, coding.codes[rank])

            message.extend(tokens_t)
            prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
            i += 1

    return message
Example #5
0
def encode_huffman(model,
                   enc,
                   message,
                   context,
                   bits_per_word,
                   finish_sent=False,
                   device='cpu'):
    length = len(message)

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)

    prev = context
    output = context
    past = None

    total_num = 0
    total_num_for_stats = 0
    total_log_probs = 0
    total_kl = 0  # in bits
    total_num_sents = 0

    with torch.no_grad():
        i = 0
        sent_finish = False
        while i < length or (finish_sent and not sent_finish):
            logits, past = model(prev.unsqueeze(0), past=past)
            past = limit_past(past)
            logits[0, -1, -1] = -1e10  # endoftext can't happen
            logits[0, -1, 628] = -1e10  # 2 newlines can't happen
            logits, indices = logits[0, -1, :].sort(descending=True)

            # Get the top 2**bits options
            indices = indices[:2**bits_per_word]
            log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word]
            probs = torch.exp(log_probs)

            if i >= length:
                selection = 0
                sent_finish = is_sent_finish(indices[0].item(), enc)
            else:
                probs_array = probs.cpu().numpy()
                coding = HuffmanCoding()
                coding.make_heap_from_array(probs_array)
                coding.merge_nodes()
                root = coding.make_codes()

                #print(message[i:i+10])
                while root.token is None:
                    if i >= length or message[i] == 0:
                        root = root.left
                    else:
                        root = root.right
                    i += 1
                selection = root.token

                logq = torch.tensor([
                    -len(coding.codes[idx]) for idx in range(len(probs_array))
                ],
                                    dtype=torch.float,
                                    device=device)  # in bits
                logq = logq * 0.69315  # in nats
                q = torch.exp(logq)
                total_kl += kl(q, logq, log_probs)
                total_log_probs += log_probs[selection].item()
                total_num_for_stats += 1

            total_num += 1

            prev = indices[selection].view(1)
            output = torch.cat((output, prev))

    avg_NLL = -total_log_probs / total_num_for_stats
    avg_KL = total_kl / total_num_for_stats
    words_per_bit = total_num_for_stats / i

    return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
def encode_block(model,
                 enc,
                 message,
                 context,
                 block_size,
                 bin2words,
                 words2bin,
                 finish_sent=False,
                 device='cpu'):
    length = len(message)

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)

    prev = context
    output = context
    past = None

    total_num = 0
    total_num_for_stats = 0
    total_log_probs = 0
    total_kl = 0  # in bits
    total_num_sents = 0

    with torch.no_grad():
        i = 0
        sent_finish = False
        while i < length or (finish_sent and not sent_finish):
            logits, past = model(prev.unsqueeze(0), past=past)
            past = limit_past(past)
            logits[0, -1, -1] = -1e10  # endoftext can't happen
            logits[0, -1, 628] = -1e10  # 2 newlines can't happen
            logits = logits[0, -1, :]
            log_probs = F.log_softmax(logits, dim=-1)

            filtered_logits = logits.clone()
            filtered_logits[:] = -1e10  # first set all to 0

            if i >= length:
                _, indices = logits.sort(descending=True)
                sent_finish = is_sent_finish(indices[0].item(), enc)
            else:
                # First calculate logq
                logq = logits.clone()
                logq[:] = -1e10  # first set all to 0

                for bin_val in range(2**block_size):
                    filtered_logits = logits.clone()
                    filtered_logits[:] = -1e10  # first set all to 0
                    available_tokens = bin2words[bin_val]
                    filtered_logits[available_tokens] = logits[
                        available_tokens]
                    filtered_logits, indices = filtered_logits.sort(
                        descending=True)

                    logq[indices[0]] = -block_size  # in bits

                logq = logq * 0.69315  # in nats
                q = torch.exp(logq)

                # Then find the actual word for the right bin
                m_part = message[i:i + block_size]

                filtered_logits = logits.clone()
                filtered_logits[:] = -1e10  # first set all to 0
                available_tokens = bin2words[bits2int(m_part)]
                filtered_logits[available_tokens] = logits[available_tokens]
                filtered_logits, indices = filtered_logits.sort(
                    descending=True)

                total_kl += kl(q, logq, log_probs)
                total_log_probs += log_probs[indices[0]].item()
                i += block_size
                total_num_for_stats += 1

            total_num += 1
            prev = indices[0].view(1)
            output = torch.cat((output, prev))

    avg_NLL = -total_log_probs / total_num_for_stats
    avg_KL = total_kl / total_num_for_stats
    words_per_bit = total_num_for_stats / i

    return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
def decode_block(model,
                 enc,
                 text,
                 context,
                 block_size,
                 bin2words,
                 words2bin,
                 device='cpu'):
    # inp is a list of token indices
    # context is a list of token indices
    inp = enc.encode(text)
    i = 0
    while i < len(inp):
        if inp[i] == 628:
            inp[i] = 198
            inp[i + 1:i + 1] = [198]
            i += 2
        else:
            i += 1

    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
    prev = context
    past = None

    message = []
    with torch.no_grad():
        i = 0
        while i < len(inp):
            if past and past[0].shape[3] >= 1023:
                raise RuntimeError
            bin_num = words2bin[inp[i]]

            logits, past = model(prev.unsqueeze(0), past=past)
            past = limit_past(past)
            logits[0, -1, -1] = -1e10  # endoftext can't happen
            logits[0, -1, 628] = -1e10  # 2 newlines can't happen

            logits = logits[0, -1, :]
            filtered_logits = logits.clone()
            filtered_logits[:] = -1e10  # first set all to 0

            available_tokens = bin2words[bin_num]
            filtered_logits[available_tokens] = logits[available_tokens]
            filtered_logits, indices = filtered_logits.sort(descending=True)

            rank = (indices == inp[i]).nonzero().item()

            # Handle errors that could happen because of BPE
            if rank > 0:
                true_token_text = enc.decoder[inp[i]]
                for bin_num in range(len(bin2words)):
                    filtered_logits = logits.clone()
                    filtered_logits[:] = -1e10  # first set all to 0

                    available_tokens = bin2words[bin_num]
                    filtered_logits[available_tokens] = logits[
                        available_tokens]
                    filtered_logits, indices = filtered_logits.sort(
                        descending=True)

                    prop_token_text = enc.decoder[indices[0].item()]
                    #print(true_token_text, prop_token_text)

                    # Is there a more likely prefix token that could be the actual token generated?
                    if len(prop_token_text) < len(true_token_text) and \
                            prop_token_text == true_token_text[:len(prop_token_text)]:
                        suffix = true_token_text[len(prop_token_text):]
                        suffix_tokens = enc.encode(suffix)  # a list
                        inp[i] = indices[0].item()
                        inp[i + 1:i +
                            1] = suffix_tokens  # insert suffix tokens into list
                        break

                    # Is there a more likely longer token that could be the actual token generated?
                    elif len(prop_token_text) > len(true_token_text) and \
                              true_token_text == prop_token_text[:len(true_token_text)]:
                        whole_text = true_token_text
                        num_extra = 1
                        while len(whole_text) < len(prop_token_text):
                            whole_text += enc.decoder[inp[i + num_extra]]
                            num_extra += 1
                        if prop_token_text == whole_text[:len(prop_token_text
                                                              )]:
                            inp[i] = indices[0].item()
                            for j in range(1, num_extra):
                                del inp[i + j]

                            if len(whole_text) > len(prop_token_text):
                                suffix = whole_text[len(prop_token_text):]
                                suffix_tokens = enc.encode(suffix)  # a list
                                inp[i + 1:i +
                                    1] = suffix_tokens  # insert suffix tokens into list
                            break
                else:
                    print(
                        'Unable to fix BPE error: token received: %s=%d, text: %s'
                        % (true_token_text, inp[i], text))

            tokens_t = int2bits(bin_num, block_size)

            message.extend(tokens_t)
            prev = torch.tensor([inp[i]], device=device, dtype=torch.long)
            i += 1

    return message