Esempio n. 1
0
def std_symbt_pred(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    # sparse_indices: value
    # output_shape: value
    # sparse_values: SymbolicTensor
    a = ports_has(arr.in_ports()[0:2], 'value', port_attr)
    b = port_has(arr.in_port(2), 'symbolic_tensor', port_attr)
    # Hack, a FIXME to propagate should stop repropagation
    c = not port_has(arr.out_port(0), 'symbolic_tensor', port_attr)
    return a and b and c
Esempio n. 2
0
def add_pred1(arr: "AddArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Esempio n. 3
0
def div_pred1(arr: "DivArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr) and np.all(
        port_attr[arr.in_port(1)]['value'])
Esempio n. 4
0
def mul_pred3(arr: "MulArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[1], arr.out_ports()[0]]
    # TODO: think harder about the zeros case
    return ports_has(ports, 'value', port_attr) and np.all(
        port_attr[ports[0]]['value'])
Esempio n. 5
0
def std_pred2(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Esempio n. 6
0
def broadcast_pred(arr: Arrow, port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'shape', port_attr)
Esempio n. 7
0
def neg_fwd_pred(arr: "NegArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Esempio n. 8
0
def pow_pred2(arr: "PowArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[0], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr)
Esempio n. 9
0
def reshape_pred3(arr: "ReshapeArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports()[1:2], 'value', port_attr)
Esempio n. 10
0
def reshape_pred2(arr: "ReshapeArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports()[:1], 'shape', port_attr)
Esempio n. 11
0
def snd_pred4(arr: "ScatterNdArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports(), 'shape', port_attr)
Esempio n. 12
0
def snd_pred2(arr: "ScatterNdArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Esempio n. 13
0
def std_pred4(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports(), 'shape', port_attr)
Esempio n. 14
0
def std_pred3(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports()[0:1], 'shape', port_attr)
Esempio n. 15
0
def div_pred2(arr: "DivArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[0], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr) and np.all(
        port_attr[arr.out_port(0)]['value'])
Esempio n. 16
0
def fwd_pred(arr: "SelectArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Esempio n. 17
0
def pow_pred1(arr: "PowArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Esempio n. 18
0
def gathernd_shape_pred(arr: "GatherArrow", port_attr: PortAttributes):
    # FIXME: Can we infer shaep from output or aoutput and one input?
    return ports_has(arr.in_ports(), 'shape', port_attr)
Esempio n. 19
0
def log_bwd_pred(arr: "LogArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports(), 'value', port_attr)
Esempio n. 20
0
def mul_pred1(arr: "MulArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Esempio n. 21
0
def add_pred3(arr: "AddArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[1], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr)
Esempio n. 22
0
def mul_pred2(arr: "MulArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[0], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr) and np.all(
        port_attr[ports[0]]['value'])
Esempio n. 23
0
def sub_pred1(arr: "SubArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Esempio n. 24
0
def constant_pred(arr, port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'constant', port_attr)