def ndps_parallel(ndps): dps = [ndp.get_dp() for ndp in ndps] dp = make_parallel_n(dps) F = dp.get_fun_space() R = dp.get_res_space() fnames = [] ftypes = [] rnames = [] rtypes = [] coords_postfix = [] coords_prefix = [] for i, ndp_i in enumerate(ndps): fnames_i = ndp_i.get_fnames() if not fnames_i: coords_prefix.append([]) else: mine = [] for j, fn in enumerate(fnames_i): ft = ndp_i.get_ftype(fn) F0_index = len(fnames) mine.append(F0_index) fnames.append(fn) ftypes.append(ft) if len(mine) == 1: mine = mine[0] coords_prefix.append(mine) rnames_i = ndp_i.get_rnames() for j, rn in enumerate(rnames_i): rt = ndp_i.get_rtype(rn) if len(rnames_i) == 1: coords_postfix.append(i) else: coords_postfix.append((i, j)) rnames.append(rn) rtypes.append(rt) F0 = PosetProduct(ftypes) prefix = Mux(F0, coords_prefix) assert F == prefix.get_res_space() R0 = PosetProduct(rtypes) postfix = Mux(R, coords_postfix) assert R0 == postfix.get_res_space() res_dp = make_series(make_series(prefix, dp), postfix) from mocdp.comp.connection import simplify_if_only_one_name res_dp, fnames, rnames = simplify_if_only_one_name(res_dp, fnames, rnames) res = SimpleWrap(res_dp, fnames, rnames) return res
def add_muxes(inner, cs, s_muxed, inner_name='_inner0', mux1_name='_mux1', mux2_name='_mux2'): """ Add muxes before and after inner ---(extraf)--| |---(extrar)-- |--c1-----| inner |--c1--| s_muxed-|--c2-----| |--c2--|--s_muxed """ extraf = [f for f in inner.get_fnames() if not f in cs] extrar = [r for r in inner.get_rnames() if not r in cs] fnames = extraf + [s_muxed] rnames = extrar + [s_muxed] name2ndp = {} connections = [] name2ndp[inner_name] = inner # Second mux if len(cs) == 1: F = inner.get_ftype(cs[0]) nto1 = SimpleWrap(Identity(F), fnames=cs[0], rnames=s_muxed) else: types = inner.get_ftypes(cs) F = PosetProduct(types.subs) # [0, 1, 2] coords = list(range(len(cs))) mux = Mux(F, coords) nto1 = SimpleWrap(mux, fnames=cs, rnames=s_muxed) if len(cs) == 1: R = inner.get_rtype(cs[0]) _1ton = SimpleWrap(Identity(R), fnames=s_muxed, rnames=cs[0]) else: # First mux coords = list(range(len(cs))) R = mux.get_res_space() mux2 = Mux(R, coords) _1ton = SimpleWrap(mux2, fnames=s_muxed, rnames=cs) F2 = mux2.get_res_space() tu = get_types_universe() tu.check_equal(F, F2) name2ndp[mux1_name] = nto1 name2ndp[mux2_name] = _1ton for n in cs: connections.append(Connection(inner_name, n, mux1_name, n)) for n in cs: connections.append(Connection(mux2_name, n, inner_name, n)) # Now add the remaining names connect_functions_to_outside(name2ndp, connections, ndp_name=inner_name, fnames=extraf) connect_resources_to_outside(name2ndp, connections, ndp_name=inner_name, rnames=extrar) connect_resources_to_outside(name2ndp, connections, ndp_name=mux1_name, rnames=[s_muxed]) connect_functions_to_outside(name2ndp, connections, ndp_name=mux2_name, fnames=[s_muxed]) outer = CompositeNamedDP.from_parts(name2ndp=name2ndp, connections=connections, fnames=fnames, rnames=rnames) return outer
def connect2(ndp1, ndp2, connections, split, repeated_ok=False): """ Note the argument split must be a list of strings so that orders are preserved and deterministic. """ if ndp1 is ndp2: raise ValueError('Equal') def common(x, y): return len(set(x + y)) != len(set(x)) + len(set(y)) if not repeated_ok: if (common(ndp1.get_fnames(), ndp2.get_fnames()) or common(ndp1.get_rnames(), ndp2.get_rnames())): raise_desc(DPInternalError, 'repeated names', ndp1=ndp1, ndp2=ndp2, connections=connections, split=split) if len(set(split)) != len(split): msg = 'Repeated signals in split: %s' % str(split) raise ValueError(msg) try: if not connections: raise ValueError('Empty connections') # | |------------------------->A # | | |-B1(split)-----> # f1->| |--B1----->| ___ # | 1 | |----B2->| | all_s2 = B2 + C2 all_s1 = B1 + C1 # |___| -C1--C2---------->| 2 |->r2 # ---------D----------------->|___| # # ftot = f1 + D # rtot = A + b1 + r2 # A + B + C = r1 # B + C + D = f2 # split = A + B # split = B1 is given # find B2 from B1 def s2_from_s1(s1): for c in connections: if c.s1 == s1: return c.s2 assert False, 'Cannot find connection with s1 = %s' % s1 def s1_from_s2(s2): for c in connections: if c.s2 == s2: return c.s1 assert False, 'Cannot find connection with s2 = %s' % s2 f1 = ndp1.get_fnames() r1 = ndp1.get_rnames() f2 = ndp2.get_fnames() r2 = ndp2.get_rnames() all_s2 = set([c.s2 for c in connections]) all_s1 = set([c.s1 for c in connections]) # assert that all split are in s1 for x in split: assert x in all_s1 B1 = list(split) B2 = map(s2_from_s1, B1) C2 = list_diff(all_s2, B2) C1 = map(s1_from_s2, C2) A = list_diff(r1, B1 + C1) D = list_diff(f2, B2 + C2) # print('B1: %s' % B1) # print('B2: %s' % B2) # print('C2: %s' % C1) # print('C1: %s' % C1) # print(' A: %s' % A) # print(' D: %s' % D) fntot = f1 + D rntot = A + B1 + r2 if there_are_repetitions(fntot) or there_are_repetitions(rntot): raise_desc(NotImplementedError, 'Repeated names', fnames=fntot, rnames=fntot) # now I can create Ftot and Rtot f1_types = ndp1.get_ftypes(f1) D_types = ndp2.get_ftypes(D) # print('f1: %s' % f1) # print('f1 types: %s' % f1_types) # print('D: %s' % D) # print('D types: %s' % D_types) Ftot = PosetProduct(tuple(list(f1_types) + list(D_types))) Rtot = PosetProduct( tuple( list(ndp1.get_rtypes(A)) + list(ndp1.get_rtypes(B1)) + list(ndp2.get_rtypes(r2)))) # print('Ftot: %s' % str(Ftot)) # print(' %s' % str(fntot)) # print('Rtot: %s' % str(Rtot)) # print(' %s' % str(rntot)) assert len(fntot) == len(Ftot), (fntot, Ftot) assert len(rntot) == len(Rtot), (rntot, Rtot) # I can create the first muxer m1 # from ftot to Product(f1, D) m1_for_f1 = [fntot.index(s) for s in f1] m1_for_D = [fntot.index(s) for s in D] m1coords = [m1_for_f1, m1_for_D] m1 = Mux(Ftot, m1coords) # print('m1: %s' % m1) # print('m1.R: %s' % m1.get_res_space()) # Get Identity on D D_types = ndp2.get_ftypes(D) Id_D = Identity(D_types) ndp1_p = its_dp_as_product(ndp1) X = make_parallel(ndp1_p, Id_D) # make sure we can connect m1_X = make_series(m1, X) # print('m1_X = %s' % m1_X) # print('m1_X.R = %s' % m1_X.get_res_space() ) def coords_cat(c1, m): if m != (): return c1 + (m, ) else: return c1 A_B1_types = PosetProduct( tuple(ndp1.get_rtypes(A)) + tuple(ndp1.get_rtypes(B1))) Id_A_B1 = Identity(A_B1_types) ndp2_p = its_dp_as_product(ndp2) Z = make_parallel(Id_A_B1, ndp2_p) # print('Z.R = %s' % Z.get_res_space()) # print('B1: %s' % B1) # print('R2: %s' % r2) m2coords_A = [(0, (A + B1).index(x)) for x in A] m2coords_B1 = [(0, (A + B1).index(x)) for x in B1] m2coords_r2 = [(1, r2.index(x)) for x in r2] m2coords = m2coords_A + m2coords_B1 + m2coords_r2 # print('m2coords_A: %r' % m2coords_A) # print('m2coords_B1: %r' % m2coords_B1) # print('m2coords_r2: %r' % m2coords_r2) # print('m2coords: %r' % m2coords) # print('Z.R: %s' % Z.get_res_space()) m2 = Mux(Z.get_res_space(), m2coords) assert len(m2.get_res_space()) == len(rntot), ((m2.get_res_space(), rntot)) # make sure we can connect make_series(Z, m2) # # f0 -> |m1| -> | X | -> |Y |-> |Z| -> |m2| -> r0 # # X = dp1 | Id_D # Z = Id_B1 | dp2 # ___ # | |------------------------->A # | | |-B1-----------> # f1->| |--B1----->| ___ # | 1 | |----B2->| | # |___| -C1-----------C2->| 2 |->r2 # ---------D----------------->|___| # ___ # | |-------------------------------->A # | | . *-B1-------.-----> # f1->| | . |--B1----->* . ___ # | 1 |--.-| *----B2->| . | | # |___| . |-C1------------C2->|-.->| 2 |->r2 # ---------D-.-------------------->| . |___| # m1 | X | Y | Z | m2 # I need to write the muxer # look at the end # iterate 2's functions Y_coords_A_B1 = [] for x in A: Y_coords_A_B1.append((0, r1.index(x))) for x in B1: Y_coords_A_B1.append((0, r1.index(x))) Y_coords_B2_C2_D = [] for x in f2: if (x in B2) or (x in C2): Y_coords_B2_C2_D.append((0, r1.index(s1_from_s2(x)))) assert x not in D elif x in D: Y_coords_B2_C2_D.append((1, D.index(x))) else: assert False # print ('Y_coords_A_B1: %s' % Y_coords_A_B1) # print ('Y_coords_B2_C2_D: %s' % Y_coords_B2_C2_D) Y_coords = [Y_coords_A_B1, Y_coords_B2_C2_D] Y = Mux(m1_X.get_res_space(), Y_coords) # m1* Xp Y* Zp m2* # Let's make series # m1_X is simplifed Y_Z = make_series(Y, Z) Y_Z_m2 = make_series(Y_Z, m2) res_dp = make_series(m1_X, Y_Z_m2) fnames = fntot rnames = rntot res_dp, fnames, rnames = simplify_if_only_one_name( res_dp, fnames, rnames) # print('res_dp: %s' % res_dp) res = dpwrap(res_dp, fnames, rnames) return res except Exception as e: msg = 'connect2() failed' raise_wrapped(DPInternalError, e, msg, ndp1=ndp1, ndp2=ndp2, connections=connections, split=split)
def cndp_abstract_loop2(ndp): """ Abstracts the dp using the canonical form """ from .composite_makecanonical import get_canonical_elements res = get_canonical_elements(ndp) cycles = res['cycles'] if len(cycles) > 1: msg = ( 'I expected that the cycles were already compacted, while %s remain.' % cycles) raise_desc(NotImplementedError, msg, res=res) inner = res['inner'] inner_dp = inner.get_dp() extraf = res['extraf'] extrar = res['extrar'] # print 'ndp', ndp.get_fnames(), ndp.get_rnames() # print 'inner', inner.get_fnames(), inner.get_rnames() # print 'extra', extraf, extrar # print 'cycles', res['cycles'] assert extraf == ndp.get_fnames(), (extraf, ndp.get_fnames()) assert extrar == ndp.get_rnames(), (extrar, ndp.get_rnames()) # We use the ndp layer to create a dp that has F1 = ndp.get_ftypes(extraf) R1 = ndp.get_rtypes(extrar) # if len(cycles) > 1: # msg = 'Expected there would be at most one cycle, found: %d.' % len(cycles) # raise_desc(Exception, msg, ndp=ndp) if len(cycles) == 0: # raise NotImplementedError() mcdp_dev_warning('this needs much more testing') dp = inner_dp fnames = extraf rnames = extrar if len(fnames) == 1: fnames = fnames[0] if len(rnames) == 1: rnames = rnames[0] from mocdp.comp.wrap import dpwrap return dpwrap(dp, fnames, rnames) F2 = inner.get_rtype(cycles[0]) R2 = F2 dp0F = PosetProduct((F1, F2)) coords1 = [] for inner_fname in inner.get_fnames(): if inner_fname in extraf: coords1.append((0, extraf.index(inner_fname))) else: coords1.append(1) if len(coords1) == 1: coords1 = coords1[0] mux1 = Mux(dp0F, coords1) assert mux1.get_res_space() == inner_dp.get_fun_space() mux0F = inner_dp.get_res_space() coords2extra = [] for rname in extrar: i = inner.get_rnames().index(rname) if len(inner.get_rnames()) == 1: i = () coords2extra.append(i) j = inner.get_rnames().index(cycles[0]) if len(inner.get_rnames()) == 1: j = () coords2 = [coords2extra, j] mux2 = Mux(mux0F, coords2) dp0 = make_series(make_series(mux1, inner_dp), mux2) dp0R_expect = PosetProduct((R1, R2)) assert dp0.get_res_space() == dp0R_expect dp = DPLoop2(dp0) # this is what we want to obtain at the end F = ndp.get_ftypes(ndp.get_fnames()) if len(ndp.get_fnames()) == 1: F = F[0] R = ndp.get_rtypes(ndp.get_rnames()) if len(ndp.get_rnames()) == 1: R = R[0] if len(extraf) == 1: dp = make_series(Mux(F, [()]), dp) if len(extrar) == 1: dp = make_series(dp, Mux(PosetProduct((R, )), 0)) tu = get_types_universe() tu.check_equal(dp.get_fun_space(), F) tu.check_equal(dp.get_res_space(), R) fnames = extraf rnames = extrar if len(fnames) == 1: fnames = fnames[0] if len(rnames) == 1: rnames = rnames[0] # now dp has extra (1) and (1) return SimpleWrap(dp, fnames=fnames, rnames=rnames)
def connect2(ndp1, ndp2, connections, split, repeated_ok=False): """ Note the argument split must be a list of strings so that orders are preserved and deterministic. """ if ndp1 is ndp2: raise ValueError('Equal') def common(x, y): return len(set(x + y)) != len(set(x)) + len(set(y)) if not repeated_ok: if (common(ndp1.get_fnames(), ndp2.get_fnames()) or common(ndp1.get_rnames(), ndp2.get_rnames())): raise_desc(DPInternalError, 'repeated names', ndp1=ndp1, ndp2=ndp2, connections=connections, split=split) if len(set(split)) != len(split): msg = 'Repeated signals in split: %s' % str(split) raise ValueError(msg) try: if not connections: raise ValueError('Empty connections') # | |------------------------->A # | | |-B1(split)-----> # f1->| |--B1----->| ___ # | 1 | |----B2->| | all_s2 = B2 + C2 all_s1 = B1 + C1 # |___| -C1--C2---------->| 2 |->r2 # ---------D----------------->|___| # # ftot = f1 + D # rtot = A + b1 + r2 # A + B + C = r1 # B + C + D = f2 # split = A + B # split = B1 is given # find B2 from B1 def s2_from_s1(s1): for c in connections: if c.s1 == s1: return c.s2 assert False, 'Cannot find connection with s1 = %s' % s1 def s1_from_s2(s2): for c in connections: if c.s2 == s2: return c.s1 assert False, 'Cannot find connection with s2 = %s' % s2 f1 = ndp1.get_fnames() r1 = ndp1.get_rnames() f2 = ndp2.get_fnames() r2 = ndp2.get_rnames() all_s2 = set([c.s2 for c in connections]) all_s1 = set([c.s1 for c in connections]) # assert that all split are in s1 for x in split: assert x in all_s1 B1 = list(split) B2 = map(s2_from_s1, B1) C2 = list_diff(all_s2, B2) C1 = map(s1_from_s2, C2) A = list_diff(r1, B1 + C1) D = list_diff(f2, B2 + C2) # print('B1: %s' % B1) # print('B2: %s' % B2) # print('C2: %s' % C1) # print('C1: %s' % C1) # print(' A: %s' % A) # print(' D: %s' % D) fntot = f1 + D rntot = A + B1 + r2 if there_are_repetitions(fntot) or there_are_repetitions(rntot): raise_desc(NotImplementedError, 'Repeated names', fnames=fntot, rnames=fntot) # now I can create Ftot and Rtot f1_types = ndp1.get_ftypes(f1) D_types = ndp2.get_ftypes(D) # print('f1: %s' % f1) # print('f1 types: %s' % f1_types) # print('D: %s' % D) # print('D types: %s' % D_types) Ftot = PosetProduct(tuple(list(f1_types) + list(D_types))) Rtot = PosetProduct(tuple(list(ndp1.get_rtypes(A)) + list(ndp1.get_rtypes(B1)) + list(ndp2.get_rtypes(r2)))) # print('Ftot: %s' % str(Ftot)) # print(' %s' % str(fntot)) # print('Rtot: %s' % str(Rtot)) # print(' %s' % str(rntot)) assert len(fntot) == len(Ftot), (fntot, Ftot) assert len(rntot) == len(Rtot), (rntot, Rtot) # I can create the first muxer m1 # from ftot to Product(f1, D) m1_for_f1 = [fntot.index(s) for s in f1] m1_for_D = [fntot.index(s) for s in D] m1coords = [m1_for_f1, m1_for_D] m1 = Mux(Ftot, m1coords) # print('m1: %s' % m1) # print('m1.R: %s' % m1.get_res_space()) # Get Identity on D D_types = ndp2.get_ftypes(D) Id_D = Identity(D_types) ndp1_p = its_dp_as_product(ndp1) X = make_parallel(ndp1_p, Id_D) # make sure we can connect m1_X = make_series(m1, X) # print('m1_X = %s' % m1_X) # print('m1_X.R = %s' % m1_X.get_res_space() ) def coords_cat(c1, m): if m != (): return c1 + (m,) else: return c1 A_B1_types = PosetProduct(tuple(ndp1.get_rtypes(A)) + tuple(ndp1.get_rtypes(B1))) Id_A_B1 = Identity(A_B1_types) ndp2_p = its_dp_as_product(ndp2) Z = make_parallel(Id_A_B1, ndp2_p) # print('Z.R = %s' % Z.get_res_space()) # print('B1: %s' % B1) # print('R2: %s' % r2) m2coords_A = [(0, (A + B1).index(x)) for x in A] m2coords_B1 = [(0, (A + B1).index(x)) for x in B1] m2coords_r2 = [(1, r2.index(x)) for x in r2] m2coords = m2coords_A + m2coords_B1 + m2coords_r2 # print('m2coords_A: %r' % m2coords_A) # print('m2coords_B1: %r' % m2coords_B1) # print('m2coords_r2: %r' % m2coords_r2) # print('m2coords: %r' % m2coords) # print('Z.R: %s' % Z.get_res_space()) m2 = Mux(Z.get_res_space(), m2coords) assert len(m2.get_res_space()) == len(rntot), ((m2.get_res_space(), rntot)) # make sure we can connect make_series(Z, m2) # # f0 -> |m1| -> | X | -> |Y |-> |Z| -> |m2| -> r0 # # X = dp1 | Id_D # Z = Id_B1 | dp2 # ___ # | |------------------------->A # | | |-B1-----------> # f1->| |--B1----->| ___ # | 1 | |----B2->| | # |___| -C1-----------C2->| 2 |->r2 # ---------D----------------->|___| # ___ # | |-------------------------------->A # | | . *-B1-------.-----> # f1->| | . |--B1----->* . ___ # | 1 |--.-| *----B2->| . | | # |___| . |-C1------------C2->|-.->| 2 |->r2 # ---------D-.-------------------->| . |___| # m1 | X | Y | Z | m2 # I need to write the muxer # look at the end # iterate 2's functions Y_coords_A_B1 = [] for x in A: Y_coords_A_B1.append((0, r1.index(x))) for x in B1: Y_coords_A_B1.append((0, r1.index(x))) Y_coords_B2_C2_D = [] for x in f2: if (x in B2) or (x in C2): Y_coords_B2_C2_D.append((0, r1.index(s1_from_s2(x)))) assert x not in D elif x in D: Y_coords_B2_C2_D.append((1, D.index(x))) else: assert False # print ('Y_coords_A_B1: %s' % Y_coords_A_B1) # print ('Y_coords_B2_C2_D: %s' % Y_coords_B2_C2_D) Y_coords = [Y_coords_A_B1, Y_coords_B2_C2_D] Y = Mux(m1_X.get_res_space(), Y_coords) # m1* Xp Y* Zp m2* # Let's make series # m1_X is simplifed Y_Z = make_series(Y, Z) Y_Z_m2 = make_series(Y_Z, m2) res_dp = make_series(m1_X, Y_Z_m2) fnames = fntot rnames = rntot res_dp, fnames, rnames = simplify_if_only_one_name(res_dp, fnames, rnames) # print('res_dp: %s' % res_dp) res = dpwrap(res_dp, fnames, rnames) return res except Exception as e: msg = 'connect2() failed' raise_wrapped(DPInternalError, e, msg, ndp1=ndp1, ndp2=ndp2, connections=connections, split=split)
def compact_context(context): """ If there are two subs with multiple connections, we take the product of their wires. """ from .context_functions import find_nodes_with_multiple_connections from mcdp_dp import Mux from mocdp.comp.wrap import dpwrap from mocdp.comp.connection import connect2 s = find_nodes_with_multiple_connections(context) if not s: return context else: name1, name2, their_connections = s[0] logger.debug('Will compact %s, %s, %s' % s[0]) # establish order their_connections = list(their_connections) s1s = [c.s1 for c in their_connections] s2s = [c.s2 for c in their_connections] # print 'compacting', their_connections ndp1 = context.names[name1] ndp2 = context.names[name2] sname = '_'.join(sorted(s1s)) # space -- [mux] -- R -- [demux] space = ndp1.get_rtypes(s1s) N = len(their_connections) mux = Mux(space, [list(range(N))]) muxndp = dpwrap(mux, s1s, sname) R = mux.get_res_space() coords = [(0, i) for i in range(N)] demux = Mux(R, coords) R2 = demux.get_res_space() assert space == R2, (space, R2) # example: R = PosetProduct((PosetProduct((A, B, C)),)) # demuxndp = dpwrap(demux, sname, s2s) replace1 = connect2(ndp1, muxndp, connections=set( [Connection('*', s, '*', s) for s in s1s]), split=[], repeated_ok=False) replace2 = connect2(demuxndp, ndp2, connections=set( [Connection('*', s, '*', s) for s in s2s]), split=[], repeated_ok=False) context.names[name1] = replace1 context.names[name2] = replace2 context.connections = [ x for x in context.connections if not x in their_connections ] c = Connection(name1, sname, name2, sname) context.connections.append(c) return compact_context(context)
def cndp_abstract_loop2(ndp): """ Abstracts the dp using the canonical form """ from .composite_makecanonical import get_canonical_elements res = get_canonical_elements(ndp) cycles = res['cycles'] if len(cycles) > 1: msg = ('I expected that the cycles were already compacted, while %s remain.' % cycles) raise_desc(NotImplementedError, msg, res=res) inner = res['inner'] inner_dp = inner.get_dp() extraf = res['extraf'] extrar = res['extrar'] # print 'ndp', ndp.get_fnames(), ndp.get_rnames() # print 'inner', inner.get_fnames(), inner.get_rnames() # print 'extra', extraf, extrar # print 'cycles', res['cycles'] assert extraf == ndp.get_fnames(), (extraf, ndp.get_fnames()) assert extrar == ndp.get_rnames(), (extrar, ndp.get_rnames()) # We use the ndp layer to create a dp that has F1 = ndp.get_ftypes(extraf) R1 = ndp.get_rtypes(extrar) # if len(cycles) > 1: # msg = 'Expected there would be at most one cycle, found: %d.' % len(cycles) # raise_desc(Exception, msg, ndp=ndp) if len(cycles) == 0: # raise NotImplementedError() mcdp_dev_warning('this needs much more testing') dp = inner_dp fnames = extraf rnames = extrar if len(fnames) == 1: fnames = fnames[0] if len(rnames) == 1: rnames = rnames[0] from mocdp.comp.wrap import dpwrap return dpwrap(dp, fnames, rnames) F2 = inner.get_rtype(cycles[0]) R2 = F2 dp0F = PosetProduct((F1,F2)) coords1 = [] for inner_fname in inner.get_fnames(): if inner_fname in extraf: coords1.append((0, extraf.index(inner_fname))) else: coords1.append(1) if len(coords1) == 1: coords1 = coords1[0] mux1 = Mux(dp0F, coords1) assert mux1.get_res_space() == inner_dp.get_fun_space() mux0F = inner_dp.get_res_space() coords2extra = [] for rname in extrar: i = inner.get_rnames().index(rname) if len(inner.get_rnames()) == 1: i = () coords2extra.append(i) j = inner.get_rnames().index(cycles[0]) if len(inner.get_rnames()) == 1: j = () coords2 = [coords2extra, j] mux2 = Mux(mux0F, coords2) dp0 = make_series(make_series(mux1, inner_dp), mux2) dp0R_expect = PosetProduct((R1, R2)) assert dp0.get_res_space() == dp0R_expect dp = DPLoop2(dp0) # this is what we want to obtain at the end F = ndp.get_ftypes(ndp.get_fnames()) if len(ndp.get_fnames()) == 1: F = F[0] R = ndp.get_rtypes(ndp.get_rnames()) if len(ndp.get_rnames()) == 1: R = R[0] if len(extraf) == 1: dp = make_series(Mux(F, [()]), dp) if len(extrar) == 1: dp = make_series(dp, Mux(PosetProduct((R,)), 0)) tu = get_types_universe() tu.check_equal(dp.get_fun_space(), F) tu.check_equal(dp.get_res_space(), R) fnames = extraf rnames = extrar if len(fnames) == 1: fnames = fnames[0] if len(rnames) == 1: rnames = rnames[0] # now dp has extra (1) and (1) return SimpleWrap(dp, fnames=fnames, rnames=rnames)
def compact_context(context): """ If there are two subs with multiple connections, we take the product of their wires. """ from .context_functions import find_nodes_with_multiple_connections from mcdp_dp import Mux from mocdp.comp.wrap import dpwrap from mocdp.comp.connection import connect2 s = find_nodes_with_multiple_connections(context) if not s: return context else: name1, name2, their_connections = s[0] logger.debug('Will compact %s, %s, %s' % s[0]) # establish order their_connections = list(their_connections) s1s = [c.s1 for c in their_connections] s2s = [c.s2 for c in their_connections] # print 'compacting', their_connections ndp1 = context.names[name1] ndp2 = context.names[name2] sname = '_'.join(sorted(s1s)) # space -- [mux] -- R -- [demux] space = ndp1.get_rtypes(s1s) N = len(their_connections) mux = Mux(space, [list(range(N))]) muxndp = dpwrap(mux, s1s, sname) R = mux.get_res_space() coords = [(0, i) for i in range(N)] demux = Mux(R, coords) R2 = demux.get_res_space() assert space == R2, (space, R2) # example: R = PosetProduct((PosetProduct((A, B, C)),)) # demuxndp = dpwrap(demux, sname, s2s) replace1 = connect2(ndp1, muxndp, connections=set([Connection('*', s, '*', s) for s in s1s]), split=[], repeated_ok=False) replace2 = connect2(demuxndp, ndp2, connections=set([Connection('*', s, '*', s) for s in s2s]), split=[], repeated_ok=False) context.names[name1] = replace1 context.names[name2] = replace2 context.connections = [x for x in context.connections if not x in their_connections] c = Connection(name1, sname, name2, sname) context.connections.append(c) return compact_context(context)
def add_muxes(inner, cs, s_muxed, inner_name="_inner0", mux1_name="_mux1", mux2_name="_mux2"): """ Add muxes before and after inner ---(extraf)--| |---(extrar)-- |--c1-----| inner |--c1--| s_muxed-|--c2-----| |--c2--|--s_muxed """ extraf = [f for f in inner.get_fnames() if not f in cs] extrar = [r for r in inner.get_rnames() if not r in cs] fnames = extraf + [s_muxed] rnames = extrar + [s_muxed] name2ndp = {} connections = [] name2ndp[inner_name] = inner # Second mux if len(cs) == 1: F = inner.get_ftype(cs[0]) nto1 = SimpleWrap(Identity(F), fnames=cs[0], rnames=s_muxed) else: types = inner.get_ftypes(cs) F = PosetProduct(types.subs) # [0, 1, 2] coords = list(range(len(cs))) mux = Mux(F, coords) nto1 = SimpleWrap(mux, fnames=cs, rnames=s_muxed) if len(cs) == 1: R = inner.get_rtype(cs[0]) _1ton = SimpleWrap(Identity(R), fnames=s_muxed, rnames=cs[0]) else: # First mux coords = list(range(len(cs))) R = mux.get_res_space() mux2 = Mux(R, coords) _1ton = SimpleWrap(mux2, fnames=s_muxed, rnames=cs) F2 = mux2.get_res_space() tu = get_types_universe() tu.check_equal(F, F2) name2ndp[mux1_name] = nto1 name2ndp[mux2_name] = _1ton for n in cs: connections.append(Connection(inner_name, n, mux1_name, n)) for n in cs: connections.append(Connection(mux2_name, n, inner_name, n)) # Now add the remaining names connect_functions_to_outside(name2ndp, connections, ndp_name=inner_name, fnames=extraf) connect_resources_to_outside(name2ndp, connections, ndp_name=inner_name, rnames=extrar) connect_resources_to_outside(name2ndp, connections, ndp_name=mux1_name, rnames=[s_muxed]) connect_functions_to_outside(name2ndp, connections, ndp_name=mux2_name, fnames=[s_muxed]) outer = CompositeNamedDP.from_parts(name2ndp=name2ndp, connections=connections, fnames=fnames, rnames=rnames) return outer