def optimize_for_inference(frozen_graph_def: tf.GraphDef, output_node_names: List, graph_input: str) -> graph_pb2.GraphDef: """Optimize graph for inference. Args: frozen_graph_def: Frozen graph definition output_node_names: Names of outputs graph_input: Name of the image input to the graph. Returns: Optimized inference graph definition. """ logging.info('Starting graph optimization.') # Remove identity ops in initializers to allow fusing batch norm with conv in the next line optimized_graph_def = tf.graph_util.remove_training_nodes( frozen_graph_def) optimized_graph_def = fold_batch_norms(optimized_graph_def) transforms = [ 'remove_nodes(op=Identity, op=CheckNumerics)', 'strip_unused_nodes', 'fold_constants(ignore_errors=true)' ] optimized_graph_def = TransformGraph(optimized_graph_def, [f"{graph_input}:0"], output_node_names, transforms) logging.info('Completed graph optimization.') return optimized_graph_def
def construct_graph(batch_size: int = 1) -> tf.Tensor: """Construct densenet inference graph on the IPU. Args: batch_size: Batch size for inference Returns: Output probability """ # Set up the graph densenet_model = DenseNet(blocks=DENSENET_121_BLOCKS, num_classes=NUM_CLASSES, image_width=IMG_WIDTH, image_height=IMG_HEIGHT, image_channels=IMG_CHANNELS) image_input = tf.placeholder(dtype=tf.float16, shape=(batch_size, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), name="image_input") densenet_model(image_input) # Restore weights checkpoint_dir = CHECKPOINT_DIR if tf.train.latest_checkpoint(checkpoint_dir) is None: logging.info( 'Checkpoint directory `%s` does not contain a checkpoint, ' 'attempting to download pre-trained weights.', Path(checkpoint_dir)) get_densenet_weights(Path(checkpoint_dir)) if tf.train.latest_checkpoint(checkpoint_dir) is None: raise ValueError( "Weight download failed. Please re-try downloading the weights using the `densenet_weights.py`" " script under models/tensorflow/") saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) logging.info('Successfully restored imagenet weights.') # Optimize inference graph logging.info('Starting graph optimization.') densenet_graph_def = tf.get_default_graph().as_graph_def() frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( sess, densenet_graph_def, output_node_names=["output-prob"]) # Remove identity ops in initializers to allow fusing batch norm with conv in the next line frozen_graph_def = tf.compat.v1.graph_util.remove_training_nodes( frozen_graph_def) optimized_graph_def = optimize_for_infer.fold_batch_norms( frozen_graph_def) logging.info('Completed graph optimization.') tf.reset_default_graph() with tf.device('/device:IPU:0'): with tf.variable_scope('', use_resource=True): return tf.import_graph_def(optimized_graph_def, input_map={}, name="optimized", return_elements=["output-prob:0"])[0]
def setUpClass(cls): # Set up input to the network img_width = img_height = 224 img_channels = 3 densenet_121_blocks = (6, 12, 24, 16) cls.batch_size = 1 cls.num_classes = 1000 # Set up image input placeholder cls.placeholder_input = tf.placeholder(dtype=tf.float16, shape=(cls.batch_size, img_height, img_width, img_channels), name="image_input") # Set compile and device options opts = utils.create_ipu_config(profiling=False, use_poplar_text_report=False) utils.auto_select_ipus(opts, [1]) utils.configure_ipu_system(opts) # Construct Densenet model cls.densenet_model = DenseNet(blocks=densenet_121_blocks, num_classes=cls.num_classes, image_width=img_width, image_height=img_height, image_channels=img_channels) cls.densenet_model(cls.placeholder_input) # Restore weights checkpoint_file = CHECKPOINT_PATH if not Path(checkpoint_file + ".index").exists(): print('Checkpoint file does not exist, attempting to download pre-trained weights') checkpoint_file = get_densenet_weights(Path(checkpoint_file)) # Create test session saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, checkpoint_file) logging.info('Restored imagenet weights.') # Optimize inference graph logging.info('Starting graph optimization.') densenet_graph_def = tf.get_default_graph().as_graph_def() frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, densenet_graph_def, output_node_names=["output-prob"]) # Remove identity ops in initializers to allow fusing batch norm with conv in the next line frozen_graph_def = tf.compat.v1.graph_util.remove_training_nodes(frozen_graph_def) optimized_graph_def = optimize_for_infer.fold_batch_norms(frozen_graph_def) logging.info('Completed graph optimization.') tf.reset_default_graph() with tf.device('/device:IPU:0'): with tf.variable_scope('', use_resource=True): cls.output = tf.import_graph_def(optimized_graph_def, input_map={}, name="optimized", return_elements=["output-prob:0"])[0]