コード例 #1
0
def test_contains():
    with pm.Node() as graph:
        test = pm.placeholder()
        alphabet = pm.variable('abc')
        contains = pm.contains(alphabet, test)

    assert graph(contains, {test: 'a'})
    assert not graph(contains, {test: 'x'})
コード例 #2
0
def test_new():
    test_a = np.array([1, 2, 3, 4])
    test_b = np.array([5, 6, 7, 8])
    test_placeholder = pm.placeholder("hello")
    with pm.Node(name="main") as graph:
        a = pm.parameter(default=6, name="a")
        b = pm.parameter(default=5, name="b")
        a = (a + b).set_name("a_mul_b")
        with pm.Node(name="graph2") as graph2:
            c = pm.variable([[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
                             [[9, 10, 11], [12, 13, 14], [15, 16, 17]],
                             [[18, 19, 20], [21, 22, 23], [24, 25, 26]]],
                            name="c")
            c_2 = (c * 2).set_name(name="c2")
            e = pm.parameter(default=4, name="e")
            l = pm.placeholder("test")
            x = (l * e).set_name("placeholdermult")
            i = pm.index(0, 1, name="i")
            j = pm.index(0, 1, name="j")
            k = pm.index(0, 2, name="k")
            e_i = pm.var_index(c, [i, j, k], "e_i")
コード例 #3
0
def generate_srdfg(onnx_graph):
    names = [des.name for des in onnx_graph.DESCRIPTOR.fields]
    graph_name = getattr(onnx_graph, "name")
    initializers = get_initializers(onnx_graph.initializer)
    mgdfg = pm.Node(name=graph_name)
    # TODO: This is a hotfix for identifying gradient updates, but weights should have initializers
    state_variables = get_states_by_gradient(onnx_graph)
    node_info = {}
    # TODO: If a value has an initializer, set the initializer value as the value for the node
    for o in onnx_graph.output:

        assert o.name not in node_info

        if o.name in state_variables:
            node_info[o.name] = pm.state(name=state_variables[o.name],
                                         shape=get_value_info_shape(o, mgdfg),
                                         graph=mgdfg)
            node_info[state_variables[o.name]] = node_info[o.name]
        else:
            node_info[o.name] = pm.output(name=o.name,
                                          shape=get_value_info_shape(o, mgdfg),
                                          graph=mgdfg)

    for i in onnx_graph.input:
        if i.name in state_variables.values():
            assert i.name in node_info
            continue
        assert i.name not in node_info
        if i.name in state_variables:
            node_info[i.name] = pm.state(name=state_variables[i.name],
                                         shape=get_value_info_shape(i, mgdfg),
                                         graph=mgdfg)
            node_info[state_variables[i.name]] = node_info[i.name]
        elif i.name in initializers and not itercheck(initializers[i.name]):
            node_info[i.name] = pm.parameter(name=i.name,
                                             default=initializers[i.name],
                                             graph=mgdfg)
        elif i.name in initializers:
            node_info[i.name] = pm.state(name=i.name,
                                         shape=get_value_info_shape(i, mgdfg),
                                         graph=mgdfg)
        else:
            node_info[i.name] = pm.input(name=i.name,
                                         shape=get_value_info_shape(i, mgdfg),
                                         graph=mgdfg)

    for v in onnx_graph.value_info:
        if v.name in node_info:
            continue
        elif v.name in initializers:
            node_info[v.name] = pm.variable(initializers[v.name],
                                            name=v.name,
                                            shape=get_value_info_shape(
                                                v, mgdfg),
                                            graph=mgdfg)
        else:

            node_info[v.name] = {
                "name": v.name,
                "shape": get_value_info_shape(v, mgdfg)
            }

    for k, v in initializers.items():
        if k not in node_info:
            # TODO: Need to set the value here
            node_info[k] = pm.state(name=k,
                                    shape=get_value_info_shape(v, mgdfg),
                                    graph=mgdfg)
            state_variables[k] = node_info[k]

    for k, v in mgdfg.nodes.items():
        if isinstance(v, pm.parameter) and k not in node_info:
            node_info[k] = v

    for n in onnx_graph.node:
        assert n.op_type in NODE_NAMES
        _ = convert_node(n, mgdfg, node_info, state_variables)

    return mgdfg
コード例 #4
0
def test_slice():
    with pm.Node() as graph:
        a = pm.variable(range(100))
        b = pm.parameter(default=1)
        c = a[b:]
    assert len(graph(c)) == 99
コード例 #5
0
def test_reversed():
    with pm.Node() as graph:
        rev = reversed(pm.variable('abc'))

    assert list(graph(rev)) == list('cba')
コード例 #6
0
def test_iter():
    with pm.Node(name="outer") as graph:
        pm.variable('abc', name='alphabet', shape=3)
    a, b, c = graph['alphabet']
    assert graph([a, b, c]) == tuple('abc')