Exemplo n.º 1
0
    def optimize_for_inference(frozen_graph_def: tf.GraphDef,
                               output_node_names: List,
                               graph_input: str) -> graph_pb2.GraphDef:
        """Optimize graph for inference.

        Args:
            frozen_graph_def: Frozen graph definition
            output_node_names: Names of outputs
            graph_input: Name of the image input to the graph.

        Returns: Optimized inference graph definition.
        """
        logging.info('Starting graph optimization.')
        # Remove identity ops in initializers to allow fusing batch norm with conv in the next line
        optimized_graph_def = tf.graph_util.remove_training_nodes(
            frozen_graph_def)
        optimized_graph_def = fold_batch_norms(optimized_graph_def)
        transforms = [
            'remove_nodes(op=Identity, op=CheckNumerics)',
            'strip_unused_nodes', 'fold_constants(ignore_errors=true)'
        ]

        optimized_graph_def = TransformGraph(optimized_graph_def,
                                             [f"{graph_input}:0"],
                                             output_node_names, transforms)

        logging.info('Completed graph optimization.')
        return optimized_graph_def
Exemplo n.º 2
0
def construct_graph(batch_size: int = 1) -> tf.Tensor:
    """Construct densenet inference graph on the IPU.

    Args:
        batch_size: Batch size for inference

    Returns: Output probability

    """
    # Set up the graph
    densenet_model = DenseNet(blocks=DENSENET_121_BLOCKS,
                              num_classes=NUM_CLASSES,
                              image_width=IMG_WIDTH,
                              image_height=IMG_HEIGHT,
                              image_channels=IMG_CHANNELS)
    image_input = tf.placeholder(dtype=tf.float16,
                                 shape=(batch_size, IMG_HEIGHT, IMG_WIDTH,
                                        IMG_CHANNELS),
                                 name="image_input")
    densenet_model(image_input)

    # Restore weights
    checkpoint_dir = CHECKPOINT_DIR
    if tf.train.latest_checkpoint(checkpoint_dir) is None:
        logging.info(
            'Checkpoint directory `%s` does not contain a checkpoint, '
            'attempting to download pre-trained weights.',
            Path(checkpoint_dir))
        get_densenet_weights(Path(checkpoint_dir))

    if tf.train.latest_checkpoint(checkpoint_dir) is None:
        raise ValueError(
            "Weight download failed. Please re-try downloading the weights using the `densenet_weights.py`"
            " script under models/tensorflow/")

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
        logging.info('Successfully restored imagenet weights.')

        # Optimize inference graph
        logging.info('Starting graph optimization.')
        densenet_graph_def = tf.get_default_graph().as_graph_def()
        frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
            sess, densenet_graph_def, output_node_names=["output-prob"])
        # Remove identity ops in initializers to allow fusing batch norm with conv in the next line
        frozen_graph_def = tf.compat.v1.graph_util.remove_training_nodes(
            frozen_graph_def)
        optimized_graph_def = optimize_for_infer.fold_batch_norms(
            frozen_graph_def)
        logging.info('Completed graph optimization.')

    tf.reset_default_graph()
    with tf.device('/device:IPU:0'):
        with tf.variable_scope('', use_resource=True):
            return tf.import_graph_def(optimized_graph_def,
                                       input_map={},
                                       name="optimized",
                                       return_elements=["output-prob:0"])[0]
Exemplo n.º 3
0
    def setUpClass(cls):
        # Set up input to the network
        img_width = img_height = 224
        img_channels = 3
        densenet_121_blocks = (6, 12, 24, 16)
        cls.batch_size = 1
        cls.num_classes = 1000
        # Set up image input placeholder
        cls.placeholder_input = tf.placeholder(dtype=tf.float16,
                                               shape=(cls.batch_size, img_height, img_width, img_channels),
                                               name="image_input")

        # Set compile and device options
        opts = utils.create_ipu_config(profiling=False, use_poplar_text_report=False)
        utils.auto_select_ipus(opts, [1])
        utils.configure_ipu_system(opts)

        # Construct Densenet model
        cls.densenet_model = DenseNet(blocks=densenet_121_blocks, num_classes=cls.num_classes,
                                      image_width=img_width, image_height=img_height, image_channels=img_channels)

        cls.densenet_model(cls.placeholder_input)

        # Restore weights
        checkpoint_file = CHECKPOINT_PATH

        if not Path(checkpoint_file + ".index").exists():
            print('Checkpoint file does not exist, attempting to download pre-trained weights')
            checkpoint_file = get_densenet_weights(Path(checkpoint_file))

        # Create test session
        saver = tf.train.Saver()

        with tf.Session() as sess:
            saver.restore(sess, checkpoint_file)
            logging.info('Restored imagenet weights.')

            # Optimize inference graph
            logging.info('Starting graph optimization.')
            densenet_graph_def = tf.get_default_graph().as_graph_def()
            frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, densenet_graph_def,
                                                                                      output_node_names=["output-prob"])
            # Remove identity ops in initializers to allow fusing batch norm with conv in the next line
            frozen_graph_def = tf.compat.v1.graph_util.remove_training_nodes(frozen_graph_def)
            optimized_graph_def = optimize_for_infer.fold_batch_norms(frozen_graph_def)

            logging.info('Completed graph optimization.')

        tf.reset_default_graph()
        with tf.device('/device:IPU:0'):
            with tf.variable_scope('', use_resource=True):
                cls.output = tf.import_graph_def(optimized_graph_def, input_map={}, name="optimized",
                                                 return_elements=["output-prob:0"])[0]