def input_fn(dataset, mode, batch_size, num_epoch=None): ''' params: dataset, tf.data.Dataset params: mode, learning phase params: batch size params: num of epoch ''' if mode == utils.TRAIN: _, num_gpus = utils.gpu_device_names() per_device_batch_size = utils.per_device_batch_size( batch_size, num_gpus) else: # using one device to eval or infer, # otherwise will drop reminder samples, e.g. 32 batch with 3 gpus per_device_batch_size = batch_size num_epoch = 1 logging.info( "Learning Phase: {}, Total Batch size:{}, Per device batch size: {}". format(mode, batch_size, per_device_batch_size)) def _input_fn(): return dataset(mode, per_device_batch_size, num_epoch) return _input_fn
def get_batches(config, mode): ''' make batches of metas and get dataset size''' assert mode in (utils.TRAIN, utils.EVAL, utils.INFER) # read meta of json logging.info("load json data") json_path = config['data'][mode]['paths'] assert len(json_path) == 1 #pylint: disable=invalid-name with open(json_path[0], 'r', encoding='utf-8') as f: metas_raw = json.load(f)['utts'] # sort by utts id metas = OrderedDict(sorted(metas_raw.items(), key=lambda t: t[0])) # dataset size utts = len(metas.keys()) logging.info('# utts: ' + str(utts)) # make batchset use_sortagrad = config['data']['task']['sortagrad'] task = config['data']['task']['type'] assert task in list(TASK_SET.keys()) # using same json for asr and tts task if task == TASK_SET['asr']: src = 'src' tgt = 'tgt' elif task == TASK_SET['tts']: src = 'tgt' tgt = 'src' else: raise ValueError("task type must int : {} get : {}".format( list(TASK_SET.keys()), task)) maxlen_src = config['data']['task'][src]['max_len'] maxlen_tgt = config['data']['task'][tgt]['max_len'] batch_sort_key = config['data']['task']['batch_sort_key'] num_batches = config['data']['task']['num_batches'] _, ngpu = utils.gpu_device_names() global_batch_size = config['solver']['optimizer']['batch_size'] batch_size = utils.per_device_batch_size(global_batch_size, ngpu) batch_bins = config['solver']['optimizer']['batch_bins'] batch_frames_in = config['solver']['optimizer']['batch_frames_in'] batch_frames_out = config['solver']['optimizer']['batch_frames_out'] batch_frames_inout = config['solver']['optimizer']['batch_frames_inout'] batch_strategy = config['solver']['optimizer']['batch_strategy'] minibatches = make_batchset(task=task, data=metas, batch_size=batch_size, max_length_in=maxlen_src, max_length_out=maxlen_tgt, num_batches=num_batches, batch_sort_key=batch_sort_key, min_batch_size=ngpu if ngpu else 1, shortest_first=use_sortagrad, batch_bins=batch_bins, batch_frames_in=batch_frames_in, batch_frames_out=batch_frames_out, batch_frames_inout=batch_frames_inout, batch_strategy=batch_strategy) return {'data': minibatches, 'n_utts': utts}