def main(_):

    # Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
    hvd.init()

    if not FLAGS.output_file:
        raise ValueError(
            'You must supply the path to save to with --output_file')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default() as graph:
        if FLAGS.input_format == 'NCHW':
            input_shape = [
                FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size
            ]
        else:
            input_shape = [
                FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3
            ]
        input_images = tf.placeholder(name='input',
                                      dtype=tf.float32,
                                      shape=input_shape)

        resnet50_config = resnet.model_architectures[FLAGS.model_name]
        network = resnet.ResnetModel(FLAGS.model_name, FLAGS.num_classes,
                                     resnet50_config['layers'],
                                     resnet50_config['widths'],
                                     resnet50_config['expansions'],
                                     FLAGS.compute_format, FLAGS.input_format)
        probs, logits = network.build_model(
            input_images,
            training=False,
            reuse=False,
            use_final_conv=FLAGS.use_final_conv)

        if FLAGS.quantize:
            tf.contrib.quantize.experimental_create_eval_graph(
                symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq)

        # Define the saver and restore the checkpoint
        saver = tf.train.Saver()
        with tf.Session() as sess:
            if FLAGS.checkpoint:
                saver.restore(sess, FLAGS.checkpoint)
            else:
                sess.run(tf.global_variables_initializer())
            graph_def = graph.as_graph_def()
            frozen_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, graph_def, [probs.op.name])

        # Write out the frozen graph
        tf.io.write_graph(frozen_graph_def,
                          os.path.dirname(FLAGS.output_file),
                          os.path.basename(FLAGS.output_file),
                          as_text=FLAGS.write_text_graphdef)
Exemple #2
0
    def __init__(
            self,
            # ========= Model HParams ========= #
            n_classes=1001,
            architecture='resnet50',
            input_format='NHWC',  # NCHW or NHWC
            compute_format='NCHW',  # NCHW or NHWC
            dtype=tf.float32,  # tf.float32 or tf.float16
            n_channels=3,
            height=224,
            width=224,
            distort_colors=False,
            model_dir=None,
            log_dir=None,
            data_dir=None,
            data_idx_dir=None,
            weight_init="fan_out",

            # ======= Optimization HParams ======== #
            use_xla=False,
            use_tf_amp=False,
            use_dali=False,
            gpu_memory_fraction=1.0,
            gpu_id=0,

            # ======== Debug Flags ======== #
            debug_verbosity=0,
            seed=None):

        if dtype not in [tf.float32, tf.float16]:
            raise ValueError(
                "Unknown dtype received: %s (allowed: `tf.float32` and `tf.float16`)"
                % dtype)

        if compute_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `compute_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % compute_format)

        if input_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `input_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % input_format)

        if n_channels not in [1, 3]:
            raise ValueError(
                "Unsupported number of channels: %d (allowed: 1 (grayscale) and 3 (color))"
                % n_channels)

        tf_seed = 2 * (seed + hvd.rank()) if seed is not None else None

        # ============================================
        # Optimsation Flags - Do not remove
        # ============================================

        os.environ['CUDA_CACHE_DISABLE'] = '0'

        os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'

        #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

        os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
        os.environ['TF_GPU_THREAD_COUNT'] = '1' if not hvd_utils.is_using_hvd(
        ) else str(hvd.size())

        os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'

        os.environ['TF_ADJUST_HUE_FUSED'] = '1'
        os.environ['TF_ADJUST_SATURATION_FUSED'] = '1'
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

        os.environ['TF_SYNC_ON_FINISH'] = '0'
        os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'
        os.environ['TF_DISABLE_NVTX_RANGES'] = '1'
        os.environ["TF_XLA_FLAGS"] = (
            os.environ.get("TF_XLA_FLAGS", "") +
            " --tf_xla_enable_lazy_compilation=false")

        # ============================================
        # TF-AMP Setup - Do not remove
        # ============================================

        if dtype == tf.float16:
            if use_tf_amp:
                raise RuntimeError(
                    "TF AMP can not be activated for FP16 precision")

        elif use_tf_amp:
            os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
        else:
            os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "0"

        # =================================================

        model_hparams = tf.contrib.training.HParams(
            width=height,
            height=width,
            n_channels=n_channels,
            n_classes=n_classes,
            dtype=dtype,
            input_format=input_format,
            compute_format=compute_format,
            distort_colors=distort_colors,
            seed=tf_seed)

        num_preprocessing_threads = 10 if not use_dali else 4
        run_config_performance = tf.contrib.training.HParams(
            num_preprocessing_threads=num_preprocessing_threads,
            use_tf_amp=use_tf_amp,
            use_xla=use_xla,
            use_dali=use_dali,
            gpu_memory_fraction=gpu_memory_fraction,
            gpu_id=gpu_id)

        run_config_additional = tf.contrib.training.HParams(
            model_dir=model_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            log_dir=log_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            data_dir=data_dir,
            data_idx_dir=data_idx_dir,
            num_preprocessing_threads=num_preprocessing_threads)

        self.run_hparams = Runner._build_hparams(model_hparams,
                                                 run_config_additional,
                                                 run_config_performance)

        model_name = architecture
        architecture = resnet.model_architectures[architecture]

        self._model = resnet.ResnetModel(
            model_name=model_name,
            n_classes=model_hparams.n_classes,
            layers_count=architecture["layers"],
            layers_depth=architecture["widths"],
            expansions=architecture["expansions"],
            input_format=model_hparams.input_format,
            compute_format=model_hparams.compute_format,
            dtype=model_hparams.dtype,
            weight_init=weight_init,
            use_dali=use_dali,
            cardinality=architecture['cardinality']
            if 'cardinality' in architecture else 1,
            use_se=architecture['use_se']
            if 'use_se' in architecture else False,
            se_ratio=architecture['se_ratio']
            if 'se_ratio' in architecture else 1)

        if self.run_hparams.seed is not None:
            np.random.seed(self.run_hparams.seed)
            tf.set_random_seed(self.run_hparams.seed)

        self.training_logging_hook = None
        self.eval_logging_hook = None