def test_make_batchset(self): dummy_json = make_dummy_json(128, [128, 512], [16, 128]) for task in espnet_utils.TASK_SET: # check w/o adaptive batch size batchset = espnet_utils.make_batchset(task, data=dummy_json, batch_size=24, max_length_in=2**10, max_length_out=2**10, min_batch_size=1) self.assertEqual(sum([len(batch) >= 1 for batch in batchset]), len(batchset)) logging.info('batch: {}'.format( ([len(batch) for batch in batchset]))) batchset = espnet_utils.make_batchset(task, dummy_json, 24, 2**10, 2**10, min_batch_size=10) self.assertEqual(sum([len(batch) >= 10 for batch in batchset]), len(batchset)) logging.info('batch: {}'.format( ([len(batch) for batch in batchset]))) # check w/ adaptive batch size batchset = espnet_utils.make_batchset(task, dummy_json, 24, 256, 64, min_batch_size=10) self.assertEqual(sum([len(batch) >= 10 for batch in batchset]), len(batchset)) logging.info('batch: {}'.format( ([len(batch) for batch in batchset]))) batchset = espnet_utils.make_batchset(task, dummy_json, 24, 256, 64, min_batch_size=10) self.assertEqual(sum([len(batch) >= 10 for batch in batchset]), len(batchset))
def test_sortagrad(self): dummy_json = make_dummy_json(128, [1, 700], [1, 700]) for task in espnet_utils.TASK_SET: if task == 'tts': batchset = espnet_utils.make_batchset(task, dummy_json, 16, 2**10, 2**10, batch_sort_key="input", shortest_first=True) key = 'output' elif task == 'asr': batchset = espnet_utils.make_batchset(task, dummy_json, 16, 2**10, 2**10, batch_sort_key='input', shortest_first=True) key = 'input' prev_start_ilen = batchset[0][0][1][key][0]['shape'][0] for batch in batchset: # short to long cur_start_ilen = batch[0][1][key][0]['shape'][0] self.assertGreaterEqual(cur_start_ilen, prev_start_ilen) prev_ilen = cur_start_ilen for sample in batch: cur_ilen = sample[1][key][0]['shape'][0] # long to short in minibatch self.assertLessEqual(cur_ilen, prev_ilen) prev_ilen = cur_ilen prev_start_ilen = cur_start_ilen