def _get_one_dim(self, val, key, axis): """Slice along one axis, keeping the dimensionality of the input""" if isinstance(key, slice): if any(k is not None for k in [key.start, key.stop, key.step]): ix = np.arange(*key.indices(val.shape[axis])) return O.gather(val, ix, axis=axis) else: return val elif isinstance(key, int): key %= val.shape[axis] return O.gather(val, [key], axis=axis) else: return O.gather(val, key, axis=axis)
def __call__(self, x): """Perform the forward pass""" embs = [ O.gather(self.embeddings[i](), x[:, i]) for i in range(len(self.embeddings)) ] return O.cat(embs, -1)
def test_gather(): """Tests gather""" pf.set_backend('pytorch') # Should lookup along 1st axis by default vals = torch.Tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) inds = torch.LongTensor([0, 1, 2, 1, 0]) output = ops.gather(vals, inds) assert output.ndim == 2 assert output.shape[0] == 5 assert output.shape[1] == 2 assert output.numpy()[0, 0] == 1.0 assert output.numpy()[0, 1] == 2.0 assert output.numpy()[1, 0] == 3.0 assert output.numpy()[1, 1] == 4.0 assert output.numpy()[2, 0] == 5.0 assert output.numpy()[2, 1] == 6.0 assert output.numpy()[3, 0] == 3.0 assert output.numpy()[3, 1] == 4.0 assert output.numpy()[4, 0] == 1.0 assert output.numpy()[4, 1] == 2.0 # But can set axis inds = torch.LongTensor([1, 0, 1, 0]) output = ops.gather(vals, inds, axis=1) assert output.ndim == 2 assert output.shape[0] == 3 assert output.shape[1] == 4 assert output.numpy()[0, 0] == 2.0 assert output.numpy()[1, 0] == 4.0 assert output.numpy()[2, 0] == 6.0 assert output.numpy()[0, 1] == 1.0 assert output.numpy()[1, 1] == 3.0 assert output.numpy()[2, 1] == 5.0 assert output.numpy()[0, 2] == 2.0 assert output.numpy()[1, 2] == 4.0 assert output.numpy()[2, 2] == 6.0 assert output.numpy()[0, 3] == 1.0 assert output.numpy()[1, 3] == 3.0 assert output.numpy()[2, 3] == 5.0
def test_gather(): """Tests gather""" # Should lookup along 1st axis by default vals = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) inds = tf.constant([0, 1, 2, 1, 0], dtype=tf.int32) output = ops.gather(vals, inds) assert output.ndim == 2 assert output.shape[0] == 5 assert output.shape[1] == 2 assert output.numpy()[0, 0] == 1.0 assert output.numpy()[0, 1] == 2.0 assert output.numpy()[1, 0] == 3.0 assert output.numpy()[1, 1] == 4.0 assert output.numpy()[2, 0] == 5.0 assert output.numpy()[2, 1] == 6.0 assert output.numpy()[3, 0] == 3.0 assert output.numpy()[3, 1] == 4.0 assert output.numpy()[4, 0] == 1.0 assert output.numpy()[4, 1] == 2.0 # But can set axis inds = tf.constant([1, 0, 1, 0]) output = ops.gather(vals, inds, axis=1) assert output.ndim == 2 assert output.shape[0] == 3 assert output.shape[1] == 4 assert output.numpy()[0, 0] == 2.0 assert output.numpy()[1, 0] == 4.0 assert output.numpy()[2, 0] == 6.0 assert output.numpy()[0, 1] == 1.0 assert output.numpy()[1, 1] == 3.0 assert output.numpy()[2, 1] == 5.0 assert output.numpy()[0, 2] == 2.0 assert output.numpy()[1, 2] == 4.0 assert output.numpy()[2, 2] == 6.0 assert output.numpy()[0, 3] == 1.0 assert output.numpy()[1, 3] == 3.0 assert output.numpy()[2, 3] == 5.0
def __call__(self, x): """Perform the forward pass""" return O.gather(self.embeddings(), x)