Beispiel #1
0
class ASRConverter(Converter):
    ''' ASR preprocess '''
    def __init__(self, config):
        super().__init__(config)
        taskconf = self.config['data']['task']
        assert taskconf['type'] == TASK_SET['asr']
        self.subsampling_factor = taskconf['src']['subsampling_factor']
        self.preprocess_conf = taskconf['src']['preprocess_conf']
        # mode: asr or tts
        self.load_inputs_and_targets = LoadInputsAndTargets(
            mode=taskconf['type'],
            load_output=True,
            preprocess_conf=self.preprocess_conf)

    #pylint: disable=arguments-differ
    #pylint: disable=too-many-branches
    def transform(self, batch):
        """Function to load inputs, targets and uttid from list of dicts

    :param List[Tuple[str, dict]] batch: list of dict which is subset of
        loaded data.json
    :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]
    :return: list of input feature sequences
        [(T_1, D), (T_2, D), ..., (T_B, D)]
    :rtype: list of float ndarray
    :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]
    :rtype: list of int ndarray
    Reference: Espnet source code, /espnet/utils/io_utils.py
               https://github.com/espnet/espnet/blob/master/espnet/utils/io_utils.py
    """
        x_feats_dict = OrderedDict()  # OrderedDict[str, List[np.ndarray]]
        y_feats_dict = OrderedDict()  # OrderedDict[str, List[np.ndarray]]
        uttid_list = []  # List[str]

        mode = self.load_inputs_and_targets.mode
        for uttid, info in batch:
            uttid_list.append(uttid)

            if self.load_inputs_and_targets.load_input:
                # Note(kamo): This for-loop is for multiple inputs
                for idx, inp in enumerate(info['input']):
                    # {"input":
                    #  [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
                    #    "filetype": "hdf5",
                    #    "name": "input1", ...}], ...}

                    #pylint: disable=protected-access
                    x_data = self.load_inputs_and_targets._get_from_loader(
                        filepath=inp['feat'],
                        filetype=inp.get('filetype', 'mat'))
                    x_feats_dict.setdefault(inp['name'], []).append(x_data)

            elif mode == 'tts' and self.load_inputs_and_targets.use_speaker_embedding:
                for idx, inp in enumerate(info['input']):
                    if idx != 1 and len(info['input']) > 1:
                        x_data = None
                    else:
                        x_data = self.load_inputs_and_targets._get_from_loader(  #pylint: disable=protected-access
                            filepath=inp['feat'],
                            filetype=inp.get('filetype', 'mat'))
                    x_feats_dict.setdefault(inp['name'], []).append(x_data)

            if self.load_inputs_and_targets.load_output:
                for idx, inp in enumerate(info['output']):
                    if 'tokenid' in inp:
                        # ======= Legacy format for output =======
                        # {"output": [{"tokenid": "1 2 3 4"}])
                        x_data = np.fromiter(map(int, inp['tokenid'].split()),
                                             dtype=np.int64)
                    else:
                        # ======= New format =======
                        # {"input":
                        #  [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
                        #    "filetype": "hdf5",
                        #    "name": "target1", ...}], ...}
                        x_data = self.load_inputs_and_targets._get_from_loader(  #pylint: disable=protected-access
                            filepath=inp['feat'],
                            filetype=inp.get('filetype', 'mat'))

                    y_feats_dict.setdefault(inp['name'], []).append(x_data)
        if self.load_inputs_and_targets.mode == 'asr':
            #pylint: disable=protected-access
            return_batch, uttid_list = self.load_inputs_and_targets._create_batch_asr(
                x_feats_dict, y_feats_dict, uttid_list)

        elif self.load_inputs_and_targets.mode == 'tts':
            _, info = batch[0]
            eos = int(info['output'][0]['shape'][1]) - 1
            #pylint: disable=protected-access
            return_batch, uttid_list = self.load_inputs_and_targets._create_batch_tts(
                x_feats_dict, y_feats_dict, uttid_list, eos)
        else:
            raise NotImplementedError

        if self.load_inputs_and_targets.preprocessing is not None:
            # Apply pre-processing only to input1 feature, now
            if 'input1' in return_batch:
                return_batch['input1'] = \
                    self.load_inputs_and_targets.preprocessing(return_batch['input1'], uttid_list,
                                       **self.load_inputs_and_targets.preprocess_args)

        # Doesn't return the names now.
        return tuple(return_batch.values()), uttid_list