def test_broadcast(self, data, target_shape, axes_mapping=None, mode='numpy', ref_out=None, test_raising=False):
        if ref_out is not None:
            input = valued_const_with_data('data', int64_array(data))
        else:
            input = shaped_data('data', int64_array(data))

        nodes = {
            **input,
            **valued_const_with_data('target_shape', int64_array(target_shape)),
            **regular_op_with_empty_data('broadcast', {'op': 'Broadcast', 'mode': mode}),
        }

        edges = [('data', 'broadcast'),
                 ('target_shape', 'broadcast'),
                 ('broadcast', 'broadcast_d')]

        if axes_mapping is not None:
            nodes.update(**valued_const_with_data('axes_mapping', int64_array(axes_mapping)))
            edges.append(('axes_mapping', 'broadcast'))
        graph = build_graph(nodes, edges)

        broadcast_node = Node(graph, 'broadcast')
        if test_raising:
            self.assertRaises(AssertionError, Broadcast.infer, broadcast_node)
            return

        Broadcast.infer(broadcast_node)
        if ref_out is not None:
            self.assertTrue(np.array_equal(broadcast_node.out_node().value, np.array(ref_out)))
        else:
            self.assertTrue(np.array_equal(broadcast_node.out_node().shape, np.array(target_shape)))
Exemple #2
0
    def test_broadcast_dynamic(self, data, target_shape_shape, mode='numpy', ref_out_shape=None, test_raising=False):
        nodes = {
            **shaped_data('data', int64_array(data)),
            **shaped_data('target_shape', int64_array(target_shape_shape)),
            **regular_op_with_empty_data('broadcast', {'op': 'Broadcast', 'mode': mode}),
        }

        edges = [('data', 'broadcast'),
                 ('target_shape', 'broadcast'),
                 ('broadcast', 'broadcast_d')]

        graph = build_graph(nodes, edges)

        broadcast_node = Node(graph, 'broadcast')
        if test_raising:
            self.assertRaises(AssertionError, Broadcast.infer, broadcast_node)
            return

        Broadcast.infer(broadcast_node)
        self.assertTrue(np.array_equal(broadcast_node.out_node().shape, ref_out_shape))