Exemplo n.º 1
0
def inv_gathernd(arrow: GatherNdArrow,
                 port_attr: PortAttributes) -> Tuple[Arrow, PortMap]:
    if is_constant(arrow.out_ports()[0], port_attr):
        return GatherNdArrow(), {0: 0, 1: 1, 2: 2}
    tensor_shape = np.array(port_attr[arrow.in_ports()[0]]['shape'])
    index_list_value = port_attr[arrow.in_ports()[1]]['value']
    index_list_compl = complement_bool(index_list_value, tensor_shape)
    # fixme: don't do this, complement could be huge
    source_compl = SourceArrow(np.array(index_list_compl, dtype=np.float32))
    source_tensor_shape = SourceArrow(tensor_shape)
    snd = ScatterNdArrow()
    mul = MulArrow()
    add = AddArrow()
    edges = Bimap()
    edges.add(source_tensor_shape.out_port(0), snd.in_port(2))
    edges.add(source_compl.out_port(0), mul.in_port(1))
    edges.add(snd.out_port(0), add.in_port(0))
    edges.add(mul.out_port(0), add.in_port(1))
    # orig_out_port, params, inp_list
    in_ports = [snd.in_port(1), mul.in_port(0), snd.in_port(0)]
    out_ports = [add.out_port(0)]
    op = CompositeArrow(in_ports=in_ports,
                        out_ports=out_ports,
                        edges=edges,
                        name="InvGatherNd")
    make_param_port(op.in_ports()[1])
    return op, {0: 3, 1: 2, 2: 0}
Exemplo n.º 2
0
def inv_neg(arrow: NegArrow,
            port_attr: PortAttributes) -> Tuple[Arrow, PortMap]:
    if is_constant(arrow.in_ports()[0], port_attr):
        return deepcopy(arrow), {0: 0, 1: 1}
    sub_port_attr = extract(arrow.ports(), port_attr)
    neg = NegArrow()
    return neg, {0: 1, 1: 0}
Exemplo n.º 3
0
def inv_gather(arrow: GatherArrow,
               port_attr: PortAttributes) -> Tuple[Arrow, PortMap]:
    if is_constant(arrow.out_ports()[0], port_attr):
        return GatherArrow(), {0: 0, 1: 1, 2: 2}
    tensor_shape = port_attr[arrow.in_ports()[0]]['shape']
    if isinstance(tensor_shape, tuple):
        tensor_shape = list(tensor_shape)
    index_list_value = port_attr[arrow.in_ports()[1]]['value']
    index_list_compl = complement(index_list_value, tensor_shape)
    std1 = SparseToDenseArrow()
    std2 = SparseToDenseArrow()
    dupl1 = DuplArrow()
    dupl2 = DuplArrow()
    # FIXME: don't do this, complement could be huge
    source_compl = SourceArrow(np.array(index_list_compl))
    source_tensor_shape = SourceArrow(np.array(tensor_shape))
    add = AddArrow()
    edges = Bimap()
    edges.add(source_compl.out_ports()[0], std1.in_ports()[0])
    edges.add(source_tensor_shape.out_ports()[0], dupl1.in_ports()[0])
    edges.add(dupl1.out_ports()[0], std1.in_ports()[1])
    edges.add(dupl1.out_ports()[1], std2.in_ports()[1])
    edges.add(std1.out_ports()[0], add.in_ports()[0])
    edges.add(std2.out_ports()[0], add.in_ports()[1])
    # orig_out_port, params, inp_list
    in_ports = [std2.in_ports()[2], std1.in_ports()[2], std2.in_ports()[0]]
    out_ports = [add.out_ports()[0]]
    op = CompositeArrow(in_ports=in_ports,
                        out_ports=out_ports,
                        edges=edges,
                        name="InvGather")
    make_param_port(op.in_ports()[1])
    return op, {0: 3, 1: 2, 2: 0}
Exemplo n.º 4
0
def inv_gathernd_elim(arrow: GatherNdArrow,
                      port_attr: PortAttributes) -> Tuple[Arrow, PortMap]:
    if is_constant(arrow.out_ports()[0], port_attr):
        return GatherNdArrow(), {0: 0, 1: 1, 2: 2}
    inv_arrow = InvGatherNdArrow()
    set_port_shape(inv_arrow.out_port(0), port_attr[arrow.in_port(0)]['shape'])
    set_port_value(inv_arrow.in_port(2), port_attr[arrow.in_port(1)]['value'])
    return inv_arrow, {0: 3, 1: 2, 2: 0}
Exemplo n.º 5
0
def inv_broadcast(arrow: BroadcastArrow,
                  port_attr: PortAttributes) -> Tuple[Arrow, PortMap]:
    if all((is_constant(port, port_attr) for port in arrow.in_ports())):
        return BroadcastArrow(), {0: 0, 1: 1}

    inv_arrow = IdentityArrow()
    port_map = {
        0: 0,
        1: 1
    } if is_constant(arrow.out_ports()[0], port_attr) else {
        0: 1,
        1: 0
    }
    if ports_has(arrow.ports(), 'shape', port_attr):
        in_shape = port_attr[arrow.in_ports()[0]]['shape']
        out_shape = port_attr[arrow.out_ports()[0]]['shape']
        if len(in_shape) < len(out_shape):
            inv_arrow = InvBroadcastArrow(in_shape, out_shape)
        # start = np.zeros(len(out_shape), dtype=np.int32)
        # size = np.concatenate((np.ones(len(out_shape) - len(in_shape)), np.array(in_shape))).astype(np.int32)
        # source_start = SourceArrow(start)
        # source_size = SourceArrow(size)
        # slicer = SliceArrow()
        # source = SourceArrow(np.array(in_shape, dtype=np.int32))
        # reshape = ReshapeArrow()
        # edges = Bimap()
        # edges.add(source_start.out_ports()[0], slicer.in_ports()[1])
        # edges.add(source_size.out_ports()[0], slicer.in_ports()[2])
        # edges.add(slicer.out_ports()[0], reshape.in_ports()[0])
        # edges.add(source.out_ports()[0], reshape.in_ports()[1])
        # in_ports = [slicer.in_ports()[0]]
        # out_ports = reshape.out_ports()
        # inv_arrow = CompositeArrow(in_ports=in_ports,
        #                     out_ports=out_ports,
        #                     edges=edges,
        #                     name="InvBroadcast")

    return inv_arrow, port_map
Exemplo n.º 6
0
def generic_binary_inv(arrow: Arrow, port_values: PortAttributes,
                       PInverseArrow, Port0ConstArrow, Port0ConstPortMap,
                       Port1ConstArrow,
                       Port1ConstPortMap) -> Tuple[Arrow, PortMap]:
    # FIXME: Is this actually correct for mul/add/sub
    port_0_const = is_constant(arrow.in_ports()[0], port_values)
    port_1_const = is_constant(arrow.in_ports()[1], port_values)

    if port_0_const and port_1_const:
        # If both ports constant just return arrow as is
        inv_arrow = deepcopy(arrow)
        port_map = {0: 0, 1: 1, 2: 2}
    elif port_0_const:
        inv_arrow = Port0ConstArrow()
        port_map = Port0ConstPortMap
    elif port_1_const:
        inv_arrow = Port1ConstArrow()
        port_map = Port1ConstPortMap
    else:
        # Neither constant, do 'normal' parametric inversison
        inv_arrow = PInverseArrow()
        port_map = {0: 2, 1: 3, 2: 0}

    return inv_arrow, port_map
Exemplo n.º 7
0
def inv_exp(arrow: ExpArrow,
            port_attr: PortAttributes) -> Tuple[Arrow, PortMap]:
    if is_constant(arrow.in_ports()[0], port_attr):
        return deepcopy(arrow), {0: 0, 1: 1}
    # bounds = np.finfo(np.float32)
    ibi = IntervalBoundIdentity(0.000001, 1000000.0)
    log = LogArrow()
    comp_arrow = CompositeArrow(name="approx_invexp")
    in_port = comp_arrow.add_port()
    make_in_port(in_port)
    out_port = comp_arrow.add_port()
    make_out_port(out_port)
    error_port = comp_arrow.add_port()
    make_out_port(error_port)
    make_error_port(error_port)
    comp_arrow.add_edge(in_port, ibi.in_ports()[0])
    comp_arrow.add_edge(ibi.out_ports()[0], log.in_ports()[0])
    comp_arrow.add_edge(log.out_ports()[0], out_port)
    comp_arrow.add_edge(ibi.out_ports()[1], error_port)
    comp_arrow.is_wired_correctly()
    return comp_arrow, {0: 1, 1: 0}
Exemplo n.º 8
0
def inv_sin(arrow: SinArrow,
            port_attr: PortAttributes) -> Tuple[Arrow, PortMap]:
    if is_constant(arrow.in_ports()[0], port_attr):
        return deepcopy(arrow), {0: 0, 1: 1}
    ibi = IntervalBoundIdentity(-0.999, 0.999)
    asin = ASinArrow()

    comp_arrow = CompositeArrow(name="approx_asin")
    in_port = comp_arrow.add_port()
    make_in_port(in_port)
    out_port = comp_arrow.add_port()
    make_out_port(out_port)
    error_port = comp_arrow.add_port()
    make_out_port(error_port)
    make_error_port(error_port)

    comp_arrow.add_edge(in_port, ibi.in_ports()[0])
    comp_arrow.add_edge(ibi.out_ports()[0], asin.in_ports()[0])
    comp_arrow.add_edge(asin.out_ports()[0], out_port)
    comp_arrow.add_edge(ibi.out_ports()[1], error_port)
    comp_arrow.is_wired_correctly()
    return comp_arrow, {0: 1, 1: 0}
Exemplo n.º 9
0
def inv_cos(arrow: CosArrow,
            port_attr: PortAttributes) -> Tuple[Arrow, PortMap]:
    if is_constant(arrow.in_ports()[0], port_attr):
        return deepcopy(arrow), {0: 0, 1: 1}
    #FIXME: More rigorous than 0.999, should be 1.0 but get NaNs
    ibi = IntervalBoundIdentity(-0.999, 0.999)
    acos = ACosArrow()

    comp_arrow = CompositeArrow(name="approx_acos")
    in_port = comp_arrow.add_port()
    make_in_port(in_port)
    out_port = comp_arrow.add_port()
    make_out_port(out_port)
    error_port = comp_arrow.add_port()
    make_out_port(error_port)
    make_error_port(error_port)

    comp_arrow.add_edge(in_port, ibi.in_ports()[0])
    comp_arrow.add_edge(ibi.out_ports()[0], acos.in_ports()[0])
    comp_arrow.add_edge(acos.out_ports()[0], out_port)
    comp_arrow.add_edge(ibi.out_ports()[1], error_port)
    comp_arrow.is_wired_correctly()
    return comp_arrow, {0: 1, 1: 0}
Exemplo n.º 10
0
def inner_invert(comp_arrow: CompositeArrow, port_attr: PortAttributes,
                 dispatch: Dict[Arrow, Callable]):
    """Construct a parametric inverse of arrow
    Args:
        arrow: Arrow to invert
        dispatch: Dict mapping arrow class to invert function
    Returns:
        A (approximate) parametric inverse of `arrow`
        The ith in_port of comp_arrow will be corresponding ith out_port
        error_ports and param_ports will follow"""
    # Empty compositon for inverse
    inv_comp_arrow = CompositeArrow(name="%s_inv" % comp_arrow.name)
    # import pdb; pdb.set_trace()

    # Add a port on inverse arrow for every port on arrow
    for port in comp_arrow.ports():
        inv_port = inv_comp_arrow.add_port()
        if is_in_port(port):
            if is_constant(port, port_attr):
                make_in_port(inv_port)
            else:
                make_out_port(inv_port)
        elif is_out_port(port):
            if is_constant(port, port_attr):
                make_out_port(inv_port)
            else:
                make_in_port(inv_port)
        # Transfer port information
        # FIXME: What port_attr go transfered from port to inv_port
        if 'shape' not in port_attr[port]:
            print('WARNING: shape unknown for %s' % port)
        else:
            set_port_shape(inv_port, get_port_shape(port, port_attr))

    # invert each sub_arrow
    arrow_to_inv = dict()
    arrow_to_port_map = dict()
    for sub_arrow in comp_arrow.get_sub_arrows():
        inv_sub_arrow, port_map = invert_sub_arrow(sub_arrow, port_attr,
                                                   dispatch)
        assert sub_arrow is not None
        assert inv_sub_arrow.parent is None
        arrow_to_port_map[sub_arrow] = port_map
        arrow_to_inv[sub_arrow] = inv_sub_arrow

    # Add comp_arrow to inv
    assert comp_arrow is not None
    arrow_to_inv[comp_arrow] = inv_comp_arrow
    comp_port_map = {i: i for i in range(comp_arrow.num_ports())}
    arrow_to_port_map[comp_arrow] = comp_port_map

    # Then, rewire up all the edges
    for out_port, in_port in comp_arrow.edges.items():
        left_inv_port = get_inv_port(out_port, arrow_to_port_map, arrow_to_inv)
        right_inv_port = get_inv_port(in_port, arrow_to_port_map, arrow_to_inv)
        both = [left_inv_port, right_inv_port]
        projecting = list(
            filter(lambda x: would_project(x, inv_comp_arrow), both))
        receiving = list(
            filter(lambda x: would_receive(x, inv_comp_arrow), both))
        assert len(projecting) == 1, "Should be only 1 projecting"
        assert len(receiving) == 1, "Should be only 1 receiving"
        inv_comp_arrow.add_edge(projecting[0], receiving[0])

    for transform in transforms:
        transform(inv_comp_arrow)

    # Craete new ports on inverse compositions for parametric and error ports
    for sub_arrow in inv_comp_arrow.get_sub_arrows():
        for port in sub_arrow.ports():
            if is_param_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                param_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(param_port, port)
                make_in_port(param_port)
                make_param_port(param_port)
            elif is_error_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                error_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(port, error_port)
                make_out_port(error_port)
                make_error_port(error_port)

    return inv_comp_arrow, comp_port_map