def log_prob(w, D): agent = RWSoftmaxAgent(task=MyBanditTask(), learning_rate=sigmoid(w[0]), inverse_softmax_temp=stable_exp(w[1])) for t in range(D.shape[0]): x = D[t, :7] u = D[t, 7:11] r = D[t, 11] x_ = D[t, 12:] agent.log_prob(x, u) agent.learning(x, u, r, x_, None) J = np.diag([grad.sigmoid(w[0]), grad.exp(w[1])]) return -agent.logprob_, -J @ agent.grad_
def log_prob(w, D): lr = sigmoid(w[0], a_min=-6, a_max=6) ist = stable_exp(w[1], a_min=-10, a_max=10) agent = RWSoftmaxAgent(TwoArmedBandit(), lr, ist) L = 0 for t in range(D.shape[0]): x = D[t, :3] u = D[t, 3:5] r = D[t, 5] x_ = D[t, 6:] agent.log_prob(x, u) agent.learning(x, u, r, x_, None) J = np.array([grad.sigmoid(w[0]), grad.exp(w[1])]) return -agent.logprob_, -J * agent.grad_,
def reparam_jac_rwssm(x): return np.diag(np.array([grad.sigmoid(x[0]), grad.exp(x[1]), x[2]]))
def test_sigmoid(): x = np.linspace(-5, 5, 10) f = lambda x: utils.sigmoid(x) ag = elementwise_grad(f)(x) fg = grad.sigmoid(x) assert (np.all(np.linalg.norm(ag - fg) < 1e-6))