Exemple #1
0
def std_symbt_pred(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    # sparse_indices: value
    # output_shape: value
    # sparse_values: SymbolicTensor
    a = ports_has(arr.in_ports()[0:2], 'value', port_attr)
    b = port_has(arr.in_port(2), 'symbolic_tensor', port_attr)
    # Hack, a FIXME to propagate should stop repropagation
    c = not port_has(arr.out_port(0), 'symbolic_tensor', port_attr)
    return a and b and c
Exemple #2
0
def add_pred1(arr: "AddArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Exemple #3
0
def div_pred1(arr: "DivArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr) and np.all(
        port_attr[arr.in_port(1)]['value'])
Exemple #4
0
def mul_pred3(arr: "MulArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[1], arr.out_ports()[0]]
    # TODO: think harder about the zeros case
    return ports_has(ports, 'value', port_attr) and np.all(
        port_attr[ports[0]]['value'])
Exemple #5
0
def std_pred2(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Exemple #6
0
def broadcast_pred(arr: Arrow, port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'shape', port_attr)
Exemple #7
0
def neg_fwd_pred(arr: "NegArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Exemple #8
0
def pow_pred2(arr: "PowArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[0], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr)
Exemple #9
0
def reshape_pred3(arr: "ReshapeArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports()[1:2], 'value', port_attr)
Exemple #10
0
def reshape_pred2(arr: "ReshapeArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports()[:1], 'shape', port_attr)
Exemple #11
0
def snd_pred4(arr: "ScatterNdArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports(), 'shape', port_attr)
Exemple #12
0
def snd_pred2(arr: "ScatterNdArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Exemple #13
0
def std_pred4(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports(), 'shape', port_attr)
Exemple #14
0
def std_pred3(arr: "SparseToDenseArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports()[0:1], 'shape', port_attr)
Exemple #15
0
def div_pred2(arr: "DivArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[0], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr) and np.all(
        port_attr[arr.out_port(0)]['value'])
Exemple #16
0
def fwd_pred(arr: "SelectArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Exemple #17
0
def pow_pred1(arr: "PowArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Exemple #18
0
def gathernd_shape_pred(arr: "GatherArrow", port_attr: PortAttributes):
    # FIXME: Can we infer shaep from output or aoutput and one input?
    return ports_has(arr.in_ports(), 'shape', port_attr)
Exemple #19
0
def log_bwd_pred(arr: "LogArrow", port_attr: PortAttributes):
    return ports_has(arr.out_ports(), 'value', port_attr)
Exemple #20
0
def mul_pred1(arr: "MulArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Exemple #21
0
def add_pred3(arr: "AddArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[1], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr)
Exemple #22
0
def mul_pred2(arr: "MulArrow", port_attr: PortAttributes):
    ports = [arr.in_ports()[0], arr.out_ports()[0]]
    return ports_has(ports, 'value', port_attr) and np.all(
        port_attr[ports[0]]['value'])
Exemple #23
0
def sub_pred1(arr: "SubArrow", port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'value', port_attr)
Exemple #24
0
def constant_pred(arr, port_attr: PortAttributes):
    return ports_has(arr.in_ports(), 'constant', port_attr)