Пример #1
0
    def _create_if_node(self, cond_context):
        output_shapes, output_dtypes = self._get_output_shape_dtype(
            cond_context)
        pred_node = self.g.get_node_by_output(cond_context.pred_input)
        while pred_node.type == "Identity":
            pred_node = pred_node.inputs[0]
        if pred_node.is_const():
            # Constant folding for if node
            if pred_node.get_tensor_value():
                branch_outputs = cond_context.true_branch_context.output
            else:
                branch_outputs = cond_context.false_branch_context.output
            for merge, out in zip(cond_context.merges, branch_outputs):
                self.g.replace_all_inputs(merge.output[0], out)
            return None

        true_graph = utils.construct_graph_from_nodes(
            self.g, list(cond_context.true_branch_context.nodes),
            cond_context.true_branch_context.output, output_shapes,
            output_dtypes)
        false_graph = utils.construct_graph_from_nodes(
            self.g, list(cond_context.false_branch_context.nodes),
            cond_context.false_branch_context.output, output_shapes,
            output_dtypes)
        branches = {"then_branch": true_graph, "else_branch": false_graph}
        if_node = self.g.make_node(
            "If", [cond_context.pred_input],
            op_name_scope=cond_context.cond_scope,
            outputs=[m.output[0] for m in cond_context.merges],
            shapes=output_shapes,
            dtypes=output_dtypes,
            skip_conversion=False,
            branches=branches)
        return if_node
Пример #2
0
 def _create_if_node(self, cond_context):
     output_shapes, output_dtypes = self._get_output_shape_dtype(cond_context)
     true_graph = utils.construct_graph_from_nodes(
         self.g,
         list(cond_context.true_branch_context.nodes),
         cond_context.true_branch_context.output,
         output_shapes,
         output_dtypes
     )
     false_graph = utils.construct_graph_from_nodes(
         self.g,
         list(cond_context.false_branch_context.nodes),
         cond_context.false_branch_context.output,
         output_shapes,
         output_dtypes
     )
     branches = {"then_branch": true_graph, "else_branch": false_graph}
     if_node = self.g.make_node(
         "If",
         [cond_context.pred_input],
         op_name_scope=cond_context.cond_scope,
         outputs=[m.output[0] for m in cond_context.merges],
         shapes=output_shapes,
         dtypes=output_dtypes,
         skip_conversion=False,
         branches=branches
     )
     return if_node
Пример #3
0
 def _create_if_node(self, cond_context):
     if_node = self.g.make_node(
         "If", [cond_context.pred_input],
         op_name_scope=cond_context.cond_scope,
         outputs=[m.output[0] for m in cond_context.merges],
         skip_conversion=False)
     log.debug("set graph for if branchs")
     true_graph = utils.construct_graph_from_nodes(
         self.g, list(cond_context.true_branch_context.nodes),
         cond_context.true_branch_context.output, [
             self.g.get_shape(out)
             for out in cond_context.true_branch_context.output
         ], [
             self.g.get_dtype(out)
             for out in cond_context.true_branch_context.output
         ])
     false_graph = utils.construct_graph_from_nodes(
         self.g, list(cond_context.false_branch_context.nodes),
         cond_context.false_branch_context.output, [
             self.g.get_shape(out)
             for out in cond_context.false_branch_context.output
         ], [
             self.g.get_dtype(out)
             for out in cond_context.false_branch_context.output
         ])
     if_node.set_body_graph_as_attr("then_branch", true_graph)
     if_node.set_body_graph_as_attr("else_branch", false_graph)
     return if_node
 def construct_graph_from_nodes(parent_g, nodes, outputs):
     return utils.construct_graph_from_nodes(
         parent_g,
         nodes,
         [out.id for out in outputs],
         [out.shape for out in outputs],
         [out.dtype for out in outputs]
     )