Example #1
0
    def array_infer(node: Node):
        assert len(node.in_nodes()) == 4

        handle = node.in_node(0)
        index = node.in_node(1)
        value = node.in_node(2)
        flow_in = node.in_node(3)

        value_shape = value.shape

        ta_node = Node(node.graph, str(handle.value))
        if ta_node.has_valid('element_shape') and len(
                ta_node.element_shape) > 0:
            assert match_shapes(ta_node['element_shape'], value.shape), \
                'Shapes are not compatible: {} and {}'.format(ta_node['element_shape'], value.shape)
        ta_node['element_shape'] = value_shape

        output_shape = flow_in.shape
        output_value = flow_in.value

        # flow_out
        for _, out_node in node.graph.out_edges(node.id):
            node.graph.node[out_node]['shape'] = np.array(output_shape)
            node.graph.node[out_node][
                'value'] = None if output_value is None else np.array(
                    output_value)
    def array_infer(node: Node):
        handle = node.in_node(0)
        indices = node.in_node(1)
        value = node.in_node(2)
        flow_in = node.in_node(3)

        ta_node = Node(node.graph, str(handle.value))
        if ta_node.has_valid('element_shape') and len(
                ta_node.element_shape) > 0:
            assert match_shapes(ta_node['element_shape'], value.shape[1:]), \
                'Shapes are not compatible: {} and {}'.format(ta_node['element_shape'], value.shape[1:])
        else:
            ta_node['element_shape'] = value.shape[1:]

        # Assign element_shape anyway, because the original element_shape can contain -1
        ta_node['element_shape'] = value.shape[1:]

        output_shape = flow_in.shape
        output_value = flow_in.value
        #flow_out
        for _, out_node in node.graph.out_edges(node.id):
            node.graph.node[out_node]['shape'] = np.array(output_shape)
            node.graph.node[out_node][
                'value'] = None if output_value is None else np.array(
                    output_value)
Example #3
0
 def run_match_shapes(self, pattern: list, shape: list):
     return match_shapes(shape_array(pattern), shape_array(shape))
Example #4
0
 def run_match_shapes(self, pattern: list, shape: list):
     return match_shapes(np.array(pattern, dtype=np.int64),
                         np.array(shape, dtype=np.int64))