def evaluate(model, data_loader, criterion, device, save_output=False): total_loss = 0. total_num = 0 total_dist = 0 total_length = 0 total_sent_num = 0 transcripts_list = [] model.eval() with torch.no_grad(): for i_batch, (data) in tqdm(enumerate(data_loader), total=len(data_loader)): batch_x, batch_y, feat_lengths, script_lengths = data batch_x = batch_x.to(device) batch_y = batch_y.to(device) feat_lengths = feat_lengths.to(device) src_len = batch_y.size(1) target = batch_y[:, 1:] logit = model(batch_x, feat_lengths, None, teacher_forcing_ratio=0.0) logit = torch.stack(logit, dim=1).to(device) y_hat = logit.max(-1)[1] logit = logit[:, :target.size(1), :] loss = criterion(logit.contiguous().view(-1, logit.size(-1)), target.contiguous().view(-1)) total_loss += loss.item() total_num += sum(feat_lengths).item() dist, length, transcript = get_distance(target, y_hat) cer = float(dist / length) * 100 total_dist += dist total_length += length if save_output == True: transcripts_list.append(transcript) total_sent_num += target.size(0) aver_loss = total_loss / total_num aver_cer = float(total_dist / total_length) * 100 return aver_loss, aver_cer, transcripts_list
def demo(model, test_data, criterion, device, save_output=False): total_loss = 0. total_num = 0 total_dist = 0 total_length = 0 total_sent_num = 0 transcripts_list = [] model.eval() with torch.no_grad() : for i in range(len(test_data)) : batch_x = test_data[i][0][np.newaxis,np.newaxis,:,:] batch_y = test_data[i][1] target = test_data[i][2][np.newaxis, :] feat_lengths = np.array([batch_x.shape[3]]) batch_x = torch.from_numpy(batch_x) batch_x = batch_x.to(device) feat_lengths = torch.from_numpy(feat_lengths) feat_lengths = feat_lengths.to(device) target = target[:,1:] target = torch.from_numpy(target) target = target.to(device) logit = model(batch_x, feat_lengths, None, teacher_forcing_ratio=0.0) logit = torch.stack(logit, dim=1).to(device) y_hat = logit.max(-1)[1] dist, length, transcript = get_distance(target, y_hat) #cer = float(dist / length) * 100 total_dist += dist total_length += length if save_output == True : transcripts_list.append(transcript) total_sent_num += target.size(0) aver_loss = 0#total_loss / total_num aver_cer = 0#float(total_dist / total_length) * 100 return aver_loss, aver_cer, transcripts_list
def train(model, data_loader, criterion, optimizer, device, epoch, max_norm=400, teacher_forcing_ratio=1): total_loss = 0. total_num = 0 total_dist = 0 total_length = 0 total_sent_num = 0 model.train() for i_batch, (data) in enumerate(data_loader): batch_x, batch_y, feat_lengths, script_lengths = data optimizer.zero_grad() batch_x = batch_x.to(device) batch_y = batch_y.to(device) feat_lengths = feat_lengths.to(device) src_len = batch_y.size(1) target = batch_y[:, 1:] logit = model(batch_x, feat_lengths, batch_y, teacher_forcing_ratio=teacher_forcing_ratio) logit = torch.stack(logit, dim=1).to(device) y_hat = logit.max(-1)[1] loss = criterion(logit.contiguous().view(-1, logit.size(-1)), target.contiguous().view(-1)) total_loss += loss.item() total_num += sum(feat_lengths).item() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() dist, length, _ = get_distance(target, y_hat) total_dist += dist total_length += length cer = float(dist / length) * 100 total_sent_num += batch_y.size(0) if i_batch % 1000 == 0 and i_batch != 0: state = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(state, args.model_path) print('Epoch: [{0}][{1}/{2}]\t Loss {loss:.4f}\t Cer {cer:.4f}'.format( (epoch + 1), (i_batch + 1), len(train_sampler), loss=loss, cer=cer)) return total_loss / total_num, (total_dist / total_length) * 100