def benchmark_func(env_args, recurrent, model, num_episodes, render_len, device, greedy = False, valid_act = False):

    results = {"agents_on_goal": [],
                "all_done": [],
                "episode_length": [],
                "execution_time": [],
                "obstacle_collisions": [],
                "agent_collisions": []}
    env = make_parallel_env(env_args, np.random.randint(0, 10000), 1)
    render_frames = []
    model.actors[0].eval()
    terminal_info = []
    render_frames.append(env.render(indices = [0])[0])

    for ep in range(num_episodes):
        obs = env.reset()
        if recurrent:
            hx_cx_actr = [{i:model.init_hx_cx(device) for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i:model.init_hx_cx(device) for i in ob.keys()} for ob in obs]
        else:
            hx_cx_actr = [{i:(None,None) for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i:(None,None) for i in ob.keys()} for ob in obs]
        info2 = [{"valid_act":{i:None for i in ob.keys()}} for ob in obs]
        total_ep_time = 0
        for t in itertools.count():
            t1 = time.time()
            if valid_act:
                val_act_hldr = copy.deepcopy(env.return_valid_act())
                info2 = [{"valid_act":hldr} for hldr in val_act_hldr]

            a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[model.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"], greedy = greedy) \
                                    for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])
            
            a_env_dict = [{key:val.item() for key,val in hldr.items()} for hldr in a_select]
            next_obs, r, dones, info = env.step(a_env_dict)
            total_ep_time += (time.time() - t1)

            hx_cx_actr = [{k:v for k,v in hldr.items()} for hldr in hx_cx_actr_n]
            hx_cx_cr = [{k:v for k,v in hldr.items()} for hldr in hx_cx_cr_n]
            if info[0]["terminate"]:
                results["agents_on_goal"].append(info[0]["agent_dones"])
                results["all_done"].append(info[0]["all_agents_on_goal"])
                results["episode_length"].append(t+1)
                results["obstacle_collisions"].append(info[0]["total_obstacle_collisions"])
                results["agent_collisions"].append(info[0]["total_agent_collisions"])
                results["execution_time"].append(total_ep_time)
                total_ep_time = 0

            if ep < render_len:
                if info[0]["terminate"]:
                    render_frames.append(info[0]["terminal_render"])
                else:
                    render_frames.append(env.render(indices = [0])[0])
            obs = copy.deepcopy(next_obs)
            if info[0]["terminate"]:
                terminal_info.append(info[0])
                break
    
    return render_frames, results
def benchmark_func(args,
                   model,
                   num_episodes,
                   render_len,
                   device,
                   greedy=False):
    env = make_parallel_env(args, np.random.randint(0, 10000), 1)
    render_frames = []
    obs = env.reset()
    all_info = []
    render_frames.append(env.render(indices=[0])[0])

    for ep in range(num_episodes):
        if args.ppo_recurrent:
            hx_cx_actr = [{i: model.init_hx_cx(device)
                           for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: model.init_hx_cx(device)
                         for i in ob.keys()} for ob in obs]
        else:
            hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]
        info2 = [{"valid_act": {i: None for i in ob.keys()}} for ob in obs]
        for t in itertools.count():
            model.prep_device(device)
            a_probs, a_select, value, _, _, _ = zip(*[model.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"], greedy=greedy) \
                                    for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])
            a_env_dict = [{key: val.item()
                           for key, val in hldr.items()} for hldr in a_select]
            next_obs, r, dones, info = env.step(a_env_dict)
            if ep < render_len:
                if info[0]["terminate"]:
                    render_frames.append(info[0]["terminal_render"])
                else:
                    render_frames.append(env.render(indices=[0])[0])
            obs = next_obs
            all_info.append(info)
            if info[0]["terminate"]:
                break
    return render_frames, all_info
def run(args):
    if args.ppo_use_gpu:
        device = 'gpu'
    else:
        device = 'cpu'
    buff = PPO_Buffer(args.n_agents, args.ppo_workers, args.ppo_rollout_length,
                      args.ppo_recurrent)
    env = make_parallel_env(args, args.seed, buff.nworkers)
    ppo = PPO(env.action_space[0].n, env.observation_space[0],
              args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor,
              args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size,
              args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim,
              args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent,
              args.ppo_heur_block)
    buff.init_model(ppo)
    logger = Logger(args, "No summary", "no policy summary", ppo)

    obs = env.reset()
    if args.ppo_recurrent:
        hx_cx_actr = [{i: ppo.init_hx_cx(device)
                       for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: ppo.init_hx_cx(device)
                     for i in ob.keys()} for ob in obs]
    else:
        hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

    stats = {}
    for it in range(args.ppo_iterations):
        print("Iteration: {}".format(it))
        render_frames = []
        render_cntr = 0
        info2 = [{"valid_act": {i: None for i in ob.keys()}} for ob in obs]
        while buff.is_full == False:
            if args.ppo_heur_valid_act:
                val_act_hldr = copy.deepcopy(env.return_valid_act())
                info2 = [{"valid_act": hldr} for hldr in val_act_hldr]
            else:
                val_act_hldr = [{i: None for i in ob.keys()} for ob in obs]
            ppo.prep_device(device)
            a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                    for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])
            a_env_dict = [{key: val.item()
                           for key, val in hldr.items()} for hldr in a_select]
            next_obs, r, dones, info = env.step(a_env_dict)
            if it % args.render_rate == 0:
                if render_cntr < args.render_length:
                    render_frames.append(env.render(indices=[0])[0])
                    if info[0]["terminate"]:
                        render_cntr += 1
            next_obs_ = get_n_obs(next_obs, info)
            buff.add(obs, r, value, next_obs_, a_probs, a_select, info, dones,
                     hx_cx_actr, hx_cx_cr, hx_cx_actr_n, hx_cx_cr_n, blocking,
                     val_act_hldr)
            hx_cx_actr, hx_cx_cr = hx_cx_actr_n, hx_cx_cr_n
            #Reset hidden and cell states when epidode done
            for i, inf in enumerate(info):
                if inf["terminate"] and args.ppo_recurrent:
                    hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(hx_cx_cr)
                    hx_cx_actr[i] = {
                        i2: ppo.init_hx_cx(device)
                        for i2 in hx_cx_actr[i].keys()
                    }
                    hx_cx_cr[i] = {
                        i2: ppo.init_hx_cx(device)
                        for i2 in hx_cx_cr[i].keys()
                    }

            obs = next_obs

        if buff.is_full:
            observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act  = buff.sample(args.ppo_discount, \
                args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act)
            stats["action_loss"], stats["value_loss"] \
            = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device)

            if args.ppo_recurrent:
                for a, c in zip(hx_cx_actr, hx_cx_cr):
                    for a2, c2 in zip(a.values(), c.values()):
                        a2[0].detach_()
                        a2[1].detach_()
                        c2[0].detach_()
                        c2[1].detach_()
        if (it + 1) % args.benchmark_frequency == 0:
            rend_f, ter_i = benchmark_func(args, ppo,
                                           args.benchmark_num_episodes,
                                           args.benchmark_render_length,
                                           device)
            logger.benchmark_info(ter_i, rend_f, it)

        #Logging:
        stats["iterations"] = it
        stats["num_timesteps"] = len(infos)
        terminal_t_info = [inf for inf in infos if inf["terminate"]]
        stats["num_episodes"] = len(terminal_t_info)
        logger.log(stats, terminal_t_info, render_frames, checkpoint=True)
        render_frames = []
        render_cntr = 0

    rend_f, ter_i = benchmark_func(args, ppo, args.benchmark_num_episodes,
                                   args.benchmark_render_length, device)
    logger.benchmark_info(ter_i, rend_f, it)
示例#4
0
def run(args):  #Curriculum train:
    if args.ppo_use_gpu:
        device = 'gpu'
    else:
        device = 'cpu'

    assert args.ppo_bc_iteration_prob >= 0.0 and args.ppo_bc_iteration_prob <= 1.0

    curr_manager = CurriculumManager(args, CurriculumLogger)
    #Get env args for first env
    env_args = curr_manager.init_env_args()
    #Determine number of workers for buffer and init env
    buff = PPO_Buffer(env_args.n_agents, args.ppo_workers,
                      args.ppo_rollout_length, args.ppo_recurrent)
    seed = np.random.randint(0, 100000)
    env = make_parallel_env(env_args, seed, buff.nworkers)
    #Init ppo model
    ppo = PPO(env.action_space[0].n, env.observation_space[0],
              args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor,
              args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size,
              args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim,
              args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent,
              args.ppo_heur_block)
    #Add model to buffer
    buff.init_model(ppo)

    logger = curr_manager.init_logger(ppo, benchmark_func)
    global_iterations = 0
    env.close()
    while not curr_manager.is_done:
        env_args = curr_manager.sample_env()
        buff.__init__(env_args.n_agents, args.ppo_workers,
                      args.ppo_rollout_length,
                      args.ppo_recurrent)  #recalculates nworkers
        ppo.extend_agent_indexes(env_args.n_agents)
        buff.init_model(ppo)
        seed = np.random.randint(0, 10000)
        env = make_parallel_env(env_args, seed, buff.nworkers)

        obs = env.reset()
        if args.ppo_recurrent:
            hx_cx_actr = [{i: ppo.init_hx_cx(device)
                           for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: ppo.init_hx_cx(device)
                         for i in ob.keys()} for ob in obs]
        else:
            hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

        extra_stats = {}
        env_id = curr_manager.curr_env_id
        up_i = 0
        #for up_i in range(curr_manager.n_updates):
        while up_i < curr_manager.n_updates:
            bc_iteration = np.random.choice(
                [True, False],
                p=[args.ppo_bc_iteration_prob, 1 - args.ppo_bc_iteration_prob])
            if bc_iteration:
                n_sub_updates = (
                    (args.ppo_workers * args.ppo_rollout_length) //
                    args.ppo_minibatch_size)
                bc_training_iteration2(args, env_args, ppo, device, 64,
                                       n_sub_updates * 4)
            else:
                print("Iteration: {}".format(global_iterations))
                info2 = [{
                    "valid_act": {i: None
                                  for i in ob.keys()}
                } for ob in obs]
                while buff.is_full == False:
                    if args.ppo_heur_valid_act:
                        val_act_hldr = copy.deepcopy(env.return_valid_act())
                        info2 = [{"valid_act": hldr} for hldr in val_act_hldr]
                    else:
                        val_act_hldr = [{i: None
                                         for i in ob.keys()} for ob in obs]

                    a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                        for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])

                    a_env_dict = [{
                        key: val.item()
                        for key, val in hldr.items()
                    } for hldr in a_select]
                    next_obs, r, dones, info = env.step(a_env_dict)
                    logger.record_render(env_id, env, info[0])
                    next_obs_ = get_n_obs(next_obs, info)

                    buff.add(obs, r, value, next_obs_, a_probs, a_select, info,
                             dones, hx_cx_actr, hx_cx_cr, hx_cx_actr_n,
                             hx_cx_cr_n, blocking, val_act_hldr)

                    hx_cx_actr = [{k: v
                                   for k, v in hldr.items()}
                                  for hldr in hx_cx_actr_n]
                    hx_cx_cr = [{k: v
                                 for k, v in hldr.items()}
                                for hldr in hx_cx_cr_n]
                    #Reset hidden and cell states when epidode done
                    for i, inf in enumerate(info):
                        if inf["terminate"] and args.ppo_recurrent:
                            hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(
                                hx_cx_cr)
                            hx_cx_actr[i] = {
                                i2: ppo.init_hx_cx(device)
                                for i2 in hx_cx_actr[i].keys()
                            }
                            hx_cx_cr[i] = {
                                i2: ppo.init_hx_cx(device)
                                for i2 in hx_cx_cr[i].keys()
                            }
                    obs = next_obs

                if buff.is_full:
                    observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act = buff.sample(args.ppo_discount,\
                    args.ppo_gae_lambda,args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act)
                    extra_stats["action_loss"], extra_stats["value_loss"] \
                    = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device)
                    if args.ppo_recurrent:
                        for a, c in zip(hx_cx_actr, hx_cx_cr):
                            for a2, c2 in zip(a.values(), c.values()):
                                a2[0].detach_()
                                a2[1].detach_()
                                c2[0].detach_()
                                c2[1].detach_()
                    global_iterations += 1
                    logger.log(env_id, infos, extra_stats)
                    if up_i == curr_manager.n_updates - 1:
                        logger.release_render(env_id)
                up_i += 1
        env.close()
    print("Done")
示例#5
0
def bc_training_iteration2(args, env_args, ppo, device, minibatch_size,
                           n_sub_updates):
    def make_start_postion_list(env_hldr):
        '''Assumes agent keys in evn.agents is the same as agent id's '''
        start_positions = []
        for i in range(len(env_hldr.agents.values())):
            start_positions.append(env_hldr.agents[i].pos)
        return start_positions

    def make_end_postion_list(env_hldr):
        '''Assumes agent keys in evn.agents is the same as agent id's '''
        end_positions = []
        for i in range(len(env_hldr.goals.values())):
            end_positions.append(env_hldr.goals[i].pos)
        return end_positions

    #make single env:
    env = make_parallel_env(env_args, np.random.randint(0, 1e6), 1)
    #for i in range(n_sub_updates):
    obs = env.reset()

    info2 = [{"valid_act": {i: None for i in ob.keys()}} for ob in obs]
    if args.ppo_recurrent:
        hx_cx_actr = [{i: ppo.init_hx_cx(device)
                       for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: ppo.init_hx_cx(device)
                     for i in ob.keys()} for ob in obs]
    else:
        hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

    # Get start and end_positions
    # run mstar
    #

    num_samples = minibatch_size
    for i in range(n_sub_updates):
        buff_a_probs = []
        buff_expert_a = []
        buff_blocking_pred = []
        buff_is_blocking = []
        buff_valid_act = []
        info = None
        while len(buff_a_probs) < num_samples:
            if not info is None:
                assert info[0]["terminate"] == True
            env_hldr = env.return_env()
            #env_hldr[0].render(mode='human')
            start_pos = make_start_postion_list(env_hldr[0])
            end_pos = make_end_postion_list(env_hldr[0])

            all_actions = env_hldr[0].graph.mstar_search4_OD(start_pos,
                                                             end_pos,
                                                             inflation=1.2)

            for i, mstar_action in enumerate(all_actions):
                env_hldr2 = env.return_env()
                if args.ppo_heur_valid_act:
                    #val_act_hldr = copy.deepcopy(env.return_valid_act())
                    #info2 = [{"valid_act":hldr} for hldr in val_act_hldr]
                    buff_valid_act.append({
                        k: env_hldr2[0].graph.get_valid_actions(k)
                        for k in env_hldr2[0].agents.keys()
                    })

                a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                                    for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])

                #env_hldr2[0].render(mode='human')
                buff_a_probs.append(a_probs[0])
                buff_expert_a.append(mstar_action)
                if args.ppo_heur_block:
                    buff_blocking_pred.append(blocking[0])
                    #env_hldr2 = env.return_env()
                    buff_is_blocking.append(
                        copy.deepcopy(env_hldr2[0].blocking_hldr))

                #a_env_dict = [{key:val.item() for key,val in hldr.items()} for hldr in a_select]
                #next_obs, r, dones, info = env.step(a_env_dict)
                next_obs, r, dones, info = env.step([mstar_action])

                next_obs_ = get_n_obs(next_obs, info)

                # buff.add(obs, r, value, next_obs_, a_probs, a_select, info, dones, hx_cx_actr, hx_cx_cr, hx_cx_actr_n, hx_cx_cr_n, blocking)
                #Reset hidden and cell states when epidode done
                hx_cx_actr = [{k: v
                               for k, v in hldr.items()}
                              for hldr in hx_cx_actr_n]
                hx_cx_cr = [{k: v
                             for k, v in hldr.items()} for hldr in hx_cx_cr_n]
                for i, inf in enumerate(info):
                    if inf["terminate"] and args.ppo_recurrent:
                        hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(hx_cx_cr)
                        hx_cx_actr[i] = {
                            i2: ppo.init_hx_cx(device)
                            for i2 in hx_cx_actr[i].keys()
                        }
                        hx_cx_cr[i] = {
                            i2: ppo.init_hx_cx(device)
                            for i2 in hx_cx_cr[i].keys()
                        }
                #if info[0]["terminate"]:
                #    assert i == len(all_actions)
                obs = next_obs

                if len(buff_a_probs) == num_samples:
                    keys = buff_a_probs[0].keys()
                    buff_a_probs_flat = []
                    for ap in buff_a_probs:
                        for k in keys:
                            buff_a_probs_flat.append(ap[k])

                    buff_expert_a_flat = []
                    for hldr in buff_expert_a:
                        for k in keys:
                            buff_expert_a_flat.append(hldr[k])

                    if args.ppo_heur_block:
                        buff_blocking_pred_flat = []
                        for hldr in buff_blocking_pred:
                            for k in keys:
                                buff_blocking_pred_flat.append(hldr[k])

                        buff_is_blocking_flat = []
                        for hldr in buff_is_blocking:
                            for k in keys:
                                buff_is_blocking_flat.append(hldr[k])
                    if args.ppo_heur_valid_act:
                        buff_valid_act_flat = []
                        for time_step in buff_valid_act:
                            for k in keys:
                                hldr = time_step[k]
                                multi_hot = torch.zeros(size=(1, 5),
                                                        dtype=torch.float32)
                                for i in hldr:
                                    multi_hot[0][i] = 1
                                buff_valid_act_flat.append(multi_hot)

                    #Discard excess samples:
                    #all_data = []
                    a_probs = torch.cat(buff_a_probs_flat[:num_samples])
                    expert_a = torch.from_numpy(
                        np.array(buff_expert_a_flat[:num_samples])).reshape(
                            -1, 1)
                    expert_a = ppo.tens_to_dev(device, expert_a)
                    #all_data.append(a_probs)
                    #all_data.append(expert_a)

                    if args.ppo_heur_valid_act:
                        buff_valid_act_flat = torch.cat(
                            buff_valid_act_flat[:num_samples])
                        buff_valid_act_flat = ppo.tens_to_dev(
                            device, buff_valid_act_flat)

                    if args.ppo_heur_block:
                        blocking_pred = torch.cat(
                            buff_blocking_pred_flat[:num_samples])
                        #all_data.append(blocking_pred)
                        #make tensor of ones and zeros
                        block_bool = buff_is_blocking_flat[:num_samples]
                        is_blocking = torch.zeros(len(block_bool))
                        is_blocking[block_bool] = 1
                        #is_blocking = torch.cat(is_blocking)
                        is_blocking = is_blocking.reshape(-1, 1)
                        is_blocking = ppo.tens_to_dev(device, is_blocking)
                        #all_data.append(is_blocking)

                    #train_dataset = torch.utils.data.TensorDataset(*all_data)
                    #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = minibatch_size, shuffle=True)

                    def loss_f_actions(pred, label):

                        action_label_prob = torch.gather(
                            F.softmax(pred), -1, label.long())
                        log_actions = -torch.log(action_label_prob)
                        loss = log_actions.mean()
                        return loss

                    def loss_f_blocking(pred, label):
                        #pred_action_prob = torch.gather(pred,-1, label.long())
                        #categorical loss:
                        #test_categorical_loss = CrossEntropyLoss(pred, label.detach())
                        pred = torch.clamp(pred, 1e-15, 1.0)
                        loss = -(label * torch.log(pred) +
                                 (1 - label) * torch.log(1 - pred))
                        #loss = -(label*torch.log(pred_action_prob) + (1-label)*torch.log(1-pred_action_prob))
                        #log_actions = -torch.log(action_label_prob)
                        #loss = log_actions.mean()
                        return loss.mean()

                    def loss_f_valid(all_act_prob, valid_act_multi_hot):
                        sigmoid_act = F.sigmoid(all_act_prob)
                        valid_act_loss = -(
                            torch.log(sigmoid_act) * valid_act_multi_hot +
                            torch.log(1 - sigmoid_act) *
                            (1 - valid_act_multi_hot))
                        return valid_act_loss.mean()

                    a_pred = a_probs  #data[0]
                    a = expert_a  #data[1]
                    loss2 = loss_f_actions(a_pred, a)

                    if args.ppo_heur_block:
                        loss2 += 0.5 * loss_f_blocking(blocking_pred,
                                                       is_blocking)

                    if args.ppo_heur_valid_act:
                        loss2 += 0.5 * loss_f_valid(a_pred,
                                                    buff_valid_act_flat)

                    ppo.actors[0].optimizer.zero_grad()
                    loss2.backward()
                    ppo.actors[0].optimizer.step()
    del env
示例#6
0
def bc_training_iteration2(args,
                           env_args,
                           ppo,
                           device,
                           minibatch_size,
                           n_sub_updates,
                           logger=None,
                           env_id=None,
                           bench_freq=None,
                           iteration=None):
    def make_start_postion_list(env_hldr):
        '''Assumes agent keys in evn.agents is the same as agent id's '''
        start_positions = []
        for i in range(len(env_hldr.agents.values())):
            start_positions.append(env_hldr.agents[i].pos)
        return start_positions

    def make_end_postion_list(env_hldr):
        '''Assumes agent keys in evn.agents is the same as agent id's '''
        end_positions = []
        for i in range(len(env_hldr.goals.values())):
            end_positions.append(env_hldr.goals[i].pos)
        return end_positions

    #make single env:
    env = make_parallel_env(env_args, np.random.randint(0, 1e6), 1)
    #for i in range(n_sub_updates):
    obs = env.reset()

    info2 = [{"valid_act": {i: None for i in ob.keys()}} for ob in obs]
    if args.ppo_recurrent:
        hx_cx_actr = [{i: ppo.init_hx_cx(device)
                       for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: ppo.init_hx_cx(device)
                     for i in ob.keys()} for ob in obs]
    else:
        hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

    # Get start and end_positions
    # run mstar
    #

    num_samples = minibatch_size * n_sub_updates
    #for i in range(n_sub_updates):
    inflation = 1.2
    buff_a_probs = []
    buff_expert_a = []
    buff_blocking_pred = []
    buff_is_blocking = []
    buff_valid_act = []
    info = None
    #Collect samples:
    update_cntr = 0
    #while len(buff_a_probs) < num_samples:
    while update_cntr < n_sub_updates:
        if not info is None:
            assert info[0]["terminate"] == True
        env_hldr = env.return_env()
        #env_hldr[0].render(mode='human')
        start_pos = make_start_postion_list(env_hldr[0])
        end_pos = make_end_postion_list(env_hldr[0])

        for i in range(5):
            all_actions = env_hldr[0].graph.mstar_search4_OD(
                start_pos, end_pos, inflation=inflation, memory_limit=4 * 1e6)
            if all_actions is None:
                inflation += 1.5
                if i == 4:
                    print(
                        "Mstar ran out of memory. No solution found. Exiting behaviour cloning"
                    )
                    return
            else:
                break

        for i, mstar_action in enumerate(all_actions):
            env_hldr2 = env.return_env()
            if args.ppo_heur_valid_act:
                #val_act_hldr = copy.deepcopy(env.return_valid_act())
                #info2 = [{"valid_act":hldr} for hldr in val_act_hldr]
                buff_valid_act.append({
                    k: env_hldr2[0].graph.get_valid_actions(k)
                    for k in env_hldr2[0].agents.keys()
                })

            a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                                for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])

            #env_hldr2[0].render(mode='human')
            buff_a_probs.append(a_probs[0])
            buff_expert_a.append(mstar_action)
            if args.ppo_heur_block:
                buff_blocking_pred.append(blocking[0])
                #env_hldr2 = env.return_env()
                buff_is_blocking.append(
                    copy.deepcopy(env_hldr2[0].blocking_hldr))

            next_obs, r, dones, info = env.step([mstar_action])

            next_obs_ = get_n_obs(next_obs, info)

            hx_cx_actr = [{k: v
                           for k, v in hldr.items()} for hldr in hx_cx_actr_n]
            hx_cx_cr = [{k: v for k, v in hldr.items()} for hldr in hx_cx_cr_n]
            for i, inf in enumerate(info):
                if inf["terminate"] and args.ppo_recurrent:
                    hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(hx_cx_cr)
                    hx_cx_actr[i] = {
                        i2: ppo.init_hx_cx(device)
                        for i2 in hx_cx_actr[i].keys()
                    }
                    hx_cx_cr[i] = {
                        i2: ppo.init_hx_cx(device)
                        for i2 in hx_cx_cr[i].keys()
                    }
            obs = next_obs

            # if len(buff_a_probs) == num_samples:
            #     break
            if len(buff_a_probs) == minibatch_size:
                # Update:
                update_cntr += 1
                keys = buff_a_probs[0].keys()
                buff_a_probs_flat = []
                for ap in buff_a_probs:
                    for k in keys:
                        buff_a_probs_flat.append(ap[k])

                buff_expert_a_flat = []
                for hldr in buff_expert_a:
                    for k in keys:
                        buff_expert_a_flat.append(hldr[k])

                if args.ppo_heur_block:
                    buff_blocking_pred_flat = []
                    for hldr in buff_blocking_pred:
                        for k in keys:
                            buff_blocking_pred_flat.append(hldr[k])

                    buff_is_blocking_flat = []
                    for hldr in buff_is_blocking:
                        for k in keys:
                            buff_is_blocking_flat.append(hldr[k])
                if args.ppo_heur_valid_act:
                    buff_valid_act_flat = []
                    for time_step in buff_valid_act:
                        for k in keys:
                            hldr = time_step[k]
                            multi_hot = torch.zeros(size=(1, 5),
                                                    dtype=torch.float32)
                            for i in hldr:
                                multi_hot[0][i] = 1
                            buff_valid_act_flat.append(multi_hot)

                #Discard excess samples:
                #all_data = []
                a_probs = torch.cat(buff_a_probs_flat[:num_samples])
                expert_a = torch.from_numpy(
                    np.array(buff_expert_a_flat[:num_samples])).reshape(-1, 1)
                expert_a = ppo.tens_to_dev(device, expert_a)
                #all_data.append(a_probs)
                #all_data.append(expert_a)

                if args.ppo_heur_valid_act:
                    buff_valid_act_flat = torch.cat(
                        buff_valid_act_flat[:num_samples])
                    buff_valid_act_flat = ppo.tens_to_dev(
                        device, buff_valid_act_flat)
                    #print("Buff valid act: {} \n {}".format(buff_valid_act, buff_valid_act_flat))

                if args.ppo_heur_block:
                    blocking_pred = torch.cat(
                        buff_blocking_pred_flat[:num_samples])
                    #all_data.append(blocking_pred)
                    #make tensor of ones and zeros
                    block_bool = buff_is_blocking_flat[:num_samples]
                    is_blocking = torch.zeros(len(block_bool))
                    is_blocking[block_bool] = 1
                    #is_blocking = torch.cat(is_blocking)
                    is_blocking = is_blocking.reshape(-1, 1)
                    is_blocking = ppo.tens_to_dev(device, is_blocking)
                    #all_data.append(is_blocking)

                #train_dataset = torch.utils.data.TensorDataset(*all_data)
                #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = minibatch_size, shuffle=True)

                def loss_f_actions(pred, label):
                    action_label_prob = torch.gather(F.softmax(pred), -1,
                                                     label.long())
                    log_actions = -torch.log(action_label_prob)
                    #Should sum before taking mean...
                    #This is equivalent to scaling the correct loss by 0.2.
                    loss = log_actions.mean()
                    return loss

                def loss_f_blocking(pred, label):
                    #pred_action_prob = torch.gather(pred,-1, label.long())
                    #categorical loss:
                    #test_categorical_loss = CrossEntropyLoss(pred, label.detach())
                    pred = torch.clamp(pred, 1e-15, 1.0)
                    loss = -(label * torch.log(pred) +
                             (1 - label) * torch.log(1 - pred))
                    #loss = -(label*torch.log(pred_action_prob) + (1-label)*torch.log(1-pred_action_prob))
                    #log_actions = -torch.log(action_label_prob)
                    #loss = log_actions.mean()
                    return loss.mean()

                def loss_f_valid(all_act_prob, valid_act_multi_hot):
                    sigmoid_act = F.sigmoid(all_act_prob)

                    #hldr1 = 1-sigmoid_act
                    #hldr2 = 1-valid_act_multi_hot
                    #hldr3 = torch.log(sigmoid_act)

                    valid_act_loss = -(
                        torch.log(sigmoid_act) * valid_act_multi_hot +
                        torch.log(1 - sigmoid_act) * (1 - valid_act_multi_hot))
                    return valid_act_loss.mean()

                #indexes = np.arange(0,num_samples)
                #np.random.shuffle(indexes)
                #mb_start = 0
                #for i in range(minibatch_size, num_samples, minibatch_size):
                #ind = indexes[mb_start:i]
                #mb_start = i

                a_pred = a_probs  #[ind] #data[0]
                a = expert_a  #[ind] #data[1]
                action_loss = loss_f_actions(a_pred, a)
                loss2 = action_loss
                block_loss, vld_loss = None, None

                if args.ppo_heur_block:
                    block_loss = 0.5 * loss_f_blocking(blocking_pred,
                                                       is_blocking)
                    loss2 += block_loss

                if args.ppo_heur_valid_act:
                    vld_loss = 0.5 * loss_f_valid(a_pred, buff_valid_act_flat)
                    loss2 += vld_loss

                ppo.actors[0].optimizer.zero_grad()
                loss2.backward(retain_graph=True)
                torch.nn.utils.clip_grad_norm_(ppo.actors[0].parameters(),
                                               1000)
                ppo.actors[0].optimizer.step()
                buff_a_probs = []
                buff_expert_a = []
                buff_blocking_pred = []
                buff_is_blocking = []
                buff_valid_act = []
                if logger is not None:
                    hldr = dict()
                    hldr["bc_action_loss"] = action_loss.item()
                    if block_loss is not None:
                        hldr["bc_block_loss"] = block_loss.item()
                    if vld_loss is not None:
                        hldr["bc_valid_loss"] = vld_loss.item()
                    # if block_loss is not None and vld_loss is not None:
                    #     hldr = {"bc_action_loss": action_loss.item(),
                    #             "bc_block_loss": block_loss.item(),
                    #             "bc_valid_loss": vld_loss.item()}
                    logger.log_bc(hldr)
示例#7
0
def bc_training_iteration(args, env_args, ppo, device, minibatch_size,
                          n_sub_updates):
    def make_start_postion_list(env_hldr):
        '''Assumes agent keys in evn.agents is the same as agent id's '''
        start_positions = []
        for i in range(len(env_hldr.agents.values())):
            start_positions.append(env_hldr.agents[i].pos)
        return start_positions

    def make_end_postion_list(env_hldr):
        '''Assumes agent keys in evn.agents is the same as agent id's '''
        end_positions = []
        for i in range(len(env_hldr.goals.values())):
            end_positions.append(env_hldr.goals[i].pos)
        return end_positions

    #make single env:
    env = make_parallel_env(env_args, np.random.randint(0, 1e6), 1)
    for i in range(n_sub_updates):
        obs = env.reset()

        info2 = [{"valid_act": {i: None for i in ob.keys()}} for ob in obs]
        if args.ppo_recurrent:
            hx_cx_actr = [{i: ppo.init_hx_cx(device)
                           for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: ppo.init_hx_cx(device)
                         for i in ob.keys()} for ob in obs]
        else:
            hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

        # Get start and end_positions
        # run mstar
        #
        buff_a_probs = []
        buff_expert_a = []
        buff_blocking_pred = []
        buff_is_blocking = []
        num_samples = minibatch_size
        info = None
        while len(buff_a_probs) < num_samples:
            if not info is None:
                assert info[0]["terminate"] == True
            env_hldr = env.return_env()
            #env_hldr[0].render(mode='human')
            start_pos = make_start_postion_list(env_hldr[0])
            end_pos = make_end_postion_list(env_hldr[0])

            all_actions = env_hldr[0].graph.mstar_search4_OD(start_pos,
                                                             end_pos,
                                                             inflation=1.2)

            for i, mstar_action in enumerate(all_actions):
                a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                                    for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])
                env_hldr2 = env.return_env()
                #env_hldr2[0].render(mode='human')
                buff_a_probs.append(a_probs[0])
                buff_expert_a.append(mstar_action)
                if args.ppo_heur_block:
                    buff_blocking_pred.append(blocking[0])
                    #env_hldr2 = env.return_env()
                    buff_is_blocking.append(
                        copy.deepcopy(env_hldr2[0].blocking_hldr))

                #a_env_dict = [{key:val.item() for key,val in hldr.items()} for hldr in a_select]
                #next_obs, r, dones, info = env.step(a_env_dict)
                next_obs, r, dones, info = env.step([mstar_action])

                next_obs_ = get_n_obs(next_obs, info)

                # buff.add(obs, r, value, next_obs_, a_probs, a_select, info, dones, hx_cx_actr, hx_cx_cr, hx_cx_actr_n, hx_cx_cr_n, blocking)
                #Reset hidden and cell states when epidode done
                for i, inf in enumerate(info):
                    if inf["terminate"] and args.ppo_recurrent:
                        hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(hx_cx_cr)
                        hx_cx_actr[i] = {
                            i2: ppo.init_hx_cx(device)
                            for i2 in hx_cx_actr[i].keys()
                        }
                        hx_cx_cr[i] = {
                            i2: ppo.init_hx_cx(device)
                            for i2 in hx_cx_cr[i].keys()
                        }
                #if info[0]["terminate"]:
                #    assert i == len(all_actions)
                obs = next_obs

                if len(buff_a_probs) == num_samples:
                    break

        #train PPO policy with expert data:
    #print("test imitation learning ")

        keys = buff_a_probs[0].keys()
        buff_a_probs_flat = []
        for ap in buff_a_probs:
            for k in keys:
                buff_a_probs_flat.append(ap[k])

        buff_expert_a_flat = []
        for hldr in buff_expert_a:
            for k in keys:
                buff_expert_a_flat.append(hldr[k])

        if args.ppo_heur_block:
            buff_blocking_pred_flat = []
            for hldr in buff_blocking_pred:
                for k in keys:
                    buff_blocking_pred_flat.append(hldr[k])

            buff_is_blocking_flat = []
            for hldr in buff_is_blocking:
                for k in keys:
                    buff_is_blocking_flat.append(hldr[k])

        #Discard excess samples:
        all_data = []
        a_probs = torch.cat(buff_a_probs_flat[:num_samples])
        expert_a = torch.from_numpy(np.array(
            buff_expert_a_flat[:num_samples])).reshape(-1, 1)
        expert_a = ppo.tens_to_dev(device, expert_a)
        all_data.append(a_probs)
        all_data.append(expert_a)

        if args.ppo_heur_block:
            blocking_pred = torch.cat(buff_blocking_pred_flat[:num_samples])
            all_data.append(blocking_pred)
            is_blocking = torch.cat(buff_is_blocking_flat[:num_samples])
            all_data.append(is_blocking)

        #train_dataset = torch.utils.data.TensorDataset(*all_data)
        #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = minibatch_size, shuffle=True)

        def loss_f_actions(pred, label):
            action_label_prob = torch.gather(pred, -1, label.long())
            log_actions = -torch.log(action_label_prob)
            loss = log_actions.mean()
            return loss

        def loss_f_blocking(pred, label):
            action_label_prob = torch.gather(pred, -1, label.long())
            log_actions = -torch.log(action_label_prob)
            loss = log_actions.mean()
            return loss

        a_pred = a_probs  #data[0]
        a = expert_a  #data[1]
        if args.ppo_heur_block:
            loss2 = loss_f_actions(a_pred, a) + loss_f_blocking(
                blocking_pred, is_blocking)
        loss2 = loss_f_actions(a_pred, a)
        ppo.actors[0].optimizer.zero_grad()
        loss2.backward()
        ppo.actors[0].optimizer.step()
    del env