예제 #1
0
def partition(comp_arrow: CompositeArrow) -> List[Set[Arrow]]:
    """Partitions the comp_arrow into sequential layers of its sub_arrows"""
    partition_arrows = []
    arrow_colors = pqdict()
    for sub_arrow in comp_arrow.get_sub_arrows():
        arrow_colors[sub_arrow] = sub_arrow.num_in_ports()

    for port in comp_arrow.in_ports():
        in_ports = comp_arrow.neigh_in_ports(port)
        for in_port in in_ports:
            assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors"
            arrow_colors[in_port.arrow] -= 1

    while len(arrow_colors) > 0:
        arrow_layer = set()
        view_arrow, view_priority = arrow_colors.topitem()

        while view_priority == 0:
            sub_arrow, priority = arrow_colors.popitem()
            arrow_layer.add(sub_arrow)
            if len(arrow_colors) == 0:
                break
            view_arrow, view_priority = arrow_colors.topitem()

        partition_arrows.append(arrow_layer)

        for arrow in arrow_layer:
            for out_port in arrow.out_ports():
                in_ports = comp_arrow.neigh_in_ports(out_port)
                for in_port in in_ports:
                    if in_port.arrow != comp_arrow:
                        assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors"
                        arrow_colors[in_port.arrow] -= 1

    return partition_arrows
예제 #2
0
파일: invert.py 프로젝트: llwu/reverseflow
def invert(comp_arrow: CompositeArrow,
           dispatch: Dict[Arrow, Callable] = default_dispatch) -> Arrow:
    """Construct a parametric inverse of comp_arrow
    Args:
        comp_arrow: Arrow to invert
        dispatch: Dict mapping comp_arrow class to invert function
    Returns:
        A (approximate) parametric inverse of `comp_arrow`"""
    # Replace multiedges with dupls and propagate
    comp_arrow.duplify()
    port_attr = propagate(comp_arrow)
    return inner_invert(comp_arrow, port_attr, dispatch)[0]
예제 #3
0
def conv(a: CompositeArrow, args: TensorVarList, state) -> Sequence[Tensor]:
    assert len(args) == a.num_in_ports()
    with tf.name_scope(a.name):
        # import pdb; pdb.set_trace()
        # FIXME: A horrible horrible hack
        port_grab = state['port_grab']
        return interpret(conv, a, args, state, port_grab)
예제 #4
0
def gen_arrow_inputs(comp_arrow: CompositeArrow, inputs: List, arrow_colors):
    # Store a map from an arrow to its inputs
    # Use a dict because no guarantee we'll create input tensors in order
    arrow_inputs = dict()  # type: Dict[Arrow, MutableMapping[int, Any]]
    for sub_arrow in comp_arrow.get_all_arrows():
        arrow_inputs[sub_arrow] = dict()

    # Decrement priority of every arrow connected to the input
    for i, input_value in enumerate(inputs):
        for in_port in comp_arrow.edges[comp_arrow.in_ports()[i]]:
            # in_port = comp_arrow.inner_in_ports()[i]
            sub_arrow = in_port.arrow
            arrow_colors[sub_arrow] = arrow_colors[sub_arrow] - 1
            arrow_inputs[sub_arrow][in_port.index] = input_value

    return arrow_inputs
예제 #5
0
def filter_arrows(fil, arrow: CompositeArrow, deep=True):
    good_arrows = set()
    for sub_arrow in arrow.get_sub_arrows():
        if fil(sub_arrow):
            good_arrows.add(sub_arrow)
        if deep and isinstance(sub_arrow, CompositeArrow):
            for sub_sub_arrow in filter_arrows(fil, sub_arrow, deep=deep):
                good_arrows.add(sub_sub_arrow)
    return good_arrows
예제 #6
0
def gen_arrow_colors(comp_arrow: CompositeArrow):
    """
    Interpret a composite arrow on some inputs
    Args:
        comp_arrow: Composite Arrow
    Returns:
        arrow_colors: Priority Queue of arrows
    """
    # priority is the number of inputs each arrrow has which have been 'seen'
    # seen inputs are inputs to the composition, or outputs of arrows that
    # have already been converted into
    arrow_colors = pqdict()  # type: MutableMapping[Arrow, int]
    for sub_arrow in comp_arrow.get_sub_arrows():
        arrow_colors[sub_arrow] = sub_arrow.num_in_ports()

    # TODO: Unify
    arrow_colors[comp_arrow] = comp_arrow.num_out_ports()
    return arrow_colors
예제 #7
0
def inner_interpret(conv: Callable, comp_arrow: CompositeArrow, inputs: List,
                    arrow_colors: MutableMapping[Arrow, int],
                    arrow_inputs: Sequence, state: Dict, port_grab: Dict[Port,
                                                                         Any]):
    """Convert an comp_arrow to a tensorflow graph and add to graph"""
    assert len(inputs) == comp_arrow.num_in_ports(), "wrong # inputs"

    emit_list = []
    while len(arrow_colors) > 0:
        # print_arrow_colors(arrow_colors)
        # print("Converting ", sub_arrow.name)
        sub_arrow, priority = arrow_colors.popitem()
        if sub_arrow is not comp_arrow:
            assert priority == 0, "Must resolve {} more inputs to {} first".format(
                priority, sub_arrow)
            # inputs = [arrow_inputs[sub_arrow][i] for i in range(len(arrow_inputs[sub_arrow]))]
            inputs = [
                arrow_inputs[sub_arrow][i]
                for i in sorted(arrow_inputs[sub_arrow].keys())
            ]
            outputs = conv(sub_arrow, inputs, state)

            assert len(outputs) == len(
                sub_arrow.out_ports()), "diff num outputs"

            # Decrement the priority of each subarrow connected to this arrow
            # Unless of course it is connected to the outside word
            for i, out_port in enumerate(sub_arrow.out_ports()):
                neigh_in_ports = comp_arrow.neigh_in_ports(out_port)
                for neigh_in_port in neigh_in_ports:
                    neigh_arrow = neigh_in_port.arrow
                    arrow_colors[neigh_arrow] = arrow_colors[neigh_arrow] - 1
                    arrow_inputs[neigh_arrow][neigh_in_port.index] = outputs[i]

    # Extract some port, kind of a hack
    for port in port_grab:
        if port.arrow in arrow_inputs:
            if port.index in arrow_inputs[port.arrow]:
                port_grab[port] = arrow_inputs[port.arrow][port.index]

    outputs_dict = arrow_inputs[comp_arrow]
    out_port_indices = sorted(list(outputs_dict.keys()))
    return [outputs_dict[i] for i in out_port_indices]
예제 #8
0
def propagate(comp_arrow: CompositeArrow,
              port_attr: PortAttributes=None,
              state=None,
              already_prop=None,
              only_prop=None) -> PortAttributes:
    """
    Propagate values around a composite arrow to determine knowns from unknowns
    The knowns should be determined by the knowns, otherwise an error throws
    Args:
        sub_propagate: an @overloaded function which propagates from each arrow
          sub_propagate(a: ArrowType, port_to_known:Dict[Port, T], state:Any)
        comp_arrow: Composite Arrow to propagate through
        port_attr: port->value map for inputs to composite arrow
        state: A value of any type that is passed around during propagation
               and can be updated by sub_propagate
    Returns:
        port->value map for all ports in composite arrow
    """
    already_prop = set() if already_prop is None else already_prop
    print("Propagating")
    # Copy port_attr to avoid affecting input
    port_attr = {} if port_attr is None else port_attr
    _port_attr = defaultdict(lambda: dict())
    for port, attr in port_attr.items():
        for attr_key, attr_value in attr.items():
            _port_attr[port][attr_key] = attr_value

    # if comp_arrow.parent is None:
    #     comp_arrow.toposort()

    # update port_attr with values stored on port
    extract_port_attr(comp_arrow, _port_attr)

    updated = set(comp_arrow.get_sub_arrows_nested())
    update_neigh(_port_attr, _port_attr, comp_arrow, comp_arrow, only_prop, updated)
    while len(updated) > 0:
        # print(len(updated), " arrows updating in proapgation iteration")
        sub_arrow = updated.pop()
        sub_port_attr = {port: _port_attr[port]
                           for port in sub_arrow.ports()
                           if port in _port_attr}

        pred_dispatches = sub_arrow.get_dispatches()
        for pred, dispatch in pred_dispatches.items():
            if pred(sub_arrow, sub_port_attr) and (sub_arrow, dispatch) not in already_prop:
                new_sub_port_attr = dispatch(sub_arrow, sub_port_attr)
                update_neigh(new_sub_port_attr, _port_attr, sub_arrow.parent, comp_arrow, only_prop, updated)
                already_prop.add((sub_arrow, dispatch))
        if isinstance(sub_arrow, CompositeArrow):
            # new_sub_port_attr = propagate(sub_arrow, sub_port_attr, state, already_prop)
            # update_neigh(new_sub_port_attr, _port_attr, comp_arrow, updated)
            update_neigh(sub_port_attr, _port_attr, sub_arrow, comp_arrow, only_prop, updated)
            update_neigh(sub_port_attr, _port_attr, sub_arrow.parent, comp_arrow, only_prop, updated)
    print("Done Propagating")
    return _port_attr
예제 #9
0
def eliminate(arrow: CompositeArrow):
    """Eliminates redundant parameter
    Args:
        a: Parametric Arrow prime for eliminate!
    Returns:
        New Parameteric Arrow with fewer parameters"""

    # Warning: This is a huge hack

    # Get the shapes of param ports
    port_attr = propagate(arrow)
    symbt_ports = {}
    for port in arrow.in_ports():
        if is_param_port(port):
            shape = get_port_shape(port, port_attr)
            symbt_ports[port] = {}
            # Create a symbolic tensor for each param port
            st = SymbolicTensor(shape=shape,
                                name="port%s" % port.index,
                                port=port)
            symbt_ports[port]['symbolic_tensor'] = st

    # repropagate
    port_attr = propagate(arrow, symbt_ports)
    # as a hack, just look on ports of duples to  find symbolic tensors which
    # should be equivalent
    dupls = filter_arrows(lambda a: a.name in dupl_names, arrow)
    dupl_to_equiv = {}

    # Not all ports contain ports with symbolic tensor constraints
    valid_ports = set()
    for dupl in dupls:
        equiv = []
        for p in dupl.ports():
            if 'symbolic_tensor' in port_attr[p]:
                valid_ports.add(port_attr[p]['symbolic_tensor'].port)
                equiv.append(port_attr[p]['symbolic_tensor'])
        dupl_to_equiv[dupl] = equiv

    equiv_thetas = find_equivalent_thetas(dupl_to_equiv)
    return create_arrow(arrow, equiv_thetas, port_attr, valid_ports,
                        symbt_ports)
예제 #10
0
def gen_data(lmbda: float, size: int):
    forward_arrow = CompositeArrow(name="forward")
    in_port1 = forward_arrow.add_port()
    make_in_port(in_port1)
    in_port2 = forward_arrow.add_port()
    make_in_port(in_port2)
    out_port = forward_arrow.add_port()
    make_out_port(out_port)

    subtraction = SubArrow()
    absolute = AbsArrow()
    forward_arrow.add_edge(in_port1, subtraction.in_ports()[0])
    forward_arrow.add_edge(in_port2, subtraction.in_ports()[1])
    forward_arrow.add_edge(subtraction.out_ports()[0], absolute.in_ports()[0])
    forward_arrow.add_edge(absolute.out_ports()[0], out_port)
    assert forward_arrow.is_wired_correctly()

    inverse_arrow = CompositeArrow(name="inverse")
    in_port = inverse_arrow.add_port()
    make_in_port(in_port)
    param_port_1 = inverse_arrow.add_port()
    make_in_port(param_port_1)
    make_param_port(param_port_1)
    param_port_2 = inverse_arrow.add_port()
    make_in_port(param_port_2)
    make_param_port(param_port_2)
    out_port_1 = inverse_arrow.add_port()
    make_out_port(out_port_1)
    out_port_2 = inverse_arrow.add_port()
    make_out_port(out_port_2)

    inv_sub = InvSubArrow()
    inv_abs = InvAbsArrow()
    inverse_arrow.add_edge(in_port, inv_abs.in_ports()[0])
    inverse_arrow.add_edge(param_port_1, inv_abs.in_ports()[1])
    inverse_arrow.add_edge(inv_abs.out_ports()[0], inv_sub.in_ports()[0])
    inverse_arrow.add_edge(param_port_2, inv_sub.in_ports()[1])
    inverse_arrow.add_edge(inv_sub.out_ports()[0], out_port_1)
    inverse_arrow.add_edge(inv_sub.out_ports()[1], out_port_2)
    assert inverse_arrow.is_wired_correctly()

    eps = 1e-5
    data = []
    dist = ExponentialRV(lmbda)
    for _ in range(size):
        x = tuple(dist.sample(shape=[2]))
        y = np.abs(x[0] - x[1])
        theta = (np.sign(x[0] - x[1]), x[1])
        y_ = apply(forward_arrow, list(x))
        x_ = apply(inverse_arrow, [y, theta[0], theta[1]])
        assert np.abs(y_[0] - y) < eps
        assert np.abs(x_[0] - x[0]) < eps
        assert np.abs(x_[1] - x[1]) < eps
        data.append((x, theta, y))
    return data, forward_arrow, inverse_arrow
예제 #11
0
파일: invert.py 프로젝트: llwu/reverseflow
def inner_invert(comp_arrow: CompositeArrow, port_attr: PortAttributes,
                 dispatch: Dict[Arrow, Callable]):
    """Construct a parametric inverse of arrow
    Args:
        arrow: Arrow to invert
        dispatch: Dict mapping arrow class to invert function
    Returns:
        A (approximate) parametric inverse of `arrow`
        The ith in_port of comp_arrow will be corresponding ith out_port
        error_ports and param_ports will follow"""
    # Empty compositon for inverse
    inv_comp_arrow = CompositeArrow(name="%s_inv" % comp_arrow.name)
    # import pdb; pdb.set_trace()

    # Add a port on inverse arrow for every port on arrow
    for port in comp_arrow.ports():
        inv_port = inv_comp_arrow.add_port()
        if is_in_port(port):
            if is_constant(port, port_attr):
                make_in_port(inv_port)
            else:
                make_out_port(inv_port)
        elif is_out_port(port):
            if is_constant(port, port_attr):
                make_out_port(inv_port)
            else:
                make_in_port(inv_port)
        # Transfer port information
        # FIXME: What port_attr go transfered from port to inv_port
        if 'shape' not in port_attr[port]:
            print('WARNING: shape unknown for %s' % port)
        else:
            set_port_shape(inv_port, get_port_shape(port, port_attr))

    # invert each sub_arrow
    arrow_to_inv = dict()
    arrow_to_port_map = dict()
    for sub_arrow in comp_arrow.get_sub_arrows():
        inv_sub_arrow, port_map = invert_sub_arrow(sub_arrow, port_attr,
                                                   dispatch)
        assert sub_arrow is not None
        assert inv_sub_arrow.parent is None
        arrow_to_port_map[sub_arrow] = port_map
        arrow_to_inv[sub_arrow] = inv_sub_arrow

    # Add comp_arrow to inv
    assert comp_arrow is not None
    arrow_to_inv[comp_arrow] = inv_comp_arrow
    comp_port_map = {i: i for i in range(comp_arrow.num_ports())}
    arrow_to_port_map[comp_arrow] = comp_port_map

    # Then, rewire up all the edges
    for out_port, in_port in comp_arrow.edges.items():
        left_inv_port = get_inv_port(out_port, arrow_to_port_map, arrow_to_inv)
        right_inv_port = get_inv_port(in_port, arrow_to_port_map, arrow_to_inv)
        both = [left_inv_port, right_inv_port]
        projecting = list(
            filter(lambda x: would_project(x, inv_comp_arrow), both))
        receiving = list(
            filter(lambda x: would_receive(x, inv_comp_arrow), both))
        assert len(projecting) == 1, "Should be only 1 projecting"
        assert len(receiving) == 1, "Should be only 1 receiving"
        inv_comp_arrow.add_edge(projecting[0], receiving[0])

    for transform in transforms:
        transform(inv_comp_arrow)

    # Craete new ports on inverse compositions for parametric and error ports
    for sub_arrow in inv_comp_arrow.get_sub_arrows():
        for port in sub_arrow.ports():
            if is_param_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                param_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(param_port, port)
                make_in_port(param_port)
                make_param_port(param_port)
            elif is_error_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                error_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(port, error_port)
                make_out_port(error_port)
                make_error_port(error_port)

    return inv_comp_arrow, comp_port_map
예제 #12
0
파일: reparam.py 프로젝트: llwu/reverseflow
def reparam(comp_arrow: CompositeArrow,
            phi_shape: Tuple,
            nn_takes_input=True):
    """Reparameterize an arrow.  All parametric inputs now function of phi
    Args:
        comp_arrow: Arrow to reparameterize
        phi_shape: Shape of parameter input
    """
    reparam = CompositeArrow(name="%s_reparam" % comp_arrow.name)
    phi = reparam.add_port()
    set_port_shape(phi, phi_shape)
    make_in_port(phi)
    make_param_port(phi)
    n_in_ports = 1
    if nn_takes_input:
        n_in_ports += comp_arrow.num_in_ports() - comp_arrow.num_param_ports()
    nn = TfArrow(n_in_ports=n_in_ports, n_out_ports=comp_arrow.num_param_ports())
    reparam.add_edge(phi, nn.in_port(0))
    i = 0
    j = 1
    for port in comp_arrow.ports():
        if is_param_port(port):
            reparam.add_edge(nn.out_port(i), port)
            i += 1
        else:
            re_port = reparam.add_port()
            if is_out_port(port):
                make_out_port(re_port)
                reparam.add_edge(port, re_port)
            if is_in_port(port):
                make_in_port(re_port)
                reparam.add_edge(re_port, port)
                if nn_takes_input:
                    reparam.add_edge(re_port, nn.in_port(j))
                    j += 1
            if is_error_port(port):
                make_error_port(re_port)
            for label in get_port_labels(port):
                add_port_label(re_port, label)

    assert reparam.is_wired_correctly()
    return reparam
예제 #13
0
파일: gan.py 프로젝트: llwu/reverseflow
def g_from_g_theta(inv: Arrow, g_theta: Arrow):
    """
    Construct an amoritzed random variable g from a parametric inverse `inv`
    and function `g_theta` which constructs parameters for `inv`
    Args:
      g_theta: Y x Y ... Y x Z -> Theta
    Returns:
      Y x Y x .. x Z -> X x ... X x Error x ... Error
    """
    c = CompositeArrow(name="%s_g_theta" % inv.name)
    inv_in_ports, inv_param_ports, inv_out_ports, inv_error_ports = split_ports(
        inv)
    g_theta_in_ports, g_theta_param_ports, g_theta_out_ports, g_theta_error_ports = split_ports(
        g_theta)
    assert len(g_theta_out_ports) == len(inv_param_ports)

    # Connect y to g_theta and f
    for i in range(len(inv_in_ports)):
        y_in_port = c.add_port()
        make_in_port(y_in_port)
        c.add_edge(y_in_port, g_theta.in_port(i))
        c.add_edge(y_in_port, inv_in_ports[i])

    # conect up noise input to g_theta
    z_in_port = c.add_port()
    make_in_port(z_in_port)
    c.add_edge(z_in_port, g_theta.in_port(len(inv_in_ports)))

    # connect g_theta to inv
    for i in range(len(g_theta_out_ports)):
        c.add_edge(g_theta.out_port(i), inv_param_ports[i])

    for inv_out_port in inv_out_ports:
        out_port = c.add_port()
        make_out_port(out_port)
        c.add_edge(inv_out_port, out_port)

    for inv_error_port in inv_error_ports:
        error_port = c.add_port()
        make_out_port(error_port)
        make_error_port(error_port)
        c.add_edge(inv_error_port, error_port)

    assert c.is_wired_correctly()
    return c
예제 #14
0
파일: gan.py 프로젝트: llwu/reverseflow
def ConcatShuffleArrow(n_inputs: int, ndims: int):
    """Concatenate n_inputs inputs and shuffle
  Arrow first n_inputs inputs are arrays to shuffle
  last input is permutation vector
  Args:
    n_inputs"""
    c = CompositeArrow(name="ConcatShuffle")

    stack = StackArrow(n_inputs, axis=0)
    for i in range(n_inputs):
        in_port = c.add_port()
        make_in_port(in_port)
        c.add_edge(in_port, stack.in_port(i))

    # Permutation vector
    perm = c.add_port()
    make_in_port(perm)
    set_port_dtype(perm, 'int32')
    gather = GatherArrow()

    c.add_edge(stack.out_port(0), gather.in_port(0))
    c.add_edge(perm, gather.in_port(1))

    # Switch the first and last dimension
    # FIXME: I do this only because gather seems to affect only first dimension
    # There's a better way, probably using gather_nd
    a = [i for i in range(ndims + 1)]
    tp_perm = [a[i + 1] for i in range(len(a) - 1)] + [a[0]]

    transpose = TransposeArrow(tp_perm)
    c.add_edge(gather.out_port(0), transpose.in_port(0))

    out = c.add_port()
    make_out_port(out)
    c.add_edge(transpose.out_port(0), out)
    assert c.is_wired_correctly()
    return c
예제 #15
0
파일: gan.py 프로젝트: llwu/reverseflow
def set_gan_arrow(arrow: Arrow,
                  cond_gen: Arrow,
                  disc: Arrow,
                  n_fake_samples: int,
                  ndims: int,
                  x_shapes=None,
                  z_shape=None) -> CompositeArrow:
    """
    Arrow wihch computes loss for amortized random variable using set gan.
    Args:
        arrow: Forward function
        cond_gen: Y x Z -> X - Conditional Generators
        disc: X^n -> {0,1}^n
        n_fake_samples: n, number of samples seen by discriminator at once
        ndims: dimensionality of dims
    Returns:
        CompositeArrow: X x Z x ... Z x RAND_PERM -> d_Loss x g_Loss x Y x ... Y
    """
    # TODO: Assumes that f has single in_port and single out_port, generalize
    c = CompositeArrow(name="%s_set_gan" % arrow.name)
    assert cond_gen.num_in_ports(
    ) == 2, "don't handle case of more than one Y input"
    cond_gens = [deepcopy(cond_gen) for i in range(n_fake_samples)]

    comp_in_ports = []
    # Connect x to arrow in puts
    for i in range(arrow.num_in_ports()):
        in_port = c.add_port()
        make_in_port(in_port)
        comp_in_ports.append(in_port)
        c.add_edge(in_port, arrow.in_port(i))
        if x_shapes is not None:
            set_port_shape(in_port, x_shapes[i])

    # Connect f(x) to generator
    for i in range(n_fake_samples):
        for j in range(arrow.num_out_ports()):
            c.add_edge(arrow.out_port(j), cond_gens[i].in_port(j))

    # Connect noise input to generator second inport
    for i in range(n_fake_samples):
        noise_in_port = c.add_port()
        make_in_port(noise_in_port)
        if z_shape is not None:
            set_port_shape(noise_in_port, z_shape)
        cg_noise_in_port_id = cond_gens[i].num_in_ports() - 1
        c.add_edge(noise_in_port, cond_gens[i].in_port(cg_noise_in_port_id))

    stack_shuffles = []
    rand_perm_in_port = c.add_port()
    make_in_port(rand_perm_in_port)
    set_port_shape(rand_perm_in_port, (n_fake_samples + 1, ))
    set_port_dtype(rand_perm_in_port, 'int32')

    # For every output of g, i.e. x and y if f(x, y) = z
    # Stack all the Xs from the differet samples together and shuffle
    cond_gen_non_error_out_ports = cond_gen.num_out_ports(
    ) - cond_gen.num_error_ports()
    for i in range(cond_gen_non_error_out_ports):
        stack_shuffle = ConcatShuffleArrow(n_fake_samples + 1, ndims)
        stack_shuffles.append(stack_shuffle)
        # Add each output from generator to shuffle set
        for j in range(n_fake_samples):
            c.add_edge(cond_gens[j].out_port(i), stack_shuffle.in_port(j))

        # Add the posterior sample x to the shuffle set
        c.add_edge(comp_in_ports[i], stack_shuffle.in_port(n_fake_samples))
        c.add_edge(rand_perm_in_port,
                   stack_shuffle.in_port(n_fake_samples + 1))

    gan_loss_arrow = GanLossArrow(n_fake_samples + 1)

    # Connect output of each stack shuffle to discriminator
    for i in range(cond_gen_non_error_out_ports):
        c.add_edge(stack_shuffles[i].out_port(0), disc.in_port(i))

    c.add_edge(disc.out_port(0), gan_loss_arrow.in_port(0))
    c.add_edge(rand_perm_in_port, gan_loss_arrow.in_port(1))

    # Add generator and discriminator loss ports
    loss_d_port = c.add_port()
    make_out_port(loss_d_port)
    make_error_port(loss_d_port)
    c.add_edge(gan_loss_arrow.out_port(0), loss_d_port)

    loss_g_port = c.add_port()
    make_out_port(loss_g_port)
    make_error_port(loss_g_port)
    c.add_edge(gan_loss_arrow.out_port(1), loss_g_port)

    # Connect fake samples to output of composition
    for i in range(n_fake_samples):
        for j in range(cond_gen_non_error_out_ports):
            sample = c.add_port()
            make_out_port(sample)
            c.add_edge(cond_gens[i].out_port(j), sample)

    # Pipe up error ports
    for cond_gen in cond_gens:
        for i in range(cond_gen_non_error_out_ports, cond_gen.num_out_ports()):
            error_port = c.add_port()
            make_out_port(error_port)
            c.add_edge(cond_gen.out_port(i), error_port)

    assert c.is_wired_correctly()
    return c
예제 #16
0
def unparam(arrow: Arrow, nnet: Arrow = None):
    """Unparameerize an arrow by sticking a tfArrow between its normal inputs,
    and any parametric inputs
    Args:
        arrow: Y x Theta -> X
        nnet: Y -> Theta
    Returns:
        Y -> X
    """
    c = CompositeArrow(name="%s_unparam" % arrow.name)
    in_ports = [p for p in arrow.in_ports() if not is_param_port(p)]
    param_ports = [p for p in arrow.in_ports() if is_param_port(p)]
    if nnet is None:
        nnet = TfArrow(n_in_ports=len(in_ports), n_out_ports=len(param_ports))
    for i, in_port in enumerate(in_ports):
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        transfer_labels(in_port, c_in_port)
        c.add_edge(c_in_port, in_port)
        c.add_edge(c_in_port, nnet.in_port(i))

    for i, param_port in enumerate(param_ports):
        c.add_edge(nnet.out_port(i), param_port)

    for out_port in arrow.out_ports():
        c_out_port = c.add_port()
        make_out_port(c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)
        transfer_labels(out_port, c_out_port)
        c.add_edge(out_port, c_out_port)

    assert c.is_wired_correctly()
    return c
예제 #17
0
def create_arrow(arrow: CompositeArrow, equiv_thetas, port_attr, valid_ports,
                 symbt_ports):
    # New parameter space should have nclasses elements
    nclasses = num_unique_elem(equiv_thetas)
    new_arrow = CompositeArrow(name="%s_elim" % arrow.name)
    for out_port in arrow.out_ports():
        c_out_port = new_arrow.add_port()
        make_out_port(c_out_port)
        transfer_labels(out_port, c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)
        new_arrow.add_edge(out_port, c_out_port)

    flat_shape = SourceArrow(np.array([nclasses], dtype=np.int32))
    flatten = ReshapeArrow()
    new_arrow.add_edge(flat_shape.out_port(0), flatten.in_port(1))
    slim_param_flat = flatten.out_port(0)
    batch_size = None
    for in_port in arrow.in_ports():
        if in_port in valid_ports:
            symbt = symbt_ports[in_port]['symbolic_tensor']
            indices = []
            for theta in symbt.symbols:
                setid = equiv_thetas[theta]
                indices.append(setid)
            shape = get_port_shape(in_port, port_attr)
            if len(shape) > 1:
                if batch_size is not None:
                    assert shape[0] == batch_size
                batch_size = shape[0]
            gather = GatherArrow()
            src = SourceArrow(np.array(indices, dtype=np.int32))
            shape_shape = SourceArrow(np.array(shape, dtype=np.int32))
            reshape = ReshapeArrow()
            new_arrow.add_edge(slim_param_flat, gather.in_port(0))
            new_arrow.add_edge(src.out_port(0), gather.in_port(1))
            new_arrow.add_edge(gather.out_port(0), reshape.in_port(0))
            new_arrow.add_edge(shape_shape.out_port(0), reshape.in_port(1))
            new_arrow.add_edge(reshape.out_port(0), in_port)
        else:
            new_in_port = new_arrow.add_port()
            make_in_port(new_in_port)
            if is_param_port(in_port):
                make_param_port(new_in_port)
            transfer_labels(in_port, new_in_port)
            new_arrow.add_edge(new_in_port, in_port)
    assert nclasses % batch_size == 0
    slim_param = new_arrow.add_port()
    make_in_port(slim_param)
    make_param_port(slim_param)
    new_arrow.add_edge(slim_param, flatten.in_port(0))
    set_port_shape(slim_param, (batch_size, nclasses // batch_size))

    assert new_arrow.is_wired_correctly()
    return new_arrow
예제 #18
0
def comp(fwd: Arrow, right_inv: Arrow, DiffArrow=SquaredDifference):
    """Compositon: Pipe output of forward model into input of right inverse
    Args:
        fwd: X -> Y
        right_inv: Y -> X x Error
    Returns:
        X -> X
    """
    c = CompositeArrow(name="fwd_to_right_inv")

    # Connect left boundar to fwd
    for in_port in fwd.in_ports():
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        c.add_edge(c_in_port, in_port)
        transfer_labels(in_port, c_in_port)

    # Connect fwd to right_inv
    for i, out_port in enumerate(fwd.out_ports()):
        c.add_edge(out_port, right_inv.in_port(i))

    # connect right_inv to right boundary
    for out_port in right_inv.out_ports():
        c_out_port = c.add_port()
        make_out_port(c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)

        c.add_edge(out_port, c_out_port)
        transfer_labels(out_port, c_out_port)

    # Find difference between X and right_inv(f(x))
    right_inv_out_ports = list(filter(lambda port: not is_error_port(port),
                                      right_inv.out_ports()))  # len(X)
    assert len(right_inv_out_ports) == len(c.in_ports())
    for i, in_port in enumerate(c.in_ports()):
        diff = DiffArrow()
        c.add_edge(in_port, diff.in_port(0))
        c.add_edge(right_inv_out_ports[i], diff.in_port(1))
        error_port = c.add_port()
        make_out_port(error_port)
        make_error_port(error_port)
        add_port_label(error_port, "supervised_error")
        c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c
예제 #19
0
def eliminate_gathernd(arrow: CompositeArrow):
    """Eliminates redundant parameters in GatherNd
    Args:
        a: Parametric Arrow prime for eliminate!
    Returns:
        New Parameteric Arrow with fewer parameters"""
    dupls = filter_arrows(lambda a: a.name in dupl_names, arrow)
    for dupl in dupls:
        slim_param_arrow = UpdateArrow()
        constraints = None
        free = None
        shape = None
        shape_source = None
        for p in dupl.in_ports():
            inv = arrow.neigh_ports(p)[0].arrow
            if inv.name == 'InvGatherNd':
                out, theta, indices = inv.in_ports()
                make_not_param_port(theta)
                arrow.add_edge(slim_param_arrow.out_port(0), theta)

                indices_val = get_port_value(indices)
                if shape is not None:
                    assert np.array_equal(
                        shape, np.array(get_port_shape(inv.out_port(0))))
                else:
                    shape = np.array(get_port_shape(inv.out_port(0)))
                    shape_source = SourceArrow(shape)

                out = arrow.neigh_ports(out)[0]

                unset, unique = complement_bool_list(indices_val, shape)
                free = unset if free is None else np.logical_and(
                    unset, free)  # FIXME: don't really need bool anymore
                unique_source = SourceArrow(unique)
                unique_inds = SourceArrow(indices_val[tuple(
                    np.transpose(unique))])

                # add in a constraint for these new known values
                unique_upds = GatherNdArrow()
                arrow.add_edge(out, unique_upds.in_port(0))
                arrow.add_edge(unique_source.out_port(0),
                               unique_upds.in_port(1))
                tmp = UpdateArrow()
                if constraints is None:
                    constraints = SourceArrow(np.zeros(shape,
                                                       dtype=np.float32))
                arrow.add_edge(constraints.out_port(0), tmp.in_port(0))
                arrow.add_edge(unique_inds.out_port(0), tmp.in_port(1))
                arrow.add_edge(unique_upds.out_port(0), tmp.in_port(2))
                arrow.add_edge(shape_source.out_port(0), tmp.in_port(3))
                constraints = tmp
        if constraints is not None:
            # put in knowns
            arrow.add_edge(constraints.out_port(0),
                           slim_param_arrow.in_port(0))
            # put in params
            make_param_port(slim_param_arrow.in_port(2))
            arrow.add_edge(shape_source.out_port(0),
                           slim_param_arrow.in_port(3))
            inds = np.transpose(np.nonzero(free))
            if (free == free[0]).all():
                print("Assuming batched input")
                inds = np.array(np.split(inds, shape[0]))
            else:
                print(
                    "WARNING: Unbatched input, haven't designed for this case")
            inds_source = SourceArrow(inds)
            arrow.add_edge(inds_source.out_port(0),
                           slim_param_arrow.in_port(1))
예제 #20
0
def supervised_loss_arrow(arrow: Arrow,
                          DiffArrow=SquaredDifference) -> CompositeArrow:
    """
    Creates an arrow that  computes |f(y) - x|
    Args:
        Arrow: f: Y -> X - The arrow to modify
        DiffArrow: d: X x X - R - Arrow for computing difference
    Returns:
        f: Y/Theta x .. Y/Theta x X -> |f^{-1}(y) - X| x X
        Arrow with same input and output as arrow except that it takes an
        addition input with label 'train_output' that should contain examples
        in Y, and it returns an additional error output labelled
        'supervised_error' which is the |f(y) - x|
    """
    c = CompositeArrow(name="%s_supervised" % arrow.name)
    # Pipe all inputs of composite to inputs of arrow

    # Make all in_ports of inverse inputs to composition
    for in_port in arrow.in_ports():
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        if is_param_port(in_port):
            make_param_port(c_in_port)
        c.add_edge(c_in_port, in_port)

    # find difference between inputs to inverse and outputs of fwd
    # make error port for each
    for i, out_port in enumerate(arrow.out_ports()):
        if is_error_port(out_port):
            # if its an error port just pass through
            error_port = c.add_port()
            make_out_port(error_port)
            make_error_port(error_port)
            transfer_labels(out_port, error_port)
            c.add_edge(out_port, error_port)
        else:
            # If its normal outport then pass through
            c_out_port = c.add_port()
            make_out_port(c_out_port)
            c.add_edge(out_port, c_out_port)

            # And compute the error
            diff = DiffArrow()
            in_port = c.add_port()
            make_in_port(in_port)
            add_port_label(in_port, "train_output")
            c.add_edge(in_port, diff.in_port(0))
            c.add_edge(out_port, diff.in_port(1))
            error_port = c.add_port()
            make_out_port(error_port)
            make_error_port(error_port)
            add_port_label(error_port, "supervised_error")
            c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c
예제 #21
0
    def __init__(self, lmbda: float) -> None:
        comp_arrow = CompositeArrow(name="ExponentialRVQuantile")
        in_port = comp_arrow.add_port()
        make_in_port(in_port)
        out_port = comp_arrow.add_port()
        make_out_port(out_port)

        lmbda_source = SourceArrow(lmbda)
        one_source = SourceArrow(1.0)
        one_minus_p = SubArrow()
        comp_arrow.add_edge(one_source.out_ports()[0],
                            one_minus_p.in_ports()[0])
        comp_arrow.add_edge(in_port, one_minus_p.in_ports()[1])
        ln = LogArrow()
        comp_arrow.add_edge(one_minus_p.out_ports()[0], ln.in_ports()[0])

        negate = NegArrow()
        comp_arrow.add_edge(ln.out_ports()[0], negate.in_ports()[0])
        div_lmbda = DivArrow()
        comp_arrow.add_edge(negate.out_ports()[0], div_lmbda.in_ports()[0])
        comp_arrow.add_edge(lmbda_source.out_ports()[0],
                            div_lmbda.in_ports()[1])
        comp_arrow.add_edge(div_lmbda.out_ports()[0], out_port)

        assert comp_arrow.is_wired_correctly()
        self.quantile = comp_arrow
예제 #22
0
def SumNArrow(ninputs: int):
    """
  Create arrow f(x1, ..., xn) = sum(x1, ..., xn)
  Args:
    n: number of inputs
  Returns:
    Arrow of n inputs and one output
  """
    assert ninputs > 1
    c = CompositeArrow(name="SumNArrow")
    light_port = c.add_port()
    make_in_port(light_port)

    for _ in range(ninputs - 1):
        add = AddArrow()
        c.add_edge(light_port, add.in_port(0))
        dark_port = c.add_port()
        make_in_port(dark_port)
        c.add_edge(dark_port, add.in_port(1))
        light_port = add.out_port(0)

    out_port = c.add_port()
    make_out_port(out_port)
    c.add_edge(add.out_port(0), out_port)

    assert c.is_wired_correctly()
    assert c.num_in_ports() == ninputs
    return c
예제 #23
0
def inv_fwd_loss_arrow(arrow: Arrow,
                       inverse: Arrow,
                       DiffArrow=SquaredDifference) -> CompositeArrow:
    """
    Arrow wihch computes |f(f^-1(y)) - y|
    Args:
        arrow: Forward function
    Returns:
        CompositeArrow
    """
    c = CompositeArrow(name="%s_inv_fwd_loss" % arrow.name)

    # Make all in_ports of inverse inputs to composition
    for inv_in_port in inverse.in_ports():
        in_port = c.add_port()
        make_in_port(in_port)
        if is_param_port(inv_in_port):
            make_param_port(in_port)
        c.add_edge(in_port, inv_in_port)

    # Connect all out_ports of inverse to in_ports of f
    for i, out_port in enumerate(inverse.out_ports()):
        if not is_error_port(out_port):
            c.add_edge(out_port, arrow.in_port(i))
            c_out_port = c.add_port()
            # add edge from inverse output to composition output
            make_out_port(c_out_port)
            c.add_edge(out_port, c_out_port)

    # Pass errors (if any) of parametric inverse through as error_ports
    for i, out_port in enumerate(inverse.out_ports()):
        if is_error_port(out_port):
            error_port = c.add_port()
            make_out_port(error_port)
            make_error_port(error_port)
            add_port_label(error_port, "sub_arrow_error")
            c.add_edge(out_port, error_port)

    # find difference between inputs to inverse and outputs of fwd
    # make error port for each
    for i, out_port in enumerate(arrow.out_ports()):
        diff = DiffArrow()
        c.add_edge(c.in_port(i), diff.in_port(0))
        c.add_edge(out_port, diff.in_port(1))
        error_port = c.add_port()
        make_out_port(error_port)
        make_error_port(error_port)
        add_port_label(error_port, "inv_fwd_error")
        c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c