def test_squeeze(): """Tests squeeze""" x = tf.random.normal([3, 2, 1]) xo = ops.squeeze(x) assert isinstance(xo, tf.Tensor) assert xo.ndim == 2 assert xo.shape[0] == 3 assert xo.shape[1] == 2 x = tf.random.normal([3, 2]) xo = ops.squeeze(x) assert isinstance(xo, tf.Tensor) assert xo.ndim == 2 assert xo.shape[0] == 3 assert xo.shape[1] == 2
def log_prob(self, y): """Doesn't broadcast correctly when logits/probs and y are same dims""" if self.ndim == len(y.shape): y = O.squeeze(y) return super().log_prob(y)