Esempio n. 1
0
def train_epoch(epoch, global_step, model, loss_fun, metrics, config,
                max_label_length, device, tensorboard_logger,
                seq_len_scheduler, overfit_small_batch,
                starting_dataset_sampler_state, optimizers, lr_schedulers,
                do_validation, _valid_epoch, checkpoint_dir):
    """
    Training logic for an epoch

    :param epoch: Current training epoch.
    :return: A log that contains all information you want to save.

    Note:
        If you have additional information to record, for example:
            > additional_log = {"x": x, "y": y}
        merge it with log before return. i.e.
            > log = {**log, **additional_log}
            > return log

        The metrics in log must have the key 'metrics'.
    """
    model.train()
    tensorboard_logger.set_step(global_step, 'train')
    tr_data = config['dataset']['data_use']['train_with']
    _all_feats = config['dataset']['dataset_definition']['datasets'][tr_data][
        'features']
    _all_labs = config['dataset']['dataset_definition']['datasets'][tr_data][
        'labels']

    dataset = get_dataset(
        config['training']['dataset_type'],
        config['exp']['data_cache_root'],
        f"{tr_data}_{config['exp']['name']}",
        {feat: _all_feats[feat]
         for feat in config['dataset']['features_use']},
        {lab: _all_labs[lab]
         for lab in config['dataset']['labels_use']},
        config['training']['batching']['max_seq_length_train'],
        model.context_left,
        model.context_right,
        normalize_features=True,
        phoneme_dict=config['dataset']['dataset_definition']['phoneme_dict'],
        max_seq_len=seq_len_scheduler.get_seq_len(epoch),
        max_label_length=max_label_length,
        overfit_small_batch=overfit_small_batch)

    dataloader = KaldiDataLoader(
        dataset,
        config['training']['batching']['batch_size_train'],
        config["exp"]["n_gpu"] > 0,
        batch_ordering=model.batch_ordering,
        shuffle=True)

    if starting_dataset_sampler_state is not None:
        dataloader.sampler.load_state_dict(starting_dataset_sampler_state)
        starting_dataset_sampler_state = None

    assert len(dataset) >= config['training']['batching']['batch_size_train'], \
        f"Length of train dataset {len(dataset)} too small " \
        + f"for batch_size of {config['training']['batching']['batch_size_train']}"

    total_train_loss = 0
    total_train_metrics = {metric: 0 for metric in metrics}

    accumulated_train_losses = {}
    accumulated_train_metrics = {metric: 0 for metric in metrics}
    n_steps_chunk = 0
    last_train_logging = time.time()
    last_checkpoint = time.time()

    n_steps_this_epoch = 0

    with tqdm(disable=not logger.isEnabledFor(logging.INFO),
              total=len(dataloader),
              position=0) as pbar:
        pbar.set_description('T e:{} l: {} a: {}'.format(epoch, '-', '-'))
        pbar.update(dataloader.start())
        # TODO remove for epoch after 0

        for batch_idx, (_, inputs, targets) in enumerate(dataloader):
            global_step += 1
            n_steps_this_epoch += 1

            # TODO assert out.shape[1] >= lab_dnn.max() and lab_dnn.min() >= 0, \
            #     "lab_dnn max of {} is bigger than shape of output {} or min {} is smaller than 0" \
            #         .format(lab_dnn.max().cpu().numpy(), out.shape[1], lab_dnn.min().cpu().numpy())

            inputs = to_device(device, inputs)
            if "lab_phn" not in targets:
                targets = to_device(device, targets)

            for opti in optimizers.values():
                opti.zero_grad()

            with torch.autograd.detect_anomaly():
                # TODO check if there is a perfomance penalty
                output = model(inputs)
                loss = loss_fun(output, targets)
                loss["loss_final"].backward()

            if config['training']['clip_grad_norm'] > 0:
                trainable_params = filter(lambda p: p.requires_grad,
                                          model.parameters())
                torch.nn.utils.clip_grad_norm_(
                    trainable_params, config['training']['clip_grad_norm'])
            for opti in optimizers.values():
                opti.step()

            # detach so metrics etc. don't accumulate gradients
            inputs = detach(inputs)
            targets = detach(targets)
            loss = detach(loss)

            #### Logging ####
            n_steps_chunk += 1
            for _loss, loss_value in loss.items():
                if _loss not in accumulated_train_losses:
                    accumulated_train_losses[_loss] = 0
                accumulated_train_losses[_loss] += loss_value
            total_train_loss += loss["loss_final"]

            if config['exp']['compute_train_metrics']:
                """
                If the metric computation is fast like with plain accuracy on a discrete output, it is better to 
                perform it in a batched fashion on the GPU. 
                The alternative would be to copy the result (blocking) to the CPU and then compute 
                the metrics asynchronously (not batched).
                On the otherhand, if the metrics computation is not implemented on GPU or does not benefit from
                batching that much, it is preferred to copy the result (blocking) to the CPU and then compute 
                the metrics asynchronously (not batched).
                | main thread |  metrics thread |
                =================================
                     |        '
                forward pass  '
                     |        '        
                     +---> output -> comput metric -----+
                     |        '                         |
                forward pass  '                         +-> accumulate metrics
                     |        '                         |
                     +---> output -> comput metric -----+
                     |        '
                forward pass  '
                     |        '

                """

                _train_metrics = eval_metrics((output, targets), metrics)
                for metric, metric_value in _train_metrics.items():
                    accumulated_train_metrics[metric] += metric_value
                    total_train_metrics[metric] += metric_value

            pbar.set_description('T e:{} l: {:.4f}'.format(
                epoch, loss["loss_final"].item()))
            pbar.update()

            # Log training every 30s and smoothe since its the average
            if (time.time() - last_train_logging) > 30:
                # TODO add flag for delayed logging
                last_train_logging = time.time()
                tensorboard_logger.set_step(global_step, 'train')
                for _loss, loss_value in accumulated_train_losses.items():
                    tensorboard_logger.add_scalar(_loss,
                                                  loss_value / n_steps_chunk)

                if config['exp']['compute_train_metrics']:
                    for metric, metric_value in accumulated_train_metrics.items(
                    ):
                        tensorboard_logger.add_scalar(
                            metric, metric_value / n_steps_chunk)

                # most_recent_inputs = inputs
                # for feat_name in most_recent_inputs:
                #     if isinstance(most_recent_inputs[feat_name], dict) \
                #             and 'sequence_lengths' in most_recent_inputs[feat_name]:
                #         total_padding = torch.sum(
                #             (torch.ones_like(most_recent_inputs[feat_name]['sequence_lengths'])
                #              * most_recent_inputs[feat_name]['sequence_lengths'][0])
                #             - most_recent_inputs[feat_name]['sequence_lengths'])
                #         tensorboard_logger.add_scalar('total_padding_{}'.format(feat_name),
                #                                            total_padding.item())

                accumulated_train_losses = {}
                if config['exp']['compute_train_metrics']:
                    accumulated_train_metrics = {
                        metric: 0
                        for metric in metrics
                    }
                n_steps_chunk = 0

                if (time.time() - last_checkpoint
                    ) > config['exp']['checkpoint_interval_minutes'] * 60:
                    save_checkpoint(
                        epoch,
                        global_step,
                        model,
                        optimizers,
                        lr_schedulers,
                        seq_len_scheduler,
                        config,
                        checkpoint_dir,
                        dataset_sampler_state=dataloader.sampler.state_dict())

                    last_checkpoint = time.time()

                #### /Logging ####

            del inputs
            del targets

    if n_steps_this_epoch > 0:
        tensorboard_logger.set_step(epoch, 'train')
        tensorboard_logger.add_scalar('train_loss_avg',
                                      total_train_loss / n_steps_this_epoch)
        if config['exp']['compute_train_metrics']:
            for metric in total_train_metrics:
                tensorboard_logger.add_scalar(
                    metric + "_avg",
                    total_train_metrics[metric] / n_steps_this_epoch)

        # TODO add this flag to vlaid since ctcdecode is f*****g slow or do it async
        if config['exp']['compute_train_metrics']:
            log = {
                'train_loss_avg': total_train_loss / n_steps_this_epoch,
                'train_metrics_avg': {
                    metric: total_train_metrics[metric] / n_steps_this_epoch
                    for metric in total_train_metrics
                }
            }
        else:
            log = {'train_loss_avg': total_train_loss / n_steps_this_epoch}
        if do_validation and (not overfit_small_batch or epoch == 1):
            valid_log = _valid_epoch(epoch)
            log.update(valid_log)
        else:
            log.update({'valid_loss': -1, 'valid_metrics': {}})
    else:
        raise RuntimeError("Training epoch hat 0 batches.")

    return log, starting_dataset_sampler_state, global_step
Esempio n. 2
0
def valid_epoch_sync_metrics(epoch, model, loss_fun, metrics, config,
                             max_label_length, device, tensorboard_logger):
    model.eval()

    valid_loss = 0
    accumulated_valid_metrics = {metric: 0 for metric in metrics}

    valid_data = config['dataset']['data_use']['valid_with']
    _all_feats = config['dataset']['dataset_definition']['datasets'][
        valid_data]['features']
    _all_labs = config['dataset']['dataset_definition']['datasets'][
        valid_data]['labels']
    dataset = get_dataset(
        config['training']['dataset_type'],
        config['exp']['data_cache_root'],
        f"{valid_data}_{config['exp']['name']}",
        {feat: _all_feats[feat]
         for feat in config['dataset']['features_use']},
        {lab: _all_labs[lab]
         for lab in config['dataset']['labels_use']},
        config['training']['batching']['max_seq_length_valid'],
        model.context_left,
        model.context_right,
        normalize_features=True,
        phoneme_dict=config['dataset']['dataset_definition']['phoneme_dict'],
        max_seq_len=config['training']['batching']['max_seq_length_valid'],
        max_label_length=max_label_length)

    if config['training']['batching']['batch_size_valid'] != 1:
        logger.warn("setting valid batch size to 1 to avoid padding zeros")
    dataloader = KaldiDataLoader(
        dataset,
        config['training']['batching']['batch_size_valid'],
        config["exp"]["n_gpu"] > 0,
        batch_ordering=model.batch_ordering)

    assert len(dataset) >= config['training']['batching']['batch_size_valid'], \
        f"Length of valid dataset {len(dataset)} too small " \
        + f"for batch_size of {config['training']['batching']['batch_size_valid']}"

    n_steps_this_epoch = 0
    with tqdm(disable=not logger.isEnabledFor(logging.INFO),
              total=len(dataloader)) as pbar:
        pbar.set_description('V e:{} l: {} '.format(epoch, '-'))
        for batch_idx, (sample_name, inputs, targets) in enumerate(dataloader):
            n_steps_this_epoch += 1

            inputs = to_device(device, inputs)
            if "lab_phn" not in targets:
                targets = to_device(device, targets)

            output = model(inputs)
            loss = loss_fun(output, targets)

            output = detach_cpu(output)
            targets = detach_cpu(targets)
            loss = detach_cpu(loss)

            #### Logging ####
            valid_loss += loss["loss_final"].item()
            _valid_metrics = eval_metrics((output, targets), metrics)
            for metric, metric_value in _valid_metrics.items():
                accumulated_valid_metrics[metric] += metric_value

            pbar.set_description('V e:{} l: {:.4f} '.format(
                epoch, loss["loss_final"].item()))
            pbar.update()

            do_plotting = True
            if n_steps_this_epoch == 60 or n_steps_this_epoch == 1 and do_plotting:
                # raise NotImplementedError("TODO: add plots to tensorboard")
                output = output['out_phn']
                inputs = inputs["fbank"].numpy()
                _phoneme_dict = dataset.state.phoneme_dict
                vocabulary_size = len(
                    dataset.state.phoneme_dict.reducedIdx2phoneme) + 1
                vocabulary = [
                    chr(c) for c in list(range(65, 65 + 58)) +
                    list(range(65 + 58 + 69, 65 + 58 + 69 + 500))
                ][:vocabulary_size]
                decoder = ctcdecode.CTCBeamDecoder(vocabulary,
                                                   log_probs_input=True,
                                                   beam_width=1)

                decoder_logits = output.permute(0, 2, 1)
                # We expect batch x seq x label_size
                beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(
                    decoder_logits)

                _targets = []
                curr_l = 0
                for l in targets['target_sequence_lengths']:
                    _targets.append(targets['lab_phn'][curr_l:curr_l + l])
                    curr_l += l
                for i in range(len(inputs)):
                    _beam_result = beam_result[i, 0, :out_seq_len[i, 0]]
                    # logger.debug(sample_name)
                    result_decoded = [
                        _phoneme_dict.reducedIdx2phoneme[l.item() - 1]
                        for l in _beam_result
                    ]
                    result_decoded = " ".join(result_decoded)
                    logger.debug("RES: " + result_decoded)
                    # plot_phns = True
                    # if plot_phns:
                    label_decoded = " ".join([
                        _phoneme_dict.reducedIdx2phoneme[l.item() - 1]
                        for l in _targets[i]
                    ])
                    logger.debug("LAB: " + label_decoded)
                    text = sample_id_to_transcript(
                        sample_name[i],
                        "/mnt/data/datasets/LibriSpeech/dev-clean")
                    logger.debug("TXT: " + text)

                    # if plot_phns:
                    plot_alignment_spectrogram_ctc(
                        sample_name[i],
                        inputs[i],
                        (np.exp(output.numpy()[i]).T /
                         np.exp(output.numpy()[i]).sum(axis=1)).T,
                        _phoneme_dict,
                        label_decoded,
                        text,
                        result_decoded=result_decoded)
                    # else:
                    #     plot_alignment_spectrogram(sample_name, inputs["fbank"][i],
                    #                                (np.exp(output).T / np.exp(output).sum(axis=1)).T,
                    #                                _phoneme_dict, result_decoded=result_decoded)

            #### /Logging ####
    for metric, metric_value in accumulated_valid_metrics.items():
        accumulated_valid_metrics[metric] += metric_value

    tensorboard_logger.set_step(epoch, 'valid')
    tensorboard_logger.add_scalar('valid_loss',
                                  valid_loss / n_steps_this_epoch)
    for metric in accumulated_valid_metrics:
        tensorboard_logger.add_scalar(
            metric, accumulated_valid_metrics[metric] / n_steps_this_epoch)

    return {
        'valid_loss': valid_loss / n_steps_this_epoch,
        'valid_metrics': {
            metric: accumulated_valid_metrics[metric] / n_steps_this_epoch
            for metric in accumulated_valid_metrics
        }
    }
Esempio n. 3
0
def evaluate(model,
             metrics,
             device,
             out_folder,
             exp_name,
             max_label_length,
             epoch,
             dataset_type,
             data_cache_root,
             test_with,
             all_feats_dict,
             features_use,
             all_labs_dict,
             labels_use,
             phoneme_dict,
             decoding_info,
             lab_graph_dir=None,
             tensorboard_logger=None):
    model.eval()
    batch_size = 1
    max_seq_length = -1

    accumulated_test_metrics = {metric: 0 for metric in metrics}

    test_data = test_with
    dataset = get_dataset(
        dataset_type,
        data_cache_root,
        f"{test_data}_{exp_name}",
        {feat: all_feats_dict[feat]
         for feat in features_use},
        {lab: all_labs_dict[lab]
         for lab in labels_use},
        max_seq_length,
        model.context_left,
        model.context_right,
        normalize_features=True,
        phoneme_dict=phoneme_dict,
        max_seq_len=max_seq_length,
        max_label_length=max_label_length)

    dataloader = KaldiDataLoader(dataset,
                                 batch_size,
                                 use_gpu=False,
                                 batch_ordering=model.batch_ordering)

    assert len(dataset) >= batch_size, \
        f"Length of test dataset {len(dataset)} too small " \
        + f"for batch_size of {batch_size}"

    n_steps_this_epoch = 0
    warned_size = False

    with Pool(os.cpu_count()) as pool:
        multip_process = Manager()
        metrics_q = multip_process.Queue(maxsize=os.cpu_count())
        # accumulated_test_metrics_future_list = pool.apply_async(metrics_accumulator, (metrics_q, metrics))
        accumulated_test_metrics_future_list = [
            pool.apply_async(metrics_accumulator, (metrics_q, metrics))
            for _ in range(os.cpu_count())
        ]
        with KaldiOutputWriter(out_folder, test_data, model.out_names,
                               epoch) as writer:
            with tqdm(disable=not logger.isEnabledFor(logging.INFO),
                      total=len(dataloader),
                      position=0) as pbar:
                pbar.set_description('E e:{}    '.format(epoch))
                for batch_idx, (sample_names, inputs,
                                targets) in enumerate(dataloader):
                    n_steps_this_epoch += 1

                    inputs = to_device(device, inputs)
                    if "lab_phn" not in targets:
                        targets = to_device(device, targets)

                    output = model(inputs)

                    output = detach_cpu(output)
                    targets = detach_cpu(targets)

                    #### Logging ####
                    metrics_q.put((output, targets))

                    pbar.set_description('E e:{} '.format(epoch))
                    pbar.update()
                    #### /Logging ####

                    warned_label = False
                    for output_label in output:
                        if output_label in model.out_names:
                            # squeeze that batch
                            output[output_label] = output[
                                output_label].squeeze(1)
                            # remove blank/padding 0th dim
                            # if config["arch"]["framewise_labels"] == "shuffled_frames":
                            out_save = output[output_label].data.cpu().numpy()
                            # else:
                            #     raise NotImplementedError("TODO make sure the right dimension is taken")
                            #     out_save = output[output_label][:, :-1].data.cpu().numpy()

                            if len(out_save.shape
                                   ) == 3 and out_save.shape[0] == 1:
                                out_save = out_save.squeeze(0)

                            if dataset.state.dataset_type != DatasetType.SEQUENTIAL_APPENDED_CONTEXT \
                                    and dataset.state.dataset_type != DatasetType.SEQUENTIAL:
                                raise NotImplementedError(
                                    "TODO rescaling with prior")

                            # if config['dataset']['dataset_definition']['decoding']['normalize_posteriors']:
                            #     # read the config file
                            #     counts = config['dataset']['dataset_definition'] \
                            #         ['data_info']['labels']['lab_phn']['lab_count']
                            #     if out_save.shape[-1] == len(counts) - 1:
                            #         if not warned_size:
                            #             logger.info(
                            #                 f"Counts length is {len(counts)} but output"
                            #                 + f" has size {out_save.shape[-1]}."
                            #                 + f" Assuming that counts is 1 indexed")
                            #             warned_size = True
                            #         counts = counts[1:]
                            #     # Normalize by output count
                            # #     if ctc:
                            # #         blank_scale = 1.0
                            # #         # TODO try different blank_scales 4.0 5.0 6.0 7.0
                            # #         counts[0] /= blank_scale
                            # #         # for i in range(1, 8):
                            # #         #     counts[i] /= noise_scale #TODO try noise_scale for SIL SPN etc I guess
                            # #
                            # #     prior = np.log(counts / np.sum(counts))
                            #
                            #     out_save = out_save - np.log(prior)

                            # shape == NC
                            assert len(out_save.shape) == 2
                            assert len(sample_names) == 1
                            writer.write_mat(output_label, out_save.squeeze(),
                                             sample_names[0])

                        else:
                            if not warned_label:
                                logger.debug(
                                    "Skipping saving forward for decoding for key {}"
                                    .format(output_label))
                                warned_label = True

            for _accumulated_test_metrics in accumulated_test_metrics_future_list:
                metrics_q.put(None)
            for _accumulated_test_metrics in accumulated_test_metrics_future_list:
                _accumulated_test_metrics = _accumulated_test_metrics.get()
                for metric, metric_value in _accumulated_test_metrics.items():
                    accumulated_test_metrics[metric] += metric_value

    # test_metrics = {metric: 0 for metric in metrics}
    # for metric in accumulated_test_metrics:
    #     for metric, metric_value in metric.items():
    #         test_metrics[metric] += metric_value

    test_metrics = {
        metric: accumulated_test_metrics[metric] / len(dataloader)
        for metric in accumulated_test_metrics
    }
    if tensorboard_logger is not None:
        tensorboard_logger.set_step(epoch, 'eval')
        for metric, metric_value in test_metrics.items():
            tensorboard_logger.add_scalar(
                metric, test_metrics[metric] / len(dataloader))

    # decoding_results = []
    #### DECODING ####
    # for out_lab in model.out_names:
    out_lab = model.out_names[0]  # TODO query from model or sth

    # forward_data_lst = config['data_use']['test_with'] #TODO multiple forward sets
    # forward_data_lst = [config['dataset']['data_use']['test_with']]
    # forward_dec_outs = config['test'][out_lab]['require_decoding']

    # for data in forward_data_lst:
    logger.debug('Decoding {} output {}'.format(test_with, out_lab))

    if out_lab == 'out_cd':
        _label = 'lab_cd'
    elif out_lab == 'out_phn':
        _label = 'lab_phn'
    else:
        raise NotImplementedError(out_lab)

    lab_field = all_labs_dict[_label]

    out_folder = os.path.abspath(out_folder)
    out_dec_folder = '{}/decode_{}_{}'.format(out_folder, test_with, out_lab)

    # logits_test_clean_100_ep006_out_phn.ark
    files_dec_list = glob(
        f'{out_folder}/exp_files/logits_{test_with}_ep*_{out_lab}.ark')

    if lab_graph_dir is None:
        lab_graph_dir = os.path.abspath(lab_field['lab_graph'])
    if _label == 'lab_phn':
        decode_ctc(data=os.path.abspath(lab_field['lab_data_folder']),
                   graphdir=lab_graph_dir,
                   out_folder=out_dec_folder,
                   featstrings=files_dec_list)
    elif _label == 'lab_cd':
        decode_ce(**decoding_info,
                  alidir=os.path.abspath(lab_field['label_folder']),
                  data=os.path.abspath(lab_field['lab_data_folder']),
                  graphdir=lab_graph_dir,
                  out_folder=out_dec_folder,
                  featstrings=files_dec_list)
    else:
        raise ValueError(_label)

    decoding_results = best_wer(out_dec_folder, decoding_info['scoring_type'])
    logger.info(decoding_results)

    tensorboard_logger.add_text("WER results", str(decoding_results))

    # TODO plotting curves

    return {'test_metrics': test_metrics, "decoding_results": decoding_results}
Esempio n. 4
0
def valid_epoch_async_metrics(epoch, model, loss_fun, metrics, config,
                              max_label_length, device, tensorboard_logger):
    """
    Validate after training an epoch
    :return: A log that contains information about validation
    Note:
        The validation metrics in log must have the key 'val_metrics'.
    """
    model.eval()

    valid_loss = 0
    accumulated_valid_metrics = {metric: 0 for metric in metrics}

    valid_data = config['dataset']['data_use']['valid_with']
    _all_feats = config['dataset']['dataset_definition']['datasets'][
        valid_data]['features']
    _all_labs = config['dataset']['dataset_definition']['datasets'][
        valid_data]['labels']
    dataset = get_dataset(
        config['training']['dataset_type'],
        config['exp']['data_cache_root'],
        f"{valid_data}_{config['exp']['name']}",
        {feat: _all_feats[feat]
         for feat in config['dataset']['features_use']},
        {lab: _all_labs[lab]
         for lab in config['dataset']['labels_use']},
        config['training']['batching']['max_seq_length_valid'],
        model.context_left,
        model.context_right,
        normalize_features=True,
        phoneme_dict=config['dataset']['dataset_definition']['phoneme_dict'],
        max_seq_len=config['training']['batching']['max_seq_length_valid'],
        max_label_length=max_label_length)

    dataloader = KaldiDataLoader(
        dataset,
        config['training']['batching']['batch_size_valid'],
        config["exp"]["n_gpu"] > 0,
        batch_ordering=model.batch_ordering)

    assert len(dataset) >= config['training']['batching']['batch_size_valid'], \
        f"Length of valid dataset {len(dataset)} too small " \
        + f"for batch_size of {config['training']['batching']['batch_size_valid']}"

    n_steps_this_epoch = 0

    with Pool(os.cpu_count()) as pool:
        multip_process = Manager()
        metrics_q = multip_process.Queue(maxsize=os.cpu_count())
        # accumulated_valid_metrics_future_list = pool.apply_async(metrics_accumulator, (metrics_q, metrics))
        accumulated_valid_metrics_future_list = [
            pool.apply_async(metrics_accumulator, (metrics_q, metrics))
            for _ in range(os.cpu_count())
        ]
        with tqdm(disable=not logger.isEnabledFor(logging.INFO),
                  total=len(dataloader)) as pbar:
            pbar.set_description('V e:{} l: {} '.format(epoch, '-'))
            for batch_idx, (_, inputs, targets) in enumerate(dataloader):
                n_steps_this_epoch += 1

                inputs = to_device(device, inputs)
                if "lab_phn" not in targets:
                    targets = to_device(device, targets)

                output = model(inputs)
                loss = loss_fun(output, targets)

                output = detach_cpu(output)
                targets = detach_cpu(targets)
                loss = detach_cpu(loss)

                #### Logging ####
                valid_loss += loss["loss_final"].item()
                metrics_q.put((output, targets))
                # _valid_metrics = eval_metrics((output, targets), metrics)
                # for metric, metric_value in _valid_metrics.items():
                #     accumulated_valid_metrics[metric] += metric_value

                pbar.set_description('V e:{} l: {:.4f} '.format(
                    epoch, loss["loss_final"].item()))
                pbar.update()
                #### /Logging ####
        for _accumulated_valid_metrics in accumulated_valid_metrics_future_list:
            metrics_q.put(None)
        for _accumulated_valid_metrics in accumulated_valid_metrics_future_list:
            _accumulated_valid_metrics = _accumulated_valid_metrics.get()
            for metric, metric_value in _accumulated_valid_metrics.items():
                accumulated_valid_metrics[metric] += metric_value

    tensorboard_logger.set_step(epoch, 'valid')
    tensorboard_logger.add_scalar('valid_loss',
                                  valid_loss / n_steps_this_epoch)
    logger.info(f'valid_loss: {valid_loss / n_steps_this_epoch}')
    for metric in accumulated_valid_metrics:
        tensorboard_logger.add_scalar(
            metric, accumulated_valid_metrics[metric] / n_steps_this_epoch)
        logger.info(
            f'{metric}: {accumulated_valid_metrics[metric] / n_steps_this_epoch}'
        )

    return {
        'valid_loss': valid_loss / n_steps_this_epoch,
        'valid_metrics': {
            metric: accumulated_valid_metrics[metric] / n_steps_this_epoch
            for metric in accumulated_valid_metrics
        }
    }