Exemplo n.º 1
0
def apply_backwards(arrow: Arrow,
                    outputs: List[np.ndarray],
                    port_attr=None) -> List[np.ndarray]:
    """
    Takes out_port vals (excluding errors) and returns in_port vals (including params).
    FIXME: Mutates port_attr
    """
    out_ports = [
        out_port for out_port in arrow.out_ports()
        if not is_error_port(out_port)
    ]
    if port_attr is None:
        port_attr = propagate(arrow)
    for i, out_port in enumerate(out_ports):
        if out_port not in port_attr:
            port_attr[out_port] = {}
        port_attr[out_port]['value'] = outputs[i]
    for out_port in arrow.out_ports():
        if is_error_port(out_port):
            if out_port not in port_attr:
                port_attr[out_port] = {}
            if 'shape' in port_attr[out_port]:
                port_attr[out_port]['value'] = np.zeros(
                    port_attr[out_port]['shape'])
            else:
                # FIXME: there has to be a better way to do this
                print("WARNING: shape of error port unknown: %s" % (out_port))
                port_attr[out_port]['value'] = 0

    port_attr = propagate(arrow, port_attr)  #, only_prop=set(['value']))
    vals = extract_attribute('value', port_attr)
    in_vals = {port: vals[port] for port in arrow.in_ports() if port in vals}
    return in_vals
Exemplo n.º 2
0
def inv_fwd_loss_arrow(arrow: Arrow,
                       inverse: Arrow,
                       DiffArrow=SquaredDifference) -> CompositeArrow:
    """
    Arrow wihch computes |f(f^-1(y)) - y|
    Args:
        arrow: Forward function
    Returns:
        CompositeArrow
    """
    c = CompositeArrow(name="%s_inv_fwd_loss" % arrow.name)

    # Make all in_ports of inverse inputs to composition
    for inv_in_port in inverse.in_ports():
        in_port = c.add_port()
        make_in_port(in_port)
        if is_param_port(inv_in_port):
            make_param_port(in_port)
        c.add_edge(in_port, inv_in_port)

    # Connect all out_ports of inverse to in_ports of f
    for i, out_port in enumerate(inverse.out_ports()):
        if not is_error_port(out_port):
            c.add_edge(out_port, arrow.in_port(i))
            c_out_port = c.add_port()
            # add edge from inverse output to composition output
            make_out_port(c_out_port)
            c.add_edge(out_port, c_out_port)

    # Pass errors (if any) of parametric inverse through as error_ports
    for i, out_port in enumerate(inverse.out_ports()):
        if is_error_port(out_port):
            error_port = c.add_port()
            make_out_port(error_port)
            make_error_port(error_port)
            add_port_label(error_port, "sub_arrow_error")
            c.add_edge(out_port, error_port)

    # find difference between inputs to inverse and outputs of fwd
    # make error port for each
    for i, out_port in enumerate(arrow.out_ports()):
        diff = DiffArrow()
        c.add_edge(c.in_port(i), diff.in_port(0))
        c.add_edge(out_port, diff.in_port(1))
        error_port = c.add_port()
        make_out_port(error_port)
        make_error_port(error_port)
        add_port_label(error_port, "inv_fwd_error")
        c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c
Exemplo n.º 3
0
def unparam(arrow: Arrow, nnet: Arrow = None):
    """Unparameerize an arrow by sticking a tfArrow between its normal inputs,
    and any parametric inputs
    Args:
        arrow: Y x Theta -> X
        nnet: Y -> Theta
    Returns:
        Y -> X
    """
    c = CompositeArrow(name="%s_unparam" % arrow.name)
    in_ports = [p for p in arrow.in_ports() if not is_param_port(p)]
    param_ports = [p for p in arrow.in_ports() if is_param_port(p)]
    if nnet is None:
        nnet = TfArrow(n_in_ports=len(in_ports), n_out_ports=len(param_ports))
    for i, in_port in enumerate(in_ports):
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        transfer_labels(in_port, c_in_port)
        c.add_edge(c_in_port, in_port)
        c.add_edge(c_in_port, nnet.in_port(i))

    for i, param_port in enumerate(param_ports):
        c.add_edge(nnet.out_port(i), param_port)

    for out_port in arrow.out_ports():
        c_out_port = c.add_port()
        make_out_port(c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)
        transfer_labels(out_port, c_out_port)
        c.add_edge(out_port, c_out_port)

    assert c.is_wired_correctly()
    return c
Exemplo n.º 4
0
def get_param_pairs(inv,
                    voxel_grids,
                    batch_size,
                    n,
                    port_attr=None,
                    pickle_to=None):
    """Pulls params from 'forward' runs. FIXME: mutates port_attr."""
    if port_attr is None:
        port_attr = propagate(inv)
    shapes = [
        port_attr[port]['shape'] for port in inv.out_ports()
        if not is_error_port(port)
    ]
    params = []
    inputs = []
    for i in range(n):
        rand_voxel_id = np.random.randint(0,
                                          voxel_grids.shape[0],
                                          size=batch_size)
        input_data = [
            voxel_grids[rand_voxel_id].reshape(shape).astype(np.float32)
            for shape in shapes
        ]
        inputs.append(input_data)
        params_bwd = apply_backwards(inv, input_data, port_attr=None)
        params_list = [
            params_bwd[port] for port in inv.in_ports() if is_param_port(port)
        ]
        params.append(params_list)
    if pickle_to is not None:
        with open(pickle_to, 'wb') as f:
            pickle.dump((inputs, params), f)
    return inputs, params
Exemplo n.º 5
0
 def error_ports(self):
     """
     Get ErrorPorts of an Arrow.
     Returns:
         List of ErrorPorts
     """
     return [port for port in self._ports if pa.is_error_port(port)]
Exemplo n.º 6
0
def comp(fwd: Arrow, right_inv: Arrow, DiffArrow=SquaredDifference):
    """Compositon: Pipe output of forward model into input of right inverse
    Args:
        fwd: X -> Y
        right_inv: Y -> X x Error
    Returns:
        X -> X
    """
    c = CompositeArrow(name="fwd_to_right_inv")

    # Connect left boundar to fwd
    for in_port in fwd.in_ports():
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        c.add_edge(c_in_port, in_port)
        transfer_labels(in_port, c_in_port)

    # Connect fwd to right_inv
    for i, out_port in enumerate(fwd.out_ports()):
        c.add_edge(out_port, right_inv.in_port(i))

    # connect right_inv to right boundary
    for out_port in right_inv.out_ports():
        c_out_port = c.add_port()
        make_out_port(c_out_port)
        if is_error_port(out_port):
            make_error_port(c_out_port)

        c.add_edge(out_port, c_out_port)
        transfer_labels(out_port, c_out_port)

    # Find difference between X and right_inv(f(x))
    right_inv_out_ports = list(filter(lambda port: not is_error_port(port),
                                      right_inv.out_ports()))  # len(X)
    assert len(right_inv_out_ports) == len(c.in_ports())
    for i, in_port in enumerate(c.in_ports()):
        diff = DiffArrow()
        c.add_edge(in_port, diff.in_port(0))
        c.add_edge(right_inv_out_ports[i], diff.in_port(1))
        error_port = c.add_port()
        make_out_port(error_port)
        make_error_port(error_port)
        add_port_label(error_port, "supervised_error")
        c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c
Exemplo n.º 7
0
def test_apply_backwards():
    orig = test_twoxyplusx()
    arrow = invert(orig)
    outputs = [
        np.random.randn(2, 2) for out_port in arrow.out_ports()
        if not is_error_port(out_port)
    ]
    return orig, arrow, outputs, apply_backwards(arrow, outputs)
Exemplo n.º 8
0
def default_grans():
    """Default tensors to grab"""
    def_grabs = {
        'input': lambda p: is_in_port(p) and not is_param_port(p),
        'param': lambda p: is_param_port(p),
        'error': lambda p: is_error_port(p),
        'output': lambda p: is_out_port(p)
    }
    return def_grabs
Exemplo n.º 9
0
def supervised_loss_arrow(arrow: Arrow,
                          DiffArrow=SquaredDifference) -> CompositeArrow:
    """
    Creates an arrow that  computes |f(y) - x|
    Args:
        Arrow: f: Y -> X - The arrow to modify
        DiffArrow: d: X x X - R - Arrow for computing difference
    Returns:
        f: Y/Theta x .. Y/Theta x X -> |f^{-1}(y) - X| x X
        Arrow with same input and output as arrow except that it takes an
        addition input with label 'train_output' that should contain examples
        in Y, and it returns an additional error output labelled
        'supervised_error' which is the |f(y) - x|
    """
    c = CompositeArrow(name="%s_supervised" % arrow.name)
    # Pipe all inputs of composite to inputs of arrow

    # Make all in_ports of inverse inputs to composition
    for in_port in arrow.in_ports():
        c_in_port = c.add_port()
        make_in_port(c_in_port)
        if is_param_port(in_port):
            make_param_port(c_in_port)
        c.add_edge(c_in_port, in_port)

    # find difference between inputs to inverse and outputs of fwd
    # make error port for each
    for i, out_port in enumerate(arrow.out_ports()):
        if is_error_port(out_port):
            # if its an error port just pass through
            error_port = c.add_port()
            make_out_port(error_port)
            make_error_port(error_port)
            transfer_labels(out_port, error_port)
            c.add_edge(out_port, error_port)
        else:
            # If its normal outport then pass through
            c_out_port = c.add_port()
            make_out_port(c_out_port)
            c.add_edge(out_port, c_out_port)

            # And compute the error
            diff = DiffArrow()
            in_port = c.add_port()
            make_in_port(in_port)
            add_port_label(in_port, "train_output")
            c.add_edge(in_port, diff.in_port(0))
            c.add_edge(out_port, diff.in_port(1))
            error_port = c.add_port()
            make_out_port(error_port)
            make_error_port(error_port)
            add_port_label(error_port, "supervised_error")
            c.add_edge(diff.out_port(0), error_port)

    assert c.is_wired_correctly()
    return c
Exemplo n.º 10
0
    def wrap(a: Arrow):
        """Wrap an arrow in a composite arrow"""
        c = CompositeArrow(name=a.name)
        for port in a.ports():
            c_port = c.add_port()
            if is_in_port(port):
                make_in_port(c_port)
                c.add_edge(c_port, port)
            if is_param_port(port):
                make_param_port(c_port)
            if is_out_port(port):
                make_out_port(c_port)
                c.add_edge(port, c_port)
            if is_error_port(port):
                make_error_port(c_port)
            transfer_labels(port, c_port)

        assert c.is_wired_correctly()
        return c