def __init__(self, config, mode): super().__init__(config, mode) self.dummy = config['data']['task']['dummy'] self.batch_mode = config['data']['task']['batch_mode'] self.batch_size = config['solver']['optimizer']['batch_size'] self._shuffle_buffer_size = config['data']['task'][ 'shuffle_buffer_size'] self._need_shuffle = config['data']['task']['need_shuffle'] # get batches form data path if self.dummy: self._feat_shape = [40] logging.info("Dummy data: feat {}".format(self.feat_shape)) self._vocab_size = 100 else: data_metas = espnet_utils.get_batches(self.config, mode) self.batches = data_metas['data'] self.n_utts = data_metas['n_utts'] logging.info("utts: {}".format(self.n_utts)) # [nframe, feat_shape, ...] self._feat_shape = self.batches[0][0][1]['input'][0]['shape'][1:] # [tgt_len, vocab_size] self._vocab_size = self.batches[0][0][1]['output'][0]['shape'][1] logging.info('#input feat shape: ' + str(self.feat_shape)) logging.info('#output dims: ' + str(self.vocab_size)) self._converter = espnet_utils.ASRConverter(self.config)
def test_converter(self): np.random.seed(100) nexamples_list = (10, 12) batch_size = 4 self.config['solver']['optimizer']['batch_size'] = batch_size converter = espnet_utils.ASRConverter(self.config) for mode in (utils.TRAIN, utils.EVAL, utils.INFER): for nexamples in nexamples_list: desire_xs, desire_ilens, desire_ys, desire_olens = generate_json_data( self.config, mode, nexamples) del desire_xs, desire_ilens, desire_ys, desire_olens data_metas = espnet_utils.get_batches(self.config, mode) batches = data_metas['data'] n_utts = data_metas['n_utts'] self.assertEqual(n_utts, nexamples) o_uttids = [] o_xs = [] o_ilens = [] o_ys = [] o_olens = [] for _, batch in enumerate(batches): batch_data = converter(batch) self.assertEqual(len(batch_data), 5) xs, ilens, ys, olens, uttids = batch_data for x, ilen, y, olen, uttid in zip(xs, ilens, ys, olens, uttids): self.assertDTypeEqual(x, np.float32) self.assertDTypeEqual(y, np.int64) o_uttids.append(uttid) o_xs.append(x) o_ilens.append(ilen) o_ys.append(y) o_olens.append(olen) self.assertEqual(len(o_xs), nexamples)