コード例 #1
0
    def evaluate(
        self,
        iter_unit,
        num_iter,
        batch_size,
        warmup_steps=50,
        log_every_n_steps=1,
        is_benchmark=False,
        export_dir=None,
        quantize=False,
        symmetric=False,
        use_qdq=False,
        use_final_conv=False,
    ):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for evaluation!')

        if hvd_utils.is_using_hvd() and hvd.rank() != 0:
            raise RuntimeError('Multi-GPU inference is not supported')

        estimator_params = {
            'quantize': quantize,
            'symmetric': symmetric,
            'use_qdq': use_qdq,
            'use_final_conv': use_final_conv
        }

        image_classifier = self._get_estimator(
            mode='validation',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction,
            gpu_id=self.run_hparams.gpu_id)

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="validation",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=batch_size,
            )

        else:
            num_epochs = 1
            num_decay_steps = -1
            num_steps = num_iter

        if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir, mode="validation")

        eval_hooks = []

        if hvd.rank() == 0:
            self.eval_logging_hook = hooks.BenchmarkLoggingHook(
                global_batch_size=batch_size, warmup_steps=warmup_steps)
            eval_hooks.append(self.eval_logging_hook)

            print('Starting Model Evaluation...')
            print("Evaluation Epochs", num_epochs)
            print("Evaluation Steps", num_steps)
            print("Decay Steps", num_decay_steps)
            print("Global Batch Size", batch_size)

        def evaluation_data_fn():

            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    print("Using DALI input... ")

                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=False,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            elif self.run_hparams.data_dir is not None:
                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=False,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            else:
                print("Using Synthetic Data ...\n")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )

        try:
            eval_results = image_classifier.evaluate(
                input_fn=evaluation_data_fn,
                steps=num_steps,
                hooks=eval_hooks,
            )

            eval_throughput = self.eval_logging_hook.mean_throughput.value()
            eval_latencies = np.array(self.eval_logging_hook.latencies) * 1000
            eval_latencies_q = np.quantile(eval_latencies, q=[0.9, 0.95, 0.99])
            eval_latencies_mean = np.mean(eval_latencies)

            dllogger.log(data={
                'top1_accuracy': float(eval_results['top1_accuracy']),
                'top5_accuracy': float(eval_results['top5_accuracy']),
                'eval_throughput': eval_throughput,
                'eval_latency_avg': eval_latencies_mean,
                'eval_latency_p90': eval_latencies_q[0],
                'eval_latency_p95': eval_latencies_q[1],
                'eval_latency_p99': eval_latencies_q[2],
            },
                         step=tuple())

            if export_dir is not None:
                dllogger.log(data={'export_dir': export_dir}, step=tuple())
                input_receiver_fn = data_utils.get_serving_input_receiver_fn(
                    batch_size=None,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    dtype=self.run_hparams.dtype)

                image_classifier.export_savedmodel(export_dir,
                                                   input_receiver_fn)

        except KeyboardInterrupt:
            print("Keyboard interrupt")

        print('Model evaluation finished')
コード例 #2
0
    def train(self,
              iter_unit,
              num_iter,
              run_iter,
              batch_size,
              warmup_steps=50,
              weight_decay=1e-4,
              lr_init=0.1,
              lr_warmup_epochs=5,
              momentum=0.9,
              log_every_n_steps=1,
              loss_scale=256,
              label_smoothing=0.0,
              mixup=0.0,
              use_cosine_lr=False,
              use_static_loss_scaling=False,
              is_benchmark=False,
              quantize=False,
              symmetric=False,
              quant_delay=0,
              finetune_checkpoint=None,
              use_final_conv=False,
              use_qdq=False):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for training!')

        if self.run_hparams.use_tf_amp or self.run_hparams.dtype == tf.float16:
            if use_static_loss_scaling:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
            else:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
        else:
            use_static_loss_scaling = False  # Make sure it hasn't been set to True on FP32 training

        num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()
        global_batch_size = batch_size * num_gpus

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="train",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=global_batch_size,
            )

            steps_per_epoch = num_steps / num_epochs

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = num_steps
            num_decay_steps = num_steps
            num_samples = num_steps * batch_size

        if run_iter == -1:
            run_iter = num_steps
        else:
            run_iter = steps_per_epoch * run_iter if iter_unit == "epoch" else run_iter

        if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir, mode="train")

        training_hooks = []

        if hvd.rank() == 0:
            print('Starting Model Training...')
            print("Training Epochs", num_epochs)
            print("Total Steps", num_steps)
            print("Steps per Epoch", steps_per_epoch)
            print("Decay Steps", num_decay_steps)
            print("Weight Decay Factor", weight_decay)
            print("Init Learning Rate", lr_init)
            print("Momentum", momentum)
            print("Num GPUs", num_gpus)
            print("Per-GPU Batch Size", batch_size)

            if is_benchmark:
                self.training_logging_hook = hooks.BenchmarkLoggingHook(
                    global_batch_size=global_batch_size,
                    warmup_steps=warmup_steps)
            else:
                self.training_logging_hook = hooks.TrainingLoggingHook(
                    global_batch_size=global_batch_size,
                    num_steps=num_steps,
                    num_samples=num_samples,
                    num_epochs=num_epochs,
                    steps_per_epoch=steps_per_epoch)
            training_hooks.append(self.training_logging_hook)

        if hvd_utils.is_using_hvd():
            bcast_hook = hvd.BroadcastGlobalVariablesHook(0)
            training_hooks.append(bcast_hook)

        training_hooks.append(hooks.PrefillStagingAreasHook())
        training_hooks.append(hooks.TrainingPartitionHook())

        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'num_gpus': num_gpus,
            'momentum': momentum,
            'lr_init': lr_init,
            'lr_warmup_epochs': lr_warmup_epochs,
            'weight_decay': weight_decay,
            'loss_scale': loss_scale,
            'apply_loss_scaling': use_static_loss_scaling,
            'label_smoothing': label_smoothing,
            'mixup': mixup,
            'num_decay_steps': num_decay_steps,
            'use_cosine_lr': use_cosine_lr,
            'use_final_conv': use_final_conv,
            'quantize': quantize,
            'use_qdq': use_qdq,
            'symmetric': symmetric,
            'quant_delay': quant_delay
        }

        if finetune_checkpoint:
            estimator_params['finetune_checkpoint'] = finetune_checkpoint

        image_classifier = self._get_estimator(
            mode='train',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction,
            gpu_id=self.run_hparams.gpu_id)

        def training_data_fn():

            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    print("Using DALI input... ")

                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            elif self.run_hparams.data_dir is not None:

                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            else:
                if hvd.rank() == 0:
                    print("Using Synthetic Data ...")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )

        try:
            current_step = image_classifier.get_variable_value("global_step")
        except ValueError:
            current_step = 0

        run_iter = max(0, min(run_iter, num_steps - current_step))
        print("Current step:", current_step)

        if run_iter > 0:
            try:
                image_classifier.train(
                    input_fn=training_data_fn,
                    steps=run_iter,
                    hooks=training_hooks,
                )
            except KeyboardInterrupt:
                print("Keyboard interrupt")

        if hvd.rank() == 0:
            if run_iter > 0:
                print('Ending Model Training ...')
                train_throughput = self.training_logging_hook.mean_throughput.value(
                )
                train_time = self.training_logging_hook.train_time
                dllogger.log(data={'train_throughput': train_throughput},
                             step=tuple())
                dllogger.log(data={'Total Training time': train_time},
                             step=tuple())
            else:
                print(
                    'Model already trained required number of steps. Skipped')
コード例 #3
0
    def train(
        self,
        iter_unit,
        num_iter,
        batch_size,
        warmup_steps=50,
        weight_decay=1e-4,
        lr_init=0.1,
        lr_warmup_epochs=5,
        momentum=0.9,
        log_every_n_steps=1,
        loss_scale=256,
        label_smoothing=0.0,
        use_cosine_lr=False,
        use_static_loss_scaling=False,
        is_benchmark=False
    ):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError('`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])' % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for training!')

        if self.run_hparams.use_tf_amp or self.run_hparams.dtype == tf.float16:
            if use_static_loss_scaling:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
            else:
                LOGGER.log("TF Loss Auto Scaling is activated")
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
        else:
            use_static_loss_scaling = False  # Make sure it hasn't been set to True on FP32 training

        num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()
        global_batch_size = batch_size * num_gpus

        if self.run_hparams.data_dir is not None:
            filenames,num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="train",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=global_batch_size,
            )

            steps_per_epoch = num_steps / num_epochs

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = num_steps
            num_decay_steps = num_steps
            num_samples = num_steps * batch_size

            
        if self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir,
                mode="train"
            )
            
        training_hooks = []
      
        if hvd.rank() == 0:
            LOGGER.log('Starting Model Training...')
            LOGGER.log("Training Epochs", num_epochs)
            LOGGER.log("Total Steps", num_steps)
            LOGGER.log("Steps per Epoch", steps_per_epoch)
            LOGGER.log("Decay Steps", num_decay_steps)
            LOGGER.log("Weight Decay Factor", weight_decay)
            LOGGER.log("Init Learning Rate", lr_init)
            LOGGER.log("Momentum", momentum)
            LOGGER.log("Num GPUs", num_gpus)
            LOGGER.log("Per-GPU Batch Size", batch_size)

            
            if is_benchmark:

                benchmark_logging_hook = hooks.BenchmarkLoggingHook(
                    log_file_path=os.path.join(self.run_hparams.log_dir, "training_benchmark.json"),
                    global_batch_size=global_batch_size,
                    log_every=log_every_n_steps,
                    warmup_steps=warmup_steps
                )

                training_hooks.append(benchmark_logging_hook)

            else:

                training_logging_hook = hooks.TrainingLoggingHook(
                    log_file_path=os.path.join(self.run_hparams.log_dir, "training.json"),
                    global_batch_size=global_batch_size,
                    num_steps=num_steps,
                    num_samples=num_samples,
                    num_epochs=num_epochs,
                    log_every=log_every_n_steps
                )

                training_hooks.append(training_logging_hook)

        if hvd_utils.is_using_hvd():
            bcast_hook = hvd.BroadcastGlobalVariablesHook(0)
            training_hooks.append(bcast_hook)

        training_hooks.append(hooks.PrefillStagingAreasHook())

        # NVTX
        nvtx_callback = NVTXHook(skip_n_steps=1, name='Train')
        training_hooks.append(nvtx_callback)
      
        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'num_gpus': num_gpus,
            'momentum': momentum,
            'lr_init': lr_init,
            'lr_warmup_epochs': lr_warmup_epochs,
            'weight_decay': weight_decay,
            'loss_scale': loss_scale,
            'apply_loss_scaling': use_static_loss_scaling,
            'label_smoothing': label_smoothing,
            'num_decay_steps': num_decay_steps,
            'use_cosine_lr': use_cosine_lr
        }

        image_classifier = self._get_estimator(
            mode='train',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction
        )

        def training_data_fn():
            
            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    LOGGER.log("Using DALI input... ")
                    
                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False if self.run_hparams.seed is None else True
                )
            
            elif self.run_hparams.data_dir is not None:

                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False if self.run_hparams.seed is None else True
                )

            else:
                if hvd.rank() == 0:
                    LOGGER.log("Using Synthetic Data ...")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )


        try:
            image_classifier.train(
                input_fn=training_data_fn,
                steps=num_steps,
                hooks=training_hooks,
            )
        except KeyboardInterrupt:
            print("Keyboard interrupt")
            
        if hvd.rank() == 0:
            LOGGER.log('Ending Model Training ...')
コード例 #4
0
    def evaluate(
        self,
        iter_unit,
        num_iter,
        batch_size,
        warmup_steps=50,
        log_every_n_steps=1,
        is_benchmark=False
    ):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError('`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])' % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for evaluation!')

        if hvd_utils.is_using_hvd() and hvd.rank() != 0:
            raise RuntimeError('Multi-GPU inference is not supported')

        estimator_params = {}
            
        image_classifier = self._get_estimator(
            mode='validation',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction
        )

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="validation",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=batch_size,
            )
        
        else:
            num_epochs = 1
            num_decay_steps = -1
            num_steps = num_iter
       
    
        if self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir,
                mode="validation"
            )
    
        eval_hooks = []
        
        if hvd.rank() == 0:
            if is_benchmark:
                
                benchmark_logging_hook = hooks.BenchmarkLoggingHook(
                    log_file_path=os.path.join(self.run_hparams.log_dir, "eval_benchmark.json"),
                    global_batch_size=batch_size,
                    log_every=log_every_n_steps,
                    warmup_steps=warmup_steps
                )
                eval_hooks.append(benchmark_logging_hook)

            LOGGER.log('Starting Model Evaluation...')
            LOGGER.log("Evaluation Epochs", num_epochs)
            LOGGER.log("Evaluation Steps", num_steps)
            LOGGER.log("Decay Steps", num_decay_steps)
            LOGGER.log("Global Batch Size", batch_size)

        def evaluation_data_fn():
    
            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    LOGGER.log("Using DALI input... ")
                    
                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=False,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False if self.run_hparams.seed is None else True
                )
            
            
            elif self.run_hparams.data_dir is not None:
                return data_utils.get_tfrecords_input_fn(
                        filenames=filenames,
                        batch_size=batch_size,
                        height=self.run_hparams.height,
                        width=self.run_hparams.width,
                        training=False,
                        distort_color=self.run_hparams.distort_colors,
                        num_threads=self.run_hparams.num_preprocessing_threads,
                        deterministic=False if self.run_hparams.seed is None else True
                    )

            else:
                LOGGER.log("Using Synthetic Data ...\n")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )

        try:
            eval_results = image_classifier.evaluate(
                input_fn=evaluation_data_fn,
                steps=num_steps,
                hooks=eval_hooks,
            )
            LOGGER.log('Top-1 Accuracy: %.3f' % float(eval_results['top1_accuracy'] * 100))
            LOGGER.log('Top-5 Accuracy: %.3f' % float(eval_results['top5_accuracy'] * 100))

        except KeyboardInterrupt:
            print("Keyboard interrupt")

        LOGGER.log('Ending Model Evaluation ...')
コード例 #5
0
def main(args: argparse.Namespace):
    hvd.init()

    dllogger.init(backends=[
        dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                   filename=args.log_path),
        dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)
    ])
    dllogger.log(data=vars(args), step='PARAMETER')

    if args.model is None:
        saved_model_to_load = tempfile.mkdtemp(prefix="tftrt-savedmodel")
        r = runner.Runner(n_classes=1001,
                          architecture=args.architecture,
                          use_tf_amp=args.amp,
                          model_dir=saved_model_to_load)
        r.train("batch", 1, 1, args.batch_size, is_benchmark=True)
        r.evaluate("batch",
                   1,
                   args.batch_size,
                   export_dir=saved_model_to_load,
                   is_benchmark=True)

        saved_model_to_load = r.exported_path.decode("utf-8")
    else:
        saved_model_to_load = args.model

    output_tensor_name = "y_preds_ref:0" if not args.tf_trt else "ArgMax:0"
    batch_size = args.batch_size

    if args.tf_trt:
        converter = trt.TrtGraphConverter(
            input_saved_model_dir=str(saved_model_to_load),
            precision_mode="FP16" if args.amp else "FP32")
        converter.convert()
        converter.save(OUTPUT_SAVED_MODEL_PATH)
        saved_model_to_load = OUTPUT_SAVED_MODEL_PATH
    elif args.amp:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

    if args.data_dir is not None:
        filenames, _, num_steps, _, _ = runner_utils.parse_tfrecords_dataset(
            data_dir=str(args.data_dir),
            mode="validation",
            iter_unit="epoch",
            num_iter=1,
            global_batch_size=batch_size,
        )

        dataset = data_utils.get_tfrecords_input_fn(filenames=filenames,
                                                    batch_size=batch_size,
                                                    height=224,
                                                    width=224,
                                                    training=False,
                                                    distort_color=False,
                                                    num_threads=1,
                                                    deterministic=True)
        iterator = dataset.make_initializable_iterator()
        next_item = iterator.get_next()
    else:
        num_steps = 60000 / batch_size

    with tf.Session() as sess:
        if args.data_dir is not None:
            sess.run(iterator.initializer)
        tf.saved_model.loader.load(sess,
                                   [tf.saved_model.tag_constants.SERVING],
                                   str(saved_model_to_load))

        try:
            start_time = time.time()
            last_time = start_time
            image_processed = 0
            image_correct = 0

            for samples_processed in range(int(num_steps)):
                if args.data_dir is not None:
                    next_batch_image, next_batch_target = sess.run(next_item)
                else:
                    if samples_processed == 0:
                        next_batch_image = np.random.normal(size=(batch_size,
                                                                  224, 224, 3))
                        next_batch_target = np.random.randint(
                            0, 1000, size=(batch_size, ))
                output = sess.run(
                    [output_tensor_name],
                    feed_dict={"input_tensor:0": next_batch_image})
                image_processed += args.batch_size
                image_correct += np.sum(output == next_batch_target)

                if samples_processed % LOG_FREQUENCY == 0 and samples_processed != 0:
                    current_time = time.time()
                    current_throughput = LOG_FREQUENCY * batch_size / (
                        current_time - last_time)
                    dllogger.log(step=(0, samples_processed),
                                 data={"throughput": current_throughput})
                    last_time = current_time

        except tf.errors.OutOfRangeError:
            pass
        finally:
            dllogger.log(step=tuple(),
                         data={
                             "throughput":
                             image_processed / (last_time - start_time),
                             "accuracy":
                             image_correct / image_processed
                         })