def train(model, train_dataloader, device, distance, optim, epoch, lr_scheduler, dataset): model.train() average_meter = AverageMeter() for i, data in enumerate(train_dataloader): spectrograms, targets, input_lens, target_lens, word_wise_target = data spectrograms, targets = Dataset.pad_batch( spectrograms=list(spectrograms), targets=list(targets) ) spectrograms = spectrograms.to(device) targets = targets.to(device) # ==== forward ==== output = model(x=spectrograms, this_model_train=True) output = nn.LogSoftmax(dim=2)(output) output = output.transpose(0, 1) # reshape to '(input_sequence_len, batch_size, n_classes)' as described in 'https://pytorch.org/docs/master/generated/torch.nn.CTCLoss.html' loss = distance(output, targets, input_lens, target_lens) # ==== backward ==== optim.zero_grad() loss.backward() optim.step() # ==== adjustments ==== lr = lr_scheduler.new_lr() for param_group in optim.param_groups: param_group['lr'] = lr # ==== log ==== if loss.item() != 0: average_meter.step(loss=loss.item()) if i % 200 == 0: average_loss = average_meter.average() train_losses.append(average_loss) print(f'Loss: {average_loss} | Batch: {i} / {len(train_dataloader)} | Epoch: {epoch} | lr: {lr}') return lr
def test(model, test_dataloader, device, distance): model.eval() average_meter = AverageMeter() with torch.no_grad(): for i, data in enumerate(test_dataloader): spectrograms, targets, input_lens, target_lens, word_wise_target = data spectrograms, targets = Dataset.pad_batch( spectrograms=list(spectrograms), targets=list(targets) ) spectrograms = spectrograms.to(device) # ==== forward ==== output = model(spectrograms, this_model_train=True) output = nn.LogSoftmax(dim=2)(output) # adjust word wise targets adjusted_targets = [] for target in word_wise_target: for word_index in target: adjusted_targets.append(torch.Tensor([word_index])) adjusted_targets = torch.stack(adjusted_targets) adjusted_targets.transpose_(1, 0) tensor_len_delta = adjusted_targets.shape[1] - output.shape[0] if tensor_len_delta > 0: output = torch.cat((output, torch.zeros(tensor_len_delta, 1, 9896).to(device))) loss = distance(output, adjusted_targets, (output.shape[0],), (adjusted_targets.shape[1],)) # ==== log ==== if loss.item() != 0: average_meter.step(loss=loss.item()) average_loss = average_meter.average() test_losses.append(average_loss) print(f'Test evaluation: Average loss: {average_loss}')
test_dataset = Dataset(root=root, url=train_url, mode='test', n_features=128, download=False) test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn) hmm = HMM(root='data/hmm_data', n_states=29) for i, data in enumerate(test_dataloader): spectrograms, targets, input_lens, target_lens, _ = data spectrograms, targets = Dataset.pad_batch(spectrograms=list(spectrograms), targets=list(targets)) spectrograms = spectrograms.to(device) targets = targets.to(device) # ==== forward ==== output = model(spectrograms) # ==== log ==== probabilities = output output, targets = decoder(output=output, targets=targets, dataset=test_dataset, label_lens=target_lens, blank_label=28) if USE_HMM: # word wise hmm