def evaluate(model, data_loader, criterion, device, save_output=False):
    total_loss = 0.
    total_num = 0
    total_dist = 0
    total_length = 0
    total_sent_num = 0
    transcripts_list = []

    model.eval()
    with torch.no_grad():
        for i_batch, (data) in tqdm(enumerate(data_loader),
                                    total=len(data_loader)):
            batch_x, batch_y, feat_lengths, script_lengths = data

            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            feat_lengths = feat_lengths.to(device)

            src_len = batch_y.size(1)
            target = batch_y[:, 1:]

            logit = model(batch_x,
                          feat_lengths,
                          None,
                          teacher_forcing_ratio=0.0)
            logit = torch.stack(logit, dim=1).to(device)
            y_hat = logit.max(-1)[1]

            logit = logit[:, :target.size(1), :]
            loss = criterion(logit.contiguous().view(-1, logit.size(-1)),
                             target.contiguous().view(-1))
            total_loss += loss.item()
            total_num += sum(feat_lengths).item()

            dist, length, transcript = get_distance(target, y_hat)
            cer = float(dist / length) * 100

            total_dist += dist
            total_length += length
            if save_output == True:
                transcripts_list.append(transcript)
            total_sent_num += target.size(0)
    aver_loss = total_loss / total_num
    aver_cer = float(total_dist / total_length) * 100
    return aver_loss, aver_cer, transcripts_list
def demo(model, test_data, criterion, device, save_output=False):
    total_loss = 0.
    total_num = 0
    total_dist = 0
    total_length = 0
    total_sent_num = 0
    transcripts_list = []

    model.eval()
    with torch.no_grad() :
        for i in range(len(test_data)) :
            batch_x = test_data[i][0][np.newaxis,np.newaxis,:,:]
            batch_y = test_data[i][1]
            target = test_data[i][2][np.newaxis, :]
            feat_lengths = np.array([batch_x.shape[3]])

            batch_x = torch.from_numpy(batch_x)
            batch_x = batch_x.to(device)

            feat_lengths = torch.from_numpy(feat_lengths)
            feat_lengths = feat_lengths.to(device)

            target = target[:,1:]
            target = torch.from_numpy(target)
            target = target.to(device)
            logit = model(batch_x, feat_lengths, None, teacher_forcing_ratio=0.0)
            logit = torch.stack(logit, dim=1).to(device)
            y_hat = logit.max(-1)[1]

            dist, length, transcript = get_distance(target, y_hat)
            #cer = float(dist / length) * 100

            total_dist += dist
            total_length += length
            if save_output == True :
                transcripts_list.append(transcript)
            total_sent_num += target.size(0)
    aver_loss = 0#total_loss / total_num
    aver_cer = 0#float(total_dist / total_length) * 100
    return aver_loss, aver_cer, transcripts_list
def train(model,
          data_loader,
          criterion,
          optimizer,
          device,
          epoch,
          max_norm=400,
          teacher_forcing_ratio=1):
    total_loss = 0.
    total_num = 0
    total_dist = 0
    total_length = 0
    total_sent_num = 0

    model.train()
    for i_batch, (data) in enumerate(data_loader):
        batch_x, batch_y, feat_lengths, script_lengths = data
        optimizer.zero_grad()

        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        feat_lengths = feat_lengths.to(device)

        src_len = batch_y.size(1)
        target = batch_y[:, 1:]

        logit = model(batch_x,
                      feat_lengths,
                      batch_y,
                      teacher_forcing_ratio=teacher_forcing_ratio)

        logit = torch.stack(logit, dim=1).to(device)
        y_hat = logit.max(-1)[1]

        loss = criterion(logit.contiguous().view(-1, logit.size(-1)),
                         target.contiguous().view(-1))
        total_loss += loss.item()
        total_num += sum(feat_lengths).item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        dist, length, _ = get_distance(target, y_hat)
        total_dist += dist
        total_length += length
        cer = float(dist / length) * 100

        total_sent_num += batch_y.size(0)
        if i_batch % 1000 == 0 and i_batch != 0:
            state = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, args.model_path)

        print('Epoch: [{0}][{1}/{2}]\t Loss {loss:.4f}\t Cer {cer:.4f}'.format(
            (epoch + 1), (i_batch + 1), len(train_sampler), loss=loss,
            cer=cer))

    return total_loss / total_num, (total_dist / total_length) * 100