Пример #1
0
    def test_make_batchset(self):
        dummy_json = make_dummy_json(128, [128, 512], [16, 128])
        for task in espnet_utils.TASK_SET:
            # check w/o adaptive batch size
            batchset = espnet_utils.make_batchset(task,
                                                  data=dummy_json,
                                                  batch_size=24,
                                                  max_length_in=2**10,
                                                  max_length_out=2**10,
                                                  min_batch_size=1)
            self.assertEqual(sum([len(batch) >= 1 for batch in batchset]),
                             len(batchset))
            logging.info('batch: {}'.format(
                ([len(batch) for batch in batchset])))

            batchset = espnet_utils.make_batchset(task,
                                                  dummy_json,
                                                  24,
                                                  2**10,
                                                  2**10,
                                                  min_batch_size=10)
            self.assertEqual(sum([len(batch) >= 10 for batch in batchset]),
                             len(batchset))
            logging.info('batch: {}'.format(
                ([len(batch) for batch in batchset])))

            # check w/ adaptive batch size
            batchset = espnet_utils.make_batchset(task,
                                                  dummy_json,
                                                  24,
                                                  256,
                                                  64,
                                                  min_batch_size=10)
            self.assertEqual(sum([len(batch) >= 10 for batch in batchset]),
                             len(batchset))
            logging.info('batch: {}'.format(
                ([len(batch) for batch in batchset])))

            batchset = espnet_utils.make_batchset(task,
                                                  dummy_json,
                                                  24,
                                                  256,
                                                  64,
                                                  min_batch_size=10)
            self.assertEqual(sum([len(batch) >= 10 for batch in batchset]),
                             len(batchset))
Пример #2
0
    def test_sortagrad(self):
        dummy_json = make_dummy_json(128, [1, 700], [1, 700])

        for task in espnet_utils.TASK_SET:
            if task == 'tts':
                batchset = espnet_utils.make_batchset(task,
                                                      dummy_json,
                                                      16,
                                                      2**10,
                                                      2**10,
                                                      batch_sort_key="input",
                                                      shortest_first=True)
                key = 'output'
            elif task == 'asr':
                batchset = espnet_utils.make_batchset(task,
                                                      dummy_json,
                                                      16,
                                                      2**10,
                                                      2**10,
                                                      batch_sort_key='input',
                                                      shortest_first=True)
                key = 'input'

            prev_start_ilen = batchset[0][0][1][key][0]['shape'][0]
            for batch in batchset:
                # short to long
                cur_start_ilen = batch[0][1][key][0]['shape'][0]
                self.assertGreaterEqual(cur_start_ilen, prev_start_ilen)

                prev_ilen = cur_start_ilen
                for sample in batch:
                    cur_ilen = sample[1][key][0]['shape'][0]
                    # long to short in minibatch
                    self.assertLessEqual(cur_ilen, prev_ilen)
                    prev_ilen = cur_ilen
                prev_start_ilen = cur_start_ilen