コード例 #1
0
def train(env_id, gpu, num_timesteps, seed, config):
    from ppo.ppo_rl import PPO
    set_global_seeds(seed, gpu)
    env = gym.make(env_id)
    env = bench.Monitor(
        env,
        logger.get_dir() and osp.join(logger.get_dir(), "monitor.json"))
    env.seed(seed)
    gym.logger.setLevel(logging.WARN)
    if hasattr(config, 'wrap_env_fn'):
        env = config.wrap_env_fn(env)
        env.seed(seed)
    ppo_rl = PPO(env,
                 gpu=gpu,
                 policy=config.policy,
                 timesteps_per_batch=config.timesteps_per_batch,
                 clip_param=config.clip_param,
                 entcoeff=config.entcoeff,
                 optim_epochs=config.optim_epochs,
                 optim_stepsize=config.optim_stepsize,
                 optim_batchsize=config.optim_batchsize,
                 gamma=config.gamma,
                 lam=config.lam,
                 max_timesteps=num_timesteps,
                 schedule=config.schedule)
    ppo_rl.run()
    env.close()
コード例 #2
0
def train(env, gpu, num_timesteps, seed, config, log_dir, load_path):
    from ppo.ppo_rl import PPO
    set_global_seeds(seed, gpu)
    env = bench.Monitor(env,
                        logger.get_dir()
                        and osp.join(logger.get_dir(), "monitor.json"),
                        allow_early_resets=True)
    env.seed(seed)
    gym.logger.setLevel(logging.WARN)
    if hasattr(config, 'wrap_env_fn'):
        env = config.wrap_env_fn(env)
        env.seed(seed)
    ppo_rl = PPO(env,
                 gpu=gpu,
                 policy=config.policy,
                 prob_dist=config.prob_dist,
                 num_hid_layers=config.num_hid_layers,
                 hid_size=config.hid_size,
                 timesteps_per_batch=config.timesteps_per_batch,
                 clip_param=config.clip_param,
                 beta=config.beta,
                 entcoeff=config.entcoeff,
                 optim_epochs=config.optim_epochs,
                 optim_stepsize=config.optim_stepsize,
                 optim_batchsize=config.optim_batchsize,
                 gamma=config.gamma,
                 lam=config.lam,
                 max_timesteps=num_timesteps,
                 schedule=config.schedule,
                 record_video_freq=config.record_video_freq,
                 log_dir=log_dir,
                 load_path=load_path)
    ppo_rl.run()
    env.close()
コード例 #3
0
 def save_model(self, modelfn=None):
     modelfn = modelfn if modelfn else 'checkpoint.pt'
     modelpath = os.path.join(logger.get_dir(), 'models', modelfn)
     os.makedirs(os.path.join(logger.get_dir(), 'models'), exist_ok=True)
     state_dict = {
         'epoch': self._epoch + 1,
         'state_dict': self.ac.state_dict(),
         'optimizer': self.optimizer.state_dict(),
         'lr_scheduler': self.lr_scheduler.state_dict(),
     }
     torch.save(state_dict, modelpath)
コード例 #4
0
ファイル: run_atari.py プロジェクト: Baldwin054212/DQfD-1
def train(args):
    total_timesteps = int(args.num_timesteps)
    pre_train_timesteps = int(args.pre_train_timesteps)
    seed = args.seed

    env = make_env(args.env,
                   args.seed,
                   args.max_episode_steps,
                   wrapper_kwargs={'frame_stack': True})
    if args.save_video_interval != 0:
        env = Monitor(env,
                      osp.join(logger.get_dir(), "videos"),
                      video_callable=(lambda ep: ep % 1 == 0),
                      force=True)
    model = dqfd.learn(
        env=env,
        network='cnn',
        checkpoint_path=args.save_path,
        seed=seed,
        total_timesteps=total_timesteps,
        pre_train_timesteps=pre_train_timesteps,
        load_path=args.load_path,
        demo_path=args.demo_path,
    )

    return model, env
コード例 #5
0
ファイル: run_atari.py プロジェクト: Kokkini/DQfD
def train(args):
    total_timesteps = int(args.num_timesteps)
    pre_train_timesteps = int(args.pre_train_timesteps)
    seed = args.seed

    env = make_env(args.env,
                   args.seed,
                   args.max_episode_steps,
                   wrapper_kwargs={
                       'frame_stack': True,
                       'episode_life': True
                   })
    if args.save_video_interval != 0:
        env = Monitor(
            env,
            osp.join(logger.get_dir(), "videos"),
            video_callable=(lambda ep: ep % args.save_video_interval == 0),
            force=True)
    model = dqfd.learn(env=env,
                       network='cnn',
                       checkpoint_path=args.save_path,
                       seed=seed,
                       total_timesteps=total_timesteps,
                       pre_train_timesteps=pre_train_timesteps,
                       load_path=args.load_path,
                       demo_path=args.demo_path,
                       buffer_size=int(args.buffer_size),
                       batch_size=args.batch_size,
                       exploration_fraction=args.exploration_fraction,
                       exploration_final_eps=args.exploration_final_eps,
                       epsilon_schedule=args.epsilon_schedule,
                       lr=args.lr,
                       print_freq=args.print_freq)

    return model, env
コード例 #6
0
 def load_model(self, model_path=None):
     if model_path is None:
         model_path = os.path.join(logger.get_dir(), 'models', 'checkpoint.pt')
     state_dict = torch.load(model_path)
     self._epoch = state_dict['epoch']
     self.ac.load_state_dict(state_dict['state_dict'])
     self.optimizer.load_state_dict(state_dict['optimizer'])
     self.lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
コード例 #7
0
    def save(self, path=None):
        """Save model to a pickle located at `path`"""
        if path is None:
            path = os.path.join(logger.get_dir(), "model.pkl")

        with tempfile.TemporaryDirectory() as td:
            self.save_state(os.path.join(td, "model"))
            arc_name = os.path.join(td, "packed.zip")
            with zipfile.ZipFile(arc_name, 'w') as zipf:
                for root, dirs, files in os.walk(td):
                    for fname in files:
                        file_path = os.path.join(root, fname)
                        if file_path != arc_name:
                            zipf.write(file_path,
                                       os.path.relpath(file_path, td))
            with open(arc_name, "rb") as f:
                model_data = f.read()
        with open(path, "wb") as f:
            cloudpickle.dump((model_data, self._act_params), f)
コード例 #8
0
ファイル: main.py プロジェクト: jiameij/vae-for-IL
def main(args):
    from ppo1 import mlp_policy
    U.make_session(num_cpu=args.num_cpu).__enter__()
    set_global_seeds(args.seed)
    env = gym.make(args.env_id)
    def policy_fn(name, ob_space, ac_space, reuse=False):
        return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            reuse=reuse, hid_size=64, num_hid_layers=2)
    env = bench.Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), "monitor.json"))
    env.seed(args.seed)
    gym.logger.setLevel(logging.WARN)
    task_name = get_task_name(args)
    args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
    args.log_dir = osp.join(args.log_dir, task_name)
    dataset = Mujoco_Dset(expert_path=args.expert_path, ret_threshold=args.ret_threshold, traj_limitation=args.traj_limitation)
    pretrained_weight = None
    if (args.pretrained and args.task == 'train') or args.algo == 'bc':
        # Pretrain with behavior cloning
        from gailtf.algo import behavior_clone
        if args.algo == 'bc' and args.task == 'evaluate':
            behavior_clone.evaluate(env, policy_fn, args.load_model_path, stochastic_policy=args.stochastic_policy)
            sys.exit()
        pretrained_weight = behavior_clone.learn(env, policy_fn, dataset,
            max_iters=args.BC_max_iter, pretrained=args.pretrained,
            ckpt_dir=args.checkpoint_dir, log_dir=args.log_dir, task_name=task_name)
        if args.algo == 'bc':
            sys.exit()

    from gailtf.network.adversary import TransitionClassifier
    # discriminator
    discriminator = TransitionClassifier(env, args.adversary_hidden_size, entcoeff=args.adversary_entcoeff)
    if args.algo == 'trpo':
        # Set up for MPI seed
        from mpi4py import MPI
        rank = MPI.COMM_WORLD.Get_rank()
        if rank != 0:
            logger.set_level(logger.DISABLED)
        workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
        set_global_seeds(workerseed)
        env.seed(workerseed)
        from gailtf.algo import trpo_mpi
        if args.task == 'train':
            trpo_mpi.learn(env, policy_fn, discriminator, dataset,
                pretrained=args.pretrained, pretrained_weight=pretrained_weight,
                g_step=args.g_step, d_step=args.d_step,
                timesteps_per_batch=1024,
                max_kl=args.max_kl, cg_iters=10, cg_damping=0.1,
                max_timesteps=args.num_timesteps,
                entcoeff=args.policy_entcoeff, gamma=0.995, lam=0.97,
                vf_iters=5, vf_stepsize=1e-3,
                ckpt_dir=args.checkpoint_dir, log_dir=args.log_dir,
                save_per_iter=args.save_per_iter, load_model_path=args.load_model_path,
                task_name=task_name)
        elif args.task == 'evaluate':
            trpo_mpi.evaluate(env, policy_fn, args.load_model_path, timesteps_per_batch=1024,
                number_trajs=10, stochastic_policy=args.stochastic_policy)
        else: raise NotImplementedError
    elif args.algo == 'ppo':
        # Set up for MPI seed
        from mpi4py import MPI
        rank = MPI.COMM_WORLD.Get_rank()
        if rank != 0:
            logger.set_level(logger.DISABLED)
        workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank()
        set_global_seeds(workerseed)
        env.seed(workerseed)
        from gailtf.algo import ppo_mpi
        if args.task == 'train':
            ppo_mpi.learn(env, policy_fn, discriminator, dataset,
                           # pretrained=args.pretrained, pretrained_weight=pretrained_weight,
                           timesteps_per_batch=1024,
                           g_step=args.g_step, d_step=args.d_step,
                           # max_kl=args.max_kl, cg_iters=10, cg_damping=0.1,
                           clip_param= 0.2,entcoeff=args.policy_entcoeff,
                           max_timesteps=args.num_timesteps,
                            gamma=0.99, lam=0.95,
                           # vf_iters=5, vf_stepsize=1e-3,
                            optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
                          d_stepsize=3e-4,
                          schedule='linear', ckpt_dir=args.checkpoint_dir,
                          save_per_iter=100, task=args.task,
                          sample_stochastic=args.stochastic_policy,
                          load_model_path=args.load_model_path,
                          task_name=task_name)
        elif args.task == 'evaluate':
            ppo_mpi.evaluate(env, policy_fn, args.load_model_path, timesteps_per_batch=1024,
                              number_trajs=10, stochastic_policy=args.stochastic_policy)
        else:
            raise NotImplementedError
    else: raise NotImplementedError

    env.close()
コード例 #9
0
ファイル: training.py プロジェクト: huschen/walk_prosthetics
def train(env,
          eval_env,
          agent,
          render=False,
          render_eval=False,
          sanity_run=False,
          nb_epochs=500,
          nb_epoch_cycles=20,
          nb_rollout_steps=100,
          nb_train_steps=50,
          param_noise_adaption_interval=50,
          hist_files=None,
          start_ckpt=None,
          demo_files=None):

    rank = MPI.COMM_WORLD.Get_rank()
    mpi_size = MPI.COMM_WORLD.Get_size()
    if rank == 0:
        logdir = logger.get_dir()
    else:
        logdir = None

    memory = agent.memory
    batch_size = agent.batch_size

    with tf_util.single_threaded_session() as sess:
        # Prepare everything.
        agent.initialize(sess, start_ckpt=start_ckpt)
        sess.graph.finalize()
        agent.reset()
        dbg_tf_init(sess, agent.dbg_vars)

        total_nb_train = 0
        total_nb_rollout = 0
        total_nb_eval = 0

        # pre-train demo and critic_step
        # train_params: (nb_steps, lr_scale)
        total_nb_train = pretrain_demo(agent,
                                       env,
                                       demo_files,
                                       total_nb_train,
                                       train_params=[(100, 1.0)],
                                       start_ckpt=start_ckpt)
        load_history(agent, env, hist_files)

        # main training
        obs = env.reset()
        reset = False
        episode_step = 0
        last_episode_step = 0

        for i_epoch in range(nb_epochs):
            t_epoch_start = time.time()
            logger.info('\n%s epoch %d starts:' %
                        (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                         i_epoch))
            for i_cycle in range(nb_epoch_cycles):
                logger.info(
                    '\n%s cycles_%d of epoch_%d' %
                    (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f'),
                     i_cycle, i_epoch))

                # rollout
                rcd_obs, rcd_action, rcd_r, rcd_new_obs, rcd_done = [], [], [], [], []
                if not sanity_run and mpi_size == 1 and last_episode_step != 0:
                    # todo: use mpi_max(last_episode_step)
                    # dynamically set nb_rollout_steps
                    nb_rollout_steps = max(last_episode_step * 4, batch_size)
                logger.info(
                    '[%d, %d] rollout for %d steps.' %
                    (total_nb_rollout, memory.nb_entries, nb_rollout_steps))
                t_rollout_start = time.time()

                for i_rollout in range(nb_rollout_steps):
                    rollout_log = i_cycle == 0
                    # 50% param_noise, 40% action_noise
                    action, q = agent.pi(obs,
                                         total_nb_rollout,
                                         compute_Q=True,
                                         rollout_log=rollout_log,
                                         apply_param_noise=i_rollout % 10 < 5,
                                         apply_action_noise=i_rollout % 10 > 5)
                    assert action.shape == env.action_space.shape
                    new_obs, r, done, reset, info = env.step(action)

                    if rank == 0 and render:
                        env.render()

                    episode_step += 1
                    total_nb_rollout += 1

                    if rollout_log:
                        summary_list = [('rollout/%s' % tp, info[tp])
                                        for tp in ['rwd_walk', 'rwd_total']]
                        tp = 'rwd_agent'
                        summary_list += [
                            ('rollout/%s_x%d' % (tp, info['rf_agent']),
                             info[tp] * info['rf_agent'])
                        ]
                        summary_list += [('rollout/q', q)]
                        if r != 0:
                            summary_list += [('rollout/q_div_r', q / r)]
                        agent.add_list_summary(summary_list, total_nb_rollout)

                    # store at the end of cycle to speed up MPI rollout
                    # agent.store_transition(obs, action, r, new_obs, done)
                    rcd_obs.append(obs)
                    rcd_action.append(action)
                    rcd_r.append(r)
                    rcd_new_obs.append(new_obs)
                    rcd_done.append(done)

                    obs = new_obs
                    if reset:
                        # Episode done.
                        last_episode_step = episode_step
                        episode_step = 0

                        agent.reset()
                        obs = env.reset()

                agent.store_multrans(memory, rcd_obs, rcd_action, rcd_r,
                                     rcd_new_obs, rcd_done)

                t_train_start = time.time()
                steps_per_second = float(nb_rollout_steps) / (t_train_start -
                                                              t_rollout_start)
                agent.add_list_summary(
                    [('rollout/steps_per_second', steps_per_second)],
                    total_nb_rollout)

                # Train.
                if not sanity_run:
                    # dynamically set nb_train_steps
                    if memory.nb_entries > batch_size * 20:
                        # using 1% of data for training every step?
                        nb_train_steps = max(
                            int(memory.nb_entries * 0.01 / batch_size), 1)
                    else:
                        nb_train_steps = 0
                logger.info('[%d] training for %d steps.' %
                            (total_nb_train, nb_train_steps))
                for _ in range(nb_train_steps):
                    # Adapt param noise, if necessary.
                    if memory.nb_entries >= batch_size and total_nb_train % param_noise_adaption_interval == 0:
                        agent.adapt_param_noise(total_nb_train)

                    agent.train_main(total_nb_train)
                    agent.update_target_net()
                    total_nb_train += 1

                if i_epoch == 0 and i_cycle < 5:
                    rollout_duration = t_train_start - t_rollout_start
                    train_duration = time.time() - t_train_start
                    logger.info(
                        'rollout_time(%d) = %.3fs, train_time(%d) = %.3fs' %
                        (nb_rollout_steps, rollout_duration, nb_train_steps,
                         train_duration))
                    logger.info(
                        'rollout_speed=%.3fs/step, train_speed = %.3fs/step' %
                        (np.divide(rollout_duration, nb_rollout_steps),
                         np.divide(train_duration, nb_train_steps)))

            logger.info('')
            mpi_size = MPI.COMM_WORLD.Get_size()
            # Log stats.
            stats = agent.get_stats(memory)
            combined_stats = stats.copy()

            def as_scalar(x):
                if isinstance(x, np.ndarray):
                    assert x.size == 1
                    return x[0]
                elif np.isscalar(x):
                    return x
                else:
                    raise ValueError('expected scalar, got %s' % x)

            combined_stats_sums = MPI.COMM_WORLD.allreduce(
                np.array([as_scalar(x) for x in combined_stats.values()]))
            combined_stats = {
                k: v / mpi_size
                for (k, v) in zip(combined_stats.keys(), combined_stats_sums)
            }

            # exclude logging zobs_dbg_%d, zobs_dbg_%d_normalized
            summary_list = [(key, combined_stats[key])
                            for key, v in combined_stats.items()
                            if 'dbg' not in key]
            agent.add_list_summary(summary_list, i_epoch)

            # only print out train stats for epoch_0 for sanity check
            if i_epoch > 0:
                combined_stats = {}

            # Evaluation and statistics.
            if eval_env is not None:
                logger.info('[%d, %d] run evaluation' %
                            (i_epoch, total_nb_eval))
                total_nb_eval = eval_episode(eval_env, render_eval, agent,
                                             combined_stats, total_nb_eval)

            logger.info('epoch %d duration: %.2f mins' %
                        (i_epoch, (time.time() - t_epoch_start) / 60))
            for key in sorted(combined_stats.keys()):
                logger.record_tabular(key, combined_stats[key])
            logger.dump_tabular()
            logger.info('')

            if rank == 0:
                agent.store_ckpt(os.path.join(logdir, '%s.ckpt' % 'ddpg'),
                                 i_epoch)
コード例 #10
0
def run(mode, render, render_eval, verbose_eval, sanity_run, env_kwargs,
        model_kwargs, train_kwargs):
    if sanity_run:
        # Mode to sanity check the basic code.
        # Fixed seed and logging dir.
        # Dynamic setting of nb_rollout_steps and nb_train_steps in training.train() is disabled.
        print('SANITY CHECK MODE!!!')

    # Configure MPI, logging, random seeds, etc.
    mpi_rank = MPI.COMM_WORLD.Get_rank()
    mpi_size = MPI.COMM_WORLD.Get_size()

    if mpi_rank == 0:
        logger.configure(dir='logs' if sanity_run else datetime.datetime.now().
                         strftime("train_%m%d_%H%M"))
        logdir = logger.get_dir()
    else:
        logger.set_level(logger.DISABLED)
        logdir = None
    logdir = MPI.COMM_WORLD.bcast(logdir, root=0)

    start_time = time.time()
    # fixed seed when running sanity check, same seed hourly for training.
    seed = 1000000 * mpi_rank
    seed += int(start_time) // 3600 if not sanity_run else 0

    seed_list = MPI.COMM_WORLD.gather(seed, root=0)
    logger.info('mpi_size {}: seeds={}, logdir={}'.format(
        mpi_size, seed_list, logger.get_dir()))

    # Create envs.
    envs = []
    if mode in [MODE_TRAIN]:
        train_env = cust_env.ProsEnvMon(
            visualize=render,
            seed=seed,
            fn_step=None,
            fn_epis=logdir and os.path.join(logdir, '%d' % mpi_rank),
            reset_dflt_interval=2,
            **env_kwargs)
        logger.info('action, observation space:', train_env.action_space.shape,
                    train_env.observation_space.shape)
        envs.append(train_env)
    else:
        train_env = None

    # Always run eval_env, either in evaluation mode during MODE_TRAIN, or MODE_SAMPLE, MODE_TEST.
    # Reset to random states (reset_dflt_interval=0) in MODE_SAMPLE ,
    # Reset to default state (reset_dflt_interval=1) in evaluation of MODE_TRAIN, or MODE_TEST
    reset_dflt_interval = 0 if mode in [MODE_SAMPLE] else 1
    eval_env = cust_env.ProsEnvMon(
        visualize=render_eval,
        seed=seed,
        fn_step=logdir and os.path.join(logdir, 'eval_step_%d.csv' % mpi_rank),
        fn_epis=logdir and os.path.join(logdir, 'eval_%d' % mpi_rank),
        reset_dflt_interval=reset_dflt_interval,
        verbose=verbose_eval,
        **env_kwargs)
    envs.append(eval_env)

    # Create DDPG agent
    tf.reset_default_graph()
    set_global_seeds(seed)
    assert (eval_env is not None), 'Empty Eval Environment!'

    action_range = (min(eval_env.action_space.low),
                    max(eval_env.action_space.high))
    logger.info('\naction_range', action_range)
    nb_demo_kine, nb_key_states = eval_env.obs_cust_params
    agent = ddpg.DDPG(eval_env.observation_space.shape,
                      eval_env.action_space.shape,
                      nb_demo_kine,
                      nb_key_states,
                      action_range=action_range,
                      save_ckpt=mpi_rank == 0,
                      **model_kwargs)
    logger.debug('Using agent with the following configuration:')
    logger.debug(str(agent.__dict__.items()))

    # Set up agent mimic reward interface, for environment
    for env in envs:
        env.set_agent_intf_fp(agent.get_mimic_rwd)

    # Run..
    logger.info('\nEnv params:', env_kwargs)
    logger.info('Model params:', model_kwargs)
    if mode == MODE_TRAIN:
        logger.info('Start training', train_kwargs)
        training.train(train_env,
                       eval_env,
                       agent,
                       render=render,
                       render_eval=render_eval,
                       sanity_run=sanity_run,
                       **train_kwargs)

    elif mode == MODE_SAMPLE:
        sampling.sample(eval_env, agent, render=render_eval, **train_kwargs)
    else:
        training.test(eval_env, agent, render_eval=render_eval, **train_kwargs)

    # Close up.
    if train_env:
        train_env.close()
    if eval_env:
        eval_env.close()

    mpi_complete(start_time, mpi_rank, mpi_size, non_blocking_mpi=True)
コード例 #11
0
    def __init__(self,
                 observation_shape,
                 action_shape,
                 nb_demo_kine,
                 nb_key_states,
                 batch_size=128,
                 noise_type='',
                 actor=None,
                 critic=None,
                 layer_norm=True,
                 observation_range=(-5., 5.),
                 action_range=(-1., 1.),
                 return_range=(-np.inf, np.inf),
                 normalize_returns=False,
                 normalize_observations=True,
                 reward_scale=1.,
                 clip_norm=None,
                 demo_l2_reg=0.,
                 critic_l2_reg=0.,
                 actor_lr=1e-4,
                 critic_lr=1e-3,
                 demo_lr=5e-3,
                 gamma=0.99,
                 tau=0.001,
                 enable_popart=False,
                 save_ckpt=True):

        # Noise
        nb_actions = action_shape[-1]
        param_noise, action_noise = process_noise_type(noise_type, nb_actions)

        logger.info('param_noise', param_noise)
        logger.info('action_noise', action_noise)

        # States recording
        self.memory = Memory(limit=int(2e5),
                             action_shape=action_shape,
                             observation_shape=observation_shape)

        # Models
        self.nb_demo_kine = nb_demo_kine
        self.actor = actor or Actor(
            nb_actions, nb_demo_kine, layer_norm=layer_norm)
        self.nb_key_states = nb_key_states
        self.critic = critic or Critic(nb_key_states, layer_norm=layer_norm)
        self.nb_obs_org = nb_key_states

        # Inputs.
        self.obs0 = tf.placeholder(tf.float32,
                                   shape=(None, ) + observation_shape,
                                   name='obs0')
        self.obs1 = tf.placeholder(tf.float32,
                                   shape=(None, ) + observation_shape,
                                   name='obs1')
        self.terminals1 = tf.placeholder(tf.float32,
                                         shape=(None, 1),
                                         name='terminals1')
        self.rewards = tf.placeholder(tf.float32,
                                      shape=(None, 1),
                                      name='rewards')
        self.actions = tf.placeholder(tf.float32,
                                      shape=(None, ) + action_shape,
                                      name='actions')
        # self.critic_target_Q: value assigned by self.target_Q_obs0
        self.critic_target_Q = tf.placeholder(tf.float32,
                                              shape=(None, 1),
                                              name='critic_target_Q')
        self.param_noise_stddev = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='param_noise_stddev')

        # change in observations
        self.obs_delta_kine = (self.obs1 - self.obs0)[:, :self.nb_demo_kine]
        self.obs_delta_kstates = (self.obs1 -
                                  self.obs0)[:, :self.nb_key_states]

        # Parameters.
        self.gamma = gamma
        self.tau = tau
        self.normalize_observations = normalize_observations
        self.normalize_returns = normalize_returns
        self.action_noise = action_noise
        self.param_noise = param_noise
        self.action_range = action_range
        self.return_range = return_range
        self.observation_range = observation_range

        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.demo_lr = demo_lr
        self.clip_norm = clip_norm
        self.enable_popart = enable_popart
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.stats_sample = None
        self.critic_l2_reg = critic_l2_reg
        self.demo_l2_reg = demo_l2_reg

        # Observation normalization.
        if self.normalize_observations:
            with tf.variable_scope('obs_rms'):
                self.obs_rms = RunningMeanStd(shape=observation_shape)
        else:
            self.obs_rms = None

        self.normalized_obs0 = tf.clip_by_value(
            obs_norm_partial(self.obs0, self.obs_rms, self.nb_obs_org),
            self.observation_range[0], self.observation_range[1])
        normalized_obs1 = tf.clip_by_value(
            obs_norm_partial(self.obs1, self.obs_rms, self.nb_obs_org),
            self.observation_range[0], self.observation_range[1])

        # Return normalization.
        if self.normalize_returns:
            with tf.variable_scope('ret_rms'):
                self.ret_rms = RunningMeanStd()
        else:
            self.ret_rms = None

        # Create target networks.
        target_actor = copy(self.actor)
        target_actor.name = 'target_actor'
        self.target_actor = target_actor
        target_critic = copy(self.critic)
        target_critic.name = 'target_critic'
        self.target_critic = target_critic

        # Create networks and core TF parts that are shared across set-up parts.
        # the actor output is [0,1], need to normalised to [-1,1] before feeding into critic
        self.actor_tf, self.demo_aprx = self.actor(self.normalized_obs0)

        # critic loss
        # normalized_critic_tf, pred_rwd, pred_obs_delta: critic_loss
        self.normalized_critic_tf, self.pred_rwd, self.pred_obs_delta = self.critic(
            self.normalized_obs0, act_norm(self.actions))
        # self.critic_tf: only in logging [reference_Q_mean/std]
        self.critic_tf = ret_denormalize(
            tf.clip_by_value(self.normalized_critic_tf, self.return_range[0],
                             self.return_range[1]), self.ret_rms)

        # actor loss
        normalized_critic_with_actor_tf = self.critic(self.normalized_obs0,
                                                      act_norm(self.actor_tf),
                                                      reuse=True)[0]
        # self.critic_with_actor_tf: actor loss, and logging [reference_Q_tf_mean/std]
        self.critic_with_actor_tf = ret_denormalize(
            tf.clip_by_value(normalized_critic_with_actor_tf,
                             self.return_range[0], self.return_range[1]),
            self.ret_rms)

        # target Q
        self.target_action = tf.clip_by_value(
            target_actor(normalized_obs1)[0], self.action_range[0],
            self.action_range[1])
        self.target_Q_obs1 = ret_denormalize(
            target_critic(normalized_obs1, act_norm(self.target_action))[0],
            self.ret_rms)
        self.target_Q_obs0 = self.rewards + (
            1. - self.terminals1) * gamma * self.target_Q_obs1

        # Set up parts.
        if self.param_noise is not None:
            self.setup_param_noise(self.normalized_obs0)

        self.setup_actor_optimizer()
        self.setup_critic_optimizer()
        if self.normalize_returns and self.enable_popart:
            self.setup_popart()
        self.setup_stats()
        self.setup_target_network_updates()
        self.dbg_vars = self.actor.dbg_vars + self.critic.dbg_vars

        self.sess = None
        # Set up checkpoint saver
        self.save_ckpt = save_ckpt
        if save_ckpt:
            self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=20)
        else:
            # saver for loading ckpt
            self.saver = tf.train.Saver()

        self.main_summaries = tf.summary.merge_all()
        logdir = logger.get_dir()
        if logdir:
            self.train_writer = tf.summary.FileWriter(
                os.path.join(logdir, 'tb'), tf.get_default_graph())
        else:
            self.train_writer = None