def __init__(self): name = "Update" edges = Bimap() scatter = ScatterNdArrow() add = AddArrow() edges.add(scatter.out_port(0), add.in_port(1)) # must be such that we are only adding these elements to zeros super().__init__(edges=edges, in_ports=add.in_ports()[:1]+scatter.in_ports(), out_ports=add.out_ports(), name=name)
def __init__(self) -> None: name = 'DimsBarBatch' rank_arrow = RankArrow() one_source = SourceArrow(1) range_arrow = RangeArrow() edges = Bimap() # type: EdgeMap edges.add(one_source.out_ports()[0], range_arrow.in_ports()[0]) edges.add(rank_arrow.out_ports()[0], range_arrow.in_ports()[1]) super().__init__(edges=edges, in_ports=rank_arrow.in_ports(), out_ports=range_arrow.out_ports(), name=name)
def __init__(self, arith_arrow: PrimitiveArrow): name = "BroadcastArith" edges = Bimap() in_ports = [] out_ports = arith_arrow.out_ports() for in_port in arith_arrow.in_ports(): broadcast = cfarrows.BroadcastArrow() in_ports += broadcast.in_ports() edges.add(broadcast.out_ports()[0], in_port) super().__init__(edges=edges, in_ports=in_ports, out_ports=out_ports, name=name)
def inv_gather(arrow: GatherArrow, port_attr: PortAttributes) -> Tuple[Arrow, PortMap]: if is_constant(arrow.out_ports()[0], port_attr): return GatherArrow(), {0: 0, 1: 1, 2: 2} tensor_shape = port_attr[arrow.in_ports()[0]]['shape'] if isinstance(tensor_shape, tuple): tensor_shape = list(tensor_shape) index_list_value = port_attr[arrow.in_ports()[1]]['value'] index_list_compl = complement(index_list_value, tensor_shape) std1 = SparseToDenseArrow() std2 = SparseToDenseArrow() dupl1 = DuplArrow() dupl2 = DuplArrow() # FIXME: don't do this, complement could be huge source_compl = SourceArrow(np.array(index_list_compl)) source_tensor_shape = SourceArrow(np.array(tensor_shape)) add = AddArrow() edges = Bimap() edges.add(source_compl.out_ports()[0], std1.in_ports()[0]) edges.add(source_tensor_shape.out_ports()[0], dupl1.in_ports()[0]) edges.add(dupl1.out_ports()[0], std1.in_ports()[1]) edges.add(dupl1.out_ports()[1], std2.in_ports()[1]) edges.add(std1.out_ports()[0], add.in_ports()[0]) edges.add(std2.out_ports()[0], add.in_ports()[1]) # orig_out_port, params, inp_list in_ports = [std2.in_ports()[2], std1.in_ports()[2], std2.in_ports()[0]] out_ports = [add.out_ports()[0]] op = CompositeArrow(in_ports=in_ports, out_ports=out_ports, edges=edges, name="InvGather") make_param_port(op.in_ports()[1]) return op, {0: 3, 1: 2, 2: 0}
def test_mixed_knowns() -> CompositeArrow: arrow = test_twoxyplusx() c1 = SourceArrow(2) c2 = SourceArrow(2) a1 = AddArrow() edges = Bimap() # type: EdgeMap edges.add(c1.out_ports()[0], arrow.in_ports()[0]) edges.add(c2.out_ports()[0], arrow.in_ports()[1]) return CompositeArrow(in_ports=[a1.in_ports()[0], a1.in_ports()[1]], out_ports=[arrow.out_ports()[0], a1.out_ports()[0]], edges=edges, name="test_mixed_knowns")
def __init__(self): name = "InvDiv" edges = Bimap() # type: EdgeMap dupl_theta = DuplArrow() div = DivArrow() in_ports = [div.in_ports()[0], dupl_theta.in_ports()[0]] out_ports = [div.out_ports()[0], dupl_theta.out_ports()[1]] edges.add(dupl_theta.out_ports()[0], div.in_ports()[1]) super().__init__(edges=edges, in_ports=in_ports, out_ports=out_ports, name=name) make_param_port(self.ports()[1])
def __init__(self): name = "InvSub" edges = Bimap() # type: EdgeMap dupl_theta = DuplArrow() add = AddArrow() in_ports = [add.in_ports()[0], dupl_theta.in_ports()[0]] edges.add(dupl_theta.out_ports()[0], add.in_ports()[1]) out_ports = [add.out_ports()[0], dupl_theta.out_ports()[1]] super().__init__(edges=edges, in_ports=in_ports, out_ports=out_ports, name=name) make_param_port(self.ports()[1])
def __init__(self) -> None: name = 'InvPow' edges = Bimap() # type: EdgeMap dupl_theta = DuplArrow() log = LogBaseArrow() in_ports = [log.in_ports()[1], dupl_theta.in_ports()[0]] out_ports = [dupl_theta.out_ports()[1], log.out_ports()[0]] edges.add(dupl_theta.out_ports()[0], log.in_ports()[0]) super().__init__(edges=edges, in_ports=in_ports, out_ports=out_ports, name=name) make_param_port(self.ports()[1])
def test_xyplusx_flat() -> CompositeArrow: """f(x,y) = x * y + x""" mul = MulArrow() add = AddArrow() dupl = DuplArrow() edges = Bimap() # type: EdgeMap edges.add(dupl.out_ports()[0], mul.in_ports()[0]) # dupl -> mul edges.add(dupl.out_ports()[1], add.in_ports()[0]) # dupl -> add edges.add(mul.out_ports()[0], add.in_ports()[1]) # mul -> add d = CompositeArrow(in_ports=[dupl.in_ports()[0], mul.in_ports()[1]], out_ports=[add.out_ports()[0]], edges=edges) d.name = "test_xyplusx_flat" return d
def inv_dupl_approx(arrow: DuplArrow, port_values: PortAttributes) -> Tuple[Arrow, PortMap]: # assert port_values[arrow.in_ports()[0]] == VAR, "Dupl is constant" n_duplications = arrow.n_out_ports inv_dupl = InvDuplArrow(n_duplications=n_duplications) approx_id = ApproxIdentityArrow(n_inputs=n_duplications) edges = Bimap() # type: EdgeMap for i in range(n_duplications): edges.add(approx_id.out_ports()[i], inv_dupl.in_ports()[i]) error_ports = [approx_id.out_ports()[n_duplications]] out_ports = inv_dupl.out_ports() + error_ports inv_arrow = CompositeArrow(edges=edges, in_ports=approx_id.in_ports(), out_ports=out_ports, name="InvDuplApprox") make_error_port(inv_arrow.out_ports()[-1]) port_map = {0: inv_arrow.ports()[-2].index} port_map.update({i + 1: i for i in range(n_duplications)}) inv_arrow.name = "InvDuplApprox" return inv_arrow, port_map
def inv_gathernd(arrow: GatherNdArrow, port_attr: PortAttributes) -> Tuple[Arrow, PortMap]: if is_constant(arrow.out_ports()[0], port_attr): return GatherNdArrow(), {0: 0, 1: 1, 2: 2} tensor_shape = np.array(port_attr[arrow.in_ports()[0]]['shape']) index_list_value = port_attr[arrow.in_ports()[1]]['value'] index_list_compl = complement_bool(index_list_value, tensor_shape) # fixme: don't do this, complement could be huge source_compl = SourceArrow(np.array(index_list_compl, dtype=np.float32)) source_tensor_shape = SourceArrow(tensor_shape) snd = ScatterNdArrow() mul = MulArrow() add = AddArrow() edges = Bimap() edges.add(source_tensor_shape.out_port(0), snd.in_port(2)) edges.add(source_compl.out_port(0), mul.in_port(1)) edges.add(snd.out_port(0), add.in_port(0)) edges.add(mul.out_port(0), add.in_port(1)) # orig_out_port, params, inp_list in_ports = [snd.in_port(1), mul.in_port(0), snd.in_port(0)] out_ports = [add.out_port(0)] op = CompositeArrow(in_ports=in_ports, out_ports=out_ports, edges=edges, name="InvGatherNd") make_param_port(op.in_ports()[1]) return op, {0: 3, 1: 2, 2: 0}
def __init__(self, n_inputs: int) -> None: name = 'Mean' edges = Bimap() # type: EdgeMap addn_arrow = AddNArrow(n_inputs) nsource = SourceArrow(n_inputs) castarrow_nb = CastArrow(floatX()) castarrow = BroadcastArrow() div_arrow = DivArrow() edges.add(nsource.out_ports()[0], castarrow_nb.in_ports()[0]) edges.add(castarrow_nb.out_port(0), castarrow.in_port(0)) edges.add(addn_arrow.out_ports()[0], div_arrow.in_ports()[0]) edges.add(castarrow.out_ports()[0], div_arrow.in_ports()[1]) super().__init__(edges=edges, in_ports=addn_arrow.in_ports(), out_ports=div_arrow.out_ports(), name=name)
def __init__(self, n_inputs: int) -> None: name = 'VarFromMean' # import pdb; pdb.set_trace() dupl = DuplArrow(n_duplications=n_inputs) subs = [SubArrow() for i in range(n_inputs)] abss = [AbsArrow() for i in range(n_inputs)] addn = AddNArrow(n_inputs) edges = Bimap() # type: EdgeMap in_ports = [dupl.in_ports()[0]] + [sub.in_ports()[1] for sub in subs] for i in range(n_inputs): edges.add(dupl.out_ports()[i], subs[i].in_ports()[0]) edges.add(subs[i].out_ports()[0], abss[i].in_ports()[0]) edges.add(abss[i].out_ports()[0], addn.in_ports()[i]) dupl2 = DuplArrow(n_duplications=2) edges.add(addn.out_ports()[0], dupl2.in_ports()[0]) reduce_mean = ReduceMeanArrow(n_inputs=2) dimsbarbatch = DimsBarBatchArrow() edges.add(dupl2.out_ports()[0], reduce_mean.in_ports()[0]) edges.add(dupl2.out_ports()[1], dimsbarbatch.in_ports()[0]) edges.add(dimsbarbatch.out_ports()[0], reduce_mean.in_ports()[1]) out_ports = reduce_mean.out_ports() super().__init__(edges=edges, in_ports=in_ports, out_ports=out_ports, name=name)
def test_inv_twoxyplusx() -> CompositeArrow: """approximate parametric inverse of twoxyplusx""" inv_add = InvAddArrow() inv_mul = InvMulArrow() two_int = SourceArrow(2) two = CastArrow(floatX()) div = DivArrow() c = ApproxIdentityArrow(2) inv_dupl = InvDuplArrow() edges = Bimap() # type: EdgeMap edges.add(two_int.out_ports()[0], two.in_ports()[0]) edges.add(inv_add.out_ports()[0], c.in_ports()[0]) edges.add(inv_add.out_ports()[1], inv_mul.in_ports()[0]) edges.add(inv_mul.out_ports()[0], div.in_ports()[0]) edges.add(two.out_ports()[0], div.in_ports()[1]) edges.add(div.out_ports()[0], c.in_ports()[1]) edges.add(c.out_ports()[0], inv_dupl.in_ports()[0]) edges.add(c.out_ports()[1], inv_dupl.in_ports()[1]) param_inports = [inv_add.in_ports()[1], inv_mul.in_ports()[1]] op = CompositeArrow(in_ports=[inv_add.in_ports()[0]] + param_inports, out_ports=[ inv_dupl.out_ports()[0], inv_mul.out_ports()[1], c.out_ports()[2] ], edges=edges, name="InvTwoXYPlusY") make_param_port(op.in_ports()[1]) make_param_port(op.in_ports()[2]) make_error_port(op.out_ports()[2]) return op
def test_twoxyplusx() -> CompositeArrow: """f(x,y) = 2 * x * y + x""" two = SourceArrow(2.0) broadcast = BroadcastArrow() mul1 = MulArrow() mul2 = MulArrow() add = AddArrow() dupl = DuplArrow() edges = Bimap() # type: EdgeMap edges.add(dupl.out_ports()[0], mul1.in_ports()[0]) # dupl -> mul1 edges.add(dupl.out_ports()[1], add.in_ports()[0]) # dupl -> add edges.add(two.out_ports()[0], broadcast.in_ports()[0]) edges.add(broadcast.out_ports()[0], mul2.in_ports()[0]) edges.add(mul1.out_ports()[0], mul2.in_ports()[1]) edges.add(mul2.out_ports()[0], add.in_ports()[1]) # mul1 -> add return CompositeArrow(in_ports=[dupl.in_ports()[0], mul1.in_ports()[1]], out_ports=[add.out_ports()[0]], edges=edges, name="test_twoxyplusx")
def test_bimap(): a = Bimap() # type: Bimap[str, int] a.add("myage", 99) assert a.fwd("myage") == 99 assert a.inv(99) == "myage"