def run(args):
    if args.ppo_use_gpu:
        device = 'gpu'
    else:
        device = 'cpu'
    buff = PPO_Buffer(args.n_agents, args.ppo_workers, args.ppo_rollout_length,
                      args.ppo_recurrent)
    env = make_parallel_env(args, args.seed, buff.nworkers)
    ppo = PPO(env.action_space[0].n, env.observation_space[0],
              args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor,
              args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size,
              args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim,
              args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent,
              args.ppo_heur_block)
    buff.init_model(ppo)
    logger = Logger(args, "No summary", "no policy summary", ppo)

    obs = env.reset()
    if args.ppo_recurrent:
        hx_cx_actr = [{i: ppo.init_hx_cx(device)
                       for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: ppo.init_hx_cx(device)
                     for i in ob.keys()} for ob in obs]
    else:
        hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
        hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

    stats = {}
    for it in range(args.ppo_iterations):
        print("Iteration: {}".format(it))
        render_frames = []
        render_cntr = 0
        info2 = [{"valid_act": {i: None for i in ob.keys()}} for ob in obs]
        while buff.is_full == False:
            if args.ppo_heur_valid_act:
                val_act_hldr = copy.deepcopy(env.return_valid_act())
                info2 = [{"valid_act": hldr} for hldr in val_act_hldr]
            else:
                val_act_hldr = [{i: None for i in ob.keys()} for ob in obs]
            ppo.prep_device(device)
            a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                    for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])
            a_env_dict = [{key: val.item()
                           for key, val in hldr.items()} for hldr in a_select]
            next_obs, r, dones, info = env.step(a_env_dict)
            if it % args.render_rate == 0:
                if render_cntr < args.render_length:
                    render_frames.append(env.render(indices=[0])[0])
                    if info[0]["terminate"]:
                        render_cntr += 1
            next_obs_ = get_n_obs(next_obs, info)
            buff.add(obs, r, value, next_obs_, a_probs, a_select, info, dones,
                     hx_cx_actr, hx_cx_cr, hx_cx_actr_n, hx_cx_cr_n, blocking,
                     val_act_hldr)
            hx_cx_actr, hx_cx_cr = hx_cx_actr_n, hx_cx_cr_n
            #Reset hidden and cell states when epidode done
            for i, inf in enumerate(info):
                if inf["terminate"] and args.ppo_recurrent:
                    hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(hx_cx_cr)
                    hx_cx_actr[i] = {
                        i2: ppo.init_hx_cx(device)
                        for i2 in hx_cx_actr[i].keys()
                    }
                    hx_cx_cr[i] = {
                        i2: ppo.init_hx_cx(device)
                        for i2 in hx_cx_cr[i].keys()
                    }

            obs = next_obs

        if buff.is_full:
            observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act  = buff.sample(args.ppo_discount, \
                args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act)
            stats["action_loss"], stats["value_loss"] \
            = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device)

            if args.ppo_recurrent:
                for a, c in zip(hx_cx_actr, hx_cx_cr):
                    for a2, c2 in zip(a.values(), c.values()):
                        a2[0].detach_()
                        a2[1].detach_()
                        c2[0].detach_()
                        c2[1].detach_()
        if (it + 1) % args.benchmark_frequency == 0:
            rend_f, ter_i = benchmark_func(args, ppo,
                                           args.benchmark_num_episodes,
                                           args.benchmark_render_length,
                                           device)
            logger.benchmark_info(ter_i, rend_f, it)

        #Logging:
        stats["iterations"] = it
        stats["num_timesteps"] = len(infos)
        terminal_t_info = [inf for inf in infos if inf["terminate"]]
        stats["num_episodes"] = len(terminal_t_info)
        logger.log(stats, terminal_t_info, render_frames, checkpoint=True)
        render_frames = []
        render_cntr = 0

    rend_f, ter_i = benchmark_func(args, ppo, args.benchmark_num_episodes,
                                   args.benchmark_render_length, device)
    logger.benchmark_info(ter_i, rend_f, it)
Beispiel #2
0
def run(args):  #Curriculum train:
    if args.ppo_use_gpu:
        device = 'gpu'
    else:
        device = 'cpu'

    assert args.ppo_bc_iteration_prob >= 0.0 and args.ppo_bc_iteration_prob <= 1.0

    curr_manager = CurriculumManager(args, CurriculumLogger)
    #Get env args for first env
    env_args = curr_manager.init_env_args()
    #Determine number of workers for buffer and init env
    buff = PPO_Buffer(env_args.n_agents, args.ppo_workers,
                      args.ppo_rollout_length, args.ppo_recurrent)
    seed = np.random.randint(0, 100000)
    env = make_parallel_env(env_args, seed, buff.nworkers)
    #Init ppo model
    ppo = PPO(env.action_space[0].n, env.observation_space[0],
              args.ppo_base_policy_type, env.n_agents[0], args.ppo_share_actor,
              args.ppo_share_value, args.ppo_k_epochs, args.ppo_minibatch_size,
              args.ppo_lr_a, args.ppo_lr_v, args.ppo_hidden_dim,
              args.ppo_eps_clip, args.ppo_entropy_coeff, args.ppo_recurrent,
              args.ppo_heur_block)
    #Add model to buffer
    buff.init_model(ppo)

    logger = curr_manager.init_logger(ppo, benchmark_func)
    global_iterations = 0
    env.close()
    while not curr_manager.is_done:
        env_args = curr_manager.sample_env()
        buff.__init__(env_args.n_agents, args.ppo_workers,
                      args.ppo_rollout_length,
                      args.ppo_recurrent)  #recalculates nworkers
        ppo.extend_agent_indexes(env_args.n_agents)
        buff.init_model(ppo)
        seed = np.random.randint(0, 10000)
        env = make_parallel_env(env_args, seed, buff.nworkers)

        obs = env.reset()
        if args.ppo_recurrent:
            hx_cx_actr = [{i: ppo.init_hx_cx(device)
                           for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: ppo.init_hx_cx(device)
                         for i in ob.keys()} for ob in obs]
        else:
            hx_cx_actr = [{i: (None, None) for i in ob.keys()} for ob in obs]
            hx_cx_cr = [{i: (None, None) for i in ob.keys()} for ob in obs]

        extra_stats = {}
        env_id = curr_manager.curr_env_id
        up_i = 0
        #for up_i in range(curr_manager.n_updates):
        while up_i < curr_manager.n_updates:
            bc_iteration = np.random.choice(
                [True, False],
                p=[args.ppo_bc_iteration_prob, 1 - args.ppo_bc_iteration_prob])
            if bc_iteration:
                n_sub_updates = (
                    (args.ppo_workers * args.ppo_rollout_length) //
                    args.ppo_minibatch_size)
                bc_training_iteration2(args, env_args, ppo, device, 64,
                                       n_sub_updates * 4)
            else:
                print("Iteration: {}".format(global_iterations))
                info2 = [{
                    "valid_act": {i: None
                                  for i in ob.keys()}
                } for ob in obs]
                while buff.is_full == False:
                    if args.ppo_heur_valid_act:
                        val_act_hldr = copy.deepcopy(env.return_valid_act())
                        info2 = [{"valid_act": hldr} for hldr in val_act_hldr]
                    else:
                        val_act_hldr = [{i: None
                                         for i in ob.keys()} for ob in obs]

                    a_probs, a_select, value, hx_cx_actr_n, hx_cx_cr_n, blocking = zip(*[ppo.forward(ob,ha,hc, dev=device, valid_act_heur = inf2["valid_act"] ) \
                                        for ob,ha,hc, inf2 in zip(obs, hx_cx_actr, hx_cx_cr, info2)])

                    a_env_dict = [{
                        key: val.item()
                        for key, val in hldr.items()
                    } for hldr in a_select]
                    next_obs, r, dones, info = env.step(a_env_dict)
                    logger.record_render(env_id, env, info[0])
                    next_obs_ = get_n_obs(next_obs, info)

                    buff.add(obs, r, value, next_obs_, a_probs, a_select, info,
                             dones, hx_cx_actr, hx_cx_cr, hx_cx_actr_n,
                             hx_cx_cr_n, blocking, val_act_hldr)

                    hx_cx_actr = [{k: v
                                   for k, v in hldr.items()}
                                  for hldr in hx_cx_actr_n]
                    hx_cx_cr = [{k: v
                                 for k, v in hldr.items()}
                                for hldr in hx_cx_cr_n]
                    #Reset hidden and cell states when epidode done
                    for i, inf in enumerate(info):
                        if inf["terminate"] and args.ppo_recurrent:
                            hx_cx_actr, hx_cx_cr = list(hx_cx_actr), list(
                                hx_cx_cr)
                            hx_cx_actr[i] = {
                                i2: ppo.init_hx_cx(device)
                                for i2 in hx_cx_actr[i].keys()
                            }
                            hx_cx_cr[i] = {
                                i2: ppo.init_hx_cx(device)
                                for i2 in hx_cx_cr[i].keys()
                            }
                    obs = next_obs

                if buff.is_full:
                    observations, a_prob, a_select, adv, v, infos, h_actr, h_cr, blk_labels, blk_pred, val_act = buff.sample(args.ppo_discount,\
                    args.ppo_gae_lambda,args.ppo_gae_lambda, blocking=args.ppo_heur_block, use_valid_act = args.ppo_heur_valid_act)
                    extra_stats["action_loss"], extra_stats["value_loss"] \
                    = ppo.update(observations, a_prob, a_select, adv, v, 0, h_actr, h_cr, blk_labels, blk_pred,val_act, dev=device)
                    if args.ppo_recurrent:
                        for a, c in zip(hx_cx_actr, hx_cx_cr):
                            for a2, c2 in zip(a.values(), c.values()):
                                a2[0].detach_()
                                a2[1].detach_()
                                c2[0].detach_()
                                c2[1].detach_()
                    global_iterations += 1
                    logger.log(env_id, infos, extra_stats)
                    if up_i == curr_manager.n_updates - 1:
                        logger.release_render(env_id)
                up_i += 1
        env.close()
    print("Done")