Ejemplo n.º 1
0
def test_seg_gen(sequence_size=1000,
                 attention_size=30,
                 hidden_size=30,
                 env_id='Hopper-v1',
                 cell_type='lstm'):
    from gailtf.baselines.ppo1 import mlp_policy
    from gailtf.network.adversary_traj import TrajectoryClassifier
    import gym
    env = gym.make("Hopper-v1")

    def policy_fn(name, ob_space, ac_space, reuse=False):
        return mlp_policy.MlpPolicy(name=name,
                                    ob_space=ob_space,
                                    ac_space=ac_space,
                                    reuse=reuse,
                                    hid_size=64,
                                    num_hid_layers=2)

    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_fn('pi', ob_space, ac_space)
    discriminator = TrajectoryClassifier(env, hidden_size, sequence_size,
                                         attention_size, cell_type)
    U.make_session(num_cpu=2).__enter__()
    U.initialize()
    seg_gen = traj_segment_generator(pi, env, discriminator, 10, True,
                                     sequence_size)
    for i in range(10):
        seg = seg_gen.__next__()
        ob, ac = traj2trans(seg["ep_trajs"], seg["ep_lens"], ob_space.shape[0])
        add_vtarg_and_adv(seg, gamma=0.995, lam=0.97)
        print(seg['adv'].shape, seg['tdlamret'].shape, seg['ob'].shape,
              seg['nextvpred'])
Ejemplo n.º 2
0
def train(args):
    from gailtf.baselines.ppo1 import mlp_policy, pposgd_simple
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)
    env = gym.make(args.env_id)
    def policy_fn(name, ob_space, ac_space):
        return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=64, num_hid_layers=2)
    env = wrappers.Monitor(env, './video', force=True)
    #env = bench.Monitor(env, logger.get_dir() and
        #osp.join(logger.get_dir(), "monitor.json"))
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)
    task_name = "ppo." + args.env_id.split("-")[0] + "." + ("%.2f"%args.entcoeff)
    args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
    pposgd_simple.learn(env, policy_fn, 
            max_timesteps=args.num_timesteps,
            timesteps_per_batch=2048,
            clip_param=0.2, entcoeff=args.entcoeff,
    #        optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
            optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
            gamma=0.99, lam=0.95, schedule='linear', ckpt_dir=args.checkpoint_dir,
            save_per_iter=args.save_per_iter, task=args.task,
            sample_stochastic=args.sample_stochastic,
            load_model_path=args.load_model_path,
            task_name=task_name
        )
    env.close()
Ejemplo n.º 3
0
def test(expert_path,
         sequence_size=1000,
         attention_size=30,
         hidden_size=30,
         env_id='Hopper-v1',
         cell_type='lstm'):
    from gailtf.dataset.mujoco_traj import Mujoco_Traj_Dset
    import gym
    U.make_session(num_cpu=2).__enter__()
    dset = Mujoco_Traj_Dset(expert_path)
    env = gym.make(env_id)
    t1, tl1 = dset.get_next_traj_batch(10)
    t2, tl2 = dset.get_next_traj_batch(10)
    discriminator = TrajectoryClassifier(env, hidden_size, sequence_size,
                                         attention_size, cell_type)
    U.initialize()

    *losses, g = discriminator.lossandgrad(t1, tl1, t2, tl2, 0.5)
    rs1 = discriminator.get_rewards(t1, tl1)
    #cv1,cv2 = discriminator.check_values(t1,tl1,t2,tl2,0.5)
    print(rs1.shape)
Ejemplo n.º 4
0
def main(args):
    from gailtf.baselines.ppo1 import mlp_policy
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)
    env = gym.make(args.env_id)

    def policy_fn(name, ob_space, ac_space, reuse=False):
        return mlp_policy.MlpPolicy(name=name,
                                    ob_space=ob_space,
                                    ac_space=ac_space,
                                    reuse=reuse,
                                    hid_size=64,
                                    num_hid_layers=2)

    env = bench.Monitor(
        env,
        logger.get_dir() and osp.join(logger.get_dir(), "monitor.json"))
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)
    task_name = get_task_name(args)
    args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
    args.log_dir = osp.join(args.log_dir, task_name)
    cmd = hfo_py.get_hfo_path(
    ) + ' --offense-npcs=1 --defense-npcs=1 --log-dir /home/yupeng/Desktop/workspace/src2/GAMIL-tf0/gail-tf/log/soccer_data/ --record --frames=200'
    print(cmd)
    # os.system(cmd)

    dataset = Mujoco_Dset(expert_data_path=args.expert_data_path,
                          ret_threshold=args.ret_threshold,
                          traj_limitation=args.traj_limitation)

    # previous: dataset = Mujoco_Dset(expert_path=args.expert_path, ret_threshold=args.ret_threshold, traj_limitation=args.traj_limitation)
    pretrained_weight = None

    if (args.pretrained and args.task == 'train') or args.algo == 'bc':
        # Pretrain with behavior cloning
        from gailtf.algo import behavior_clone
        if args.algo == 'bc' and args.task == 'evaluate':
            behavior_clone.evaluate(env,
                                    policy_fn,
                                    args.load_model_path_high,
                                    args.load_model_path_low,
                                    stochastic_policy=args.stochastic_policy)
            sys.exit()
        if args.task == 'train' and args.action_space_level == 'high':
            print("training high level policy")
            pretrained_weight_high = behavior_clone.learn(
                env,
                policy_fn,
                dataset,
                max_iters=args.BC_max_iter,
                pretrained=args.pretrained,
                ckpt_dir=args.checkpoint_dir + '/high_level',
                log_dir=args.log_dir + '/high_level',
                task_name=task_name,
                high_level=True)
        if args.task == 'train' and args.action_space_level == 'low':
            print("training low level policy")
            pretrained_weight_low = behavior_clone.learn(
                env,
                policy_fn,
                dataset,
                max_iters=args.BC_max_iter,
                pretrained=args.pretrained,
                ckpt_dir=args.checkpoint_dir + '/low_level',
                log_dir=args.log_dir + '/low_level',
                task_name=task_name,
                high_level=False)
        if args.algo == 'bc':
            sys.exit()

    from gailtf.network.adversary import TransitionClassifier
    # discriminator
    discriminator = TransitionClassifier(env,
                                         args.adversary_hidden_size,
                                         entcoeff=args.adversary_entcoeff)
    if args.algo == 'trpo':
        # Set up for MPI seed
        from mpi4py import MPI
        rank = MPI.COMM_WORLD.Get_rank()
        if rank != 0:
            logger.set_level(logger.DISABLED)
        workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
        set_global_seeds(workerseed)
        env.seed(workerseed)
        from gailtf.algo import trpo_mpi
        if args.task == 'train':
            trpo_mpi.learn(env,
                           policy_fn,
                           discriminator,
                           dataset,
                           pretrained=args.pretrained,
                           pretrained_weight=pretrained_weight,
                           g_step=args.g_step,
                           d_step=args.d_step,
                           timesteps_per_batch=1024,
                           max_kl=args.max_kl,
                           cg_iters=10,
                           cg_damping=0.1,
                           max_timesteps=args.num_timesteps,
                           entcoeff=args.policy_entcoeff,
                           gamma=0.995,
                           lam=0.97,
                           vf_iters=5,
                           vf_stepsize=1e-3,
                           ckpt_dir=args.checkpoint_dir,
                           log_dir=args.log_dir,
                           save_per_iter=args.save_per_iter,
                           load_model_path=args.load_model_path,
                           task_name=task_name)
        elif args.task == 'evaluate':
            trpo_mpi.evaluate(env,
                              policy_fn,
                              args.load_model_path,
                              timesteps_per_batch=1024,
                              number_trajs=10,
                              stochastic_policy=args.stochastic_policy)
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    env.close()
Ejemplo n.º 5
0
def main(args):
    from gailtf.baselines.ppo1 import mlp_policy
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)
    env = gym.make(args.env_id)
    def policy_fn(name, ob_space, ac_space, reuse=False):
        return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            reuse=reuse, hid_size=64, num_hid_layers=2)
    env = bench.Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), "monitor.json"))
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)
    task_name = get_task_name(args)
    args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
    args.log_dir = osp.join(args.log_dir, task_name)
    dataset = Mujoco_Traj_Dset(expert_path=args.expert_path, ret_threshold=args.ret_threshold, traj_limitation=args.traj_limitation, sentence_size = args.adversary_seq_size)
    if args.adversary_seq_size is None:
        args.adversary_seq_size  = dataset.sentence_size
    pretrained_weight = None
    if (args.pretrained and args.task == 'train') or args.algo == 'bc':
        # Pretrain with behavior cloning
        from gailtf.algo import behavior_clone
        if args.algo == 'bc' and args.task == 'evaluate':
            behavior_clone.evaluate(env, policy_fn, args.load_model_path, stochastic_policy=args.stochastic_policy)
            sys.exit()
        pretrained_weight = behavior_clone.learn(env, policy_fn, dataset,
            max_iters=args.BC_max_iter, pretrained=args.pretrained, 
            ckpt_dir=args.checkpoint_dir, log_dir=args.log_dir, task_name=task_name)
        if args.algo == 'bc':
            sys.exit()

    from gailtf.network.adversary_traj import TrajectoryClassifier
    # discriminator
    discriminator = TrajectoryClassifier(env, args.adversary_hidden_size, args.adversary_seq_size, args.adversary_attn_size, cell_type = args.adversary_cell_type, entcoeff=args.adversary_entcoeff)
    if args.algo == 'trpo':
        # Set up for MPI seed
        from mpi4py import MPI
        rank = MPI.COMM_WORLD.Get_rank()
        if rank != 0:
            logger.set_level(logger.DISABLED)
        workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
        set_global_seeds(workerseed)
        env.seed(workerseed)
        from gailtf.algo import trpo_traj_mpi
        if args.task == 'train':
            trpo_traj_mpi.learn(env, policy_fn, discriminator, dataset,
                pretrained=args.pretrained, pretrained_weight=pretrained_weight,
                g_step=args.g_step, d_step=args.d_step,
                episodes_per_batch=100,
                dropout_keep_prob = 0.5, sequence_size = args.adversary_seq_size, 
                max_kl=args.max_kl, cg_iters=10, cg_damping=0.1,
                max_timesteps=args.num_timesteps, 
                entcoeff=args.policy_entcoeff, gamma=0.995, lam=0.97, 
                vf_iters=5, vf_stepsize=1e-3,
                ckpt_dir=args.checkpoint_dir, log_dir=args.log_dir,
                save_per_iter=args.save_per_iter, load_model_path=args.load_model_path,
                task_name=task_name)
        elif args.task == 'evaluate':
            trpo_mpi.evaluate(env, policy_fn, args.load_model_path, timesteps_per_batch=1024,
                number_trajs=10, stochastic_policy=args.stochastic_policy)
        else: raise NotImplementedError
    else: raise NotImplementedError

    env.close()
Ejemplo n.º 6
0
def replayACS(env, modelPath, transpose=True, fps=30, zoom=None):
    """
    Replays a game from recorded trajectories using actions
    This method is not precise though, because it indirectly recovers environment states from actions.
    Sometimes it gets asynchronous and distorts the real trajectory.
    :param env: Atari environment
    :param modelPath: path to trained model
    :param transpose:
    :param fps:
    :param zoom:
    :return:
    """
    global obs
    with open(modelPath, 'rb') as rfp:
        trajectories = pkl.load(rfp)

    U.make_session(num_cpu=1).__enter__()

    U.initialize()

    tempEnv = env
    while not isinstance(tempEnv, ActionWrapper):
        try:
            tempEnv = tempEnv.env
        except:
            break
    # using ActionWrapper:
    if isinstance(tempEnv, ActionWrapper):
        obs_s = tempEnv.screen_space
    else:
        obs_s = env.observation_space

    # assert type(obs_s) == Box
    assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1, 3])

    if zoom is None:
        zoom = 1

    video_size = int(obs_s.shape[0] * zoom), int(obs_s.shape[1] * zoom)

    if transpose:
        video_size = tuple(reversed(video_size))

    # setup the screen using pygame
    flags = RESIZABLE | HWSURFACE | DOUBLEBUF
    screen = pygame.display.set_mode(video_size, flags)
    pygame.event.set_blocked(pygame.MOUSEMOTION)
    clock = pygame.time.Clock()

    # =================================================================================================================

    running = True
    envDone = False

    playerScore = opponentScore = 0
    wins = losses = ties = gamesTotal = totalPlayer = totalOpponent = 0

    while running:
        trl = len(trajectories)

        for i in range(trl):
            obs = env.reset()
            print("\nRunning trajectory {}".format(i))
            print("Length {}".format(len(trajectories[i]['ac'])))

            for ac in tqdm(trajectories[i]['ac']):
                if not isinstance(ac, list):
                    ac = np.atleast_1d(ac)

                obs, reward, envDone, info = env.step(ac)

                # track of player score:
                if reward > 0:
                    playerScore += abs(reward)
                else:
                    opponentScore += abs(reward)

                if hasattr(env, 'getImage'):
                    obs = env.getImage()

                if obs is not None:
                    if len(obs.shape) == 2:
                        obs = obs[:, :, None]
                    if obs.shape[2] == 1:
                        obs = obs.repeat(3, axis=2)
                    display_arr(screen, obs, video_size, transpose)

                    pygame.display.flip()
                    clock.tick(fps)

            msg = format("End of game: score %d - %d" % (playerScore, opponentScore))
            print(colorize(msg, color='red'))
            gamesTotal += 1
            if playerScore > opponentScore:
                wins += 1
            elif opponentScore > playerScore:
                losses += 1
            else:
                ties += 1

            totalPlayer += playerScore
            totalOpponent += opponentScore

            playerScore = opponentScore = 0

            msg = format("Status so far: \nGames played - %d wins - %d losses - %d ties - %d\n Total score: %d - %d" % (
                gamesTotal, wins, losses, ties, totalPlayer, totalOpponent))
            print(colorize(msg, color='red'))
    pygame.quit()
Ejemplo n.º 7
0
def main(args):
    from gailtf.baselines.ppo1 import mlp_policy
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)
    env = gym.make(args.env_id)

    def policy_fn(name, ob_space, ac_space, reuse=False):
        return mlp_policy.MlpPolicy(name=name,
                                    ob_space=ob_space,
                                    ac_space=ac_space,
                                    reuse=reuse,
                                    hid_size=64,
                                    num_hid_layers=2)

    tdatetime = dt.now()
    tstr = tdatetime.strftime('%Y-%m-%d-%H-%M')
    os.mkdir("./video/" + args.env_id + '/' + tstr)
    env = wrappers.Monitor(env,
                           "./video/" + args.env_id + '/' + tstr,
                           force=True)  #動画準備
    #env = bench.Monitor(env, logger.get_dir() and
    #    osp.join(logger.get_dir(), "monitor.json"))
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)
    task_name = get_task_name(args)
    args.log_dir = "./log/GAIL/" + args.env_id + '/' + tstr + " " + task_name
    os.mkdir(args.log_dir)
    args.checkpoint_dir = "./checkpoint/GAIL/" + args.env_id + '/' + tstr + " " + task_name
    os.mkdir(args.checkpoint_dir)
    #args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
    #args.log_dir = osp.join(args.log_dir, task_name)
    dataset = Mujoco_Dset(expert_path=args.expert_path,
                          ret_threshold=args.ret_threshold,
                          traj_limitation=args.traj_limitation)
    pretrained_weight = None
    if (args.pretrained and args.task == 'train') or args.algo == 'bc':
        # Pretrain with behavior cloning
        from gailtf.algo import behavior_clone
        if args.algo == 'bc' and args.task == 'evaluate':
            behavior_clone.evaluate(env,
                                    policy_fn,
                                    args.load_model_path,
                                    stochastic_policy=args.stochastic_policy)
            sys.exit()
        pretrained_weight = behavior_clone.learn(env,
                                                 policy_fn,
                                                 dataset,
                                                 max_iters=args.BC_max_iter,
                                                 pretrained=args.pretrained,
                                                 ckpt_dir=args.checkpoint_dir,
                                                 log_dir=args.log_dir,
                                                 task_name=task_name)
        if args.algo == 'bc':
            sys.exit()

    from gailtf.network.adversary import TransitionClassifier
    # discriminator
    discriminator = TransitionClassifier(env,
                                         args.adversary_hidden_size,
                                         entcoeff=args.adversary_entcoeff)
    if args.algo == 'trpo':
        # Set up for MPI seed
        from mpi4py import MPI
        rank = MPI.COMM_WORLD.Get_rank()
        if rank != 0:
            logger.set_level(logger.DISABLED)
        workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
        set_global_seeds(workerseed)
        env.seed(workerseed)
        from gailtf.algo import trpo_mpi
        if args.task == 'train':
            trpo_mpi.learn(env,
                           policy_fn,
                           discriminator,
                           dataset,
                           pretrained=args.pretrained,
                           pretrained_weight=pretrained_weight,
                           g_step=args.g_step,
                           d_step=args.d_step,
                           timesteps_per_batch=1024,
                           max_kl=args.max_kl,
                           cg_iters=10,
                           cg_damping=0.1,
                           max_timesteps=args.num_timesteps,
                           entcoeff=args.policy_entcoeff,
                           gamma=0.995,
                           lam=0.97,
                           vf_iters=5,
                           vf_stepsize=1e-3,
                           ckpt_dir=args.checkpoint_dir,
                           log_dir=args.log_dir,
                           save_per_iter=args.save_per_iter,
                           load_model_path=args.load_model_path,
                           task_name=task_name)
        elif args.task == 'evaluate':
            trpo_mpi.evaluate(env,
                              policy_fn,
                              args.load_model_path,
                              timesteps_per_batch=1024,
                              number_trajs=10,
                              stochastic_policy=args.stochastic_policy)
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    env.close()