示例#1
0
def get_tokenized_dataset(dataset):
    train, dev, test = get_dataset_strings(dataset)
    languages = {"antoloji": "tr", "tur": "tr", "eng": "en", "cz": "cz"}
    language = languages[dataset]
    tokenizer = get_tokenizer(language)
    train_set = [tokenize_string(x, tokenizer) for x in train]
    dev_set = [tokenize_string(x, tokenizer) for x in dev]
    test_set = [tokenize_string(x, tokenizer) for x in test]
    return train_set, dev_set, test_set
示例#2
0
def translate_devset(args):
    save_to = args.output
    token_selector = make_critic(32000, 2, 2).to(device)
    model = make_model(32000, 32000, N=6).to(device)
    try:
        checkpoint = torch.load(args.checkpoint)
    except RuntimeError:
        checkpoint = torch.load(args.checkpoint,
                                map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    token_selector.load_state_dict(checkpoint['selector_state_dict'])
    token_selector.eval()
    model.eval()
    tokenizer = get_tokenizer(args.language)
    val_iter, val_indices = make_val_iterator(args.input,
                                              tokenizer,
                                              batch_size=128)
    pad_idx = 3
    val_iter = (rebatch_single(pad_idx, b) for b in val_iter)
    decoded = [""] * len(val_indices)
    for batch in val_iter:
        dae_input = get_dae_input(batch.src,
                                  token_selector).transpose(0, 1).to(device)
        # create src and src mask from selected tokens
        dae_input_mask = (dae_input != 3).unsqueeze(-2).to(device)
        out = greedy_decode(model,
                            dae_input,
                            dae_input_mask,
                            max_len=args.max_len,
                            start_symbol=1)
        for c, decoded_row in enumerate(out):
            padded_src = batch.src[c, :].tolist()
            src_seq = []
            for item in padded_src:
                if item == 3:
                    break
                src_seq.append(item)
            index = val_indices[tuple(src_seq)]
            to_spm = []
            for item in decoded_row:
                if item == 2:
                    break
                to_spm.append(item)
            decoded_string = tokenizer.Decode(decoded_row.tolist())
            decoded[index] = decoded_string
            print(decoded_string.encode('utf8'))
        print("Decoded batch of", batch.src.shape)

    # TODO cutoff at line end and actually decode
    with open(save_to, "w", encoding="utf-8") as outfile:
        for line in decoded:
            outfile.writelines(line + "\n")

    checkpoint_time = os.stat(args.checkpoint).st_mtime
    os.utime(save_to, (checkpoint_time, checkpoint_time))
示例#3
0
def get_dataset(dataset):
    languages = {
        "antoloji": "tr",
        "tur": "tr",
        "cz": "cz",
        "turkish": "tr",
        "eng": "en",
        "tur-lower": "tr",
        "cz-lower": "cz",
        "turkish-lower": "tr",
        "eng-lower": "en"
    }
    language = languages[dataset]
    tokenizer = get_tokenizer(language)

    def tok(seq):
        return tokenizer.EncodeAsIds(seq)

    src = data.Field(tokenize=tok,
                     init_token=1,
                     eos_token=2,
                     pad_token=3,
                     use_vocab=False)
    tgt = data.Field(tokenize=tok,
                     init_token=1,
                     eos_token=2,
                     pad_token=3,
                     use_vocab=False)
    mt_train = datasets.TranslationDataset(path='data/{}/{}.train'.format(
        language, dataset),
                                           exts=('.src', '.tgt'),
                                           fields=(src, tgt))
    mt_dev = datasets.TranslationDataset(path='data/{}/{}.dev'.format(
        language, dataset),
                                         exts=('.src', '.tgt'),
                                         fields=(src, tgt))
    mt_test = datasets.TranslationDataset(path='data/{}/{}.test'.format(
        language, dataset),
                                          exts=('.src', '.tgt'),
                                          fields=(src, tgt))
    return mt_train, mt_dev, mt_test
示例#4
0
        checkpoint = torch.load(last_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model_opt.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.steps = int(last.split(".")[0])
    else:
        model.steps = 0

    print("Training with 1 GPU.")
    model = model.to(device)
    loss_train = SimpleLossCompute(model.generator, criterion, model_opt)
    loss_val = SimpleLossCompute(model.generator, criterion, None)
    for epoch in range(epochs):
        model.train()
        run_epoch((rebatch(pad_idx, b) for b in train_iter), model, loss_train, tokenizer, model_opt=model_opt,
                  criterion=criterion, save_path=save_path, exp_name=exp_name)
        #model.eval()
        #loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model, loss_val, tokenizer,
        #                 criterion=criterion, validate=True)
        #print(loss)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="tur")
    parser.add_argument("--acc_steps", default=8)
    args = parser.parse_args()
    dataset_lang = {"tur": "tr", "eng": "en", "cz": "cz", "tur-lower": "tr", "eng-lower": "en", "cz-lower": "cz"}
    tokenizer = get_tokenizer(dataset_lang[args.dataset])
    run_training(args.dataset, tokenizer, args.acc_steps)

示例#5
0
import sys, os
import sentencepiece as spm

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from src.utils.utils import get_tokenizer

from src.data_utils.data import get_training_iterators

model = nn.Transformer(nhead=8, num_encoder_layers=6)
gen = nn.Linear(512, 32000)
embed = nn.Embedding(32000, 512)

crit = nn.CrossEntropyLoss()

tokenizer = get_tokenizer("tr")

with open("data/tr/tur.train.src") as infile:
    src_raw = [
        tokenizer.EncodeAsIds(infile.readline().strip()) for x in range(10)
    ]

with open("data/tr/tur.train.tgt") as infile:
    tgt_raw = [
        tokenizer.EncodeAsIds(infile.readline().strip()) for x in range(10)
    ]

for i in range(50):
    src, tgt = embed(src_raw), embed(tgt_raw)
    memory = model.encoder(src)
    out = model.decoder(