def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.fp16:
        optim_level = Optimization.mxprO3
    else:
        optim_level = Optimization.mxprO0

    model_definition = toml.load(args.model_toml)
    dataset_vocab = model_definition['labels']['labels']
    ctc_vocab = add_blank_label(dataset_vocab)

    val_manifest = args.val_manifest
    featurizer_config = model_definition['input_eval']
    featurizer_config["optimization_level"] = optim_level

    if args.max_duration is not None:
        featurizer_config['max_duration'] = args.max_duration
    if args.pad_to is not None:
        featurizer_config['pad_to'] = args.pad_to if args.pad_to >= 0 else "max"

    data_layer = AudioToTextDataLayer(
        dataset_dir=args.dataset_dir,
        featurizer_config=featurizer_config,
        manifest_filepath=val_manifest,
        labels=dataset_vocab,
        batch_size=args.batch_size,
        pad_to_max=featurizer_config['pad_to'] == "max",
        shuffle=False,
        multi_gpu=False)

    audio_preprocessor = AudioPreprocessing(**featurizer_config)

    audio_preprocessor.eval()

    eval_transforms = torchvision.transforms.Compose([
        lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]],
        lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]],
    ])

    eval(
        data_layer=data_layer,
        audio_processor=eval_transforms,
        args=args)
Ejemplo n.º 2
0
    def __init__(self, config_toml, checkpoint_path, dataset_dir,
                 manifest_filepath, perf_count):
        config = toml.load(config_toml)

        dataset_vocab = config['labels']['labels']
        rnnt_vocab = add_blank_label(dataset_vocab)
        featurizer_config = config['input_eval']

        self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries,
                                   self.process_latencies)
        self.qsl = AudioQSLInMemory(dataset_dir,
                                    manifest_filepath,
                                    dataset_vocab,
                                    featurizer_config["sample_rate"],
                                    perf_count)
        self.audio_preprocessor = AudioPreprocessing(**featurizer_config)
        self.audio_preprocessor.eval()
        self.audio_preprocessor = torch.jit.script(self.audio_preprocessor)
        self.audio_preprocessor = torch.jit._recursive.wrap_cpp_module(
            torch._C._freeze_module(self.audio_preprocessor._c))

        model = RNNT(
            feature_config=featurizer_config,
            rnnt=config['rnnt'],
            num_classes=len(rnnt_vocab)
        )
        model.load_state_dict(load_and_migrate_checkpoint(checkpoint_path),
                              strict=True)
        model.eval()
        model.encoder = torch.jit.script(model.encoder)
        model.encoder = torch.jit._recursive.wrap_cpp_module(
            torch._C._freeze_module(model.encoder._c))
        model.prediction = torch.jit.script(model.prediction)
        model.prediction = torch.jit._recursive.wrap_cpp_module(
            torch._C._freeze_module(model.prediction._c))
        model.joint = torch.jit.script(model.joint)
        model.joint = torch.jit._recursive.wrap_cpp_module(
            torch._C._freeze_module(model.joint._c))
        model = torch.jit.script(model)

        self.greedy_decoder = ScriptGreedyDecoder(len(rnnt_vocab) - 1, model)
Ejemplo n.º 3
0
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    #torch.set_default_dtype(torch.double)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = args.cudnn_benchmark
    #print("CUDNN BENCHMARK ", args.cudnn_benchmark)
    if args.cuda:
        assert(torch.cuda.is_available())

    model_definition = toml.load(args.model_toml)
    dataset_vocab = model_definition['labels']['labels']
    ctc_vocab = add_blank_label(dataset_vocab)

    val_manifest = args.val_manifest
    featurizer_config = model_definition['input_eval']

    if args.pad_to is not None:
        featurizer_config['pad_to'] = args.pad_to if args.pad_to >= 0 else "max"

    #print('model_config')
    #print_dict(model_definition)
    #print('feature_config')
    #print_dict(featurizer_config)
    data_layer = None
    data_layer = AudioToTextDataLayer(
        dataset_dir=args.dataset_dir,
        featurizer_config=featurizer_config,
        manifest_filepath=val_manifest,
        labels=dataset_vocab,
        batch_size=args.batch_size,
        pad_to_max=featurizer_config['pad_to'] == "max",
        shuffle=False,
        sampler='bucket' #sort by duration 
        )
    audio_preprocessor = AudioPreprocessing(**featurizer_config)

    model = RNNT(
        feature_config=featurizer_config,
        rnnt=model_definition['rnnt'],
        num_classes=len(ctc_vocab)
    )
  
    if args.ckpt is not None and args.mode in[3]:
        #print("loading model from ", args.ckpt)
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    audio_preprocessor.featurizer.normalize = "per_feature"

    if args.cuda:
        audio_preprocessor.cuda()
    audio_preprocessor.eval()

    eval_transforms = []
    if args.cuda:
        eval_transforms.append(lambda xs: [xs[0].cuda(),xs[1].cuda(), *xs[2:]])
   
    eval_transforms.append(lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]])
    # These are just some very confusing transposes, that's all.
    # BxFxT -> TxBxF
    eval_transforms.append(lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]])
    eval_transforms = torchvision.transforms.Compose(eval_transforms)

    if args.cuda:
        model.cuda()
    # Ideally, I would jit this as well... But this is just the constructor...
    greedy_decoder = RNNTGreedyDecoder(len(ctc_vocab) - 1, model)

    eval(
        data_layer=data_layer,
        audio_processor=eval_transforms,
        greedy_decoder=greedy_decoder,
        labels=ctc_vocab,
        args=args)
Ejemplo n.º 4
0
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    multi_gpu = args.local_rank is not None
    if multi_gpu:
        print("DISTRIBUTED with ", torch.distributed.get_world_size())

    if args.fp16:
        optim_level = Optimization.mxprO3
    else:
        optim_level = Optimization.mxprO0

    model_definition = toml.load(args.model_toml)
    dataset_vocab = model_definition['labels']['labels']
    ctc_vocab = add_blank_label(dataset_vocab)

    val_manifest = args.val_manifest
    featurizer_config = model_definition['input_eval']
    featurizer_config["optimization_level"] = optim_level

    if args.max_duration is not None:
        featurizer_config['max_duration'] = args.max_duration
    if args.pad_to is not None:
        featurizer_config['pad_to'] = args.pad_to if args.pad_to >= 0 else "max"

    print('model_config')
    print_dict(model_definition)
    print('feature_config')
    print_dict(featurizer_config)
    data_layer = None
    
    if args.wav is None:
        data_layer = AudioToTextDataLayer(
            dataset_dir=args.dataset_dir, 
            featurizer_config=featurizer_config,
            manifest_filepath=val_manifest,
            # sampler='bucket',
            sort_by_duration=args.sort_by_duration,
            labels=dataset_vocab,
            batch_size=args.batch_size,
            pad_to_max=featurizer_config['pad_to'] == "max",
            shuffle=False,
            multi_gpu=multi_gpu)
    audio_preprocessor = AudioPreprocessing(**featurizer_config)

    #encoderdecoder = JasperEncoderDecoder(jasper_model_definition=jasper_model_definition, feat_in=1024, num_classes=len(ctc_vocab))
    model = RNNT(
        feature_config=featurizer_config,
        rnnt=model_definition['rnnt'],
        num_classes=len(ctc_vocab)
    )

    if args.ckpt is not None:
        print("loading model from ", args.ckpt)
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    if args.ipex:
        import intel_extension_for_pytorch as ipex
        from rnn import IPEXStackTime
        model.joint_net.eval()
        data_type = torch.bfloat16 if args.mix_precision else torch.float32
        if model.encoder["stack_time"].factor == 2:
            model.encoder["stack_time"] = IPEXStackTime(model.encoder["stack_time"].factor)
        model.joint_net = ipex.optimize(model.joint_net, dtype=data_type, auto_kernel_selection=True)
        model.prediction["embed"] = model.prediction["embed"].to(data_type)
        if args.jit:
            print("running jit path")
            model.joint_net.eval()
            if args.mix_precision:
                with torch.cpu.amp.autocast(), torch.no_grad():
                    model.joint_net = torch.jit.trace(model.joint_net, torch.randn(args.batch_size, 1, 1, model_definition['rnnt']['encoder_n_hidden'] + model_definition['rnnt']['pred_n_hidden']), check_trace=False)
            else:
                with torch.no_grad():
                    model.joint_net = torch.jit.trace(model.joint_net, torch.randn(args.batch_size, 1, 1, model_definition['rnnt']['encoder_n_hidden'] + model_definition['rnnt']['pred_n_hidden']), check_trace=False)
            model.joint_net = torch.jit.freeze(model.joint_net)
    else:
        model = model.to("cpu")

    #greedy_decoder = GreedyCTCDecoder()

    # print("Number of parameters in encoder: {0}".format(model.jasper_encoder.num_weights()))
    if args.wav is None:
        N = len(data_layer)
        # step_per_epoch = math.ceil(N / (args.batch_size * (1 if not torch.distributed.is_available() else torch.distributed.get_world_size())))
        step_per_epoch = math.ceil(N / (args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))

        if args.steps is not None:
            print('-----------------')
            # print('Have {0} examples to eval on.'.format(args.steps * args.batch_size * (1 if not torch.distributed.is_available() else torch.distributed.get_world_size())))
            print('Have {0} examples to eval on.'.format(args.steps * args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
            print('Have {0} warm up steps / (gpu * epoch).'.format(args.warm_up))
            print('Have {0} measure steps / (gpu * epoch).'.format(args.steps))
            print('-----------------')
        else:
            print('-----------------')
            print('Have {0} examples to eval on.'.format(N))
            print('Have {0} warm up steps / (gpu * epoch).'.format(args.warm_up))
            print('Have {0} measure steps / (gpu * epoch).'.format(step_per_epoch))
            print('-----------------')
    else:
            audio_preprocessor.featurizer.normalize = "per_feature"

    print ("audio_preprocessor.normalize: ", audio_preprocessor.featurizer.normalize)
    audio_preprocessor.eval()

    # eval_transforms = torchvision.transforms.Compose([
    #     lambda xs: [x.to(ipex.DEVICE) if args.ipex else x.cpu() for x in xs],
    #     lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]],
    #     lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]],
    # ])

    eval_transforms = torchvision.transforms.Compose([
        lambda xs: [x.cpu() for x in xs],
        lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]],
        lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]],
    ])

    model.eval()
    if args.ipex:
        ipex.nn.utils._model_convert.replace_lstm_with_ipex_lstm(model)

    greedy_decoder = RNNTGreedyDecoder(len(ctc_vocab) - 1, model.module if multi_gpu else model)

    eval(
        data_layer=data_layer,
        audio_processor=eval_transforms,
        encoderdecoder=model,
        greedy_decoder=greedy_decoder,
        labels=ctc_vocab,
        args=args,
        multi_gpu=multi_gpu)
Ejemplo n.º 5
0
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = args.cudnn_benchmark
    print("CUDNN BENCHMARK ", args.cudnn_benchmark)
    assert(torch.cuda.is_available())

    if args.local_rank is not None:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    multi_gpu = args.local_rank is not None
    if multi_gpu:
        print("DISTRIBUTED with ", torch.distributed.get_world_size())

    if args.fp16:
        optim_level = Optimization.mxprO3
    else:
        optim_level = Optimization.mxprO0

    model_definition = toml.load(args.model_toml)
    dataset_vocab = model_definition['labels']['labels']
    ctc_vocab = add_blank_label(dataset_vocab)

    val_manifest = args.val_manifest
    featurizer_config = model_definition['input_eval']
    featurizer_config["optimization_level"] = optim_level

    if args.max_duration is not None:
        featurizer_config['max_duration'] = args.max_duration
    if args.pad_to is not None:
        featurizer_config['pad_to'] = args.pad_to if args.pad_to >= 0 else "max"

    print('model_config')
    print_dict(model_definition)
    print('feature_config')
    print_dict(featurizer_config)
    data_layer = None
    
    if args.wav is None:
        data_layer = AudioToTextDataLayer(
            dataset_dir=args.dataset_dir, 
            featurizer_config=featurizer_config,
            manifest_filepath=val_manifest,
            labels=dataset_vocab,
            batch_size=args.batch_size,
            pad_to_max=featurizer_config['pad_to'] == "max",
            shuffle=False,
            multi_gpu=multi_gpu)
    audio_preprocessor = AudioPreprocessing(**featurizer_config)

    #encoderdecoder = JasperEncoderDecoder(jasper_model_definition=jasper_model_definition, feat_in=1024, num_classes=len(ctc_vocab))
    model = RNNT(
        feature_config=featurizer_config,
        rnnt=model_definition['rnnt'],
        num_classes=len(ctc_vocab)
    )

    if args.ckpt is not None:
        print("loading model from ", args.ckpt)
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    #greedy_decoder = GreedyCTCDecoder()

    # print("Number of parameters in encoder: {0}".format(model.jasper_encoder.num_weights()))
    if args.wav is None:
        N = len(data_layer)
        step_per_epoch = math.ceil(N / (args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))

        if args.steps is not None:
            print('-----------------')
            print('Have {0} examples to eval on.'.format(args.steps * args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())))
            print('Have {0} steps / (gpu * epoch).'.format(args.steps))
            print('-----------------')
        else:
            print('-----------------')
            print('Have {0} examples to eval on.'.format(N))
            print('Have {0} steps / (gpu * epoch).'.format(step_per_epoch))
            print('-----------------')
    else:
            audio_preprocessor.featurizer.normalize = "per_feature"

    print ("audio_preprocessor.normalize: ", audio_preprocessor.featurizer.normalize)
    audio_preprocessor.cuda()
    audio_preprocessor.eval()

    eval_transforms = torchvision.transforms.Compose([
        lambda xs: [x.cuda() for x in xs],
        lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]],
        lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]],
    ])

    model.cuda()
    if args.fp16:
        model = amp.initialize(
            models=model,
            opt_level=AmpOptimizations[optim_level])

    model = model_multi_gpu(model, multi_gpu)

    greedy_decoder = RNNTGreedyDecoder(len(ctc_vocab) - 1, model.module if multi_gpu else model)

    eval(
        data_layer=data_layer,
        audio_processor=eval_transforms,
        encoderdecoder=model,
        greedy_decoder=greedy_decoder,
        labels=ctc_vocab,
        args=args,
        multi_gpu=multi_gpu)
Ejemplo n.º 6
0
class PytorchSUT:
    def __init__(self, config_toml, checkpoint_path, dataset_dir,
                 manifest_filepath, perf_count):
        config = toml.load(config_toml)

        dataset_vocab = config['labels']['labels']
        rnnt_vocab = add_blank_label(dataset_vocab)
        featurizer_config = config['input_eval']

        self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries,
                                   self.process_latencies)
        self.qsl = AudioQSLInMemory(dataset_dir,
                                    manifest_filepath,
                                    dataset_vocab,
                                    featurizer_config["sample_rate"],
                                    perf_count)
        self.audio_preprocessor = AudioPreprocessing(**featurizer_config)
        self.audio_preprocessor.eval()
        self.audio_preprocessor = torch.jit.script(self.audio_preprocessor)
        self.audio_preprocessor = torch.jit._recursive.wrap_cpp_module(
            torch._C._freeze_module(self.audio_preprocessor._c))

        model = RNNT(
            feature_config=featurizer_config,
            rnnt=config['rnnt'],
            num_classes=len(rnnt_vocab)
        )
        model.load_state_dict(load_and_migrate_checkpoint(checkpoint_path),
                              strict=True)
        model.eval()
        model.encoder = torch.jit.script(model.encoder)
        model.encoder = torch.jit._recursive.wrap_cpp_module(
            torch._C._freeze_module(model.encoder._c))
        model.prediction = torch.jit.script(model.prediction)
        model.prediction = torch.jit._recursive.wrap_cpp_module(
            torch._C._freeze_module(model.prediction._c))
        model.joint = torch.jit.script(model.joint)
        model.joint = torch.jit._recursive.wrap_cpp_module(
            torch._C._freeze_module(model.joint._c))
        model = torch.jit.script(model)

        self.greedy_decoder = ScriptGreedyDecoder(len(rnnt_vocab) - 1, model)

    def issue_queries(self, query_samples):
        for query_sample in query_samples:
            waveform = self.qsl[query_sample.index]
            assert waveform.ndim == 1
            waveform_length = np.array(waveform.shape[0], dtype=np.int64)
            waveform = np.expand_dims(waveform, 0)
            waveform_length = np.expand_dims(waveform_length, 0)
            with torch.no_grad():
                waveform = torch.from_numpy(waveform)
                waveform_length = torch.from_numpy(waveform_length)
                feature, feature_length = self.audio_preprocessor.forward((waveform, waveform_length))
                assert feature.ndim == 3
                assert feature_length.ndim == 1
                feature = feature.permute(2, 0, 1)

                _, _, transcript = self.greedy_decoder.forward(feature, feature_length)

            assert len(transcript) == 1
            response_array = array.array('q', transcript[0])
            bi = response_array.buffer_info()
            response = lg.QuerySampleResponse(query_sample.id, bi[0],
                                              bi[1] * response_array.itemsize)
            lg.QuerySamplesComplete([response])

    def flush_queries(self):
        pass

    def process_latencies(self, latencies_ns):
        print("Average latency (ms) per query:")
        print(np.mean(latencies_ns)/1000000.0)
        print("Median latency (ms): ")
        print(np.percentile(latencies_ns, 50)/1000000.0)
        print("90 percentile latency (ms): ")
        print(np.percentile(latencies_ns, 90)/1000000.0)

    def __del__(self):
        lg.DestroySUT(self.sut)
        print("Finished destroying SUT.")
Ejemplo n.º 7
0
    def run(self):
        core_list = range(self.start_core, self.end_core + 1)
        num_cores = len(core_list)
        os.sched_setaffinity(self.pid, core_list)
        cmd = "taskset -p -c %d-%d %d" % (self.start_core, self.end_core,
                                          self.pid)
        print(cmd)
        os.system(cmd)
        os.environ['OMP_NUM_THREADS'] = '{}'.format(self.end_core -
                                                    self.start_core + 1)
        print("### set rank {} to cores [{}:{}]; omp num threads = {}".format(
            self.rank, self.start_core, self.end_core, num_cores))

        torch.set_num_threads(num_cores)

        if not self.model_init:
            print("lazy_init rank {}".format(self.rank))
            config = toml.load(self.config_toml)
            dataset_vocab = config['labels']['labels']
            rnnt_vocab = add_blank_label(dataset_vocab)
            featurizer_config = config['input_eval']
            self.audio_preprocessor = AudioPreprocessing(**featurizer_config)
            self.audio_preprocessor.eval()
            self.audio_preprocessor = torch.jit.script(self.audio_preprocessor)
            self.audio_preprocessor = torch.jit._recursive.wrap_cpp_module(
                torch._C._freeze_module(self.audio_preprocessor._c))

            model = RNNT(feature_config=featurizer_config,
                         rnnt=config['rnnt'],
                         num_classes=len(rnnt_vocab))
            checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
            migrated_state_dict = {}
            for key, value in checkpoint['state_dict'].items():
                key = key.replace("joint_net", "joint.net")
                migrated_state_dict[key] = value
            del migrated_state_dict["audio_preprocessor.featurizer.fb"]
            del migrated_state_dict["audio_preprocessor.featurizer.window"]
            model.load_state_dict(migrated_state_dict, strict=True)

            if self.ipex:
                import intel_pytorch_extension as ipex
                if self.bf16:
                    ipex.enable_auto_mixed_precision(
                        mixed_dtype=torch.bfloat16)
                ipex.core.enable_auto_dnnl()
                model = model.to(ipex.DEVICE)

            model.eval()
            if not self.ipex:
                model.encoder = torch.jit.script(model.encoder)
                model.encoder = torch.jit._recursive.wrap_cpp_module(
                    torch._C._freeze_module(model.encoder._c))
                model.prediction = torch.jit.script(model.prediction)
                model.prediction = torch.jit._recursive.wrap_cpp_module(
                    torch._C._freeze_module(model.prediction._c))
            model.joint = torch.jit.script(model.joint)
            model.joint = torch.jit._recursive.wrap_cpp_module(
                torch._C._freeze_module(model.joint._c))
            if not self.ipex:
                model = torch.jit.script(model)

            self.greedy_decoder = ScriptGreedyDecoder(
                len(rnnt_vocab) - 1, model)

            self.model_init = True

        if self.warmup:
            self.do_warmup()

        self.lock.acquire()
        self.init_counter.value += 1
        self.lock.release()

        if self.rank == 0 and self.cosim:
            print('Running with cosim mode, performance will be slow!!!')
        if self.rank == 0 and self.profile:
            print('Start profiler')
            with profiler.profile(record_shapes=True) as prof:
                self.run_queue(debug=True)
            print(prof.key_averages().table(sort_by="self_cpu_time_total",
                                            row_limit=20))
            print(prof.key_averages().table(sort_by="cpu_time_total",
                                            row_limit=20))
            print(
                prof.key_averages(group_by_input_shape=True).table(
                    sort_by="self_cpu_time_total", row_limit=40))
            print(
                prof.key_averages(group_by_input_shape=True).table(
                    sort_by="cpu_time_total", row_limit=40))
            while self.run_queue():
                pass
        else:
            while self.run_queue():
                pass
Ejemplo n.º 8
0
class Consumer(mp.Process):
    def __init__(self, task_queue, result_queue, lock, init_counter, rank,
                 start_core, end_core, num_cores, qsl, config_toml,
                 checkpoint_path, dataset_dir, manifest_filepath, perf_count,
                 cosim, profile, ipex, bf16, warmup):

        mp.Process.__init__(self)

        ### sub process
        self.task_queue = task_queue
        self.result_queue = result_queue
        self.lock = lock
        self.init_counter = init_counter
        self.rank = rank
        self.start_core = start_core
        self.end_core = end_core

        self.qsl = qsl
        self.config_toml = config_toml
        self.checkpoint_path = checkpoint_path
        self.dataset_dir = dataset_dir
        self.manifest_filepath = manifest_filepath
        self.perf_count = perf_count
        self.cosim = cosim
        self.profile = profile
        self.ipex = ipex
        self.bf16 = bf16
        self.warmup = warmup

        self.model_init = False

    # warmup basically go through samples with different feature lengths so
    # all shapes can be prepared
    def do_warmup(self):
        print('Start warmup...')
        length_list = {}
        count = 0
        idxs = self.qsl.idxs()
        for i in idxs:
            feature_list = []
            feature_length_list = []
            waveform = self.qsl[i]
            feature_element, feature_length = self.audio_preprocessor.forward(
                (torch.from_numpy(waveform).unsqueeze(0),
                 torch.tensor(len(waveform)).unsqueeze(0)))
            feature_list.append(feature_element.squeeze(0).transpose_(0, 1))
            feature_length_list.append(feature_length.squeeze(0))
            feature = torch.nn.utils.rnn.pad_sequence(feature_list,
                                                      batch_first=True)
            feature_length = torch.tensor(feature_length_list)

            if feature_length[0].item() in length_list:
                continue
            length_list[feature_length[0].item()] = True

            assert feature.ndim == 3
            assert feature_length.ndim == 1
            if self.ipex:
                import intel_pytorch_extension as ipex
                if self.bf16:
                    ipex.enable_auto_mixed_precision(
                        mixed_dtype=torch.bfloat16)
                ipex.core.enable_auto_dnnl()
                feature = feature.to(ipex.DEVICE)
                feature_length = feature_length.to(ipex.DEVICE)
            feature_ = feature.permute(1, 0, 2)
            _, _, transcripts = self.greedy_decoder.forward_batch(
                feature_, feature_length, self.rank)

            count += 1
            if self.rank == 0 and count % 10 == 0:
                print('Warmup {} samples'.format(count))
        print('Warmup done')

    def run_queue(self, debug=False):
        next_task = self.task_queue.get()
        if next_task is None:
            self.task_queue.task_done()
            return False

        query_id_list = next_task.query_id_list
        query_idx_list = next_task.query_idx_list
        query_len = len(query_id_list)
        with torch.no_grad():
            t1 = time.time()
            serial_audio_processor = True
            if serial_audio_processor:
                feature_list = []
                feature_length_list = []
                for idx in query_idx_list:
                    waveform = self.qsl[idx]
                    feature_element, feature_length = self.audio_preprocessor.forward(
                        (torch.from_numpy(waveform).unsqueeze(0),
                         torch.tensor(len(waveform)).unsqueeze(0)))
                    feature_list.append(
                        feature_element.squeeze(0).transpose_(0, 1))
                    feature_length_list.append(feature_length.squeeze(0))
                feature = torch.nn.utils.rnn.pad_sequence(feature_list,
                                                          batch_first=True)
                feature_length = torch.tensor(feature_length_list)
            else:
                waveform_list = []
                for idx in query_idx_list:
                    waveform = self.qsl[idx]
                    waveform_list.append(torch.from_numpy(waveform))
                waveform_batch = torch.nn.utils.rnn.pad_sequence(
                    waveform_list, batch_first=True)
                waveform_lengths = torch.tensor(
                    [waveform.shape[0] for waveform in waveform_list],
                    dtype=torch.int64)

                feature, feature_length = self.audio_preprocessor.forward(
                    (waveform_batch, waveform_lengths))

            assert feature.ndim == 3
            assert feature_length.ndim == 1
            if self.ipex:
                import intel_pytorch_extension as ipex
                if self.bf16:
                    ipex.enable_auto_mixed_precision(
                        mixed_dtype=torch.bfloat16)
                ipex.core.enable_auto_dnnl()
                feature = feature.to(ipex.DEVICE)
                feature_length = feature_length.to(ipex.DEVICE)
            if serial_audio_processor:
                feature_ = feature.permute(1, 0, 2)
            else:
                feature_ = feature.permute(2, 0, 1)
            t3 = time.time()
            if query_len == 1:
                _, _, transcripts = self.greedy_decoder.forward_single_batch(
                    feature_, feature_length, self.ipex, self.rank)
            else:
                _, _, transcripts = self.greedy_decoder.forward_batch(
                    feature_, feature_length, self.ipex, self.rank)
            t4 = time.time()
            # cosim
            if self.cosim:
                _, _, transcripts0 = self.greedy_decoder.forward(
                    feature, feature_length)
                if transcripts0 != transcripts:
                    print(
                        'vvvvvv difference between reference and batch impl. vvvvvv'
                    )
                    for i in range(query_len):
                        if transcripts0[i] != transcripts[i]:
                            for j in range(len(transcripts0[i])):
                                if transcripts0[i][j] != transcripts[i][j]:
                                    break
                            print('[{}] reference'.format(i))
                            print('{} diff {}'.format(transcripts0[i][0:j],
                                                      transcripts0[i][j:]))
                            print('[{}] batch'.format(i))
                            print('{} diff {}'.format(transcripts[i][0:j],
                                                      transcripts[i][j:]))
                            print('')
                    print(
                        '^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^'
                    )
                else:
                    print('.', end='', flush=True)

        t6 = time.time()
        assert len(transcripts) == query_len
        for id, trans in zip(query_id_list, transcripts):
            self.result_queue.put(Output(id, trans))
        t2 = time.time()
        dur = t2 - t1
        if debug:
            print('Audio {} Infer {} Total {}'.format(t3 - t1, t4 - t3,
                                                      t2 - t1))
            if query_len > 1:
                print("#### rank {} finish {} sample in {:.3f} sec".format(
                    self.rank, query_len, dur))
            else:
                print(
                    "#### rank {} finish sample of feature_len={} in {:.3f} sec"
                    .format(self.rank, feature_length[0].item(), dur))

        self.task_queue.task_done()
        return True

    def run(self):
        core_list = range(self.start_core, self.end_core + 1)
        num_cores = len(core_list)
        os.sched_setaffinity(self.pid, core_list)
        cmd = "taskset -p -c %d-%d %d" % (self.start_core, self.end_core,
                                          self.pid)
        print(cmd)
        os.system(cmd)
        os.environ['OMP_NUM_THREADS'] = '{}'.format(self.end_core -
                                                    self.start_core + 1)
        print("### set rank {} to cores [{}:{}]; omp num threads = {}".format(
            self.rank, self.start_core, self.end_core, num_cores))

        torch.set_num_threads(num_cores)

        if not self.model_init:
            print("lazy_init rank {}".format(self.rank))
            config = toml.load(self.config_toml)
            dataset_vocab = config['labels']['labels']
            rnnt_vocab = add_blank_label(dataset_vocab)
            featurizer_config = config['input_eval']
            self.audio_preprocessor = AudioPreprocessing(**featurizer_config)
            self.audio_preprocessor.eval()
            self.audio_preprocessor = torch.jit.script(self.audio_preprocessor)
            self.audio_preprocessor = torch.jit._recursive.wrap_cpp_module(
                torch._C._freeze_module(self.audio_preprocessor._c))

            model = RNNT(feature_config=featurizer_config,
                         rnnt=config['rnnt'],
                         num_classes=len(rnnt_vocab))
            checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
            migrated_state_dict = {}
            for key, value in checkpoint['state_dict'].items():
                key = key.replace("joint_net", "joint.net")
                migrated_state_dict[key] = value
            del migrated_state_dict["audio_preprocessor.featurizer.fb"]
            del migrated_state_dict["audio_preprocessor.featurizer.window"]
            model.load_state_dict(migrated_state_dict, strict=True)

            if self.ipex:
                import intel_pytorch_extension as ipex
                if self.bf16:
                    ipex.enable_auto_mixed_precision(
                        mixed_dtype=torch.bfloat16)
                ipex.core.enable_auto_dnnl()
                model = model.to(ipex.DEVICE)

            model.eval()
            if not self.ipex:
                model.encoder = torch.jit.script(model.encoder)
                model.encoder = torch.jit._recursive.wrap_cpp_module(
                    torch._C._freeze_module(model.encoder._c))
                model.prediction = torch.jit.script(model.prediction)
                model.prediction = torch.jit._recursive.wrap_cpp_module(
                    torch._C._freeze_module(model.prediction._c))
            model.joint = torch.jit.script(model.joint)
            model.joint = torch.jit._recursive.wrap_cpp_module(
                torch._C._freeze_module(model.joint._c))
            if not self.ipex:
                model = torch.jit.script(model)

            self.greedy_decoder = ScriptGreedyDecoder(
                len(rnnt_vocab) - 1, model)

            self.model_init = True

        if self.warmup:
            self.do_warmup()

        self.lock.acquire()
        self.init_counter.value += 1
        self.lock.release()

        if self.rank == 0 and self.cosim:
            print('Running with cosim mode, performance will be slow!!!')
        if self.rank == 0 and self.profile:
            print('Start profiler')
            with profiler.profile(record_shapes=True) as prof:
                self.run_queue(debug=True)
            print(prof.key_averages().table(sort_by="self_cpu_time_total",
                                            row_limit=20))
            print(prof.key_averages().table(sort_by="cpu_time_total",
                                            row_limit=20))
            print(
                prof.key_averages(group_by_input_shape=True).table(
                    sort_by="self_cpu_time_total", row_limit=40))
            print(
                prof.key_averages(group_by_input_shape=True).table(
                    sort_by="cpu_time_total", row_limit=40))
            while self.run_queue():
                pass
        else:
            while self.run_queue():
                pass