Example #1
0
def run(args):  #Curriculum train:
    if args.ppo_use_gpu:
        device = 'gpu'
    else:
        device = 'cpu'

    assert args.ppo_bc_iteration_prob >= 0.0 and args.ppo_bc_iteration_prob <= 1.0

    curr_manager = CurriculumManager(args, CurriculumLogger)
    #Get env args for first env
    env_args = curr_manager.init_env_args()
    #Determine number of workers for buffer and init env
    buff = PPO_Buffer(env_args.n_agents, args.ppo_workers,
                      args.ppo_rollout_length, args.ppo_recurrent)
    seed = np.random.randint(0, 100000)
    env = make_parallel_env(env_args, seed, buff.nworkers)
    #Init ppo model
    ppo = PPO(env.action_space[0].n, env.observation_space[0],
              args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor,
              args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size,
              args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim,
              args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent,
              args.ppo_heur_block)
    #Add model to buffer
    buff.init_model(ppo)

    logger = curr_manager.init_logger(ppo, benchmark_func)
    global_iterations = 0
    env.close()
    while not curr_manager.is_done:
        env_args = curr_manager.sample_env()
        buff.__init__(env_args.n_agents, args.ppo_workers,
                      args.ppo_rollout_length,
                      args.ppo_recurrent)  #recalculates nworkers
        ppo.extend_agent_indexes(env_args.n_agents)
        buff.init_model(ppo)
        seed = np.random.randint(0, 10000)
        env = make_parallel_env(env_args, seed, buff.nworkers)

        obs = env.reset()
        if args.ppo_recurrent:
            hx_cx_actr = [{i: ppo.init_hx_cx(device)
                           for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: ppo.init_hx_cx(device)
                         for i in ob.keys()} for ob in obs]
        else:
            hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

        extra_stats = {}
        env_id = curr_manager.curr_env_id
        up_i = 0
        #for up_i in range(curr_manager.n_updates):
        while up_i < curr_manager.n_updates:
            bc_iteration = np.random.choice(
                [True, False],
                p=[args.ppo_bc_iteration_prob, 1 - args.ppo_bc_iteration_prob])
            if bc_iteration:
                n_sub_updates = (
                    (args.ppo_workers * args.ppo_rollout_length) //
                    args.ppo_minibatch_size)
                bc_training_iteration2(args, env_args, ppo, device, 64,
                                       n_sub_updates * 4)
            else:
                print("Iteration: {}".format(global_iterations))
                info2 = [{
                    "valid_act": {i: None
                                  for i in ob.keys()}
                } for ob in obs]
                while buff.is_full == False:
                    if args.ppo_heur_valid_act:
                        val_act_hldr = copy.deepcopy(env.return_valid_act())
                        info2 = [{"valid_act": hldr} for hldr in val_act_hldr]
                    else:
                        val_act_hldr = [{i: None
                                         for i in ob.keys()} for ob in obs]

                    a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                        for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])

                    a_env_dict = [{
                        key: val.item()
                        for key, val in hldr.items()
                    } for hldr in a_select]
                    next_obs, r, dones, info = env.step(a_env_dict)
                    logger.record_render(env_id, env, info[0])
                    next_obs_ = get_n_obs(next_obs, info)

                    buff.add(obs, r, value, next_obs_, a_probs, a_select, info,
                             dones, hx_cx_actr, hx_cx_cr, hx_cx_actr_n,
                             hx_cx_cr_n, blocking, val_act_hldr)

                    hx_cx_actr = [{k: v
                                   for k, v in hldr.items()}
                                  for hldr in hx_cx_actr_n]
                    hx_cx_cr = [{k: v
                                 for k, v in hldr.items()}
                                for hldr in hx_cx_cr_n]
                    #Reset hidden and cell states when epidode done
                    for i, inf in enumerate(info):
                        if inf["terminate"] and args.ppo_recurrent:
                            hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(
                                hx_cx_cr)
                            hx_cx_actr[i] = {
                                i2: ppo.init_hx_cx(device)
                                for i2 in hx_cx_actr[i].keys()
                            }
                            hx_cx_cr[i] = {
                                i2: ppo.init_hx_cx(device)
                                for i2 in hx_cx_cr[i].keys()
                            }
                    obs = next_obs

                if buff.is_full:
                    observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act = buff.sample(args.ppo_discount,\
                    args.ppo_gae_lambda,args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act)
                    extra_stats["action_loss"], extra_stats["value_loss"] \
                    = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device)
                    if args.ppo_recurrent:
                        for a, c in zip(hx_cx_actr, hx_cx_cr):
                            for a2, c2 in zip(a.values(), c.values()):
                                a2[0].detach_()
                                a2[1].detach_()
                                c2[0].detach_()
                                c2[1].detach_()
                    global_iterations += 1
                    logger.log(env_id, infos, extra_stats)
                    if up_i == curr_manager.n_updates - 1:
                        logger.release_render(env_id)
                up_i += 1
        env.close()
    print("Done")
def BC_train():
    #Helper functions
    def loss_f(pred, label):
        action_label_prob = torch.gather(pred, -1, label.long())
        log_actions = -torch.log(action_label_prob)
        loss = log_actions.mean()
        return loss

    def get_validation_loss(validate_loader, ppo_policy):
        with torch.no_grad():
            mae_like = 0
            total = 0
            valid_loss_hldr = []
            for ob, a in validate_loader:
                (a_pred, _, _, _) = ppo_policy.actors[0].forward(ob)
                valid_loss_hldr.append(loss_f(a_pred, a).item())
        return np.mean(valid_loss_hldr)

    def save(model, logger, end):
        name = "checkpoint_" + str(end)
        checkpoint_path = os.path.join(logger.checkpoint_dir, name)
        model.save(checkpoint_path)

    experiment_group_name = "BC_5x5"
    work_dir = '/home/james/Desktop/Gridworld/EXPERIMENTS/' + experiment_group_name
    plot_dir = '/home/james/Desktop/Gridworld/CENTRAL_TENSORBOARD/' + experiment_group_name  #+ "_Central"
    # work_dir = experiment_group_name
    #  plot_dir = experiment_group_name + "_Central"

    #
    parser = argparse.ArgumentParser("Generate Data")

    #
    parser.add_argument("--map_shape", default=(5, 5), type=object)
    parser.add_argument("--n_agents", default=4, type=int)
    parser.add_argument("--env_name",
                        default='independent_navigation-v0',
                        type=str)
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--obj_density", default=0.2, type=int)
    parser.add_argument("--use_custom_rewards",
                        default=False,
                        action='store_true')
    parser.add_argument("--custom_env_ind", default=1, type=int)
    parser.add_argument("--ppo_recurrent", default=False, action='store_true')

    parser.add_argument("--alternative_plot_dir", default="none")
    parser.add_argument("--working_directory", default="none")
    parser.add_argument("--name",
                        default="NO_NAME",
                        type=str,
                        help="Experiment name")

    args = [
        "--working_directory", work_dir, "--alternative_plot_dir", plot_dir
    ]

    DEVICE = 'gpu'

    parmeter_grid1 = {  #1:mbsize_32_lr_0.0001_epochs_100_weightdecay_0.0001  2:ize_32_lr_5e-05_epochs_40_weightdecay_0.0001
        "bc_mb_size": [32],
        "bc_lr": [0.000005],  #0.0001
        "n_epoch": [150],
        "weight_decay": [0.0001]
    }
    grid1 = ParameterGrid(parmeter_grid1)

    data_folder_path = '/home/james/Desktop/Gridworld/BC_Data/5x5'

    combine_data = False

    if combine_data:
        data, files = combine_all_data(data_folder_path)
        delete_files(files)
        train_f, val_f, test_f = split_data(data, data_folder_path, "5x5")
    else:
        train_f, val_f, test_f = get_data_files(data_folder_path)

    #Get data from files:
    train_data = torch.load(train_f)
    val_data = torch.load(val_f)
    test_data = torch.load(test_f)

    for param in grid1:
        name = "BC_" + "5x5_" \
            + "mbsize_" + str(param["bc_mb_size"]) \
            + "_lr_" + str(param["bc_lr"]) \
            + "_epochs_" + str(param["n_epoch"]) \
            + "_weightdecay_" + str(param["weight_decay"]) \

        args.extend(["--name", name])
        args = parser.parse_args(args)

        ppo = PPO(5, spaces.Box(low=0, high=1, shape=(5, 5, 5),
                                dtype=int), "primal6", 1, True, True, 1, 1,
                  param["bc_lr"], 0.001, 120, 0.2, 0.01, False, False)

        logger = Logger(args, "5x5", "none", ppo)

        #Make training data loader
        (obs, actions) = zip(*train_data)
        (obs, actions) = (np.array(obs), np.array(actions))
        ppo.prep_device(DEVICE)
        obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
        action_labels = ppo.tens_to_dev(
            DEVICE,
            torch.from_numpy(actions).reshape((-1, 1)).float())
        train_dataset = torch.utils.data.TensorDataset(obs, action_labels)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=param["bc_mb_size"], shuffle=True)

        #Make validation data_loader
        (obs, actions) = zip(*val_data)
        (obs, actions) = (np.array(obs), np.array(actions))
        obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
        val_action_labels = ppo.tens_to_dev(
            DEVICE,
            torch.from_numpy(actions).reshape((-1, 1)).float())
        valid_dataset = torch.utils.data.TensorDataset(obs, val_action_labels)
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset, batch_size=param["bc_mb_size"], shuffle=True)

        # optimizer = ppo.actors[0].optimizer

        optimizer = optim.Adam(ppo.actors[0].parameters(),
                               lr=param["bc_lr"],
                               weight_decay=param["weight_decay"])

        #Train:
        #print("Nr iterations: {}".format(len(train_loader) / param["bc_mb_size"]))
        #print("dataset size: {}".format(len(train_data)))
        print(name)
        for epoch in range(param["n_epoch"]):
            epoch_loss_hldr = []
            iterations = 0
            for data in train_loader:
                ob, a = data
                (a_pred, _, _, _) = ppo.actors[0].forward(ob)
                loss = loss_f(a_pred, a)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss_hldr.append(loss.item())
                iterations += 1
            if (epoch + 1) % 2 == 0:
                save(ppo, logger, epoch)

            if (epoch + 1) % 10 == 0:
                ppo.extend_agent_indexes(args.n_agents)
                rend_frames, all_info = benchmark_func(args,
                                                       ppo,
                                                       100,
                                                       30,
                                                       DEVICE,
                                                       greedy=True)
                logger.benchmark_info(all_info,
                                      rend_frames,
                                      epoch,
                                      end_str="_greedy")
                rend_frames, all_info = benchmark_func(args,
                                                       ppo,
                                                       100,
                                                       30,
                                                       DEVICE,
                                                       greedy=False)
                logger.benchmark_info(all_info,
                                      rend_frames,
                                      epoch,
                                      end_str="_NotGreedy")

            print("iterations: {}".format(iterations))
            epoch_loss = np.mean(epoch_loss_hldr)
            valid_loss = get_validation_loss(valid_loader, ppo)
            log_info = {
                "train_loss": epoch_loss,
                "validation_loss": valid_loss
            }
            logger.plot_tensorboard_custom_keys(log_info,
                                                external_iteration=epoch)
            print("Epoch: {}  Train Loss: {} Validation Loss: {}".format(
                epoch, epoch_loss, valid_loss))
        print("Done")

        save(ppo, logger, "end")

        #Evaluate policy (benchmark)
        ppo.extend_agent_indexes(args.n_agents)
        rend_frames, all_info = benchmark_func(args, ppo, 100, 30, DEVICE)
        logger.benchmark_info(all_info, rend_frames, param["n_epoch"] + 1)