Esempio n. 1
0
def generate_data(name = "independent_navigation-v0", custom_args = None):
    
     #Environment:
    parser = argparse.ArgumentParser("Generate Data")
    parser.add_argument("--file_number", default = 0, type=int)
    parser.add_argument("--map_shape", default = 5, type=int)
    parser.add_argument("--n_agents", default = 4, type=int)
    parser.add_argument("--env_name", default = name, type= str)
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--view_d", default = 3, type=int)
    parser.add_argument("--obj_density", default = 0.2, type=float)
    parser.add_argument("--use_custom_rewards", default = False, action='store_true')
    parser.add_argument("--custom_env_ind", default= 1, type=int)
    parser.add_argument("--n_episodes", default= -1, type=int)
    parser.add_argument("--folder_name", default= "none", type=str)
    parser.add_argument("--data_name", default= "none", type=str)
    parser.add_argument("--base_path", default= "none", type=str)

    if custom_args is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(custom_args)

    args.map_shape = (args.map_shape, args.map_shape)

    if not args.n_episodes > 1:
        EPISODES = 5000
    else:
        EPISODES = args.n_episodes
   
    if args.base_path == "none":
        import __main__
        base_path = os.path.dirname(__main__.__file__)
        base_path = os.path.join(base_path, "BC_Data")
    else:
  
        base_path = '/home/james/Desktop/Gridworld/BC_Data'

    data_folder = args.folder_name #'none'
    if args.data_name == "none":
        data_name = "none_" + str(args.file_number) + ".pt"
    else:
        data_name = args.data_name
    
    data_folder_dir = os.path.join(base_path, data_folder) #base_path + '/' + data_folder + '/'
    if not os.path.exists(data_folder_dir):
        os.makedirs(data_folder_dir)

    all_data = generate(make_env(args), EPISODES, args.n_agents)
    torch.save(all_data, os.path.join(data_folder_dir, data_name))
def benchmark(config, logger, policy, num_episodes, render_length, curr_episode):
    env = make_parallel_env(config, 1, np.random.randint(0, 10000))#seed=num_episodes//100)
    hldr = make_env(config)
    max_steps = hldr.max_step

    render_frames = []
    all_infos = []

    for ep_i in range(num_episodes):
        obs = flat_np_lst_env_stack(env.reset(), flat=False) #*
        if ep_i < render_length:
            render_frames.append(env.render(indices = [0])[0])

        policy.prep_rollouts(device='cpu')
        
        for et_i in range(max_steps):
            torch_obs = [torch.tensor(obs[:, i],
                                requires_grad=False)
                        for i in range(policy.nagents)]
            torch_agent_actions = policy.step(torch_obs, explore=False) 
            # convert actions to numpy arrays
            agent_actions = [ac.data.numpy() for ac in torch_agent_actions]
            # rearrange actions to be per environment
            actions = [[ac[i] for ac in agent_actions] for i in range(1)]
            actions_dict = wrap_actions(actions)#[lst_to_dict(a) for a in actions] #*
            next_obs, rewards, dones, infos = env.step(actions_dict)
            if ep_i < render_length:
                render_frames.append(env.render(indices = [0])[0])

            all_infos.append(infos)
            next_obs = flat_np_lst_env_stack(next_obs, flat = False) #*)
            rewards = flat_np_lst_env_stack(rewards, flat=False) #*
            dones = flat_np_lst_env_stack(dones, flat = False) #*
            obs = next_obs
            if infos[0]["terminate"]:
                obs = flat_np_lst_env_stack(env.reset(), flat=False)
                break

    logger.benchmark_info(all_infos, render_frames, curr_episode, parallel_env = True)
def main(args):
    parser = argparse.ArgumentParser("Experiment parameters")
    #Environment:
    parser.add_argument("--map_shape", default=5, type=int)
    parser.add_argument("--n_agents", default=1, type=int)
    parser.add_argument("--obj_density", default=0.0, type=float)
    parser.add_argument("--view_d", default=2, type=int)
    parser.add_argument("--env_name",
                        default="cooperative_navigation-v0",
                        type=str)
    parser.add_argument("--use_custom_rewards",
                        default=False,
                        action='store_true')
    parser.add_argument("--step_r", default=-10, type=float)
    parser.add_argument("--agent_collision_r", default=-10, type=float)
    parser.add_argument("--obstacle_collision_r", default=-10, type=float)
    parser.add_argument("--goal_reached_r", default=-10, type=float)
    parser.add_argument("--finish_episode_r", default=-10, type=float)
    parser.add_argument("--block_r", default=-10, type=float)
    parser.add_argument("--custom_env_ind", default=-1, type=int)

    #Policy:
    parser.add_argument("--policy", default="PPO", type=str)

    #           PPO Paramares:
    parser.add_argument("--ppo_hidden_dim", default=120, type=int)
    parser.add_argument("--ppo_lr_a", default=0.001, type=float)
    parser.add_argument("--ppo_lr_v", default=0.001, type=float)
    parser.add_argument("--ppo_base_policy_type", default="mlp", type=str)
    parser.add_argument("--ppo_recurrent", default=False, action='store_true')
    parser.add_argument("--ppo_heur_block", default=False, action='store_true')
    parser.add_argument("--ppo_heur_valid_act",
                        default=False,
                        action='store_true')
    parser.add_argument("--ppo_heur_no_prev_state",
                        default=False,
                        action='store_true')
    parser.add_argument("--ppo_workers", default=1, type=int)
    parser.add_argument("--ppo_rollout_length", default=32, type=int)
    parser.add_argument("--ppo_share_actor",
                        default=False,
                        action='store_true')
    parser.add_argument("--ppo_share_value",
                        default=False,
                        action='store_true')
    parser.add_argument("--ppo_iterations", default=7000, type=int)
    parser.add_argument("--ppo_k_epochs", default=1, type=int)
    parser.add_argument("--ppo_eps_clip", default=0.2, type=float)
    parser.add_argument("--ppo_minibatch_size", default=32, type=int)
    parser.add_argument("--ppo_entropy_coeff", default=0.01, type=float)
    parser.add_argument("--ppo_value_coeff", default=0.5, type=float)
    parser.add_argument("--ppo_discount", default=0.95, type=float)
    parser.add_argument("--ppo_gae_lambda", default=1.0, type=float)
    parser.add_argument("--ppo_use_gpu", default=False,
                        action='store_true')  #ppo_curr_n_updates
    parser.add_argument("--ppo_curr_n_updates", default=5, type=int)
    parser.add_argument("--ppo_bc_iteration_prob", default=0.0, type=float)
    parser.add_argument("--ppo_continue_from_checkpoint",
                        default=False,
                        action='store_true')

    ################################################################################################
    #           IC3Net: #All parameters from https://github.com/IC3Net/IC3Net/blob/master/comm.py
    #
    parser.add_argument("--hid_size", default=120, type=int)
    parser.add_argument("--recurrent", default=False, action='store_true')
    parser.add_argument("--detach_gap", default=10, type=int)
    parser.add_argument("--comm_passes", default=1, type=int)
    parser.add_argument('--share_weights',
                        default=False,
                        action='store_true',
                        help='Share weights for hops')
    #parser.add_argument("--comm_mode", default = 1, type = int,
    #help= "if mode == 0 -- no communication; mode==1--ic3net communication; mode ==2 -- commNet communication")
    parser.add_argument(
        "--comm_mode",
        default="avg",
        type=str,
        help="Average or sum the hidden states to obtain the comm vector")
    parser.add_argument(
        "--hard_attn",
        default=True,
        action='store_false',
        help="to communicate or not. If hard_attn == False, no comm")
    parser.add_argument(
        "--comm_mask_zero",
        default=False,
        action='store_true',
        help="to communicate or not. If hard_attn == False, no comm")
    parser.add_argument('--comm_action_one',
                        default=False,
                        action='store_true',
                        help='Always communicate.')
    parser.add_argument("--ic3_base_policy_type", default="mlp", type=str)

    ################################################################################################
    #           MAAC parampeters from : https://github.com/shariqiqbal2810/MAAC
    #
    parser.add_argument("--maac_buffer_length", default=int(1e6), type=int)
    parser.add_argument("--maac_n_episodes", default=50000, type=int)
    parser.add_argument("--maac_n_rollout_threads", default=6, type=int)
    parser.add_argument("--maac_steps_per_update", default=100, type=int)
    parser.add_argument("--maac_num_updates",
                        default=4,
                        type=int,
                        help="Number of updates per update cycle")
    parser.add_argument("--maac_batch_size",
                        default=1024,
                        type=int,
                        help="Batch size for training")
    parser.add_argument("--maac_pol_hidden_dim", default=128, type=int)
    parser.add_argument("--maac_critic_hidden_dim", default=128, type=int)
    parser.add_argument("--maac_attend_heads", default=4, type=int)
    parser.add_argument("--maac_pi_lr", default=0.001, type=float)
    parser.add_argument("--maac_q_lr", default=0.001, type=float)
    parser.add_argument("--maac_tau", default=0.001, type=float)
    parser.add_argument("--maac_gamma", default=0.9, type=float)
    parser.add_argument("--maac_reward_scale", default=100., type=float)
    parser.add_argument("--maac_use_gpu", action='store_true')
    parser.add_argument("--maac_share_actor", action='store_true')
    parser.add_argument("--maac_base_policy_type",
                        default='cnn_old')  #Types: mlp, cnn_old, cnn_new

    #Training:
    parser.add_argument(
        "--n_workers",
        default=1,
        type=int,
        help="The number of parallel environments sampled from")
    parser.add_argument("--n_steps",
                        default=1,
                        type=int,
                        help="For AC type policies")

    parser.add_argument("--device", default='cuda', type=str)
    parser.add_argument("--iterations", default=int(3 * 1e6), type=int)
    parser.add_argument("--lr", default=0.001, type=float)
    # parser.add_argument("--lr_step", default = -1, type=int)
    # parser.add_argument("lr_")
    parser.add_argument("--discount", default=1.0, type=float)
    parser.add_argument("--lambda_", default=1.0, type=float)
    parser.add_argument("--value_coeff",
                        default=0.01,
                        type=float,
                        help="Value function update coefficient")
    parser.add_argument("--entropy_coeff",
                        default=0.05,
                        type=float,
                        help="Entropy regularization coefficient")
    parser.add_argument("--model", default="mlp", type=str)
    parser.add_argument("--seed", default=2, type=int)
    parser.add_argument('--batch_size',
                        type=int,
                        default=2000,
                        help='number of steps before each update (per thread)')

    #Saving and rendering
    parser.add_argument("--working_directory", default='none', type=str)
    parser.add_argument(
        "--mean_stats",
        default=1,
        type=int,
        help="The number of iterations over which stats are averaged")
    parser.add_argument("--checkpoint_frequency", default=int(10e3), type=int)
    parser.add_argument("--print_frequency", default=int(5e1), type=int)
    parser.add_argument("--replace_checkpoints", default=True, type=bool)
    parser.add_argument("--render_rate", default=int(5e2 - 10), type=int)
    parser.add_argument("--render_length",
                        default=7,
                        type=int,
                        help="Number of episodes of rendering to save")
    parser.add_argument("--name",
                        default="NO_NAME",
                        type=str,
                        help="Experiment name")
    parser.add_argument(
        "--alternative_plot_dir",
        default="none",
        help="Creates a single folder to store tensorboard plots for comparison"
    )
    parser.add_argument(
        "--benchmark_frequency",
        default=int(3000),
        type=int,
        help=
        "Frequency (iterations or episodes) with which to evalue the greedy policy."
    )
    parser.add_argument("--benchmark_num_episodes", default=int(500), type=int)
    parser.add_argument("--benchmark_render_length",
                        default=int(100),
                        type=int)

    args = parser.parse_args(args)
    args.map_shape = (args.map_shape, args.map_shape)

    config.set_global_seed(args.seed)
    env = make_env(args)

    print("Running: \n {}".format(env.summary()))
    if args.policy == 'IC3':
        policy = make_policy(args, env)

        trainer = Trainer(args, policy, env)

        logger = Logger(args, env.summary(), policy.summary(), policy)

        for iteration in range(args.iterations):
            print("IC3 iteration: {} of {}".format(iteration, args.iterations))
            batch, stats, render_frames = trainer.sample_batch(
                render=logger.should_render())
            stats["value_loss"], stats["action_loss"] = policy.update(batch)

            stats["iterations"] = iteration
            terminal_t_info = [
                inf for i, inf in enumerate(batch.misc) if inf["terminate"]
            ]

            avg_comm = np.average(
                [np.average(inf['comm_action']) for inf in batch.misc])
            logger.log(stats,
                       terminal_t_info,
                       render_frames,
                       commActions=avg_comm)

            if iteration % args.checkpoint_frequency == 0 and iteration != 0:
                path = logger.checkpoint_dir + "/checkpoint_" + str(iteration)
                policy.save(path)

            if iteration % args.benchmark_frequency == 0 and iteration != 0:
                trainer.benchmark_ic3(policy, logger,
                                      args.benchmark_num_episodes,
                                      args.benchmark_render_length, iteration)
        #Benchmark
        trainer.benchmark_ic3(policy, logger, args.benchmark_num_episodes,
                              args.benchmark_render_length, iteration)
        #Checkpoint
        path = logger.checkpoint_dir + "/checkpoint_" + str(args.iterations)
        policy.save(path)
    elif args.policy == 'MAAC':
        from utils.logger import Maac_Logger
        from Agents.MAAC import run
        logger = Maac_Logger(args)
        run(args, logger)
    elif args.policy == 'A2C':
        raise NotImplementedError
        # from Agents.Ind_A2C import run_a2c
        #run_a2c(args)
    elif args.policy == 'PPO':
        from Agents.PPO_IL import run
        run(args)
    elif args.policy == 'CURR_PPO':
        from Agents.PPO_CurriculumTrain import run
        run(args)
    else:
        raise Exception("Policy type not implemented")
def train_PO_FOV_data(custom_args=None):
    '''Check if data has been processed into train, val and test sets.
        If not, make copy of data, and process into diff sets.
        Then starts training with given parameters. '''

    #Helper functions
    def loss_f(pred, label):
        action_label_prob = torch.gather(pred, -1, label.long())
        log_actions = -torch.log(action_label_prob)
        loss = log_actions.mean()
        return loss

    def get_validation_loss(validate_loader, ppo_policy):
        with torch.no_grad():
            mae_like = 0
            total = 0
            valid_loss_hldr = []
            for data in validate_loader:
                if len(data) == 2:
                    ob, a = data
                elif len(data) == 3:
                    ob = (data[0], data[1])
                    a = data[-1]
                else:
                    raise Exception("Data incorrect length")
                (a_pred, _, _, _) = ppo_policy.actors[0].forward(ob)
                valid_loss_hldr.append(loss_f(F.softmax(a_pred), a).item())
        return np.mean(valid_loss_hldr)
        #valid_loss = {"validation_loss": torch.mean(valid_loss_hldr).item()}
        #logger.plot_tensorboard(valid_loss)

    def save(model, logger, end):
        name = "checkpoint_" + str(end)
        #checkpoint_path = os.path.join(logger.checkpoint_dir, name)
        #model.save(checkpoint_path)
        logger.make_checkpoint(False, str(end))

    def is_processesed(path):
        '''Returns bool whether or not there exists 
            three files for train, val and test '''
        files = get_file_names_in_fldr(path)
        file_markers = ["train", "test", "validation"]
        is_processed_flag = True
        for m in file_markers:
            sub_flag = False
            for f in files:
                if m in f:
                    sub_flag = True
            if sub_flag == False:
                is_processed_flag = False
                break
        return is_processed_flag

    parser = argparse.ArgumentParser("Train arguments")
    parser.add_argument("--folder_name", type=str)
    parser.add_argument("--mb_size", default=32, type=int)
    parser.add_argument("--lr", default=5e-5, type=float)
    parser.add_argument("--n_epoch", default=50, type=int)
    parser.add_argument("--weight_decay", default=0.0001, type=float)
    parser.add_argument("--env_name", default='none', type=str)
    parser.add_argument("--alternative_plot_dir", default="none")
    parser.add_argument("--working_directory", default="none")
    parser.add_argument("--name",
                        default="NO_NAME",
                        type=str,
                        help="Experiment name")
    parser.add_argument("--replace_checkpoints", default=True, type=bool)
    #Placeholders:
    parser.add_argument("--n_agents", default=1, type=int)
    parser.add_argument("--map_shape", default=(5, 5), type=object)
    parser.add_argument("--obj_density", default=0.0, type=float)
    parser.add_argument("--view_d", default=3, type=int)
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--use_custom_rewards", default=False, type=bool)
    parser.add_argument("--base_path", default="none", type=str)

    #Best previous performing hyp param is: mb:32 lr:5e-5  weightdecay: 0.0001

    if custom_args is None:
        args = parser.parse_args()
    else:
        args, unkn = parser.parse_known_args(custom_args)

    experiment_group_name = args.folder_name  #"BC_5x5"
    if args.working_directory == "none":
        import __main__
        work_dir = os.path.join(os.path.dirname(__main__.__file__),
                                '/EXPERIMENTS/', experiment_group_name)
    else:
        work_dir = '/home/james/Desktop/Gridworld/EXPERIMENTS/' + experiment_group_name

    if args.alternative_plot_dir == "none":
        import __main__
        plot_dir = os.path.join(os.path.dirname(__main__.__file__),
                                '/CENTRAL_TENSORBOARD/', experiment_group_name)
    else:
        plot_dir = '/home/james/Desktop/Gridworld/CENTRAL_TENSORBOARD/' + experiment_group_name

    work_dir = "/home/jellis/workspace2/gridworld/EXPERIMENTS/" + args.folder_name
    plot_dir = "/home/jellis/workspace2/gridoworld/CENTRAL_TENSORBOARD/" + args.folder_name
    args.working_directory = work_dir
    args.alternative_plot_dir = plot_dir

    #BASE_DATA_FOLDER_PATH = '/home/james/Desktop/Gridworld/BC_Data'

    if args.base_path == "none":
        import __main__
        BASE_DATA_FOLDER_PATH = os.path.dirname(__main__.__file__)
        BASE_DATA_FOLDER_PATH = os.path.join(BASE_DATA_FOLDER_PATH, "BC_Data")
    else:
        #base_path = args.base_path
        BASE_DATA_FOLDER_PATH = '/home/james/Desktop/Gridworld/BC_Data'

    data_fldr_path = os.path.join(BASE_DATA_FOLDER_PATH, args.folder_name)
    data_fldr_path = "/home/jellis/workspace2/BC_Data/"
    if is_processesed(data_fldr_path):
        train_f, val_f, test_f = get_data_files(data_fldr_path)
    else:
        #Copy data:
        to_dir = data_fldr_path + "_cpy"
        os.makedirs(to_dir)
        copy_tree(data_fldr_path, to_dir)
        #Split and save data:
        data, files = combine_all_data(data_fldr_path)
        delete_files(files)
        train_f, val_f, test_f = split_data(data, data_fldr_path, "data_")

    DEVICE = 'gpu'
    #Train on data: keep best policy
    #Get data from files:
    train_data = torch.load(train_f)
    val_data = torch.load(val_f)
    test_data = torch.load(test_f)

    env_hldr = make_env(args)
    observation_space = env_hldr.observation_space[-1]

    name = "BC_" + args.folder_name \
            + "mbsize_" + str(args.mb_size) \
            + "_lr_" + str(args.lr) \
            + "_epochs_" + str(args.n_epoch) \
            + "_weightdecay_" + str(args.weight_decay)

    #args.extend(["--name", name])
    #args = parser.parse_args(args)
    args.name = name

    ppo = PPO(5, observation_space, "primal7", 1, True, True, 1, 1, args.lr,
              0.001, 120, 0.2, 0.01, False, False)

    logger = Logger(args, env_hldr.summary(), "none", ppo)

    #Make training data loader
    (obs, actions) = zip(*train_data)
    #(obs, actions) = (np.array(obs), np.array(actions))
    if type(observation_space) == tuple:
        (obs1, obs2) = zip(*obs)
        (obs, actions) = ((np.array(obs1), np.array(obs2)), np.array(actions))
    else:
        (obs, actions) = (np.array(obs), np.array(actions))
    ppo.prep_device(DEVICE)

    if type(observation_space) == tuple:
        obs = (ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[0]).float()), \
            ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[1]).float()))
    else:
        obs = [ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())]

# obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
    action_labels = ppo.tens_to_dev(
        DEVICE,
        torch.from_numpy(actions).reshape((-1, 1)).float())
    train_dataset = torch.utils.data.TensorDataset(*obs, action_labels)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.mb_size,
                                               shuffle=True)

    #Make validation data_loader
    (obs, actions) = zip(*val_data)
    #(obs, actions) = (np.array(obs), np.array(actions))
    if type(observation_space) == tuple:
        (obs1, obs2) = zip(*obs)
        (obs, actions) = ((np.array(obs1), np.array(obs2)), np.array(actions))
    else:
        (obs, actions) = (np.array(obs), np.array(actions))
    #obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
    if type(observation_space) == tuple:
        obs = (ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[0]).float()), \
            ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[1]).float()))
    else:
        obs = [ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())]

    val_action_labels = ppo.tens_to_dev(
        DEVICE,
        torch.from_numpy(actions).reshape((-1, 1)).float())
    valid_dataset = torch.utils.data.TensorDataset(*obs, val_action_labels)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.mb_size,
                                               shuffle=True)

    #Make test data_loader
    (obs, actions) = zip(*test_data)
    #(obs, actions) = (np.array(obs), np.array(actions))
    if type(observation_space) == tuple:
        (obs1, obs2) = zip(*obs)
        (obs, actions) = ((np.array(obs1), np.array(obs2)), np.array(actions))
    else:
        (obs, actions) = (np.array(obs), np.array(actions))
    #obs = ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())
    if type(observation_space) == tuple:
        obs = (ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[0]).float()), \
            ppo.tens_to_dev(DEVICE, torch.from_numpy(obs[1]).float()))
    else:
        obs = [ppo.tens_to_dev(DEVICE, torch.from_numpy(obs).float())]

    test_action_labels = ppo.tens_to_dev(
        DEVICE,
        torch.from_numpy(actions).reshape((-1, 1)).float())
    test_dataset = torch.utils.data.TensorDataset(*obs, test_action_labels)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.mb_size,
                                              shuffle=True)

    optimizer = optim.Adam(ppo.actors[0].parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    #Train:
    prev_val_loss = 1e6
    best_epoch_nr = None
    val_loss_is_greater_cntr = 0
    val_loss_is_greater_threshhold = 8
    best_policy = None
    print(name)
    for epoch in range(args.n_epoch):
        epoch_loss_hldr = []
        iterations = 0
        for data in train_loader:
            #ob, a = data
            if len(data) == 2:
                ob, a = data
            elif len(data) == 3:
                ob = (data[0], data[1])
                a = data[-1]
            else:
                raise Exception("Data incorrect length")
            (a_pred, _, _, _) = ppo.actors[0].forward(ob)
            loss = loss_f(F.softmax(a_pred), a)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss_hldr.append(loss.item())
            iterations += 1

        print("iterations: {}".format(iterations))
        epoch_loss = np.mean(epoch_loss_hldr)
        valid_loss = get_validation_loss(valid_loader, ppo)
        test_loss = get_validation_loss(test_loader, ppo)
        log_info = {
            "train_loss": epoch_loss,
            "validation_loss": valid_loss,
            "test_loss": test_loss
        }

        logger.plot_tensorboard_custom_keys(log_info, external_iteration=epoch)
        if valid_loss < prev_val_loss:
            save(ppo, logger, epoch)
            best_policy = copy.deepcopy(ppo)
            prev_val_loss = valid_loss
            val_loss_is_greater_cntr = 0
            best_epoch_nr = epoch
        else:
            val_loss_is_greater_cntr += 1

        print("Epoch: {}  Train Loss: {} Validation Loss: {}".format(
            epoch, epoch_loss, valid_loss))
        if val_loss_is_greater_cntr > val_loss_is_greater_threshhold:
            print("Ending training")
            break
    print("Done")

    # free up memory:
    try:
        del train_data
        del val_data
        del test_data
        del action_labels
        del train_dataset
        del train_loader
        del val_action_labels
        del valid_dataset
        del valid_loader
    except:
        pass

    assert not best_policy is None
    print("Best epoch nr is {}".format(best_epoch_nr))

    # 32 x 32
    variable_args_dict = {
        "n_agents": [4, 10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(32, 32)]
    }
    evaluate_across_evs(best_policy,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        DEVICE,
                        greedy=False)

    # 40 x 40
    variable_args_dict = {
        "n_agents": [4, 10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(40, 40)]
    }
    evaluate_across_evs(best_policy,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        DEVICE,
                        greedy=False)
def evaluate_checkpoint2():

    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/BC1AgentShortestPath_2_V0/BC_BC1AgentShortestPath_2_V0mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_40"
    #experiment_group_name = "Results_Shortest Path"

    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/BC1AgentDirVec_2_V1/BC_BC1AgentDirVec_2_V1mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_31"
    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/5_2_0_CL/ppo_arc_primal7_sr_-0.1_ocr_-0.4_acr_-0.4_grr_0.3_fer_2.0_viewd_3_disc_0.5_lambda_1.0_entropy_0.01_minibatch_512_rollouts_256_workers_4_kepochs_8_curr_ppo_cl_inc_size_seed_1_N4/checkpoint/checkpoint_17600"
    CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/5_3_0_CL/ppo_arc_primal7_sr_-0.1_ocr_-0.4_acr_-0.4_grr_0.3_fer_2.0_viewd_3_disc_0.5_lambda_1.0_entropy_0.01_minibatch_512_rollouts_256_workers_4_kepochs_8_curr_ppo_cl_inc_size_seed_1_N0/checkpoint/checkpoint_5200"
    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/9_0_0/BC_9_0_0mbsize_32_lr_5e-05_epochs_50_weightdecay_1e-05_N0/checkpoint/checkpoint_20"
    experiment_group_name = "5_3_0_CL_benchmark3"
    work_dir = '/home/james/Desktop/Gridworld/EXPERIMENTS/' + experiment_group_name
    plot_dir = '/home/james/Desktop/Gridworld/CENTRAL_TENSORBOARD/' + experiment_group_name

    parser = argparse.ArgumentParser("Train arguments")
    parser.add_argument("--alternative_plot_dir", default="none")
    parser.add_argument("--working_directory", default="none")
    parser.add_argument("--name",
                        default="benchmark3",
                        type=str,
                        help="Experiment name")
    parser.add_argument("--replace_checkpoints", default=True, type=bool)
    #Placeholders:
    parser.add_argument("--env_name",
                        default="independent_navigation-v0",
                        type=str)
    parser.add_argument("--n_agents", default=1, type=int)
    parser.add_argument("--map_shape", default=(7, 7), type=object)
    parser.add_argument("--obj_density", default=0.0, type=float)
    parser.add_argument("--view_d", default=3, type=int)
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--use_custom_rewards", default=False, type=bool)
    parser.add_argument("--base_path", default="none", type=str)
    args = parser.parse_args()

    args.working_directory = work_dir
    args.alternative_plot_dir = plot_dir
    ##########
    args = make_env_args(args, {})
    env_hldr = make_env(args)
    observation_space = env_hldr.observation_space[-1]

    ppo = PPO(5, observation_space, "primal7", 1, True, True, 1, 1, 0.001,
              0.001, 120, 0.2, 0.01, False, False)

    ppo.load(torch.load(CHECKPOINT_PATH))
    logger = Logger(args, "NONE", "none", ppo)

    # variable_args_dict = {
    #     "n_agents": [10,20,30,40],
    #     "obj_density": [0.0,0.1, 0.2, 0.3],
    #     "map_shape": [(32, 32)]
    # }

    # evaluate_across_evs(ppo, logger, args, variable_args_dict, 200, 10, 'gpu', greedy=False)

    variable_args_dict = {
        "n_agents": [4],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(7, 7)]
    }
    evaluate_across_evs(ppo,
                        logger,
                        args,
                        variable_args_dict,
                        500,
                        30,
                        'gpu',
                        greedy=False)
def evaluate_checkpoint():

    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/BC1AgentShortestPath_2_V0/BC_BC1AgentShortestPath_2_V0mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_40"
    #experiment_group_name = "Results_Shortest Path"

    #CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/BC1AgentDirVec_2_V1/BC_BC1AgentDirVec_2_V1mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_31"
    CHECKPOINT_PATH = "/home/james/Desktop/Gridworld/EXPERIMENTS/9_0_0/BC_9_0_0mbsize_32_lr_5e-05_epochs_50_weightdecay_0.0001_N0/checkpoint/checkpoint_31"
    experiment_group_name = "9_0_0"
    work_dir = '/home/james/Desktop/Gridworld/EXPERIMENTS/' + experiment_group_name
    plot_dir = '/home/james/Desktop/Gridworld/CENTRAL_TENSORBOARD/' + experiment_group_name

    parser = argparse.ArgumentParser("Train arguments")
    parser.add_argument("--alternative_plot_dir", default="none")
    parser.add_argument("--working_directory", default="none")
    parser.add_argument("--name",
                        default="9_0_0_benchmark",
                        type=str,
                        help="Experiment name")
    parser.add_argument("--replace_checkpoints", default=True, type=bool)
    #Placeholders:
    parser.add_argument("--env_name",
                        default="independent_navigation-v8_0",
                        type=str)
    parser.add_argument("--n_agents", default=1, type=int)
    parser.add_argument("--map_shape", default=(5, 5), type=object)
    parser.add_argument("--obj_density", default=0.0, type=float)
    parser.add_argument("--view_d", default=3, type=int)
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--use_custom_rewards", default=False, type=bool)
    parser.add_argument("--base_path", default="none", type=str)
    args = parser.parse_args()

    args.working_directory = work_dir
    args.alternative_plot_dir = plot_dir
    ##########
    args = make_env_args(args, {})
    env_hldr = make_env(args)
    observation_space = env_hldr.observation_space[-1]

    ppo = PPO(5, observation_space, "primal7", 1, True, True, 1, 1, 0.001,
              0.001, 120, 0.2, 0.01, False, False)

    ppo.load(torch.load(CHECKPOINT_PATH))
    logger = Logger(args, "NONE", "none", ppo)

    # 32 x 32
    variable_args_dict = {
        "n_agents": [10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(32, 32)]
    }
    evaluate_across_evs(ppo,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        'gpu',
                        greedy=False)

    # 40 x 40
    variable_args_dict = {
        "n_agents": [10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(40, 40)]
    }
    evaluate_across_evs(ppo,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        'gpu',
                        greedy=False)

    # 50 x 50
    variable_args_dict = {
        "n_agents": [10, 30, 35, 40, 45, 50, 60, 70],
        "obj_density": [0.0, 0.1, 0.2, 0.3],
        "map_shape": [(50, 50)]
    }
    evaluate_across_evs(ppo,
                        logger,
                        args,
                        variable_args_dict,
                        1000,
                        30,
                        'gpu',
                        greedy=False)
def run(config, logger0):
    config.maac_n_rollout_threads = 1
    start_time = time.time()
    logger = None 
    run_num = 1
    torch.manual_seed(run_num)
    np.random.seed(run_num)
    env = make_parallel_env(config, config.maac_n_rollout_threads, run_num)

    hldr = make_env(config)
    episode_length = hldr.max_step
    del hldr


    model = AttentionSAC.init_from_env(env, 
                                       tau=config.maac_tau,
                                       pi_lr=config.maac_pi_lr,
                                       q_lr=config.maac_q_lr,
                                       gamma=config.maac_gamma,
                                       pol_hidden_dim=config.maac_pol_hidden_dim,
                                       critic_hidden_dim=config.maac_critic_hidden_dim,
                                       attend_heads=config.maac_attend_heads,
                                       reward_scale=config.maac_reward_scale,
                                       share_actor = config.maac_share_actor,
                                       base_policy_type = config.maac_base_policy_type)
    print("model.nagents: {}".format(model.nagents))
    replay_buffer = ReplayBuffer(config.maac_buffer_length, model.nagents,
                                 [obsp.shape for obsp in env.observation_space],
                                 [acsp.shape[0] if isinstance(acsp, Box) else acsp.n
                                  for acsp in env.action_space])

    t = 0
    render_frames = []
    render_counter = 0
    for ep_i in range(0, config.maac_n_episodes, config.maac_n_rollout_threads):
        ETA = ((((time.time() - start_time) /3600.0) / float(ep_i + 1)) * float(config.maac_n_episodes) ) - ((time.time() - start_time)/3600.0)
        print("Episodes %i-%i of %i ETA %f" % (ep_i + 1,
                                        ep_i + 1 + config.maac_n_rollout_threads,
                                        config.maac_n_episodes, ETA))
        obs = flat_np_lst_env_stack(env.reset(), flat=False) #*
        model.prep_rollouts(device='cpu')

        all_infos = []

        for et_i in range(episode_length):
            torch_obs = [torch.tensor(obs[:, i],
                                  requires_grad=False)
                         for i in range(model.nagents)]
            # get actions as torch Variables
            torch_agent_actions = model.step(torch_obs, explore=True) 
            # convert actions to numpy arrays
            agent_actions = [ac.data.numpy() for ac in torch_agent_actions]
            # rearrange actions to be per environment
            actions = [[ac[i] for ac in agent_actions] for i in range(config.maac_n_rollout_threads)]
            #print("actions in maac: {}".format(actions))
            actions_dict = wrap_actions(actions)#[lst_to_dict(a) for a in actions] #*
            next_obs, rewards, dones, infos = env.step(actions_dict)

           # rewards = [{i:0.01 for i in range(model.nagents)}]

            if (ep_i+ 1)%config.render_rate <= (config.render_length * config.maac_n_rollout_threads):
                render_frames.append(env.render(indices = [0])[0])
                if et_i == 0: render_counter += 1
                
            all_infos.append(infos)

            #print("Obs before: {}".format(next_obs))
            next_obs = flat_np_lst_env_stack(next_obs, flat = False) #*
           # print("Rewards before: {}".format(rewards))
            rewards = flat_np_lst_env_stack(rewards, flat=False) #*
            if np.isnan(rewards).any():
                hldr = 1
            dones = flat_np_lst_env_stack(dones, flat = False) #*
            #print("Dones looks like: {}".format(dones))

            replay_buffer.push(obs, agent_actions, rewards, next_obs, dones)
            obs = next_obs
            t += config.maac_n_rollout_threads
            if (len(replay_buffer) >= config.maac_batch_size and
                (t % config.maac_steps_per_update) < config.maac_n_rollout_threads):
                if config.maac_use_gpu:
                    model.prep_training(device='gpu')
                else:
                    model.prep_training(device='cpu')
                for u_i in range(config.maac_num_updates):
                    sample = replay_buffer.sample(config.maac_batch_size,
                                                  to_gpu=config.maac_use_gpu)
                    model.update_critic(sample, logger=logger)
                    model.update_policies(sample, logger=logger)
                    model.update_all_targets()
                model.prep_rollouts(device='cpu')

            if infos[0]["terminate"]:
                obs = flat_np_lst_env_stack(env.reset(), flat=False)
                break

        
        if render_counter == config.render_length:
          render_counter = 0
          logger0.log_ep_info(all_infos, render_frames, ep_i)
          render_frames = []
        else:
          logger0.log_ep_info(all_infos, [], ep_i)

        if ep_i % config.checkpoint_frequency < config.maac_n_rollout_threads:
            model.prep_rollouts(device='cpu')
            model.save(logger0.checkpoint_dir + '/checkpoint_' + str(ep_i //config.checkpoint_frequency) + '.pt')

        if ep_i % config.benchmark_frequency < config.maac_n_rollout_threads and ep_i != 0:
            benchmark(config, logger0, model, config.benchmark_num_episodes, config.benchmark_render_length, ep_i)

    model.save(logger0.checkpoint_dir + '/model.pt')
    benchmark(config, logger0, model, config.benchmark_num_episodes, config.benchmark_render_length, ep_i)
    env.close()
 def init_env():
     env = make_env(args)
     np.random.seed(seed + rank * 1000)
     return env
def run_gridworld():
    #mode = "rgb_array"
    mode = "human"
    parser = argparse.ArgumentParser("Testing")

    #Environment:
    parser.add_argument("--env_name", default = 'independent_navigation-v0', type= str, \
         help="The env type: 'independent_navigation-v0'=Fully observable MAPF problem  ;" \
            + "'independent_navigation-v8_0'=PO MAPF problem with distance to shortest path observation" \
            + "'independent_navigation-v8_1'=PO MAPF problem with direction vector")
    parser.add_argument("--map_shape", default=10, type=int)
    parser.add_argument("--n_agents", default=1, type=int)
    parser.add_argument("--verbose",
                        default=False,
                        action='store_true',
                        help='Prints the observations.')
    parser.add_argument("--use_default_rewards", default=True, type=bool)
    parser.add_argument("--obj_density", default=0.2, type=float)
    parser.add_argument("--view_d", default=3, type=int)
    parser.add_argument("--ppo_recurrent", default=False, action='store_true')
    parser.add_argument("--ppo_heur_block", default=False, action='store_true')
    parser.add_argument("--ppo_heur_valid_act",
                        default=False,
                        action='store_true')
    parser.add_argument("--ppo_heur_no_prev_state",
                        default=False,
                        action='store_true')

    parser.add_argument("--use_custom_rewards",
                        default=False,
                        action='store_true')
    parser.add_argument("--step_r", default=-10, type=float)
    parser.add_argument("--agent_collision_r", default=-10, type=float)
    parser.add_argument("--obstacle_collision_r", default=-10, type=float)
    parser.add_argument("--goal_reached_r", default=-10, type=float)
    parser.add_argument("--finish_episode_r", default=-10, type=float)
    parser.add_argument("--block_r", default=-10, type=float)

    args = parser.parse_args()
    name = args.env_name
    version = int(name[-1])

    if name == 'independent_navigation-v8_0':
        version = 3

    args.map_shape = (args.map_shape, args.map_shape)
    env = make_env(args)
    obs = env.reset()
    env.render(mode=mode)

    #The following is modified from: https://github.com/AIcrowd/flatland-challenge-starter-kit
    print(
        "Manual control: Actions: stay: 0 ; up: 1, right: 2, down: 3 ,left: 4 , step s \n Commands: agnt_handle action  e.g 0 2 s"
    )
    cmd = ""
    while 'q' not in cmd:
        cmd = input(">> ")
        print(cmd)
        cmds = cmd.split(" ")
        action_dict = {}
        i = 0
        while i < len(cmds):
            if cmds[i] == 'q':
                import sys
                sys.exit()
            elif cmds[i] == 's':
                (obs, rewards, dones, info) = env.step(action_dict)
                if args.verbose:
                    if version == 0:
                        headers = ["Channels"]
                        rows = [["Obstacle Channel"], ["Other Agent Channel"], ["Own Goals Channel"], \
                            ["Own Position Channel"], ["Other Goal Channel"]]
                        for agnt in range(args.n_agents):
                            headers.append("Agent {}".format(agnt))
                            rows[0].append(obs[agnt][0])
                            rows[1].append(obs[agnt][1])
                            rows[2].append(obs[agnt][2])
                            rows[3].append(obs[agnt][4])
                            rows[4].append(obs[agnt][3])
                        print(
                            tabulate(rows,
                                     headers=headers,
                                     tablefmt='fancy_grid'))
                        if args.ppo_heur_block:
                            print("Blocking is: {}".format(info["blocking"]))
                        if args.ppo_heur_valid_act:
                            print("Valid actions are: {}".format({
                                k: env.graph.get_valid_actions(k)
                                for k in env.agents.keys()
                            }))
                    elif version == 1 or version == 2:
                        headers = ["Channels"]
                        rows = [["Obstacle Channel"], ["Other Agent Channel"], ["Own Goals Channel"], \
                            ["Other Goal Channel"], ["Vector"]]
                        for agnt in range(args.n_agents):
                            headers.append("Agent {}".format(agnt))
                            rows[0].append(obs[agnt][0][0])
                            rows[1].append(obs[agnt][0][1])
                            rows[2].append(obs[agnt][0][2])
                            rows[3].append(obs[agnt][0][3])
                            rows[4].append(obs[agnt][1])
                        print(
                            tabulate(rows,
                                     headers=headers,
                                     tablefmt='fancy_grid'))
                        if args.ppo_heur_block:
                            print("Blocking is: {}".format(info["blocking"]))
                        if args.ppo_heur_valid_act:
                            print("Valid actions are: {}".format({
                                k: env.graph.get_valid_actions(k)
                                for k in env.agents.keys()
                            }))
                    elif version == 3:
                        headers = ["Channels"]
                        rows = [["Obstacle Channel"], ["Other Agent Channel"], ["Own Goals Channel"], \
                            ["Other Goal Channel"], ["Shortest Path Heur"]]
                        for agnt in range(args.n_agents):
                            headers.append("Agent {}".format(agnt))
                            rows[0].append(obs[agnt][0])
                            rows[1].append(obs[agnt][1])
                            rows[2].append(obs[agnt][2])
                            rows[3].append(obs[agnt][3])
                            rows[4].append(obs[agnt][4].round(3))
                        print(
                            tabulate(rows,
                                     headers=headers,
                                     tablefmt='fancy_grid'))
                        if args.ppo_heur_block:
                            print("Blocking is: {}".format(info["blocking"]))
                        if args.ppo_heur_valid_act:
                            print("Valid actions are: {}".format({
                                k: env.graph.get_valid_actions(k)
                                for k in env.agents.keys()
                            }))
                    else:
                        raise NotImplementedError
                print("Rewards: ", rewards)
                print("Dones: ", dones)
                print("Collisions: ", info["step_collisions"])
            else:
                agent_id = int(cmds[i])
                action = int(cmds[i + 1])
                action_dict[agent_id] = action
                i = i + 1
            i += 1
            r = env.render(mode=mode)
 def init_env():
     env = make_env(args)
     return env