Пример #1
0
def train(args):
    from gailtf.baselines.ppo1 import mlp_policy, pposgd_simple
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)
    env = gym.make(args.env_id)
    def policy_fn(name, ob_space, ac_space):
        return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=64, num_hid_layers=2)
    env = wrappers.Monitor(env, './video', force=True)
    #env = bench.Monitor(env, logger.get_dir() and
        #osp.join(logger.get_dir(), "monitor.json"))
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)
    task_name = "ppo." + args.env_id.split("-")[0] + "." + ("%.2f"%args.entcoeff)
    args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
    pposgd_simple.learn(env, policy_fn, 
            max_timesteps=args.num_timesteps,
            timesteps_per_batch=2048,
            clip_param=0.2, entcoeff=args.entcoeff,
    #        optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
            optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
            gamma=0.99, lam=0.95, schedule='linear', ckpt_dir=args.checkpoint_dir,
            save_per_iter=args.save_per_iter, task=args.task,
            sample_stochastic=args.sample_stochastic,
            load_model_path=args.load_model_path,
            task_name=task_name
        )
    env.close()
Пример #2
0
def train(args):
    import gailtf.baselines.common.tf_util as U
    sess = U.single_threaded_session()
    sess.__enter__()

    rank = MPI.COMM_WORLD.Get_rank()
    if rank != 0:
        logger.set_level(logger.DISABLED)
    workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = gym.make(args.env_id)
    env = wrappers.Monitor(env, './video', force=True)

    #方策関数

    def policy_fn(name, ob_space, ac_space):
        return MlpPolicy(name=name,
                         ob_space=env.observation_space,
                         ac_space=env.action_space,
                         hid_size=32,
                         num_hid_layers=2)

    # env = bench.Monitor(env, logger.get_dir() and
    #    osp.join(logger.get_dir(), "%i.monitor.json" % rank))
    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    task_name = "trpo." + args.env_id.split("-")[0] + "." + ("%.2f" %
                                                             args.entcoeff)
    args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)

    # TRPOの学習
    # バッチごとのタイムステップ:1024
    # 最大KL:0.01
    # cg_iters=10,
    # cg_damping=0.1,
    # 最大タイムステップ(終了まで)(自分で設定した):300000
    # gamma=0.99,
    # lam=0.98,
    # vf_iters=5,
    # vf_stepsize=1e-3,
    trpo_mpi.learn(env,
                   policy_fn,
                   timesteps_per_batch=1024,
                   max_kl=0.01,
                   cg_iters=10,
                   cg_damping=0.1,
                   max_timesteps=args.num_timesteps,
                   gamma=0.99,
                   lam=0.98,
                   vf_iters=5,
                   vf_stepsize=1e-3,
                   sample_stochastic=args.sample_stochastic,
                   task_name=task_name,
                   save_per_iter=args.save_per_iter,
                   ckpt_dir=args.checkpoint_dir,
                   load_model_path=args.load_model_path,
                   task=args.task)
    env.close()
Пример #3
0
def main(args):
    from gailtf.baselines.ppo1 import mlp_policy
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)
    env = gym.make(args.env_id)

    def policy_fn(name, ob_space, ac_space, reuse=False):
        return mlp_policy.MlpPolicy(name=name,
                                    ob_space=ob_space,
                                    ac_space=ac_space,
                                    reuse=reuse,
                                    hid_size=64,
                                    num_hid_layers=2)

    env = bench.Monitor(
        env,
        logger.get_dir() and osp.join(logger.get_dir(), "monitor.json"))
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)
    task_name = get_task_name(args)
    args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
    args.log_dir = osp.join(args.log_dir, task_name)
    cmd = hfo_py.get_hfo_path(
    ) + ' --offense-npcs=1 --defense-npcs=1 --log-dir /home/yupeng/Desktop/workspace/src2/GAMIL-tf0/gail-tf/log/soccer_data/ --record --frames=200'
    print(cmd)
    # os.system(cmd)

    dataset = Mujoco_Dset(expert_data_path=args.expert_data_path,
                          ret_threshold=args.ret_threshold,
                          traj_limitation=args.traj_limitation)

    # previous: dataset = Mujoco_Dset(expert_path=args.expert_path, ret_threshold=args.ret_threshold, traj_limitation=args.traj_limitation)
    pretrained_weight = None

    if (args.pretrained and args.task == 'train') or args.algo == 'bc':
        # Pretrain with behavior cloning
        from gailtf.algo import behavior_clone
        if args.algo == 'bc' and args.task == 'evaluate':
            behavior_clone.evaluate(env,
                                    policy_fn,
                                    args.load_model_path_high,
                                    args.load_model_path_low,
                                    stochastic_policy=args.stochastic_policy)
            sys.exit()
        if args.task == 'train' and args.action_space_level == 'high':
            print("training high level policy")
            pretrained_weight_high = behavior_clone.learn(
                env,
                policy_fn,
                dataset,
                max_iters=args.BC_max_iter,
                pretrained=args.pretrained,
                ckpt_dir=args.checkpoint_dir + '/high_level',
                log_dir=args.log_dir + '/high_level',
                task_name=task_name,
                high_level=True)
        if args.task == 'train' and args.action_space_level == 'low':
            print("training low level policy")
            pretrained_weight_low = behavior_clone.learn(
                env,
                policy_fn,
                dataset,
                max_iters=args.BC_max_iter,
                pretrained=args.pretrained,
                ckpt_dir=args.checkpoint_dir + '/low_level',
                log_dir=args.log_dir + '/low_level',
                task_name=task_name,
                high_level=False)
        if args.algo == 'bc':
            sys.exit()

    from gailtf.network.adversary import TransitionClassifier
    # discriminator
    discriminator = TransitionClassifier(env,
                                         args.adversary_hidden_size,
                                         entcoeff=args.adversary_entcoeff)
    if args.algo == 'trpo':
        # Set up for MPI seed
        from mpi4py import MPI
        rank = MPI.COMM_WORLD.Get_rank()
        if rank != 0:
            logger.set_level(logger.DISABLED)
        workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
        set_global_seeds(workerseed)
        env.seed(workerseed)
        from gailtf.algo import trpo_mpi
        if args.task == 'train':
            trpo_mpi.learn(env,
                           policy_fn,
                           discriminator,
                           dataset,
                           pretrained=args.pretrained,
                           pretrained_weight=pretrained_weight,
                           g_step=args.g_step,
                           d_step=args.d_step,
                           timesteps_per_batch=1024,
                           max_kl=args.max_kl,
                           cg_iters=10,
                           cg_damping=0.1,
                           max_timesteps=args.num_timesteps,
                           entcoeff=args.policy_entcoeff,
                           gamma=0.995,
                           lam=0.97,
                           vf_iters=5,
                           vf_stepsize=1e-3,
                           ckpt_dir=args.checkpoint_dir,
                           log_dir=args.log_dir,
                           save_per_iter=args.save_per_iter,
                           load_model_path=args.load_model_path,
                           task_name=task_name)
        elif args.task == 'evaluate':
            trpo_mpi.evaluate(env,
                              policy_fn,
                              args.load_model_path,
                              timesteps_per_batch=1024,
                              number_trajs=10,
                              stochastic_policy=args.stochastic_policy)
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    env.close()
Пример #4
0
def main(args):
    from gailtf.baselines.ppo1 import mlp_policy
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)
    env = gym.make(args.env_id)
    def policy_fn(name, ob_space, ac_space, reuse=False):
        return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            reuse=reuse, hid_size=64, num_hid_layers=2)
    env = bench.Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), "monitor.json"))
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)
    task_name = get_task_name(args)
    args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
    args.log_dir = osp.join(args.log_dir, task_name)
    dataset = Mujoco_Traj_Dset(expert_path=args.expert_path, ret_threshold=args.ret_threshold, traj_limitation=args.traj_limitation, sentence_size = args.adversary_seq_size)
    if args.adversary_seq_size is None:
        args.adversary_seq_size  = dataset.sentence_size
    pretrained_weight = None
    if (args.pretrained and args.task == 'train') or args.algo == 'bc':
        # Pretrain with behavior cloning
        from gailtf.algo import behavior_clone
        if args.algo == 'bc' and args.task == 'evaluate':
            behavior_clone.evaluate(env, policy_fn, args.load_model_path, stochastic_policy=args.stochastic_policy)
            sys.exit()
        pretrained_weight = behavior_clone.learn(env, policy_fn, dataset,
            max_iters=args.BC_max_iter, pretrained=args.pretrained, 
            ckpt_dir=args.checkpoint_dir, log_dir=args.log_dir, task_name=task_name)
        if args.algo == 'bc':
            sys.exit()

    from gailtf.network.adversary_traj import TrajectoryClassifier
    # discriminator
    discriminator = TrajectoryClassifier(env, args.adversary_hidden_size, args.adversary_seq_size, args.adversary_attn_size, cell_type = args.adversary_cell_type, entcoeff=args.adversary_entcoeff)
    if args.algo == 'trpo':
        # Set up for MPI seed
        from mpi4py import MPI
        rank = MPI.COMM_WORLD.Get_rank()
        if rank != 0:
            logger.set_level(logger.DISABLED)
        workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
        set_global_seeds(workerseed)
        env.seed(workerseed)
        from gailtf.algo import trpo_traj_mpi
        if args.task == 'train':
            trpo_traj_mpi.learn(env, policy_fn, discriminator, dataset,
                pretrained=args.pretrained, pretrained_weight=pretrained_weight,
                g_step=args.g_step, d_step=args.d_step,
                episodes_per_batch=100,
                dropout_keep_prob = 0.5, sequence_size = args.adversary_seq_size, 
                max_kl=args.max_kl, cg_iters=10, cg_damping=0.1,
                max_timesteps=args.num_timesteps, 
                entcoeff=args.policy_entcoeff, gamma=0.995, lam=0.97, 
                vf_iters=5, vf_stepsize=1e-3,
                ckpt_dir=args.checkpoint_dir, log_dir=args.log_dir,
                save_per_iter=args.save_per_iter, load_model_path=args.load_model_path,
                task_name=task_name)
        elif args.task == 'evaluate':
            trpo_mpi.evaluate(env, policy_fn, args.load_model_path, timesteps_per_batch=1024,
                number_trajs=10, stochastic_policy=args.stochastic_policy)
        else: raise NotImplementedError
    else: raise NotImplementedError

    env.close()
Пример #5
0
def train(args):
    global env

    if args.expert_path is not None:
        assert osp.exists(args.expert_path)
    if args.load_model_path is not None:
        assert osp.exists(args.load_model_path + '.meta')
        args.pretrained = False

    printArgs(args)

    # ================================================ ENVIRONMENT =====================================================
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)

    if args.networkName == "MLP":
        env = gym.make(args.env_id)
        env = ActionWrapper(env, args.discrete)
    elif args.networkName == "CNN":
        env = make_atari(args.env_id)
        env = ActionWrapper(env, args.discrete)
        if args.deepmind:
            from gailtf.baselines.common.atari_wrappers import wrap_deepmind
            env = wrap_deepmind(env, False)

    env.metadata = 0
    env = bench.Monitor(env,
                        logger.get_dir()
                        and osp.join(logger.get_dir(), "monitor.json"),
                        allow_early_resets=True)
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)

    discrete = (".D." if args.discrete else ".MD")
    # ============================================== PLAY AGENT ========================================================
    # ==================================================================================================================
    if args.task == 'play_agent':
        logger.log("Playing agent...")
        from environments.atari.atari_agent import playAtari
        agent = policy_fn(args, PI, env, reuse=False)
        playAtari(env,
                  agent,
                  U,
                  modelPath=args.load_model_path,
                  fps=15,
                  stochastic=args.stochastic_policy,
                  zoom=2,
                  delay=10)
        env.close()
        sys.exit()

    # ========================================== SAMPLE TRAJECTORY FROM RL =============================================
    # ==================================================================================================================

    if args.task == 'RL_expert':
        logger.log("Sampling trajectory...")
        stoch = 'stochastic.' if args.stochastic_policy else 'deterministic.'
        taskName = stoch + "" + args.alg + "." + args.env_id + discrete + "." + str(
            args.maxSampleTrajectories)
        taskName = osp.join("data/expert", taskName)
        currentPolicy = policy_fn(args, PI, env, reuse=False)
        episodesGenerator = traj_episode_generator(
            currentPolicy,
            env,
            args.trajectoriesPerBatch,
            stochastic=args.stochastic_policy,
            render=args.visualize,
            downsample=args.downsample)
        sample_trajectory(args.load_model_path,
                          episodesGenerator,
                          taskName,
                          args.stochastic_policy,
                          max_sample_traj=args.maxSampleTrajectories)
        sys.exit()

    # ======================================== SAMPLE TRAJECTORY FROM HUMAN ============================================
    # ==================================================================================================================
    if args.task == 'human_expert':
        logger.log("Human plays...")
        taskName = "human." + args.env_id + "_" + args.networkName + "." + "50.pkl"
        args.checkpoint_dir = osp.join(args.checkpoint_dir, taskName)
        taskName = osp.join("data/expert", taskName)

        from environments.atari.atari_human import playAtari
        sampleTrajectories = playAtari(env, fps=15, zoom=2, taskName=taskName)

        pkl.dump(sampleTrajectories, open(taskName, "wb"))
        env.close()
        sys.exit()

    # =========================================== TRAIN RL EXPERT ======================================================
    # ==================================================================================================================

    if args.task == "train_RL_expert":
        logger.log("Training RL expert...")

        if args.alg == 'trpo':
            from gailtf.baselines.trpo_mpi import trpo_mpi
            taskName = args.alg + "." + args.env_id + "." + str(
                args.policy_hidden_size) + discrete + "." + str(
                    args.maxSampleTrajectories)

            rank = MPI.COMM_WORLD.Get_rank()
            if rank != 0:
                logger.set_level(logger.DISABLED)
            workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
            set_global_seeds(workerseed)
            env = gym.make(args.env_id)

            env = bench.Monitor(
                env,
                logger.get_dir()
                and osp.join(logger.get_dir(), "%i.monitor.json" % rank))
            env.seed(workerseed)
            gym.logger.setLevel(logging.WARN)

            args.checkpoint_dir = osp.join("data/training", taskName)
            trpo_mpi.learn(args,
                           env,
                           policy_fn,
                           timesteps_per_batch=1024,
                           max_iters=50_000,
                           vf_iters=5,
                           vf_stepsize=1e-3,
                           task_name=taskName)

            env.close()
            sys.exit()

        else:
            return NotImplementedError

    # =================================================== GAIL =========================================================
    # ==================================================================================================================
    if args.task == 'train_gail':

        taskName = get_task_name(args)
        args.checkpoint_dir = osp.join(args.checkpoint_dir, taskName)
        args.log_dir = osp.join(args.log_dir, taskName)
        args.task_name = taskName

        dataset = Mujoco_Dset(expert_path=args.expert_path,
                              ret_threshold=args.ret_threshold,
                              traj_limitation=args.traj_limitation)

        # discriminator
        if len(env.observation_space.shape) > 2:
            from gailtf.network.adversary_cnn import TransitionClassifier
        else:
            if args.wasserstein:
                from gailtf.network.w_adversary import TransitionClassifier
            else:
                from gailtf.network.adversary import TransitionClassifier

        discriminator = TransitionClassifier(env,
                                             args.adversary_hidden_size,
                                             entcoeff=args.adversary_entcoeff)

        pretrained_weight = None
        # pre-training with BC (optional):
        if (args.pretrained and args.task == 'train_gail') or args.alg == 'bc':
            # Pretrain with behavior cloning
            from gailtf.algo import behavior_clone
            if args.load_model_path is None:
                pretrained_weight = behavior_clone.learn(
                    args, env, policy_fn, dataset)
            if args.alg == 'bc':
                sys.exit()

        if args.alg == 'trpo':
            # Set up for MPI seed
            rank = MPI.COMM_WORLD.Get_rank()
            if rank != 0:
                logger.set_level(logger.DISABLED)
            workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
            set_global_seeds(workerseed)
            env.seed(workerseed)

            # if args.wasserstein:
            #     from gailtf.algo import w_trpo_mpi as trpo
            # else:
            from gailtf.algo import trpo_mpi as trpo

            trpo.learn(args,
                       env,
                       policy_fn,
                       discriminator,
                       dataset,
                       pretrained_weight=pretrained_weight,
                       cg_damping=0.1,
                       vf_iters=5,
                       vf_stepsize=1e-3)
        else:
            raise NotImplementedError

        env.close()
        sys.exit()
Пример #6
0
def main(args):
    from gailtf.baselines.ppo1 import mlp_policy
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)
    env = gym.make(args.env_id)

    def policy_fn(name, ob_space, ac_space, reuse=False):
        return mlp_policy.MlpPolicy(name=name,
                                    ob_space=ob_space,
                                    ac_space=ac_space,
                                    reuse=reuse,
                                    hid_size=64,
                                    num_hid_layers=2)

    tdatetime = dt.now()
    tstr = tdatetime.strftime('%Y-%m-%d-%H-%M')
    os.mkdir("./video/" + args.env_id + '/' + tstr)
    env = wrappers.Monitor(env,
                           "./video/" + args.env_id + '/' + tstr,
                           force=True)  #動画準備
    #env = bench.Monitor(env, logger.get_dir() and
    #    osp.join(logger.get_dir(), "monitor.json"))
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)
    task_name = get_task_name(args)
    args.log_dir = "./log/GAIL/" + args.env_id + '/' + tstr + " " + task_name
    os.mkdir(args.log_dir)
    args.checkpoint_dir = "./checkpoint/GAIL/" + args.env_id + '/' + tstr + " " + task_name
    os.mkdir(args.checkpoint_dir)
    #args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
    #args.log_dir = osp.join(args.log_dir, task_name)
    dataset = Mujoco_Dset(expert_path=args.expert_path,
                          ret_threshold=args.ret_threshold,
                          traj_limitation=args.traj_limitation)
    pretrained_weight = None
    if (args.pretrained and args.task == 'train') or args.algo == 'bc':
        # Pretrain with behavior cloning
        from gailtf.algo import behavior_clone
        if args.algo == 'bc' and args.task == 'evaluate':
            behavior_clone.evaluate(env,
                                    policy_fn,
                                    args.load_model_path,
                                    stochastic_policy=args.stochastic_policy)
            sys.exit()
        pretrained_weight = behavior_clone.learn(env,
                                                 policy_fn,
                                                 dataset,
                                                 max_iters=args.BC_max_iter,
                                                 pretrained=args.pretrained,
                                                 ckpt_dir=args.checkpoint_dir,
                                                 log_dir=args.log_dir,
                                                 task_name=task_name)
        if args.algo == 'bc':
            sys.exit()

    from gailtf.network.adversary import TransitionClassifier
    # discriminator
    discriminator = TransitionClassifier(env,
                                         args.adversary_hidden_size,
                                         entcoeff=args.adversary_entcoeff)
    if args.algo == 'trpo':
        # Set up for MPI seed
        from mpi4py import MPI
        rank = MPI.COMM_WORLD.Get_rank()
        if rank != 0:
            logger.set_level(logger.DISABLED)
        workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
        set_global_seeds(workerseed)
        env.seed(workerseed)
        from gailtf.algo import trpo_mpi
        if args.task == 'train':
            trpo_mpi.learn(env,
                           policy_fn,
                           discriminator,
                           dataset,
                           pretrained=args.pretrained,
                           pretrained_weight=pretrained_weight,
                           g_step=args.g_step,
                           d_step=args.d_step,
                           timesteps_per_batch=1024,
                           max_kl=args.max_kl,
                           cg_iters=10,
                           cg_damping=0.1,
                           max_timesteps=args.num_timesteps,
                           entcoeff=args.policy_entcoeff,
                           gamma=0.995,
                           lam=0.97,
                           vf_iters=5,
                           vf_stepsize=1e-3,
                           ckpt_dir=args.checkpoint_dir,
                           log_dir=args.log_dir,
                           save_per_iter=args.save_per_iter,
                           load_model_path=args.load_model_path,
                           task_name=task_name)
        elif args.task == 'evaluate':
            trpo_mpi.evaluate(env,
                              policy_fn,
                              args.load_model_path,
                              timesteps_per_batch=1024,
                              number_trajs=10,
                              stochastic_policy=args.stochastic_policy)
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    env.close()