示例#1
0
def std_symbt_pred(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    # sparse_indices: value
    # output_shape: value
    # sparse_values: SymbolicTensor
    a = ports_has(arr.in_ports()[0:2], 'value', port_attr)
    b = port_has(arr.in_port(2), 'symbolic_tensor', port_attr)
    # Hack, a FIXME to propagate should stop repropagation
    c = not port_has(arr.out_port(0), 'symbolic_tensor', port_attr)
    return a and b and c
示例#2
0
def add_pred1(arr: "AddArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
示例#3
0
def div_pred1(arr: "DivArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr) and np.all(
        port_attr[arr.in_port(1)]['value'])
示例#4
0
def mul_pred3(arr: "MulArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[1], arr.out_ports()[0]]
    # TODO: think harder about the zeros case
    return ports_has(ports, 'value', port_attr) and np.all(
        port_attr[ports[0]]['value'])
示例#5
0
def std_pred2(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
示例#6
0
def broadcast_pred(arr: Arrow, port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'shape', port_attr)
示例#7
0
def neg_fwd_pred(arr: "NegArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
示例#8
0
def pow_pred2(arr: "PowArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[0], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr)
示例#9
0
def reshape_pred3(arr: "ReshapeArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports()[1:2], 'value', port_attr)
示例#10
0
def reshape_pred2(arr: "ReshapeArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports()[:1], 'shape', port_attr)
示例#11
0
def snd_pred4(arr: "ScatterNdArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports(), 'shape', port_attr)
示例#12
0
def snd_pred2(arr: "ScatterNdArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
示例#13
0
def std_pred4(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports(), 'shape', port_attr)
示例#14
0
def std_pred3(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports()[0:1], 'shape', port_attr)
示例#15
0
def div_pred2(arr: "DivArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[0], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr) and np.all(
        port_attr[arr.out_port(0)]['value'])
示例#16
0
def fwd_pred(arr: "SelectArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
示例#17
0
def pow_pred1(arr: "PowArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
示例#18
0
def gathernd_shape_pred(arr: "GatherArrow", port_attr: PortAttributes):
    # FIXME: Can we infer shaep from output or aoutput and one input?
    return ports_has(arr.in_ports(), 'shape', port_attr)
示例#19
0
def log_bwd_pred(arr: "LogArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports(), 'value', port_attr)
示例#20
0
def mul_pred1(arr: "MulArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
示例#21
0
def add_pred3(arr: "AddArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[1], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr)
示例#22
0
def mul_pred2(arr: "MulArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[0], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr) and np.all(
        port_attr[ports[0]]['value'])
示例#23
0
def sub_pred1(arr: "SubArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
示例#24
0
def constant_pred(arr, port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'constant', port_attr)