示例#1
0
def train(model, train_dataloader, device, distance, optim, epoch, lr_scheduler, dataset):
    model.train()
    average_meter = AverageMeter()

    for i, data in enumerate(train_dataloader):
        spectrograms, targets, input_lens, target_lens, word_wise_target = data
        spectrograms, targets = Dataset.pad_batch(
            spectrograms=list(spectrograms),
            targets=list(targets)
        )
        spectrograms = spectrograms.to(device)
        targets = targets.to(device)

        # ==== forward ====
        output = model(x=spectrograms, this_model_train=True)
        output = nn.LogSoftmax(dim=2)(output)
        output = output.transpose(0, 1)     # reshape to '(input_sequence_len, batch_size, n_classes)' as described in 'https://pytorch.org/docs/master/generated/torch.nn.CTCLoss.html'
        loss = distance(output, targets, input_lens, target_lens)

        # ==== backward ====
        optim.zero_grad()
        loss.backward()
        optim.step()

        # ==== adjustments ====
        lr = lr_scheduler.new_lr()
        for param_group in optim.param_groups:
            param_group['lr'] = lr

        # ==== log ====
        if loss.item() != 0:
            average_meter.step(loss=loss.item())
        if i % 200 == 0:
            average_loss = average_meter.average()
            train_losses.append(average_loss)
            print(f'Loss: {average_loss} | Batch: {i} / {len(train_dataloader)} | Epoch: {epoch} | lr: {lr}')

    return lr
示例#2
0
def test(model, test_dataloader, device, distance):
    model.eval()
    average_meter = AverageMeter()

    with torch.no_grad():
        for i, data in enumerate(test_dataloader):
            spectrograms, targets, input_lens, target_lens, word_wise_target = data
            spectrograms, targets = Dataset.pad_batch(
                spectrograms=list(spectrograms),
                targets=list(targets)
            )
            spectrograms = spectrograms.to(device)

            # ==== forward ====
            output = model(spectrograms, this_model_train=True)
            output = nn.LogSoftmax(dim=2)(output)

            # adjust word wise targets
            adjusted_targets = []
            for target in word_wise_target:
                for word_index in target:
                    adjusted_targets.append(torch.Tensor([word_index]))
            adjusted_targets = torch.stack(adjusted_targets)

            adjusted_targets.transpose_(1, 0)
            tensor_len_delta = adjusted_targets.shape[1] - output.shape[0]
            if tensor_len_delta > 0:
                output = torch.cat((output, torch.zeros(tensor_len_delta, 1, 9896).to(device)))

            loss = distance(output, adjusted_targets, (output.shape[0],), (adjusted_targets.shape[1],))

            # ==== log ====
            if loss.item() != 0:
                average_meter.step(loss=loss.item())

    average_loss = average_meter.average()
    test_losses.append(average_loss)
    print(f'Test evaluation: Average loss: {average_loss}')
示例#3
0
test_dataset = Dataset(root=root,
                       url=train_url,
                       mode='test',
                       n_features=128,
                       download=False)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=0,
                                              collate_fn=collate_fn)

hmm = HMM(root='data/hmm_data', n_states=29)

for i, data in enumerate(test_dataloader):
    spectrograms, targets, input_lens, target_lens, _ = data
    spectrograms, targets = Dataset.pad_batch(spectrograms=list(spectrograms),
                                              targets=list(targets))
    spectrograms = spectrograms.to(device)
    targets = targets.to(device)

    # ==== forward ====
    output = model(spectrograms)

    # ==== log ====
    probabilities = output
    output, targets = decoder(output=output,
                              targets=targets,
                              dataset=test_dataset,
                              label_lens=target_lens,
                              blank_label=28)

    if USE_HMM:  # word wise hmm