def test_onehot_like(): a = np.array([0.1, 0.5, 0.7, 0.4]) o = onehot_like(a, 2) assert o.shape == a.shape assert o.dtype == a.dtype assert np.all(o[:2] == 0) assert o[2] == 1 assert np.all(o[3:] == 0) o = onehot_like(a, 3, value=-77.5) assert o.shape == a.shape assert o.dtype == a.dtype assert np.all(o[:3] == 0) assert o[3] == -77.5 assert np.all(o[4:] == 0)
def best_other_class(logits, exclude): """ Returns the index of the largest logit, ignoring the class that is passed as `exclude`. """ other_logits = logits - onehot_like(logits, exclude, value=np.inf) return np.argmax(other_logits)