def train(config): base_dir = os.path.join('./results/', args.algo, model_architecture, config.env_id) try: os.makedirs(base_dir) except OSError: files = glob.glob(os.path.join(base_dir, '*.*')) for f in files: os.remove(f) log_dir = os.path.join(base_dir, 'logs/') try: os.makedirs(log_dir) except OSError: files = glob.glob(os.path.join(log_dir, '*.csv'))+glob.glob(os.path.join(log_dir, '*.png')) for f in files: os.remove(f) model_dir = os.path.join(base_dir, 'saved_model/') try: os.makedirs(model_dir) except OSError: files = glob.glob(os.path.join(model_dir, '*.dump')) for f in files: os.remove(f) #save configuration for later reference save_config(config, base_dir) seed = np.random.randint(0, int(1e6)) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) #torch.set_num_threads(1) envs = [make_env_a2c_smb(config.env_id, seed, i, log_dir, stack_frames=config.stack_frames, action_repeat=config.action_repeat, reward_type=config.reward_type) for i in range(config.num_agents)] envs = SubprocVecEnv(envs) if config.num_agents > 1 else DummyVecEnv(envs) env = make_env_a2c_smb(config.env_id, seed, 16, log_dir, stack_frames=config.stack_frames, action_repeat=config.action_repeat, reward_type=config.reward_type) model = Model(env=envs, config=config, log_dir=base_dir) obs = envs.reset() obs = torch.from_numpy(obs.astype(np.float32)).to(config.device) model.config.rollouts.observations[0].copy_(obs) episode_rewards = np.zeros(config.num_agents, dtype=np.float) final_rewards = np.zeros(config.num_agents, dtype=np.float) start=timer() print_threshold = args.print_threshold max_dist = np.zeros(config.num_agents) for frame_idx in range(1, config.MAX_FRAMES+1): for step in range(config.rollout): with torch.no_grad(): values, actions, action_log_prob, states = model.get_action( model.config.rollouts.observations[step], model.config.rollouts.states[step], model.config.rollouts.masks[step]) cpu_actions = actions.view(-1).cpu().numpy() obs, reward, done, info = envs.step(cpu_actions) obs = torch.from_numpy(obs.astype(np.float32)).to(config.device) episode_rewards += reward masks = 1. - done.astype(np.float32) final_rewards *= masks final_rewards += (1. - masks) * episode_rewards episode_rewards *= masks for index, inf in enumerate(info): if inf['x_pos'] < 60000: #there's a simulator glitch? Ignore this value max_dist[index] = np.max((max_dist[index], inf['x_pos'])) if done[index]: model.save_distance(max_dist[index], (frame_idx-1)*config.rollout*config.num_agents+step*config.num_agents+index) max_dist*=masks rewards = torch.from_numpy(reward.astype(np.float32)).view(-1, 1).to(config.device) masks = torch.from_numpy(masks).to(config.device).view(-1, 1) obs *= masks.view(-1, 1, 1, 1) model.config.rollouts.insert(obs, states, actions.view(-1, 1), action_log_prob, values, rewards, masks) with torch.no_grad(): next_value = model.get_values(model.config.rollouts.observations[-1], model.config.rollouts.states[-1], model.config.rollouts.masks[-1]) value_loss, action_loss, dist_entropy = model.update(model.config.rollouts, next_value) model.config.rollouts.after_update() if frame_idx % print_threshold == 0: #save_model if frame_idx % (print_threshold*10) == 0: model.save_w() #print end = timer() total_num_steps = (frame_idx + 1) * config.num_agents * config.rollout print("Updates {}, num timesteps {}, FPS {}, max distance {:.1f}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}". format(frame_idx, total_num_steps, int(total_num_steps / (end - start)), np.mean(max_dist), np.mean(final_rewards), np.median(final_rewards), np.min(final_rewards), np.max(final_rewards), dist_entropy, value_loss, action_loss)) #plot if frame_idx % (print_threshold * 1) == 0: try: # Sometimes monitor doesn't properly flush the outputs plot_all_data(log_dir, config.env_id, 'A2C', config.MAX_FRAMES * config.num_agents * config.rollout, bin_size=(10, 10), smooth=1, time=timedelta(seconds=int(timer()-start)), ipynb=False, action_repeat=config.action_repeat) except IOError: pass #final print try: # Sometimes monitor doesn't properly flush the outputs plot_all_data(log_dir, config.env_id, 'A2C', config.MAX_FRAMES * config.num_agents * config.rollout, bin_size=(10, 10), smooth=1, time=timedelta(seconds=int(timer()-start)), ipynb=False, action_repeat=config.action_repeat) except IOError: pass model.save_w() envs.close()
# Learn agent.update(prev_observation, action, reward, observation, frame_idx) episode_reward += reward # Episode End if done: agent.finish_nstep() agent.save_reward(episode_reward) observation = env.reset() episode_reward = 0 # Log Info if frame_idx % 10000 == 0: agent.save_weight() try: plot_all_data(log_dir, env_id, exp_name, config.MAX_FRAMES, bin_size=(10, 100, 100, 1), save_filename=exp_name + '.png', smooth=1, time=timedelta(seconds=int(timer() - start)), ipynb=False) except IOError: pass
if done: model.finish_nstep() model.reset_hx() observation = env.reset() model.save_reward(episode_reward) episode_reward = 0 if frame_idx % 10000 == 0: model.save_w() try: clear_output(True) plot_all_data(log_dir, env_id, 'C51', param.MAX_FRAMES, bin_size=(10, 100, 100, 1), smooth=1, time=timedelta(seconds=int(timer() - start)), ipynb=False) except IOError: pass model.save_w() env.close() plot_all_data(log_dir, env_id, 'C51', param.MAX_FRAMES, bin_size=(10, 100, 100, 1), smooth=1, time=timedelta(seconds=int(timer() - start)),