예제 #1
0
def unparam(arrow: Arrow, nnet: Arrow = None):
    """Unparameerize an arrow by sticking a tfArrow between its normal inputs,
    and any parametric inputs
    Args:
        arrow: Y x Theta -> X
        nnet: Y -> Theta
    Returns:
        Y -> X
    """
    c = CompositeArrow(name="%s_unparam" % arrow.name)
    in_ports = [p for p in arrow.in_ports() if not is_param_port(p)]
    param_ports = [p for p in arrow.in_ports() if is_param_port(p)]
    if nnet is None:
        nnet = TfArrow(n_in_ports=len(in_ports), n_out_ports=len(param_ports))
    for i, in_port in enumerate(in_ports):
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        transfer_labels(in_port, c_in_port)
        c.add_edge(c_in_port, in_port)
        c.add_edge(c_in_port, nnet.in_port(i))

    for i, param_port in enumerate(param_ports):
        c.add_edge(nnet.out_port(i), param_port)

    for out_port in arrow.out_ports():
        c_out_port = c.add_port()
        make_out_port(c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)
        transfer_labels(out_port, c_out_port)
        c.add_edge(out_port, c_out_port)

    assert c.is_wired_correctly()
    return c
예제 #2
0
    def __init__(self, lmbda: float) -> None:
        comp_arrow = CompositeArrow(name="ExponentialRVQuantile")
        in_port = comp_arrow.add_port()
        make_in_port(in_port)
        out_port = comp_arrow.add_port()
        make_out_port(out_port)

        lmbda_source = SourceArrow(lmbda)
        one_source = SourceArrow(1.0)
        one_minus_p = SubArrow()
        comp_arrow.add_edge(one_source.out_ports()[0],
                            one_minus_p.in_ports()[0])
        comp_arrow.add_edge(in_port, one_minus_p.in_ports()[1])
        ln = LogArrow()
        comp_arrow.add_edge(one_minus_p.out_ports()[0], ln.in_ports()[0])

        negate = NegArrow()
        comp_arrow.add_edge(ln.out_ports()[0], negate.in_ports()[0])
        div_lmbda = DivArrow()
        comp_arrow.add_edge(negate.out_ports()[0], div_lmbda.in_ports()[0])
        comp_arrow.add_edge(lmbda_source.out_ports()[0],
                            div_lmbda.in_ports()[1])
        comp_arrow.add_edge(div_lmbda.out_ports()[0], out_port)

        assert comp_arrow.is_wired_correctly()
        self.quantile = comp_arrow
예제 #3
0
    def __init__(self, n_inputs: int) -> None:
        super().__init__(name="VarFromMean")
        comp_arrow = self
        in_ports = [comp_arrow.add_port() for i in range(n_inputs + 1)]
        for in_port in in_ports:
            make_in_port(in_port)
        out_port = comp_arrow.add_port()
        make_out_port(out_port)

        sub_arrows = [SubArrow() for i in range(n_inputs)]
        squares = [SquareArrow() for i in range(n_inputs)]
        addn = AddNArrow(n_inputs)
        for i in range(n_inputs):
            comp_arrow.add_edge(in_ports[0], sub_arrows[i].in_ports()[1])
            comp_arrow.add_edge(in_ports[i + 1], sub_arrows[i].in_ports()[0])
            comp_arrow.add_edge(sub_arrows[i].out_ports()[0],
                                squares[i].in_ports()[0])
            comp_arrow.add_edge(squares[i].out_ports()[0], addn.in_ports()[i])

        nn = SourceArrow(n_inputs)
        cast = CastArrow(floatX())
        variance = DivArrow()
        comp_arrow.add_edge(nn.out_ports()[0], cast.in_ports()[0])
        comp_arrow.add_edge(addn.out_ports()[0], variance.in_ports()[0])
        comp_arrow.add_edge(cast.out_ports()[0], variance.in_ports()[1])

        comp_arrow.add_edge(variance.out_ports()[0], out_port)
        assert comp_arrow.is_wired_correctly
예제 #4
0
def SumNArrow(ninputs: int):
    """
  Create arrow f(x1, ..., xn) = sum(x1, ..., xn)
  Args:
    n: number of inputs
  Returns:
    Arrow of n inputs and one output
  """
    assert ninputs > 1
    c = CompositeArrow(name="SumNArrow")
    light_port = c.add_port()
    make_in_port(light_port)

    for _ in range(ninputs - 1):
        add = AddArrow()
        c.add_edge(light_port, add.in_port(0))
        dark_port = c.add_port()
        make_in_port(dark_port)
        c.add_edge(dark_port, add.in_port(1))
        light_port = add.out_port(0)

    out_port = c.add_port()
    make_out_port(out_port)
    c.add_edge(add.out_port(0), out_port)

    assert c.is_wired_correctly()
    assert c.num_in_ports() == ninputs
    return c
예제 #5
0
def gen_data(lmbda: float, size: int):
    forward_arrow = CompositeArrow(name="forward")
    in_port1 = forward_arrow.add_port()
    make_in_port(in_port1)
    in_port2 = forward_arrow.add_port()
    make_in_port(in_port2)
    out_port = forward_arrow.add_port()
    make_out_port(out_port)

    subtraction = SubArrow()
    absolute = AbsArrow()
    forward_arrow.add_edge(in_port1, subtraction.in_ports()[0])
    forward_arrow.add_edge(in_port2, subtraction.in_ports()[1])
    forward_arrow.add_edge(subtraction.out_ports()[0], absolute.in_ports()[0])
    forward_arrow.add_edge(absolute.out_ports()[0], out_port)
    assert forward_arrow.is_wired_correctly()

    inverse_arrow = CompositeArrow(name="inverse")
    in_port = inverse_arrow.add_port()
    make_in_port(in_port)
    param_port_1 = inverse_arrow.add_port()
    make_in_port(param_port_1)
    make_param_port(param_port_1)
    param_port_2 = inverse_arrow.add_port()
    make_in_port(param_port_2)
    make_param_port(param_port_2)
    out_port_1 = inverse_arrow.add_port()
    make_out_port(out_port_1)
    out_port_2 = inverse_arrow.add_port()
    make_out_port(out_port_2)

    inv_sub = InvSubArrow()
    inv_abs = InvAbsArrow()
    inverse_arrow.add_edge(in_port, inv_abs.in_ports()[0])
    inverse_arrow.add_edge(param_port_1, inv_abs.in_ports()[1])
    inverse_arrow.add_edge(inv_abs.out_ports()[0], inv_sub.in_ports()[0])
    inverse_arrow.add_edge(param_port_2, inv_sub.in_ports()[1])
    inverse_arrow.add_edge(inv_sub.out_ports()[0], out_port_1)
    inverse_arrow.add_edge(inv_sub.out_ports()[1], out_port_2)
    assert inverse_arrow.is_wired_correctly()

    eps = 1e-5
    data = []
    dist = ExponentialRV(lmbda)
    for _ in range(size):
        x = tuple(dist.sample(shape=[2]))
        y = np.abs(x[0] - x[1])
        theta = (np.sign(x[0] - x[1]), x[1])
        y_ = apply(forward_arrow, list(x))
        x_ = apply(inverse_arrow, [y, theta[0], theta[1]])
        assert np.abs(y_[0] - y) < eps
        assert np.abs(x_[0] - x[0]) < eps
        assert np.abs(x_[1] - x[1]) < eps
        data.append((x, theta, y))
    return data, forward_arrow, inverse_arrow
예제 #6
0
 def __init__(self, n_in_ports: int, n_out_ports: int, name: str) -> None:
     super().__init__(name=name)
     n_ports = n_in_ports + n_out_ports
     self.n_in_ports = n_in_ports
     self.n_out_ports = n_out_ports
     self._ports = [Port(self, i) for i in range(n_ports)]
     self.port_attr = [{} for i in range(n_ports)]
     for i in range(n_in_ports):
         make_in_port(self._ports[i])
     for i in range(n_in_ports, n_ports):
         make_out_port(self._ports[i])
예제 #7
0
    def __init__(self) -> None:
        super().__init__(name="TriangleWave")
        comp_arrow = self
        in_port0 = comp_arrow.add_port()
        make_in_port(in_port0)
        in_port1 = comp_arrow.add_port()
        make_in_port(in_port1)
        in_port2 = comp_arrow.add_port()
        make_in_port(in_port2)
        out_port = comp_arrow.add_port()
        make_out_port(out_port)

        two = SourceArrow(2.0)
        u_minus_l = SubArrow()
        comp_arrow.add_edge(in_port2, u_minus_l.in_ports()[0])
        comp_arrow.add_edge(in_port1, u_minus_l.in_ports()[1])
        twice_u_minus_l = MulArrow()
        comp_arrow.add_edge(u_minus_l.out_ports()[0],
                            twice_u_minus_l.in_ports()[0])
        comp_arrow.add_edge(two.out_ports()[0], twice_u_minus_l.in_ports()[1])

        floordiv = FloorDivArrow()
        comp_arrow.add_edge(in_port0, floordiv.in_ports()[0])
        comp_arrow.add_edge(twice_u_minus_l.out_ports()[0],
                            floordiv.in_ports()[1])
        product = MulArrow()
        comp_arrow.add_edge(floordiv.out_ports()[0], product.in_ports()[0])
        comp_arrow.add_edge(twice_u_minus_l.out_ports()[0],
                            product.in_ports()[1])
        a_sub = SubArrow()
        comp_arrow.add_edge(in_port0, a_sub.in_ports()[0])
        comp_arrow.add_edge(product.out_ports()[0], a_sub.in_ports()[1])
        t = AddArrow()
        comp_arrow.add_edge(a_sub.out_ports()[0], t.in_ports()[0])
        comp_arrow.add_edge(in_port1, t.in_ports()[1])

        t_gt_u = GreaterArrow()
        comp_arrow.add_edge(t.out_ports()[0], t_gt_u.in_ports()[0])
        comp_arrow.add_edge(in_port2, t_gt_u.in_ports()[1])
        twice_u = MulArrow()
        comp_arrow.add_edge(in_port2, twice_u.in_ports()[0])
        comp_arrow.add_edge(two.out_ports()[0], twice_u.in_ports()[1])
        twou_minus_t = SubArrow()
        comp_arrow.add_edge(twice_u.out_ports()[0], twou_minus_t.in_ports()[0])
        comp_arrow.add_edge(t.out_ports()[0], twou_minus_t.in_ports()[1])

        if_arrow = IfArrow()
        comp_arrow.add_edge(t_gt_u.out_ports()[0], if_arrow.in_ports()[0])
        comp_arrow.add_edge(twou_minus_t.out_ports()[0],
                            if_arrow.in_ports()[1])
        comp_arrow.add_edge(t.out_ports()[0], if_arrow.in_ports()[2])
        comp_arrow.add_edge(if_arrow.out_ports()[0], out_port)
        assert comp_arrow.is_wired_correctly()
예제 #8
0
def graph_to_arrow(output_tensors: Sequence[Tensor],
                   name: str,
                   input_tensors: Sequence[Tensor] = None) -> Arrow:
    """Convert a tensorflow graph into an arrow.
    Assume inputs are 'Placeholder' tensors
    Args:
        output_tensors: Tensors designated as outputs
        input_tensors: Tensors designated as inputs.  If not given then
                       we assume any placeholder tensors connected (indrectly)
                       to the outputs are input tensors
        name: Name of the composite arrow
    Returns:
        A 'CompositeArrow' equivalent to graph which computes 'output_tensors'
    """
    op_to_arrow = dict()
    seen_tensors = set()
    to_see_tensors = []
    comp_arrow = CompositeArrow(name=name)

    # If in_ports are given don't dynamically find them
    # FIXME: Should this really be optional?
    given_in_ports = input_tensors is not None
    if given_in_ports:
        # Make an in_port for every input tensor
        tensor_to_in_port = dict()
        for tensor in input_tensors:
            in_port = comp_arrow.add_port()
            make_in_port(in_port)
            set_port_shape(in_port,
                           const_to_tuple(tensor.get_shape().as_list()))
            tensor_to_in_port[tensor] = in_port

    # Make an out_port for every output tensor
    for tensor in output_tensors:
        out_port = comp_arrow.add_port()
        make_out_port(out_port)
        arrow = arrow_from_op(tensor.op, op_to_arrow)
        left = arrow.out_ports()[tensor.value_index]
        comp_arrow.add_edge(left, out_port)

    # Starting from outputs
    to_see_tensors = output_tensors[:]
    while len(to_see_tensors) > 0:
        tensor = to_see_tensors.pop()
        seen_tensors.add(tensor)
        if is_input_tensor(tensor):
            if given_in_ports:
                left_port = tensor_to_in_port[tensor]
            else:
                left_port = comp_arrow.add_port()
                make_in_port(left_port)
                # FIXME: We are only taking shapes from placeholder inputs
                # is this sufficient?
                set_port_shape(left_port,
                               const_to_tuple(tensor.get_shape().as_list()))
        else:
            out_port_id = tensor.value_index
            left_arrow = arrow_from_op(tensor.op, op_to_arrow)
            left_port = left_arrow.out_ports()[out_port_id]
            update_seen(tensor.op, seen_tensors, to_see_tensors)

        for rec_op in tensor.consumers():
            for i, input_tensor in enumerate(rec_op.inputs):
                if tensor == input_tensor:
                    in_port_id = i
                    right_arrow = arrow_from_op(rec_op, op_to_arrow)
                    comp_arrow.add_edge(left_port,
                                        right_arrow.in_ports()[in_port_id])

    assert comp_arrow.is_wired_correctly()
    return comp_arrow