def eval_bleu(self):
        self.logger.info('Evaluate dev BLEU')
        start = time.time()
        self.model.eval()
        avg_bleus = []
        dump_dir = self.args.dump_dir
        with torch.no_grad():
            for pair in self.pairs:
                self.logger.info('--> {}'.format(pair))
                src_lang, tgt_lang = pair.split('2')
                src_lang_idx = self.data_manager.lang_vocab[src_lang]
                tgt_lang_idx = self.data_manager.lang_vocab[tgt_lang]
                logit_mask = self.data_manager.logit_masks[tgt_lang]
                data = self.data_manager.translate_data[pair]
                src_batches = data['src_batches']
                sorted_idxs = data['sorted_idxs']
                ref_file = data['ref_file']

                all_best_trans, all_beam_trans = self._translate(src_batches, sorted_idxs, src_lang_idx, tgt_lang_idx, logit_mask)

                all_best_trans = ''.join(all_best_trans)
                best_trans_file = join(dump_dir, '{}_val_trans.txt.bpe'.format(pair))
                open(best_trans_file, 'w').close()
                with open(best_trans_file, 'w') as fout:
                    fout.write(all_best_trans)

                all_beam_trans = ''.join(all_beam_trans)
                beam_trans_file = join(dump_dir, '{}_beam_trans.txt.bpe'.format(pair))
                open(beam_trans_file, 'w').close()
                with open(beam_trans_file, 'w') as fout:
                    fout.write(all_beam_trans)

                # merge BPE
                nobpe_best_trans_file = join(dump_dir, '{}_val_trans.txt'.format(pair))
                ut.remove_bpe(best_trans_file, nobpe_best_trans_file)
                nobpe_beam_trans_file = join(dump_dir, '{}_beam_trans.txt'.format(pair))
                ut.remove_bpe(beam_trans_file, nobpe_beam_trans_file)

                # calculate BLEU
                bleu, msg = ut.calc_bleu(self.args.bleu_script, nobpe_best_trans_file, ref_file)
                self.logger.info(msg)
                avg_bleus.append(bleu)
                self.stats[pair]['dev_bleus'].append(bleu)

                # save translation with BLEU score for future reference
                trans_file = '{}-{}'.format(nobpe_best_trans_file, bleu)
                shutil.copyfile(nobpe_best_trans_file, trans_file)
                beam_file = '{}-{}'.format(nobpe_beam_trans_file, bleu)
                shutil.copyfile(nobpe_beam_trans_file, beam_file)

        avg_bleu = sum(avg_bleus) / len(avg_bleus)
        self.stats['avg_bleus'].append(avg_bleu)
        self.logger.info('avg_bleu = {}'.format(avg_bleu))
        self.logger.info('Done evaluating dev BLEU, it takes {} seconds'.format(ut.format_seconds(time.time() - start)))
Esempio n. 2
0
def gen_batch2str(src, generated, gen_len, src_vocab, tgt_vocab):
    generated = generated.cpu().numpy().tolist()
    gen_len = gen_len.cpu().numpy().tolist()
    src = src.cpu().numpy().tolist()
    translated = []
    for i, l in enumerate(generated):
        l = l[:gen_len[i]]
        sys_sent = " ".join([tgt_vocab.itos[tok] for tok in l])
        src_sent = " ".join([src_vocab.itos[tok] for tok in src[i]])
        sys_sent = remove_special_tok(remove_bpe(sys_sent))
        src_sent = remove_special_tok(remove_bpe(src_sent))
        translated.append("S: " + src_sent)
        translated.append("H: " + sys_sent)
    return translated
Esempio n. 3
0
def translate(args, net, src_vocab, tgt_vocab):
    "done"
    sentences = [l.split() for l in args.text]
    translated = []

    for i_s, sentence in enumerate(sentences):
        s_trans = translate_sentence(sentence, net, args, src_vocab, tgt_vocab)
        s_trans = remove_special_tok(remove_bpe(s_trans))
        translated.append(s_trans)
        print(translated[-1])

    return translated
Esempio n. 4
0
def translate(args, net, src_vocab, tgt_vocab):
    "done"
    sentences = [l.split() for l in args.text]
    translated = []

    if args.greedy:
        infer_dataset = ParallelDataset(args.text, args.ref_text, src_vocab,
                                        tgt_vocab)
        if args.batch_size is not None:
            infer_dataset.BATCH_SIZE = args.batch_size
        if args.max_batch_size is not None:
            infer_dataset.max_batch_size = args.max_batch_size
        if args.tokens_per_batch is not None:
            infer_dataset.tokens_per_batch = args.tokens_per_batch

        infer_dataiter = iter(infer_dataset.get_iterator(True, True))
        num_sents = 0
        for raw_batch in infer_dataiter:
            src_mask = (raw_batch.src !=
                        src_vocab.stoi[config.PAD]).unsqueeze(-2)
            if args.use_cuda:
                src, src_mask = raw_batch.src.cuda(), src_mask.cuda()
            generated, gen_len = greedy(args, net, src, src_mask, src_vocab,
                                        tgt_vocab)
            new_translations = gen_batch2str(src, raw_batch.tgt, generated,
                                             gen_len, src_vocab, tgt_vocab)
            print('src size : {}'.format(src.size()))
            '''
            for res_sent in new_translations:
                print(res_sent)
            translated.extend(new_translations)
            '''
    else:
        for i_s, sentence in enumerate(sentences):
            s_trans = translate_sentence(sentence, net, args, src_vocab,
                                         tgt_vocab)
            s_trans = remove_special_tok(remove_bpe(s_trans))
            translated.append(s_trans)
            print(translated[-1])

    return translated
Esempio n. 5
0
def back_translation_modify(args):
    f = open(args.unlabeled_dataset, 'r')
    unlabeled_dataset = f.read().split("\n")[:-1]
    f.close()

    src_labeled_dataset, tgt_labeled_dataset = args.labeled_dataset.split(",")
    labeled_dataset = []
    f = open(src_labeled_dataset, 'r')
    labeled_dataset.append(f.read().split("\n")[:-1])
    f.close()

    f = open(tgt_labeled_dataset, 'r')
    labeled_dataset.append(f.read().split("\n")[:-1])
    f.close()

    # Read oracle
    f = open(args.oracle, "r")
    oracle = f.read().split("\n")[:-1]
    f.close()
    assert len(oracle) == len(unlabeled_dataset)

    # Read active out_1
    f = open(args.active_out_1, "r")
    active_out_1 = f.read().split("\n")[:-1]
    f.close()

    # Sort active_out
    assert len(active_out_1) % 5 == 0
    assert len(active_out_1) / 5 == len(oracle)
    active_out_1 = [[
        active_out_1[i], active_out_1[i + 1], active_out_1[i + 2],
        float(active_out_1[i + 3].split(' ')[-1]), active_out_1[i + 4]
    ] for i in range(0, len(active_out_1), 5)]
    active_out_1 = sorted(active_out_1, key=lambda item: item[3])

    # Read active out_2
    f = open(args.active_out_2, "r")
    active_out_2 = f.read().split("\n")[:-1]
    f.close()

    # Sort active_out
    assert len(active_out_2) % 4 == 0
    assert len(active_out_2) / 4 == len(oracle)
    active_out_2 = [[
        active_out_2[i], "H: ", active_out_2[i + 1],
        float(active_out_2[i + 2].split(' ')[-1]), active_out_2[i + 3]
    ] for i in range(0, len(active_out_2), 4)]
    active_out_2 = sorted(active_out_2, key=lambda item: item[3])

    active_out = [0] * len(active_out_1)

    for i in range(len(active_out_1)):
        idx = int(active_out_1[i][4].split()[-1])
        active_out[idx] = [
            active_out_1[i][0], active_out_1[i][1], active_out_1[i][2], i
        ]

    for i in range(len(active_out_2)):
        idx = int(active_out_2[i][4].split()[-1])
        assert active_out[idx][0] == active_out_2[i][0]
        assert active_out[idx][2] == active_out_2[i][2]
        active_out[idx][3] += i

    active_out = sorted(active_out, key=lambda item: item[3])

    # Change datasets
    indices = np.arange(len(active_out))
    lengths = np.array([
        len(remove_special_tok(remove_bpe(item[0][len("S: "):])).split(' '))
        for item in active_out
    ])
    include_oracle = np.cumsum(lengths) <= args.tok_budget

    for idx in indices[include_oracle]:
        labeled_dataset[0].append(active_out[idx][0][len("S: "):].strip())
        labeled_dataset[1].append(active_out[idx][2][len("T: "):].strip())

    unlabeled_dataset = []
    oracle = []
    not_include = (1 - include_oracle).astype('bool')
    not_include = indices[not_include]
    for idx in not_include:
        unlabeled_dataset.append(active_out[idx][0][len("S: "):].strip())
        oracle.append(active_out[idx][2][len("T: "):].strip())

    combined = list(zip(unlabeled_dataset, oracle))
    random.shuffle(combined)

    unlabeled_dataset[:], oracle[:] = zip(*combined)

    # Store new labeled, unlabeled, oracle dataset
    f = open(args.output_unlabeled_dataset, 'w')
    f.write("\n".join(unlabeled_dataset) + "\n")
    f.close()

    output_src_labeled_dataset, output_tgt_labeled_dataset = args.output_labeled_dataset.split(
        ",")
    f = open(output_src_labeled_dataset, 'w')
    f.write("\n".join(labeled_dataset[0]) + "\n")
    f.close()

    f = open(output_tgt_labeled_dataset, 'w')
    f.write("\n".join(labeled_dataset[1]) + "\n")
    f.close()

    f = open(args.output_oracle, 'w')
    f.write("\n".join(oracle) + "\n")
    f.close()

    output_new_queries_src, output_new_queries_tgt = args.output_new_queries.split(
        ',')

    f = open(output_new_queries_src, 'w')
    f.write("\n".join([
        active_out[idx][0][len("S: "):].strip()
        for idx in range(len(include_oracle)) if include_oracle[idx]
    ]) + '\n')
    f.close()

    f = open(output_new_queries_tgt, 'w')
    f.write("\n".join([
        active_out[idx][2][len("T: "):].strip()
        for idx in range(len(include_oracle)) if include_oracle[idx]
    ]) + '\n')
    f.close()

    output_train_src, output_train_tgt = args.output_train.split(',')
    include_pseudo = np.cumsum(lengths) <= (args.tok_budget +
                                            args.back_translation_tok_budget)
    include_pseudo = np.logical_xor(include_pseudo, include_oracle)
    include_pseudo = indices[include_pseudo]
    labeled_dataset[0].extend(
        [active_out[idx][0][len("S: "):].strip() for idx in include_pseudo])
    labeled_dataset[1].extend(
        [active_out[idx][1][len("H: "):].strip() for idx in include_pseudo])
    ots = labeled_dataset[0]
    ott = labeled_dataset[1]

    f = open(output_train_src, 'w')
    f.write("\n".join(ots) + "\n")
    f.close()

    f = open(output_train_tgt, 'w')
    f.write("\n".join(ott) + "\n")
    f.close()
Esempio n. 6
0
def supvised_learning_modify(args):
    f = open(args.unlabeled_dataset, 'r')
    unlabeled_dataset = f.read().split("\n")[:-1]
    f.close()

    src_labeled_dataset, tgt_labeled_dataset = args.labeled_dataset.split(",")
    labeled_dataset = []
    f = open(src_labeled_dataset, 'r')
    labeled_dataset.append(f.read().split("\n")[:-1])
    f.close()

    f = open(tgt_labeled_dataset, 'r')
    labeled_dataset.append(f.read().split("\n")[:-1])
    f.close()

    # Read oracle
    f = open(args.oracle, "r")
    oracle = f.read().split("\n")[:-1]
    assert len(oracle) == len(unlabeled_dataset)

    # Read active out
    f = open(args.active_out, "r")
    active_out = f.read().split("\n")[:-1]

    # Sort active_out
    assert len(active_out) % 4 == 0
    assert len(active_out) / 4 == len(oracle)
    active_out = [[
        active_out[i], active_out[i + 1],
        float(active_out[i + 2].split(' ')[-1]), active_out[i + 3]
    ] for i in range(0, len(active_out), 4)]
    random.shuffle(active_out)
    active_out = sorted(active_out, key=lambda item: item[2])

    # Change datasets
    indices = np.arange(len(active_out))
    lengths = np.array([
        len(remove_special_tok(remove_bpe(item[0][len("S: "):])).split(' '))
        for item in active_out
    ])
    include = np.cumsum(lengths) <= args.tok_budget
    not_include = (1 - include).astype('bool')
    include = indices[include]
    not_include = indices[not_include]

    for idx in include:
        labeled_dataset[0].append(active_out[idx][0][len("S: "):].strip())
        labeled_dataset[1].append(active_out[idx][1][len("T: "):].strip())

    unlabeled_dataset = []
    oracle = []
    for idx in not_include:
        unlabeled_dataset.append(active_out[idx][0][len("S: "):].strip())
        oracle.append(active_out[idx][1][len("T: "):].strip())

    combined = list(zip(unlabeled_dataset, oracle))
    random.shuffle(combined)

    unlabeled_dataset[:], oracle[:] = zip(*combined)

    # Store new labeled, unlabeled, oracle dataset
    f = open(args.output_unlabeled_dataset, 'w')
    f.write("\n".join(unlabeled_dataset) + "\n")
    f.close()

    output_src_labeled_dataset, output_tgt_labeled_dataset = args.output_labeled_dataset.split(
        ",")
    f = open(output_src_labeled_dataset, 'w')
    f.write("\n".join(labeled_dataset[0]) + "\n")
    f.close()

    f = open(output_tgt_labeled_dataset, 'w')
    f.write("\n".join(labeled_dataset[1]) + "\n")
    f.close()

    f = open(args.output_oracle, 'w')
    f.write("\n".join(oracle) + "\n")
    f.close()
Esempio n. 7
0
def query_instances(args,
                    unlabeled_dataset,
                    oracle,
                    active_func="random",
                    labeled_dataset=None):
    # lc stands for least confident
    # te stands for token entropy
    # tte stands for total token entropy
    assert active_func in [
        "random", "longest", "shortest", "lc", "margin", "te", "tte", "dden"
    ]

    # lengths represents number of tokens, so BPE should be removed
    lengths = np.array([
        len(remove_special_tok(remove_bpe(s)).split())
        for s in unlabeled_dataset
    ])

    # Preparations before querying instances
    # Reloading network parameters
    args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available()
    net, _ = model.get()

    assert os.path.exists(args.checkpoint)
    net, src_vocab, tgt_vocab = load_model(args.checkpoint, net)

    if args.use_cuda:
        net = net.cuda()

    # Initialize inference dataset (Unlabeled dataset)
    infer_dataset = Dataset(unlabeled_dataset, src_vocab)
    if args.batch_size is not None:
        infer_dataset.BATCH_SIZE = args.batch_size
    if args.max_batch_size is not None:
        infer_dataset.max_batch_size = args.max_batch_size
    if args.tokens_per_batch is not None:
        infer_dataset.tokens_per_batch = args.tokens_per_batch

    infer_dataiter = iter(
        infer_dataset.get_iterator(shuffle=True,
                                   group_by_size=True,
                                   include_indices=True))

    # Start ranking unlabeled dataset
    indices = np.arange(len(unlabeled_dataset))
    if active_func == "random":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        random.shuffle(result)
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference,
                  idx + args.previous_num_sents)
    elif active_func == "longest":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = [(len(
            remove_special_tok(remove_bpe(
                unlabeled_dataset[item[1]])).split(' ')), item[1], item[2])
                  for item in result]
        result = sorted(result, key=lambda item: -item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", -result[idx][0])
            print("I:", args.input, args.reference,
                  idx + args.previous_num_sents)
    elif active_func == "shortest":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = [(len(
            remove_special_tok(remove_bpe(
                unlabeled_dataset[item[1]])).split(' ')), item[1], item[2])
                  for item in result]
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference,
                  idx + args.previous_num_sents)
        indices = indices[np.argsort(lengths[indices])]
    elif active_func in ["lc", "margin", "te", "tte"]:
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')

        for idx in range(len(result)):
            print("S:", unlabeled_dataset[result[idx][1]])
            print("H:", result[idx][2])
            print("T:", oracle[result[idx][1]])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference,
                  result[idx][1] + args.previous_num_sents)
    elif active_func == "dden":
        punc = [
            ".", ",", "?", "!", "'", "<", ">", ":", ";", "(", ")", "{", "}",
            "[", "]", "-", "..", "...", "...."
        ]
        lamb1 = 1
        lamb2 = 1
        p_u = {}
        unlabeled_dataset_without_bpe = []
        labeled_dataset_without_bpe = [[], []]
        for s in unlabeled_dataset:
            unlabeled_dataset_without_bpe.append(
                remove_special_tok(remove_bpe(s)))
        for s in labeled_dataset[0]:
            labeled_dataset_without_bpe[0].append(
                remove_special_tok(remove_bpe(s)))
        for s in labeled_dataset[1]:
            labeled_dataset_without_bpe[1].append(
                remove_special_tok(remove_bpe(s)))
        for s in unlabeled_dataset_without_bpe:
            sentence = s.split()
            for token in sentence:
                if token not in punc:
                    if token in p_u.keys():
                        p_u[token] += 1
                    else:
                        p_u[token] = 1
        total_dden = 0
        for token in p_u.keys():
            p_u[token] = math.log(p_u[token] + 1)
            total_dden += p_u[token]
        for token in p_u.keys():
            p_u[token] /= total_dden
        count_l = {}
        for s in labeled_dataset_without_bpe[0]:
            sentence = s.split()
            for token in sentence:
                if token not in punc:
                    if token in count_l.keys():
                        count_l[token] += 1
                    else:
                        count_l[token] = 1
        dden = []
        for s in unlabeled_dataset_without_bpe:
            sentence = s.split()
            len_for_sentence = 0
            sum_for_sentence = 0
            for token in sentence:
                if token not in punc:
                    if token in count_l.keys():
                        sum_for_sentence += p_u[token] * math.exp(
                            -lamb1 * count_l[token])
                    else:
                        sum_for_sentence += p_u[token]
                len_for_sentence += 1
            if len_for_sentence != 0:
                sum_for_sentence /= len_for_sentence
            dden.append(sum_for_sentence)
        unlabeled_with_index = []
        for i in range((len(unlabeled_dataset))):
            unlabeled_with_index.append((dden[i], i))
        unlabeled_with_index.sort(key=lambda x: x[0], reverse=True)
        count_batch = {}
        dden_new = []
        for _, i in unlabeled_with_index:
            sentence = unlabeled_dataset_without_bpe[i].split()
            len_for_sentence = 0
            sum_for_sentence = 0
            for token in sentence:
                if token not in punc:
                    p_tmp = p_u[token]
                    if token in count_batch.keys():
                        p_tmp = 0
                        p_tmp *= math.exp(-lamb2 * count_batch[token])
                    if token in count_l.keys():
                        p_tmp *= math.exp(-lamb1 * count_l[token])
                    sum_for_sentence += p_tmp
                len_for_sentence += 1
            for token in sentence:
                if token not in punc:
                    if token in count_batch.keys():
                        count_batch[token] += 1
                    else:
                        count_batch[token] = 1
            if len_for_sentence != 0:
                sum_for_sentence /= len_for_sentence
            dden_new.append((sum_for_sentence, i))
        dden_new.sort(key=lambda x: x[1])
        dden_sort = []
        for dden_num, _ in dden_new:
            dden_sort.append(dden_num)
        ddens = np.array(dden_sort)
        indices = indices[np.argsort(-ddens)]
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("T:", oracle[idx])
            print("V:", -ddens[idx])
            print("I:", args.input, args.reference, idx)
Esempio n. 8
0
def main():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(help='two modes, score or modify')

    # Add argument for score mode
    parser_score = subparsers.add_parser(
        'score', help='Get active function scores for each unlabeled sentence')
    parser_score.add_argument("-a",
                              "--active_func",
                              type=str,
                              help="Active query function type",
                              required=True)
    parser_score.add_argument("-i",
                              "--input",
                              type=str,
                              help="where to read unlabeled data")
    parser_score.add_argument("-lb",
                              "--input_labeled",
                              type=str,
                              help="where to read labeled data")
    parser_score.add_argument("-ref",
                              "--reference",
                              type=str,
                              help="where to read oracle data")
    parser_score.add_argument(
        '-ckpt',
        '--checkpoint',
        type=str,
        help="Checkpoint path to reload network parameters")
    parser_score.add_argument(
        '-max_len',
        type=int,
        default=250,
        help="Maximum length for generating translations")
    parser_score.add_argument('-no_cuda',
                              action="store_true",
                              help="Use cpu to do translation")
    parser_score.add_argument('--batch_size',
                              type=int,
                              default=None,
                              help="Batch size for generating translations")
    parser_score.add_argument(
        '--max_batch_size',
        type=int,
        default=None,
        help="Maximum batch size if tokens_per_batch is not None")
    parser_score.add_argument(
        '--tokens_per_batch',
        type=int,
        default=None,
        help="Maximum number of tokens in a batch when generating translations"
    )

    # Add argument for modify mode
    parser_modify = subparsers.add_parser(
        'modify',
        help=
        'Change labeled, unlabeled oracle dataset after activation function values is calculated'
    )
    parser_modify.add_argument("-U",
                               "--unlabeled_dataset",
                               type=str,
                               help="where to read unlabelded dataset",
                               required=True)
    parser_modify.add_argument(
        "-L",
        "--labeled_dataset",
        type=str,
        help="where to read labeled dataset, split by comma, e.g. l.de,l.en",
        required=True)
    parser_modify.add_argument("--oracle",
                               type=str,
                               help="where to read oracle dataset",
                               required=True)
    parser_modify.add_argument("-tb",
                               "--tok_budget",
                               type=int,
                               help="Token budget",
                               required=True)
    parser_modify.add_argument("-OU",
                               "--output_unlabeled_dataset",
                               type=str,
                               help="path to store new unlabeled dataset",
                               required=True)
    parser_modify.add_argument("-OL",
                               "--output_labeled_dataset",
                               type=str,
                               help="path to store new labeled dataset",
                               required=True)
    parser_modify.add_argument("-OO",
                               "--output_oracle",
                               type=str,
                               help="path to oracle",
                               required=True)
    parser_modify.add_argument('-AO',
                               '--active_out',
                               type=str,
                               help="path to active function output")
    args = parser.parse_args()

    args.mode = "score" if hasattr(args, "active_func") else "modify"

    if args.mode == "score":
        f = open(args.input, 'r')
        text = f.read().split('\n')
        if text[-1] == "":
            text = text[:-1]
        f.close()

        f = open(args.reference, 'r')
        ref_text = f.read().split('\n')
        if ref_text[-1] == "":
            ref_text = ref_text[:-1]
        f.close()

        if hasattr(args, "input_labeled"):
            lab_text = []
            f = open(args.input_labeled + '.de', 'r')
            labeled_text = f.read().split('\n')
            if labeled_text[-1] == "":
                labeled_text = labeled_text[:-1]
            f.close()
            lab_text.append(labeled_text)

            f = open(args.input_labeled + '.en', 'r')
            labeled_text = f.read().split('\n')
            if labeled_text[-1] == "":
                labeled_text = labeled_text[:-1]
            f.close()
            lab_text.append(labeled_text)

            query_instances(args, text, ref_text, args.active_func, lab_text)
        else:
            query_instances(args, text, ref_text, args.active_func)

    else:
        # Read labeled and unlabeled datasets
        f = open(args.unlabeled_dataset, 'r')
        unlabeled_dataset = f.read().split("\n")[:-1]
        f.close()

        src_labeled_dataset, tgt_labeled_dataset = args.labeled_dataset.split(
            ",")
        labeled_dataset = []
        f = open(src_labeled_dataset, 'r')
        labeled_dataset.append(f.read().split("\n")[:-1])
        f.close()

        f = open(tgt_labeled_dataset, 'r')
        labeled_dataset.append(f.read().split("\n")[:-1])
        f.close()

        # Read oracle
        f = open(args.oracle, "r")
        oracle = f.read().split("\n")[:-1]
        assert len(oracle) == len(unlabeled_dataset)

        # Read active out
        f = open(args.active_out, "r")
        active_out = f.read().split("\n")[:-1]

        # Sort active_out
        assert len(active_out) % 4 == 0
        assert len(active_out) / 4 == len(oracle)
        active_out = [[
            active_out[i], active_out[i + 1],
            float(active_out[i + 2].split(' ')[-1]), active_out[i + 3]
        ] for i in range(0, len(active_out), 4)]
        active_out = sorted(active_out, key=lambda item: item[2])

        # Change datasets
        indices = np.arange(len(active_out))
        lengths = np.array([
            len(
                remove_special_tok(remove_bpe(
                    item[0][len("S: "):])).split(' ')) for item in active_out
        ])
        include = np.cumsum(lengths) <= args.tok_budget
        not_include = (1 - include).astype('bool')
        include = indices[include]
        not_include = indices[not_include]

        for idx in include:
            labeled_dataset[0].append(active_out[idx][0][len("S: "):])
            labeled_dataset[1].append(active_out[idx][1][len("T: "):])

        unlabeled_dataset = []
        oracle = []
        for idx in not_include:
            unlabeled_dataset.append(active_out[idx][0][len("S: "):])
            oracle.append(active_out[idx][1][len("T: "):])

        combined = list(zip(unlabeled_dataset, oracle))
        random.shuffle(combined)

        unlabeled_dataset[:], oracle[:] = zip(*combined)

        # Store new labeled, unlabeled, oracle dataset
        f = open(args.output_unlabeled_dataset, 'w')
        f.write("\n".join(unlabeled_dataset) + "\n")
        f.close()

        output_src_labeled_dataset, output_tgt_labeled_dataset = args.output_labeled_dataset.split(
            ",")
        f = open(output_src_labeled_dataset, 'w')
        f.write("\n".join(labeled_dataset[0]) + "\n")
        f.close()

        f = open(output_tgt_labeled_dataset, 'w')
        f.write("\n".join(labeled_dataset[1]) + "\n")
        f.close()

        f = open(args.output_oracle, 'w')
        f.write("\n".join(oracle) + "\n")
        f.close()
Esempio n. 9
0
def query_instances(args, unlabeled_dataset, oracle, active_func="random"):
    # lc stands for least confident
    # te stands for token entropy
    # tte stands for total token entropy
    assert active_func in [
        "random", "longest", "shortest", "lc", "margin", "te", "tte"
    ]

    # lengths represents number of tokens, so BPE should be removed
    lengths = np.array([
        len(remove_special_tok(remove_bpe(s)).split())
        for s in unlabeled_dataset
    ])

    # Preparations before querying instances
    # Reloading network parameters
    args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available()
    net, _ = model.get()

    assert os.path.exists(args.checkpoint)
    net, src_vocab, tgt_vocab = load_model(args.checkpoint, net)

    if args.use_cuda:
        net = net.cuda()

    # Initialize inference dataset (Unlabeled dataset)
    infer_dataset = Dataset(unlabeled_dataset, src_vocab)
    if args.batch_size is not None:
        infer_dataset.BATCH_SIZE = args.batch_size
    if args.max_batch_size is not None:
        infer_dataset.max_batch_size = args.max_batch_size
    if args.tokens_per_batch is not None:
        infer_dataset.tokens_per_batch = args.tokens_per_batch

    infer_dataiter = iter(
        infer_dataset.get_iterator(shuffle=True,
                                   group_by_size=True,
                                   include_indices=True))

    # Start ranking unlabeled dataset
    indices = np.arange(len(unlabeled_dataset))
    if active_func == "random":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        random.shuffle(result)
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference, idx)
    elif active_func == "longest":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = [(len(
            remove_special_tok(remove_bpe(
                unlabeled_dataset[item[1]])).split(' ')), item[1], item[2])
                  for item in result]
        result = sorted(result, key=lambda item: -item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", -result[idx][0])
            print("I:", args.input, args.reference, idx)
    elif active_func == "shortest":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = [(len(
            remove_special_tok(remove_bpe(
                unlabeled_dataset[item[1]])).split(' ')), item[1], item[2])
                  for item in result]
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference, idx)
        indices = indices[np.argsort(lengths[indices])]
    elif active_func in ["lc", "margin", "te", "tte"]:
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')

        for idx in range(len(result)):
            print("S:", unlabeled_dataset[result[idx][1]])
            print("H:", result[idx][2])
            print("T:", oracle[result[idx][1]])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference, result[idx][1])
Esempio n. 10
0
def main():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(help='two modes, get or translate')

    parser_get = subparsers.add_parser(
        'get', help='Get texts that needs to be labeled or translated')
    parser_get.add_argument(
        '-AO',
        '--active_out',
        type=str,
        default=None,
        help="Output file generated by active.py score mode")
    parser_get.add_argument('-tb',
                            '--tok_budget',
                            type=int,
                            help="Token budget")
    parser_get.add_argument('-bttb',
                            '--back_translation_tok_budget',
                            type=int,
                            help="Back translation token budget")
    parser_get.add_argument('--sort',
                            action="store_true",
                            help="Whether to sort active out by value")
    parser_get.add_argument('-o', '--output', type=str, help="Output filepath")
    parser_get.add_argument('-on',
                            '--output_num',
                            type=int,
                            default=1,
                            help="Output filepath")

    parser_trans = subparsers.add_parser('translate',
                                         help='Translate sentences')
    parser_trans.add_argument('-i', '--input', type=str, help='Input file')
    parser_trans.add_argument('-o', '--output', type=str, help="Output file")
    parser_trans.add_argument('--ckpt', required=True)
    parser_trans.add_argument('--max_len', type=int, default=250)
    parser_trans.add_argument('--gen_a', type=float, default=1.3)
    parser_trans.add_argument('--gen_b', type=int, default=5)
    parser_trans.add_argument('--no_cuda', action='store_true')
    parser_trans.add_argument('--batch_size', type=int, default=None)
    parser_trans.add_argument('--max_batch_size', type=int, default=None)
    parser_trans.add_argument('--tokens_per_batch', type=int, default=None)

    args = parser.parse_args()
    args.mode = "get" if hasattr(args, 'active_out') else "translate"
    if args.mode == 'translate':
        args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available()

    if args.mode == "get":
        f = open(args.active_out, 'r')
        lines = f.read().split('\n')[:-1]
        f.close()

        assert len(lines) % 4 == 0
        active_out = [(lines[idx], lines[idx + 1],
                       float(lines[idx + 2].split(' ')[-1]), lines[idx + 3])
                      for idx in range(0, len(lines), 4)]
        if args.sort:
            active_out = sorted(active_out, key=lambda item: item[2])

        indices = np.arange(len(active_out))
        lengths = np.array([
            len(
                remove_special_tok(remove_bpe(
                    item[0][len("S: "):])).split(' ')) for item in active_out
        ])
        include_oracle = np.cumsum(lengths) <= args.tok_budget
        include_pseudo = np.cumsum(lengths) <= (
            args.tok_budget + args.back_translation_tok_budget)
        include_pseudo = np.logical_xor(include_pseudo, include_oracle)
        include_pseudo = indices[include_pseudo]
        include_oracle = indices[include_oracle]
        others = [
            idx for idx in indices
            if (idx not in include_pseudo) and (idx not in include_oracle)
        ]

        # Output oracle and others
        output_oracle = args.output + '_oracle'
        f = open(output_oracle, 'w')
        out = []
        for idx in include_oracle:
            item = []
            item.append(active_out[idx][0])
            item.append('H: ' + active_out[idx][1][len('T: '):])
            item.append('T: ' + active_out[idx][1][len('T: '):])
            item.append('V: ' + str(active_out[idx][2]))
            item.append(active_out[idx][3])
            out.extend(item)

        f.write('\n'.join(out) + '\n')
        f.close()

        output_others = args.output + '_others'
        f = open(output_others, 'w')
        out = []
        for idx in others:
            item = []
            item.append(active_out[idx][0])
            item.append('H: ' + active_out[idx][1][len('T: '):])
            item.append('T: ' + active_out[idx][1][len('T: '):])
            item.append('V: ' + str(active_out[idx][2]))
            item.append(active_out[idx][3])
            out.extend(item)

        f.write('\n'.join(out) + '\n')
        f.close()

        # Output pseudo
        if args.output_num > 1:
            n_lines = len(include_pseudo) // args.output_num + 1
            for n in range(args.output_num):
                output_pseudo = args.output + '_pseudo_' + str(n)
                f = open(output_pseudo, 'w')
                out = []

                for idx in include_pseudo[n * n_lines:(n + 1) * n_lines]:
                    item = []
                    item.append(active_out[idx][0])
                    item.append('H: ' + active_out[idx][1][len('T: '):])
                    item.append('T: ' + active_out[idx][1][len('T: '):])
                    item.append('V: ' + str(active_out[idx][2]))
                    item.append(active_out[idx][3])
                    out.extend(item)

                f.write('\n'.join(out) + '\n')
                f.close()
        else:
            assert args.output_num == 1
            output_pseudo = args.output + '_pseudo'
            f = open(output_pseudo, 'w')
            out = []

            for idx in include_pseudo:
                item = []
                item.append(active_out[idx][0])
                item.append('H: ' + active_out[idx][1][len('T: '):])
                item.append('T: ' + active_out[idx][1][len('T: '):])
                item.append('V: ' + str(active_out[idx][2]))
                item.append(active_out[idx][3])
                out.extend(item)

            f.write('\n'.join(out) + '\n')
            f.close()
    elif args.mode == 'translate':

        assert args.max_len > 10

        net, _ = model.get()
        net, src_vocab, tgt_vocab = load_model(args.ckpt, net)

        if args.use_cuda:
            net = net.cuda()

        fpath = args.input
        try:
            lines = open(fpath, 'r').read().split('\n')[:-1]
            active_out = [(lines[idx], lines[idx + 1], lines[idx + 2],
                           float(lines[idx + 3].split(' ')[-1]),
                           lines[idx + 4]) for idx in range(0, len(lines), 5)]
            args.text = [a[0][len('S: '):].strip() for a in active_out]
            args.ref_text = [a[2][len('T: '):].strip() for a in active_out]
        except:
            print("error opening or reading text file")

        out = translate(args, net, src_vocab, tgt_vocab, active_out)

        f = open(args.output, 'w')
        f.write('\n'.join(out) + '\n')
        f.close()
Esempio n. 11
0
def query_instances(args,
                    unlabeled_dataset,
                    active_func="random",
                    tok_budget=None):
    # lc stands for least confident
    # te stands for token entropy
    # tte stands for total token entropy
    assert active_func in [
        "random", "longest", "shortest", "lc", "margin", "te", "tte"
    ]
    assert isinstance(tok_budget, int)

    # lengths represents number of tokens, so BPE should be removed
    lengths = np.array([
        len(remove_special_tok(remove_bpe(s)).split())
        for s in unlabeled_dataset
    ])
    total_num = sum(lengths)
    if total_num < tok_budget:
        tok_budget = total_num

    # Preparations before querying instances
    if active_func in ["lc", "margin", "te", "tte"]:
        # Reloading network parameters
        args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available()
        net, _ = model.get()

        assert os.path.exists(args.checkpoint)
        net, src_vocab, tgt_vocab = load_model(args.checkpoint, net)

        if args.use_cuda:
            net = net.cuda()

        # Initialize inference dataset (Unlabeled dataset)
        infer_dataset = Dataset(unlabeled_dataset, src_vocab)
        if args.batch_size is not None:
            infer_dataset.BATCH_SIZE = args.batch_size
        if args.max_batch_size is not None:
            infer_dataset.max_batch_size = args.max_batch_size
        if args.tokens_per_batch is not None:
            infer_dataset.tokens_per_batch = args.tokens_per_batch

        infer_dataiter = iter(
            infer_dataset.get_iterator(shuffle=True,
                                       group_by_size=True,
                                       include_indices=True))

    # Start ranking unlabeled dataset
    indices = np.arange(len(unlabeled_dataset))
    if active_func == "random":
        np.random.shuffle(indices)
    elif active_func == "longest":
        indices = indices[np.argsort(-lengths[indices])]
    elif active_func == "shortest":
        indices = indices[np.argsort(lengths[indices])]
    elif active_func in ["lc", "margin", "te", "tte"]:
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')

    include = np.cumsum(lengths[indices]) <= tok_budget
    include = indices[include]
    return [unlabeled_dataset[idx] for idx in include], include