Example #1
0
    def enable_streaming(self, secondary_model=None, return_string_parts=True):
        """
        Enables the DanSpeech system to perform speech recognition on a stream of audio data.

        :param secondary_model: A DanSpeech to perform speech recognition when a buffer of audio data has been build,
        hence this model can be given to provide better final transcriptions. If None, then the system will use the
        streaming model for the final output.
        """
        # Streaming declarations
        self.full_output = []
        self.iterating_transcript = ""
        if secondary_model:
            self.secondary_model = secondary_model.to(self.device)
            self.secondary_model.eval()
        else:
            self.secondary_model = None

        self.spectrograms = []

        # This is needed for streaming decoding
        self.greedy_decoder = GreedyDecoder(labels=self.labels, blank_index=self.labels.index('_'))

        # Use SpecroGramAudioParser
        self.audio_parser = InferenceSpectrogramAudioParser(audio_config=self.audio_config)

        if return_string_parts:
            self.string_parts = True
        else:
            self.string_parts = False
Example #2
0
    def update_decoder(self, lm=None, alpha=None, beta=None, labels=None, beam_width=None):

        update = False

        # If both lm_name and decoder is not set, then we need to init greedy as default use
        if not self.lm and not self.decoder:
            update = True
            self.lm = "greedy"

        if lm and self.lm != lm:
            update = True
            self.lm = lm

        if alpha and self.alpha != alpha:
            update = True
            self.alpha = alpha

        if beta and self.beta != beta:
            update = True
            self.beta = beta

        if labels and labels != self.labels:
            update = True
            self.labels = labels

        if beam_width and beam_width != self.beam_width:
            update = True
            self.beam_width = beam_width

        if update:
            if self.lm != "greedy":
                self.decoder = BeamCTCDecoder(labels=self.labels, lm_path=self.lm,
                                              alpha=self.alpha, beta=self.beta,
                                              beam_width=self.beam_width, num_processes=6, cutoff_prob=1.0,
                                              cutoff_top_n=40, blank_index=self.labels.index('_'))

            else:
                self.decoder = GreedyDecoder(labels=self.labels, blank_index=self.labels.index('_'))
Example #3
0
def _train_model(model_id=None,
                 train_data_paths=None,
                 train_data_weights=None,
                 validation_data_path=None,
                 epochs=20,
                 stored_model=None,
                 save_dir=None,
                 use_tensorboard=True,
                 augmented_training=False,
                 batch_size=32,
                 num_workers=6,
                 cuda=False,
                 lr=3e-4,
                 momentum=0.9,
                 weight_decay=1e-5,
                 max_norm=400,
                 learning_anneal=1.0,
                 context=20,
                 finetune=False,
                 continue_train=False,
                 train_new=False,
                 num_freeze_layers=None,
                 rnn_type='gru',
                 conv_layers=2,
                 rnn_hidden_layers=5,
                 rnn_hidden_size=800,
                 bidirectional=True,
                 distributed=False,
                 gpu_rank=None,
                 dist_backend='nccl',
                 rank=0,
                 dist_url='tcp://127.0.0.1:1550',
                 world_size=1,
                 danspeech_model=None,
                 augmentations=None,
                 sampling_rate=16000,
                 window="hamming",
                 window_stride=0.01,
                 window_size=0.02,
                 save_every_epoch=0):
    # set training device
    if augmentations is None:
        augmentations = []
    main_proc = rank == 0
    print("Is main proc: {}".format(main_proc))
    if cuda and not torch.cuda.is_available():
        warnings.warn("Specified GPU training but cuda is not available...",
                      CudaNotAvailable)

    device = torch.device("cuda" if cuda else "cpu")

    # prepare directories for storage and logging.
    if not save_dir:
        warnings.warn(
            "You did not specify a directory for saving the trained model.\n"
            "Defaulting to ~/.danspeech/custom/ directory.",
            NoModelSaveDirSpecified)

        save_dir = os.path.join(os.path.expanduser('~'), '.danspeech/models/')

    os.makedirs(save_dir, exist_ok=True)

    if not model_id:
        warnings.warn(
            "You did not specify a name for the trained model.\n"
            "Defaulting to danish_speaking_panda.pth", NoModelNameSpecified)

        model_id = "danish_speaking_panda"

    assert train_data_paths, "please specify path(s) to a valid directory with training data"
    assert validation_data_path, "please specify path to a valid directory with validation data"

    if main_proc and use_tensorboard:
        logging_process = True
        tensorboard_logger = TensorBoardLogger(model_id, save_dir)
    else:
        logging_process = False
        if main_proc:
            warnings.warn(
                "You did not specify a directory for logging training process. Training process will not be logged.",
                NoLoggingDirSpecified)

    # handle distributed processing
    if distributed:
        import torch.distributed as dist
        from torch.utils.data.distributed import DistributedSampler
        from torch.nn.parallel import DistributedDataParallel

        if gpu_rank:
            torch.cuda.set_device(int(gpu_rank))

        dist.init_process_group(backend=dist_backend,
                                init_method=dist_url,
                                world_size=world_size,
                                rank=rank)

    # initialize training metrics
    loss_results = torch.Tensor(epochs)
    cer_results = torch.Tensor(epochs)
    wer_results = torch.Tensor(epochs)

    # initialize helper variables
    avg_loss = 0
    start_epoch = 0
    start_iter = 0

    # load and initialize model metrics based on wrapper function
    if train_new:
        with open(os.path.dirname(os.path.realpath(__file__)) + '/labels.json',
                  "r",
                  encoding="utf-8") as label_file:
            labels = str(''.join(json.load(label_file)))

        # changing the default audio config is highly experimental, make changes with care and expect vastly
        # different results compared to baseline
        audio_conf = get_audio_config(normalize=True,
                                      sample_rate=sampling_rate,
                                      window=window,
                                      window_stride=window_stride,
                                      window_size=window_size)

        rnn_type = rnn_type.lower()
        conv_layers = conv_layers
        assert rnn_type in ["lstm", "rnn", "gru"
                            ], "rnn_type should be either lstm, rnn or gru"
        assert conv_layers in [1, 2, 3
                               ], "conv_layers must be set to either 1, 2 or 3"
        model = DeepSpeech(
            model_name=model_id,
            conv_layers=conv_layers,
            rnn_hidden_size=rnn_hidden_size,
            rnn_layers=rnn_hidden_layers,
            labels=labels,
            rnn_type=supported_rnns.get(rnn_type),
            audio_conf=audio_conf,
            bidirectional=bidirectional,
            streaming_inference_model=False,
            # streaming inference should always be disabled during training
            context=context)
        model = model.to(device)
        parameters = model.parameters()
        optimizer = torch.optim.SGD(parameters,
                                    lr=lr,
                                    momentum=momentum,
                                    nesterov=True,
                                    weight_decay=1e-5)

    if finetune:
        if not stored_model and danspeech_model is None:
            raise ArgumentMissingForOption(
                "If you want to finetune, please provide the absolute path"
                "to a trained pytorch model object as the --continue_model_path argument"
            )
        else:
            if danspeech_model:
                print("Using DanSpeech model: {}".format(
                    danspeech_model.model_name))
                model = danspeech_model
            else:
                print("Loading checkpoint model %s" % stored_model)
                package = torch.load(stored_model,
                                     map_location=lambda storage, loc: storage)
                model = DeepSpeech.load_model_package(package)

            if num_freeze_layers:
                # freezing layers might result in unexpected results, use with cation
                model.freeze_layers(num_freeze_layers)

            model = model.to(device)

            parameters = model.parameters()
            optimizer = torch.optim.SGD(parameters,
                                        lr=lr,
                                        momentum=momentum,
                                        nesterov=True,
                                        weight_decay=1e-5)

    if continue_train:
        # continue_training wrapper
        if not stored_model:
            raise ArgumentMissingForOption(
                "If you want to continue training, please support a package with previous"
                "training information or use the finetune option instead")
        else:
            print("Loading checkpoint model %s" % stored_model)
            package = torch.load(stored_model,
                                 map_location=lambda storage, loc: storage)
            model = DeepSpeech.load_model_package(package)
            model = model.to(device)
            # load stored training information
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=lr,
                                        momentum=momentum,
                                        nesterov=True,
                                        weight_decay=weight_decay)
            optim_state = package['optim_dict']
            optimizer.load_state_dict(optim_state)
            start_epoch = int(
                package['epoch']) + 1  # Index start at 0 for training

            print("Last successfully trained Epoch: {0}".format(start_epoch))

            start_epoch += 1
            start_iter = 0

            avg_loss = int(package.get('avg_loss', 0))
            loss_results_ = package['loss_results']
            cer_results_ = package['cer_results']
            wer_results_ = package['wer_results']

            # ToDo: Make depend on the epoch from the package
            previous_epochs = loss_results_.size()[0]
            print("Previously set to run for: {0} epochs".format(
                previous_epochs))

            loss_results[0:previous_epochs] = loss_results_
            wer_results[0:previous_epochs] = cer_results_
            cer_results[0:previous_epochs] = wer_results_

            if logging_process:
                tensorboard_logger.load_previous_values(start_epoch, package)

    # initialize DanSpeech augmenter
    if augmented_training:
        augmenter = DanSpeechAugmenter(
            sampling_rate=model.audio_conf["sample_rate"],
            augmentation_list=augmentations)
    else:
        augmenter = None

    # initialize audio parser and dataset
    # audio parsers
    training_parser = SpectrogramAudioParser(audio_config=model.audio_conf,
                                             data_augmenter=augmenter)
    validation_parser = SpectrogramAudioParser(audio_config=model.audio_conf,
                                               data_augmenter=None)

    # instantiate data-sets
    multi_data_set = False
    if len(train_data_paths) > 1:
        assert len(train_data_paths) == len(
            train_data_weights), "Must provide weights for each dataset"
        multi_data_set = True
        training_set = DanSpeechMultiDataset(train_data_paths,
                                             train_data_weights,
                                             labels=model.labels,
                                             audio_parser=training_parser)
    else:
        training_set = DanSpeechDataset(train_data_paths[0],
                                        labels=model.labels,
                                        audio_parser=training_parser)

    validation_set = DanSpeechDataset(validation_data_path,
                                      labels=model.labels,
                                      audio_parser=validation_parser)

    # initialize batch loaders
    if not distributed:
        # initialize batch loaders for single GPU or CPU training
        if multi_data_set:
            train_batch_loader = MultiDatasetBatchDataLoader(
                training_set,
                batch_size=batch_size,
                num_workers=num_workers,
                pin_memory=True)
        else:
            train_batch_loader = BatchDataLoader(training_set,
                                                 batch_size=batch_size,
                                                 num_workers=num_workers,
                                                 shuffle=True,
                                                 pin_memory=True)
        validation_batch_loader = BatchDataLoader(validation_set,
                                                  batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  shuffle=False)
    else:
        # initialize batch loaders for distributed training on multiple GPUs
        if multi_data_set:
            train_sampler = DistributedWeightedSamplerCustom(
                training_set, num_replicas=world_size, rank=rank)
        else:
            train_sampler = DistributedSamplerCustom(training_set,
                                                     num_replicas=world_size,
                                                     rank=rank)

        train_batch_loader = BatchDataLoader(training_set,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             sampler=train_sampler,
                                             pin_memory=True)

        validation_sampler = DistributedSampler(validation_set,
                                                num_replicas=world_size,
                                                rank=rank)

        validation_batch_loader = BatchDataLoader(validation_set,
                                                  batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  sampler=validation_sampler)

    decoder = GreedyDecoder(model.labels)
    criterion = CTCLoss()
    best_wer = None
    model = model.to(device)

    if distributed:
        model = DistributedDataParallel(model, device_ids=[int(gpu_rank)])

    # verbatim training outputs during progress
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    print(model)
    print("Initializations complete, starting training pass on model: %s \n" %
          model_id)
    print("Number of parameters: %d \n" % DeepSpeech.get_param_size(model))
    try:
        for epoch in range(start_epoch, epochs):
            if distributed and epoch != 0:
                # distributed sampling, keep epochs on all GPUs
                train_sampler.set_epoch(epoch)

            print('started training epoch %d' % (epoch + 1))
            model.train()

            # timings per epoch
            end = time.time()
            start_epoch_time = time.time()
            num_updates = len(train_batch_loader)

            # per epoch training loop, iterate over all mini-batches in the training set
            for i, (data) in enumerate(train_batch_loader, start=start_iter):
                if i == num_updates:
                    break

                # grab and prepare a sample for a training pass
                inputs, targets, input_percentages, target_sizes = data
                input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

                # measure data load times, this gives an indication on the number of workers required for latency
                # free training.
                data_time.update(time.time() - end)

                # parse data and perform a training pass
                inputs = inputs.to(device)

                # compute the CTC-loss and average over mini-batch
                out, output_sizes = model(inputs, input_sizes)
                out = out.transpose(0, 1)
                float_out = out.float()
                loss = criterion(float_out, targets, output_sizes,
                                 target_sizes).to(device)
                loss = loss / inputs.size(0)

                # check for diverging losses
                if distributed:
                    loss_value = reduce_tensor(loss, world_size).item()
                else:
                    loss_value = loss.item()

                if loss_value == float("inf") or loss_value == -float("inf"):
                    warnings.warn(
                        "received an inf loss, setting loss value to 0",
                        InfiniteLossReturned)
                    loss_value = 0

                # update average loss, and loss tensor
                avg_loss += loss_value
                losses.update(loss_value, inputs.size(0))

                # compute gradients and back-propagate errors
                optimizer.zero_grad()
                loss.backward()

                # avoid exploding gradients by clip_grad_norm, defaults to 400
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_norm=max_norm)

                # stochastic gradient descent step
                optimizer.step()

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                          (epoch + 1), (epochs), (i + 1),
                          len(train_batch_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses))

                del loss, out, float_out

            # report epoch summaries and prepare validation run
            avg_loss /= len(train_batch_loader)
            loss_results[epoch] = avg_loss
            epoch_time = time.time() - start_epoch_time
            print('Training Summary Epoch: [{0}]\t'
                  'Time taken (s): {epoch_time:.0f}\t'
                  'Average Loss {loss:.3f}\t'.format(epoch + 1,
                                                     epoch_time=epoch_time,
                                                     loss=avg_loss))

            # prepare validation specific parameters, and set model ready for evaluation
            total_cer, total_wer = 0, 0
            model.eval()
            with torch.no_grad():
                for i, (data) in tqdm(enumerate(validation_batch_loader),
                                      total=len(validation_batch_loader)):
                    inputs, targets, input_percentages, target_sizes = data
                    input_sizes = input_percentages.mul_(int(
                        inputs.size(3))).int()

                    # unflatten targets
                    split_targets = []
                    offset = 0
                    targets = targets.numpy()
                    for size in target_sizes:
                        split_targets.append(targets[offset:offset + size])
                        offset += size

                    inputs = inputs.to(device)
                    out, output_sizes = model(inputs, input_sizes)
                    decoded_output, _ = decoder.decode(out, output_sizes)
                    target_strings = decoder.convert_to_strings(split_targets)

                    # compute accuracy metrics
                    wer, cer = 0, 0
                    for x in range(len(target_strings)):
                        transcript, reference = decoded_output[x][
                            0], target_strings[x][0]
                        wer += decoder.wer(transcript, reference) / float(
                            len(reference.split()))
                        cer += decoder.cer(transcript, reference) / float(
                            len(reference))

                    total_wer += wer
                    total_cer += cer
                    del out

            if distributed:
                # sums tensor across all devices if distributed training is enabled
                total_wer_tensor = torch.tensor(total_wer).to(device)
                total_wer_tensor = sum_tensor(total_wer_tensor)
                total_wer = total_wer_tensor.item()

                total_cer_tensor = torch.tensor(total_cer).to(device)
                total_cer_tensor = sum_tensor(total_cer_tensor)
                total_cer = total_cer_tensor.item()

                del total_wer_tensor, total_cer_tensor

            # compute average metrics for the validation pass
            avg_wer_epoch = (total_wer /
                             len(validation_batch_loader.dataset)) * 100
            avg_cer_epoch = (total_cer /
                             len(validation_batch_loader.dataset)) * 100

            # append metrics for logging
            loss_results[epoch] = avg_loss
            wer_results[epoch] = avg_wer_epoch
            cer_results[epoch] = avg_cer_epoch

            # log metrics for tensorboard
            if logging_process:
                logging_values = {
                    "loss_results": loss_results,
                    "wer": wer_results,
                    "cer": cer_results
                }
                tensorboard_logger.update(epoch, logging_values)

            # print validation metrics summary
            print('Validation Summary Epoch: [{0}]\t'
                  'Average WER {wer:.3f}\t'
                  'Average CER {cer:.3f}\t'.format(epoch + 1,
                                                   wer=avg_wer_epoch,
                                                   cer=avg_cer_epoch))

            # check if the model is uni or bidirectional, and set streaming model accordingly
            if not bidirectional:
                streaming_inference_model = True
            else:
                streaming_inference_model = False

            # save model if it has the highest recorded performance on validation.
            if main_proc and (best_wer is None or best_wer > wer):
                model_path = save_dir + model_id + '.pth'
                print("Found better validated model, saving to %s" %
                      model_path)
                torch.save(
                    serialize(model,
                              optimizer=optimizer,
                              epoch=epoch,
                              loss_results=loss_results,
                              wer_results=wer_results,
                              cer_results=cer_results,
                              distributed=distributed,
                              streaming_model=streaming_inference_model,
                              context=context), model_path)
            if main_proc and save_every_epoch != 0 and (
                    epoch + 1) % save_every_epoch == 0:
                model_epochs_save_path = save_dir + model_id + '_epoch_{}'.format(
                    epoch + 1) + '.pth'
                print(
                    "Saving since save_every_epoch option has been given to %s"
                    % model_epochs_save_path)
                torch.save(
                    serialize(model,
                              optimizer=optimizer,
                              epoch=epoch,
                              loss_results=loss_results,
                              wer_results=wer_results,
                              cer_results=cer_results,
                              distributed=distributed,
                              streaming_model=streaming_inference_model,
                              context=context), model_epochs_save_path)

                param_groups = optimizer.param_groups
                if learning_anneal != 1.0:
                    for g in param_groups:
                        g['lr'] = g['lr'] / learning_anneal
                    print('Learning rate annealed to: {lr:.6f}'.format(
                        lr=g['lr']))

                best_wer = wer
                avg_loss = 0

            # reset start iteration for next epoch
            start_iter = 0

    except KeyboardInterrupt:
        print('Successfully exited training and stopped all processes.')
Example #4
0
class DanSpeechRecognizer(object):

    def __init__(self, model_name=None, lm_name=None,
                 alpha=1.3, beta=0.2, with_gpu=False,
                 beam_width=64):

        self.device = torch.device("cuda" if with_gpu else "cpu")
        print("Using device: {0}".format(self.device))

        # Init model if given
        if model_name:
            self.update_model(model_name)
        else:
            self.model = None
            self.model_name = None
            self.labels = None
            self.audio_config = None
            self.audio_parser = None

        # Always set alpha and beta
        self.alpha = alpha
        self.beta = beta
        self.beam_width = beam_width

        # Init LM if given
        if lm_name:
            if not self.model:
                raise ModelNotInitialized("Trying to initialize LM without also choosing a DanSpeech model.")
            else:
                self.update_decoder(lm_name)
                self.lm = lm_name
        else:
            self.lm = None
            self.decoder = None

    def update_model(self, model):
        self.audio_config = model.audio_conf
        self.model = model.to(self.device)
        self.model.eval()
        self.audio_parser = SpectrogramAudioParser(self.audio_config)

        self.labels = self.model.labels
        # When updating model, always update decoder because of labels
        self.update_decoder(labels=self.labels)

    def update_decoder(self, lm=None, alpha=None, beta=None, labels=None, beam_width=None):

        update = False

        # If both lm_name and decoder is not set, then we need to init greedy as default use
        if not self.lm and not self.decoder:
            update = True
            self.lm = "greedy"

        if lm and self.lm != lm:
            update = True
            self.lm = lm

        if alpha and self.alpha != alpha:
            update = True
            self.alpha = alpha

        if beta and self.beta != beta:
            update = True
            self.beta = beta

        if labels and labels != self.labels:
            update = True
            self.labels = labels

        if beam_width and beam_width != self.beam_width:
            update = True
            self.beam_width = beam_width

        if update:
            if self.lm != "greedy":
                self.decoder = BeamCTCDecoder(labels=self.labels, lm_path=self.lm,
                                              alpha=self.alpha, beta=self.beta,
                                              beam_width=self.beam_width, num_processes=6, cutoff_prob=1.0,
                                              cutoff_top_n=40, blank_index=self.labels.index('_'))

            else:
                self.decoder = GreedyDecoder(labels=self.labels, blank_index=self.labels.index('_'))


    def enable_streaming(self, secondary_model=None, return_string_parts=True):
        """
        Enables the DanSpeech system to perform speech recognition on a stream of audio data.

        :param secondary_model: A DanSpeech to perform speech recognition when a buffer of audio data has been build,
        hence this model can be given to provide better final transcriptions. If None, then the system will use the
        streaming model for the final output.
        """
        # Streaming declarations
        self.full_output = []
        self.iterating_transcript = ""
        if secondary_model:
            self.secondary_model = secondary_model.to(self.device)
            self.secondary_model.eval()
        else:
            self.secondary_model = None

        self.spectrograms = []

        # This is needed for streaming decoding
        self.greedy_decoder = GreedyDecoder(labels=self.labels, blank_index=self.labels.index('_'))

        # Use SpecroGramAudioParser
        self.audio_parser = InferenceSpectrogramAudioParser(audio_config=self.audio_config)

        if return_string_parts:
            self.string_parts = True
        else:
            self.string_parts = False

    def disable_streaming(self, keep_secondary_model=False):
        self.audio_parser = SpectrogramAudioParser(self.audio_config)
        self.greedy_decoder = None
        self.reset_streaming_params()
        self.string_parts = False

        if not keep_secondary_model:
            self.secondary_model = None


    def reset_streaming_params(self):
        self.iterating_transcript = ""
        self.full_output = []
        self.spectrograms = []


    def streaming_transcribe(self, recording, is_last, is_first):
        recording = self.audio_parser.parse_audio(recording, is_last)
        out = ""
        # This can happen if it is the last part of a recording and there is too little samples
        # to generate a spectrogram
        if len(recording) != 0:
            if self.secondary_model:
                self.spectrograms.append(recording)

            # Convert recording to batch for model purpose
            recording = recording.view(1, 1, recording.size(0), recording.size(1))
            recording = recording.to(self.device)

            out = self.model(recording, is_first, is_last)

            # First pass returns None, as we need more context to perform the first prediction
            if is_first:
                return ""

            self.full_output.append(out)

            # Decode the output with greedy decoding
            decoded_out, _ = self.greedy_decoder.decode(out)
            transcript = decoded_out[0][0]

            # Collapsing characters hack
            if self.iterating_transcript and transcript and self.iterating_transcript[-1] == transcript[0]:
                self.iterating_transcript = self.iterating_transcript + transcript[1:]
                transcript = transcript[1:]
            else:
                self.iterating_transcript += transcript

            if self.string_parts:
                out = transcript
            else:
                out = self.iterating_transcript

        if is_last:
            # If something was actually detected (require at least two characters)
            if len(self.iterating_transcript) > 1:

                # If we use secondary model, pass full output through the model
                if self.secondary_model:

                    final = torch.cat(self.spectrograms, dim=1)
                    self.spectrograms = []

                    final = final.view(1, 1, final.size(0), final.size(1))
                    final = final.to(self.device)
                    input_sizes = torch.IntTensor([final.size(3)]).int()
                    out, _ = self.secondary_model(final, input_sizes)
                    decoded_out, _ = self.decoder.decode(out)
                    decoded_out = decoded_out[0][0]

                    self.reset_streaming_params()
                    return decoded_out

                else:
                    # if no secondary model, check whether we need to decode with LM or not
                    if self.lm != "greedy":
                        final_out = torch.cat(self.full_output, dim=1)
                        decoded_out, _ = self.decoder.decode(final_out)
                        decoded_out = decoded_out[0][0]
                        self.reset_streaming_params()
                        return decoded_out
                    else:
                        out = self.iterating_transcript
                        self.reset_streaming_params()
                        return out
            else:
                return ""

        return out

    def transcribe(self, recording, show_all=False):
        recording = self.audio_parser.parse_audio(recording)
        recording = recording.view(1, 1, recording.size(0), recording.size(1))
        recording = recording.to(self.device)
        input_sizes = torch.IntTensor([recording.size(3)]).int()
        out, output_sizes = self.model(recording, input_sizes)
        decoded_output, _ = self.decoder.decode(out, output_sizes)
        if show_all:
            if self.lm == 'greedy':
                warnings.warn("You are trying to get all beams but no LM has been instantiated.",
                              NoLmInstantiatedWarning)
            return decoded_output[0]
        else:
            return decoded_output[0][0]
Example #5
0
def test_model(model_path,
               data_path,
               decoder="greedy",
               cuda=False,
               batch_size=96,
               num_workers=4,
               lm_path=None,
               alpha=1.3,
               beta=0.4,
               cutoff_top_n=40,
               cutoff_prob=1.0,
               beam_width=64,
               lm_workers=4,
               output_path=None,
               verbose=False):
    torch.set_grad_enabled(False)
    model = DeepSpeech.load_model(model_path)
    device = torch.device("cuda" if cuda else "cpu")
    model = model.to(device)
    model.eval()

    # -- Save name of model name and language model to when scores are saved
    model_name = model_path.split("/")[-1]
    data_name = data_path.split("/")[-1]

    if lm_path is None:
        lm = "None"
    else:
        lm = lm_path.split("/")[-1]

    print(model_name)

    # -- Tue: Implemented to make beam search possible
    if decoder == "beam":
        from danspeech.deepspeech.decoder import BeamCTCDecoder

        decode_type = decoder
        decoder = BeamCTCDecoder(model.labels,
                                 lm_path=lm_path,
                                 alpha=alpha,
                                 beta=beta,
                                 cutoff_top_n=cutoff_top_n,
                                 cutoff_prob=cutoff_prob,
                                 beam_width=beam_width,
                                 num_processes=lm_workers)
        greedy_decoder = GreedyDecoder(model.labels,
                                       blank_index=model.labels.index('_'))

    elif decoder == "greedy":

        decode_type = decoder
        decoder = GreedyDecoder(model.labels,
                                blank_index=model.labels.index('_'))
    else:
        raise AttributeError(
            "please specify a valid decoder, DanSpeech currently supports [greedy, beam]"
        )

    test_parser = SpectrogramAudioParser(audio_config=model.audio_conf,
                                         data_augmenter=None)
    test_dataset = DanSpeechDataset(data_path,
                                    'test.csv',
                                    labels=model.labels,
                                    audio_parser=test_parser)
    test_batch_loader = BatchDataLoader(test_dataset,
                                        batch_size=batch_size,
                                        num_workers=num_workers,
                                        shuffle=False)

    total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
    output_data = []
    testScore = []  # -- Tue: Creating aƩmpty list to save each score
    for i, (data) in tqdm(enumerate(test_batch_loader),
                          total=len(test_batch_loader)):
        inputs, targets, input_percentages, target_sizes = data
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

        split_targets = []
        offset = 0
        targets = targets.numpy()
        for size in target_sizes:
            split_targets.append(targets[offset:offset + size])
            offset += size

        inputs = inputs.to(device)
        out, output_sizes = model(inputs, input_sizes)

        if decoder is None:
            output_data.append((out.numpy(), output_sizes.numpy()))
            continue

        decoded_output, _ = decoder.decode(out.data, output_sizes.data)

        # -- Tue: Implemented to make beam search possible
        if decode_type == "greedy":
            target_strings = decoder.convert_to_strings(split_targets)
        elif decode_type == "beam":
            target_strings = greedy_decoder.convert_to_strings(split_targets)

        for x in range(len(target_strings)):
            transcript, reference = decoded_output[x][0], target_strings[x][0]
            wer_inst = decoder.wer(transcript, reference)
            cer_inst = decoder.cer(transcript, reference)
            total_wer += wer_inst
            total_cer += cer_inst
            num_tokens += len(reference.split())
            num_chars += len(reference)

            # -- Tue: Appending save score to make statistics possible
            testScore.append([
                data_name, decode_type, lm, model_name,
                reference.lower(),
                float(wer_inst) / len(reference.split()),
                float(cer_inst) / len(reference),
                transcript.lower()
            ])

            if verbose:
                print("Ref:", reference.lower())
                print("Hyp:", transcript.lower())
                print("WER:",
                      float(wer_inst) / len(reference.split()), "CER:",
                      float(cer_inst) / len(reference), "\n")

    if decoder is not None:
        wer = float(total_wer) / num_tokens
        cer = float(total_cer) / num_chars

        print('Test Summary \t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))

        #-- Tue: Saving score for testing
        csvSaverTest(testScore)

        #return wer, cer
    elif output_path is not None:
        np.save(output_path, output_data)
    else:
        pass

    return
Example #6
0
def test_model(model_path,
               data_path,
               decoder="greedy",
               cuda=False,
               batch_size=96,
               num_workers=4,
               lm_path=None,
               alpha=1.3,
               beta=0.4,
               cutoff_top_n=40,
               cutoff_prob=1.0,
               beam_width=64,
               lm_workers=4,
               verbose=False,
               transcriptions_out_file=None):
    torch.set_grad_enabled(False)
    model = DeepSpeech.load_model(model_path)
    device = torch.device("cuda" if cuda else "cpu")
    model = model.to(device)
    model.eval()

    if decoder == "beam":
        from danspeech.deepspeech.decoder import BeamCTCDecoder

        decoder = BeamCTCDecoder(model.labels,
                                 lm_path=lm_path,
                                 alpha=alpha,
                                 beta=beta,
                                 cutoff_top_n=cutoff_top_n,
                                 cutoff_prob=cutoff_prob,
                                 beam_width=beam_width,
                                 num_processes=lm_workers)
    elif decoder == "greedy":
        decoder = GreedyDecoder(model.labels,
                                blank_index=model.labels.index('_'))
    else:
        raise AttributeError(
            "please specify a valid decoder, DanSpeech currently supports [greedy, beam]"
        )

    target_decoder = GreedyDecoder(model.labels,
                                   blank_index=model.labels.index('_'))

    test_parser = SpectrogramAudioParser(audio_config=model.audio_conf,
                                         data_augmenter=None)
    test_dataset = DanSpeechDataset(data_path,
                                    labels=model.labels,
                                    audio_parser=test_parser)
    test_batch_loader = BatchDataLoader(test_dataset,
                                        batch_size=batch_size,
                                        num_workers=num_workers,
                                        shuffle=False)

    total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0

    if transcriptions_out_file:
        out_f = open(transcriptions_out_file, "w", encoding="utf-8")
        out_f.write("reference,transcription,WER,CER\n")

    for i, (data) in tqdm(enumerate(test_batch_loader),
                          total=len(test_batch_loader)):
        inputs, targets, input_percentages, target_sizes = data
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

        split_targets = []
        offset = 0
        targets = targets.numpy()
        for size in target_sizes:
            split_targets.append(targets[offset:offset + size])
            offset += size

        inputs = inputs.to(device)
        out, output_sizes = model(inputs, input_sizes)

        decoded_output, _ = decoder.decode(out.data, output_sizes.data)
        target_strings = target_decoder.convert_to_strings(split_targets)
        for x in range(len(target_strings)):
            transcript, reference = decoded_output[x][0], target_strings[x][0]
            wer_inst = decoder.wer(transcript, reference)
            cer_inst = decoder.cer(transcript, reference)
            total_wer += wer_inst
            total_cer += cer_inst
            num_tokens += len(reference.split())
            num_chars += len(reference)
            if verbose:
                print("Ref:", reference.lower())
                print("Hyp:", transcript.lower())
                print("WER:",
                      float(wer_inst) / len(reference.split()), "CER:",
                      float(cer_inst) / len(reference), "\n")

            if transcriptions_out_file:
                out_f.write("{0},{1},{2},{3}\n".format(
                    reference.lower(), transcript.lower(),
                    float(wer_inst) / len(reference.split()),
                    float(cer_inst) / len(reference)))

    if transcriptions_out_file:
        out_f.close()

    if decoder is not None:
        wer = float(total_wer) / num_tokens
        cer = float(total_cer) / num_chars

        print('Test Summary \t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))
Example #7
0
def _train_model(model_id=None,
                 training_set=None,
                 validation_set=None,
                 root_dir=None,
                 in_memory=False,
                 epochs=20,
                 stored_model=None,
                 model_save_dir=None,
                 tensorboard_log_dir=None,
                 augmented_training=False,
                 batch_size=8,
                 num_workers=6,
                 cuda=False,
                 lr=3e-4,
                 momentum=0.9,
                 weight_decay=1e-5,
                 max_norm=400,
                 context=20,
                 continue_train=False,
                 finetune=False,
                 train_new=False,
                 num_freeze_layers=None,
                 rnn_type='gru',
                 conv_layers=2,
                 rnn_hidden_layers=5,
                 rnn_hidden_size=800,
                 bidirectional=True,
                 distributed=False,
                 gpu_rank=None,
                 dist_backend='nccl',
                 rank=0,
                 dist_url='tcp://127.0.0.1:1550',
                 world_size=1,
                 augment_w_specaug=False,
                 augmentation_list=[],
                 augment_parameters=None,
                 augment_prob_dir=None,
                 score_ID=None):

    # Load scores
    scoresDict = loadLogger()

    # Add cross-val fold
    counter = 1
    if score_ID != None:
        ID = score_ID
    else:
        ID = model_id

    while ID in list(scoresDict.keys()):
        if ID[-2] != "_":
            ID = ID + "_" + str(counter)
        else:
            ID = ID[:-1] + str(counter)
        counter += 1

    # Add Time
    currTime = datetime.now()
    currTime = time.strftime("%d/%m/%Y, %H:%M:%S")

    # Set values
    scoresDict[ID] = {
        "Time": currTime,
        "Augmentations": augmentation_list,
        "Avg_WER": 0,
        "Avg_CER": 0
    }

    # -- set training device
    main_proc = True
    device = torch.device("cuda" if cuda else "cpu")

    # -- prepare directories for storage and logging.
    if not model_save_dir:
        warnings.warn(
            "You did not specify a directory for saving the trained model.\n"
            "Defaulting to ~/.danspeech/custom/ directory.",
            NoModelSaveDirSpecified)

        model_save_dir = os.path.join(os.path.expanduser('~'),
                                      '.danspeech/models/')

    os.makedirs(model_save_dir, exist_ok=True)

    if not model_id:
        warnings.warn(
            "You did not specify a name for the trained model.\n"
            "Defaulting to danish_speaking_panda.pth", NoModelNameSpecified)

        model_id = "danish_speaking_panda"

    if main_proc and tensorboard_log_dir:
        logging_process = True
        tensorboard_logger = TensorBoardLogger(model_id, tensorboard_log_dir)
    else:
        logging_process = False
        warnings.warn(
            "You did not specify a directory for logging training process. Training process will not be logged.",
            NoLoggingDirSpecified)

    # -- handle distributed processing
    if distributed:
        import torch.distributed as dist
        from torch.utils.data.distributed import DistributedSampler
        from apex.parallel import DistributedDataParallel

        if gpu_rank:
            torch.cuda.set_device(int(gpu_rank))

        dist.init_process_group(backend=dist_backend,
                                init_method=dist_url,
                                world_size=world_size,
                                rank=rank)

    # -- initialize training metrics
    loss_results = torch.Tensor(epochs)
    cer_results = torch.Tensor(epochs)
    wer_results = torch.Tensor(epochs)

    # -- initialize helper variables
    avg_loss = 0
    start_epoch = 0
    start_iter = 0

    # -- load and initialize model metrics based on wrapper function
    #if train_new:
    #    with open(os.path.dirname(os.path.realpath(__file__)) + '/labels.json', "r", encoding="utf-8") as label_file:
    #        labels = str(''.join(json.load(label_file)))
    #
    #    # -- changing the default audio config is highly experimental, make changes with care and expect vastly
    #    # -- different results compared to baseline
    #    audio_conf = get_default_audio_config()
    #
    #    rnn_type = rnn_type.lower()
    #    conv_layers = conv_layers
    #    assert rnn_type in ["lstm", "rnn", "gru"], "rnn_type should be either lstm, rnn or gru"
    #    assert conv_layers in [1, 2, 3], "conv_layers must be set to either 1, 2 or 3"
    #    model = DeepSpeech(model_name=model_id,
    #                       conv_layers=conv_layers,
    #                       rnn_hidden_size=rnn_hidden_size,
    #                       rnn_layers=rnn_hidden_layers,
    #                       labels=labels,
    #                       rnn_type=supported_rnns.get(rnn_type),
    #                       audio_conf=audio_conf,
    #                       bidirectional=bidirectional,
    #                       streaming_inference_model=False,  # -- streaming inference should always be disabled during training
    #                       context=context)
    #    parameters = model.parameters()
    #    optimizer = torch.optim.SGD(parameters, lr=lr,
    #                                momentum=momentum, nesterov=True, weight_decay=1e-5)

    if finetune:
        if not stored_model:
            raise ArgumentMissingForOption(
                "If you want to finetune, please provide the absolute path"
                "to a trained pytorch model object as the stored_model argument"
            )
        else:
            print("Loading checkpoint model %s" % stored_model)
            package = torch.load(stored_model,
                                 map_location=lambda storage, loc: storage)
            model = DeepSpeech.load_model_package(package)

            if num_freeze_layers:
                # -- freezing layers might result in unexpected results, use with cation
                print("Freezing of layers initiated")
                model.freeze_layers(num_freeze_layers)

            parameters = model.parameters()
            optimizer = torch.optim.SGD(parameters,
                                        lr=lr,
                                        momentum=momentum,
                                        nesterov=True,
                                        weight_decay=1e-5)

            if logging_process:
                tensorboard_logger.load_previous_values(start_epoch, package)

    #if continue_train:
    #    # -- continue_training wrapper
    #    if not stored_model:
    #        raise ArgumentMissingForOption("If you want to continue training, please support a package with previous"
    #                                       "training information or use the finetune option instead")
    #    else:
    #        print("Loading checkpoint model %s" % stored_model)
    #        package = torch.load(stored_model, map_location=lambda storage, loc: storage)
    #        model = DeepSpeech.load_model_package(package)
    #        # -- load stored training information
    #        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum,
    #                                    nesterov=True, weight_decay=1e-5)
    #        optim_state = package['optim_dict']
    #        optimizer.load_state_dict(optim_state)
    #        start_epoch = int(package['epoch']) + 1  # -- Index start at 0 for training
#
#        print("Last successfully trained Epoch: {0}".format(start_epoch))
#
#        start_epoch += 1
#        start_iter = 0
#
#        avg_loss = int(package.get('avg_loss', 0))
#        loss_results_ = package['loss_results']
#        cer_results_ = package['cer_results']
#        wer_results_ = package['wer_results']
#
#        # ToDo: Make depend on the epoch from the package
#        previous_epochs = loss_results_.size()[0]
#        print("Previously set to run for: {0} epochs".format(previous_epochs))
#
#        loss_results[0:previous_epochs] = loss_results_
#        wer_results[0:previous_epochs] = cer_results_
#        cer_results[0:previous_epochs] = wer_results_
#
#        if logging_process:
#            tensorboard_logger.load_previous_values(start_epoch, package)

#     if augment_w_specaug:
    training_parser = AugmenterAudioParser(audio_config=model.audio_conf,
                                           augmentation_list=augmentation_list,
                                           augment_args=augment_parameters,
                                           augment_prob_dir=augment_prob_dir)
    validation_parser = SpectrogramAudioParser(audio_config=model.audio_conf,
                                               data_augmenter=None)

    #     else:

    #         # -- initialize DanSpeech augmenter
    #         if augmented_training:
    #             print("Augmentations started")
    #             #augmenter = DanSpeechAugmenter(sampling_rate=model.audio_conf["sampling_rate"])
    #             augmenter = DanSpeechAugmenter(sampling_rate=model.audio_conf["sample_rate"])

    #         else:
    #             augmenter = None

    #         # -- initialize audio parser and dataset
    #         # -- audio parsers
    #         training_parser = SpectrogramAudioParser(audio_config=model.audio_conf, data_augmenter=augmenter)
    #         validation_parser = SpectrogramAudioParser(audio_config=model.audio_conf, data_augmenter=None)

    # -- instantiate data-sets
    training_set = DanSpeechDataset(root_dir,
                                    training_set,
                                    labels=model.labels,
                                    audio_parser=training_parser,
                                    in_memory=in_memory)
    validation_set = DanSpeechDataset(root_dir,
                                      validation_set,
                                      labels=model.labels,
                                      audio_parser=validation_parser,
                                      in_memory=in_memory)

    # -- Tue: extracting meta data for validation set such as file names
    meta = validation_set.meta

    print("")
    # -- initialize batch loaders
    if not distributed:
        # -- initialize batch loaders for single GPU or CPU training
        train_batch_loader = BatchDataLoader(training_set,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             shuffle=True,
                                             pin_memory=True)
        validation_batch_loader = BatchDataLoader(validation_set,
                                                  batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  shuffle=False)
    else:
        # -- initialize batch loaders for distributed training on multiple GPUs
        train_sampler = DistributedSampler(training_set,
                                           num_replicas=world_size,
                                           rank=rank)
        train_batch_loader = BatchDataLoader(training_set,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             sampler=train_sampler,
                                             pin_memory=True)

        validation_sampler = DistributedSampler(validation_set,
                                                num_replicas=world_size,
                                                rank=rank)
        validation_batch_loader = BatchDataLoader(validation_set,
                                                  batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  sampler=validation_sampler)

        model = DistributedDataParallel(model)

    decoder = GreedyDecoder(model.labels)
    criterion = CTCLoss()
    model = model.to(device)
    best_wer = None

    # -- verbatim training outputs during progress
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    print(model)
    print("Initializations complete, starting training pass on model: %s \n" %
          model_id)
    print("Number of parameters: %d \n" % DeepSpeech.get_param_size(model))

    try:
        for epoch in range(start_epoch, epochs):
            if distributed and epoch != 0:
                # -- distributed sampling, keep epochs on all GPUs
                train_sampler.set_epoch(epoch)

            print('started training epoch %d' % (epoch + 1))
            model.train()

            # -- timings per epoch
            end = time.time()
            start_epoch_time = time.time()
            num_updates = len(train_batch_loader)

            # -- per epoch training loop, iterate over all mini-batches in the training set
            for i, (data) in enumerate(train_batch_loader, start=start_iter):
                if i == num_updates:
                    break

                # -- grab and prepare a sample for a training pass
                inputs, targets, input_percentages, target_sizes = data
                input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

                # -- measure data load times, this gives an indication on the number of workers required for latency
                # -- free training.
                data_time.update(time.time() - end)

                # -- parse data and perform a training pass
                inputs = inputs.to(device)

                # -- compute the CTC-loss and average over mini-batch
                out, output_sizes = model(inputs, input_sizes)
                out = out.transpose(0, 1)
                float_out = out.float()
                loss = criterion(float_out, targets, output_sizes,
                                 target_sizes).to(device)
                loss = loss / inputs.size(0)

                # -- check for diverging losses
                if distributed:
                    loss_value = reduce_tensor(loss, world_size).item()
                else:
                    loss_value = loss.item()

                if loss_value == float("inf") or loss_value == -float("inf"):
                    warnings.warn(
                        "received an inf loss, setting loss value to 0",
                        InfiniteLossReturned)
                    loss_value = 0

                # -- update average loss, and loss tensor
                avg_loss += loss_value
                losses.update(loss_value, inputs.size(0))

                # -- compute gradients and back-propagate errors
                optimizer.zero_grad()
                loss.backward()

                # -- avoid exploding gradients by clip_grad_norm, defaults to 400
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_norm=max_norm)

                # -- stochastic gradient descent step
                optimizer.step()

                # -- measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                #print('Epoch: [{0}/{1}][{2}/{3}]\t'
                #      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                #      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                #      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                #    (epoch + 1), (epochs), (i + 1), len(train_batch_loader), batch_time=batch_time,
                #    data_time=data_time, loss=losses))

                del loss, out, float_out

            # -- report epoch summaries and prepare validation run
            avg_loss /= len(train_batch_loader)
            loss_results[epoch] = avg_loss
            epoch_time = time.time() - start_epoch_time
            print('Training Summary Epoch: [{0}]\t'
                  'Time taken (s): {epoch_time:.0f}\t'
                  'Average Loss {loss:.3f}\t'.format(epoch + 1,
                                                     epoch_time=epoch_time,
                                                     loss=avg_loss))

            # -- prepare validation specific parameters, and set model ready for evaluation
            total_cer, total_wer = 0, 0
            model.eval()
            # -- Tue: transcriptScore here so it only saves the current epoch
            transcriptScore = []
            with torch.no_grad():
                for i, (data) in tqdm(enumerate(validation_batch_loader),
                                      total=len(validation_batch_loader)):
                    inputs, targets, input_percentages, target_sizes = data
                    input_sizes = input_percentages.mul_(int(
                        inputs.size(3))).int()

                    # -- unflatten targets
                    split_targets = []
                    offset = 0
                    targets = targets.numpy()
                    for size in target_sizes:
                        split_targets.append(targets[offset:offset + size])
                        offset += size

                    inputs = inputs.to(device)
                    out, output_sizes = model(inputs, input_sizes)
                    decoded_output, _ = decoder.decode(out, output_sizes)
                    target_strings = decoder.convert_to_strings(split_targets)

                    # -- compute accuracy metrics
                    wer, cer = 0, 0

                    for x in range(len(target_strings)):

                        transcript, reference = decoded_output[x][
                            0], target_strings[x][0]
                        wer = decoder.wer(transcript, reference) / float(
                            len(reference.split()))
                        cer = decoder.cer(transcript, reference) / float(
                            len(reference))

                        total_wer += wer  #Tue: - Changed so saving of wer per audio file possible.
                        total_cer += cer  #Tue: - Changed so saving of wer per audio file possible

                        #filename,_ = meta[n] -- Does not choose the correct filename

                        transcriptScore.append([
                            currTime, score_ID, reference, wer, cer,
                            transcript, augmentation_list, epoch + 1, epochs
                        ])

                    #total_wer += wer --Tue: - so saving of wer per audio file possible.
                    #total_cer += cer --Tue: - so saving of wer per audio file possible.
                    del out

            if distributed:
                # -- sums tensor across all devices if distributed training is enabled
                total_wer_tensor = torch.tensor(total_wer).to(device)
                total_wer_tensor = sum_tensor(total_wer_tensor)
                total_wer = total_wer_tensor.item()

                total_cer_tensor = torch.tensor(total_cer).to(device)
                total_cer_tensor = sum_tensor(total_cer_tensor)
                total_cer = total_cer_tensor.item()

                del total_wer_tensor, total_cer_tensor

            # -- compute average metrics for the validation pass
            avg_wer_epoch = (total_wer /
                             len(validation_batch_loader.dataset)) * 100
            avg_cer_epoch = (total_cer /
                             len(validation_batch_loader.dataset)) * 100

            # -- append metrics for logging
            loss_results[epoch], wer_results[epoch], cer_results[
                epoch] = avg_loss, avg_wer_epoch, avg_cer_epoch

            # -- Johan: Logging

            if epoch > epoch - 5:
                scoresDict[ID]['Avg_WER'] += avg_wer_epoch
                scoresDict[ID]['Avg_CER'] += avg_cer_epoch
            else:
                pass
            #if epoch > 0:
            #    scoresDict[model_id]['Avg_WER'] = scoresDict[model_id]['Avg_WER']/(epoch+1)
            #    scoresDict[model_id]['Avg_CER'] = scoresDict[model_id]['Avg_CER']/(epoch+1)
            #else:
            #    pass
            #saveScores(scoresDict)

            # -- log metrics for tensorboard
            if logging_process:
                logging_values = {
                    "loss_results": loss_results,
                    "wer": avg_wer_epoch,
                    "cer": avg_cer_epoch
                }
                tensorboard_logger.update(epoch, logging_values)

            # -- print validation metrics summary
            print('Validation Summary Epoch: [{0}]\t'
                  'Average WER {wer:.3f}\t'
                  'Average CER {cer:.3f}\t'.format(epoch + 1,
                                                   wer=avg_wer_epoch,
                                                   cer=avg_cer_epoch))

            # -- save model if it has the highest recorded performance on validation.
            #OBS! changed from best_wer > wer to best_wer > avg_wer_epoch - Tue
            if main_proc and (best_wer is None) or (best_wer > avg_wer_epoch):
                model_path = model_save_dir + model_id + '.pth'
                best_transcript = transcriptScore.copy(
                )  # -Tue: Saving best performance to be saved in .csv

                # -- check if the model is uni or bidirectional, and set streaming model accordingly
                if not bidirectional:
                    streaming_inference_model = True
                else:
                    streaming_inference_model = False
                print("Found better validated model, saving to %s" %
                      model_path)
                torch.save(
                    serialize(model,
                              optimizer=optimizer,
                              epoch=epoch,
                              loss_results=loss_results,
                              wer_results=wer_results,
                              cer_results=cer_results,
                              distributed=distributed,
                              streaming_model=streaming_inference_model,
                              context=context), model_path)

                best_wer = avg_wer_epoch  #OBS! changed from wer to avg_wer_epoch - Tue
                avg_loss = 0

            # -- reset start iteration for next epoch
            start_iter = 0

        if epoch > epochs - 6:
            scoresDict[ID]['Avg_WER'] = scoresDict[ID]['Avg_WER'] / (epochs -
                                                                     5)
            scoresDict[ID]['Avg_CER'] = scoresDict[ID]['Avg_CER'] / (epochs -
                                                                     5)
            saveScores(scoresDict)

        if epoch == (
                epochs - 1
        ):  # when last epoch is run the scores will be saved or with keyboardinterrupt
            csvSaver(best_transcript)

    except KeyboardInterrupt:
        print('Successfully exited training and stopped all processes.')
        csvSaver(best_transcript)