def test_partial_infer(self):
        graph = build_graph(nodes_attributes, edges1, inputs1)

        sparse_fill_empty_rows_node = Node(graph,
                                           'sparse_fill_empty_rows_node')
        SparseFillEmptyRows.infer(sparse_fill_empty_rows_node)

        # prepare reference results
        ref_output_indices_shape = int64_array([20, 2])
        ref_output_values_shape = int64_array([20])
        ref_empty_row_indicator_shape = int64_array([4])

        # get resulted shapes
        res_output_indices_shape = graph.node['output_indices']['shape']
        res_output_values_shape = graph.node['output_values']['shape']
        res_empty_row_indicator_shape = graph.node['empty_row_indicator'][
            'shape']

        self.assertTrue(
            np.array_equal(ref_output_indices_shape, res_output_indices_shape),
            'shapes do not match expected: {} and given: {}'.format(
                ref_output_indices_shape, res_output_indices_shape))

        self.assertTrue(
            np.array_equal(ref_output_values_shape, res_output_values_shape),
            'shapes do not match expected: {} and given: {}'.format(
                ref_output_values_shape, res_output_values_shape))

        self.assertTrue(
            np.array_equal(ref_empty_row_indicator_shape,
                           res_empty_row_indicator_shape),
            'shapes do not match expected: {} and given: {}'.format(
                ref_empty_row_indicator_shape, res_empty_row_indicator_shape))
Ejemplo n.º 2
0
    def test_partial_infer_for_some_out_ports(self):
        edges = [
            ('input_indices', 'sparse_fill_empty_rows_node', {
                'in': 0
            }),
            ('input_values', 'sparse_fill_empty_rows_node', {
                'in': 1
            }),
            ('dense_shape', 'sparse_fill_empty_rows_node', {
                'in': 2
            }),
            ('default_value', 'sparse_fill_empty_rows_node', {
                'in': 3
            }),
            ('sparse_fill_empty_rows_node', 'output_indices', {
                'out': 0
            }),
            ('sparse_fill_empty_rows_node', 'empty_row_indicator', {
                'out': 2
            }),
            ('output_indices', 'result_indices', {
                'out': 0
            }),
            ('empty_row_indicator', 'result_empty_row_indicator', {
                'out': 0
            }),
        ]
        graph = build_graph(nodes_attributes, edges, inputs1)

        sparse_fill_empty_rows_node = Node(graph,
                                           'sparse_fill_empty_rows_node')
        SparseFillEmptyRows.infer(sparse_fill_empty_rows_node)

        # prepare reference results
        ref_output_indices_shape = int64_array([20, 2])
        ref_empty_row_indicator_shape = int64_array([4])

        # get resulted shapes
        res_output_indices_shape = graph.node['output_indices']['shape']
        res_empty_row_indicator_shape = graph.node['empty_row_indicator'][
            'shape']

        self.assertTrue(
            np.array_equal(ref_output_indices_shape, res_output_indices_shape),
            'shapes do not match expected: {} and given: {}'.format(
                ref_output_indices_shape, res_output_indices_shape))

        self.assertTrue(
            np.array_equal(ref_empty_row_indicator_shape,
                           res_empty_row_indicator_shape),
            'shapes do not match expected: {} and given: {}'.format(
                ref_empty_row_indicator_shape, res_empty_row_indicator_shape))
Ejemplo n.º 3
0
    def extract(cls, node):
        attrs = {}

        SparseFillEmptyRows.update_node_stat(node, attrs)

        return cls.enabled
    def extract(node):
        attrs = {}

        SparseFillEmptyRows.update_node_stat(node, attrs)

        return __class__.enabled