def train(local_rank, args):

    setup_seed(args.seed)

    rank = args.nr * args.gpus + local_rank

    saved_model_dir, _ = os.path.split(args.checkpoint)

    if not os.path.isdir(saved_model_dir):
        os.makedirs(saved_model_dir)

    src_data, src_vocab = load_corpus_data(args.src_path, args.src_language,
                                           args.start_token, args.end_token,
                                           args.mask_token,
                                           args.src_vocab_path,
                                           args.rebuild_vocab, args.unk,
                                           args.threshold)

    tgt_data, tgt_vocab = load_corpus_data(args.tgt_path, args.tgt_language,
                                           args.start_token, args.end_token,
                                           args.mask_token,
                                           args.tgt_vocab_path,
                                           args.rebuild_vocab, args.unk,
                                           args.threshold)

    args.src_vocab_size = len(src_vocab)
    args.tgt_vocab_size = len(tgt_vocab)

    logging.info("Source language vocab size: {}".format(len(src_vocab)))
    logging.info("Target language vocab size: {}".format(len(tgt_vocab)))

    assert len(src_data) == len(tgt_data)

    if args.sort_sentence_by_length:
        src_data, tgt_data = sort_src_sentence_by_length(
            list(zip(src_data, tgt_data)))

    logging.info("Transformer")

    max_src_len = max(len(line) for line in src_data)
    max_tgt_len = max(len(line) for line in tgt_data)

    args.max_src_len = max_src_len
    args.max_tgt_len = max_tgt_len

    padding_value = src_vocab.get_index(args.mask_token)

    assert padding_value == tgt_vocab.get_index(args.mask_token)
    args.padding_value = padding_value

    logging.info("Multi GPU training")

    dist.init_process_group(backend="nccl",
                            init_method=args.init_method,
                            rank=rank,
                            world_size=args.world_size)

    device = torch.device("cuda", local_rank)

    torch.cuda.set_device(device)

    if args.load:

        logging.info("Load existing model from {}".format(args.load))
        s2s, optimizer_state_dict = load_transformer(args,
                                                     training=True,
                                                     device=device)
        s2s = nn.parallel.DistributedDataParallel(s2s, device_ids=[local_rank])
        optimizer = get_optimizer(s2s.parameters(), args)
        optimizer.load_state_dict(optimizer_state_dict)

    else:
        logging.info("New model")
        s2s = build_transformer(args, device)
        s2s.init_parameters()
        s2s = nn.parallel.DistributedDataParallel(s2s, device_ids=[local_rank])
        optimizer = get_optimizer(s2s.parameters(), args)

    s2s.train()

    if args.label_smoothing:
        logging.info("Label Smoothing!")
        criterion = LabelSmoothingLoss(args.label_smoothing, padding_value)
    else:
        criterion = nn.CrossEntropyLoss(ignore_index=padding_value)

    train_data = NMTDataset(src_data, tgt_data)

    # release cpu memory
    del src_data
    del tgt_data

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_data, num_replicas=args.world_size, rank=rank)
    train_loader = DataLoader(train_data,
                              args.batch_size,
                              shuffle=False,
                              sampler=train_sampler,
                              drop_last=True,
                              pin_memory=True,
                              collate_fn=lambda batch: collate(
                                  batch, padding_value, batch_first=True))

    for i in range(args.start_epoch, args.end_epoch):

        train_sampler.set_epoch(i)

        epoch_loss = 0.0

        start_time = time.time()

        steps = 0

        for j, (input_batch, target_batch) in enumerate(train_loader):

            if args.update_freq == 1:
                need_update = True
            else:
                need_update = True if (j +
                                       1) % args.update_freq == 0 else False

            input_batch = input_batch.to(device, non_blocking=True)
            target_batch = target_batch.to(device, non_blocking=True)

            output = s2s(input_batch, target_batch[:, :-1])
            del input_batch
            output = output.view(-1, output.size(-1))
            target_batch = target_batch[:, 1:].contiguous().view(-1)

            batch_loss = criterion(output, target_batch)
            del target_batch
            del output

            # synchronize all processes
            # Gradient synchronization communications take place during the backward pass and overlap
            # with the backward computation. When the backward() returns, param.grad already contains
            # the synchronized gradient tensor.
            dist.barrier()
            batch_loss.backward()

            if need_update:
                optimizer.step()
                optimizer.zero_grad()

            batch_loss = batch_loss.item()

            epoch_loss += batch_loss

            steps += 1

        if (steps + 1) % args.update_freq != 0:
            optimizer.step()
            optimizer.zero_grad()

        epoch_loss /= steps

        epoch_ppl = math.exp(epoch_loss)

        logging.info(
            "Epoch: {}, time: {} seconds, loss: {}, perplexity: {}, local rank: {}"
            .format(i,
                    time.time() - start_time, epoch_loss, epoch_ppl,
                    local_rank))
        if local_rank == 0:
            torch.save(save_transformer(s2s, optimizer, args),
                       "{}_{}_{}".format(args.checkpoint, i, steps))

    torch.save(save_transformer(s2s, optimizer, args),
               args.checkpoint + "_rank{}".format(local_rank))
Beispiel #2
0
    drop_last=False)

if args.is_prefix:
    args.model_load = args.model_load + "*"

for model_path in glob.glob(args.model_load):

    print("Load model from {}".format(model_path))

    if args.transformer:

        s2s = load_transformer(model_path,
                               len(src_vocab),
                               max_src_len,
                               len(tgt_vocab),
                               max_tgt_len,
                               padding_value,
                               training=False,
                               share_dec_pro_emb=args.share_dec_pro_emb,
                               device=device)

    else:
        s2s = load_model(model_path, device=device)

    s2s.eval()

    pred_data = []

    if args.record_time:
        import time
        start_time = time.time()
Beispiel #3
0
def evaluation(local_rank, args):
    rank = args.nr * args.gpus + local_rank
    dist.init_process_group(backend="nccl",
                            init_method=args.init_method,
                            rank=rank,
                            world_size=args.world_size)

    device = torch.device("cuda", local_rank)
    torch.cuda.set_device(device)

    # List[str]
    src_data = read_data(args.test_src_path)

    tgt_prefix_data = None
    if args.tgt_prefix_file_path is not None:
        tgt_prefix_data = read_data(args.tgt_prefix_file_path)

    max_src_len = max(len(line.split()) for line in src_data) + 2
    max_tgt_len = max_src_len * 3
    logging.info("max src sentence length: {}".format(max_src_len))

    src_vocab = Vocab.load(args.src_vocab_path)
    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    padding_value = src_vocab.get_index(src_vocab.mask_token)

    assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

    src_data = convert_data_to_index(src_data, src_vocab)

    dataset = DataPartition(src_data, args.world_size, tgt_prefix_data,
                            args.work_load_per_process).dataset(rank)

    logging.info("dataset size: {}, rank: {}".format(len(dataset), rank))

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=(args.batch_size if args.batch_size else 1),
        shuffle=False,
        pin_memory=True,
        drop_last=False,
        collate_fn=lambda batch: collate_eval(
            batch,
            padding_value,
            batch_first=(True if args.transformer else False)))

    if not os.path.isdir(args.translation_output_dir):
        os.makedirs(args.translation_output_dir)

    if args.beam_size:
        logging.info("Beam size: {}".format(args.beam_size))

    if args.is_prefix:
        args.model_load = args.model_load + "*"

    for model_path in glob.glob(args.model_load):
        logging.info("Load model from: {}, rank: {}".format(model_path, rank))

        if args.transformer:

            s2s = load_transformer(model_path,
                                   len(src_vocab),
                                   max_src_len,
                                   len(tgt_vocab),
                                   max_tgt_len,
                                   padding_value,
                                   training=False,
                                   share_dec_pro_emb=args.share_dec_pro_emb,
                                   device=device)

        else:
            s2s = load_model(model_path, device=device)

        s2s.eval()

        if args.record_time:
            import time
            start_time = time.time()

        pred_data = []

        for data, tgt_prefix_batch in data_loader:
            if args.beam_size:
                pred_data.append(
                    beam_search_decoding(s2s, data.to(device,
                                                      non_blocking=True),
                                         tgt_vocab, args.beam_size, device))
            else:
                pred_data.extend(
                    greedy_decoding(s2s, data.to(device, non_blocking=True),
                                    tgt_vocab, device, tgt_prefix_batch))

        if args.record_time:
            end_time = time.time()
            logging.info("Time spend: {} seconds, rank: {}".format(
                end_time - start_time, rank))

        _, model_name = os.path.split(model_path)

        if args.beam_size:
            translation_file_name_prefix = "{}_beam_size_{}".format(
                model_name, args.beam_size)
        else:
            translation_file_name_prefix = "{}_greedy".format(model_name)

        p = os.path.join(
            args.translation_output_dir,
            "{}_translations.rank{}".format(translation_file_name_prefix,
                                            rank))

        write_data(pred_data, p)

        if args.need_tok:

            # replace '@@ ' with ''
            p_tok = os.path.join(
                args.translation_output_dir,
                "{}_translations_tok.rank{}".format(
                    translation_file_name_prefix, rank))

            tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(
                p, p_tok)

            call(tok_command, shell=True)
Beispiel #4
0
def train(args):

    setup_seed(args.seed)

    saved_model_dir, _ = os.path.split(args.checkpoint)

    if not os.path.isdir(saved_model_dir):
        os.makedirs(saved_model_dir)

    device = args.device

    torch.cuda.set_device(device)

    src_data, src_vocab = load_corpus_data(args.src_path, args.src_language,
                                           args.start_token, args.end_token,
                                           args.mask_token,
                                           args.src_vocab_path,
                                           args.rebuild_vocab, args.unk,
                                           args.threshold)

    tgt_data, tgt_vocab = load_corpus_data(args.tgt_path, args.tgt_language,
                                           args.start_token, args.end_token,
                                           args.mask_token,
                                           args.tgt_vocab_path,
                                           args.rebuild_vocab, args.unk,
                                           args.threshold)

    args.src_vocab_size = len(src_vocab)
    args.tgt_vocab_size = len(tgt_vocab)

    logging.info("Source language vocab size: {}".format(len(src_vocab)))
    logging.info("Target language vocab size: {}".format(len(tgt_vocab)))

    assert len(src_data) == len(tgt_data)

    if args.sort_sentence_by_length:
        src_data, tgt_data = sort_src_sentence_by_length(
            list(zip(src_data, tgt_data)))

    logging.info("Transformer")

    max_src_len = max(len(line) for line in src_data)
    max_tgt_len = max(len(line) for line in tgt_data)

    args.max_src_len = max_src_len
    args.max_tgt_len = max_tgt_len

    padding_value = src_vocab.get_index(args.mask_token)

    assert padding_value == tgt_vocab.get_index(args.mask_token)
    args.padding_value = padding_value

    if args.load:

        logging.info("Load existing model from {}".format(args.load))
        s2s, optimizer_state_dict = load_transformer(args,
                                                     training=True,
                                                     device=device)
        optimizer = get_optimizer(s2s.parameters(), args)
        optimizer.load_state_dict(optimizer_state_dict)

    else:
        logging.info("New model")
        s2s = build_transformer(args, device)
        s2s.init_parameters()
        optimizer = get_optimizer(s2s.parameters(), args)

    s2s.train()

    if args.label_smoothing:
        criterion = LabelSmoothingLoss(args.label_smoothing, padding_value)
    else:
        criterion = nn.CrossEntropyLoss(ignore_index=padding_value)

    train_data = NMTDataset(src_data, tgt_data)

    # release cpu memory
    del src_data
    del tgt_data

    train_loader = DataLoader(train_data,
                              args.batch_size,
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=lambda batch: collate(
                                  batch, padding_value, batch_first=True))

    for i in range(args.start_epoch, args.end_epoch):

        epoch_loss = 0.0

        start_time = time.time()

        steps = 0

        for j, (input_batch, target_batch) in enumerate(train_loader):

            batch_loss = s2s.train_batch(
                input_batch.to(device, non_blocking=True),
                target_batch.to(device, non_blocking=True), criterion,
                optimizer, j, args.update_freq)

            epoch_loss += batch_loss

            steps += 1

        if (steps + 1) % args.update_freq != 0:
            optimizer.step()
            optimizer.zero_grad()

        epoch_loss /= steps

        epoch_ppl = math.exp(epoch_loss)

        torch.save(save_transformer(s2s, optimizer, args),
                   "{}_{}_{}".format(args.checkpoint, i, steps))
        logging.info(
            "Epoch: {}, time: {} seconds, loss: {}, perplexity: {}".format(
                i,
                time.time() - start_time, epoch_loss, epoch_ppl))
Beispiel #5
0
args, unknown = parser.parse_known_args()

src_vocab = Vocab.load(args.src_vocab_path)

if args.transformer:

    assert args.tgt_vocab_path is not None

    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    padding_value = src_vocab.get_index(src_vocab.mask_token)

    assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

    s2s = load_transformer(args.load, len(src_vocab), 10, len(tgt_vocab), 10,
                           padding_value)

else:

    # multi_layer_token_embedding only support transformer now
    assert args.lang_vec_type == "token_embedding"
    s2s = load_model(args.load)

s2s.eval()

lang_token_list = read_data(args.lang_token_list_path)

lang_vec = {}

for lang_token in lang_token_list:
Beispiel #6
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--device", required=True)
    parser.add_argument("--load", required=True)
    parser.add_argument("--src_vocab_path", required=True)
    parser.add_argument("--tgt_vocab_path", required=True)
    parser.add_argument("--test_src_path", required=True)
    parser.add_argument("--test_tgt_path", required=True)
    parser.add_argument("--lang_vec_path", required=True)
    parser.add_argument("--translation_output", required=True)
    parser.add_argument("--transformer", action="store_true")

    args, unknown = parser.parse_known_args()

    src_vocab = Vocab.load(args.src_vocab_path)
    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    src_data = read_data(args.test_src_path)

    lang_vec = load_lang_vec(args.lang_vec_path)  # lang_vec: dict

    device = args.device

    print("load from {}".format(args.load))

    if args.transformer:

        max_src_length = max(len(line) for line in src_data) + 2
        max_tgt_length = max_src_length * 3
        padding_value = src_vocab.get_index(src_vocab.mask_token)
        assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

        s2s = load_transformer(args.load,
                               len(src_vocab),
                               max_src_length,
                               len(tgt_vocab),
                               max_tgt_length,
                               padding_value,
                               device=device)

    else:
        s2s = load_model(args.load, device=device)

    s2s.eval()

    pred_data = []
    for i, line in enumerate(src_data):
        pred_data.append(
            translate(line, i, s2s, src_vocab, tgt_vocab, lang_vec, device))

    write_data(pred_data, args.translation_output)

    pred_data_tok_path = args.translation_output + ".tok"

    tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(
        args.translation_output, pred_data_tok_path)
    call(tok_command, shell=True)

    bleu_calculation_command = "perl /data/rrjin/NMT/scripts/multi-bleu.perl {} < {}".format(
        args.test_tgt_path, pred_data_tok_path)
    call(bleu_calculation_command, shell=True)