def test_max(self): helper_test_op([(45,3)], lambda x: x.max(), Tensor.max) helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5)) helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), vals=[ [[1.0,1.0,0.0,1.0]], ]) helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1))
def test_max(self): helper_test_op([(45,3)], lambda x: x.max(), Tensor.max) helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5)) #helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), # vals=[ # [[1.0,1.0,0.0,1.0]], # ]) --> broken test helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1))
def test_max(self): helper_test_op([(45, 3)], lambda x: x.max(), Tensor.max, device=self.device) helper_test_op([(45, 3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device) helper_test_op([(3, 4, 5, 6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device) helper_test_op([(3, 4, 5, 6)], lambda x: x.max(axis=1)[0].mul(0.5), lambda x: Tensor.max(x, axis=1).mul(0.5), device=self.device) helper_test_op([(3, 4, 5, 6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
def test_max_axis(self): helper_test_op([(3, 4, 5, 6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)