示例#1
0
def main():

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name', type=str, default='HalfCheetah-v1')
    # Experiment meta-params
    parser.add_argument('--exp_name', type=str, default='mb_mpc')
    parser.add_argument('--seed', type=int, default=3)
    parser.add_argument('--render', action='store_true')
    # Training args
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-3)
    parser.add_argument('--onpol_iters', '-n', type=int, default=1)
    parser.add_argument('--dyn_iters', '-nd', type=int, default=60)
    parser.add_argument('--batch_size', '-b', type=int, default=512)
    # Data collection
    parser.add_argument('--random_paths', '-r', type=int, default=10)
    parser.add_argument('--onpol_paths', '-d', type=int, default=10)
    parser.add_argument('--simulated_paths', '-sp', type=int, default=1000)
    parser.add_argument('--ep_len', '-ep', type=int, default=1000)
    # Neural network architecture args
    parser.add_argument('--n_layers', '-l', type=int, default=2)
    parser.add_argument('--size', '-s', type=int, default=500)
    # MPC Controller
    parser.add_argument('--mpc_horizon', '-m', type=int, default=15)
    args = parser.parse_args()

    # Set seed
    np.random.seed(args.seed)
    tf.set_random_seed(args.seed)

    # Make data directory if it does not already exist
    if not(os.path.exists('data')):
        os.makedirs('data')
    logdir = args.exp_name + '_' + args.env_name + '_' + time.strftime("%d-%m-%Y_%H-%M-%S")
    logdir = os.path.join('data', logdir)
    if not(os.path.exists(logdir)):
        os.makedirs(logdir)

    # Make env
    if args.env_name is "HalfCheetah-v1":
        env = HalfCheetahEnvNew()
        cost_fn = cheetah_cost_fn
    train(env=env, 
                 cost_fn=cost_fn,
                 logdir=logdir,
                 render=args.render,
                 learning_rate=args.learning_rate,
                 onpol_iters=args.onpol_iters,
                 dynamics_iters=args.dyn_iters,
                 batch_size=args.batch_size,
                 num_paths_random=args.random_paths, 
                 num_paths_onpol=args.onpol_paths, 
                 num_simulated_paths=args.simulated_paths,
                 env_horizon=args.ep_len, 
                 mpc_horizon=args.mpc_horizon,
                 n_layers = args.n_layers,
                 size=args.size,
                 activation=tf.nn.relu,
                 output_activation=None,
                 )
示例#2
0
def main():
    args = get_parser().parse_args()

    # Establish the logger.
    format = "[%(asctime)-15s %(pathname)s:%(lineno)-3s] %(message)s"
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter(format))
    logger = logging.getLogger("mjw")
    logger.propagate = False
    logger.addHandler(handler)
    if args.verbose:
        logger.setLevel(logging.DEBUG)

    # Set seed
    np.random.seed(args.seed)
    tf.set_random_seed(args.seed)

    # Make data directory if it does not already exist
    if not (os.path.exists('data')):
        os.makedirs('data')
    timestamp = time.strftime("%d-%m-%Y_%H-%M-%S")
    logdir = "{}_{}_{}".format(args.exp_name, args.env_name, timestamp)
    logdir = os.path.join('data', logdir)
    if not (os.path.exists(logdir)):
        os.makedirs(logdir)

    # Make env
    if args.env_name == "HalfCheetah-v1":
        env = HalfCheetahEnvNew()
        cost_fn = cheetah_cost_fn
    train(
        env=env,
        cost_fn=cost_fn,
        logdir=logdir,
        render=args.render,
        learning_rate=args.learning_rate,
        onpol_iters=args.onpol_iters,
        dynamics_iters=args.dyn_iters,
        batch_size=args.batch_size,
        num_paths_random=args.random_paths,
        num_paths_onpol=args.onpol_paths,
        num_simulated_paths=args.simulated_paths,
        env_horizon=args.ep_len,
        mpc_horizon=args.mpc_horizon,
        n_layers=args.n_layers,
        size=args.size,
        activation=tf.nn.relu,
        output_activation=None,
    )
def experiment(variant):
    from cheetah_env import HalfCheetahEnvNew
    from cost_functions import cheetah_cost_fn, \
        hopper_cost_fn, \
        swimmer_cost_fn
    from hopper_env import HopperEnvNew
    from main_solution import train_dagger
    from railrl.core import logger
    from swimmer_env import SwimmerEnvNew
    env_name_or_class = variant['env_name_or_class']

    if type(env_name_or_class) == str:
        if 'cheetah' in str.lower(env_name_or_class):
            env = HalfCheetahEnvNew()
            cost_fn = cheetah_cost_fn
        elif 'hopper' in str.lower(env_name_or_class):
            env = HopperEnvNew()
            cost_fn = hopper_cost_fn
        elif 'swimmer' in str.lower(env_name_or_class):
            env = SwimmerEnvNew()
            cost_fn = swimmer_cost_fn
        else:
            raise NotImplementedError
    else:
        env = env_name_or_class()
        from railrl.envs.wrappers import NormalizedBoxEnv
        env = NormalizedBoxEnv(env)
        if env_name_or_class == Pusher2DEnv:
            cost_fn = pusher2d_cost_fn
        elif env_name_or_class == Reacher7Dof:
            cost_fn = reacher7dof_cost_fn
        elif env_name_or_class == HalfCheetah:
            cost_fn = half_cheetah_cost_fn
        else:
            if variant['multitask']:
                env = MultitaskToFlatEnv(env)
            cost_fn = env.cost_fn

    train_dagger(env=env,
                 cost_fn=cost_fn,
                 logdir=logger.get_snapshot_dir(),
                 **variant['dagger_params'])
def train_PG(exp_name='',
             env_name=' HalfCheetah',
             n_iter=100, 
             gamma=1.0, 
             min_timesteps_per_batch=1000, 
             max_path_length=None,
             learning_rate=5e-3, 
             reward_to_go=False, 
             animate=True, 
             logdir=None, 
             normalize_advantages=False,
             nn_baseline=False, 
             seed=0,
             # network arguments
             n_layers=1,
             size=32,
             ):

    start = time.time()

    # Configure output directory for logging
    logz.configure_output_dir(logdir)

    # Log experimental parameters
    args = inspect.getargspec(train_PG)[0]
    locals_ = locals()
    params = {k: locals_[k] if k in locals_ else None for k in args}
    logz.save_params(params)

    # Set random seeds
    tf.set_random_seed(seed)
    np.random.seed(seed)

    # Make the gym environment
    env = HalfCheetahEnvNew()
    # env = gym.make("RoboschoolHalfCheetah-v1")

    # Is this env continuous, or discrete?
    discrete = isinstance(env.action_space, gym.spaces.Discrete)

    # Maximum length for episodes
    max_path_length = max_path_length

    # Observation and action sizes
    ob_dim = env.observation_space.shape[0]
    ac_dim = env.action_space.n if discrete else env.action_space.shape[0]

    # Print environment infomation
    print("Environment name: ",  "HalfCheetah")
    print("Action space is discrete: ", discrete)
    print("Action space dim: ", ac_dim)
    print("Observation space dim: ", ob_dim)
    print("Max_path_length ", max_path_length)



    #========================================================================================#
    # Tensorflow Engineering: Config, Session, Variable initialization
    #========================================================================================#


    tf_config = tf.ConfigProto(inter_op_parallelism_threads=1, intra_op_parallelism_threads=4) 

    sess = tf.Session(config=tf_config)

    sess.__enter__() # equivalent to `with sess:`

    data_buffer_ppo = DataBuffer_general(10000, 4)


    timesteps_per_actorbatch=1000
    max_timesteps = 10000000
    clip_param=0.2
    entcoeff=0.0
    optim_epochs=10
    optim_stepsize=3e-4 
    optim_batchsize=64
    gamma=0.99
    lam=0.95
    schedule='linear'
    callback=None # you can do anything in the callback, since it takes locals(), globals()
    adam_epsilon=1e-5

    policy_nn = MlpPolicy_bc(sess=sess, env=env, hid_size=128, num_hid_layers=2, clip_param=clip_param , entcoeff=entcoeff)
    # policy_nn = MlpPolicy(sess=sess, env=env, hid_size=64, num_hid_layers=2, clip_param=clip_param , entcoeff=entcoeff, adam_epsilon=adam_epsilon)

    tf.global_variables_initializer().run() #pylint: disable=E1101


    # Prepare for rollouts
    # ----------------------------------------

    # seg_gen = traj_segment_generator_old(policy_nn, env, timesteps_per_actorbatch)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards


    while True:

        if callback: callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break

        if schedule == 'constant':
            cur_lrmult = 1.0
        elif schedule == 'linear':
            cur_lrmult =  max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        else:
            raise NotImplementedError

        logger.log("********** Iteration %i ************"%iters_so_far)

        data_buffer_ppo.clear()
        seg = traj_segment_generator(policy_nn, env, timesteps_per_actorbatch)
        # seg = seg_gen.__next__()

        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
        vpredbefore = seg["vpred"] # predicted value function before udpate
        atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
        # d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not policy_nn.recurrent)

        for n in range(len(ob)):
            data_buffer_ppo.add([ob[n], ac[n], atarg[n], tdlamret[n]])
        print("data_buffer_ppo", data_buffer_ppo.size)

        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(policy_nn, "ob_rms"): policy_nn.ob_rms.update(ob) # update running mean/std for policy

        policy_nn.assign_old_eq_new() # set old parameter values to new parameter values

        # logger.log("Optimizing...")
        # logger.log(fmt_row(13, policy_nn.loss_names))

        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [] # list of tuples, each of which gives the loss for a minibatch
            for i in range(int(timesteps_per_actorbatch/optim_batchsize)):
                sample_ob_no, sample_ac_na, sample_adv_n, sample_b_n_target = data_buffer_ppo.sample(optim_batchsize)

                newlosses = policy_nn.lossandupdate_ppo(sample_ob_no, sample_ac_na, sample_adv_n, sample_b_n_target, cur_lrmult, optim_stepsize*cur_lrmult)
                losses.append(newlosses)

            # logger.log(fmt_row(13, np.mean(losses, axis=0)))



        # logger.log("Evaluating losses...")
        # losses = []
        # # for batch in d.iterate_once(optim_batchsize):
        # sample_ob_no, sample_ac_na, sample_adv_n, sample_b_n_target = data_buffer_ppo.sample(optim_batchsize)

        # newlosses = policy_nn.compute_losses(sample_ob_no, sample_ac_na, sample_adv_n, sample_b_n_target, cur_lrmult)
        # losses.append(newlosses)
        # meanlosses,_,_ = mpi_moments(losses, axis=0)
        # logger.log(fmt_row(13, meanlosses))
        # for (lossval, name) in zipsame(meanlosses, policy_nn.loss_names):
        #     logger.record_tabular("loss_"+name, lossval)
        # logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
        lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)
        # logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        # logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        # logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1
        # logger.record_tabular("EpisodesSoFar", episodes_so_far)
        # logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        # logger.record_tabular("TimeElapsed", time.time() - tstart)
        # if MPI.COMM_WORLD.Get_rank()==0:
        #     logger.dump_tabular()




        # Log diagnostics
        # returns = [path["reward"].sum() for path in paths]
        # ep_lengths = [pathlength(path) for path in paths]

        ep_lengths = seg["ep_lens"]
        returns =  seg["ep_rets"]

        logz.log_tabular("Time", time.time() - start)
        logz.log_tabular("Iteration", iters_so_far)
        logz.log_tabular("AverageReturn", np.mean(returns))
        logz.log_tabular("StdReturn", np.std(returns))
        logz.log_tabular("MaxReturn", np.max(returns))
        logz.log_tabular("MinReturn", np.min(returns))
        logz.log_tabular("EpLenMean", np.mean(ep_lengths))
        logz.log_tabular("EpLenStd", np.std(ep_lengths))
        # logz.log_tabular("TimestepsThisBatch", timesteps_this_batch)
        logz.log_tabular("TimestepsSoFar", timesteps_so_far)
        logz.dump_tabular()
        logz.pickle_tf_vars()
示例#5
0
文件: main.py 项目: Snowstu/MBMF
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name', type=str,
                        default='HalfCheetah-v2')  #HalfCheetah-v2
    # Experiment meta-params
    parser.add_argument('--exp_name', type=str, default='mb_mpc')
    parser.add_argument('--seed', type=int, default=3)
    parser.add_argument('--render', action='store_true')
    # Training args
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-3)
    parser.add_argument('--onpol_iters', '-n', type=int,
                        default=5)  # Aggregation iters 10
    parser.add_argument('--dyn_iters', '-nd', type=int,
                        default=60)  # epochs 60
    parser.add_argument('--batch_size', '-b', type=int, default=512)
    # Data collection
    parser.add_argument('--random_paths', '-r', type=int,
                        default=700)  # random path nums 700
    parser.add_argument('--onpol_paths', '-d', type=int,
                        default=10)  # mpc path nums   10
    parser.add_argument('--ep_len', '-ep', type=int,
                        default=1000)  # 1000   path length   1000
    # Neural network architecture args
    parser.add_argument('--n_layers', '-l', type=int, default=2)
    parser.add_argument('--size', '-s', type=int, default=500)
    # MPC Controller
    parser.add_argument('--mpc_horizon', '-m', type=int,
                        default=15)  # mpc simulation H  10
    parser.add_argument('--simulated_paths', '-sp', type=int,
                        default=10000)  # mpc  candidate  K 100
    args = parser.parse_args()

    print(args)
    # Set seed
    np.random.seed(args.seed)
    tf.set_random_seed(args.seed)

    # Make data directory if it does not already exist

    # Make env
    if args.env_name is 'HalfCheetah-v2':
        env = HalfCheetahEnvNew()
        cost_fn = cheetah_cost_fn

    env_name = args.env_name  # HalfCheetah-v2  My3LineDirect-v1
    cost_fn = cheetah_cost_fn
    env = gym.make(env_name)
    # env.set_goals(45 * 3.14 / 180.0)  # 角度要换成弧度

    logdir = configure_log_dir(logname=env_name, txt='-train')
    utils.LOG_PATH = logdir

    with open(logdir + '/info.txt', 'wt') as f:
        print('Hello World!\n', file=f)
        print(args, file=f)

    train(
        env=env,
        cost_fn=cost_fn,
        logdir=logdir,
        render=args.render,
        learning_rate=args.learning_rate,
        onpol_iters=args.onpol_iters,
        dynamics_iters=args.dyn_iters,
        batch_size=args.batch_size,
        num_paths_random=args.random_paths,
        num_paths_onpol=args.onpol_paths,
        num_simulated_paths=args.simulated_paths,
        env_horizon=args.ep_len,
        mpc_horizon=args.mpc_horizon,
        n_layers=args.n_layers,
        size=args.size,
        activation='relu',
        output_activation=None,
    )
示例#6
0
def main():

    assert (FLAGS.mpc or FLAGS.ppo) == True

    # Set seed
    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

    # Make data directory if it does not already exist
    if not (os.path.exists('data')):
        os.makedirs('data')
    # logdir = FLAGS.exp_name + '_' + FLAGS.env_name + '_' + time.strftime("%d-%m-%Y_%H-%M-%S")
    logdir = FLAGS.exp_name + '_' + FLAGS.env_name

    logdir = os.path.join('data', logdir)
    if not (os.path.exists(logdir)):
        os.makedirs(logdir)

    # Make env
    if FLAGS.env_name == "HalfCheetah-v1":
        env = HalfCheetahEnvNew()
        cost_fn = cheetah_cost_fn

        # env = gym.make(FLAGS.env_name)
        env.seed(FLAGS.seed)
    else:
        env = gym.make(FLAGS.env_name)
        env.seed(FLAGS.seed)
        cost_fn = None

    train(
        env=env,
        cost_fn=cost_fn,
        logdir=logdir,
        render=FLAGS.render,
        learning_rate=FLAGS.learning_rate,
        onpol_iters=FLAGS.onpol_iters,
        dynamics_iters=FLAGS.dyn_iters,
        batch_size=FLAGS.batch_size,
        num_paths_random=FLAGS.random_paths,
        num_paths_onpol=FLAGS.onpol_paths,
        num_simulated_paths=FLAGS.simulated_paths,
        env_horizon=FLAGS.ep_len,
        mpc_horizon=FLAGS.mpc_horizon,
        n_layers=FLAGS.n_layers,
        size=FLAGS.size,
        activation=tf.nn.relu,
        output_activation=None,
        clip_param=FLAGS.clip_param,
        entcoeff=FLAGS.entcoeff,
        gamma=FLAGS.gamma,
        lam=FLAGS.lam,
        optim_epochs=FLAGS.optim_epochs,
        optim_batchsize=FLAGS.optim_batchsize,
        schedule=FLAGS.schedule,
        bc_lr=FLAGS.bc_lr,
        ppo_lr=FLAGS.ppo_lr,
        timesteps_per_actorbatch=FLAGS.timesteps_per_actorbatch,
        MPC=FLAGS.mpc,
        BEHAVIORAL_CLONING=FLAGS.bc,
        PPO=FLAGS.ppo,
    )
示例#7
0
def main():

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name', type=str, default='HalfCheetah-v1')
    # Experiment meta-params
    parser.add_argument('--exp_name', type=str, default='mpc_bc_ppo')
    parser.add_argument('--seed', type=int, default=3)
    parser.add_argument('--render', action='store_true')
    # Model Training args
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-3)
    parser.add_argument('--onpol_iters', '-n', type=int, default=100)
    parser.add_argument('--dyn_iters', '-nd', type=int, default=260)
    parser.add_argument('--batch_size', '-b', type=int, default=512)

    # BC and PPO Training args
    parser.add_argument('--bc_lr', '-bc_lr', type=float, default=1e-3)
    parser.add_argument('--ppo_lr', '-ppo_lr', type=float, default=1e-4)

    parser.add_argument('--clip_param', '-cp', type=float, default=0.2)
    parser.add_argument('--gamma', '-g', type=float, default=0.99)
    parser.add_argument('--entcoeff', '-ent', type=float, default=0.0)
    parser.add_argument('--lam', type=float, default=0.95)
    parser.add_argument('--optim_epochs', type=int, default=500)
    parser.add_argument('--optim_batchsize', type=int, default=128)
    parser.add_argument('--schedule', type=str, default='constant')
    parser.add_argument('--timesteps_per_actorbatch', '-b2', type=int, default=1000)
    # Data collection
    parser.add_argument('--random_paths', '-r', type=int, default=10)
    parser.add_argument('--onpol_paths', '-d', type=int, default=1)
    parser.add_argument('--simulated_paths', '-sp', type=int, default=400)
    parser.add_argument('--ep_len', '-ep', type=int, default=1000)
    # Neural network architecture args
    parser.add_argument('--n_layers', '-l', type=int, default=2)
    parser.add_argument('--size', '-s', type=int, default=256)
    # MPC Controller
    parser.add_argument('--mpc_horizon', '-m', type=int, default=10)

    parser.add_argument('--mpc', action='store_true')
    parser.add_argument('--bc', action='store_true')
    parser.add_argument('--ppo', action='store_true')

    args = parser.parse_args()

    assert (args.mpc or args.ppo) == True

    # Set seed
    np.random.seed(args.seed)
    tf.set_random_seed(args.seed)

    # Make data directory if it does not already exist
    if not(os.path.exists('data')):
        os.makedirs('data')
    logdir = args.exp_name + '_' + args.env_name + '_' + time.strftime("%d-%m-%Y_%H-%M-%S")
    logdir = args.exp_name + '_' + args.env_name

    logdir = os.path.join('data', logdir)
    if not(os.path.exists(logdir)):
        os.makedirs(logdir)

    # Make env
    if args.env_name is "HalfCheetah-v1":
        env = HalfCheetahEnvNew()
        cost_fn = cheetah_cost_fn
        
    train(env=env, 
                 cost_fn=cost_fn,
                 logdir=logdir,
                 render=args.render,
                 learning_rate=args.learning_rate,
                 onpol_iters=args.onpol_iters,
                 dynamics_iters=args.dyn_iters,
                 batch_size=args.batch_size,
                 num_paths_random=args.random_paths, 
                 num_paths_onpol=args.onpol_paths, 
                 num_simulated_paths=args.simulated_paths,
                 env_horizon=args.ep_len, 
                 mpc_horizon=args.mpc_horizon,
                 n_layers = args.n_layers,
                 size=args.size,
                 activation=tf.nn.relu,
                 output_activation=None,
                 clip_param = args.clip_param,
                 entcoeff = args.entcoeff,
                 gamma = args.gamma,
                 lam = args.lam,
                 optim_epochs = args.optim_epochs,
                 optim_batchsize = args.optim_batchsize,
                 schedule = args.schedule,
                 bc_lr = args.bc_lr,
                 ppo_lr = args.ppo_lr,
                 timesteps_per_actorbatch = args.timesteps_per_actorbatch,
                 MPC = args.mpc,
                 BEHAVIORAL_CLONING = args.bc,
                 PPO = args.ppo,
                 )
示例#8
0
def train_PG(
             exp_name='',
             env_name='',
             n_iter=100, 
             gamma=1.0, 
             min_timesteps_per_batch=1000, 
             max_path_length=None,
             learning_rate=5e-3, 
             reward_to_go=False, 
             animate=True, 
             logdir=None, 
             normalize_advantages=False,
             nn_baseline=False, 
             seed=0,
             # network arguments
             n_layers=1,
             size=32,

             # mb mpc arguments
             model_learning_rate=1e-3,
             onpol_iters=10,
             dynamics_iters=260,
             batch_size=512,
             num_paths_random=10, 
             num_paths_onpol=10, 
             num_simulated_paths=1000,
             env_horizon=1000, 
             mpc_horizon=10,
             m_n_layers=2,
             m_size=500,
             ):

    start = time.time()

    # Configure output directory for logging
    logz.configure_output_dir(logdir)

    # Log experimental parameters
    args = inspect.getargspec(train_PG)[0]
    locals_ = locals()
    params = {k: locals_[k] if k in locals_ else None for k in args}
    logz.save_params(params)

    # Set random seeds
    tf.set_random_seed(seed)
    np.random.seed(seed)

    # Make the gym environment
    # env = gym.make(env_name)
    env = HalfCheetahEnvNew()
    cost_fn = cheetah_cost_fn
    activation=tf.nn.relu
    output_activation=None

    # Is this env continuous, or discrete?
    discrete = isinstance(env.action_space, gym.spaces.Discrete)

    # Maximum length for episodes
    # max_path_length = max_path_length or env.spec.max_episode_steps
    max_path_length = max_path_length

    # Observation and action sizes
    ob_dim = env.observation_space.shape[0]
    ac_dim = env.action_space.n if discrete else env.action_space.shape[0]

    # Print environment infomation
    print("-------- env info --------")
    print("Environment name: ", env_name)
    print("Action space is discrete: ", discrete)
    print("Action space dim: ", ac_dim)
    print("Observation space dim: ", ob_dim)
    print("Max_path_length ", max_path_length)




    #========================================================================================#
    # Random data collection
    #========================================================================================#

    random_controller = RandomController(env)
    data_buffer_model = DataBuffer()
    data_buffer_ppo = DataBuffer_general(10000, 4)

    # sample path
    print("collecting random data .....  ")
    paths = sample(env, 
               random_controller, 
               num_paths=num_paths_random, 
               horizon=env_horizon, 
               render=False,
               verbose=False)

    # add into buffer
    for path in paths:
        for n in range(len(path['observations'])):
            data_buffer_model.add(path['observations'][n], path['actions'][n], path['next_observations'][n])

    print("data buffer size: ", data_buffer_model.size)

    normalization = compute_normalization(data_buffer_model)

    #========================================================================================#
    # Tensorflow Engineering: Config, Session, Variable initialization
    #========================================================================================#
    tf_config = tf.ConfigProto() 
    tf_config.allow_soft_placement = True
    tf_config.intra_op_parallelism_threads =4
    tf_config.inter_op_parallelism_threads = 1
    sess = tf.Session(config=tf_config)

    dyn_model = NNDynamicsModel(env=env, 
                                n_layers=n_layers, 
                                size=size, 
                                activation=activation, 
                                output_activation=output_activation, 
                                normalization=normalization,
                                batch_size=batch_size,
                                iterations=dynamics_iters,
                                learning_rate=learning_rate,
                                sess=sess)

    mpc_controller = MPCcontroller(env=env, 
                                   dyn_model=dyn_model, 
                                   horizon=mpc_horizon, 
                                   cost_fn=cost_fn, 
                                   num_simulated_paths=num_simulated_paths)


    policy_nn = policy_network_ppo(sess, ob_dim, ac_dim, discrete, n_layers, size, learning_rate)

    if nn_baseline:
        value_nn = value_network(sess, ob_dim, n_layers, size, learning_rate)

    sess.__enter__() # equivalent to `with sess:`

    tf.global_variables_initializer().run()


    #========================================================================================#
    # Training Loop
    #========================================================================================#

    total_timesteps = 0

    for itr in range(n_iter):
        print("********** Iteration %i ************"%itr)

        if MPC:
            dyn_model.fit(data_buffer_model)
        returns = []
        costs = []

        # Collect paths until we have enough timesteps
        timesteps_this_batch = 0
        paths = []

        while True:
            # print("data buffer size: ", data_buffer_model.size)
            current_path = {'observations': [], 'actions': [], 'reward': [], 'next_observations':[]}

            ob = env.reset()
            obs, acs, mpc_acs, rewards = [], [], [], []
            animate_this_episode=(len(paths)==0 and (itr % 10 == 0) and animate)
            steps = 0
            return_ = 0
 
            while True:
                # print("steps ", steps)
                if animate_this_episode:
                    env.render()
                    time.sleep(0.05)
                obs.append(ob)

                if MPC:
                    mpc_ac = mpc_controller.get_action(ob)
                else:
                    mpc_ac = random_controller.get_action(ob)

                ac = policy_nn.predict(ob, mpc_ac)

                ac = ac[0]

                if not PG:
                    ac = mpc_ac

                acs.append(ac)
                mpc_acs.append(mpc_ac)

                current_path['observations'].append(ob)

                ob, rew, done, _ = env.step(ac)

                current_path['reward'].append(rew)
                current_path['actions'].append(ac)
                current_path['next_observations'].append(ob)

                return_ += rew
                rewards.append(rew)

                steps += 1
                if done or steps > max_path_length:
                    break


            if MPC:
                # cost & return
                cost = path_cost(cost_fn, current_path)
                costs.append(cost)
                returns.append(return_)
                print("total return: ", return_)
                print("costs: ", cost)

                # add into buffers
                for n in range(len(current_path['observations'])):
                    data_buffer_model.add(current_path['observations'][n], current_path['actions'][n], current_path['next_observations'][n])

            for n in range(len(current_path['observations'])):
                data_buffer_ppo.add(current_path['observations'][n], current_path['actions'][n], current_path['reward'][n], current_path['next_observations'][n])
        
            path = {"observation" : np.array(obs), 
                    "reward" : np.array(rewards), 
                    "action" : np.array(acs),
                    "mpc_action" : np.array(mpc_acs)}



            paths.append(path)
            timesteps_this_batch += pathlength(path)
            # print("timesteps_this_batch", timesteps_this_batch)
            if timesteps_this_batch > min_timesteps_per_batch:
                break
        total_timesteps += timesteps_this_batch


        print("data_buffer_ppo.size:", data_buffer_ppo.size)


        # Build arrays for observation, action for the policy gradient update by concatenating 
        # across paths
        ob_no = np.concatenate([path["observation"] for path in paths])
        ac_na = np.concatenate([path["action"] for path in paths])
        mpc_ac_na = np.concatenate([path["mpc_action"] for path in paths])


        # Computing Q-values
     
        if reward_to_go:
            q_n = []
            for path in paths:
                for t in range(len(path["reward"])):
                    t_ = 0
                    q = 0
                    while t_ < len(path["reward"]):
                        if t_ >= t:
                            q += gamma**(t_-t) * path["reward"][t_]
                        t_ += 1
                    q_n.append(q)
            q_n = np.asarray(q_n)

        else:
            q_n = []
            for path in paths:
                for t in range(len(path["reward"])):
                    t_ = 0
                    q = 0
                    while t_ < len(path["reward"]):
                        q += gamma**t_ * path["reward"][t_]
                        t_ += 1
                    q_n.append(q)
            q_n = np.asarray(q_n)


        # Computing Baselines
        if nn_baseline:

            # b_n = sess.run(baseline_prediction, feed_dict={sy_ob_no :ob_no})
            b_n = value_nn.predict(ob_no)
            b_n = normalize(b_n)
            b_n = denormalize(b_n, np.std(q_n), np.mean(q_n))
            adv_n = q_n - b_n
        else:
            adv_n = q_n.copy()

        # Advantage Normalization
        if normalize_advantages:
            adv_n = normalize(adv_n)

        # Optimizing Neural Network Baseline
        if nn_baseline:
            b_n_target = normalize(q_n)
            value_nn.fit(ob_no, b_n_target)
                # sess.run(baseline_update_op, feed_dict={sy_ob_no :ob_no, sy_baseline_target_n:b_n_target})


        # Performing the Policy Update

        # policy_nn.fit(ob_no, ac_na, adv_n)
        policy_nn.fit(ob_no, ac_na, adv_n, mpc_ac_na)

        # sess.run(update_op, feed_dict={sy_ob_no :ob_no, sy_ac_na:ac_na, sy_adv_n:adv_n})

        # Log diagnostics
        returns = [path["reward"].sum() for path in paths]
        ep_lengths = [pathlength(path) for path in paths]
        logz.log_tabular("Time", time.time() - start)
        logz.log_tabular("Iteration", itr)
        logz.log_tabular("AverageReturn", np.mean(returns))
        logz.log_tabular("StdReturn", np.std(returns))
        logz.log_tabular("MaxReturn", np.max(returns))
        logz.log_tabular("MinReturn", np.min(returns))
        logz.log_tabular("EpLenMean", np.mean(ep_lengths))
        logz.log_tabular("EpLenStd", np.std(ep_lengths))
        logz.log_tabular("TimestepsThisBatch", timesteps_this_batch)
        logz.log_tabular("TimestepsSoFar", total_timesteps)
        logz.dump_tabular()
        logz.pickle_tf_vars()