def simulate(trial_id): env = make_env('constant_high', cost=COST, ground_truth=TRIALS[trial_id]) df = evaluate(pol, [env] * 30) df['trial_id'] = trial_id return df
def chunk_exp_Q(env_type, seed, chunk_i): name = f'{env_type}_{seed}' os.makedirs(f'data/exp_Q/{name}', exist_ok=True) polfile = f'data/policies/{name}.pkl' pol = load(polfile) chunk = STATE_CHUNKS[chunk_i] env = make_env('constant_high', cost=COST, term_belief=True, ground_truth=False) def Q(state): for action in env.actions(state): if action == env.term_action: q = env.expected_term_reward(state) else: samples = [] for _ in range(1000): env._state = state env.init, r, *_ = env.step(action) samples.append(r + sum(run_episode(pol, env)['rewards'])) q = np.mean(samples) yield {'state': state, 'action': action, 'q': q} pd.DataFrame(list(concat(map( Q, chunk)))).to_pickle(f'data/exp_Q/{name}/{chunk_i}.pkl')
def yoked_rollout(pol, trial_id, clicks, n=1): env = make_env('constant_high', cost=COST, ground_truth=TRIALS[trial_id]) true_init = env.init for _ in range(n): env.init = true_init trace = run_episode(pol, env) yield { 's': env.init, 'a': trace['actions'][0], 'q': sum(trace['rewards']) } for click in clicks: s = env.reset() env.init, r, *_ = env.step(click) q = r + sum(run_episode(pol, env)['rewards']) yield {'s': s, 'a': click, 'q': q}
def chunk_write_rollouts(env_type, seed, chunk_i): polfile = f'data/policies/{env_type}_{seed}.pkl' pol = load(polfile) env = make_env('constant_high', cost=COST) chunk = CLICK_CHUNKS[chunk_i] n = 0 for i, (trial_id, clicks) in enumerate(chunk): data = {'s': [], 'a': [], 'q': [], 'phi': [], 'trial_id': trial_id} for step in yoked_rollout(pol, trial_id, clicks, n=300): data['s'].append(encode_state(step['s'])) data['a'].append(step['a']) data['q'].append(step['q']) phi = pol.phi(step['s'], step['a'], compute_all=True)[1:] phi[0] = env.expected_term_reward(step['s']) data['phi'].append(phi) data['phi'] = np.stack(data['phi']) data['a'] = np.array(data['a']) data['q'] = np.array(data['q']) dump(data, f'data/rollouts/{env_type}_{seed}/{chunk_i}_{i}.pkl') n += len(data['q']) return n