Exemplo n.º 1
0
    def ng_run(self,
               tf_target_node,
               tf_feed_dict=None,
               print_ng_result=False,
               verbose=False):
        """
        Run and get ngrpah results
        Args:
            tf_target_node: target node in tf
            tf_feed_dict: feed_dict in tf
            print_ng_result: prints ng_result if set to True
            verbose: prints tf's node_def if set to True

        Returns:
            ng_result
        """
        # init importer, transformer
        importer = TFImporter()
        importer.import_protobuf(self.pb_txt_path, verbose=verbose)
        transformer = ngt.make_transformer()

        # set target node
        ng_target_node = importer.get_op_handle_by_name(
            tf_target_node.name[:-2])

        # evaluate ngraph
        if tf_feed_dict is not None:
            # get targeting nodes for ng, convert tf's feed dict to list
            tf_placeholder_nodes = [node for (node, _) in tf_feed_dict.items()]
            tf_placeholder_names = [node.name for node in tf_placeholder_nodes]
            ng_placeholder_nodes = [
                importer.get_op_handle_by_name(name[:-2])
                for name in tf_placeholder_names
            ]
            ng_placeholder_vals = [val for (_, val) in tf_feed_dict.items()]

            # evaluate ngraph result
            ng_result_comp = transformer.computation([ng_target_node],
                                                     *ng_placeholder_nodes)
            if importer.init_ops:
                init_comp = transformer.computation(importer.init_ops)
                init_comp()

            ng_result = ng_result_comp(*ng_placeholder_vals)[0]
        else:
            ng_result_comp = transformer.computation([ng_target_node])
            if importer.init_ops:
                init_comp = transformer.computation(importer.init_ops)
                init_comp()
            ng_result = ng_result_comp()[0]
        if print_ng_result:
            print(ng_result)

        return ng_result
    def ng_run(self,
               tf_target_node,
               tf_init_op=None,
               tf_feed_dict={},
               print_ng_result=False,
               verbose=False):
        """
        Run and get ngraph results
        Args:
            tf_target_node: target node in tf
            tf_feed_dict: feed_dict in tf
            print_ng_result: prints ng_result if set to True
            verbose: prints tf's node_def if set to True

        Returns:
            ng_result
        """
        # init importer, transformer
        importer = TFImporter()
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(self.graph_string)
        importer.import_graph_def(graph_def, verbose=verbose)

        # set target node
        ng_target_node = importer.get_op_handle_by_name(
            tf_target_node.name[:-2])

        # get targeting nodes for ng, convert tf's feed dict to list
        ng_feed_dict = {
            importer.get_op_handle_by_name(node.name[:-2]): val
            for (node, val) in tf_feed_dict.items()
        }

        # evaluate ngraph
        with ExecutorFactory() as ex:
            ng_result_comp = ex.transformer.computation(
                ng_target_node, *ng_feed_dict.keys())

            if tf_init_op:
                ex.transformer.computation(
                    importer.get_op_handle(tf_init_op))()

            ng_result = ng_result_comp(feed_dict=ng_feed_dict)

        return ng_result