def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor: n = x.size(0) assert n > 1, "cannot permute features with batch_size = 1" perm = torch.randperm(n) no_perm = torch.arange(n) while (perm == no_perm).all(): perm = torch.randperm(n) return (x[perm] * feature_mask.to(dtype=x.dtype)) + ( x * feature_mask.bitwise_not().to(dtype=x.dtype))