Ejemplo n.º 1
0
def load_corpus_data(data_path,
                     language_name,
                     start_token,
                     end_token,
                     mask_token,
                     vocab_path,
                     rebuild_vocab,
                     unk="UNK",
                     threshold=0):
    if rebuild_vocab:
        v = Vocab(language_name,
                  start_token,
                  end_token,
                  mask_token,
                  threshold=threshold)

    corpus = []

    with open(data_path) as f:

        data = f.read().strip().split("\n")

        for line in data:
            line = line.strip()
            line = " ".join([start_token, line, end_token])

            if rebuild_vocab:
                v.add_sentence(line)

            corpus.append(line)

    data2index = []

    if rebuild_vocab:
        v.add_unk(unk)
        v.save(vocab_path)
    else:
        v = Vocab.load(vocab_path)

    for line in corpus:
        data2index.append([v.get_index(token) for token in line.split()])

    return data2index, v
Ejemplo n.º 2
0
parser.add_argument("--val_tgt_path", required=True)
parser.add_argument("--src_vocab_path", required=True)
parser.add_argument("--tgt_vocab_path", required=True)
parser.add_argument("--picture_path", required=True)

args, unknown = parser.parse_known_args()

device = args.device

s2s = load_model(args.load, device=device)

if not isinstance(s2s, S2S_attention.S2S):

    raise Exception("The model don't have attention mechanism")

src_vocab = Vocab.load(args.src_vocab_path)
tgt_vocab = Vocab.load(args.tgt_vocab_path)

with open(args.val_tgt_path) as f:

    data = f.read().split("\n")
    tgt_data = [normalizeString(line, to_ascii=False) for line in data]

with torch.no_grad():

    s2s.eval()

    with open(args.val_src_path) as f:

        data = f.read().split("\n")
Ejemplo n.º 3
0
def evaluation(local_rank, args):
    rank = args.nr * args.gpus + local_rank
    dist.init_process_group(backend="nccl",
                            init_method=args.init_method,
                            rank=rank,
                            world_size=args.world_size)

    device = torch.device("cuda", local_rank)
    torch.cuda.set_device(device)

    # List[str]
    src_data = read_data(args.test_src_path)

    tgt_prefix_data = None
    if args.tgt_prefix_file_path is not None:
        tgt_prefix_data = read_data(args.tgt_prefix_file_path)

    max_src_len = max(len(line.split()) for line in src_data) + 2
    max_tgt_len = max_src_len * 3
    logging.info("max src sentence length: {}".format(max_src_len))

    src_vocab = Vocab.load(args.src_vocab_path)
    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    padding_value = src_vocab.get_index(src_vocab.mask_token)

    assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

    src_data = convert_data_to_index(src_data, src_vocab)

    dataset = DataPartition(src_data, args.world_size, tgt_prefix_data,
                            args.work_load_per_process).dataset(rank)

    logging.info("dataset size: {}, rank: {}".format(len(dataset), rank))

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=(args.batch_size if args.batch_size else 1),
        shuffle=False,
        pin_memory=True,
        drop_last=False,
        collate_fn=lambda batch: collate_eval(
            batch,
            padding_value,
            batch_first=(True if args.transformer else False)))

    if not os.path.isdir(args.translation_output_dir):
        os.makedirs(args.translation_output_dir)

    if args.beam_size:
        logging.info("Beam size: {}".format(args.beam_size))

    if args.is_prefix:
        args.model_load = args.model_load + "*"

    for model_path in glob.glob(args.model_load):
        logging.info("Load model from: {}, rank: {}".format(model_path, rank))

        if args.transformer:

            s2s = load_transformer(model_path,
                                   len(src_vocab),
                                   max_src_len,
                                   len(tgt_vocab),
                                   max_tgt_len,
                                   padding_value,
                                   training=False,
                                   share_dec_pro_emb=args.share_dec_pro_emb,
                                   device=device)

        else:
            s2s = load_model(model_path, device=device)

        s2s.eval()

        if args.record_time:
            import time
            start_time = time.time()

        pred_data = []

        for data, tgt_prefix_batch in data_loader:
            if args.beam_size:
                pred_data.append(
                    beam_search_decoding(s2s, data.to(device,
                                                      non_blocking=True),
                                         tgt_vocab, args.beam_size, device))
            else:
                pred_data.extend(
                    greedy_decoding(s2s, data.to(device, non_blocking=True),
                                    tgt_vocab, device, tgt_prefix_batch))

        if args.record_time:
            end_time = time.time()
            logging.info("Time spend: {} seconds, rank: {}".format(
                end_time - start_time, rank))

        _, model_name = os.path.split(model_path)

        if args.beam_size:
            translation_file_name_prefix = "{}_beam_size_{}".format(
                model_name, args.beam_size)
        else:
            translation_file_name_prefix = "{}_greedy".format(model_name)

        p = os.path.join(
            args.translation_output_dir,
            "{}_translations.rank{}".format(translation_file_name_prefix,
                                            rank))

        write_data(pred_data, p)

        if args.need_tok:

            # replace '@@ ' with ''
            p_tok = os.path.join(
                args.translation_output_dir,
                "{}_translations_tok.rank{}".format(
                    translation_file_name_prefix, rank))

            tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(
                p, p_tok)

            call(tok_command, shell=True)
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--device", required=True)
    parser.add_argument("--load", required=True)
    parser.add_argument("--src_vocab_path", required=True)
    parser.add_argument("--tgt_vocab_path", required=True)
    parser.add_argument("--test_src_path", required=True)
    parser.add_argument("--test_tgt_path", required=True)
    parser.add_argument("--lang_vec_path", required=True)
    parser.add_argument("--translation_output", required=True)
    parser.add_argument("--transformer", action="store_true")

    args, unknown = parser.parse_known_args()

    src_vocab = Vocab.load(args.src_vocab_path)
    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    src_data = read_data(args.test_src_path)

    lang_vec = load_lang_vec(args.lang_vec_path)  # lang_vec: dict

    device = args.device

    print("load from {}".format(args.load))

    if args.transformer:

        max_src_length = max(len(line) for line in src_data) + 2
        max_tgt_length = max_src_length * 3
        padding_value = src_vocab.get_index(src_vocab.mask_token)
        assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

        s2s = load_transformer(args.load,
                               len(src_vocab),
                               max_src_length,
                               len(tgt_vocab),
                               max_tgt_length,
                               padding_value,
                               device=device)

    else:
        s2s = load_model(args.load, device=device)

    s2s.eval()

    pred_data = []
    for i, line in enumerate(src_data):
        pred_data.append(
            translate(line, i, s2s, src_vocab, tgt_vocab, lang_vec, device))

    write_data(pred_data, args.translation_output)

    pred_data_tok_path = args.translation_output + ".tok"

    tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(
        args.translation_output, pred_data_tok_path)
    call(tok_command, shell=True)

    bleu_calculation_command = "perl /data/rrjin/NMT/scripts/multi-bleu.perl {} < {}".format(
        args.test_tgt_path, pred_data_tok_path)
    call(bleu_calculation_command, shell=True)