예제 #1
0
def load_dl(name, bsize=B_SIZE, device=torch.device('cpu')):
    # preliminary stuff, create collate function based on device
    collate_fn = make_collate_fn(device)
    print('loading data...')
    path = data_path_dict[task]
    path = op.join(path, name)
    p = t_dict[task]()
    p.load(path)
    dataloader = DataLoader(p.to_dataset(device=device),
                            batch_size=bsize,
                            shuffle=True,
                            collate_fn=collate_fn)
    print('done')
    return dataloader
예제 #2
0
파일: main.py 프로젝트: mujjingun/LAS
def main(args):
    # select device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load mapping
    with open(args.path + 'mapping.pkl', 'rb') as f:
        mapping = pickle.load(f)
    pad_idx = max(mapping.values()) + 1
    vocab_size = pad_idx + 1

    # inverse mapping
    inv_map = {v: k for k, v in mapping.items()}
    inv_map[pad_idx] = ""

    # convert sequence to sentence
    def seq_to_sen(seq):
        seq = seq + [1]
        eos_idx = seq.index(1)
        return "".join(inv_map[i] for i in seq[1:eos_idx])

    # load dataset
    data = dataset.SpeechDataset('train-clean-100.csv', args.path, True)

    # train validation split
    train_size = math.ceil(len(data) * 0.9)
    val_size = len(data) - train_size
    train, val = torch.utils.data.random_split(data, [train_size, val_size])

    # make dataloaders
    collate_fn = dataset.make_collate_fn(pad_idx)
    train = torch.utils.data.DataLoader(
        train,
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        num_workers=4
    )
    val = torch.utils.data.DataLoader(
        val,
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        num_workers=4
    )
    test = torch.utils.data.DataLoader(
        dataset.SpeechDataset('test-clean.csv', args.path, True),
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        num_workers=4
    )

    # construct model
    las_model = model.LAS(device, vocab_size, pad_idx,
                          start_lr=args.start_lr,
                          decay_steps=args.decay_steps)

    # load model
    if os.path.exists(args.file_name):
        print("Loading model from file ", args.file_name)
        las_model.load(args.file_name)
        print("Loaded.")

    if not args.test:
        for epoch in range(args.epochs):
            print("Epoch ", epoch)

            train_losses = []
            val_losses = []

            pbar = tqdm.tqdm(train)
            for source, target in pbar:
                source, target = source.to(device), target.to(device)
                loss = las_model.train_step(source, target)
                train_losses.append(loss)
                pbar.set_description("Loss = {:.6f}".format(loss))
            print("Train loss ", np.mean(train_losses))

            for source, target in val:
                source, target = source.to(device), target.to(device)
                with torch.no_grad():
                    loss = las_model.loss(source, target, 0).item()
                val_losses.append(loss)
            print("Val loss ", np.mean(val_losses))

            # save model
            print("Saving to ", args.file_name)
            las_model.save(args.file_name)
            print("Saved.")
    else:
        scores = []
        generated = []
        for source, target in tqdm.tqdm(test):
            source = source.to(device)
            pred_batch = las_model.predict(source, max_length=100)
            for seq, truth in zip(pred_batch, target.numpy().tolist()):
                sen = seq_to_sen(seq)
                truth = seq_to_sen(truth)
                wer = jiwer.wer(truth, sen)
                generated.append(sen)
                scores.append(wer)
        with open("pred.txt", "w") as file:
            file.write("\n".join(generated))
        print("Avg WER = ", np.mean(scores))
예제 #3
0
        args.data_path,
        "sample_submission_toy.csv"
        if args.toy in ["True", "toy"] else "sample_submission.csv",
    ))

tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                          do_lower_case=("uncased"
                                                         in args.bert_model))

test_set = get_test_set(args, test_df, tokenizer)
test_loader = DataLoader(
    test_set,
    batch_sampler=BucketingSampler(test_set.lengths,
                                   batch_size=args.batch_size,
                                   maxlen=args.max_sequence_length),
    collate_fn=make_collate_fn(),
)

for fold, train_set, valid_set, train_fold_df, val_fold_df in cross_validation_split(
        args, train_df, tokenizer):

    print()
    print("Fold:", fold)
    print()

    valid_loader = DataLoader(
        valid_set,
        batch_sampler=BucketingSampler(
            valid_set.lengths,
            batch_size=args.batch_size,
            maxlen=args.max_sequence_length,