コード例 #1
0
def test_gather1():
    axis = 1
    x = np.random.randn(5, 4, 3, 2).astype(np.float32)
    idx = np.array([0, 1, 3])

    with pm.Node(name="gather_op") as graph:
        data = pm.input(name="input", shape=x.shape)
        indices = pm.input(name="indices", shape=idx.shape)
        out = pm.gather(data, indices, axis=axis, name="res")

    pm_y = graph("res", {"input": x, "indices": idx})
    np_y = np.take(x, idx, axis=axis)
    np.testing.assert_allclose(np_y, pm_y)
コード例 #2
0
 def define_graph(self, data, indices, output, axis=0):
     # TODO: Fix this to use manual implementation
     output.write(pm.gather(data, indices, axis=axis))