Exemplo n.º 1
0
 def _get_one_dim(self, val, key, axis):
     """Slice along one axis, keeping the dimensionality of the input"""
     if isinstance(key, slice):
         if any(k is not None for k in [key.start, key.stop, key.step]):
             ix = np.arange(*key.indices(val.shape[axis]))
             return O.gather(val, ix, axis=axis)
         else:
             return val
     elif isinstance(key, int):
         key %= val.shape[axis]
         return O.gather(val, [key], axis=axis)
     else:
         return O.gather(val, key, axis=axis)
Exemplo n.º 2
0
 def __call__(self, x):
     """Perform the forward pass"""
     embs = [
         O.gather(self.embeddings[i](), x[:, i])
         for i in range(len(self.embeddings))
     ]
     return O.cat(embs, -1)
Exemplo n.º 3
0
def test_gather():
    """Tests gather"""

    pf.set_backend('pytorch')

    # Should lookup along 1st axis by default
    vals = torch.Tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
    inds = torch.LongTensor([0, 1, 2, 1, 0])
    output = ops.gather(vals, inds)
    assert output.ndim == 2
    assert output.shape[0] == 5
    assert output.shape[1] == 2
    assert output.numpy()[0, 0] == 1.0
    assert output.numpy()[0, 1] == 2.0
    assert output.numpy()[1, 0] == 3.0
    assert output.numpy()[1, 1] == 4.0
    assert output.numpy()[2, 0] == 5.0
    assert output.numpy()[2, 1] == 6.0
    assert output.numpy()[3, 0] == 3.0
    assert output.numpy()[3, 1] == 4.0
    assert output.numpy()[4, 0] == 1.0
    assert output.numpy()[4, 1] == 2.0

    # But can set axis
    inds = torch.LongTensor([1, 0, 1, 0])
    output = ops.gather(vals, inds, axis=1)
    assert output.ndim == 2
    assert output.shape[0] == 3
    assert output.shape[1] == 4
    assert output.numpy()[0, 0] == 2.0
    assert output.numpy()[1, 0] == 4.0
    assert output.numpy()[2, 0] == 6.0
    assert output.numpy()[0, 1] == 1.0
    assert output.numpy()[1, 1] == 3.0
    assert output.numpy()[2, 1] == 5.0
    assert output.numpy()[0, 2] == 2.0
    assert output.numpy()[1, 2] == 4.0
    assert output.numpy()[2, 2] == 6.0
    assert output.numpy()[0, 3] == 1.0
    assert output.numpy()[1, 3] == 3.0
    assert output.numpy()[2, 3] == 5.0
Exemplo n.º 4
0
def test_gather():
    """Tests gather"""

    # Should lookup along 1st axis by default
    vals = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
    inds = tf.constant([0, 1, 2, 1, 0], dtype=tf.int32)
    output = ops.gather(vals, inds)
    assert output.ndim == 2
    assert output.shape[0] == 5
    assert output.shape[1] == 2
    assert output.numpy()[0, 0] == 1.0
    assert output.numpy()[0, 1] == 2.0
    assert output.numpy()[1, 0] == 3.0
    assert output.numpy()[1, 1] == 4.0
    assert output.numpy()[2, 0] == 5.0
    assert output.numpy()[2, 1] == 6.0
    assert output.numpy()[3, 0] == 3.0
    assert output.numpy()[3, 1] == 4.0
    assert output.numpy()[4, 0] == 1.0
    assert output.numpy()[4, 1] == 2.0

    # But can set axis
    inds = tf.constant([1, 0, 1, 0])
    output = ops.gather(vals, inds, axis=1)
    assert output.ndim == 2
    assert output.shape[0] == 3
    assert output.shape[1] == 4
    assert output.numpy()[0, 0] == 2.0
    assert output.numpy()[1, 0] == 4.0
    assert output.numpy()[2, 0] == 6.0
    assert output.numpy()[0, 1] == 1.0
    assert output.numpy()[1, 1] == 3.0
    assert output.numpy()[2, 1] == 5.0
    assert output.numpy()[0, 2] == 2.0
    assert output.numpy()[1, 2] == 4.0
    assert output.numpy()[2, 2] == 6.0
    assert output.numpy()[0, 3] == 1.0
    assert output.numpy()[1, 3] == 3.0
    assert output.numpy()[2, 3] == 5.0
Exemplo n.º 5
0
 def __call__(self, x):
     """Perform the forward pass"""
     return O.gather(self.embeddings(), x)