def run_bc(args, num_trials, render_len = 5):
    CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/FINAL_COMPARISON/Checkpoint_Policies/bc/checkpoint_20"
    env_name = "independent_navigation-v8_0"
    args.env_name = env_name
    env = Independent_NavigationV8_0(args)
    obs_space = env.observation_space[-1]
    ppo = PPO(5, obs_space, "primal7", env.n_agents, True, True, 8, 512,recurrent=False)
    model_info = torch.load(CHECKPOINT_PATH)
    ppo.load(model_info)
    DEVICE = 'cpu'
    render_frames, results = benchmark_func(args, True, ppo, num_trials, render_len, DEVICE)
    return render_frames, results
def run(args):
    if args.ppo_use_gpu:
        device = 'gpu'
    else:
        device = 'cpu'
    buff = PPO_Buffer(args.n_agents, args.ppo_workers, args.ppo_rollout_length,
                      args.ppo_recurrent)
    env = make_parallel_env(args, args.seed, buff.nworkers)
    ppo = PPO(env.action_space[0].n, env.observation_space[0],
              args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor,
              args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size,
              args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim,
              args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent,
              args.ppo_heur_block)
    buff.init_model(ppo)
    logger = Logger(args, "No summary", "no policy summary", ppo)

    obs = env.reset()
    if args.ppo_recurrent:
        hx_cx_actr = [{i: ppo.init_hx_cx(device)
                       for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: ppo.init_hx_cx(device)
                     for i in ob.keys()} for ob in obs]
    else:
        hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

    stats = {}
    for it in range(args.ppo_iterations):
        print("Iteration: {}".format(it))
        render_frames = []
        render_cntr = 0
        info2 = [{"valid_act": {i: None for i in ob.keys()}} for ob in obs]
        while buff.is_full == False:
            if args.ppo_heur_valid_act:
                val_act_hldr = copy.deepcopy(env.return_valid_act())
                info2 = [{"valid_act": hldr} for hldr in val_act_hldr]
            else:
                val_act_hldr = [{i: None for i in ob.keys()} for ob in obs]
            ppo.prep_device(device)
            a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                    for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])
            a_env_dict = [{key: val.item()
                           for key, val in hldr.items()} for hldr in a_select]
            next_obs, r, dones, info = env.step(a_env_dict)
            if it % args.render_rate == 0:
                if render_cntr < args.render_length:
                    render_frames.append(env.render(indices=[0])[0])
                    if info[0]["terminate"]:
                        render_cntr += 1
            next_obs_ = get_n_obs(next_obs, info)
            buff.add(obs, r, value, next_obs_, a_probs, a_select, info, dones,
                     hx_cx_actr, hx_cx_cr, hx_cx_actr_n, hx_cx_cr_n, blocking,
                     val_act_hldr)
            hx_cx_actr, hx_cx_cr = hx_cx_actr_n, hx_cx_cr_n
            #Reset hidden and cell states when epidode done
            for i, inf in enumerate(info):
                if inf["terminate"] and args.ppo_recurrent:
                    hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(hx_cx_cr)
                    hx_cx_actr[i] = {
                        i2: ppo.init_hx_cx(device)
                        for i2 in hx_cx_actr[i].keys()
                    }
                    hx_cx_cr[i] = {
                        i2: ppo.init_hx_cx(device)
                        for i2 in hx_cx_cr[i].keys()
                    }

            obs = next_obs

        if buff.is_full:
            observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act  = buff.sample(args.ppo_discount, \
                args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act)
            stats["action_loss"], stats["value_loss"] \
            = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device)

            if args.ppo_recurrent:
                for a, c in zip(hx_cx_actr, hx_cx_cr):
                    for a2, c2 in zip(a.values(), c.values()):
                        a2[0].detach_()
                        a2[1].detach_()
                        c2[0].detach_()
                        c2[1].detach_()
        if (it + 1) % args.benchmark_frequency == 0:
            rend_f, ter_i = benchmark_func(args, ppo,
                                           args.benchmark_num_episodes,
                                           args.benchmark_render_length,
                                           device)
            logger.benchmark_info(ter_i, rend_f, it)

        #Logging:
        stats["iterations"] = it
        stats["num_timesteps"] = len(infos)
        terminal_t_info = [inf for inf in infos if inf["terminate"]]
        stats["num_episodes"] = len(terminal_t_info)
        logger.log(stats, terminal_t_info, render_frames, checkpoint=True)
        render_frames = []
        render_cntr = 0

    rend_f, ter_i = benchmark_func(args, ppo, args.benchmark_num_episodes,
                                   args.benchmark_render_length, device)
    logger.benchmark_info(ter_i, rend_f, it)
Esempio n. 3
0
def run(args):  #Curriculum train:
    if args.ppo_use_gpu:
        device = 'gpu'
    else:
        device = 'cpu'

    assert args.ppo_bc_iteration_prob >= 0.0 and args.ppo_bc_iteration_prob <= 1.0

    curr_manager = CurriculumManager(args, CurriculumLogger)
    #Get env args for first env
    env_args = curr_manager.init_env_args()
    #Determine number of workers for buffer and init env
    buff = PPO_Buffer(env_args.n_agents, args.ppo_workers,
                      args.ppo_rollout_length, args.ppo_recurrent)
    seed = np.random.randint(0, 100000)
    env = make_parallel_env(env_args, seed, buff.nworkers)
    #Init ppo model
    ppo = PPO(env.action_space[0].n, env.observation_space[0],
              args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor,
              args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size,
              args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim,
              args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent,
              args.ppo_heur_block)
    #Add model to buffer
    buff.init_model(ppo)

    logger = curr_manager.init_logger(ppo, benchmark_func)
    global_iterations = 0
    env.close()
    while not curr_manager.is_done:
        env_args = curr_manager.sample_env()
        buff.__init__(env_args.n_agents, args.ppo_workers,
                      args.ppo_rollout_length,
                      args.ppo_recurrent)  #recalculates nworkers
        ppo.extend_agent_indexes(env_args.n_agents)
        buff.init_model(ppo)
        seed = np.random.randint(0, 10000)
        env = make_parallel_env(env_args, seed, buff.nworkers)

        obs = env.reset()
        if args.ppo_recurrent:
            hx_cx_actr = [{i: ppo.init_hx_cx(device)
                           for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: ppo.init_hx_cx(device)
                         for i in ob.keys()} for ob in obs]
        else:
            hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

        extra_stats = {}
        env_id = curr_manager.curr_env_id
        up_i = 0
        #for up_i in range(curr_manager.n_updates):
        while up_i < curr_manager.n_updates:
            bc_iteration = np.random.choice(
                [True, False],
                p=[args.ppo_bc_iteration_prob, 1 - args.ppo_bc_iteration_prob])
            if bc_iteration:
                n_sub_updates = (
                    (args.ppo_workers * args.ppo_rollout_length) //
                    args.ppo_minibatch_size)
                bc_training_iteration2(args, env_args, ppo, device, 64,
                                       n_sub_updates * 4)
            else:
                print("Iteration: {}".format(global_iterations))
                info2 = [{
                    "valid_act": {i: None
                                  for i in ob.keys()}
                } for ob in obs]
                while buff.is_full == False:
                    if args.ppo_heur_valid_act:
                        val_act_hldr = copy.deepcopy(env.return_valid_act())
                        info2 = [{"valid_act": hldr} for hldr in val_act_hldr]
                    else:
                        val_act_hldr = [{i: None
                                         for i in ob.keys()} for ob in obs]

                    a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                        for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])

                    a_env_dict = [{
                        key: val.item()
                        for key, val in hldr.items()
                    } for hldr in a_select]
                    next_obs, r, dones, info = env.step(a_env_dict)
                    logger.record_render(env_id, env, info[0])
                    next_obs_ = get_n_obs(next_obs, info)

                    buff.add(obs, r, value, next_obs_, a_probs, a_select, info,
                             dones, hx_cx_actr, hx_cx_cr, hx_cx_actr_n,
                             hx_cx_cr_n, blocking, val_act_hldr)

                    hx_cx_actr = [{k: v
                                   for k, v in hldr.items()}
                                  for hldr in hx_cx_actr_n]
                    hx_cx_cr = [{k: v
                                 for k, v in hldr.items()}
                                for hldr in hx_cx_cr_n]
                    #Reset hidden and cell states when epidode done
                    for i, inf in enumerate(info):
                        if inf["terminate"] and args.ppo_recurrent:
                            hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(
                                hx_cx_cr)
                            hx_cx_actr[i] = {
                                i2: ppo.init_hx_cx(device)
                                for i2 in hx_cx_actr[i].keys()
                            }
                            hx_cx_cr[i] = {
                                i2: ppo.init_hx_cx(device)
                                for i2 in hx_cx_cr[i].keys()
                            }
                    obs = next_obs

                if buff.is_full:
                    observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act = buff.sample(args.ppo_discount,\
                    args.ppo_gae_lambda,args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act)
                    extra_stats["action_loss"], extra_stats["value_loss"] \
                    = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device)
                    if args.ppo_recurrent:
                        for a, c in zip(hx_cx_actr, hx_cx_cr):
                            for a2, c2 in zip(a.values(), c.values()):
                                a2[0].detach_()
                                a2[1].detach_()
                                c2[0].detach_()
                                c2[1].detach_()
                    global_iterations += 1
                    logger.log(env_id, infos, extra_stats)
                    if up_i == curr_manager.n_updates - 1:
                        logger.release_render(env_id)
                up_i += 1
        env.close()
    print("Done")
def train_PO_FOV_data(custom_args=None):
    '''Check if data has been processed into train, val and test sets.
        If not, make copy of data, and process into diff sets.
        Then starts training with given parameters. '''

    #Helper functions
    def loss_f(pred, label):
        action_label_prob = torch.gather(pred, -1, label.long())
        log_actions = -torch.log(action_label_prob)
        loss = log_actions.mean()
        return loss

    def get_validation_loss(validate_loader, ppo_policy):
        with torch.no_grad():
            mae_like = 0
            total = 0
            valid_loss_hldr = []
            for data in validate_loader:
                if len(data) == 2:
                    ob, a = data
                elif len(data) == 3:
                    ob = (data[0], data[1])
                    a = data[-1]
                else:
                    raise Exception("Data incorrect length")
                (a_pred, _, _, _) = ppo_policy.actors[0].forward(ob)
                valid_loss_hldr.append(loss_f(F.softmax(a_pred), a).item())
        return np.mean(valid_loss_hldr)
        #valid_loss = {"validation_loss": torch.mean(valid_loss_hldr).item()}
        #logger.plot_tensorboard(valid_loss)

    def save(model, logger, end):
        name = "checkpoint_" + str(end)
        #checkpoint_path = os.path.join(logger.checkpoint_dir, name)
        #model.save(checkpoint_path)
        logger.make_checkpoint(False, str(end))

    def is_processesed(path):
        '''Returns bool whether or not there exists 
            three files for train, val and test '''
        files = get_file_names_in_fldr(path)
        file_markers = ["train", "test", "validation"]
        is_processed_flag = True
        for m in file_markers:
            sub_flag = False
            for f in files:
                if m in f:
                    sub_flag = True
            if sub_flag == False:
                is_processed_flag = False
                break
        return is_processed_flag

    parser = argparse.ArgumentParser("Train arguments")
    parser.add_argument("--folder_name", type=str)
    parser.add_argument("--mb_size", default=32, type=int)
    parser.add_argument("--lr", default=5e-5, type=float)
    parser.add_argument("--n_epoch", default=50, type=int)
    parser.add_argument("--weight_decay", default=0.0001, type=float)
    parser.add_argument("--env_name", default='none', type=str)
    parser.add_argument("--alternative_plot_dir", default="none")
    parser.add_argument("--working_directory", default="none")
    parser.add_argument("--name",
                        default="NO_NAME",
                        type=str,
                        help="Experiment name")
    parser.add_argument("--replace_checkpoints", default=True, type=bool)
    #Placeholders:
    parser.add_argument("--n_agents", default=1, type=int)
    parser.add_argument("--map_shape", default=(5, 5), type=object)
    parser.add_argument("--obj_density", default=0.0, type=float)
    parser.add_argument("--view_d", default=3, type=int)
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--use_custom_rewards", default=False, type=bool)
    parser.add_argument("--base_path", default="none", type=str)

    #Best previous performing hyp param is: mb:32 lr:5e-5  weightdecay: 0.0001

    if custom_args is None:
        args = parser.parse_args()
    else:
        args, unkn = parser.parse_known_args(custom_args)

    experiment_group_name = args.folder_name  #"BC_5x5"
    if args.working_directory == "none":
        import __main__
        work_dir = os.path.join(os.path.dirname(__main__.__file__),
                                '/EXPERIMENTS/', experiment_group_name)
    else:
        work_dir = '/home/james/Desktop/Gridworld/EXPERIMENTS/' + experiment_group_name

    if args.alternative_plot_dir == "none":
        import __main__
        plot_dir = os.path.join(os.path.dirname(__main__.__file__),
                                '/CENTRAL_TENSORBOARD/', experiment_group_name)
    else:
        plot_dir = '/home/james/Desktop/Gridworld/CENTRAL_TENSORBOARD/' + experiment_group_name

    work_dir = "/home/jellis/workspace2/gridworld/EXPERIMENTS/" + args.folder_name
    plot_dir = "/home/jellis/workspace2/gridoworld/CENTRAL_TENSORBOARD/" + args.folder_name
    args.working_directory = work_dir
    args.alternative_plot_dir = plot_dir

    #BASE_DATA_FOLDER_PATH = '/home/james/Desktop/Gridworld/BC_Data'

    if args.base_path == "none":
        import __main__
        BASE_DATA_FOLDER_PATH = os.path.dirname(__main__.__file__)
        BASE_DATA_FOLDER_PATH = os.path.join(BASE_DATA_FOLDER_PATH, "BC_Data")
    else:
        #base_path = args.base_path
        BASE_DATA_FOLDER_PATH = '/home/james/Desktop/Gridworld/BC_Data'

    data_fldr_path = os.path.join(BASE_DATA_FOLDER_PATH, args.folder_name)
    data_fldr_path = "/home/jellis/workspace2/BC_Data/"
    if is_processesed(data_fldr_path):
        train_f, val_f, test_f = get_data_files(data_fldr_path)
    else:
        #Copy data:
        to_dir = data_fldr_path + "_cpy"
        os.makedirs(to_dir)
        copy_tree(data_fldr_path, to_dir)
        #Split and save data:
        data, files = combine_all_data(data_fldr_path)
        delete_files(files)
        train_f, val_f, test_f = split_data(data, data_fldr_path, "data_")

    DEVICE = 'gpu'
    #Train on data: keep best policy
    #Get data from files:
    train_data = torch.load(train_f)
    val_data = torch.load(val_f)
    test_data = torch.load(test_f)

    env_hldr = make_env(args)
    observation_space = env_hldr.observation_space[-1]

    name = "BC_" + args.folder_name \
            + "mbsize_" + str(args.mb_size) \
            + "_lr_" + str(args.lr) \
            + "_epochs_" + str(args.n_epoch) \
            + "_weightdecay_" + str(args.weight_decay)

    #args.extend(["--name", name])
    #args = parser.parse_args(args)
    args.name = name

    ppo = PPO(5, observation_space, "primal7", 1, True, True, 1, 1, args.lr,
              0.001, 120, 0.2, 0.01, False, False)

    logger = Logger(args, env_hldr.summary(), "none", ppo)

    #Make training data loader
    (obs, actions) = zip(*train_data)
    #(obs, actions) = (np.array(obs), np.array(actions))
    if type(observation_space) == tuple:
        (obs1, obs2) = zip(*obs)
        (obs, actions) = ((np.array(obs1), np.array(obs2)), np.array(actions))
    else:
        (obs, actions) = (np.array(obs), np.array(actions))
    ppo.prep_device(DEVICE)

    if type(observation_space) == tuple:
        obs = (ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[0]).float()), \
            ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[1]).float()))
    else:
        obs = [ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())]

# obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
    action_labels = ppo.tens_to_dev(
        DEVICE,
        torch.from_numpy(actions).reshape((-1, 1)).float())
    train_dataset = torch.utils.data.TensorDataset(*obs, action_labels)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.mb_size,
                                               shuffle=True)

    #Make validation data_loader
    (obs, actions) = zip(*val_data)
    #(obs, actions) = (np.array(obs), np.array(actions))
    if type(observation_space) == tuple:
        (obs1, obs2) = zip(*obs)
        (obs, actions) = ((np.array(obs1), np.array(obs2)), np.array(actions))
    else:
        (obs, actions) = (np.array(obs), np.array(actions))
    #obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
    if type(observation_space) == tuple:
        obs = (ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[0]).float()), \
            ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[1]).float()))
    else:
        obs = [ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())]

    val_action_labels = ppo.tens_to_dev(
        DEVICE,
        torch.from_numpy(actions).reshape((-1, 1)).float())
    valid_dataset = torch.utils.data.TensorDataset(*obs, val_action_labels)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.mb_size,
                                               shuffle=True)

    #Make test data_loader
    (obs, actions) = zip(*test_data)
    #(obs, actions) = (np.array(obs), np.array(actions))
    if type(observation_space) == tuple:
        (obs1, obs2) = zip(*obs)
        (obs, actions) = ((np.array(obs1), np.array(obs2)), np.array(actions))
    else:
        (obs, actions) = (np.array(obs), np.array(actions))
    #obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
    if type(observation_space) == tuple:
        obs = (ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[0]).float()), \
            ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[1]).float()))
    else:
        obs = [ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())]

    test_action_labels = ppo.tens_to_dev(
        DEVICE,
        torch.from_numpy(actions).reshape((-1, 1)).float())
    test_dataset = torch.utils.data.TensorDataset(*obs, test_action_labels)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.mb_size,
                                              shuffle=True)

    optimizer = optim.Adam(ppo.actors[0].parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    #Train:
    prev_val_loss = 1e6
    best_epoch_nr = None
    val_loss_is_greater_cntr = 0
    val_loss_is_greater_threshhold = 8
    best_policy = None
    print(name)
    for epoch in range(args.n_epoch):
        epoch_loss_hldr = []
        iterations = 0
        for data in train_loader:
            #ob, a = data
            if len(data) == 2:
                ob, a = data
            elif len(data) == 3:
                ob = (data[0], data[1])
                a = data[-1]
            else:
                raise Exception("Data incorrect length")
            (a_pred, _, _, _) = ppo.actors[0].forward(ob)
            loss = loss_f(F.softmax(a_pred), a)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss_hldr.append(loss.item())
            iterations += 1

        print("iterations: {}".format(iterations))
        epoch_loss = np.mean(epoch_loss_hldr)
        valid_loss = get_validation_loss(valid_loader, ppo)
        test_loss = get_validation_loss(test_loader, ppo)
        log_info = {
            "train_loss": epoch_loss,
            "validation_loss": valid_loss,
            "test_loss": test_loss
        }

        logger.plot_tensorboard_custom_keys(log_info, external_iteration=epoch)
        if valid_loss < prev_val_loss:
            save(ppo, logger, epoch)
            best_policy = copy.deepcopy(ppo)
            prev_val_loss = valid_loss
            val_loss_is_greater_cntr = 0
            best_epoch_nr = epoch
        else:
            val_loss_is_greater_cntr += 1

        print("Epoch: {}  Train Loss: {} Validation Loss: {}".format(
            epoch, epoch_loss, valid_loss))
        if val_loss_is_greater_cntr > val_loss_is_greater_threshhold:
            print("Ending training")
            break
    print("Done")

    # free up memory:
    try:
        del train_data
        del val_data
        del test_data
        del action_labels
        del train_dataset
        del train_loader
        del val_action_labels
        del valid_dataset
        del valid_loader
    except:
        pass

    assert not best_policy is None
    print("Best epoch nr is {}".format(best_epoch_nr))

    # 32 x 32
    variable_args_dict = {
        "n_agents": [4, 10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(32, 32)]
    }
    evaluate_across_evs(best_policy,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        DEVICE,
                        greedy=False)

    # 40 x 40
    variable_args_dict = {
        "n_agents": [4, 10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(40, 40)]
    }
    evaluate_across_evs(best_policy,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        DEVICE,
                        greedy=False)
def evaluate_checkpoint2():

    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/BC1AgentShortestPath_2_V0/BC_BC1AgentShortestPath_2_V0mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_40"
    #experiment_group_name = "Results_Shortest Path"

    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/BC1AgentDirVec_2_V1/BC_BC1AgentDirVec_2_V1mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_31"
    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/5_2_0_CL/ppo_arc_primal7_sr_-0.1_ocr_-0.4_acr_-0.4_grr_0.3_fer_2.0_viewd_3_disc_0.5_lambda_1.0_entropy_0.01_minibatch_512_rollouts_256_workers_4_kepochs_8_curr_ppo_cl_inc_size_seed_1_N4/checkpoint/checkpoint_17600"
    CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/5_3_0_CL/ppo_arc_primal7_sr_-0.1_ocr_-0.4_acr_-0.4_grr_0.3_fer_2.0_viewd_3_disc_0.5_lambda_1.0_entropy_0.01_minibatch_512_rollouts_256_workers_4_kepochs_8_curr_ppo_cl_inc_size_seed_1_N0/checkpoint/checkpoint_5200"
    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/9_0_0/BC_9_0_0mbsize_32_lr_5e-05_epochs_50_weightdecay_1e-05_N0/checkpoint/checkpoint_20"
    experiment_group_name = "5_3_0_CL_benchmark3"
    work_dir = '/home/james/Desktop/Gridworld/EXPERIMENTS/' + experiment_group_name
    plot_dir = '/home/james/Desktop/Gridworld/CENTRAL_TENSORBOARD/' + experiment_group_name

    parser = argparse.ArgumentParser("Train arguments")
    parser.add_argument("--alternative_plot_dir", default="none")
    parser.add_argument("--working_directory", default="none")
    parser.add_argument("--name",
                        default="benchmark3",
                        type=str,
                        help="Experiment name")
    parser.add_argument("--replace_checkpoints", default=True, type=bool)
    #Placeholders:
    parser.add_argument("--env_name",
                        default="independent_navigation-v0",
                        type=str)
    parser.add_argument("--n_agents", default=1, type=int)
    parser.add_argument("--map_shape", default=(7, 7), type=object)
    parser.add_argument("--obj_density", default=0.0, type=float)
    parser.add_argument("--view_d", default=3, type=int)
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--use_custom_rewards", default=False, type=bool)
    parser.add_argument("--base_path", default="none", type=str)
    args = parser.parse_args()

    args.working_directory = work_dir
    args.alternative_plot_dir = plot_dir
    ##########
    args = make_env_args(args, {})
    env_hldr = make_env(args)
    observation_space = env_hldr.observation_space[-1]

    ppo = PPO(5, observation_space, "primal7", 1, True, True, 1, 1, 0.001,
              0.001, 120, 0.2, 0.01, False, False)

    ppo.load(torch.load(CHECKPOINT_PATH))
    logger = Logger(args, "NONE", "none", ppo)

    # variable_args_dict = {
    #     "n_agents": [10,20,30,40],
    #     "obj_density": [0.0,0.1, 0.2, 0.3],
    #     "map_shape": [(32, 32)]
    # }

    # evaluate_across_evs(ppo, logger, args, variable_args_dict, 200, 10, 'gpu', greedy=False)

    variable_args_dict = {
        "n_agents": [4],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(7, 7)]
    }
    evaluate_across_evs(ppo,
                        logger,
                        args,
                        variable_args_dict,
                        500,
                        30,
                        'gpu',
                        greedy=False)
def evaluate_checkpoint():

    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/BC1AgentShortestPath_2_V0/BC_BC1AgentShortestPath_2_V0mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_40"
    #experiment_group_name = "Results_Shortest Path"

    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/BC1AgentDirVec_2_V1/BC_BC1AgentDirVec_2_V1mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_31"
    CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/9_0_0/BC_9_0_0mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_31"
    experiment_group_name = "9_0_0"
    work_dir = '/home/james/Desktop/Gridworld/EXPERIMENTS/' + experiment_group_name
    plot_dir = '/home/james/Desktop/Gridworld/CENTRAL_TENSORBOARD/' + experiment_group_name

    parser = argparse.ArgumentParser("Train arguments")
    parser.add_argument("--alternative_plot_dir", default="none")
    parser.add_argument("--working_directory", default="none")
    parser.add_argument("--name",
                        default="9_0_0_benchmark",
                        type=str,
                        help="Experiment name")
    parser.add_argument("--replace_checkpoints", default=True, type=bool)
    #Placeholders:
    parser.add_argument("--env_name",
                        default="independent_navigation-v8_0",
                        type=str)
    parser.add_argument("--n_agents", default=1, type=int)
    parser.add_argument("--map_shape", default=(5, 5), type=object)
    parser.add_argument("--obj_density", default=0.0, type=float)
    parser.add_argument("--view_d", default=3, type=int)
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--use_custom_rewards", default=False, type=bool)
    parser.add_argument("--base_path", default="none", type=str)
    args = parser.parse_args()

    args.working_directory = work_dir
    args.alternative_plot_dir = plot_dir
    ##########
    args = make_env_args(args, {})
    env_hldr = make_env(args)
    observation_space = env_hldr.observation_space[-1]

    ppo = PPO(5, observation_space, "primal7", 1, True, True, 1, 1, 0.001,
              0.001, 120, 0.2, 0.01, False, False)

    ppo.load(torch.load(CHECKPOINT_PATH))
    logger = Logger(args, "NONE", "none", ppo)

    # 32 x 32
    variable_args_dict = {
        "n_agents": [10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(32, 32)]
    }
    evaluate_across_evs(ppo,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        'gpu',
                        greedy=False)

    # 40 x 40
    variable_args_dict = {
        "n_agents": [10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(40, 40)]
    }
    evaluate_across_evs(ppo,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        'gpu',
                        greedy=False)

    # 50 x 50
    variable_args_dict = {
        "n_agents": [10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(50, 50)]
    }
    evaluate_across_evs(ppo,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        'gpu',
                        greedy=False)
def BC_train():
    #Helper functions
    def loss_f(pred, label):
        action_label_prob = torch.gather(pred, -1, label.long())
        log_actions = -torch.log(action_label_prob)
        loss = log_actions.mean()
        return loss

    def get_validation_loss(validate_loader, ppo_policy):
        with torch.no_grad():
            mae_like = 0
            total = 0
            valid_loss_hldr = []
            for ob, a in validate_loader:
                (a_pred, _, _, _) = ppo_policy.actors[0].forward(ob)
                valid_loss_hldr.append(loss_f(a_pred, a).item())
        return np.mean(valid_loss_hldr)

    def save(model, logger, end):
        name = "checkpoint_" + str(end)
        checkpoint_path = os.path.join(logger.checkpoint_dir, name)
        model.save(checkpoint_path)

    experiment_group_name = "BC_5x5"
    work_dir = '/home/james/Desktop/Gridworld/EXPERIMENTS/' + experiment_group_name
    plot_dir = '/home/james/Desktop/Gridworld/CENTRAL_TENSORBOARD/' + experiment_group_name  #+ "_Central"
    # work_dir = experiment_group_name
    #  plot_dir = experiment_group_name + "_Central"

    #
    parser = argparse.ArgumentParser("Generate Data")

    #
    parser.add_argument("--map_shape", default=(5, 5), type=object)
    parser.add_argument("--n_agents", default=4, type=int)
    parser.add_argument("--env_name",
                        default='independent_navigation-v0',
                        type=str)
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--obj_density", default=0.2, type=int)
    parser.add_argument("--use_custom_rewards",
                        default=False,
                        action='store_true')
    parser.add_argument("--custom_env_ind", default=1, type=int)
    parser.add_argument("--ppo_recurrent", default=False, action='store_true')

    parser.add_argument("--alternative_plot_dir", default="none")
    parser.add_argument("--working_directory", default="none")
    parser.add_argument("--name",
                        default="NO_NAME",
                        type=str,
                        help="Experiment name")

    args = [
        "--working_directory", work_dir, "--alternative_plot_dir", plot_dir
    ]

    DEVICE = 'gpu'

    parmeter_grid1 = {  #1:mbsize_32_lr_0.0001_epochs_100_weightdecay_0.0001  2:ize_32_lr_5e-05_epochs_40_weightdecay_0.0001
        "bc_mb_size": [32],
        "bc_lr": [0.000005],  #0.0001
        "n_epoch": [150],
        "weight_decay": [0.0001]
    }
    grid1 = ParameterGrid(parmeter_grid1)

    data_folder_path = '/home/james/Desktop/Gridworld/BC_Data/5x5'

    combine_data = False

    if combine_data:
        data, files = combine_all_data(data_folder_path)
        delete_files(files)
        train_f, val_f, test_f = split_data(data, data_folder_path, "5x5")
    else:
        train_f, val_f, test_f = get_data_files(data_folder_path)

    #Get data from files:
    train_data = torch.load(train_f)
    val_data = torch.load(val_f)
    test_data = torch.load(test_f)

    for param in grid1:
        name = "BC_" + "5x5_" \
            + "mbsize_" + str(param["bc_mb_size"]) \
            + "_lr_" + str(param["bc_lr"]) \
            + "_epochs_" + str(param["n_epoch"]) \
            + "_weightdecay_" + str(param["weight_decay"]) \

        args.extend(["--name", name])
        args = parser.parse_args(args)

        ppo = PPO(5, spaces.Box(low=0, high=1, shape=(5, 5, 5),
                                dtype=int), "primal6", 1, True, True, 1, 1,
                  param["bc_lr"], 0.001, 120, 0.2, 0.01, False, False)

        logger = Logger(args, "5x5", "none", ppo)

        #Make training data loader
        (obs, actions) = zip(*train_data)
        (obs, actions) = (np.array(obs), np.array(actions))
        ppo.prep_device(DEVICE)
        obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
        action_labels = ppo.tens_to_dev(
            DEVICE,
            torch.from_numpy(actions).reshape((-1, 1)).float())
        train_dataset = torch.utils.data.TensorDataset(obs, action_labels)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=param["bc_mb_size"], shuffle=True)

        #Make validation data_loader
        (obs, actions) = zip(*val_data)
        (obs, actions) = (np.array(obs), np.array(actions))
        obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
        val_action_labels = ppo.tens_to_dev(
            DEVICE,
            torch.from_numpy(actions).reshape((-1, 1)).float())
        valid_dataset = torch.utils.data.TensorDataset(obs, val_action_labels)
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset, batch_size=param["bc_mb_size"], shuffle=True)

        # optimizer = ppo.actors[0].optimizer

        optimizer = optim.Adam(ppo.actors[0].parameters(),
                               lr=param["bc_lr"],
                               weight_decay=param["weight_decay"])

        #Train:
        #print("Nr iterations: {}".format(len(train_loader) / param["bc_mb_size"]))
        #print("dataset size: {}".format(len(train_data)))
        print(name)
        for epoch in range(param["n_epoch"]):
            epoch_loss_hldr = []
            iterations = 0
            for data in train_loader:
                ob, a = data
                (a_pred, _, _, _) = ppo.actors[0].forward(ob)
                loss = loss_f(a_pred, a)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss_hldr.append(loss.item())
                iterations += 1
            if (epoch + 1) % 2 == 0:
                save(ppo, logger, epoch)

            if (epoch + 1) % 10 == 0:
                ppo.extend_agent_indexes(args.n_agents)
                rend_frames, all_info = benchmark_func(args,
                                                       ppo,
                                                       100,
                                                       30,
                                                       DEVICE,
                                                       greedy=True)
                logger.benchmark_info(all_info,
                                      rend_frames,
                                      epoch,
                                      end_str="_greedy")
                rend_frames, all_info = benchmark_func(args,
                                                       ppo,
                                                       100,
                                                       30,
                                                       DEVICE,
                                                       greedy=False)
                logger.benchmark_info(all_info,
                                      rend_frames,
                                      epoch,
                                      end_str="_NotGreedy")

            print("iterations: {}".format(iterations))
            epoch_loss = np.mean(epoch_loss_hldr)
            valid_loss = get_validation_loss(valid_loader, ppo)
            log_info = {
                "train_loss": epoch_loss,
                "validation_loss": valid_loss
            }
            logger.plot_tensorboard_custom_keys(log_info,
                                                external_iteration=epoch)
            print("Epoch: {}  Train Loss: {} Validation Loss: {}".format(
                epoch, epoch_loss, valid_loss))
        print("Done")

        save(ppo, logger, "end")

        #Evaluate policy (benchmark)
        ppo.extend_agent_indexes(args.n_agents)
        rend_frames, all_info = benchmark_func(args, ppo, 100, 30, DEVICE)
        logger.benchmark_info(all_info, rend_frames, param["n_epoch"] + 1)