Exemplo n.º 1
0
def input_fn(dataset, mode, batch_size, num_epoch=None):
    '''
  params: dataset, tf.data.Dataset
  params: mode, learning phase
  params: batch size
  params: num of epoch
  '''
    if mode == utils.TRAIN:
        _, num_gpus = utils.gpu_device_names()
        per_device_batch_size = utils.per_device_batch_size(
            batch_size, num_gpus)
    else:
        # using one device to eval or infer,
        # otherwise will drop reminder samples, e.g. 32 batch with 3 gpus
        per_device_batch_size = batch_size
        num_epoch = 1

    logging.info(
        "Learning Phase: {}, Total Batch size:{}, Per device batch size: {}".
        format(mode, batch_size, per_device_batch_size))

    def _input_fn():
        return dataset(mode, per_device_batch_size, num_epoch)

    return _input_fn
Exemplo n.º 2
0
def get_batches(config, mode):
    ''' make batches of metas and get dataset size'''
    assert mode in (utils.TRAIN, utils.EVAL, utils.INFER)

    # read meta of json
    logging.info("load json data")
    json_path = config['data'][mode]['paths']
    assert len(json_path) == 1
    #pylint: disable=invalid-name
    with open(json_path[0], 'r', encoding='utf-8') as f:
        metas_raw = json.load(f)['utts']

    # sort by utts id
    metas = OrderedDict(sorted(metas_raw.items(), key=lambda t: t[0]))

    # dataset size
    utts = len(metas.keys())
    logging.info('# utts: ' + str(utts))

    # make batchset
    use_sortagrad = config['data']['task']['sortagrad']
    task = config['data']['task']['type']
    assert task in list(TASK_SET.keys())
    # using same json for asr and tts task
    if task == TASK_SET['asr']:
        src = 'src'
        tgt = 'tgt'
    elif task == TASK_SET['tts']:
        src = 'tgt'
        tgt = 'src'
    else:
        raise ValueError("task type must int : {} get : {}".format(
            list(TASK_SET.keys()), task))
    maxlen_src = config['data']['task'][src]['max_len']
    maxlen_tgt = config['data']['task'][tgt]['max_len']
    batch_sort_key = config['data']['task']['batch_sort_key']
    num_batches = config['data']['task']['num_batches']
    _, ngpu = utils.gpu_device_names()
    global_batch_size = config['solver']['optimizer']['batch_size']
    batch_size = utils.per_device_batch_size(global_batch_size, ngpu)
    batch_bins = config['solver']['optimizer']['batch_bins']
    batch_frames_in = config['solver']['optimizer']['batch_frames_in']
    batch_frames_out = config['solver']['optimizer']['batch_frames_out']
    batch_frames_inout = config['solver']['optimizer']['batch_frames_inout']
    batch_strategy = config['solver']['optimizer']['batch_strategy']

    minibatches = make_batchset(task=task,
                                data=metas,
                                batch_size=batch_size,
                                max_length_in=maxlen_src,
                                max_length_out=maxlen_tgt,
                                num_batches=num_batches,
                                batch_sort_key=batch_sort_key,
                                min_batch_size=ngpu if ngpu else 1,
                                shortest_first=use_sortagrad,
                                batch_bins=batch_bins,
                                batch_frames_in=batch_frames_in,
                                batch_frames_out=batch_frames_out,
                                batch_frames_inout=batch_frames_inout,
                                batch_strategy=batch_strategy)

    return {'data': minibatches, 'n_utts': utts}