def test_replace_input(self): def f(wire): if wire.name == 'a': w = pyrtl.clone_wire(wire, 'w2') else: w = pyrtl.clone_wire(wire, 'w3') return wire, w a, b = pyrtl.input_list('a/1 b/1') w1 = a & b o = pyrtl.Output(1, 'o') o <<= w1 src_nets, dst_nets = pyrtl.working_block().net_connections() self.assertEqual(src_nets[w1], pyrtl.LogicNet('&', None, (a, b), (w1, ))) self.assertIn(a, dst_nets) self.assertIn(b, dst_nets) transform.wire_transform(f, select_types=pyrtl.Input, exclude_types=tuple()) w2 = pyrtl.working_block().get_wirevector_by_name('w2') w3 = pyrtl.working_block().get_wirevector_by_name('w3') src_nets, dst_nets = pyrtl.working_block().net_connections() self.assertEqual(src_nets[w1], pyrtl.LogicNet('&', None, (w2, w3), (w1, ))) self.assertNotIn(a, dst_nets) self.assertNotIn(b, dst_nets)
def test_equivelence_of_same_nets(self): a = pyrtl.WireVector(1) b = pyrtl.WireVector(1) c = pyrtl.WireVector(1) net = pyrtl.LogicNet('+', 'xx', (a, b), (c,)) net2 = pyrtl.LogicNet('+', 'xx', (a, b), (c,)) self.assertIsNot(net, net2) self.assertEqual(net, net2)
def test_equivelence_of_different_nets(self): a = pyrtl.WireVector() b = pyrtl.WireVector() c = pyrtl.WireVector() n = pyrtl.LogicNet('-', 'John', (a, b), (c,)) net = pyrtl.LogicNet('+', 'John', (a, b), (c,)) net2 = pyrtl.LogicNet('+', 'xx', (a, b), (c,)) net3 = pyrtl.LogicNet('+', 'xx', (b, a), (c,)) net4 = pyrtl.LogicNet('+', 'xx', (b, a, c), (c,)) net5 = pyrtl.LogicNet('+', 'xx', (b, a, c), (c, a)) net6 = pyrtl.LogicNet('+', 'xx', (b, a, c), (a,)) self.assertDifferentNets(n, net) self.assertDifferentNets(net, net2) self.assertDifferentNets(net2, net3) self.assertDifferentNets(net3, net4) self.assertDifferentNets(net4, net5) self.assertDifferentNets(net4, net6) self.assertDifferentNets(net5, net6) # some extra edge cases to check netx_1 = pyrtl.LogicNet('+', 'John', (a, a), (c,)) netx_2 = pyrtl.LogicNet('+', 'John', (a,), (c,)) self.assertDifferentNets(netx_1, netx_2)
def test_no_logic_net_comparisons(self): a = pyrtl.WireVector(bitwidth=3) b = pyrtl.WireVector(bitwidth=3) select = pyrtl.WireVector(bitwidth=3) outwire = pyrtl.WireVector(bitwidth=3) net1 = pyrtl.LogicNet(op='x', op_param=None, args=(select, a, b), dests=(outwire,)) net2 = pyrtl.LogicNet(op='x', op_param=None, args=(select, b, a), dests=(outwire,)) with self.assertRaises(pyrtl.PyrtlError): foo = net1 < net2 with self.assertRaises(pyrtl.PyrtlError): foo = net1 <= net2 with self.assertRaises(pyrtl.PyrtlError): foo = net1 > net2 with self.assertRaises(pyrtl.PyrtlError): foo = net1 >= net2
def new_net(op='&', op_param=None, args=None, dests=None): if args is None or isinstance(args, int): args = tuple( pyrtl.Input(2) for i in range(args if isinstance(args, int) else 2)) if dests is None: dests = (pyrtl.Output(2, 'out'), ) return pyrtl.LogicNet(op=op, op_param=op_param, args=args, dests=dests)
def test_comparison(self): net = pyrtl.LogicNet('+', 'xx', ("arg1", "arg2"), ("dest",)) with self.assertRaises(pyrtl.PyrtlError): a = net < net with self.assertRaises(pyrtl.PyrtlError): a = net <= net with self.assertRaises(pyrtl.PyrtlError): a = net >= net with self.assertRaises(pyrtl.PyrtlError): a = net > net
def new_net(op='&', op_param=None, args=None, dests=None): if args is None or isinstance(args, int): args = tuple( pyrtl.Input(2) for i in range(args if isinstance(args, int) else 2)) if dests is None or isinstance(dests, int): def dest(): return pyrtl.Register(2) if op == 'r' else pyrtl.Output(2) dests = tuple( dest() for i in range(dests if isinstance(dests, int) else 1)) return pyrtl.LogicNet(op=op, op_param=op_param, args=args, dests=dests)
def test_replace_output(self): def f(wire): w = pyrtl.clone_wire(wire, 'w2') return w, wire a, b = pyrtl.input_list('a/1 b/1') w1 = a & b o = pyrtl.Output(1, 'o') o <<= w1 src_nets, dst_nets = pyrtl.working_block().net_connections() self.assertEqual(dst_nets[w1], [pyrtl.LogicNet('w', None, (w1, ), (o, ))]) self.assertIn(o, src_nets) transform.wire_transform(f, select_types=pyrtl.Output, exclude_types=tuple()) w2 = pyrtl.working_block().get_wirevector_by_name('w2') src_nets, dst_nets = pyrtl.working_block().net_connections() self.assertEqual(dst_nets[w1], [pyrtl.LogicNet('w', None, (w1, ), (w2, ))]) self.assertNotIn(o, src_nets)
def test_as_graph_trivial(self): i = pyrtl.Input(1) o = pyrtl.Output(1) b = pyrtl.working_block() net = pyrtl.LogicNet('~', None, (i,), (o,)) b.add_net(net) src_g, dst_g = b.net_connections(False) self.check_graph_correctness(src_g, dst_g) self.assertEqual(src_g[o], net) self.assertEqual(dst_g[i][0], net) self.assertEqual(len(dst_g[i]), 1) self.assertNotIn(i, src_g) self.assertNotIn(o, dst_g) src_g, dst_g = b.net_connections(True) self.check_graph_correctness(src_g, dst_g, True) self.assertEqual(src_g[o], net) self.assertEqual(dst_g[i][0], net) self.assertEqual(len(dst_g[i]), 1) self.assertIs(src_g[i], i) self.assertIs(dst_g[o][0], o) self.assertEqual(len(dst_g[o]), 1)
def test_self_equals(self): a = pyrtl.WireVector() b = pyrtl.WireVector() c = pyrtl.WireVector() net = pyrtl.LogicNet('+', 'xx', (a, b), (c,)) self.assertEqual(net, net)
def test_string_format(self): net = pyrtl.LogicNet('+', 'xx', ("arg1", "arg2"), ("dest",)) self.assertEqual(str(net), "dest <-- + -- arg1, arg2 (xx)")
def test_self_equals(self): net = pyrtl.LogicNet('+', 'xx', ("arg1", "arg2"), ("dest", )) self.assertTrue(net == net)