예제 #1
0
파일: active.py 프로젝트: yuekai146/NMT
def query_instances(args, unlabeled_dataset, oracle, active_func="random"):
    # lc stands for least confident
    # te stands for token entropy
    # tte stands for total token entropy
    assert active_func in [
        "random", "longest", "shortest", "lc", "margin", "te", "tte"
    ]

    # lengths represents number of tokens, so BPE should be removed
    lengths = np.array([
        len(remove_special_tok(remove_bpe(s)).split())
        for s in unlabeled_dataset
    ])

    # Preparations before querying instances
    # Reloading network parameters
    args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available()
    net, _ = model.get()

    assert os.path.exists(args.checkpoint)
    net, src_vocab, tgt_vocab = load_model(args.checkpoint, net)

    if args.use_cuda:
        net = net.cuda()

    # Initialize inference dataset (Unlabeled dataset)
    infer_dataset = Dataset(unlabeled_dataset, src_vocab)
    if args.batch_size is not None:
        infer_dataset.BATCH_SIZE = args.batch_size
    if args.max_batch_size is not None:
        infer_dataset.max_batch_size = args.max_batch_size
    if args.tokens_per_batch is not None:
        infer_dataset.tokens_per_batch = args.tokens_per_batch

    infer_dataiter = iter(
        infer_dataset.get_iterator(shuffle=True,
                                   group_by_size=True,
                                   include_indices=True))

    # Start ranking unlabeled dataset
    indices = np.arange(len(unlabeled_dataset))
    if active_func == "random":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        random.shuffle(result)
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference, idx)
    elif active_func == "longest":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = [(len(
            remove_special_tok(remove_bpe(
                unlabeled_dataset[item[1]])).split(' ')), item[1], item[2])
                  for item in result]
        result = sorted(result, key=lambda item: -item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", -result[idx][0])
            print("I:", args.input, args.reference, idx)
    elif active_func == "shortest":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = [(len(
            remove_special_tok(remove_bpe(
                unlabeled_dataset[item[1]])).split(' ')), item[1], item[2])
                  for item in result]
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference, idx)
        indices = indices[np.argsort(lengths[indices])]
    elif active_func in ["lc", "margin", "te", "tte"]:
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')

        for idx in range(len(result)):
            print("S:", unlabeled_dataset[result[idx][1]])
            print("H:", result[idx][2])
            print("T:", oracle[result[idx][1]])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference, result[idx][1])
예제 #2
0
파일: active.py 프로젝트: yuekai146/NMT
def query_instances(args,
                    unlabeled_dataset,
                    oracle,
                    active_func="random",
                    labeled_dataset=None):
    # lc stands for least confident
    # te stands for token entropy
    # tte stands for total token entropy
    assert active_func in [
        "random", "longest", "shortest", "lc", "margin", "te", "tte", "dden"
    ]

    # lengths represents number of tokens, so BPE should be removed
    lengths = np.array([
        len(remove_special_tok(remove_bpe(s)).split())
        for s in unlabeled_dataset
    ])

    # Preparations before querying instances
    # Reloading network parameters
    args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available()
    net, _ = model.get()

    assert os.path.exists(args.checkpoint)
    net, src_vocab, tgt_vocab = load_model(args.checkpoint, net)

    if args.use_cuda:
        net = net.cuda()

    # Initialize inference dataset (Unlabeled dataset)
    infer_dataset = Dataset(unlabeled_dataset, src_vocab)
    if args.batch_size is not None:
        infer_dataset.BATCH_SIZE = args.batch_size
    if args.max_batch_size is not None:
        infer_dataset.max_batch_size = args.max_batch_size
    if args.tokens_per_batch is not None:
        infer_dataset.tokens_per_batch = args.tokens_per_batch

    infer_dataiter = iter(
        infer_dataset.get_iterator(shuffle=True,
                                   group_by_size=True,
                                   include_indices=True))

    # Start ranking unlabeled dataset
    indices = np.arange(len(unlabeled_dataset))
    if active_func == "random":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        random.shuffle(result)
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference,
                  idx + args.previous_num_sents)
    elif active_func == "longest":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = [(len(
            remove_special_tok(remove_bpe(
                unlabeled_dataset[item[1]])).split(' ')), item[1], item[2])
                  for item in result]
        result = sorted(result, key=lambda item: -item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", -result[idx][0])
            print("I:", args.input, args.reference,
                  idx + args.previous_num_sents)
    elif active_func == "shortest":
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = [(len(
            remove_special_tok(remove_bpe(
                unlabeled_dataset[item[1]])).split(' ')), item[1], item[2])
                  for item in result]
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("H:", result[idx][2])
            print("T:", oracle[idx])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference,
                  idx + args.previous_num_sents)
        indices = indices[np.argsort(lengths[indices])]
    elif active_func in ["lc", "margin", "te", "tte"]:
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')

        for idx in range(len(result)):
            print("S:", unlabeled_dataset[result[idx][1]])
            print("H:", result[idx][2])
            print("T:", oracle[result[idx][1]])
            print("V:", result[idx][0])
            print("I:", args.input, args.reference,
                  result[idx][1] + args.previous_num_sents)
    elif active_func == "dden":
        punc = [
            ".", ",", "?", "!", "'", "<", ">", ":", ";", "(", ")", "{", "}",
            "[", "]", "-", "..", "...", "...."
        ]
        lamb1 = 1
        lamb2 = 1
        p_u = {}
        unlabeled_dataset_without_bpe = []
        labeled_dataset_without_bpe = [[], []]
        for s in unlabeled_dataset:
            unlabeled_dataset_without_bpe.append(
                remove_special_tok(remove_bpe(s)))
        for s in labeled_dataset[0]:
            labeled_dataset_without_bpe[0].append(
                remove_special_tok(remove_bpe(s)))
        for s in labeled_dataset[1]:
            labeled_dataset_without_bpe[1].append(
                remove_special_tok(remove_bpe(s)))
        for s in unlabeled_dataset_without_bpe:
            sentence = s.split()
            for token in sentence:
                if token not in punc:
                    if token in p_u.keys():
                        p_u[token] += 1
                    else:
                        p_u[token] = 1
        total_dden = 0
        for token in p_u.keys():
            p_u[token] = math.log(p_u[token] + 1)
            total_dden += p_u[token]
        for token in p_u.keys():
            p_u[token] /= total_dden
        count_l = {}
        for s in labeled_dataset_without_bpe[0]:
            sentence = s.split()
            for token in sentence:
                if token not in punc:
                    if token in count_l.keys():
                        count_l[token] += 1
                    else:
                        count_l[token] = 1
        dden = []
        for s in unlabeled_dataset_without_bpe:
            sentence = s.split()
            len_for_sentence = 0
            sum_for_sentence = 0
            for token in sentence:
                if token not in punc:
                    if token in count_l.keys():
                        sum_for_sentence += p_u[token] * math.exp(
                            -lamb1 * count_l[token])
                    else:
                        sum_for_sentence += p_u[token]
                len_for_sentence += 1
            if len_for_sentence != 0:
                sum_for_sentence /= len_for_sentence
            dden.append(sum_for_sentence)
        unlabeled_with_index = []
        for i in range((len(unlabeled_dataset))):
            unlabeled_with_index.append((dden[i], i))
        unlabeled_with_index.sort(key=lambda x: x[0], reverse=True)
        count_batch = {}
        dden_new = []
        for _, i in unlabeled_with_index:
            sentence = unlabeled_dataset_without_bpe[i].split()
            len_for_sentence = 0
            sum_for_sentence = 0
            for token in sentence:
                if token not in punc:
                    p_tmp = p_u[token]
                    if token in count_batch.keys():
                        p_tmp = 0
                        p_tmp *= math.exp(-lamb2 * count_batch[token])
                    if token in count_l.keys():
                        p_tmp *= math.exp(-lamb1 * count_l[token])
                    sum_for_sentence += p_tmp
                len_for_sentence += 1
            for token in sentence:
                if token not in punc:
                    if token in count_batch.keys():
                        count_batch[token] += 1
                    else:
                        count_batch[token] = 1
            if len_for_sentence != 0:
                sum_for_sentence /= len_for_sentence
            dden_new.append((sum_for_sentence, i))
        dden_new.sort(key=lambda x: x[1])
        dden_sort = []
        for dden_num, _ in dden_new:
            dden_sort.append(dden_num)
        ddens = np.array(dden_sort)
        indices = indices[np.argsort(-ddens)]
        for idx in indices:
            print("S:", unlabeled_dataset[idx])
            print("T:", oracle[idx])
            print("V:", -ddens[idx])
            print("I:", args.input, args.reference, idx)
예제 #3
0
def query_instances(args,
                    unlabeled_dataset,
                    active_func="random",
                    tok_budget=None):
    # lc stands for least confident
    # te stands for token entropy
    # tte stands for total token entropy
    assert active_func in [
        "random", "longest", "shortest", "lc", "margin", "te", "tte"
    ]
    assert isinstance(tok_budget, int)

    # lengths represents number of tokens, so BPE should be removed
    lengths = np.array([
        len(remove_special_tok(remove_bpe(s)).split())
        for s in unlabeled_dataset
    ])
    total_num = sum(lengths)
    if total_num < tok_budget:
        tok_budget = total_num

    # Preparations before querying instances
    if active_func in ["lc", "margin", "te", "tte"]:
        # Reloading network parameters
        args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available()
        net, _ = model.get()

        assert os.path.exists(args.checkpoint)
        net, src_vocab, tgt_vocab = load_model(args.checkpoint, net)

        if args.use_cuda:
            net = net.cuda()

        # Initialize inference dataset (Unlabeled dataset)
        infer_dataset = Dataset(unlabeled_dataset, src_vocab)
        if args.batch_size is not None:
            infer_dataset.BATCH_SIZE = args.batch_size
        if args.max_batch_size is not None:
            infer_dataset.max_batch_size = args.max_batch_size
        if args.tokens_per_batch is not None:
            infer_dataset.tokens_per_batch = args.tokens_per_batch

        infer_dataiter = iter(
            infer_dataset.get_iterator(shuffle=True,
                                       group_by_size=True,
                                       include_indices=True))

    # Start ranking unlabeled dataset
    indices = np.arange(len(unlabeled_dataset))
    if active_func == "random":
        np.random.shuffle(indices)
    elif active_func == "longest":
        indices = indices[np.argsort(-lengths[indices])]
    elif active_func == "shortest":
        indices = indices[np.argsort(lengths[indices])]
    elif active_func in ["lc", "margin", "te", "tte"]:
        result = get_scores(args, net, active_func, infer_dataiter, src_vocab,
                            tgt_vocab)
        result = sorted(result, key=lambda item: item[0])
        indices = [item[1] for item in result]
        indices = np.array(indices).astype('int')

    include = np.cumsum(lengths[indices]) <= tok_budget
    include = indices[include]
    return [unlabeled_dataset[idx] for idx in include], include