Ejemplo n.º 1
0
def partition(comp_arrow: CompositeArrow) -> List[Set[Arrow]]:
    """Partitions the comp_arrow into sequential layers of its sub_arrows"""
    partition_arrows = []
    arrow_colors = pqdict()
    for sub_arrow in comp_arrow.get_sub_arrows():
        arrow_colors[sub_arrow] = sub_arrow.num_in_ports()

    for port in comp_arrow.in_ports():
        in_ports = comp_arrow.neigh_in_ports(port)
        for in_port in in_ports:
            assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors"
            arrow_colors[in_port.arrow] -= 1

    while len(arrow_colors) > 0:
        arrow_layer = set()
        view_arrow, view_priority = arrow_colors.topitem()

        while view_priority == 0:
            sub_arrow, priority = arrow_colors.popitem()
            arrow_layer.add(sub_arrow)
            if len(arrow_colors) == 0:
                break
            view_arrow, view_priority = arrow_colors.topitem()

        partition_arrows.append(arrow_layer)

        for arrow in arrow_layer:
            for out_port in arrow.out_ports():
                in_ports = comp_arrow.neigh_in_ports(out_port)
                for in_port in in_ports:
                    if in_port.arrow != comp_arrow:
                        assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors"
                        arrow_colors[in_port.arrow] -= 1

    return partition_arrows
Ejemplo n.º 2
0
def comp(fwd: Arrow, right_inv: Arrow, DiffArrow=SquaredDifference):
    """Compositon: Pipe output of forward model into input of right inverse
    Args:
        fwd: X -> Y
        right_inv: Y -> X x Error
    Returns:
        X -> X
    """
    c = CompositeArrow(name="fwd_to_right_inv")

    # Connect left boundar to fwd
    for in_port in fwd.in_ports():
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        c.add_edge(c_in_port, in_port)
        transfer_labels(in_port, c_in_port)

    # Connect fwd to right_inv
    for i, out_port in enumerate(fwd.out_ports()):
        c.add_edge(out_port, right_inv.in_port(i))

    # connect right_inv to right boundary
    for out_port in right_inv.out_ports():
        c_out_port = c.add_port()
        make_out_port(c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)

        c.add_edge(out_port, c_out_port)
        transfer_labels(out_port, c_out_port)

    # Find difference between X and right_inv(f(x))
    right_inv_out_ports = list(filter(lambda port: not is_error_port(port),
                                      right_inv.out_ports()))  # len(X)
    assert len(right_inv_out_ports) == len(c.in_ports())
    for i, in_port in enumerate(c.in_ports()):
        diff = DiffArrow()
        c.add_edge(in_port, diff.in_port(0))
        c.add_edge(right_inv_out_ports[i], diff.in_port(1))
        error_port = c.add_port()
        make_out_port(error_port)
        make_error_port(error_port)
        add_port_label(error_port, "supervised_error")
        c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c
Ejemplo n.º 3
0
def create_arrow(arrow: CompositeArrow, equiv_thetas, port_attr, valid_ports,
                 symbt_ports):
    # New parameter space should have nclasses elements
    nclasses = num_unique_elem(equiv_thetas)
    new_arrow = CompositeArrow(name="%s_elim" % arrow.name)
    for out_port in arrow.out_ports():
        c_out_port = new_arrow.add_port()
        make_out_port(c_out_port)
        transfer_labels(out_port, c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)
        new_arrow.add_edge(out_port, c_out_port)

    flat_shape = SourceArrow(np.array([nclasses], dtype=np.int32))
    flatten = ReshapeArrow()
    new_arrow.add_edge(flat_shape.out_port(0), flatten.in_port(1))
    slim_param_flat = flatten.out_port(0)
    batch_size = None
    for in_port in arrow.in_ports():
        if in_port in valid_ports:
            symbt = symbt_ports[in_port]['symbolic_tensor']
            indices = []
            for theta in symbt.symbols:
                setid = equiv_thetas[theta]
                indices.append(setid)
            shape = get_port_shape(in_port, port_attr)
            if len(shape) > 1:
                if batch_size is not None:
                    assert shape[0] == batch_size
                batch_size = shape[0]
            gather = GatherArrow()
            src = SourceArrow(np.array(indices, dtype=np.int32))
            shape_shape = SourceArrow(np.array(shape, dtype=np.int32))
            reshape = ReshapeArrow()
            new_arrow.add_edge(slim_param_flat, gather.in_port(0))
            new_arrow.add_edge(src.out_port(0), gather.in_port(1))
            new_arrow.add_edge(gather.out_port(0), reshape.in_port(0))
            new_arrow.add_edge(shape_shape.out_port(0), reshape.in_port(1))
            new_arrow.add_edge(reshape.out_port(0), in_port)
        else:
            new_in_port = new_arrow.add_port()
            make_in_port(new_in_port)
            if is_param_port(in_port):
                make_param_port(new_in_port)
            transfer_labels(in_port, new_in_port)
            new_arrow.add_edge(new_in_port, in_port)
    assert nclasses % batch_size == 0
    slim_param = new_arrow.add_port()
    make_in_port(slim_param)
    make_param_port(slim_param)
    new_arrow.add_edge(slim_param, flatten.in_port(0))
    set_port_shape(slim_param, (batch_size, nclasses // batch_size))

    assert new_arrow.is_wired_correctly()
    return new_arrow
Ejemplo n.º 4
0
def gen_arrow_inputs(comp_arrow: CompositeArrow, inputs: List, arrow_colors):
    # Store a map from an arrow to its inputs
    # Use a dict because no guarantee we'll create input tensors in order
    arrow_inputs = dict()  # type: Dict[Arrow, MutableMapping[int, Any]]
    for sub_arrow in comp_arrow.get_all_arrows():
        arrow_inputs[sub_arrow] = dict()

    # Decrement priority of every arrow connected to the input
    for i, input_value in enumerate(inputs):
        for in_port in comp_arrow.edges[comp_arrow.in_ports()[i]]:
            # in_port = comp_arrow.inner_in_ports()[i]
            sub_arrow = in_port.arrow
            arrow_colors[sub_arrow] = arrow_colors[sub_arrow] - 1
            arrow_inputs[sub_arrow][in_port.index] = input_value

    return arrow_inputs
Ejemplo n.º 5
0
def eliminate(arrow: CompositeArrow):
    """Eliminates redundant parameter
    Args:
        a: Parametric Arrow prime for eliminate!
    Returns:
        New Parameteric Arrow with fewer parameters"""

    # Warning: This is a huge hack

    # Get the shapes of param ports
    port_attr = propagate(arrow)
    symbt_ports = {}
    for port in arrow.in_ports():
        if is_param_port(port):
            shape = get_port_shape(port, port_attr)
            symbt_ports[port] = {}
            # Create a symbolic tensor for each param port
            st = SymbolicTensor(shape=shape,
                                name="port%s" % port.index,
                                port=port)
            symbt_ports[port]['symbolic_tensor'] = st

    # repropagate
    port_attr = propagate(arrow, symbt_ports)
    # as a hack, just look on ports of duples to  find symbolic tensors which
    # should be equivalent
    dupls = filter_arrows(lambda a: a.name in dupl_names, arrow)
    dupl_to_equiv = {}

    # Not all ports contain ports with symbolic tensor constraints
    valid_ports = set()
    for dupl in dupls:
        equiv = []
        for p in dupl.ports():
            if 'symbolic_tensor' in port_attr[p]:
                valid_ports.add(port_attr[p]['symbolic_tensor'].port)
                equiv.append(port_attr[p]['symbolic_tensor'])
        dupl_to_equiv[dupl] = equiv

    equiv_thetas = find_equivalent_thetas(dupl_to_equiv)
    return create_arrow(arrow, equiv_thetas, port_attr, valid_ports,
                        symbt_ports)