Пример #1
0
def reparam(comp_arrow: CompositeArrow,
            phi_shape: Tuple,
            nn_takes_input=True):
    """Reparameterize an arrow.  All parametric inputs now function of phi
    Args:
        comp_arrow: Arrow to reparameterize
        phi_shape: Shape of parameter input
    """
    reparam = CompositeArrow(name="%s_reparam" % comp_arrow.name)
    phi = reparam.add_port()
    set_port_shape(phi, phi_shape)
    make_in_port(phi)
    make_param_port(phi)
    n_in_ports = 1
    if nn_takes_input:
        n_in_ports += comp_arrow.num_in_ports() - comp_arrow.num_param_ports()
    nn = TfArrow(n_in_ports=n_in_ports, n_out_ports=comp_arrow.num_param_ports())
    reparam.add_edge(phi, nn.in_port(0))
    i = 0
    j = 1
    for port in comp_arrow.ports():
        if is_param_port(port):
            reparam.add_edge(nn.out_port(i), port)
            i += 1
        else:
            re_port = reparam.add_port()
            if is_out_port(port):
                make_out_port(re_port)
                reparam.add_edge(port, re_port)
            if is_in_port(port):
                make_in_port(re_port)
                reparam.add_edge(re_port, port)
                if nn_takes_input:
                    reparam.add_edge(re_port, nn.in_port(j))
                    j += 1
            if is_error_port(port):
                make_error_port(re_port)
            for label in get_port_labels(port):
                add_port_label(re_port, label)

    assert reparam.is_wired_correctly()
    return reparam
Пример #2
0
def inner_invert(comp_arrow: CompositeArrow, port_attr: PortAttributes,
                 dispatch: Dict[Arrow, Callable]):
    """Construct a parametric inverse of arrow
    Args:
        arrow: Arrow to invert
        dispatch: Dict mapping arrow class to invert function
    Returns:
        A (approximate) parametric inverse of `arrow`
        The ith in_port of comp_arrow will be corresponding ith out_port
        error_ports and param_ports will follow"""
    # Empty compositon for inverse
    inv_comp_arrow = CompositeArrow(name="%s_inv" % comp_arrow.name)
    # import pdb; pdb.set_trace()

    # Add a port on inverse arrow for every port on arrow
    for port in comp_arrow.ports():
        inv_port = inv_comp_arrow.add_port()
        if is_in_port(port):
            if is_constant(port, port_attr):
                make_in_port(inv_port)
            else:
                make_out_port(inv_port)
        elif is_out_port(port):
            if is_constant(port, port_attr):
                make_out_port(inv_port)
            else:
                make_in_port(inv_port)
        # Transfer port information
        # FIXME: What port_attr go transfered from port to inv_port
        if 'shape' not in port_attr[port]:
            print('WARNING: shape unknown for %s' % port)
        else:
            set_port_shape(inv_port, get_port_shape(port, port_attr))

    # invert each sub_arrow
    arrow_to_inv = dict()
    arrow_to_port_map = dict()
    for sub_arrow in comp_arrow.get_sub_arrows():
        inv_sub_arrow, port_map = invert_sub_arrow(sub_arrow, port_attr,
                                                   dispatch)
        assert sub_arrow is not None
        assert inv_sub_arrow.parent is None
        arrow_to_port_map[sub_arrow] = port_map
        arrow_to_inv[sub_arrow] = inv_sub_arrow

    # Add comp_arrow to inv
    assert comp_arrow is not None
    arrow_to_inv[comp_arrow] = inv_comp_arrow
    comp_port_map = {i: i for i in range(comp_arrow.num_ports())}
    arrow_to_port_map[comp_arrow] = comp_port_map

    # Then, rewire up all the edges
    for out_port, in_port in comp_arrow.edges.items():
        left_inv_port = get_inv_port(out_port, arrow_to_port_map, arrow_to_inv)
        right_inv_port = get_inv_port(in_port, arrow_to_port_map, arrow_to_inv)
        both = [left_inv_port, right_inv_port]
        projecting = list(
            filter(lambda x: would_project(x, inv_comp_arrow), both))
        receiving = list(
            filter(lambda x: would_receive(x, inv_comp_arrow), both))
        assert len(projecting) == 1, "Should be only 1 projecting"
        assert len(receiving) == 1, "Should be only 1 receiving"
        inv_comp_arrow.add_edge(projecting[0], receiving[0])

    for transform in transforms:
        transform(inv_comp_arrow)

    # Craete new ports on inverse compositions for parametric and error ports
    for sub_arrow in inv_comp_arrow.get_sub_arrows():
        for port in sub_arrow.ports():
            if is_param_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                param_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(param_port, port)
                make_in_port(param_port)
                make_param_port(param_port)
            elif is_error_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                error_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(port, error_port)
                make_out_port(error_port)
                make_error_port(error_port)

    return inv_comp_arrow, comp_port_map