def main_mcts_env(run_name, save_gifs=True, num_parallel_envs=100, mcts_steps=100, max_rollout_depth=10, run_len=100): hw = 15 action_force = 0.6 scaling = 2 results = [] for t in tqdm(range(num_parallel_envs)): env_task = envs.AvoidanceTask(env, action_force=action_force) imgs = [] for i in tqdm(range(run_len)): mcts = MCTS(env_task, max_rollout=max_rollout_depth) action = mcts.run_mcts(mcts_steps) img, _, r, _ = env_task.step(action) results.append(r) imgs.append(img) if save_gifs: imgs = (255 * np.array(imgs)).astype(np.uint8) path = os.getcwd() + '/mcts_results/' try: os.mkdir(path) except OSError: print("Creation of the directory %s failed" % path) else: print("Successfully created the directory %s " % path) imageio.mimsave(path + '/{}_{}.gif'.format(run_name, t), imgs, fps=24) return results
def make_env(): """Create environment.""" config = { 'res': 32, 'hw': 10, 'n': 3, 't': 1., 'm': 1., 'granularity': 50, 'r': 1, 'friction_coefficient': 0} return envs.AvoidanceTask( envs.BillardsEnv(**config), 4, action_force=0.6)
def mcts_based_sampling(model, run_len=100, num_parallel_envs=10, depth=10, mcts_steps=20, res=32, n=3): config = { 'res': 32, 'hw': 10, 'n': 3, 't': 1., 'm': 1., # dt = 0.81 'granularity': 50, 'r': 1 } sars = [] all_envs = [ envs.AvoidanceTask(envs.BillardsEnv(n=n, hw=10, r=1., res=res, seed=s), action_force=0.6, num_stacked=8) for s in range(num_parallel_envs) ] img, actions = initialize_img(all_envs) all_imgs = np.zeros((num_parallel_envs, run_len, res, res, 3)) all_states = np.zeros((num_parallel_envs, run_len, n, 4)) all_actions = np.zeros((num_parallel_envs, run_len, 9)) all_rewards = np.zeros((num_parallel_envs, run_len, 1)) all_dones = np.zeros((num_parallel_envs, run_len, 1)) for i in tqdm(range(run_len)): mcts_actions = run_mcts_model(img, model, actions, max_rollout_depth=depth, num_parallel_envs=num_parallel_envs, mcts_steps=mcts_steps) for j in range(num_parallel_envs): ret_img, state, r, done = all_envs[j].step(mcts_actions[j]) all_imgs[j, i] = ret_img all_states[j, i] = state action = np.zeros(9) action[mcts_actions[j]] = 1. all_actions[j, i - 1] = action all_rewards[j, i] = r all_dones[j, i] = done img, action = update_buffer(img, all_imgs[:, i], actions, mcts_actions) data = dict() data['X'] = all_imgs data['y'] = all_states data['action'] = all_actions data['reward'] = all_rewards data['done'] = all_dones data['type'] = 'avoidance' data['action_force'] = 0.6 data.update({'action_space': 9}) data.update(config) data['coord_lim'] = config['hw'] return data
def main_mcts_model(run_name, restore_point, save_gifs=True, num_parallel_envs=100, mcts_steps=100, max_rollout_depth=10, run_len=100): """ Runs the MCTS agent on a pretrained STIVE world model :param run_name: (str) name for logging files :param restore_point: (str) file location of the STOVE model :param train_data_loc: (str) file location of the STOVE training data for reference :param test_data_loc: (str) file location of the STOVE testing data for reference :param save_gifs: (bool) flag enabling the logging of the gifs :return: total results from the run """ extras = {'nolog': True, 'traindata': './data/avoidance_train.pkl', 'testdata': './data/avoidance_test.pkl'} trainer = main(extras=extras, restore=restore_point) model = trainer.stove with torch.no_grad(): res = 32 # env = envs.BillardsEnv(n=3, hw=10, r=1., res=res) # task = envs.AvoidanceTask(env, action_force=0.6, num_stacked=8) for t in tqdm(range(1)): all_envs = [envs.AvoidanceTask( envs.BillardsEnv(n=3, hw=10, r=1., res=res, seed=s), action_force=0.6, num_stacked=8) for s in range(num_parallel_envs)] img, actions = initialize_img(all_envs) results = [] # initialize frame_buffer in env # infer initial model state imgs = [] for i in range(num_parallel_envs): imgs.append([]) results = [] for i in tqdm(range(run_len)): next_actions = run_mcts_model(img, model, actions, max_rollout_depth=max_rollout_depth, num_parallel_envs=num_parallel_envs, mcts_steps=mcts_steps) for j in range(num_parallel_envs): ret_img, _, r, _ = all_envs[j].step(next_actions[j]) imgs[j].append(ret_img) results.append(float(r)) img, actions = update_buffer(img, ret_img, actions, next_actions) imgs = np.array(imgs) if save_gifs: for i in range(num_parallel_envs): print_imgs = (255 * imgs[i]).astype(np.uint8) imageio.mimsave( './{}_{}_{}.gif'.format(run_name, t, i), print_imgs, fps=24) pickle.dump(results, open('quicksave', 'wb')) return results