Пример #1
0
    def test_infer_constant_input_sum(self):
        graph = build_graph(nodes_attributes2, edges2, inputs2)

        sparse_segment_sum_node = Node(graph, 'sparse_segment_sum_node')
        SparseSegmentSum.infer(sparse_segment_sum_node)

        # prepare reference results
        ref_output_segments_shape = int64_array([2, 4])
        ref_output_segments_value = np.array([[0, 0, 0, 0], [5, 6, 7, 8]],
                                             dtype=np.float)

        # get resulted shapes
        res_output_segments_shape = graph.node['output_segments']['shape']
        res_output_segments_value = graph.node['output_segments']['value']

        self.assertTrue(
            np.array_equal(ref_output_segments_shape,
                           res_output_segments_shape),
            'Shapes do not match expected: {} and given: {}'.format(
                ref_output_segments_shape, res_output_segments_shape))
        self.assertTrue(
            np.array_equal(ref_output_segments_value,
                           res_output_segments_value),
            'Shapes do not match expected: {} and given: {}'.format(
                ref_output_segments_value, res_output_segments_value))
    def test_partial_infer(self):
        graph = build_graph(nodes_attributes1, edges1, inputs1)

        sparse_segment_sum_node = Node(graph, 'sparse_segment_sum_node')
        SparseSegmentSum.infer(sparse_segment_sum_node)

        # prepare reference results
        ref_output_segments_shape = int64_array([40, 4, 5])

        # get resulted shapes
        res_output_segments_shape = graph.node['output_segments']['shape']

        self.assertTrue(np.array_equal(ref_output_segments_shape, res_output_segments_shape),
                        'Shapes do not match expected: {} and given: {}'.format(ref_output_segments_shape, res_output_segments_shape))
Пример #3
0
    def extract(cls, node):
        attrs = {}

        SparseSegmentSum.update_node_stat(node, attrs)

        return cls.enabled
Пример #4
0
    def extract(node):
        attrs = {}

        SparseSegmentSum.update_node_stat(node, attrs)

        return __class__.enabled