def test_gather1(): axis = 1 x = np.random.randn(5, 4, 3, 2).astype(np.float32) idx = np.array([0, 1, 3]) with pm.Node(name="gather_op") as graph: data = pm.input(name="input", shape=x.shape) indices = pm.input(name="indices", shape=idx.shape) out = pm.gather(data, indices, axis=axis, name="res") pm_y = graph("res", {"input": x, "indices": idx}) np_y = np.take(x, idx, axis=axis) np.testing.assert_allclose(np_y, pm_y)
def define_graph(self, data, indices, output, axis=0): # TODO: Fix this to use manual implementation output.write(pm.gather(data, indices, axis=axis))