def main(): import random import gym_micropolis import game_of_life args = get_args() args.log_dir = args.save_dir + '/logs' assert args.algo in ['a2c', 'ppo', 'acktr'] if args.recurrent_policy: assert args.algo in ['a2c', 'ppo'], \ 'Recurrent policy is not implemented for ACKTR' num_updates = int(args.num_frames) // args.num_steps // args.num_processes torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) graph_name = args.save_dir.split('trained_models/')[1].replace('/', ' ') actor_critic = False agent = False past_steps = 0 try: os.makedirs(args.log_dir) except OSError: files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv')) for f in files: if args.overwrite: os.remove(f) else: pass torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") if args.vis: from visdom import Visdom viz = Visdom(port=args.port) win = None win_eval = None if 'GameOfLife' in args.env_name: print('env name: {}'.format(args.env_name)) num_actions = 1 envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, args.log_dir, args.add_timestep, device, False, None, args=args) if isinstance(envs.observation_space, gym.spaces.Discrete): num_inputs = envs.observation_space.n elif isinstance(envs.observation_space, gym.spaces.Box): if len(envs.observation_space.shape) == 3: in_w = envs.observation_space.shape[1] in_h = envs.observation_space.shape[2] else: in_w = 1 in_h = 1 num_inputs = envs.observation_space.shape[0] if isinstance(envs.action_space, gym.spaces.Discrete): out_w = 1 out_h = 1 if 'Micropolis' in args.env_name: #otherwise it's set if args.power_puzzle: num_actions = 1 else: num_actions = 19 # TODO: have this already from env elif 'GameOfLife' in args.env_name: num_actions = 1 else: num_actions = envs.action_space.n elif isinstance(envs.action_space, gym.spaces.Box): if len(envs.action_space.shape) == 3: out_w = envs.action_space.shape[1] out_h = envs.action_space.shape[2] elif len(envs.action_space.shape) == 1: out_w = 1 out_h = 1 num_actions = envs.action_space.shape[-1] print('num actions {}'.format(num_actions)) if args.auto_expand: args.n_recs -= 1 actor_critic = Policy(envs.observation_space.shape, envs.action_space, base_kwargs={'map_width': args.map_width, 'num_actions': num_actions, 'recurrent': args.recurrent_policy, 'in_w': in_w, 'in_h': in_h, 'num_inputs': num_inputs, 'out_w': out_w, 'out_h': out_h}, curiosity=args.curiosity, algo=args.algo, model=args.model, args=args) if args.auto_expand: args.n_recs += 1 evaluator = None if not agent: agent = init_agent(actor_critic, args) #saved_model = os.path.join(args.save_dir, args.env_name + '.pt') if args.load_dir: saved_model = os.path.join(args.load_dir, args.env_name + '.tar') else: saved_model = os.path.join(args.save_dir, args.env_name + '.tar') vec_norm = get_vec_normalize(envs) if os.path.exists(saved_model) and not args.overwrite: checkpoint = torch.load(saved_model) saved_args = checkpoint['args'] actor_critic.load_state_dict(checkpoint['model_state_dict']) #for o, l in zip(agent.optimizer.state_dict, checkpoint['optimizer_state_dict']): # print(o, l) #print(agent.optimizer.state_dict()['param_groups']) #print('\n') #print(checkpoint['model_state_dict']) actor_critic.to(device) actor_critic.cuda() #agent = init_agent(actor_critic, saved_args) agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if args.auto_expand: if not args.n_recs - saved_args.n_recs == 1: print('can expand by 1 rec only from saved model, not {}'.format(args.n_recs - saved_args.n_recs)) raise Exception actor_critic.base.auto_expand() print('expanded net: \n{}'.format(actor_critic.base)) past_steps = checkpoint['past_steps'] ob_rms = checkpoint['ob_rms'] past_steps = next(iter(agent.optimizer.state_dict()['state'].values()))['step'] print('Resuming from step {}'.format(past_steps)) #print(type(next(iter((torch.load(saved_model)))))) #actor_critic, ob_rms = \ # torch.load(saved_model) #agent = \ # torch.load(os.path.join(args.save_dir, args.env_name + '_agent.pt')) #if not agent.optimizer.state_dict()['state'].values(): # past_steps = 0 #else: # raise Exception if vec_norm is not None: vec_norm.eval() vec_norm.ob_rms = ob_rms saved_args.num_frames = args.num_frames saved_args.vis_interval = args.vis_interval saved_args.eval_interval = args.eval_interval saved_args.overwrite = args.overwrite saved_args.n_recs = args.n_recs saved_args.intra_shr = args.intra_shr saved_args.inter_shr = args.inter_shr saved_args.map_width = args.map_width saved_args.render = args.render saved_args.print_map = args.print_map saved_args.load_dir = args.load_dir saved_args.experiment_name = args.experiment_name saved_args.log_dir = args.log_dir saved_args.save_dir = args.save_dir args = saved_args actor_critic.to(device) if 'LSTM' in args.model: recurrent_hidden_state_size = actor_critic.base.get_recurrent_state_size() else: recurrent_hidden_state_size = actor_critic.recurrent_hidden_state_size if args.curiosity: rollouts = CuriosityRolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, recurrent_hidden_state_size, actor_critic.base.feature_state_size(), args=args) else: rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, recurrent_hidden_state_size, args=args) obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) start = time.time() model = actor_critic.base reset_eval = False plotter = None if args.model == 'FractalNet' or args.model == 'fractal': n_cols = model.n_cols if args.rule == 'wide1' and args.n_recs > 3: col_step = 3 else: col_step = 1 else: n_cols = 0 col_step = 1 for j in range(past_steps, num_updates): if reset_eval: print('post eval reset') obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) reset_eval = False #if np.random.rand(1) < 0.1: # envs.venv.venv.remotes[1].send(('setRewardWeights', None)) if args.model == 'FractalNet' and args.drop_path: #if args.intra_shr and args.inter_shr: # n_recs = np.randint # model.set_n_recs() model.set_drop_path() if args.model == 'fixed' and model.RAND: model.num_recursions = random.randint(1, model.map_width * 2) player_act = None for step in range(args.num_steps): # Sample actions with torch.no_grad(): if args.render: if args.num_processes == 1: if not ('Micropolis' in args.env_name or 'GameOfLife' in args.env_name): envs.venv.venv.render() else: pass else: if not ('Micropolis' in args.env_name or 'GameOfLife' in args.env_name): envs.render() envs.venv.venv.render() else: pass #envs.venv.venv.remotes[0].send(('render', None)) #envs.venv.venv.remotes[0].recv() value, action, action_log_probs, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step], player_act=player_act, icm_enabled=args.curiosity, deterministic=False) # Observe reward and next obs obs, reward, done, infos = envs.step(action) player_act = None if args.render: if infos[0]: if 'player_move' in infos[0].keys(): player_act = infos[0]['player_move'] if args.curiosity: # run icm with torch.no_grad(): feature_state, feature_state_pred, action_dist_pred = actor_critic.icm_act( (rollouts.obs[step], obs, action_bin) ) intrinsic_reward = args.eta * ((feature_state - feature_state_pred).pow(2)).sum() / 2. if args.no_reward: reward = 0 reward += intrinsic_reward.cpu() for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) if args.curiosity: rollouts.insert(obs, recurrent_hidden_states, action, action_log_probs, value, reward, masks, feature_state, feature_state_pred, action_bin, action_dist_pred) else: rollouts.insert(obs, recurrent_hidden_states, action, action_log_probs, value, reward, masks) with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau) if args.curiosity: value_loss, action_loss, dist_entropy, fwd_loss, inv_loss = agent.update(rollouts) else: value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() total_num_steps = (j + 1) * args.num_processes * args.num_steps if not dist_entropy: dist_entropy = 0 if j % args.log_interval == 0 and len(episode_rewards) > 1: end = time.time() print("Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n \ dist entropy {:.1f}, val/act loss {:.1f}/{:.1f},". format(j, total_num_steps, int((total_num_steps - past_steps * args.num_processes * args.num_steps) / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), dist_entropy, value_loss, action_loss)) if args.curiosity: print("fwd/inv icm loss {:.1f}/{:.1f}\n". format( fwd_loss, inv_loss)) if (args.eval_interval is not None and len(episode_rewards) > 1 and j % args.eval_interval == 0): if evaluator is None: evaluator = Evaluator(args, actor_critic, device, envs=envs, vec_norm=vec_norm) model = evaluator.actor_critic.base col_idx = [-1, *range(0, n_cols, col_step)] for i in col_idx: evaluator.evaluate(column=i) #num_eval_frames = (args.num_frames // (args.num_steps * args.eval_interval * args.num_processes)) * args.num_processes * args.max_step # making sure the evaluator plots the '-1'st column (the overall net) if args.vis: #and j % args.vis_interval == 0: try: # Sometimes monitor doesn't properly flush the outputs win_eval = evaluator.plotter.visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name, args.algo, args.num_frames, n_graphs= col_idx) except IOError: pass #elif args.model == 'fixed' and model.RAND: # for i in model.eval_recs: # evaluator.evaluate(num_recursions=i) # win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name, # args.algo, args.num_frames, n_graphs=model.eval_recs) #else: # evaluator.evaluate(column=-1) # win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name, # args.algo, args.num_frames) reset_eval = True if j % args.save_interval == 0 and args.save_dir != "": save_path = os.path.join(args.save_dir) try: os.makedirs(save_path) except OSError: pass # A really ugly way to save a model to CPU save_model = actor_critic ob_rms = getattr(get_vec_normalize(envs), 'ob_rms', None) save_model = copy.deepcopy(actor_critic) save_agent = copy.deepcopy(agent) if args.cuda: save_model.cpu() optim_save = save_agent.optimizer.state_dict() # experimental: torch.save({ 'past_steps': next(iter(agent.optimizer.state_dict()['state'].values()))['step'], 'model_state_dict': save_model.state_dict(), 'optimizer_state_dict': optim_save, 'ob_rms': ob_rms, 'args': args }, os.path.join(save_path, args.env_name + ".tar")) #save_model = [save_model, # getattr(get_vec_normalize(envs), 'ob_rms', None)] #torch.save(save_model, os.path.join(save_path, args.env_name + ".pt")) #save_agent = copy.deepcopy(agent) #torch.save(save_agent, os.path.join(save_path, args.env_name + '_agent.pt')) #torch.save(actor_critic.state_dict(), os.path.join(save_path, args.env_name + "_weights.pt")) if args.vis and j % args.vis_interval == 0: if plotter is None: plotter = Plotter(n_cols, args.log_dir, args.num_processes) try: # Sometimes monitor doesn't properly flush the outputs win = plotter.visdom_plot(viz, win, args.log_dir, graph_name, args.algo, args.num_frames) except IOError: pass
def train(self): evaluator = self.evaluator episode_rewards = self.episode_rewards args = self.args actor_critic = self.actor_critic rollouts = self.rollouts agent = self.agent envs = self.envs plotter = self.plotter n_train = self.n_train start = self.start plotter = self.plotter n_cols = self.n_cols model = self.model device = self.device vec_norm = self.vec_norm n_frames = self.n_frames if self.reset_eval: obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) self.reset_eval = False if args.model == 'FractalNet' and args.drop_path: model.set_drop_path() if args.model == 'fixed' and model.RAND: model.num_recursions = random.randint(1, model.map_width * 2) self.player_act = None for self.n_step in range(args.num_steps): # Sample actions self.step() with torch.no_grad(): next_value = actor_critic.get_value( rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau) if args.curiosity: value_loss, action_loss, dist_entropy, fwd_loss, inv_loss = agent.update( rollouts) else: value_loss, action_loss, dist_entropy = agent.update(rollouts) envs.dist_entropy = dist_entropy rollouts.after_update() total_num_steps = (n_train + 1) * args.num_processes * args.num_steps if not dist_entropy: dist_entropy = 0 #print(episode_rewards) #if torch.max(rollouts.rewards) > 0: # print(rollouts.rewards) if args.log and n_train % args.log_interval == 0 and len( episode_rewards) > 1: end = time.time() print( "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.6f}/{:.6f}, min/max reward {:.6f}/{:.6f}\n \ dist entropy {:.6f}, val/act loss {:.6f}/{:.6f},".format( n_train, total_num_steps, int((self.n_frames - self.past_frames) / (end - start)), len(episode_rewards), round(np.mean(episode_rewards), 6), round(np.median(episode_rewards), 6), round(np.min(episode_rewards), 6), round(np.max(episode_rewards), 6), round(dist_entropy, 6), round(value_loss, 6), round(action_loss, 6))) if args.curiosity: print("fwd/inv icm loss {:.1f}/{:.1f}\n".format( fwd_loss, inv_loss)) if (args.eval_interval is not None and len(episode_rewards) > 1 and n_train % args.eval_interval == 0): if evaluator is None: evaluator = Evaluator(args, actor_critic, device, envs=envs, vec_norm=vec_norm, fieldnames=self.fieldnames) self.evaluator = evaluator col_idx = [-1, *[i for i in range(0, n_cols, self.col_step)]] for i in col_idx: evaluator.evaluate(column=i) #num_eval_frames = (args.num_frames // (args.num_steps * args.eval_interval * args.num_processes)) * args.num_processes * args.max_step # making sure the evaluator plots the '-1'st column (the overall net) viz = self.viz win_eval = self.win_eval graph_name = self.graph_name if args.vis: #and n_train % args.vis_interval == 0: try: # Sometimes monitor doesn't properly flush the outputs win_eval = evaluator.plotter.visdom_plot( viz, win_eval, evaluator.eval_log_dir, graph_name, args.algo, args.num_frames, n_graphs=col_idx) except IOError: pass #elif args.model == 'fixed' and model.RAND: # for i in model.eval_recs: # evaluator.evaluate(num_recursions=i) # win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name, # args.algo, args.num_frames, n_graphs=model.eval_recs) #else: # evaluator.evaluate(column=-1) # win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name, # args.algo, args.num_frames) self.reset_eval = True if args.save and n_train % args.save_interval == 0 and args.save_dir != "": save_path = os.path.join(args.save_dir) try: os.makedirs(save_path) except OSError: pass # A really ugly way to save a model to CPU save_model = actor_critic ob_rms = getattr(get_vec_normalize(envs), 'ob_rms', None) save_model = copy.deepcopy(actor_critic) save_agent = copy.deepcopy(agent) if args.cuda: save_model.cpu() optim_save = save_agent.optimizer.state_dict() self.agent = agent self.save_model = save_model self.optim_save = optim_save self.args = args self.ob_rms = ob_rms torch.save(self.get_save_dict(), os.path.join(save_path, args.env_name + ".tar")) #save_model = [save_model, # getattr(get_vec_normalize(envs), 'ob_rms', None)] #torch.save(save_model, os.path.join(save_path, args.env_name + ".pt")) #save_agent = copy.deepcopy(agent) #torch.save(save_agent, os.path.join(save_path, args.env_name + '_agent.pt')) #torch.save(actor_critic.state_dict(), os.path.join(save_path, args.env_name + "_weights.pt")) print('model saved at {}'.format(save_path)) if args.vis and n_train % args.vis_interval == 0: if plotter is None: plotter = Plotter(n_cols, args.log_dir, args.num_processes) try: # Sometimes monitor doesn't properly flush the outputs viz = self.viz win = self.win graph_name = self.graph_name win = plotter.visdom_plot(viz, win, args.log_dir, graph_name, args.algo, args.num_frames) except IOError: pass