def td_cp_single(f_order, alpha): d = 4 cartpole = CartPole() print('cartpole ', f_order, ' td') weight = np.zeros((1, (f_order + 1) ** d)) # update weight in 100 loops print('alpha = ', alpha) for x in range(100): s = cartpole.d_zero() count = 0 while np.abs(s[0]) < cartpole.edge and np.abs(s[1]) < cartpole.fail_angle and count < 1010: a = cartpole.pi(s) new_s, r = cartpole.P_and_R(s, a) weight += alpha * (r + vw(weight, new_s, f_order) - vw(weight, s, f_order)) * dvwdw(weight, s, f_order).T s = new_s print(weight) count += 1 # calculate td in another 100 loops td_list = [] for x in range(100): s = cartpole.d_zero() count = 0 while np.abs(s[0]) < cartpole.edge and np.abs(s[1]) < cartpole.fail_angle and count < 1010: a = cartpole.pi(s) new_s, r = cartpole.P_and_R(s, a) td_list.append((r + vw(weight, new_s, f_order) - vw(weight, s, f_order)) ** 2) s = new_s count += 1 td_list.append(0) print('square td = ', np.mean(np.array(td_list)))
def td_cp(lrs, f_order): d = 4 alpha_result = [] cartpole = CartPole() print('cartpole ', f_order, ' td') # kth order Fourier Basis is defined as: for alpha in lrs: weight = np.zeros((1, (f_order + 1) ** d)) # update weight in 100 loops print('alpha = ', alpha) for x in range(100): s = cartpole.d_zero() count = 0 while np.abs(s[0]) < cartpole.edge and np.abs(s[1]) < cartpole.fail_angle and count < 1010: a = cartpole.pi(s) new_s, r = cartpole.P_and_R(s, a) weight += alpha * (r + vw(weight, new_s, f_order) - vw(weight, s, f_order)) * dvwdw(weight, s, f_order).T s = new_s count += 1 # print(weight) # calculate td in another 100 loops td_list = [] for x in range(100): s = cartpole.d_zero() count = 0 while np.abs(s[0]) < cartpole.edge and np.abs(s[1]) < cartpole.fail_angle and count < 1010: a = cartpole.pi(s) new_s, r = cartpole.P_and_R(s, a) td_list.append((r + vw(weight, new_s, f_order) - vw(weight, s, f_order)) ** 2) s = new_s count += 1 td_list.append(0) msv = np.mean(np.array(td_list)) print('square td = ', msv) if np.isnan(msv): alpha_result.append(1e100) else: alpha_result.append(msv) print('##########################') return alpha_result