def run(args): if args.ppo_use_gpu: device = 'gpu' else: device = 'cpu' buff = PPO_Buffer(args.n_agents, args.ppo_workers, args.ppo_rollout_length, args.ppo_recurrent) env = make_parallel_env(args, args.seed, buff.nworkers) ppo = PPO(env.action_space[0].n, env.observation_space[0], args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor, args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size, args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim, args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent, args.ppo_heur_block) buff.init_model(ppo) logger = Logger(args, "No summary", "no policy summary", ppo) obs = env.reset() if args.ppo_recurrent: hx_cx_actr = [{i: ppo.init_hx_cx(device) for i in ob.keys()} for ob in obs] hx_cx_cr = [{i: ppo.init_hx_cx(device) for i in ob.keys()} for ob in obs] else: hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs] hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs] stats = {} for it in range(args.ppo_iterations): print("Iteration: {}".format(it)) render_frames = [] render_cntr = 0 info2 = [{"valid_act": {i: None for i in ob.keys()}} for ob in obs] while buff.is_full == False: if args.ppo_heur_valid_act: val_act_hldr = copy.deepcopy(env.return_valid_act()) info2 = [{"valid_act": hldr} for hldr in val_act_hldr] else: val_act_hldr = [{i: None for i in ob.keys()} for ob in obs] ppo.prep_device(device) a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \ for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)]) a_env_dict = [{key: val.item() for key, val in hldr.items()} for hldr in a_select] next_obs, r, dones, info = env.step(a_env_dict) if it % args.render_rate == 0: if render_cntr < args.render_length: render_frames.append(env.render(indices=[0])[0]) if info[0]["terminate"]: render_cntr += 1 next_obs_ = get_n_obs(next_obs, info) buff.add(obs, r, value, next_obs_, a_probs, a_select, info, dones, hx_cx_actr, hx_cx_cr, hx_cx_actr_n, hx_cx_cr_n, blocking, val_act_hldr) hx_cx_actr, hx_cx_cr = hx_cx_actr_n, hx_cx_cr_n #Reset hidden and cell states when epidode done for i, inf in enumerate(info): if inf["terminate"] and args.ppo_recurrent: hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(hx_cx_cr) hx_cx_actr[i] = { i2: ppo.init_hx_cx(device) for i2 in hx_cx_actr[i].keys() } hx_cx_cr[i] = { i2: ppo.init_hx_cx(device) for i2 in hx_cx_cr[i].keys() } obs = next_obs if buff.is_full: observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act = buff.sample(args.ppo_discount, \ args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act) stats["action_loss"], stats["value_loss"] \ = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device) if args.ppo_recurrent: for a, c in zip(hx_cx_actr, hx_cx_cr): for a2, c2 in zip(a.values(), c.values()): a2[0].detach_() a2[1].detach_() c2[0].detach_() c2[1].detach_() if (it + 1) % args.benchmark_frequency == 0: rend_f, ter_i = benchmark_func(args, ppo, args.benchmark_num_episodes, args.benchmark_render_length, device) logger.benchmark_info(ter_i, rend_f, it) #Logging: stats["iterations"] = it stats["num_timesteps"] = len(infos) terminal_t_info = [inf for inf in infos if inf["terminate"]] stats["num_episodes"] = len(terminal_t_info) logger.log(stats, terminal_t_info, render_frames, checkpoint=True) render_frames = [] render_cntr = 0 rend_f, ter_i = benchmark_func(args, ppo, args.benchmark_num_episodes, args.benchmark_render_length, device) logger.benchmark_info(ter_i, rend_f, it)
def run(args): #Curriculum train: if args.ppo_use_gpu: device = 'gpu' else: device = 'cpu' assert args.ppo_bc_iteration_prob >= 0.0 and args.ppo_bc_iteration_prob <= 1.0 curr_manager = CurriculumManager(args, CurriculumLogger) #Get env args for first env env_args = curr_manager.init_env_args() #Determine number of workers for buffer and init env buff = PPO_Buffer(env_args.n_agents, args.ppo_workers, args.ppo_rollout_length, args.ppo_recurrent) seed = np.random.randint(0, 100000) env = make_parallel_env(env_args, seed, buff.nworkers) #Init ppo model ppo = PPO(env.action_space[0].n, env.observation_space[0], args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor, args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size, args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim, args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent, args.ppo_heur_block) #Add model to buffer buff.init_model(ppo) logger = curr_manager.init_logger(ppo, benchmark_func) global_iterations = 0 env.close() while not curr_manager.is_done: env_args = curr_manager.sample_env() buff.__init__(env_args.n_agents, args.ppo_workers, args.ppo_rollout_length, args.ppo_recurrent) #recalculates nworkers ppo.extend_agent_indexes(env_args.n_agents) buff.init_model(ppo) seed = np.random.randint(0, 10000) env = make_parallel_env(env_args, seed, buff.nworkers) obs = env.reset() if args.ppo_recurrent: hx_cx_actr = [{i: ppo.init_hx_cx(device) for i in ob.keys()} for ob in obs] hx_cx_cr = [{i: ppo.init_hx_cx(device) for i in ob.keys()} for ob in obs] else: hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs] hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs] extra_stats = {} env_id = curr_manager.curr_env_id up_i = 0 #for up_i in range(curr_manager.n_updates): while up_i < curr_manager.n_updates: bc_iteration = np.random.choice( [True, False], p=[args.ppo_bc_iteration_prob, 1 - args.ppo_bc_iteration_prob]) if bc_iteration: n_sub_updates = ( (args.ppo_workers * args.ppo_rollout_length) // args.ppo_minibatch_size) bc_training_iteration2(args, env_args, ppo, device, 64, n_sub_updates * 4) else: print("Iteration: {}".format(global_iterations)) info2 = [{ "valid_act": {i: None for i in ob.keys()} } for ob in obs] while buff.is_full == False: if args.ppo_heur_valid_act: val_act_hldr = copy.deepcopy(env.return_valid_act()) info2 = [{"valid_act": hldr} for hldr in val_act_hldr] else: val_act_hldr = [{i: None for i in ob.keys()} for ob in obs] a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \ for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)]) a_env_dict = [{ key: val.item() for key, val in hldr.items() } for hldr in a_select] next_obs, r, dones, info = env.step(a_env_dict) logger.record_render(env_id, env, info[0]) next_obs_ = get_n_obs(next_obs, info) buff.add(obs, r, value, next_obs_, a_probs, a_select, info, dones, hx_cx_actr, hx_cx_cr, hx_cx_actr_n, hx_cx_cr_n, blocking, val_act_hldr) hx_cx_actr = [{k: v for k, v in hldr.items()} for hldr in hx_cx_actr_n] hx_cx_cr = [{k: v for k, v in hldr.items()} for hldr in hx_cx_cr_n] #Reset hidden and cell states when epidode done for i, inf in enumerate(info): if inf["terminate"] and args.ppo_recurrent: hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list( hx_cx_cr) hx_cx_actr[i] = { i2: ppo.init_hx_cx(device) for i2 in hx_cx_actr[i].keys() } hx_cx_cr[i] = { i2: ppo.init_hx_cx(device) for i2 in hx_cx_cr[i].keys() } obs = next_obs if buff.is_full: observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act = buff.sample(args.ppo_discount,\ args.ppo_gae_lambda,args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act) extra_stats["action_loss"], extra_stats["value_loss"] \ = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device) if args.ppo_recurrent: for a, c in zip(hx_cx_actr, hx_cx_cr): for a2, c2 in zip(a.values(), c.values()): a2[0].detach_() a2[1].detach_() c2[0].detach_() c2[1].detach_() global_iterations += 1 logger.log(env_id, infos, extra_stats) if up_i == curr_manager.n_updates - 1: logger.release_render(env_id) up_i += 1 env.close() print("Done")