예제 #1
0
파일: apply.py 프로젝트: llwu/reverseflow
def apply_backwards(arrow: Arrow,
                    outputs: List[np.ndarray],
                    port_attr=None) -> List[np.ndarray]:
    """
    Takes out_port vals (excluding errors) and returns in_port vals (including params).
    FIXME: Mutates port_attr
    """
    out_ports = [
        out_port for out_port in arrow.out_ports()
        if not is_error_port(out_port)
    ]
    if port_attr is None:
        port_attr = propagate(arrow)
    for i, out_port in enumerate(out_ports):
        if out_port not in port_attr:
            port_attr[out_port] = {}
        port_attr[out_port]['value'] = outputs[i]
    for out_port in arrow.out_ports():
        if is_error_port(out_port):
            if out_port not in port_attr:
                port_attr[out_port] = {}
            if 'shape' in port_attr[out_port]:
                port_attr[out_port]['value'] = np.zeros(
                    port_attr[out_port]['shape'])
            else:
                # FIXME: there has to be a better way to do this
                print("WARNING: shape of error port unknown: %s" % (out_port))
                port_attr[out_port]['value'] = 0

    port_attr = propagate(arrow, port_attr)  #, only_prop=set(['value']))
    vals = extract_attribute('value', port_attr)
    in_vals = {port: vals[port] for port in arrow.in_ports() if port in vals}
    return in_vals
예제 #2
0
def constant_dispatch(arr, port_attr: PortAttributes, state=None):
    ptc = extract_attribute('constant', port_attr)
    # All the outputs are constant if and only if all the inputs are constant
    if all((value == CONST for value in ptc.values())):
        return {port: {'constant': CONST} for port in arr.ports()}
    else:
        return {port: {'constant': ptc[port]} if port in ptc \
    else {'constant': VAR} for port in arr.ports()}
예제 #3
0
def mul_symbt_disp(arr: "MulArrow", port_attr: PortAttributes):
    ptv = extract_attribute('symbolic_tensor', port_attr)
    if arr.in_port(0) in ptv and arr.in_port(1) in ptv:
        assert False, "Figure this out"
    elif arr.in_port(0) in ptv:
        # probably wrong as nothing is being multiplied
        return {arr.out_port(0): {'symbolic_tensor': ptv[arr.in_port(0)]}}
    elif arr.in_port(1) in ptv:
        return {arr.out_port(0): {'symbolic_tensor': ptv[arr.in_port(1)]}}
    else:
        assert False, "why am i here"
예제 #4
0
def add_symbt_disp(arr: "AddArrow", port_attr: PortAttributes):
    ptv = extract_attribute('symbolic_tensor', port_attr)
    if arr.in_port(0) in ptv and arr.in_port(1) in ptv:
        assert False, "Figure this out"
    elif arr.in_port(0) in ptv:
        # seems wrong tbh, nothing is being added?
        return {arr.out_port(0): {'symbolic_tensor': ptv[arr.in_port(0)]}}
    elif arr.in_port(1) in ptv:
        return {arr.out_port(0): {'symbolic_tensor': ptv[arr.in_port(1)]}}
    else:
        assert False, "why am i here"
예제 #5
0
def broadcast_dispatch(arr: Arrow, port_attr: PortAttributes):
    """Decide output shape."""
    pts = extract_attribute('shape', port_attr)
    in_pts = extract(arr.in_ports(), pts)
    shapes = list(in_pts.values())
    shape = ()

    # FIXME: Broadcasting rule are complex, lets cheat
    g = tf.Graph()
    with g.as_default():
        phs = [
            tf.placeholder(dtype='float32', shape=shape) for shape in shapes
        ]
        # import pdb; pdb.set_trace()
        z = tf.add(*phs)
        shape = tuple(z.get_shape().as_list())

    # for s in shapes:
    #     if len(s) >= len(shape):
    #         if len(shape) > 0:
    #             assert s[-len(shape):] == shape, "Shapes incompatible %s %s %s" % (s, s[-len(shape):], shape)
    #         shape = s
    # print("Broadcasting %s" % pts)
    return {port: {'shape': shape} for port in arr.out_ports()}
예제 #6
0
def pow_dispatch2(arr: "PowArrow", port_attr: PortAttributes):
    ptv = extract_attribute('value', port_attr)
    i = arr.in_ports()
    o = arr.out_ports()
    return {i[1]: {'value': np.log(ptv[i[0]]) / np.log(ptv[o[0]])}}
예제 #7
0
def sub_dispatch1(arr: "SubArrow", port_attr: PortAttributes):
    ptv = extract_attribute('value', port_attr)
    i = arr.in_ports()
    o = arr.out_ports()
    return {o[0]: {'value': ptv[i[0]] - ptv[i[1]]}}
예제 #8
0
def add_dispatch3(arr: "AddArrow", port_attr: PortAttributes):
    ptv = extract_attribute('value', port_attr)
    i = arr.in_ports()
    o = arr.out_ports()
    return {i[0]: {'value': ptv[o[0]] - ptv[i[1]]}}
예제 #9
0
def neg_fwd_disp(arr: "NegArrow", port_attr: PortAttributes):
    ptv = extract_attribute('value', port_attr)
    in_val = port_attr[arr.in_port(0)]['value']
    return {arr.out_port(0): {'value': -in_val}}
예제 #10
0
def log_fwd_disp(arr: "LogArrow", port_attr: PortAttributes):
    ptv = extract_attribute('value', port_attr)
    in_val = port_attr[arr.in_port(0)]['value']
    return {arr.out_port(0): {'value': np.log(in_val)}}
예제 #11
0
def pow_dispatch3(arr: "PowArrow", port_attr: PortAttributes):
    ptv = extract_attribute('value', port_attr)
    i = arr.in_ports()
    o = arr.out_ports()
    return {i[0]: {'value': ptv[o[0]]**(1.0 / ptv[i[1]])}}
예제 #12
0
def gather_shape_dispatch(arr: "GatherArrow", port_attr: PortAttributes):
    # Produces an output tensor with shape `indices.shape + params.shape[1:]`
    pts = extract_attribute('shape', port_attr)
    indices_shape = const_to_tuple(pts[arr.in_ports()[1]])
    param_shape = const_to_tuple(pts[arr.in_ports()[0]])
    return {arr.out_ports()[0]: {'shape': indices_shape + param_shape[1:]}}
예제 #13
0
def div_dispatch2(arr: "DivArrow", port_attr: PortAttributes):
    ptv = extract_attribute('value', port_attr)
    i = arr.in_ports()
    o = arr.out_ports()
    return {i[1]: {'value': ptv[i[0]] / ptv[o[0]]}}
예제 #14
0
def gathernd_shape_dispatch(arr: "GatherArrow", port_attr: PortAttributes):
    # [d_0, ..., d_{Q-2}, params.shape[K], ..., params.shape[P-1]].
    pts = extract_attribute('shape', port_attr)
    indices_shape = const_to_tuple(pts[arr.in_ports()[1]])
    param_shape = const_to_tuple(pts[arr.in_ports()[0]])
    return {arr.out_ports()[0]: {'shape': indices_shape[:-1] + param_shape[indices_shape[-1]:]}}
예제 #15
0
def stack_shape_disp(arr: "StackArrow", port_attr: PortAttributes):
    ptv = extract_attribute('shape', port_attr)
    import pdb; pdb.set_trace()
    o = port_attr[arr.out_port(0)]
    return {arr.out_port(0): {'shape': new_shape}}
예제 #16
0
def fwd_disp(arr: "SelectArrow", port_attr: PortAttributes):
    ptv = extract_attribute('value', port_attr)
    i = arr.in_ports()
    o = arr.out_ports()
    return {o[0] : {'value': np.where(ptv[i[0]], ptv[i[1]], ptv[i[2]])}}
예제 #17
0
def reshape_eval_dispatch(arr: "ReshapeArrow", port_attr: PortAttributes):
    ptv = extract_attribute('value', port_attr)
    i = arr.in_ports()
    o = arr.out_ports()
    res = np.reshape(ptv[i[0]], ptv[i[1]])
    return {o[0]: {'value': res}}