예제 #1
0
def translate_rnn(line: str, line_number: int, s2s: S2S_basic.S2S
                  or S2S_attention.S2S, src_vocab: Vocab, tgt_vocab: Vocab,
                  lang_vec: dict, device: torch.device):

    line = " ".join([src_vocab.start_token, line, src_vocab.end_token])

    line = line.split()

    lang_token = line[1]

    assert lang_token.startswith("<") and lang_token.endswith(">")

    # inputs: (input_length,)
    inputs = torch.tensor([src_vocab.get_index(token) for token in line],
                          device=device)

    # inputs: (input_length, 1)
    inputs = inputs.view(-1, 1)

    if lang_token.startswith("<") and lang_token.endswith(">"):
        # add language vector
        # input_embedding: (input_length, 1, embedding_size)
        # lang_encoding: (embedding_size, )
        lang_encoding = torch.tensor(lang_vec[lang_token], device=device)
        input_embedding = s2s.encoder.embedding(inputs) + lang_encoding
    else:
        input_embedding = s2s.encoder.embedding(inputs)
        print("line {} does not add language embedding".format(line_number))

    encoder_output, encoder_hidden_state = s2s.encoder.rnn(input_embedding)

    decoder_hidden_state = combine_bidir_hidden_state(s2s,
                                                      encoder_hidden_state)

    decoder_input = torch.tensor(
        [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device)

    max_length = (inputs.size(0) - 2) * 3

    pred_line = []

    for i in range(max_length):

        # decoder_output: (1, 1, vocab_size)
        # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size)
        decoder_output, decoder_hidden_state = decode_batch(
            s2s, decoder_input, decoder_hidden_state, encoder_output)

        # pred: (1, 1)
        pred = torch.argmax(decoder_output, dim=2)

        if tgt_vocab.get_token(pred[0, 0].item()) == tgt_vocab.end_token:
            break

        decoder_input = pred

        pred_line.append(tgt_vocab.get_token(pred[0, 0].item()))

    return pred_line
예제 #2
0
def greedy_decoding_transformer(s2s: transformer.S2S,
                                data_tensor: torch.tensor,
                                tgt_vocab: Vocab,
                                device: torch.device,
                                tgt_prefix: List[str] = None):

    # src: (batch_size, input_length)
    src = data_tensor
    src_mask = s2s.make_src_mask(src)

    batch_size = src.size(0)

    encoder_src = s2s.encoder(src, src_mask)

    tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]],
                       device=device)
    tgt = tgt.expand(batch_size, -1)

    # pred_list: List[List] (tgt_length, batch_size)
    pred_list = []

    if tgt_prefix is not None:
        # tgt_prefix_tensor: (batch_size, )
        tgt_prefix = [
            tgt_vocab.get_index(prefix_token) for prefix_token in tgt_prefix
        ]
        tgt_prefix_tensor = torch.tensor(tgt_prefix, device=device)
        # tgt_prefix_tensor: (batch_size, 1)
        tgt_prefix_tensor = tgt_prefix_tensor.unsqueeze(1)
        # tgt: (batch_size, 2)
        tgt = torch.cat([tgt, tgt_prefix_tensor], dim=1)
        pred_list.append(tgt_prefix)

    max_length = src.size(1) * 3

    end_token_index = tgt_vocab.get_index(tgt_vocab.end_token)

    for i in range(0 if tgt_prefix is None else 1, max_length):

        # tgt: (batch_size, i + 1)
        tgt_mask = s2s.make_tgt_mask(tgt)

        # output: (batch_size, input_length, vocab_size)
        output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask)
        # output: (batch_size, vocab_size)
        output = output[:, -1, :]

        # pred: (batch_size, )
        pred = torch.argmax(output, dim=-1)

        if torch.all(pred == end_token_index).item():
            break

        tgt = torch.cat([tgt, pred.unsqueeze(1)], dim=1)

        pred_list.append(pred.tolist())

    return convert_index_to_token(pred_list, tgt_vocab, batch_size,
                                  end_token_index)
def train(cfg):
    """
    training begin
    :param cfg: config file
    :return:
    """
    datasets = build_dataset(cfg)
    algo = TFIDFClustring(cfg)
    vocab = Vocab(cfg)
    summary = SummaryTxt(cfg)
    keyword = Keyword(cfg, summary)

    processed_news_num = 0
    batch_size = cfg.SOLVER.BATCH_SIZE

    print('start training:')
    for seg_id in trange(0, datasets.file_num, batch_size):
        seg = []
        for batch_idx in range(batch_size):
            batch, seg_size = datasets.getitem(seg_id + batch_idx)
            seg.extend(batch)
            processed_news_num += seg_size

        algo.run(segments=seg,
                 vocab=vocab,
                 seg_id=seg_id,
                 keyword=keyword,
                 summary=summary)
        # keyword.update_per_seg(new_updated_topic=new_updated_topic)
        print("seg idx: {}. processed news: {}".format(seg_id,
                                                       processed_news_num))
        pass
예제 #4
0
def convert_data_to_index(data: List[str], vocab: Vocab):
    data2index = []

    for sentence in data:
        sentence = " ".join([vocab.start_token, sentence, vocab.end_token])
        data2index.append(
            [vocab.get_index(token) for token in sentence.split()])

    return data2index
예제 #5
0
def greedy_decoding_rnn(s2s: S2S_basic.S2S or S2S_attention.S2S,
                        data_tensor: torch.tensor, tgt_vocab: Vocab,
                        device: torch.device):

    # inputs: (input_length, batch_size)
    inputs = data_tensor

    batch_size = inputs.size(1)

    encoder_output, encoder_hidden_state = s2s.encoder(inputs)

    decoder_hidden_state = combine_bidir_hidden_state(s2s,
                                                      encoder_hidden_state)

    decoder_input = torch.tensor(
        [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device)
    decoder_input = decoder_input.expand(-1, batch_size)

    max_length = inputs.size(0) * 3

    # pred_list: List[List] (tgt_length, batch_size)
    pred_list = []

    end_token_index = tgt_vocab.get_index(tgt_vocab.end_token)

    for i in range(max_length):

        # decoder_output: (1, batch_size, vocab_size)
        # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size)
        decoder_output, decoder_hidden_state = decode_batch(
            s2s, decoder_input, decoder_hidden_state, encoder_output)

        # pred: (1, batch_size)
        pred = torch.argmax(decoder_output, dim=2)

        if torch.all(pred == end_token_index).item():
            break

        decoder_input = pred

        pred_list.append(pred.squeeze(0).tolist())

    return convert_index_to_token(pred_list, tgt_vocab, batch_size,
                                  end_token_index)
예제 #6
0
def load_corpus_data(data_path,
                     language_name,
                     start_token,
                     end_token,
                     mask_token,
                     vocab_path,
                     rebuild_vocab,
                     unk="UNK",
                     threshold=0):
    if rebuild_vocab:
        v = Vocab(language_name,
                  start_token,
                  end_token,
                  mask_token,
                  threshold=threshold)

    corpus = []

    with open(data_path) as f:

        data = f.read().strip().split("\n")

        for line in data:
            line = line.strip()
            line = " ".join([start_token, line, end_token])

            if rebuild_vocab:
                v.add_sentence(line)

            corpus.append(line)

    data2index = []

    if rebuild_vocab:
        v.add_unk(unk)
        v.save(vocab_path)
    else:
        v = Vocab.load(vocab_path)

    for line in corpus:
        data2index.append([v.get_index(token) for token in line.split()])

    return data2index, v
예제 #7
0
def convert_index_to_token(pred_list: List[List], tgt_vocab: Vocab,
                           batch_size: int, end_token_index: int):

    # pred_line: List[List] (tgt_length, batch_size)
    pred_line = []

    for j in range(batch_size):
        line = []
        for i in range(len(pred_list)):
            if pred_list[i][j] == end_token_index:
                break
            line.append(tgt_vocab.get_token(pred_list[i][j]))
        pred_line.append(line)

    return pred_line
예제 #8
0
def beam_search_transformer(s2s: transformer.S2S, data_tensor: torch.tensor,
                            tgt_vocab: Vocab, beam_size: int,
                            device: torch.device):

    # src: (1, input_length)
    src = data_tensor
    src = src.expand(beam_size, -1)
    src_mask = s2s.make_src_mask(src)

    encoder_src = s2s.encoder(src, src_mask)

    max_length = src.size(1) * 3

    # tgt: (1, 1)
    tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]],
                       device=device)

    # tgt: (beam_size, 1)
    tgt = tgt.expand(beam_size, -1)
    scores = torch.zeros(beam_size, device=device)

    complete_seqs = []
    complete_seqs_scores = []

    step = 1

    while True:

        tgt_mask = s2s.make_tgt_mask(tgt)

        # output: (1 * beam_size, input_length, vocab_size)
        output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask)

        # output: (1 * beam_size, vocab_size)
        output = output[:, -1, :]

        # output: (1 * beam_size, vocab_size)
        output = F.log_softmax(output, dim=-1)

        # sub_sentence_scores: (1 * beam_size, vocab_size)
        sub_sentence_scores = output + scores.unsqueeze(1)

        if step == 1:
            pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size,
                                                                  dim=-1)
        else:
            # sub_sentence_scores: (beam_size * vocab_size)
            sub_sentence_scores = sub_sentence_scores.view(-1)
            pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                               dim=-1)

        # beam_id: (beam_size, )
        beam_id = pred_indices.floor_divide(len(tgt_vocab))
        # token_id: (beam_size, )
        token_id = pred_indices % len(tgt_vocab)

        # next_tgt: (beam_size, input_length + 1)
        next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1)

        if step == max_length:
            complete_seqs.extend(next_tgt.tolist())
            complete_seqs_scores.extend(pred_prob.tolist())
            break

        complete_indices = []

        for i, indices in enumerate(token_id):

            if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token:
                complete_indices.append(i)

        if len(complete_indices) > 0:
            complete_seqs.extend(next_tgt[complete_indices].tolist())

            complete_pred_indices = beam_id[complete_indices] * len(
                tgt_vocab) + token_id[complete_indices]

            if step == 1:
                complete_seqs_scores.extend(
                    sub_sentence_scores[0][complete_pred_indices].tolist())

                if len(complete_indices) == beam_size:
                    break

                sub_sentence_scores[0][complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores[0].topk(
                    beam_size, dim=-1)
            else:
                complete_seqs_scores.extend(
                    sub_sentence_scores[complete_pred_indices].tolist())

                if len(complete_indices) == beam_size:
                    break

                sub_sentence_scores[complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                                   dim=-1)

            # beam_id: (beam_size, )
            beam_id = pred_indices.floor_divide(len(tgt_vocab))
            # token_id: (beam_size, )
            token_id = pred_indices % len(tgt_vocab)
            # next_tgt: (beam_size, input_length + 1)
            next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1)

        step += 1

        tgt = next_tgt
        scores = pred_prob

    best_sentence_id = 0
    for i in range(len(complete_seqs_scores)):
        if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]:
            best_sentence_id = i

    best_sentence = complete_seqs[best_sentence_id]

    best_sentence = [
        tgt_vocab.get_token(index) for index in best_sentence[1:-1]
    ]

    return best_sentence
예제 #9
0
def beam_search_rnn(s2s: S2S_attention.S2S or S2S_basic.S2S,
                    data_tensor: torch.tensor, tgt_vocab: Vocab,
                    beam_size: int, device: torch.device):

    # batch_size == beam_size

    # inputs: (input_length, beam_size)
    inputs = data_tensor
    inputs = inputs.expand(-1, beam_size)

    encoder_output, encoder_hidden_state = s2s.encoder(inputs)

    # decoder_input: (1, beam_size)
    decoder_input = torch.tensor(
        [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device)
    decoder_input = decoder_input.expand(-1, beam_size)

    # decoder_hidden_state: (num_layers, beam_size, hidden_size)
    decoder_hidden_state = combine_bidir_hidden_state(s2s,
                                                      encoder_hidden_state)

    max_length = inputs.size(0) * 3

    scores = torch.zeros(beam_size, device=device)

    complete_seqs = []
    complete_seqs_scores = []
    step = 1

    while True:

        # output: (1, beam_size, vocab_size)
        # decoder_hidden_state: (num_layers, beam_size, hidden_size)
        output, decoder_hidden_state = decode_batch(
            s2s, decoder_input[-1].unsqueeze(0), decoder_hidden_state,
            encoder_output)

        output = F.log_softmax(output, dim=-1)

        # sub_sentence_scores: (beam_size, vocab_size)
        sub_sentence_scores = scores.unsqueeze(1) + output.squeeze(0)

        if step == 1:
            pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size,
                                                                  dim=-1)
        else:
            # sub_sentence_scores: (beam_size * vocab_size)
            sub_sentence_scores = sub_sentence_scores.view(-1)
            pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                               dim=-1)

        # beam_id: (beam_size, )
        beam_id = pred_indices.floor_divide(len(tgt_vocab))

        # token_id: (beam_size, )
        token_id = pred_indices % len(tgt_vocab)

        # decoder_input[-1][beam_id]: (beam_size, )
        # next_decoder_input: (step + 1, beam_size)
        # decoder_input: (step, beam_size)
        next_decoder_input = torch.cat(
            [decoder_input[:, beam_id],
             token_id.unsqueeze(0)], dim=0)

        if step == max_length:
            complete_seqs.extend(next_decoder_input.t().tolist())
            complete_seqs_scores.extend(pred_prob.tolist())
            break

        complete_indices = []

        for i, indices in enumerate(token_id):

            if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token:
                complete_indices.append(i)

        if len(complete_indices) > 0:
            complete_seqs.extend(
                next_decoder_input[:, complete_indices].t().tolist())

            complete_pred_indices = beam_id[complete_indices] * len(
                tgt_vocab) + token_id[complete_indices]

            if step == 1:
                complete_seqs_scores.extend(
                    sub_sentence_scores[0][complete_pred_indices].tolist())

                if len(complete_pred_indices) == beam_size:
                    break

                sub_sentence_scores[0][complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores[0].topk(
                    beam_size, dim=-1)
            else:
                complete_seqs_scores.extend(
                    sub_sentence_scores[complete_pred_indices].tolist())

                if len(complete_pred_indices) == beam_size:
                    break

                sub_sentence_scores[complete_pred_indices] = -1e9
                pred_prob, pred_indices = sub_sentence_scores.topk(beam_size,
                                                                   dim=-1)

            beam_id = pred_indices.floor_divide(len(tgt_vocab))
            token_id = pred_indices % len(tgt_vocab)

            next_decoder_input = torch.cat(
                [decoder_input[:, beam_id],
                 token_id.unsqueeze(0)], dim=0)

        step += 1

        if isinstance(decoder_hidden_state, tuple):
            h, c = decoder_hidden_state
            h = h[:, beam_id]
            c = c[:, beam_id]
            decoder_hidden_state = (h, c)
        else:
            decoder_hidden_state = decoder_hidden_state[:, beam_id]

        decoder_input = next_decoder_input
        scores = pred_prob

    best_sentence_id = 0
    for i in range(len(complete_seqs_scores)):
        if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]:
            best_sentence_id = i

    best_sentence = complete_seqs[best_sentence_id]

    best_sentence = [
        tgt_vocab.get_token(index) for index in best_sentence[1:-1]
    ]

    return best_sentence
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-data_path", type=str, default="../data")
    parser.add_argument("-input_data", type=str, default="sections.pickle")
    parser.add_argument("-vocab", type=str, default="vocab.csv")
    parser.add_argument("-encoding", type=str, default="word2index")
    parser.add_argument("-tokenization", type=str, default="word_tokenize")
    parser.add_argument("-speaker_lables", type=bool, default=True)
    parser.add_argument("-lowercase", type=bool, default=True)
    parser.add_argument("-splitting", type=bool, default=True)
    parser.add_argument("-min_occ", type=int, default=3)
    args = parser.parse_args()

    vocab = Vocab.create(args.data_path, "train_" + args.input_data,
                         args.vocab, args.tokenization, args.lowercase,
                         args.splitting, args.min_occ)

    print("Creating new PhotoBook segment dataset...")
    chain_id = 0
    for set_name in ['dev', 'val', 'test', 'train']:
        with open(
                os.path.join(args.data_path, set_name + "_" + args.input_data),
                'rb') as f:
            dialogue_segments = pickle.load(f)
        segment_dataset, chain_dataset, chain_id = create_segment_datasets(
            dialogue_segments, vocab, chain_id, args.encoding,
            args.tokenization, args.speaker_lables, args.lowercase,
            args.splitting)

        with open(os.path.join(args.data_path, set_name + "_segments.json"),
예제 #11
0
def test_item_file(end_test_file, embedding_file_path, vocab_file_path,
                   use_gpu):
    embed = torch.Tensor(np.load(embedding_file_path)['arr_0'])
    with open(vocab_file_path) as f:
        word2id = json.load(f)
    vocab = Vocab(embed, word2id)
    #with open(end_test_file) as f:
    #    examples = [json.loads(line) for line in f]
    with open(end_test_file) as f:
        examples = list()
        for line in f:
            if line and not line.isspace():
                examples.append(json.loads(line))
    #print(examples[0])
    test_dataset = Dataset(examples)

    test_iter = DataLoader(dataset=test_dataset,
                           batch_size=args.batch_size,
                           shuffle=False)
    load_dir = os.path.join(args.input, 'model_files', 'CNN_RNN.pt')
    if use_gpu:
        checkpoint = torch.load(load_dir)
    else:
        checkpoint = torch.load(load_dir,
                                map_location=lambda storage, loc: storage)
    if not use_gpu:
        checkpoint['args'].device = None
    net = getattr(models, checkpoint['args'].model)(checkpoint['args'])
    net.load_state_dict(checkpoint['model'])
    if use_gpu:
        net.cuda()
    net.eval()
    doc_num = len(test_dataset)

    all_targets = []
    all_results = []
    all_probs = []
    all_acc = []
    all_p = []
    all_r = []
    all_f1 = []
    all_sum = []
    for batch in tqdm(test_iter):
        features, targets, summaries, doc_lens = vocab.make_features(batch)
        if use_gpu:
            probs = net(Variable(features).cuda(), doc_lens)
        else:
            probs = net(Variable(features), doc_lens)
        start = 0
        for doc_id, doc_len in enumerate(doc_lens):
            doc = batch['doc'][doc_id].split('\n')[:doc_len]
            stop = start + doc_len
            prob = probs[start:stop]
            hyp = []
            for _p, _d in zip(prob, doc):
                print(_p)
                print(_d)
                if _p > 0.5:
                    hyp.append(_d)
            if len(hyp) > 0:
                print(hyp)
                all_sum.append("###".join(hyp))
            else:
                all_sum.append('')
            all_targets.append(targets[start:stop])
            all_probs.append(prob)
            start = stop
    file_path_elems = end_test_file.split('/')
    file_name = 'TR-' + file_path_elems[len(file_path_elems) - 1]
    with open(os.path.join(args.output, file_name), mode='w',
              encoding='utf-8') as f:
        for text in all_sum:
            f.write(text.strip() + '\n')
    for item in all_probs:
        all_results.append([1 if tmp > 0.5 else 0 for tmp in item.tolist()])
    print(len(all_results))
    print(len(all_targets))
    print(len(all_probs))
    for _1, _2, _3 in zip(all_results, all_targets, all_probs):
        _2 = _2.tolist()
        _3 = _3.tolist()
        print("*" * 3)
        print('probs : ', _3)
        print('results : ', _1)
        print('targets : ', _2)
        tmp_acc = accuracy_score(_1, _2)
        tmp_p = precision_score(_1, _2)
        tmp_r = recall_score(_1, _2)
        tmp_f1 = f1_score(_1, _2)
        print('acc : ', tmp_acc)
        print('p : ', tmp_p)
        print('r : ', tmp_r)
        print('f1 : ', tmp_f1)
        all_acc.append(tmp_acc)
        all_p.append(tmp_p)
        all_r.append(tmp_r)
        all_f1.append(tmp_f1)
    print('all dataset acc : ', np.mean(all_acc))
    print('all dataset p : ', np.mean(all_p))
    print('all dataset r : ', np.mean(all_r))
    print('all dataset f1 : ', np.mean(all_f1))
    print('all results length : ', len(all_results))
예제 #12
0
from utils.ChainDataset import ChainDataset
from utils.Vocab import Vocab

# Tests the SegmentDataset and ChainDataset classes
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-data_path", type=str, default="../data")
    parser.add_argument("-segment_file", type=str, default="segments.json")
    parser.add_argument("-chains_file", type=str, default="val_chains.json")
    parser.add_argument("-vocab_file", type=str, default="vocab.csv")
    parser.add_argument("-vectors_file", type=str, default="vectors.json")
    parser.add_argument("-split", type=str, default="val")
    args = parser.parse_args()

    print("Loading the vocab...")
    vocab = Vocab(os.path.join(args.data_path, args.vocab_file), 3)

    print("Testing the SegmentDataset class initialization...")

    segment_val_set = SegmentDataset(data_dir=args.data_path,
                                     segment_file=args.segment_file,
                                     vectors_file=args.vectors_file,
                                     split=args.split)

    print("Testing the SegmentDataset class item getter...")
    print("Dataset contains {} segment samples".format(len(segment_val_set)))
    sample_id = 2
    sample = segment_val_set[sample_id]
    print("Segment {}:".format(sample_id))
    print("Image set: {}".format(sample["image_set"]))
    print("Target image index(es): {}".format(sample["targets"]))
예제 #13
0
def translate_transformer(line: str, line_number: int, s2s: transformer.S2S,
                          src_vocab: Vocab, tgt_vocab: Vocab, lang_vec: dict,
                          device: torch.device):

    line = " ".join([src_vocab.start_token, line, src_vocab.end_token])

    line = line.split()

    max_length = (len(line) - 2) * 3

    lang_token = line[1]

    # inputs: (input_length, )
    src = torch.tensor([src_vocab.get_index(token) for token in line],
                       device=device)
    # inputs: (1, input_length)
    src = src.view(1, -1)

    src_mask = s2s.make_src_mask(src)

    src = s2s.encoder.token_embedding(src) * s2s.encoder.scale

    # src: (1, input_length, d_model)
    src = s2s.encoder.pos_embedding(src)

    if lang_token.startswith("<") and lang_token.endswith(">"):
        # lang_encoding: (d_model, )
        lang_encoding = torch.tensor(lang_vec[lang_token], device=device)
        src = src + lang_encoding

    else:
        print("line {} does not add language embedding".format(line_number))

    for layer in s2s.encoder.layers:
        src, self_attention = layer(src, src_mask)

    del self_attention

    encoder_src = src

    tgt = None

    pred_line = [tgt_vocab.get_index(tgt_vocab.start_token)]

    for i in range(max_length):

        if tgt is None:
            tgt = torch.tensor([pred_line], device=device)

        tgt_mask = s2s.make_tgt_mask(tgt)

        # output: (1, tgt_input_length, vocab_size)
        output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask)

        # (1, tgt_input_length)
        pred = torch.argmax(output, dim=-1)[0, -1]

        if tgt_vocab.get_token(pred.item()) == tgt_vocab.end_token:
            break

        tgt = torch.cat([tgt, pred.unsqueeze(0).unsqueeze(1)], dim=1)
        pred_line.append(pred.item())

    pred_line = [tgt_vocab.get_token(index) for index in pred_line[1:]]
    return pred_line
예제 #14
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--device", required=True)
    parser.add_argument("--load", required=True)
    parser.add_argument("--src_vocab_path", required=True)
    parser.add_argument("--tgt_vocab_path", required=True)
    parser.add_argument("--test_src_path", required=True)
    parser.add_argument("--test_tgt_path", required=True)
    parser.add_argument("--lang_vec_path", required=True)
    parser.add_argument("--translation_output", required=True)
    parser.add_argument("--transformer", action="store_true")

    args, unknown = parser.parse_known_args()

    src_vocab = Vocab.load(args.src_vocab_path)
    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    src_data = read_data(args.test_src_path)

    lang_vec = load_lang_vec(args.lang_vec_path)  # lang_vec: dict

    device = args.device

    print("load from {}".format(args.load))

    if args.transformer:

        max_src_length = max(len(line) for line in src_data) + 2
        max_tgt_length = max_src_length * 3
        padding_value = src_vocab.get_index(src_vocab.mask_token)
        assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

        s2s = load_transformer(args.load,
                               len(src_vocab),
                               max_src_length,
                               len(tgt_vocab),
                               max_tgt_length,
                               padding_value,
                               device=device)

    else:
        s2s = load_model(args.load, device=device)

    s2s.eval()

    pred_data = []
    for i, line in enumerate(src_data):
        pred_data.append(
            translate(line, i, s2s, src_vocab, tgt_vocab, lang_vec, device))

    write_data(pred_data, args.translation_output)

    pred_data_tok_path = args.translation_output + ".tok"

    tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(
        args.translation_output, pred_data_tok_path)
    call(tok_command, shell=True)

    bleu_calculation_command = "perl /data/rrjin/NMT/scripts/multi-bleu.perl {} < {}".format(
        args.test_tgt_path, pred_data_tok_path)
    call(bleu_calculation_command, shell=True)
예제 #15
0
def train():
    print("*"*100)
    print("train begin")
    # use gpu
    use_gpu = args.device is not None
    if torch.cuda.is_available() and not use_gpu:
        print("WARNING: You have a CUDA device, should run with -device 0")
    if use_gpu:
        # set cuda device and seed
        torch.cuda.set_device(args.device)
    torch.cuda.manual_seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    numpy.random.seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    # 路径准备
    embedding_file_path = os.path.join(args.project, "embedding.npz")
    vocab_file_path = os.path.join(args.project, "word2id.json")
    end_train_file = os.path.join(args.input, "train_files", "train.txt")
    train_files_dir = os.path.join(args.input, "train_files")

    # 合并同后缀文本文件
    merge_same_suf_text_file(train_files_dir, end_train_file, '.txt')

    print('Loading vocab,train and val dataset.Wait a second,please')
    embed = torch.Tensor(np.load(embedding_file_path)['arr_0'])  # embed = torch.Tensor(list(np.load(args.embedding)))
    with open(vocab_file_path) as f:
        word2id = json.load(f)
    vocab = Vocab(embed, word2id)
    with open(end_train_file) as f:
        examples = list()
        for line in tqdm(f):
            if line and not line.isspace():
                examples.append(json.loads(line))
    train_dataset = Dataset(examples)
    print(train_dataset[:1])

    args.embed_num = embed.size(0)  # 从embeding中读取维度
    args.embed_dim = embed.size(1)  #
    args.kernel_sizes = [int(ks) for ks in args.kernel_sizes.split(',')]
    net = getattr(models, args.model)(args, embed)
    if use_gpu:
        net.cuda()
    train_iter = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=False)
    criterion = nn.BCELoss()
    params = sum(p.numel() for p in list(net.parameters())) / 1e6
    print('#Params: %.1fM' % (params))

    min_loss = float('inf')
    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    net.train()

    t1 = time()
    for epoch in range(1, args.max_epoch + 1):
        print("*"*10, 'epoch ', str(epoch), '*'*50)
        for i, batch in enumerate(train_iter):
            print("*"*10, 'batch', i, '*'*10)
            features, targets, _, doc_lens = vocab.make_features(batch, args.seq_trunc)
            features, targets = Variable(features), Variable(targets.float())
            if use_gpu:
                features = features.cuda()
                targets = targets.cuda()
            probs = net(features, doc_lens)
            loss = criterion(probs, targets)
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm(net.parameters(), args.max_norm)
            optimizer.step()
            net.save()
            print('Epoch: %2d Loss: %f' % (epoch, loss))
    t2 = time()
    print('Total Cost:%f h' % ((t2 - t1) / 3600))
    print("模型配置文件保存至输出文件夹")
예제 #16
0
def evaluation(local_rank, args):
    rank = args.nr * args.gpus + local_rank
    dist.init_process_group(backend="nccl",
                            init_method=args.init_method,
                            rank=rank,
                            world_size=args.world_size)

    device = torch.device("cuda", local_rank)
    torch.cuda.set_device(device)

    # List[str]
    src_data = read_data(args.test_src_path)

    tgt_prefix_data = None
    if args.tgt_prefix_file_path is not None:
        tgt_prefix_data = read_data(args.tgt_prefix_file_path)

    max_src_len = max(len(line.split()) for line in src_data) + 2
    max_tgt_len = max_src_len * 3
    logging.info("max src sentence length: {}".format(max_src_len))

    src_vocab = Vocab.load(args.src_vocab_path)
    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    padding_value = src_vocab.get_index(src_vocab.mask_token)

    assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

    src_data = convert_data_to_index(src_data, src_vocab)

    dataset = DataPartition(src_data, args.world_size, tgt_prefix_data,
                            args.work_load_per_process).dataset(rank)

    logging.info("dataset size: {}, rank: {}".format(len(dataset), rank))

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=(args.batch_size if args.batch_size else 1),
        shuffle=False,
        pin_memory=True,
        drop_last=False,
        collate_fn=lambda batch: collate_eval(
            batch,
            padding_value,
            batch_first=(True if args.transformer else False)))

    if not os.path.isdir(args.translation_output_dir):
        os.makedirs(args.translation_output_dir)

    if args.beam_size:
        logging.info("Beam size: {}".format(args.beam_size))

    if args.is_prefix:
        args.model_load = args.model_load + "*"

    for model_path in glob.glob(args.model_load):
        logging.info("Load model from: {}, rank: {}".format(model_path, rank))

        if args.transformer:

            s2s = load_transformer(model_path,
                                   len(src_vocab),
                                   max_src_len,
                                   len(tgt_vocab),
                                   max_tgt_len,
                                   padding_value,
                                   training=False,
                                   share_dec_pro_emb=args.share_dec_pro_emb,
                                   device=device)

        else:
            s2s = load_model(model_path, device=device)

        s2s.eval()

        if args.record_time:
            import time
            start_time = time.time()

        pred_data = []

        for data, tgt_prefix_batch in data_loader:
            if args.beam_size:
                pred_data.append(
                    beam_search_decoding(s2s, data.to(device,
                                                      non_blocking=True),
                                         tgt_vocab, args.beam_size, device))
            else:
                pred_data.extend(
                    greedy_decoding(s2s, data.to(device, non_blocking=True),
                                    tgt_vocab, device, tgt_prefix_batch))

        if args.record_time:
            end_time = time.time()
            logging.info("Time spend: {} seconds, rank: {}".format(
                end_time - start_time, rank))

        _, model_name = os.path.split(model_path)

        if args.beam_size:
            translation_file_name_prefix = "{}_beam_size_{}".format(
                model_name, args.beam_size)
        else:
            translation_file_name_prefix = "{}_greedy".format(model_name)

        p = os.path.join(
            args.translation_output_dir,
            "{}_translations.rank{}".format(translation_file_name_prefix,
                                            rank))

        write_data(pred_data, p)

        if args.need_tok:

            # replace '@@ ' with ''
            p_tok = os.path.join(
                args.translation_output_dir,
                "{}_translations_tok.rank{}".format(
                    translation_file_name_prefix, rank))

            tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(
                p, p_tok)

            call(tok_command, shell=True)
예제 #17
0
import json
from collections import defaultdict
from utils.Vocab import Vocab

with open('data/test_segments.json', 'r') as file:
    test_sg = json.load(file)

with open('data/test_chains.json', 'r') as file:
    test = json.load(file)

vocab = Vocab('data/vocab.csv', 3)

# given an img, provides the chains for which it was the target

target2chains = defaultdict(list)

for ch in test:
    target_id = ch['target']
    segment_list = ch['segments']

    target2chains[target_id].append(segment_list)

id_list = []

# segments ids, in the order in which they were encountered in the chains in the whole dataset

for c in test:
    segments = c['segments']

    for s in segments:
        if s not in id_list:
예제 #18
0
parser.add_argument("--val_tgt_path", required=True)
parser.add_argument("--src_vocab_path", required=True)
parser.add_argument("--tgt_vocab_path", required=True)
parser.add_argument("--picture_path", required=True)

args, unknown = parser.parse_known_args()

device = args.device

s2s = load_model(args.load, device=device)

if not isinstance(s2s, S2S_attention.S2S):

    raise Exception("The model don't have attention mechanism")

src_vocab = Vocab.load(args.src_vocab_path)
tgt_vocab = Vocab.load(args.tgt_vocab_path)

with open(args.val_tgt_path) as f:

    data = f.read().split("\n")
    tgt_data = [normalizeString(line, to_ascii=False) for line in data]

with torch.no_grad():

    s2s.eval()

    with open(args.val_src_path) as f:

        data = f.read().split("\n")
예제 #19
0
    print(args)

    # prepare datasets and obtain the arguments

    t = datetime.datetime.now()
    timestamp = str(t.date()) + '-' + str(t.hour) + '-' + str(
        t.minute) + '-' + str(t.second)

    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    print("Loading the vocab...")
    vocab = Vocab(os.path.join(args.data_path, args.vocab_file), 3)

    trainset = HistoryDataset(data_dir=args.data_path,
                              segment_file='train_' + args.segment_file,
                              vectors_file=args.vectors_file,
                              chain_file='train_' + args.chains_file,
                              split=args.split)

    testset = HistoryDataset(data_dir=args.data_path,
                             segment_file='test_' + args.segment_file,
                             vectors_file=args.vectors_file,
                             chain_file='test_' + args.chains_file,
                             split='test')

    valset = HistoryDataset(data_dir=args.data_path,
                            segment_file='val_' + args.segment_file,
예제 #20
0
 def __init__(self, cfg, summary):
     Vocab.__init__(self, cfg)
     self.file_path = cfg.OUTPUT_DIR
     self.summary = summary
     pass