Beispiel #1
0
def main(args):
    random_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda")
    device = torch.device("cuda" if args.cuda else "cpu")

    corpus = data.Corpus(args.data)
    ntokens = len(corpus.dictionary)
    print('loaded dictionary')
    if args.model == 'Transformer':
        model = TransformerModel(
            ntokens,
            args.emsize,
            args.nhead,
            args.nhid,
            args.nlayers,
            args.dropout).to(device)
    else:
        model = RNNModel(
            args.model,
            ntokens,
            args.emsize,
            args.nhid,
            args.nlayers,
            args.dropout,
            args.tied).to(device)

    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print('loaded model')

    is_transformer_model = hasattr(
        model, 'model_type') and model.model_type == 'Transformer'
    if not is_transformer_model:
        hidden = model.init_hidden(1)
    input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
    with open(args.outf, 'w') as outf:
        with torch.no_grad():  # no tracking history
            for i in range(args.words):
                if is_transformer_model:
                    output = model(input, False)
                    word_weights = output[-1].squeeze().div(
                        args.temperature).exp().cpu()
                    word_idx = torch.multinomial(word_weights, 1)[0]
                    word_tensor = torch.Tensor([[word_idx]]).long().to(device)
                    input = torch.cat([input, word_tensor], 0)
                else:
                    output, hidden = model(input, hidden)
                    word_weights = output.squeeze().div(args.temperature).exp().cpu()
                    word_idx = torch.multinomial(word_weights, 1)[0]
                    input.fill_(word_idx)

                word = corpus.dictionary.idx2word[word_idx]

                outf.write(word + ('\n' if i % 20 == 19 else ' '))

                if i % args.log_interval == 0:
                    print('| Generated {}/{} words'.format(i, args.words))
def main(args):
    random_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda")
    device = torch.device("cuda" if args.cuda else "cpu")

    corpus = data.Corpus(args.data)
    ntokens = len(corpus.dictionary)
    word2idx = corpus.dictionary.word2idx
    idx2word = corpus.dictionary.idx2word
    args.vocab_size = len(word2idx)
    print('loaded dictionary')

    if args.model == 'Transformer':
        model = TransformerModel(
            ntokens,
            args.emsize,
            args.nhead,
            args.nhid,
            args.nlayers,
            args.dropout).to(device)
    else:
        model = RNNModel(
            args.model,
            ntokens,
            args.emsize,
            args.nhid,
            args.nlayers,
            args.dropout,
            args.tied).to(device)

    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    is_transformer_model = hasattr(
        model, 'model_type') and model.model_type == 'Transformer'
    print('loaded model')

    input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

    # get as starting words only most common starting word
    # from data corpus(heuristics from baseline)
    most_common_first_words_ids = [i[0] for i in Counter(corpus.train.tolist()).most_common()
                                   if idx2word[i[0]][0].isupper()][:200]
#     most_common_first_words = [corpus.dictionary.idx2word[i]
#                                for i in most_common_first_words_ids]

    # private message(binary code)
    bit_stream = open(args.bit_stream_path, 'r').readline()
    outfile = open(args.save_path + 'generated' +
                   str(args.bit_num) + '_bit.txt', 'w')
    bitfile = open(args.save_path + 'bitfile_' +
                   str(args.bit_num) + '_bit.txt', 'w')
    bit_index = random.randint(0, len(word2idx))
    soft = torch.nn.Softmax(0)

    for uter_id, uter in tqdm.tqdm(
            enumerate(range(args.utterances_to_generate))):
        #         with torch.no_grad():  # no tracking history
        input_ = torch.LongTensor([random.choice(
            most_common_first_words_ids)]).unsqueeze(0).to(device)
        if not is_transformer_model:
            hidden = model.init_hidden(1)

        output, hidden = model(input_, hidden)
        gen = np.random.choice(len(corpus.dictionary), 1,
                               p=np.array(soft(output.reshape(-1)).tolist()) /
                               sum(soft(output.reshape(-1)).tolist()))[0]
        gen_res = list()
        gen_res.append(idx2word[gen])
        bit = ""
        for word_id, word in enumerate(range(args.len_of_generation - 2)):
            if is_transformer_model:
                assert NotImplementedError
            else:
                output, hidden = model(input_, hidden)
            p = output.reshape(-1)
            sorted_, indices = torch.sort(p, descending=True)
            words_prob = [(j, i) for i, j in
                          zip(sorted_[:2**int(args.bit_num)].tolist(),
                              indices[:2**int(args.bit_num)].tolist())]

            nodes = createNodes([item[1] for item in words_prob])
            root = createHuffmanTree(nodes)
            codes = huffmanEncoding(nodes, root)

            for i in range(2**int(args.bit_num)):
                if bit_stream[bit_index:bit_index + i + 1] in codes:
                    code_index = codes.index(
                        bit_stream[bit_index:bit_index + i + 1])
                    gen = words_prob[code_index][0]
                    test_data = np.int32(gen)
                    gen_res.append(idx2word[gen])
                    if idx2word[gen] in ['\n', '', "<eos>"]:
                        break
                    bit += bit_stream[bit_index: bit_index + i + 1]
                    bit_index = bit_index + i + 1
                    break

        gen_sen = ' '.join(
            [word for word in gen_res if word not in ["\n", "", "<eos>"]])
        outfile.write(gen_sen + "\n")
        bitfile.write(bit)