def main(_): # Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs) hvd.init() if not FLAGS.output_file: raise ValueError( 'You must supply the path to save to with --output_file') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default() as graph: if FLAGS.input_format == 'NCHW': input_shape = [ FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size ] else: input_shape = [ FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3 ] input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape) resnet50_config = resnet.model_architectures[FLAGS.model_name] network = resnet.ResnetModel(FLAGS.model_name, FLAGS.num_classes, resnet50_config['layers'], resnet50_config['widths'], resnet50_config['expansions'], FLAGS.compute_format, FLAGS.input_format) probs, logits = network.build_model( input_images, training=False, reuse=False, use_final_conv=FLAGS.use_final_conv) if FLAGS.quantize: tf.contrib.quantize.experimental_create_eval_graph( symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq) # Define the saver and restore the checkpoint saver = tf.train.Saver() with tf.Session() as sess: if FLAGS.checkpoint: saver.restore(sess, FLAGS.checkpoint) else: sess.run(tf.global_variables_initializer()) graph_def = graph.as_graph_def() frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, graph_def, [probs.op.name]) # Write out the frozen graph tf.io.write_graph(frozen_graph_def, os.path.dirname(FLAGS.output_file), os.path.basename(FLAGS.output_file), as_text=FLAGS.write_text_graphdef)
def __init__( self, # ========= Model HParams ========= # n_classes=1001, architecture='resnet50', input_format='NHWC', # NCHW or NHWC compute_format='NCHW', # NCHW or NHWC dtype=tf.float32, # tf.float32 or tf.float16 n_channels=3, height=224, width=224, distort_colors=False, model_dir=None, log_dir=None, data_dir=None, data_idx_dir=None, weight_init="fan_out", # ======= Optimization HParams ======== # use_xla=False, use_tf_amp=False, use_dali=False, gpu_memory_fraction=1.0, gpu_id=0, # ======== Debug Flags ======== # debug_verbosity=0, seed=None): if dtype not in [tf.float32, tf.float16]: raise ValueError( "Unknown dtype received: %s (allowed: `tf.float32` and `tf.float16`)" % dtype) if compute_format not in ["NHWC", 'NCHW']: raise ValueError( "Unknown `compute_format` received: %s (allowed: ['NHWC', 'NCHW'])" % compute_format) if input_format not in ["NHWC", 'NCHW']: raise ValueError( "Unknown `input_format` received: %s (allowed: ['NHWC', 'NCHW'])" % input_format) if n_channels not in [1, 3]: raise ValueError( "Unsupported number of channels: %d (allowed: 1 (grayscale) and 3 (color))" % n_channels) tf_seed = 2 * (seed + hvd.rank()) if seed is not None else None # ============================================ # Optimsation Flags - Do not remove # ============================================ os.environ['CUDA_CACHE_DISABLE'] = '0' os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL' #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private' os.environ['TF_GPU_THREAD_COUNT'] = '1' if not hvd_utils.is_using_hvd( ) else str(hvd.size()) os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1' os.environ['TF_ADJUST_HUE_FUSED'] = '1' os.environ['TF_ADJUST_SATURATION_FUSED'] = '1' os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' os.environ['TF_SYNC_ON_FINISH'] = '0' os.environ['TF_AUTOTUNE_THRESHOLD'] = '2' os.environ['TF_DISABLE_NVTX_RANGES'] = '1' os.environ["TF_XLA_FLAGS"] = ( os.environ.get("TF_XLA_FLAGS", "") + " --tf_xla_enable_lazy_compilation=false") # ============================================ # TF-AMP Setup - Do not remove # ============================================ if dtype == tf.float16: if use_tf_amp: raise RuntimeError( "TF AMP can not be activated for FP16 precision") elif use_tf_amp: os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1" else: os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "0" # ================================================= model_hparams = tf.contrib.training.HParams( width=height, height=width, n_channels=n_channels, n_classes=n_classes, dtype=dtype, input_format=input_format, compute_format=compute_format, distort_colors=distort_colors, seed=tf_seed) num_preprocessing_threads = 10 if not use_dali else 4 run_config_performance = tf.contrib.training.HParams( num_preprocessing_threads=num_preprocessing_threads, use_tf_amp=use_tf_amp, use_xla=use_xla, use_dali=use_dali, gpu_memory_fraction=gpu_memory_fraction, gpu_id=gpu_id) run_config_additional = tf.contrib.training.HParams( model_dir=model_dir if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None, log_dir=log_dir if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None, data_dir=data_dir, data_idx_dir=data_idx_dir, num_preprocessing_threads=num_preprocessing_threads) self.run_hparams = Runner._build_hparams(model_hparams, run_config_additional, run_config_performance) model_name = architecture architecture = resnet.model_architectures[architecture] self._model = resnet.ResnetModel( model_name=model_name, n_classes=model_hparams.n_classes, layers_count=architecture["layers"], layers_depth=architecture["widths"], expansions=architecture["expansions"], input_format=model_hparams.input_format, compute_format=model_hparams.compute_format, dtype=model_hparams.dtype, weight_init=weight_init, use_dali=use_dali, cardinality=architecture['cardinality'] if 'cardinality' in architecture else 1, use_se=architecture['use_se'] if 'use_se' in architecture else False, se_ratio=architecture['se_ratio'] if 'se_ratio' in architecture else 1) if self.run_hparams.seed is not None: np.random.seed(self.run_hparams.seed) tf.set_random_seed(self.run_hparams.seed) self.training_logging_hook = None self.eval_logging_hook = None