def train(logging_start_epoch, epoch, data, model, criterion, optimizer):
    """Main training procedure.
    
    Arguments:
        logging_start_epoch -- number of the first epoch to be logged
        epoch -- current epoch 
        data -- DataLoader which can provide batches for an epoch
        model -- model to be trained
        criterion -- instance of loss function to be optimized
        optimizer -- instance of optimizer which will be used for parameter updates
    """
    text_logger = logging.getLogger(__name__)

    model.train()

    # initialize counters, etc.
    learning_rate = optimizer.param_groups[0]['lr']
    cla = 0
    done, start_time = 0, time.time()

    total_loss = AverageMeter('Total Loss', ':.4e')
    mel_pre_loss = AverageMeter('Mel Pre Loss', ':.4e')
    mel_post_loss = AverageMeter('Mel Post Loss', ':.4e')
    lang_class_acc = AverageMeter('Lang Class Acc', ':.4e')
    progress = ProgressMeter(len(data),
                             total_loss,
                             mel_pre_loss,
                             mel_post_loss,
                             lang_class_acc,
                             prefix="Epoch: [{}]".format(epoch),
                             logger=text_logger)

    # loop through epoch batches
    for i, batch in enumerate(data):

        global_step = done + epoch * len(data)
        optimizer.zero_grad()

        # parse batch
        batch = list(map(to_gpu, batch))
        src, src_len, trg_mel, trg_lin, trg_len, stop_trg, spkrs, langs = batch

        # get teacher forcing ratio
        if hp.constant_teacher_forcing: tf = hp.teacher_forcing
        else:
            tf = cos_decay(
                max(global_step - hp.teacher_forcing_start_steps, 0),
                hp.teacher_forcing_steps)

        # run the model
        post_pred, pre_pred, stop_pred, alignment, spkrs_pred, enc_output = model(
            src, src_len, trg_mel, trg_len, spkrs, langs, tf)

        # evaluate loss function
        post_trg = trg_lin if hp.predict_linear else trg_mel
        classifier = model._reversal_classifier if hp.reversal_classifier else None
        loss, batch_losses = criterion(src_len, trg_len, pre_pred, trg_mel,
                                       post_pred, post_trg, stop_pred,
                                       stop_trg, alignment, spkrs, spkrs_pred,
                                       enc_output, classifier)

        total_loss.update(loss, src.size(0))
        mel_pre_loss.update(batch_losses['mel_pre'], src.size(0))
        mel_post_loss.update(batch_losses['mel_pos'], src.size(0))

        # evaluate adversarial classifier accuracy, if present
        if hp.reversal_classifier:
            input_mask = lengths_to_mask(src_len)
            trg_spkrs = torch.zeros_like(input_mask, dtype=torch.int64)
            for s in range(hp.speaker_number):
                speaker_mask = (spkrs == s)
                trg_spkrs[speaker_mask] = s
            matches = (trg_spkrs == torch.argmax(torch.nn.functional.softmax(
                spkrs_pred, dim=-1),
                                                 dim=-1))
            matches[~input_mask] = False
            cla = torch.sum(matches).item() / torch.sum(input_mask).item()
            lang_class_acc.update(cla, src.size(0))

        # comptute gradients and make a step
        loss.backward()
        gradient = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                  hp.gradient_clipping)
        optimizer.step()

        # log training progress
        if epoch >= logging_start_epoch:
            Logger.training(global_step, batch_losses, gradient, learning_rate,
                            time.time() - start_time, cla)
            progress.print(i)

        # update criterion states (params and decay of the loss and so on ...)
        criterion.update_states()

        start_time = time.time()
        done += 1
def evaluate(epoch, data, model, criterion):
    """Main evaluation procedure.
    
    Arguments:
        epoch -- current epoch 
        data -- DataLoader which can provide validation batches
        model -- model to be evaluated
        criterion -- instance of loss function to measure performance
    """
    text_logger = logging.getLogger(__name__)

    model.eval()

    # initialize counters, etc.
    mcd, mcd_count = 0, 0
    cla, cla_count = 0, 0
    eval_losses = {}

    total_loss = AverageMeter('Total Loss', ':.4e')
    mel_pre_loss = AverageMeter('Mel Pre Loss', ':.4e')
    mel_post_loss = AverageMeter('Mel Post Loss', ':.4e')
    lang_class_acc = AverageMeter('Lang Class Acc', ':.4e')
    progress = ProgressMeter(len(data),
                             total_loss,
                             mel_pre_loss,
                             mel_post_loss,
                             lang_class_acc,
                             prefix="Epoch: [{}]".format(epoch),
                             logger=text_logger)

    # loop through epoch batches
    with torch.no_grad():
        for i, batch in enumerate(data):

            # parse batch
            batch = list(map(to_gpu, batch))
            src, src_len, trg_mel, trg_lin, trg_len, stop_trg, spkrs, langs = batch

            # run the model (twice, with and without teacher forcing)
            post_pred, pre_pred, stop_pred, alignment, spkrs_pred, enc_output = model(
                src, src_len, trg_mel, trg_len, spkrs, langs, 1.0)
            post_pred_0, _, stop_pred_0, alignment_0, _, _ = model(
                src, src_len, trg_mel, trg_len, spkrs, langs, 0.0)
            stop_pred_probs = torch.sigmoid(stop_pred_0)

            # evaluate loss function
            post_trg = trg_lin if hp.predict_linear else trg_mel
            classifier = model._reversal_classifier if hp.reversal_classifier else None
            loss, batch_losses = criterion(src_len, trg_len, pre_pred, trg_mel,
                                           post_pred, post_trg, stop_pred,
                                           stop_trg, alignment, spkrs,
                                           spkrs_pred, enc_output, classifier)
            total_loss.update(loss, src.size(0))
            mel_pre_loss.update(batch_losses['mel_pre'], src.size(0))
            mel_post_loss.update(batch_losses['mel_pos'], src.size(0))
            # compute mel cepstral distorsion
            for j, (gen, ref, stop) in enumerate(
                    zip(post_pred_0, trg_mel, stop_pred_probs)):
                stop_idxes = np.where(stop.cpu().numpy() > 0.5)[0]
                stop_idx = min(
                    np.min(stop_idxes) + hp.stop_frames,
                    gen.size()[1]) if len(stop_idxes) > 0 else gen.size()[1]
                gen = gen[:, :stop_idx].data.cpu().numpy()
                ref = ref[:, :trg_len[j]].data.cpu().numpy()
                if hp.normalize_spectrogram:
                    gen = audio.denormalize_spectrogram(
                        gen, not hp.predict_linear)
                    ref = audio.denormalize_spectrogram(ref, True)
                if hp.predict_linear: gen = audio.linear_to_mel(gen)
                mcd = (mcd_count * mcd + audio.mel_cepstral_distorision(
                    gen, ref, 'dtw')) / (mcd_count + 1)
                mcd_count += 1

            # compute adversarial classifier accuracy
            if hp.reversal_classifier:
                input_mask = lengths_to_mask(src_len)
                trg_spkrs = torch.zeros_like(input_mask, dtype=torch.int64)
                for s in range(hp.speaker_number):
                    speaker_mask = (spkrs == s)
                    trg_spkrs[speaker_mask] = s
                matches = (trg_spkrs == torch.argmax(
                    torch.nn.functional.softmax(spkrs_pred, dim=-1), dim=-1))
                matches[~input_mask] = False
                cla = (cla_count * cla + torch.sum(matches).item() /
                       torch.sum(input_mask).item()) / (cla_count + 1)
                cla_count += 1
                lang_class_acc.update(cla, src.size(0))

            # add batch losses to epoch losses
            for k, v in batch_losses.items():
                eval_losses[k] = v + eval_losses[k] if k in eval_losses else v

    # normalize loss per batch
    for k in eval_losses.keys():
        eval_losses[k] /= len(data)

    # log evaluation
    progress.print(i)
    Logger.evaluation(epoch + 1, eval_losses, mcd, src_len, trg_len, src,
                      post_trg, post_pred, post_pred_0, stop_pred_probs,
                      stop_trg, alignment_0, cla)

    return sum(eval_losses.values())