예제 #1
0
def unparam(arrow: Arrow, nnet: Arrow = None):
    """Unparameerize an arrow by sticking a tfArrow between its normal inputs,
    and any parametric inputs
    Args:
        arrow: Y x Theta -> X
        nnet: Y -> Theta
    Returns:
        Y -> X
    """
    c = CompositeArrow(name="%s_unparam" % arrow.name)
    in_ports = [p for p in arrow.in_ports() if not is_param_port(p)]
    param_ports = [p for p in arrow.in_ports() if is_param_port(p)]
    if nnet is None:
        nnet = TfArrow(n_in_ports=len(in_ports), n_out_ports=len(param_ports))
    for i, in_port in enumerate(in_ports):
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        transfer_labels(in_port, c_in_port)
        c.add_edge(c_in_port, in_port)
        c.add_edge(c_in_port, nnet.in_port(i))

    for i, param_port in enumerate(param_ports):
        c.add_edge(nnet.out_port(i), param_port)

    for out_port in arrow.out_ports():
        c_out_port = c.add_port()
        make_out_port(c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)
        transfer_labels(out_port, c_out_port)
        c.add_edge(out_port, c_out_port)

    assert c.is_wired_correctly()
    return c
예제 #2
0
파일: reparam.py 프로젝트: llwu/reverseflow
def reparam(comp_arrow: CompositeArrow,
            phi_shape: Tuple,
            nn_takes_input=True):
    """Reparameterize an arrow.  All parametric inputs now function of phi
    Args:
        comp_arrow: Arrow to reparameterize
        phi_shape: Shape of parameter input
    """
    reparam = CompositeArrow(name="%s_reparam" % comp_arrow.name)
    phi = reparam.add_port()
    set_port_shape(phi, phi_shape)
    make_in_port(phi)
    make_param_port(phi)
    n_in_ports = 1
    if nn_takes_input:
        n_in_ports += comp_arrow.num_in_ports() - comp_arrow.num_param_ports()
    nn = TfArrow(n_in_ports=n_in_ports, n_out_ports=comp_arrow.num_param_ports())
    reparam.add_edge(phi, nn.in_port(0))
    i = 0
    j = 1
    for port in comp_arrow.ports():
        if is_param_port(port):
            reparam.add_edge(nn.out_port(i), port)
            i += 1
        else:
            re_port = reparam.add_port()
            if is_out_port(port):
                make_out_port(re_port)
                reparam.add_edge(port, re_port)
            if is_in_port(port):
                make_in_port(re_port)
                reparam.add_edge(re_port, port)
                if nn_takes_input:
                    reparam.add_edge(re_port, nn.in_port(j))
                    j += 1
            if is_error_port(port):
                make_error_port(re_port)
            for label in get_port_labels(port):
                add_port_label(re_port, label)

    assert reparam.is_wired_correctly()
    return reparam
예제 #3
0
def SumNArrow(ninputs: int):
    """
  Create arrow f(x1, ..., xn) = sum(x1, ..., xn)
  Args:
    n: number of inputs
  Returns:
    Arrow of n inputs and one output
  """
    assert ninputs > 1
    c = CompositeArrow(name="SumNArrow")
    light_port = c.add_port()
    make_in_port(light_port)

    for _ in range(ninputs - 1):
        add = AddArrow()
        c.add_edge(light_port, add.in_port(0))
        dark_port = c.add_port()
        make_in_port(dark_port)
        c.add_edge(dark_port, add.in_port(1))
        light_port = add.out_port(0)

    out_port = c.add_port()
    make_out_port(out_port)
    c.add_edge(add.out_port(0), out_port)

    assert c.is_wired_correctly()
    assert c.num_in_ports() == ninputs
    return c
예제 #4
0
파일: gan.py 프로젝트: llwu/reverseflow
def ConcatShuffleArrow(n_inputs: int, ndims: int):
    """Concatenate n_inputs inputs and shuffle
  Arrow first n_inputs inputs are arrays to shuffle
  last input is permutation vector
  Args:
    n_inputs"""
    c = CompositeArrow(name="ConcatShuffle")

    stack = StackArrow(n_inputs, axis=0)
    for i in range(n_inputs):
        in_port = c.add_port()
        make_in_port(in_port)
        c.add_edge(in_port, stack.in_port(i))

    # Permutation vector
    perm = c.add_port()
    make_in_port(perm)
    set_port_dtype(perm, 'int32')
    gather = GatherArrow()

    c.add_edge(stack.out_port(0), gather.in_port(0))
    c.add_edge(perm, gather.in_port(1))

    # Switch the first and last dimension
    # FIXME: I do this only because gather seems to affect only first dimension
    # There's a better way, probably using gather_nd
    a = [i for i in range(ndims + 1)]
    tp_perm = [a[i + 1] for i in range(len(a) - 1)] + [a[0]]

    transpose = TransposeArrow(tp_perm)
    c.add_edge(gather.out_port(0), transpose.in_port(0))

    out = c.add_port()
    make_out_port(out)
    c.add_edge(transpose.out_port(0), out)
    assert c.is_wired_correctly()
    return c
예제 #5
0
파일: gan.py 프로젝트: llwu/reverseflow
def g_from_g_theta(inv: Arrow, g_theta: Arrow):
    """
    Construct an amoritzed random variable g from a parametric inverse `inv`
    and function `g_theta` which constructs parameters for `inv`
    Args:
      g_theta: Y x Y ... Y x Z -> Theta
    Returns:
      Y x Y x .. x Z -> X x ... X x Error x ... Error
    """
    c = CompositeArrow(name="%s_g_theta" % inv.name)
    inv_in_ports, inv_param_ports, inv_out_ports, inv_error_ports = split_ports(
        inv)
    g_theta_in_ports, g_theta_param_ports, g_theta_out_ports, g_theta_error_ports = split_ports(
        g_theta)
    assert len(g_theta_out_ports) == len(inv_param_ports)

    # Connect y to g_theta and f
    for i in range(len(inv_in_ports)):
        y_in_port = c.add_port()
        make_in_port(y_in_port)
        c.add_edge(y_in_port, g_theta.in_port(i))
        c.add_edge(y_in_port, inv_in_ports[i])

    # conect up noise input to g_theta
    z_in_port = c.add_port()
    make_in_port(z_in_port)
    c.add_edge(z_in_port, g_theta.in_port(len(inv_in_ports)))

    # connect g_theta to inv
    for i in range(len(g_theta_out_ports)):
        c.add_edge(g_theta.out_port(i), inv_param_ports[i])

    for inv_out_port in inv_out_ports:
        out_port = c.add_port()
        make_out_port(out_port)
        c.add_edge(inv_out_port, out_port)

    for inv_error_port in inv_error_ports:
        error_port = c.add_port()
        make_out_port(error_port)
        make_error_port(error_port)
        c.add_edge(inv_error_port, error_port)

    assert c.is_wired_correctly()
    return c
예제 #6
0
파일: gan.py 프로젝트: llwu/reverseflow
def set_gan_arrow(arrow: Arrow,
                  cond_gen: Arrow,
                  disc: Arrow,
                  n_fake_samples: int,
                  ndims: int,
                  x_shapes=None,
                  z_shape=None) -> CompositeArrow:
    """
    Arrow wihch computes loss for amortized random variable using set gan.
    Args:
        arrow: Forward function
        cond_gen: Y x Z -> X - Conditional Generators
        disc: X^n -> {0,1}^n
        n_fake_samples: n, number of samples seen by discriminator at once
        ndims: dimensionality of dims
    Returns:
        CompositeArrow: X x Z x ... Z x RAND_PERM -> d_Loss x g_Loss x Y x ... Y
    """
    # TODO: Assumes that f has single in_port and single out_port, generalize
    c = CompositeArrow(name="%s_set_gan" % arrow.name)
    assert cond_gen.num_in_ports(
    ) == 2, "don't handle case of more than one Y input"
    cond_gens = [deepcopy(cond_gen) for i in range(n_fake_samples)]

    comp_in_ports = []
    # Connect x to arrow in puts
    for i in range(arrow.num_in_ports()):
        in_port = c.add_port()
        make_in_port(in_port)
        comp_in_ports.append(in_port)
        c.add_edge(in_port, arrow.in_port(i))
        if x_shapes is not None:
            set_port_shape(in_port, x_shapes[i])

    # Connect f(x) to generator
    for i in range(n_fake_samples):
        for j in range(arrow.num_out_ports()):
            c.add_edge(arrow.out_port(j), cond_gens[i].in_port(j))

    # Connect noise input to generator second inport
    for i in range(n_fake_samples):
        noise_in_port = c.add_port()
        make_in_port(noise_in_port)
        if z_shape is not None:
            set_port_shape(noise_in_port, z_shape)
        cg_noise_in_port_id = cond_gens[i].num_in_ports() - 1
        c.add_edge(noise_in_port, cond_gens[i].in_port(cg_noise_in_port_id))

    stack_shuffles = []
    rand_perm_in_port = c.add_port()
    make_in_port(rand_perm_in_port)
    set_port_shape(rand_perm_in_port, (n_fake_samples + 1, ))
    set_port_dtype(rand_perm_in_port, 'int32')

    # For every output of g, i.e. x and y if f(x, y) = z
    # Stack all the Xs from the differet samples together and shuffle
    cond_gen_non_error_out_ports = cond_gen.num_out_ports(
    ) - cond_gen.num_error_ports()
    for i in range(cond_gen_non_error_out_ports):
        stack_shuffle = ConcatShuffleArrow(n_fake_samples + 1, ndims)
        stack_shuffles.append(stack_shuffle)
        # Add each output from generator to shuffle set
        for j in range(n_fake_samples):
            c.add_edge(cond_gens[j].out_port(i), stack_shuffle.in_port(j))

        # Add the posterior sample x to the shuffle set
        c.add_edge(comp_in_ports[i], stack_shuffle.in_port(n_fake_samples))
        c.add_edge(rand_perm_in_port,
                   stack_shuffle.in_port(n_fake_samples + 1))

    gan_loss_arrow = GanLossArrow(n_fake_samples + 1)

    # Connect output of each stack shuffle to discriminator
    for i in range(cond_gen_non_error_out_ports):
        c.add_edge(stack_shuffles[i].out_port(0), disc.in_port(i))

    c.add_edge(disc.out_port(0), gan_loss_arrow.in_port(0))
    c.add_edge(rand_perm_in_port, gan_loss_arrow.in_port(1))

    # Add generator and discriminator loss ports
    loss_d_port = c.add_port()
    make_out_port(loss_d_port)
    make_error_port(loss_d_port)
    c.add_edge(gan_loss_arrow.out_port(0), loss_d_port)

    loss_g_port = c.add_port()
    make_out_port(loss_g_port)
    make_error_port(loss_g_port)
    c.add_edge(gan_loss_arrow.out_port(1), loss_g_port)

    # Connect fake samples to output of composition
    for i in range(n_fake_samples):
        for j in range(cond_gen_non_error_out_ports):
            sample = c.add_port()
            make_out_port(sample)
            c.add_edge(cond_gens[i].out_port(j), sample)

    # Pipe up error ports
    for cond_gen in cond_gens:
        for i in range(cond_gen_non_error_out_ports, cond_gen.num_out_ports()):
            error_port = c.add_port()
            make_out_port(error_port)
            c.add_edge(cond_gen.out_port(i), error_port)

    assert c.is_wired_correctly()
    return c
예제 #7
0
def create_arrow(arrow: CompositeArrow, equiv_thetas, port_attr, valid_ports,
                 symbt_ports):
    # New parameter space should have nclasses elements
    nclasses = num_unique_elem(equiv_thetas)
    new_arrow = CompositeArrow(name="%s_elim" % arrow.name)
    for out_port in arrow.out_ports():
        c_out_port = new_arrow.add_port()
        make_out_port(c_out_port)
        transfer_labels(out_port, c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)
        new_arrow.add_edge(out_port, c_out_port)

    flat_shape = SourceArrow(np.array([nclasses], dtype=np.int32))
    flatten = ReshapeArrow()
    new_arrow.add_edge(flat_shape.out_port(0), flatten.in_port(1))
    slim_param_flat = flatten.out_port(0)
    batch_size = None
    for in_port in arrow.in_ports():
        if in_port in valid_ports:
            symbt = symbt_ports[in_port]['symbolic_tensor']
            indices = []
            for theta in symbt.symbols:
                setid = equiv_thetas[theta]
                indices.append(setid)
            shape = get_port_shape(in_port, port_attr)
            if len(shape) > 1:
                if batch_size is not None:
                    assert shape[0] == batch_size
                batch_size = shape[0]
            gather = GatherArrow()
            src = SourceArrow(np.array(indices, dtype=np.int32))
            shape_shape = SourceArrow(np.array(shape, dtype=np.int32))
            reshape = ReshapeArrow()
            new_arrow.add_edge(slim_param_flat, gather.in_port(0))
            new_arrow.add_edge(src.out_port(0), gather.in_port(1))
            new_arrow.add_edge(gather.out_port(0), reshape.in_port(0))
            new_arrow.add_edge(shape_shape.out_port(0), reshape.in_port(1))
            new_arrow.add_edge(reshape.out_port(0), in_port)
        else:
            new_in_port = new_arrow.add_port()
            make_in_port(new_in_port)
            if is_param_port(in_port):
                make_param_port(new_in_port)
            transfer_labels(in_port, new_in_port)
            new_arrow.add_edge(new_in_port, in_port)
    assert nclasses % batch_size == 0
    slim_param = new_arrow.add_port()
    make_in_port(slim_param)
    make_param_port(slim_param)
    new_arrow.add_edge(slim_param, flatten.in_port(0))
    set_port_shape(slim_param, (batch_size, nclasses // batch_size))

    assert new_arrow.is_wired_correctly()
    return new_arrow
예제 #8
0
def eliminate_gathernd(arrow: CompositeArrow):
    """Eliminates redundant parameters in GatherNd
    Args:
        a: Parametric Arrow prime for eliminate!
    Returns:
        New Parameteric Arrow with fewer parameters"""
    dupls = filter_arrows(lambda a: a.name in dupl_names, arrow)
    for dupl in dupls:
        slim_param_arrow = UpdateArrow()
        constraints = None
        free = None
        shape = None
        shape_source = None
        for p in dupl.in_ports():
            inv = arrow.neigh_ports(p)[0].arrow
            if inv.name == 'InvGatherNd':
                out, theta, indices = inv.in_ports()
                make_not_param_port(theta)
                arrow.add_edge(slim_param_arrow.out_port(0), theta)

                indices_val = get_port_value(indices)
                if shape is not None:
                    assert np.array_equal(
                        shape, np.array(get_port_shape(inv.out_port(0))))
                else:
                    shape = np.array(get_port_shape(inv.out_port(0)))
                    shape_source = SourceArrow(shape)

                out = arrow.neigh_ports(out)[0]

                unset, unique = complement_bool_list(indices_val, shape)
                free = unset if free is None else np.logical_and(
                    unset, free)  # FIXME: don't really need bool anymore
                unique_source = SourceArrow(unique)
                unique_inds = SourceArrow(indices_val[tuple(
                    np.transpose(unique))])

                # add in a constraint for these new known values
                unique_upds = GatherNdArrow()
                arrow.add_edge(out, unique_upds.in_port(0))
                arrow.add_edge(unique_source.out_port(0),
                               unique_upds.in_port(1))
                tmp = UpdateArrow()
                if constraints is None:
                    constraints = SourceArrow(np.zeros(shape,
                                                       dtype=np.float32))
                arrow.add_edge(constraints.out_port(0), tmp.in_port(0))
                arrow.add_edge(unique_inds.out_port(0), tmp.in_port(1))
                arrow.add_edge(unique_upds.out_port(0), tmp.in_port(2))
                arrow.add_edge(shape_source.out_port(0), tmp.in_port(3))
                constraints = tmp
        if constraints is not None:
            # put in knowns
            arrow.add_edge(constraints.out_port(0),
                           slim_param_arrow.in_port(0))
            # put in params
            make_param_port(slim_param_arrow.in_port(2))
            arrow.add_edge(shape_source.out_port(0),
                           slim_param_arrow.in_port(3))
            inds = np.transpose(np.nonzero(free))
            if (free == free[0]).all():
                print("Assuming batched input")
                inds = np.array(np.split(inds, shape[0]))
            else:
                print(
                    "WARNING: Unbatched input, haven't designed for this case")
            inds_source = SourceArrow(inds)
            arrow.add_edge(inds_source.out_port(0),
                           slim_param_arrow.in_port(1))
예제 #9
0
def gen_data(lmbda: float, size: int):
    forward_arrow = CompositeArrow(name="forward")
    in_port1 = forward_arrow.add_port()
    make_in_port(in_port1)
    in_port2 = forward_arrow.add_port()
    make_in_port(in_port2)
    out_port = forward_arrow.add_port()
    make_out_port(out_port)

    subtraction = SubArrow()
    absolute = AbsArrow()
    forward_arrow.add_edge(in_port1, subtraction.in_ports()[0])
    forward_arrow.add_edge(in_port2, subtraction.in_ports()[1])
    forward_arrow.add_edge(subtraction.out_ports()[0], absolute.in_ports()[0])
    forward_arrow.add_edge(absolute.out_ports()[0], out_port)
    assert forward_arrow.is_wired_correctly()

    inverse_arrow = CompositeArrow(name="inverse")
    in_port = inverse_arrow.add_port()
    make_in_port(in_port)
    param_port_1 = inverse_arrow.add_port()
    make_in_port(param_port_1)
    make_param_port(param_port_1)
    param_port_2 = inverse_arrow.add_port()
    make_in_port(param_port_2)
    make_param_port(param_port_2)
    out_port_1 = inverse_arrow.add_port()
    make_out_port(out_port_1)
    out_port_2 = inverse_arrow.add_port()
    make_out_port(out_port_2)

    inv_sub = InvSubArrow()
    inv_abs = InvAbsArrow()
    inverse_arrow.add_edge(in_port, inv_abs.in_ports()[0])
    inverse_arrow.add_edge(param_port_1, inv_abs.in_ports()[1])
    inverse_arrow.add_edge(inv_abs.out_ports()[0], inv_sub.in_ports()[0])
    inverse_arrow.add_edge(param_port_2, inv_sub.in_ports()[1])
    inverse_arrow.add_edge(inv_sub.out_ports()[0], out_port_1)
    inverse_arrow.add_edge(inv_sub.out_ports()[1], out_port_2)
    assert inverse_arrow.is_wired_correctly()

    eps = 1e-5
    data = []
    dist = ExponentialRV(lmbda)
    for _ in range(size):
        x = tuple(dist.sample(shape=[2]))
        y = np.abs(x[0] - x[1])
        theta = (np.sign(x[0] - x[1]), x[1])
        y_ = apply(forward_arrow, list(x))
        x_ = apply(inverse_arrow, [y, theta[0], theta[1]])
        assert np.abs(y_[0] - y) < eps
        assert np.abs(x_[0] - x[0]) < eps
        assert np.abs(x_[1] - x[1]) < eps
        data.append((x, theta, y))
    return data, forward_arrow, inverse_arrow
예제 #10
0
    def __init__(self, lmbda: float) -> None:
        comp_arrow = CompositeArrow(name="ExponentialRVQuantile")
        in_port = comp_arrow.add_port()
        make_in_port(in_port)
        out_port = comp_arrow.add_port()
        make_out_port(out_port)

        lmbda_source = SourceArrow(lmbda)
        one_source = SourceArrow(1.0)
        one_minus_p = SubArrow()
        comp_arrow.add_edge(one_source.out_ports()[0],
                            one_minus_p.in_ports()[0])
        comp_arrow.add_edge(in_port, one_minus_p.in_ports()[1])
        ln = LogArrow()
        comp_arrow.add_edge(one_minus_p.out_ports()[0], ln.in_ports()[0])

        negate = NegArrow()
        comp_arrow.add_edge(ln.out_ports()[0], negate.in_ports()[0])
        div_lmbda = DivArrow()
        comp_arrow.add_edge(negate.out_ports()[0], div_lmbda.in_ports()[0])
        comp_arrow.add_edge(lmbda_source.out_ports()[0],
                            div_lmbda.in_ports()[1])
        comp_arrow.add_edge(div_lmbda.out_ports()[0], out_port)

        assert comp_arrow.is_wired_correctly()
        self.quantile = comp_arrow
예제 #11
0
def inv_fwd_loss_arrow(arrow: Arrow,
                       inverse: Arrow,
                       DiffArrow=SquaredDifference) -> CompositeArrow:
    """
    Arrow wihch computes |f(f^-1(y)) - y|
    Args:
        arrow: Forward function
    Returns:
        CompositeArrow
    """
    c = CompositeArrow(name="%s_inv_fwd_loss" % arrow.name)

    # Make all in_ports of inverse inputs to composition
    for inv_in_port in inverse.in_ports():
        in_port = c.add_port()
        make_in_port(in_port)
        if is_param_port(inv_in_port):
            make_param_port(in_port)
        c.add_edge(in_port, inv_in_port)

    # Connect all out_ports of inverse to in_ports of f
    for i, out_port in enumerate(inverse.out_ports()):
        if not is_error_port(out_port):
            c.add_edge(out_port, arrow.in_port(i))
            c_out_port = c.add_port()
            # add edge from inverse output to composition output
            make_out_port(c_out_port)
            c.add_edge(out_port, c_out_port)

    # Pass errors (if any) of parametric inverse through as error_ports
    for i, out_port in enumerate(inverse.out_ports()):
        if is_error_port(out_port):
            error_port = c.add_port()
            make_out_port(error_port)
            make_error_port(error_port)
            add_port_label(error_port, "sub_arrow_error")
            c.add_edge(out_port, error_port)

    # find difference between inputs to inverse and outputs of fwd
    # make error port for each
    for i, out_port in enumerate(arrow.out_ports()):
        diff = DiffArrow()
        c.add_edge(c.in_port(i), diff.in_port(0))
        c.add_edge(out_port, diff.in_port(1))
        error_port = c.add_port()
        make_out_port(error_port)
        make_error_port(error_port)
        add_port_label(error_port, "inv_fwd_error")
        c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c
예제 #12
0
def supervised_loss_arrow(arrow: Arrow,
                          DiffArrow=SquaredDifference) -> CompositeArrow:
    """
    Creates an arrow that  computes |f(y) - x|
    Args:
        Arrow: f: Y -> X - The arrow to modify
        DiffArrow: d: X x X - R - Arrow for computing difference
    Returns:
        f: Y/Theta x .. Y/Theta x X -> |f^{-1}(y) - X| x X
        Arrow with same input and output as arrow except that it takes an
        addition input with label 'train_output' that should contain examples
        in Y, and it returns an additional error output labelled
        'supervised_error' which is the |f(y) - x|
    """
    c = CompositeArrow(name="%s_supervised" % arrow.name)
    # Pipe all inputs of composite to inputs of arrow

    # Make all in_ports of inverse inputs to composition
    for in_port in arrow.in_ports():
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        if is_param_port(in_port):
            make_param_port(c_in_port)
        c.add_edge(c_in_port, in_port)

    # find difference between inputs to inverse and outputs of fwd
    # make error port for each
    for i, out_port in enumerate(arrow.out_ports()):
        if is_error_port(out_port):
            # if its an error port just pass through
            error_port = c.add_port()
            make_out_port(error_port)
            make_error_port(error_port)
            transfer_labels(out_port, error_port)
            c.add_edge(out_port, error_port)
        else:
            # If its normal outport then pass through
            c_out_port = c.add_port()
            make_out_port(c_out_port)
            c.add_edge(out_port, c_out_port)

            # And compute the error
            diff = DiffArrow()
            in_port = c.add_port()
            make_in_port(in_port)
            add_port_label(in_port, "train_output")
            c.add_edge(in_port, diff.in_port(0))
            c.add_edge(out_port, diff.in_port(1))
            error_port = c.add_port()
            make_out_port(error_port)
            make_error_port(error_port)
            add_port_label(error_port, "supervised_error")
            c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c
예제 #13
0
def comp(fwd: Arrow, right_inv: Arrow, DiffArrow=SquaredDifference):
    """Compositon: Pipe output of forward model into input of right inverse
    Args:
        fwd: X -> Y
        right_inv: Y -> X x Error
    Returns:
        X -> X
    """
    c = CompositeArrow(name="fwd_to_right_inv")

    # Connect left boundar to fwd
    for in_port in fwd.in_ports():
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        c.add_edge(c_in_port, in_port)
        transfer_labels(in_port, c_in_port)

    # Connect fwd to right_inv
    for i, out_port in enumerate(fwd.out_ports()):
        c.add_edge(out_port, right_inv.in_port(i))

    # connect right_inv to right boundary
    for out_port in right_inv.out_ports():
        c_out_port = c.add_port()
        make_out_port(c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)

        c.add_edge(out_port, c_out_port)
        transfer_labels(out_port, c_out_port)

    # Find difference between X and right_inv(f(x))
    right_inv_out_ports = list(filter(lambda port: not is_error_port(port),
                                      right_inv.out_ports()))  # len(X)
    assert len(right_inv_out_ports) == len(c.in_ports())
    for i, in_port in enumerate(c.in_ports()):
        diff = DiffArrow()
        c.add_edge(in_port, diff.in_port(0))
        c.add_edge(right_inv_out_ports[i], diff.in_port(1))
        error_port = c.add_port()
        make_out_port(error_port)
        make_error_port(error_port)
        add_port_label(error_port, "supervised_error")
        c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c
예제 #14
0
파일: invert.py 프로젝트: llwu/reverseflow
def inner_invert(comp_arrow: CompositeArrow, port_attr: PortAttributes,
                 dispatch: Dict[Arrow, Callable]):
    """Construct a parametric inverse of arrow
    Args:
        arrow: Arrow to invert
        dispatch: Dict mapping arrow class to invert function
    Returns:
        A (approximate) parametric inverse of `arrow`
        The ith in_port of comp_arrow will be corresponding ith out_port
        error_ports and param_ports will follow"""
    # Empty compositon for inverse
    inv_comp_arrow = CompositeArrow(name="%s_inv" % comp_arrow.name)
    # import pdb; pdb.set_trace()

    # Add a port on inverse arrow for every port on arrow
    for port in comp_arrow.ports():
        inv_port = inv_comp_arrow.add_port()
        if is_in_port(port):
            if is_constant(port, port_attr):
                make_in_port(inv_port)
            else:
                make_out_port(inv_port)
        elif is_out_port(port):
            if is_constant(port, port_attr):
                make_out_port(inv_port)
            else:
                make_in_port(inv_port)
        # Transfer port information
        # FIXME: What port_attr go transfered from port to inv_port
        if 'shape' not in port_attr[port]:
            print('WARNING: shape unknown for %s' % port)
        else:
            set_port_shape(inv_port, get_port_shape(port, port_attr))

    # invert each sub_arrow
    arrow_to_inv = dict()
    arrow_to_port_map = dict()
    for sub_arrow in comp_arrow.get_sub_arrows():
        inv_sub_arrow, port_map = invert_sub_arrow(sub_arrow, port_attr,
                                                   dispatch)
        assert sub_arrow is not None
        assert inv_sub_arrow.parent is None
        arrow_to_port_map[sub_arrow] = port_map
        arrow_to_inv[sub_arrow] = inv_sub_arrow

    # Add comp_arrow to inv
    assert comp_arrow is not None
    arrow_to_inv[comp_arrow] = inv_comp_arrow
    comp_port_map = {i: i for i in range(comp_arrow.num_ports())}
    arrow_to_port_map[comp_arrow] = comp_port_map

    # Then, rewire up all the edges
    for out_port, in_port in comp_arrow.edges.items():
        left_inv_port = get_inv_port(out_port, arrow_to_port_map, arrow_to_inv)
        right_inv_port = get_inv_port(in_port, arrow_to_port_map, arrow_to_inv)
        both = [left_inv_port, right_inv_port]
        projecting = list(
            filter(lambda x: would_project(x, inv_comp_arrow), both))
        receiving = list(
            filter(lambda x: would_receive(x, inv_comp_arrow), both))
        assert len(projecting) == 1, "Should be only 1 projecting"
        assert len(receiving) == 1, "Should be only 1 receiving"
        inv_comp_arrow.add_edge(projecting[0], receiving[0])

    for transform in transforms:
        transform(inv_comp_arrow)

    # Craete new ports on inverse compositions for parametric and error ports
    for sub_arrow in inv_comp_arrow.get_sub_arrows():
        for port in sub_arrow.ports():
            if is_param_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                param_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(param_port, port)
                make_in_port(param_port)
                make_param_port(param_port)
            elif is_error_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                error_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(port, error_port)
                make_out_port(error_port)
                make_error_port(error_port)

    return inv_comp_arrow, comp_port_map