def test_broadcast(self, data, target_shape, axes_mapping=None, mode='numpy', ref_out=None, test_raising=False): if ref_out is not None: input = valued_const_with_data('data', int64_array(data)) else: input = shaped_data('data', int64_array(data)) nodes = { **input, **valued_const_with_data('target_shape', int64_array(target_shape)), **regular_op_with_empty_data('broadcast', {'op': 'Broadcast', 'mode': mode}), } edges = [('data', 'broadcast'), ('target_shape', 'broadcast'), ('broadcast', 'broadcast_d')] if axes_mapping is not None: nodes.update(**valued_const_with_data('axes_mapping', int64_array(axes_mapping))) edges.append(('axes_mapping', 'broadcast')) graph = build_graph(nodes, edges) broadcast_node = Node(graph, 'broadcast') if test_raising: self.assertRaises(AssertionError, Broadcast.infer, broadcast_node) return Broadcast.infer(broadcast_node) if ref_out is not None: self.assertTrue(np.array_equal(broadcast_node.out_node().value, np.array(ref_out))) else: self.assertTrue(np.array_equal(broadcast_node.out_node().shape, np.array(target_shape)))
def test_broadcast_dynamic(self, data, target_shape_shape, mode='numpy', ref_out_shape=None, test_raising=False): nodes = { **shaped_data('data', int64_array(data)), **shaped_data('target_shape', int64_array(target_shape_shape)), **regular_op_with_empty_data('broadcast', {'op': 'Broadcast', 'mode': mode}), } edges = [('data', 'broadcast'), ('target_shape', 'broadcast'), ('broadcast', 'broadcast_d')] graph = build_graph(nodes, edges) broadcast_node = Node(graph, 'broadcast') if test_raising: self.assertRaises(AssertionError, Broadcast.infer, broadcast_node) return Broadcast.infer(broadcast_node) self.assertTrue(np.array_equal(broadcast_node.out_node().shape, ref_out_shape))