Beispiel #1
0
 def test_quantization_saved(self):
     for fake_yaml in [
             'dynamic_yaml.yaml', 'qat_yaml.yaml', 'ptq_yaml.yaml'
     ]:
         if fake_yaml == 'dynamic_yaml.yaml':
             model = torchvision.models.resnet18()
         else:
             model = copy.deepcopy(self.model)
         if fake_yaml == 'ptq_yaml.yaml':
             model.eval().fuse_model()
         quantizer = Quantization(fake_yaml)
         dataset = quantizer.dataset('dummy', (100, 3, 256, 256),
                                     label=True)
         quantizer.model = common.Model(model)
         if fake_yaml == 'qat_yaml.yaml':
             quantizer.q_func = q_func
         else:
             quantizer.calib_dataloader = common.DataLoader(dataset)
         quantizer.eval_dataloader = common.DataLoader(dataset)
         q_model = quantizer()
         q_model.save('./saved')
         # Load configure and weights by lpot.utils
         saved_model = load("./saved", model)
         eval_func(saved_model)
         shutil.rmtree('./saved', ignore_errors=True)
     from lpot.experimental import Benchmark
     evaluator = Benchmark('ptq_yaml.yaml')
     # Load configure and weights by lpot.model
     evaluator.model = common.Model(model)
     evaluator.b_dataloader = common.DataLoader(dataset)
     evaluator()
     evaluator.model = common.Model(model)
     evaluator()
Beispiel #2
0
    def test_quantization_saved(self):
        from lpot.utils.pytorch import load

        for fake_yaml in [
                'dynamic_yaml.yaml', 'qat_yaml.yaml', 'ptq_yaml.yaml'
        ]:
            if fake_yaml == 'dynamic_yaml.yaml':
                model = torchvision.models.resnet18()
            else:
                model = copy.deepcopy(self.model)
            if fake_yaml == 'ptq_yaml.yaml':
                model.eval().fuse_model()
            quantizer = Quantization(fake_yaml)
            dataset = quantizer.dataset('dummy', (100, 3, 256, 256),
                                        label=True)
            quantizer.model = common.Model(model)
            quantizer.calib_dataloader = common.DataLoader(dataset)
            quantizer.eval_dataloader = common.DataLoader(dataset)
            if fake_yaml == 'qat_yaml.yaml':
                quantizer.q_func = q_func
            q_model = quantizer()
            q_model.save('./saved')
            # Load configure and weights by lpot.utils
            saved_model = load("./saved", model)
            eval_func(saved_model)
        from lpot.experimental import Benchmark
        evaluator = Benchmark('ptq_yaml.yaml')
        # Load configure and weights by lpot.model
        evaluator.model = common.Model(model)
        evaluator.b_dataloader = common.DataLoader(dataset)
        results = evaluator()
        evaluator.model = common.Model(model)
        fp32_results = evaluator()
        self.assertTrue(
            (fp32_results['accuracy'][0] - results['accuracy'][0]) < 0.01)
Beispiel #3
0
    def test_adaptor(self):
        for fake_yaml in ["static.yaml", "dynamic.yaml"]:
            quantizer = Quantization(fake_yaml)
            quantizer.calib_dataloader = self.cv_dataloader
            quantizer.eval_dataloader = self.cv_dataloader
            quantizer.model = common.Model(self.rn50_model)
            q_model = quantizer()
            eval_func(q_model)
        for fake_yaml in ["non_MSE.yaml"]:
            quantizer = Quantization(fake_yaml)
            quantizer.calib_dataloader = self.cv_dataloader
            quantizer.eval_dataloader = self.cv_dataloader
            quantizer.model = common.Model(self.mb_v2_model)
            q_model = quantizer()
            eval_func(q_model)

        for fake_yaml in ["static.yaml"]:
            quantizer = Quantization(fake_yaml)
            quantizer.calib_dataloader = self.ir3_dataloader
            quantizer.eval_dataloader = self.ir3_dataloader
            quantizer.model = common.Model(self.ir3_model)
            q_model = quantizer()

        for mode in ["performance", "accuracy"]:
            fake_yaml = "benchmark.yaml"
            evaluator = Benchmark(fake_yaml)
            evaluator.b_dataloader = self.cv_dataloader
            evaluator.model = common.Model(self.rn50_model)
            evaluator(mode)
Beispiel #4
0
    def test_footprint(self):
        from lpot.experimental import Benchmark, common
        from lpot.data import DATASETS
        dataset = DATASETS('tensorflow')['dummy']((100, 256, 256, 1),
                                                  label=True)

        benchmarker = Benchmark('fake_yaml_footprint.yaml')
        benchmarker.b_dataloader = common.DataLoader(dataset)
        benchmarker.model = self.constant_graph_1
        benchmarker()
Beispiel #5
0
def main(_):
    arg_parser = ArgumentParser(description='Parse args')

    arg_parser.add_argument("--input-graph",
                            help='Specify the slim model',
                            dest='input_graph')

    arg_parser.add_argument("--output-graph",
                            help='Specify tune result model save dir',
                            dest='output_graph')

    arg_parser.add_argument("--config", default=None, help="tuning config")

    arg_parser.add_argument('--benchmark',
                            dest='benchmark',
                            action='store_true',
                            help='run benchmark')

    arg_parser.add_argument('--tune',
                            dest='tune',
                            action='store_true',
                            help='use lpot to tune.')

    args = arg_parser.parse_args()

    factory = TFSlimNetsFactory()
    # user specific model can register to slim net factory
    input_shape = [None, 299, 299, 3]
    factory.register('inception_v4', inception_v4, input_shape,
                     inception_v4_arg_scope)

    if args.tune:

        from lpot.experimental import Quantization
        quantizer = Quantization(args.config)
        quantizer.model = args.input_graph
        q_model = quantizer()
        q_model.save(args.output_graph)

    if args.benchmark:
        from lpot.experimental import Benchmark
        evaluator = Benchmark(args.config)
        evaluator.model = args.input_graph
        results = evaluator()
        for mode, result in results.items():
            acc, batch_size, result_list = result
            latency = np.array(result_list).mean() / batch_size

            print('\n{} mode benchmark result:'.format(mode))
            print('Accuracy is {:.3f}'.format(acc))
            print('Batch size = {}'.format(batch_size))
            print('Latency: {:.3f} ms'.format(latency * 1000))
            print('Throughput: {:.3f} images/sec'.format(1. / latency))
Beispiel #6
0
 def run(self):
     if self.args.tune:
         from lpot.experimental import Quantization
         quantizer = Quantization(self.args.config)
         quantizer.model = self.args.input_graph
         q_model = quantizer()
         q_model.save(self.args.output_model)
             
     if self.args.benchmark:
         from lpot.experimental import Benchmark
         evaluator = Benchmark(self.args.config)
         evaluator.model = self.args.input_graph
         evaluator(self.args.mode)
Beispiel #7
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    if FLAGS.benchmark:
        from lpot.experimental import Benchmark
        evaluator = Benchmark(FLAGS.config)
        evaluator.model = FLAGS.input_model
        evaluator(FLAGS.mode)

    elif FLAGS.tune:
        from lpot.experimental import Quantization
        quantizer = Quantization(FLAGS.config)
        quantizer.model = FLAGS.input_model
        q_model = quantizer()
        q_model.save(FLAGS.output_model)
Beispiel #8
0
    def run(self):
        """ This is lpot function include tuning and benchmark option """

        if self.args.tune:
            from lpot.experimental import Quantization, common
            quantizer = Quantization(self.args.config)
            quantizer.model = common.Model(self.args.input_graph)
            q_model = quantizer()
            q_model.save(self.args.output_graph)

        if self.args.benchmark:
            from lpot.experimental import Benchmark, common
            evaluator = Benchmark(self.args.config)
            evaluator.model = common.Model(self.args.input_graph)
            evaluator(self.args.mode)
Beispiel #9
0
def main(_):
    arg_parser = ArgumentParser(description='Parse args')

    arg_parser.add_argument("--input-graph",
                            help='Specify the slim model',
                            dest='input_graph')

    arg_parser.add_argument("--output-graph",
                            help='Specify tune result model save dir',
                            dest='output_graph')

    arg_parser.add_argument("--config", default=None, help="tuning config")

    arg_parser.add_argument('--benchmark',
                            dest='benchmark',
                            action='store_true',
                            help='run benchmark')

    arg_parser.add_argument('--mode',
                            dest='mode',
                            default='performance',
                            help='benchmark mode')

    arg_parser.add_argument('--tune',
                            dest='tune',
                            action='store_true',
                            help='use lpot to tune.')

    args = arg_parser.parse_args()

    factory = TFSlimNetsFactory()
    # user specific model can register to slim net factory
    input_shape = [None, 299, 299, 3]
    factory.register('inception_v4', inception_v4, input_shape,
                     inception_v4_arg_scope)

    if args.tune:
        from lpot.experimental import Quantization
        quantizer = Quantization(args.config)
        quantizer.model = args.input_graph
        q_model = quantizer()
        q_model.save(args.output_graph)

    if args.benchmark:
        from lpot.experimental import Benchmark
        evaluator = Benchmark(args.config)
        evaluator.model = args.input_graph
        evaluator(args.mode)
Beispiel #10
0
def benchmark_model(
    input_graph: str,
    config: str,
    benchmark_mode: str,
    framework: str,
) -> None:
    """Execute benchmark."""
    from lpot.experimental import Benchmark, common

    if framework == "onnxrt":
        import onnx

        input_graph = onnx.load(input_graph)

    evaluator = Benchmark(config)
    evaluator.model = common.Model(input_graph)
    evaluator(benchmark_mode)
Beispiel #11
0
    def test_performance(self):
        from lpot.data import DATASETS
        dataset = DATASETS('tensorflow')['dummy']((100, 256, 256, 1),
                                                  label=True)

        from lpot.experimental import Quantization, common
        from lpot.utils.utility import get_size

        quantizer = Quantization('fake_yaml.yaml')
        quantizer.calib_dataloader = common.DataLoader(dataset)
        quantizer.eval_dataloader = common.DataLoader(dataset)
        quantizer.model = self.constant_graph
        q_model = quantizer()

        from lpot.experimental import Benchmark, common
        benchmarker = Benchmark('fake_yaml.yaml')
        benchmarker.b_dataloader = common.DataLoader(dataset)
        benchmarker.model = self.constant_graph_1
        benchmarker()
Beispiel #12
0
def main():

    from lpot.experimental import Quantization, common
    quantizer = Quantization('./conf.yaml')
    quantizer.model = common.Model("./mobilenet_v1_1.0_224_frozen.pb")
    quantized_model = quantizer()

    # Optional, run benchmark
    from lpot.experimental import Benchmark
    evaluator = Benchmark('./conf.yaml')
    evaluator.model = common.Model(quantized_model.graph_def)
    results = evaluator()
    batch_size = 1
    for mode, result in results.items():
        acc, batch_size, result_list = result
        latency = np.array(result_list).mean() / batch_size

        print('Accuracy is {:.3f}'.format(acc))
        print('Latency: {:.3f} ms'.format(latency * 1000))
Beispiel #13
0
def benchmark_model(
    input_graph: str,
    config: str,
    benchmark_mode: str,
    framework: str,
    datatype: str = "",
) -> List[Dict[str, Any]]:
    """Execute benchmark."""
    from lpot.experimental import Benchmark, common

    benchmark_results = []

    if framework == "onnxrt":
        import onnx

        input_graph = onnx.load(input_graph)

    evaluator = Benchmark(config)
    evaluator.model = common.Model(input_graph)
    results = evaluator()
    for mode, result in results.items():
        if benchmark_mode == mode:
            log.info(f"Mode: {mode}")
            acc, batch_size, result_list = result
            latency = (sum(result_list) / len(result_list)) / batch_size
            log.info(f"Batch size: {batch_size}")
            if mode == "accuracy":
                log.info(f"Accuracy: {acc:.3f}")
            elif mode == "performance":
                log.info(f"Latency: {latency * 1000:.3f} ms")
                log.info(f"Throughput: {1. / latency:.3f} images/sec")

            benchmark_results.append(
                {
                    "precision": datatype,
                    "mode": mode,
                    "batch_size": batch_size,
                    "accuracy": acc,
                    "latency": latency * 1000,
                    "throughput": 1.0 / latency,
                }, )
    return benchmark_results
Beispiel #14
0
def main():
    arg_parser = ArgumentParser(description='Parse args')
    arg_parser.add_argument('--benchmark',
                            action='store_true',
                            help='run benchmark')
    arg_parser.add_argument('--tune', action='store_true', help='run tuning')
    args = arg_parser.parse_args()

    if args.tune:
        from lpot.experimental import Quantization, common
        quantizer = Quantization('./conf.yaml')
        quantizer.model = common.Model("./mobilenet_v1_1.0_224_frozen.pb")
        quantized_model = quantizer()
        quantized_model.save('./int8.pb')

    if args.benchmark:
        from lpot.experimental import Benchmark, common
        evaluator = Benchmark('./conf.yaml')
        evaluator.model = common.Model('int8.pb')
        evaluator(mode='accuracy')
Beispiel #15
0
 def test_tuning_ipex(self):
     from lpot.experimental import Quantization
     model = torchvision.models.resnet18()
     quantizer = Quantization('ipex_yaml.yaml')
     dataset = quantizer.dataset('dummy', (100, 3, 256, 256), label=True)
     quantizer.model = common.Model(model)
     quantizer.calib_dataloader = common.DataLoader(dataset)
     quantizer.eval_dataloader = common.DataLoader(dataset)
     lpot_model = quantizer()
     lpot_model.save("./saved")
     try:
         script_model = torch.jit.script(model.to(ipex.DEVICE))
     except:
         script_model = torch.jit.trace(
             model.to(ipex.DEVICE),
             torch.randn(10, 3, 224, 224).to(ipex.DEVICE))
     from lpot.experimental import Benchmark
     evaluator = Benchmark('ipex_yaml.yaml')
     evaluator.model = common.Model(script_model)
     evaluator.b_dataloader = common.DataLoader(dataset)
     results = evaluator()
Beispiel #16
0
    def run(self):
        if self.args.tune:
            from lpot.experimental import Quantization
            quantizer = Quantization(self.args.config)
            quantizer.model = self.args.input_graph
            q_model = quantizer()
            q_model.save(self.args.output_model)
                
        if self.args.benchmark:
            from lpot.experimental import Benchmark
            evaluator = Benchmark(self.args.config)
            evaluator.model = self.args.input_graph
            results = evaluator()
            for mode, result in results.items():
                acc, batch_size, result_list = result
                latency = np.array(result_list).mean() / batch_size

                print('\n{} mode benchmark result:'.format(mode))
                print('Accuracy is {:.3f}'.format(acc))
                print('Batch size = {}'.format(batch_size))
                print('Latency: {:.3f} ms'.format(latency * 1000))
                print('Throughput: {:.3f} images/sec'.format(1./ latency))
Beispiel #17
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    if FLAGS.mode == 'benchmark':
        from lpot.experimental import Benchmark
        evaluator = Benchmark(FLAGS.config)
        evaluator.model = FLAGS.input_model
        results = evaluator()
        for mode, result in results.items():
            acc, batch_size, result_list = result
            latency = np.array(result_list).mean() / batch_size
            print('\n{} mode benchmark result:'.format(mode))
            print('Accuracy is {:.3f}'.format(acc))
            print('Batch size = {}'.format(batch_size))
            print('Latency: {:.3f} ms'.format(latency * 1000))
            print('Throughput: {:.3f} images/sec'.format(1. / latency))
    elif FLAGS.mode == 'tune':
        from lpot.quantization import Quantization
        quantizer = Quantization(FLAGS.config)
        quantizer.model = FLAGS.input_model
        q_model = quantizer()
        q_model.save(FLAGS.output_model)
Beispiel #18
0
        help="whether quantize the model"
    )
    parser.add_argument('--config', type=str, help="config yaml path")
    parser.add_argument('--output_model', type=str, help="output model path")

    parser.add_argument('--mode',
                        type=str,
                        help="benchmark mode of performance or accuracy")

    args = parser.parse_args()

    model = onnx.load(args.model_path)
    if args.benchmark:
        from lpot.experimental import Benchmark, common
        evaluator = Benchmark(args.config)
        evaluator.model = common.Model(model)
        evaluator(args.mode)

    if args.tune:
        from lpot.experimental import Quantization, common

        quantize = Quantization(args.config)
        quantize.model = common.Model(model)
        q_model = quantize()
        q_model.save(args.output_model)

        if args.benchmark:
            from lpot.experimental import Benchmark
            evaluator = Benchmark(args.config)
            evaluator.model = common.Model(q_model)
            evaluator(args.mode)
Beispiel #19
0

class dataloader(object):
    def __init__(self, batch_size=100):
        mnist = keras.datasets.mnist
        (train_images, train_labels), (test_images,
                                       test_labels) = mnist.load_data()

        # Normalize the input image so that each pixel value is between 0 to 1.
        self.train_images = train_images / 255.0
        self.test_images = test_images / 255.0
        self.train_labels = train_labels
        self.test_labels = test_labels

        self.batch_size = batch_size
        self.i = 0

    def __iter__(self):
        while self.i < len(self.test_images):
            yield self.test_images[self.i:self.i +
                                   self.batch_size], self.test_labels[
                                       self.i:self.i + self.batch_size]
            self.i = self.i + self.batch_size


from lpot.experimental import Benchmark, common
evaluator = Benchmark('mnist.yaml')
evaluator.model = common.Model('quantized_model')
evaluator.b_dataloader = dataloader()
evaluator('accuracy')
Beispiel #20
0
def main(_):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "xnli": XnliProcessor,
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.compat.v1.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    session_config = tf.compat.v1.ConfigProto(
        inter_op_parallelism_threads=FLAGS.num_inter_threads,
        intra_op_parallelism_threads=FLAGS.num_intra_threads)
    run_config = tf.compat.v1.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host),
        session_config=session_config)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = model_fn_builder(
        bert_config=bert_config,
        num_labels=len(label_list),
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
    )

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        file_based_convert_examples_to_features(train_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, train_file)
        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Num examples = %d", len(train_examples))
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples) % FLAGS.eval_batch_size != 0:
                eval_examples.append(PaddingInputExample())

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, eval_file)

        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info(
            "  Num examples = %d (%d actual, %d padding)", len(eval_examples),
            num_actual_eval_examples,
            len(eval_examples) - num_actual_eval_examples)
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            assert len(eval_examples) % FLAGS.eval_batch_size == 0
        eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) + 1

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        start = time.time()
        result = estimator.evaluate(input_fn=eval_input_fn,
                                    steps=eval_steps,
                                    hooks=[LoggerHook()])
        end = time.time() - start
        result['global_step'] = str(eval_steps)
        result['latency_total'] = str(end)
        result['latency_per_step'] = str(end / eval_steps)
        if FLAGS.eval_batch_size != 1:
            result['samples_per_sec'] = str(FLAGS.eval_batch_size /
                                            (end / eval_steps))

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.compat.v1.gfile.GFile(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    # BELOW IS LPOT TUNING AND BENCHMARK CODE

    class Dataset(object):
        def __init__(self, file_name, batch_size):
            self.file_name = file_name
            self.batch_size = batch_size

        def __getitem__(self, idx):
            return (self.file_name, self.batch_size), 0

        def __len__(self):
            return 1

    def collate_fn(batch):
        """Puts each data field into a pd frame with outer dimension batch size"""
        elem = batch[0]
        return elem

    from lpot.metric import METRICS

    class Accuracy(object):
        def __init__(self):
            self.metric = METRICS('tensorflow')['Accuracy']()

        # it's ugly that the label is in the iterator
        def update(self, preds, label):
            logits, labels = preds
            self.metric.update(logits, labels)

        def reset(self):
            self.metric.reset()

        def result(self):
            return self.metric.result()

    if FLAGS.tune:

        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")

        convert_examples_to_features(examples=eval_examples,
                                     label_list=label_list,
                                     max_seq_length=FLAGS.max_seq_length,
                                     tokenizer=tokenizer,
                                     output_file=eval_file)

        estimator_input_fn = input_fn_builder(input_file=eval_file,
                                              seq_length=FLAGS.max_seq_length,
                                              is_training=False,
                                              drop_remainder=False)

        from lpot.experimental import Quantization, common
        quantizer = Quantization(FLAGS.config)
        dataset = Dataset(eval_file, FLAGS.eval_batch_size)
        quantizer.model = common.Model(estimator, input_fn=estimator_input_fn)
        quantizer.calib_dataloader = common.DataLoader(dataset,
                                                       collate_fn=collate_fn)
        quantizer.eval_dataloader = common.DataLoader(dataset,
                                                      collate_fn=collate_fn)
        quantizer.metric = common.Metric(metric_cls=Accuracy)
        q_model = quantizer()
        q_model.save(FLAGS.output_model)

    if FLAGS.benchmark:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")

        from lpot.experimental import Benchmark, common
        evaluator = Benchmark(FLAGS.config)
        dataset = Dataset(eval_file, FLAGS.eval_batch_size)
        evaluator.b_dataloader = common.DataLoader(\
            dataset, batch_size=FLAGS.eval_batch_size, collate_fn=collate_fn)
        evaluator.metric = common.Metric(metric_cls=Accuracy)

        from lpot.model.model import get_model_type
        model_type = get_model_type(FLAGS.input_model)
        if model_type == 'frozen_pb':
            evaluator.model = FLAGS.input_model
        else:
            estimator_input_fn = input_fn_builder(
                input_file=eval_file,
                seq_length=FLAGS.max_seq_length,
                is_training=False,
                drop_remainder=False)
            evaluator.model = common.Model(estimator,
                                           input_fn=estimator_input_fn)
        evaluator(FLAGS.mode)