Beispiel #1
0
def SumNArrow(ninputs: int):
    """
  Create arrow f(x1, ..., xn) = sum(x1, ..., xn)
  Args:
    n: number of inputs
  Returns:
    Arrow of n inputs and one output
  """
    assert ninputs > 1
    c = CompositeArrow(name="SumNArrow")
    light_port = c.add_port()
    make_in_port(light_port)

    for _ in range(ninputs - 1):
        add = AddArrow()
        c.add_edge(light_port, add.in_port(0))
        dark_port = c.add_port()
        make_in_port(dark_port)
        c.add_edge(dark_port, add.in_port(1))
        light_port = add.out_port(0)

    out_port = c.add_port()
    make_out_port(out_port)
    c.add_edge(add.out_port(0), out_port)

    assert c.is_wired_correctly()
    assert c.num_in_ports() == ninputs
    return c
Beispiel #2
0
def conv(a: CompositeArrow, args: TensorVarList, state) -> Sequence[Tensor]:
    assert len(args) == a.num_in_ports()
    with tf.name_scope(a.name):
        # import pdb; pdb.set_trace()
        # FIXME: A horrible horrible hack
        port_grab = state['port_grab']
        return interpret(conv, a, args, state, port_grab)
Beispiel #3
0
def inner_interpret(conv: Callable, comp_arrow: CompositeArrow, inputs: List,
                    arrow_colors: MutableMapping[Arrow, int],
                    arrow_inputs: Sequence, state: Dict, port_grab: Dict[Port,
                                                                         Any]):
    """Convert an comp_arrow to a tensorflow graph and add to graph"""
    assert len(inputs) == comp_arrow.num_in_ports(), "wrong # inputs"

    emit_list = []
    while len(arrow_colors) > 0:
        # print_arrow_colors(arrow_colors)
        # print("Converting ", sub_arrow.name)
        sub_arrow, priority = arrow_colors.popitem()
        if sub_arrow is not comp_arrow:
            assert priority == 0, "Must resolve {} more inputs to {} first".format(
                priority, sub_arrow)
            # inputs = [arrow_inputs[sub_arrow][i] for i in range(len(arrow_inputs[sub_arrow]))]
            inputs = [
                arrow_inputs[sub_arrow][i]
                for i in sorted(arrow_inputs[sub_arrow].keys())
            ]
            outputs = conv(sub_arrow, inputs, state)

            assert len(outputs) == len(
                sub_arrow.out_ports()), "diff num outputs"

            # Decrement the priority of each subarrow connected to this arrow
            # Unless of course it is connected to the outside word
            for i, out_port in enumerate(sub_arrow.out_ports()):
                neigh_in_ports = comp_arrow.neigh_in_ports(out_port)
                for neigh_in_port in neigh_in_ports:
                    neigh_arrow = neigh_in_port.arrow
                    arrow_colors[neigh_arrow] = arrow_colors[neigh_arrow] - 1
                    arrow_inputs[neigh_arrow][neigh_in_port.index] = outputs[i]

    # Extract some port, kind of a hack
    for port in port_grab:
        if port.arrow in arrow_inputs:
            if port.index in arrow_inputs[port.arrow]:
                port_grab[port] = arrow_inputs[port.arrow][port.index]

    outputs_dict = arrow_inputs[comp_arrow]
    out_port_indices = sorted(list(outputs_dict.keys()))
    return [outputs_dict[i] for i in out_port_indices]
Beispiel #4
0
def reparam(comp_arrow: CompositeArrow,
            phi_shape: Tuple,
            nn_takes_input=True):
    """Reparameterize an arrow.  All parametric inputs now function of phi
    Args:
        comp_arrow: Arrow to reparameterize
        phi_shape: Shape of parameter input
    """
    reparam = CompositeArrow(name="%s_reparam" % comp_arrow.name)
    phi = reparam.add_port()
    set_port_shape(phi, phi_shape)
    make_in_port(phi)
    make_param_port(phi)
    n_in_ports = 1
    if nn_takes_input:
        n_in_ports += comp_arrow.num_in_ports() - comp_arrow.num_param_ports()
    nn = TfArrow(n_in_ports=n_in_ports, n_out_ports=comp_arrow.num_param_ports())
    reparam.add_edge(phi, nn.in_port(0))
    i = 0
    j = 1
    for port in comp_arrow.ports():
        if is_param_port(port):
            reparam.add_edge(nn.out_port(i), port)
            i += 1
        else:
            re_port = reparam.add_port()
            if is_out_port(port):
                make_out_port(re_port)
                reparam.add_edge(port, re_port)
            if is_in_port(port):
                make_in_port(re_port)
                reparam.add_edge(re_port, port)
                if nn_takes_input:
                    reparam.add_edge(re_port, nn.in_port(j))
                    j += 1
            if is_error_port(port):
                make_error_port(re_port)
            for label in get_port_labels(port):
                add_port_label(re_port, label)

    assert reparam.is_wired_correctly()
    return reparam