def run(args): #Curriculum train: if args.ppo_use_gpu: device = 'gpu' else: device = 'cpu' assert args.ppo_bc_iteration_prob >= 0.0 and args.ppo_bc_iteration_prob <= 1.0 curr_manager = CurriculumManager(args, CurriculumLogger) #Get env args for first env env_args = curr_manager.init_env_args() #Determine number of workers for buffer and init env buff = PPO_Buffer(env_args.n_agents, args.ppo_workers, args.ppo_rollout_length, args.ppo_recurrent) seed = np.random.randint(0, 100000) env = make_parallel_env(env_args, seed, buff.nworkers) #Init ppo model ppo = PPO(env.action_space[0].n, env.observation_space[0], args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor, args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size, args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim, args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent, args.ppo_heur_block) #Add model to buffer buff.init_model(ppo) logger = curr_manager.init_logger(ppo, benchmark_func) global_iterations = 0 env.close() while not curr_manager.is_done: env_args = curr_manager.sample_env() buff.__init__(env_args.n_agents, args.ppo_workers, args.ppo_rollout_length, args.ppo_recurrent) #recalculates nworkers ppo.extend_agent_indexes(env_args.n_agents) buff.init_model(ppo) seed = np.random.randint(0, 10000) env = make_parallel_env(env_args, seed, buff.nworkers) obs = env.reset() if args.ppo_recurrent: hx_cx_actr = [{i: ppo.init_hx_cx(device) for i in ob.keys()} for ob in obs] hx_cx_cr = [{i: ppo.init_hx_cx(device) for i in ob.keys()} for ob in obs] else: hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs] hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs] extra_stats = {} env_id = curr_manager.curr_env_id up_i = 0 #for up_i in range(curr_manager.n_updates): while up_i < curr_manager.n_updates: bc_iteration = np.random.choice( [True, False], p=[args.ppo_bc_iteration_prob, 1 - args.ppo_bc_iteration_prob]) if bc_iteration: n_sub_updates = ( (args.ppo_workers * args.ppo_rollout_length) // args.ppo_minibatch_size) bc_training_iteration2(args, env_args, ppo, device, 64, n_sub_updates * 4) else: print("Iteration: {}".format(global_iterations)) info2 = [{ "valid_act": {i: None for i in ob.keys()} } for ob in obs] while buff.is_full == False: if args.ppo_heur_valid_act: val_act_hldr = copy.deepcopy(env.return_valid_act()) info2 = [{"valid_act": hldr} for hldr in val_act_hldr] else: val_act_hldr = [{i: None for i in ob.keys()} for ob in obs] a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \ for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)]) a_env_dict = [{ key: val.item() for key, val in hldr.items() } for hldr in a_select] next_obs, r, dones, info = env.step(a_env_dict) logger.record_render(env_id, env, info[0]) next_obs_ = get_n_obs(next_obs, info) buff.add(obs, r, value, next_obs_, a_probs, a_select, info, dones, hx_cx_actr, hx_cx_cr, hx_cx_actr_n, hx_cx_cr_n, blocking, val_act_hldr) hx_cx_actr = [{k: v for k, v in hldr.items()} for hldr in hx_cx_actr_n] hx_cx_cr = [{k: v for k, v in hldr.items()} for hldr in hx_cx_cr_n] #Reset hidden and cell states when epidode done for i, inf in enumerate(info): if inf["terminate"] and args.ppo_recurrent: hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list( hx_cx_cr) hx_cx_actr[i] = { i2: ppo.init_hx_cx(device) for i2 in hx_cx_actr[i].keys() } hx_cx_cr[i] = { i2: ppo.init_hx_cx(device) for i2 in hx_cx_cr[i].keys() } obs = next_obs if buff.is_full: observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act = buff.sample(args.ppo_discount,\ args.ppo_gae_lambda,args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act) extra_stats["action_loss"], extra_stats["value_loss"] \ = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device) if args.ppo_recurrent: for a, c in zip(hx_cx_actr, hx_cx_cr): for a2, c2 in zip(a.values(), c.values()): a2[0].detach_() a2[1].detach_() c2[0].detach_() c2[1].detach_() global_iterations += 1 logger.log(env_id, infos, extra_stats) if up_i == curr_manager.n_updates - 1: logger.release_render(env_id) up_i += 1 env.close() print("Done")
def BC_train(): #Helper functions def loss_f(pred, label): action_label_prob = torch.gather(pred, -1, label.long()) log_actions = -torch.log(action_label_prob) loss = log_actions.mean() return loss def get_validation_loss(validate_loader, ppo_policy): with torch.no_grad(): mae_like = 0 total = 0 valid_loss_hldr = [] for ob, a in validate_loader: (a_pred, _, _, _) = ppo_policy.actors[0].forward(ob) valid_loss_hldr.append(loss_f(a_pred, a).item()) return np.mean(valid_loss_hldr) def save(model, logger, end): name = "checkpoint_" + str(end) checkpoint_path = os.path.join(logger.checkpoint_dir, name) model.save(checkpoint_path) experiment_group_name = "BC_5x5" work_dir = '/home/james/Desktop/Gridworld/EXPERIMENTS/' + experiment_group_name plot_dir = '/home/james/Desktop/Gridworld/CENTRAL_TENSORBOARD/' + experiment_group_name #+ "_Central" # work_dir = experiment_group_name # plot_dir = experiment_group_name + "_Central" # parser = argparse.ArgumentParser("Generate Data") # parser.add_argument("--map_shape", default=(5, 5), type=object) parser.add_argument("--n_agents", default=4, type=int) parser.add_argument("--env_name", default='independent_navigation-v0', type=str) parser.add_argument("--use_default_rewards", default=True, type=bool) parser.add_argument("--obj_density", default=0.2, type=int) parser.add_argument("--use_custom_rewards", default=False, action='store_true') parser.add_argument("--custom_env_ind", default=1, type=int) parser.add_argument("--ppo_recurrent", default=False, action='store_true') parser.add_argument("--alternative_plot_dir", default="none") parser.add_argument("--working_directory", default="none") parser.add_argument("--name", default="NO_NAME", type=str, help="Experiment name") args = [ "--working_directory", work_dir, "--alternative_plot_dir", plot_dir ] DEVICE = 'gpu' parmeter_grid1 = { #1:mbsize_32_lr_0.0001_epochs_100_weightdecay_0.0001 2:ize_32_lr_5e-05_epochs_40_weightdecay_0.0001 "bc_mb_size": [32], "bc_lr": [0.000005], #0.0001 "n_epoch": [150], "weight_decay": [0.0001] } grid1 = ParameterGrid(parmeter_grid1) data_folder_path = '/home/james/Desktop/Gridworld/BC_Data/5x5' combine_data = False if combine_data: data, files = combine_all_data(data_folder_path) delete_files(files) train_f, val_f, test_f = split_data(data, data_folder_path, "5x5") else: train_f, val_f, test_f = get_data_files(data_folder_path) #Get data from files: train_data = torch.load(train_f) val_data = torch.load(val_f) test_data = torch.load(test_f) for param in grid1: name = "BC_" + "5x5_" \ + "mbsize_" + str(param["bc_mb_size"]) \ + "_lr_" + str(param["bc_lr"]) \ + "_epochs_" + str(param["n_epoch"]) \ + "_weightdecay_" + str(param["weight_decay"]) \ args.extend(["--name", name]) args = parser.parse_args(args) ppo = PPO(5, spaces.Box(low=0, high=1, shape=(5, 5, 5), dtype=int), "primal6", 1, True, True, 1, 1, param["bc_lr"], 0.001, 120, 0.2, 0.01, False, False) logger = Logger(args, "5x5", "none", ppo) #Make training data loader (obs, actions) = zip(*train_data) (obs, actions) = (np.array(obs), np.array(actions)) ppo.prep_device(DEVICE) obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float()) action_labels = ppo.tens_to_dev( DEVICE, torch.from_numpy(actions).reshape((-1, 1)).float()) train_dataset = torch.utils.data.TensorDataset(obs, action_labels) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=param["bc_mb_size"], shuffle=True) #Make validation data_loader (obs, actions) = zip(*val_data) (obs, actions) = (np.array(obs), np.array(actions)) obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float()) val_action_labels = ppo.tens_to_dev( DEVICE, torch.from_numpy(actions).reshape((-1, 1)).float()) valid_dataset = torch.utils.data.TensorDataset(obs, val_action_labels) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=param["bc_mb_size"], shuffle=True) # optimizer = ppo.actors[0].optimizer optimizer = optim.Adam(ppo.actors[0].parameters(), lr=param["bc_lr"], weight_decay=param["weight_decay"]) #Train: #print("Nr iterations: {}".format(len(train_loader) / param["bc_mb_size"])) #print("dataset size: {}".format(len(train_data))) print(name) for epoch in range(param["n_epoch"]): epoch_loss_hldr = [] iterations = 0 for data in train_loader: ob, a = data (a_pred, _, _, _) = ppo.actors[0].forward(ob) loss = loss_f(a_pred, a) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss_hldr.append(loss.item()) iterations += 1 if (epoch + 1) % 2 == 0: save(ppo, logger, epoch) if (epoch + 1) % 10 == 0: ppo.extend_agent_indexes(args.n_agents) rend_frames, all_info = benchmark_func(args, ppo, 100, 30, DEVICE, greedy=True) logger.benchmark_info(all_info, rend_frames, epoch, end_str="_greedy") rend_frames, all_info = benchmark_func(args, ppo, 100, 30, DEVICE, greedy=False) logger.benchmark_info(all_info, rend_frames, epoch, end_str="_NotGreedy") print("iterations: {}".format(iterations)) epoch_loss = np.mean(epoch_loss_hldr) valid_loss = get_validation_loss(valid_loader, ppo) log_info = { "train_loss": epoch_loss, "validation_loss": valid_loss } logger.plot_tensorboard_custom_keys(log_info, external_iteration=epoch) print("Epoch: {} Train Loss: {} Validation Loss: {}".format( epoch, epoch_loss, valid_loss)) print("Done") save(ppo, logger, "end") #Evaluate policy (benchmark) ppo.extend_agent_indexes(args.n_agents) rend_frames, all_info = benchmark_func(args, ppo, 100, 30, DEVICE) logger.benchmark_info(all_info, rend_frames, param["n_epoch"] + 1)