Ejemplo n.º 1
0
def simulate(path, color_output="color_state.gif", object_output="object_state.gif"):
    sess = tf.Session()
    graph = restore_tf_graph(sess, path)
    env = GoalGridWorld()

    state = env.reset(train=False)
    running = True
    count = 0
    color_images, object_images = [], []
    
    while running:
        a, _ = sess.run([graph['pi'], graph['v']], feed_dict={graph['x']: state.reshape(1,-1)})
        state, reward, done, _ = env.step(a[0])

        color_images = render(env.state, color_images)
        object_images = render(env.object_state, object_images)
        running = not done

        count += 1 
        if count > 100:
            break

    save_gif(color_images, path=color_output)
    save_gif(object_images, path=object_output)
    print("____________________________")
    print("Target: {}".format(env.target))
    print("Reward: {}".format(reward))
    print("____________________________")
def load_policy(fpath, itr='last', deterministic=False):

    # handle which epoch to load from
    if itr=='last':
        saves = [int(x[11:]) for x in os.listdir(fpath) if 'simple_save' in x and len(x)>11]
        itr = '%d'%max(saves) if len(saves) > 0 else ''
    else:
        itr = '%d'%itr

    # load the things!
    sess = tf.Session()
    model = restore_tf_graph(sess, osp.join(fpath, 'simple_save'+itr))

    # get the correct op for executing actions
    if deterministic and 'mu' in model.keys():
        # 'deterministic' is only a valid option for SAC policies
        print('Using deterministic action op.')
        action_op = model['mu']
    else:
        print('Using default action op.')
        action_op = model['pi']

    # make function for producing an action given a single state
    get_action = lambda x : sess.run(action_op, feed_dict={model['x']: x[None,:]})[0]

    # try to load environment from save
    # (sometimes this will fail because the environment could not be pickled)
    try:
        state = joblib.load(osp.join(fpath, 'vars'+itr+'.pkl'))
        env = state['env']
    except:
        env = None

    return env, get_action
Ejemplo n.º 3
0
def load_policy(model_path, itr='last'):
    # handle which epoch to load from
    if itr == 'last':
        saves = [
            int(x[11:]) for x in os.listdir(model_path)
            if 'simple_save' in x and len(x) > 11
        ]
        itr = '%d' % max(saves) if len(saves) > 0 else ''
    else:
        itr = '%d' % itr

    # load the things!
    sess = tf.Session()
    model = restore_tf_graph(sess, osp.join(model_path, 'simple_save' + itr))

    # get the correct op for executing actions
    pi = model['pi']
    v = model['v']

    # make function for producing an action given a single state
    get_probs = lambda x, y: sess.run(
        pi,
        feed_dict={
            model['x']: x.reshape(-1, MAX_QUEUE_SIZE * JOB_FEATURES),
            model['mask']: y.reshape(-1, MAX_QUEUE_SIZE)
        })
    get_v = lambda x: sess.run(
        v,
        feed_dict={model['x']: x.reshape(-1, MAX_QUEUE_SIZE * JOB_FEATURES)})
    return get_probs, get_v
Ejemplo n.º 4
0
def load_policy(fpath, itr='last', deterministic=False):
    # handle which epoch to load from
    if itr == 'last':
        saves = [
            int(x[11:]) for x in os.listdir(fpath)
            if 'simple_save' in x and len(x) > 11
        ]
        itr = '%d' % max(saves) if len(saves) > 0 else ''
    else:
        itr = '%d' % itr

    # load the things!
    sess = tf.Session()
    model = restore_tf_graph(sess, osp.join(fpath, 'simple_save' + itr))

    # get the correct op for executing actions
    if deterministic and 'mu' in model.keys():
        # 'deterministic' is only a valid option for SAC policies
        print('Using deterministic action op.')
        action_op = model['mu']
    else:
        print('Using default action op.')
        action_op = model['pi']

    # make function for producing an action given a single state
    get_action = lambda x: \
        sess.run(action_op, feed_dict={model['x']: x[None, :]})[0]

    return get_action
Ejemplo n.º 5
0
def policy_loader(model_path, itr='last'):
    if itr == 'last':
        saves = [
            int(x[11:]) for x in os.listdir(model_path)
            if 'tf1_save' in x and len(x) > 11
        ]
        itr = '%d' % max(saves) if len(saves) > 0 else ''
    else:
        itr = '%d' % itr
    sess = tf.Session()
    model = restore_tf_graph(sess, osp.join(model_path, 'tf1_save' + itr))
    pi = model['pi']
    v = model['v']
    out = model['out']
    get_out = lambda x, y: sess.run(
        out,
        feed_dict={
            model['x']: x.reshape(-1, MAX_QUEUE_SIZE * TASK_FEATURES),
            model['mask']: y.reshape(-1, MAX_QUEUE_SIZE)
        })
    get_probs = lambda x, y: sess.run(
        pi,
        feed_dict={
            model['x']: x.reshape(-1, MAX_QUEUE_SIZE * TASK_FEATURES),
            model['mask']: y.reshape(-1, MAX_QUEUE_SIZE)
        })
    get_v = lambda x: sess.run(
        v,
        feed_dict={model['x']: x.reshape(-1, MAX_QUEUE_SIZE * TASK_FEATURES)})
    return get_probs, get_out
Ejemplo n.º 6
0
def load_policy(fpath=None,
                itr='last',
                deterministic=False,
                hidden_sizes=[64, 64],
                activation=tf.nn.tanh,
                output_activation=None,
                action_space=None):

    # handle which epoch to load from
    if itr == 'last':
        saves = [
            int(x[11:]) for x in os.listdir(fpath)
            if 'simple_save' in x and len(x) > 11
        ]
        itr = '%d' % max(saves) if len(saves) > 0 else ''
    else:
        itr = '%d' % itr

    # load the things!
    sess = tf.Session()
    model = restore_tf_graph(sess, osp.join(fpath, 'simple_save' + itr))

    if deterministic and 'mu' in model.keys():
        print('Using deterministic action op.')
        mu = model['mu']
    else:
        print('Using default action op.')
        mu = model['pi']

    x = model['x']
    a = model['a']
    act_dim = a.shape.as_list()[-1]
    act_limit = action_space.high
    # saver = tf.train.Saver()
    with tf.variable_scope('main'):
        # with tf.variable_scope('pi'):
        #     pi = act_limit * core.mlp(x, list(hidden_sizes)+[act_dim], activation, output_activation)
        # todo: load pi parameters from model after tf.initialize
        # pi = model['pi']
        # saver.restore(sess, osp.join(fpath, 'simple_save'+itr+'/variables'))

        with tf.variable_scope('q1'):
            q1 = tf.squeeze(core.mlp(tf.concat([x, a], axis=-1),
                                     list(hidden_sizes) + [1], activation,
                                     None),
                            axis=1)
        with tf.variable_scope('q2'):
            q2 = tf.squeeze(core.mlp(tf.concat([x, a], axis=-1),
                                     list(hidden_sizes) + [1], activation,
                                     None),
                            axis=1)
        with tf.variable_scope('q1', reuse=True):
            q1_pi = tf.squeeze(core.mlp(tf.concat([x, model['pi']], axis=-1),
                                        list(hidden_sizes) + [1], activation,
                                        None),
                               axis=1)
    return sess, model['x'], model['a'], model['pi'], q1, q2, q1_pi
Ejemplo n.º 7
0
def load_model(model_dir, model_save_name):

    sess = tf.compat.v1.Session(config=tf_config)

    model = restore_tf_graph(sess=sess,
                             fpath=os.path.join(model_dir, model_save_name))
    config = load_json_obj(os.path.join(model_dir, 'config'))

    if config['rl_params']['env_type'] == 'discrete':
        if 'sim' in config['rl_params']['platform']:
            from braille_rl.envs.sim.disc_sim_braille_env.mockKBGymEnv import mockKBGymEnv as disc_mockKBGymEnv
            env = disc_mockKBGymEnv(
                mode=config['rl_params']['env_mode'],
                max_steps=config['rl_params']['max_ep_len'])
        elif 'robot' in config['rl_params']['platform']:
            from braille_rl.envs.robot.disc_ur5_braille_env.ur5GymEnv import UR5GymEnv as disc_UR5GymEnv
            env = disc_UR5GymEnv(mode=config['rl_params']['env_mode'],
                                 max_steps=config['rl_params']['max_ep_len'])

    elif config['rl_params']['env_type'] == 'continuous':
        if 'sim' in config['rl_params']['platform']:
            from braille_rl.envs.sim.cont_sim_braille_env.mockKBGymEnv import mockKBGymEnv as cont_mockKBGymEnv
            env = cont_mockKBGymEnv(
                mode=config['rl_params']['env_mode'],
                max_steps=config['rl_params']['max_ep_len'])
        elif 'robot' in config['rl_params']['platform']:
            from braille_rl.envs.robot.cont_ur5_braille_env.ur5GymEnv import UR5GymEnv as cont_UR5GymEnv
            env = cont_UR5GymEnv(mode=config['rl_params']['env_mode'],
                                 max_steps=config['rl_params']['max_ep_len'])

    print('Config: ')
    print_sorted_dict(config)
    print('')
    print('')

    # open a file, where you stored the pickled data
    file = open(os.path.join(model_dir, 'vars' + '.pkl'), 'rb')
    saved_state = joblib.load(file)
    file.close()

    print('Resume State: ')
    print_sorted_dict(saved_state)
    print('')
    print('')

    return sess, model, config['logger_kwargs'], config['rl_params'], config[
        'network_params'], env, saved_state
def load_policy(fpath, itr='last', deterministic=False, act_high=1):

    # handle which epoch to load from
    if itr == 'last':
        saves = [
            int(x[11:]) for x in os.listdir(fpath)
            if 'simple_save' in x and len(x) > 11
        ]
        itr = '%d' % max(saves) if len(saves) > 0 else ''
    else:
        itr = '%d' % itr

    # load the things!
    sess = tf.Session()
    model = restore_tf_graph(sess, osp.join(fpath, 'simple_save' + itr))

    if deterministic and 'mu' in model.keys():
        print('Using deterministic action op.')
        mu = model['mu']
    else:
        print('Using default action op.')
        mu = model['pi']

    x = model['x']
    a = model['a']
    act_dim = a.shape.as_list()[-1]
    # log_std = tf.constant(0.01*act_high, dtype=tf.float32, shape=(act_dim,))
    log_std = tf.get_variable(name='log_std',
                              initializer=math.log(0.01 * act_high[0]) *
                              np.ones(act_dim, dtype=np.float32))
    std = tf.exp(log_std)
    with tf.variable_scope("pi", reuse=True):
        pi = mu + tf.random_normal(tf.shape(mu)) * std
    with tf.variable_scope("log_pi"):
        logp = core.gaussian_likelihood(a, mu, log_std)
        logp_pi = core.gaussian_likelihood(pi, mu, log_std)

    if 'v' in model.keys():
        print("value function already in model")
        v = model['v']
    else:
        _, _, _, v = core.mlp_actor_critic(x, a, **ac_kwargs)

    # get_action = lambda x : sess.run(mu, feed_dict={model['x']: x[None,:]})[0]
    sess.run(tf.initialize_variables([log_std]))

    return sess, model['x'], model['a'], mu, pi, logp, logp_pi, v
Ejemplo n.º 9
0
def get_policy_model(fpath, sess, itr='last', use_model_only=True):

    # handle which epoch to load from
    if itr == 'last':
        saves = [
            int(x[11:]) for x in os.listdir(fpath)
            if 'simple_save' in x and len(x) > 11
        ]
        itr = '%d' % max(saves) if len(saves) > 0 else ''
    else:
        itr = '%d' % itr

    # load the things!
    if use_model_only:
        # We need this to get the same agent performance on resume.
        model = restore_tf_graph_model_only(
            sess, osp.join(fpath, f'{TF_MODEL_ONLY_DIR}/'))
    else:
        model = restore_tf_graph(sess, osp.join(fpath, 'simple_save' + itr))

    return model, itr
Ejemplo n.º 10
0
def load_tf_policy(fpath, itr, deterministic=False):
    """ Load a tensorflow policy saved with Spinning Up Logger."""

    fname = osp.join(fpath, 'tf1_save'+itr)
    print('\n\nLoading from %s.\n\n'%fname)

    # load the things!
    sess = tf.Session()
    model = restore_tf_graph(sess, fname)

    # get the correct op for executing actions
    if deterministic and 'mu' in model.keys():
        # 'deterministic' is only a valid option for SAC policies
        print('Using deterministic action op.')
        action_op = model['mu']
    else:
        print('Using default action op.')
        action_op = model['pi']

    # make function for producing an action given a single state
    get_action = lambda x : sess.run(action_op, feed_dict={model['x']: x[None,:]})[0]

    return get_action
Ejemplo n.º 11
0
def load_policy(sess, fpath):
    model = restore_tf_graph(sess, osp.join(fpath, 'simple_save'))
    get_action = lambda x: sess.run(model['pi'],
                                    feed_dict={model['x']: x[None, :]})[0]
    return get_action
Ejemplo n.º 12
0
def ppo(workload_file,
        model_path,
        ac_kwargs=dict(),
        seed=0,
        traj_per_epoch=4000,
        epochs=50,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=80,
        train_v_iters=80,
        lam=0.97,
        max_ep_len=1000,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10,
        pre_trained=0,
        trained_model=None,
        attn=False,
        shuffle=False,
        backfil=False,
        skip=False,
        score_type=0,
        batch_job_slice=0,
        sched_algo=4):
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = HPCEnvSkip(shuffle=shuffle,
                     backfil=backfil,
                     skip=skip,
                     job_score_type=score_type,
                     batch_job_slice=batch_job_slice,
                     build_sjf=False,
                     sched_algo=sched_algo)
    env.seed(seed)
    env.my_init(workload_file=workload_file, sched_file=model_path)

    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space
    ac_kwargs['attn'] = attn

    # Inputs to computation graph

    buf = PPOBuffer(obs_dim, act_dim, traj_per_epoch * JOB_SEQUENCE_SIZE,
                    gamma, lam)

    if pre_trained:
        sess = tf.Session()
        model = restore_tf_graph(sess, trained_model)
        logger.log('load pre-trained model')
        # Count variables
        var_counts = tuple(count_vars(scope) for scope in ['pi', 'v'])
        logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' %
                   var_counts)

        x_ph = model['x']
        a_ph = model['a']
        mask_ph = model['mask']
        adv_ph = model['adv']
        ret_ph = model['ret']
        logp_old_ph = model['logp_old_ph']

        pi = model['pi']
        v = model['v']
        # logits = model['logits']
        out = model['out']
        logp = model['logp']
        logp_pi = model['logp_pi']
        pi_loss = model['pi_loss']
        v_loss = model['v_loss']
        approx_ent = model['approx_ent']
        approx_kl = model['approx_kl']
        clipfrac = model['clipfrac']
        clipped = model['clipped']

        # Optimizers
        # graph = tf.get_default_graph()
        # op = sess.graph.get_operations()
        # [print(m.values()) for m in op]
        # train_pi = graph.get_tensor_by_name('pi/conv2d/kernel/Adam:0')
        # train_v = graph.get_tensor_by_name('v/conv2d/kernel/Adam:0')
        train_pi = tf.get_collection("train_pi")[0]
        train_v = tf.get_collection("train_v")[0]
        # train_pi_optimizer = MpiAdamOptimizer(learning_rate=pi_lr, name='AdamLoad')
        # train_pi = train_pi_optimizer.minimize(pi_loss)
        # train_v_optimizer = MpiAdamOptimizer(learning_rate=vf_lr, name='AdamLoad')
        # train_v = train_v_optimizer.minimize(v_loss)
        # sess.run(tf.variables_initializer(train_pi_optimizer.variables()))
        # sess.run(tf.variables_initializer(train_v_optimizer.variables()))
        # Need all placeholders in *this* order later (to zip with data from buffer)
        all_phs = [x_ph, a_ph, mask_ph, adv_ph, ret_ph, logp_old_ph]
        # Every step, get: action, value, and logprob
        get_action_ops = [pi, v, logp_pi, out]

    else:
        x_ph, a_ph = placeholders_from_spaces(env.observation_space,
                                              env.action_space)
        # y_ph = placeholder(JOB_SEQUENCE_SIZE*3) # 3 is the number of sequence features
        mask_ph = placeholder(env.action_space.n)
        adv_ph, ret_ph, logp_old_ph = placeholders(None, None, None)

        # Main outputs from computation graph
        pi, logp, logp_pi, v, out = actor_critic(x_ph, a_ph, mask_ph,
                                                 **ac_kwargs)

        # Need all placeholders in *this* order later (to zip with data from buffer)
        all_phs = [x_ph, a_ph, mask_ph, adv_ph, ret_ph, logp_old_ph]

        # Every step, get: action, value, and logprob
        get_action_ops = [pi, v, logp_pi, out]

        # Experience buffer

        # Count variables
        var_counts = tuple(count_vars(scope) for scope in ['pi', 'v'])
        logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' %
                   var_counts)

        # PPO objectives
        ratio = tf.exp(logp - logp_old_ph)  # pi(a|s) / pi_old(a|s)
        min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                           (1 - clip_ratio) * adv_ph)
        pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
        v_loss = tf.reduce_mean((ret_ph - v)**2)

        # Info (useful to watch during learning)
        approx_kl = tf.reduce_mean(
            logp_old_ph -
            logp)  # a sample estimate for KL-divergence, easy to compute
        approx_ent = tf.reduce_mean(
            -logp)  # a sample estimate for entropy, also easy to compute
        clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio <
                                (1 - clip_ratio))
        clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))

        # Optimizers
        train_pi = tf.train.AdamOptimizer(
            learning_rate=pi_lr).minimize(pi_loss)
        train_v = tf.train.AdamOptimizer(learning_rate=vf_lr).minimize(v_loss)
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        tf.add_to_collection("train_pi", train_pi)
        tf.add_to_collection("train_v", train_v)

    # Setup model saving
    # logger.setup_tf_saver(sess, inputs={'x': x_ph}, outputs={'action_probs': action_probs, 'log_picked_action_prob': log_picked_action_prob, 'v': v})
    logger.setup_tf_saver(sess,
                          inputs={
                              'x': x_ph,
                              'a': a_ph,
                              'adv': adv_ph,
                              'mask': mask_ph,
                              'ret': ret_ph,
                              'logp_old_ph': logp_old_ph
                          },
                          outputs={
                              'pi': pi,
                              'v': v,
                              'out': out,
                              'pi_loss': pi_loss,
                              'logp': logp,
                              'logp_pi': logp_pi,
                              'v_loss': v_loss,
                              'approx_ent': approx_ent,
                              'approx_kl': approx_kl,
                              'clipped': clipped,
                              'clipfrac': clipfrac
                          })

    def update():
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent],
                                          feed_dict=inputs)

        # Training
        for i in range(train_pi_iters):
            _, kl = sess.run([train_pi, approx_kl], feed_dict=inputs)
            kl = mpi_avg(kl)
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
        logger.store(StopIter=i)
        for _ in range(train_v_iters):
            sess.run(train_v, feed_dict=inputs)

        # Log changes from update
        pi_l_new, v_l_new, kl, cf = sess.run(
            [pi_loss, v_loss, approx_kl, clipfrac], feed_dict=inputs)
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    [o, co], r, d, ep_ret, ep_len, show_ret, sjf, f1, skip_count = env.reset(
    ), 0, False, 0, 0, 0, 0, 0, 0

    # Main loop: collect experience in env and update/log each epoch
    start_time = time.time()
    for epoch in range(epochs):
        t = 0
        while True:
            # [no_skip, skip]
            lst = [1, 1]
            #for i in range(0, MAX_QUEUE_SIZE * JOB_FEATURES, JOB_FEATURES):
            #    job = o[i:i + JOB_FEATURES]
            #    # the skip time of will_skip job exceeds MAX_SKIP_TIME
            #    if job[-2] == 1.0:
            #        lst = [1,0]

            a, v_t, logp_t, output = sess.run(get_action_ops,
                                              feed_dict={
                                                  x_ph:
                                                  o.reshape(1, -1),
                                                  mask_ph:
                                                  np.array(lst).reshape(1, -1)
                                              })
            # print(a, end=" ")
            '''
            action = np.random.choice(np.arange(MAX_QUEUE_SIZE), p=action_probs)
            log_action_prob = np.log(action_probs[action])
            '''

            # save and log
            buf.store(o, None, a, np.array(lst), r, v_t, logp_t)
            logger.store(VVals=v_t)
            if a[0] == 1:
                skip_count += 1
            o, r, d, r2, sjf_t, f1_t = env.step(a[0])
            ep_ret += r
            ep_len += 1
            show_ret += r2
            sjf += sjf_t
            f1 += f1_t

            if d:
                t += 1
                buf.finish_path(r)
                logger.store(EpRet=ep_ret,
                             EpLen=ep_len,
                             ShowRet=show_ret,
                             SJF=sjf,
                             F1=f1,
                             SkipRatio=skip_count / ep_len)
                [
                    o, co
                ], r, d, ep_ret, ep_len, show_ret, sjf, f1, skip_count = env.reset(
                ), 0, False, 0, 0, 0, 0, 0, 0
                if t >= traj_per_epoch:
                    # print ("state:", state, "\nlast action in a traj: action_probs:\n", action_probs, "\naction:", action)
                    break
        # print("Sample time:", (time.time()-start_time)/num_total, num_total)
        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, None)

        # Perform PPO update!
        # start_time = time.time()
        update()
        # print("Train time:", time.time()-start_time)

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', with_min_and_max=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts',
                           (epoch + 1) * traj_per_epoch * JOB_SEQUENCE_SIZE)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('ShowRet', average_only=True)
        logger.log_tabular('SJF', average_only=True)
        logger.log_tabular('F1', average_only=True)
        logger.log_tabular('SkipRatio', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
Ejemplo n.º 13
0
def run_evaluation(model_dir, model_save_name, seed=1):

    # set up the trained model
    sess = tf.compat.v1.Session(config=tf_config)
    model  = restore_tf_graph(sess=sess, fpath=os.path.join(model_dir, model_save_name))
    config = load_json_obj(os.path.join(model_dir, 'config'))

    if 'sim' in config['rl_params']['platform']:
        from braille_rl.envs.sim.disc_sim_braille_env.mockKBGymEnv import mockKBGymEnv
        env = mockKBGymEnv(mode=config['rl_params']['env_mode'], max_steps=config['rl_params']['max_ep_len'])
    elif 'robot' in config['rl_params']['platform']:
        from braille_rl.envs.robot.disc_ur5_braille_env.ur5GymEnv import UR5GymEnv
        env = UR5GymEnv(mode=config['rl_params']['env_mode'], max_steps=config['rl_params']['max_ep_len'])

    # define neccesry inputs/outputs
    x_ph      = model['x_ph']
    g_ph      = model['g_ph']
    prev_a_ph = model['prev_a_ph']
    pi        = model['pi']

    obs_dim = config['network_params']['input_dims']
    test_state_buffer = StateBuffer(m=obs_dim[2])
    max_ep_len = config['rl_params']['max_ep_len']
    test_act_noise = 0.0

    act_dim = env.action_space.n
    goal_dim = len(env.goal_list)

    # set seeding
    tf.set_random_seed(seed)
    np.random.seed(seed)
    env.seed(seed)
    env.action_space.np_random.seed(seed)
    random.seed(seed)

    # create list of key sequences to be typed
    if 'arrows' in config['rl_params']['env_mode']:
        test_goal_list = permutation(['UP', 'DOWN', 'LEFT', 'RIGHT'])

    elif 'alphabet' in config['rl_params']['env_mode']:
        test_goal_list = [random.sample(env.goal_list,  len(env.goal_list)) for i in range(10)]

    print('Key Sequences: ')
    for sequence in test_goal_list:
        print(sequence)
    print('')

    def get_action(state, one_hot_goal, prev_a, noise_scale):
        state = state.astype('float32') / 255.
        if np.random.random_sample() < noise_scale:
            a = env.action_space.sample()
        else:
            a = sess.run(pi, feed_dict={x_ph: [state],
                                        g_ph: [one_hot_goal],
                                        prev_a_ph: [prev_a]})[0]
        return a

    def reset(state_buffer, goal):
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        o = process_image_observation(o, obs_dim) # thresholding done in env
        r = process_reward(r)
        state = state_buffer.init_state(init_obs=o)
        prev_a = np.zeros(act_dim)

        # new random goal when the env is reset
        goal_id = env.goal_list.index(goal)
        one_hot_goal = np.eye(goal_dim)[goal_id]
        env.goal_button = goal

        return o, r, d, ep_ret, ep_len, state, one_hot_goal, prev_a

    def test_agent(sequence):
        n=len(sequence)

        correct_count = 0
        step_count = 0

        goal_list = []
        achieved_goal_list = []

        for j in range(n):
            test_o, test_r, test_d, test_ep_ret, test_ep_len, test_state, test_one_hot_goal, test_prev_a = reset(test_state_buffer, sequence[j])

            while not(test_d or (test_ep_len == max_ep_len)):

                test_a = get_action(test_state, test_one_hot_goal, test_prev_a, test_act_noise)

                test_o, test_r, test_d, _ = env.step(test_a)
                test_o = process_image_observation(test_o, obs_dim)
                test_r = process_reward(test_r)
                test_state = test_state_buffer.append_state(test_o)

                test_ep_ret += test_r
                test_ep_len += 1

                test_one_hot_a = process_action(test_a, act_dim)
                test_prev_a = test_one_hot_a

                if test_r == 1:
                    correct_count += 1

                if test_d:
                    achieved_goal_list.append(env.latest_button)
                    goal_list.append(sequence[j])

            step_count += test_ep_len

        acc = correct_count/n

        print('Sequence Steps: {}'.format(step_count))
        print('Sequence Accuracy: {}'.format(acc))

        return correct_count, n, step_count, goal_list, achieved_goal_list

    csv_dir = os.path.join(model_dir, 'evaluation_output.csv')

    with open(csv_dir, "w", newline='') as csv_file:
        writer = csv.writer(csv_file, delimiter=',')
        writer.writerow(['Episode','Example Sequence', 'Typed Sequence'])

        achieved_goals = []
        goals = []
        total_correct = 0
        total_count = 0
        total_steps = 0
        for (i,sequence) in enumerate(test_goal_list):
            correct, num_elements, step_count, goal_list, achieved_goal_list = test_agent(sequence)

            total_correct += correct
            total_count += num_elements
            total_steps += step_count

            writer.writerow([i, goal_list, achieved_goal_list])

            achieved_goals.append(achieved_goal_list)
            goals.append(goal_list)

        overall_acc = total_correct/total_count
        print('')
        print('Total Steps:   ', total_steps)
        print('Total Count:   ', total_count)
        print('Total Correct: ', total_correct)
        print('Overall Acc:   ', overall_acc)

        writer.writerow([])
        writer.writerow(['Total Steps', 'Total Count', 'Total Correct', 'Overall Acc'])
        writer.writerow([total_steps, total_count, total_correct, overall_acc])

    goals = np.hstack(goals)
    achieved_goals = np.hstack(achieved_goals)
    cnf_matrix = confusion_matrix(goals, achieved_goals)
    plot_confusion_matrix(cnf_matrix, classes=env.goal_list, normalize=True, cmap=plt.cm.Blues,
                          title='Normalized Confusion matrix', dirname=None, save_flag=False)

    env.close()
Ejemplo n.º 14
0
def load_policy(fpath,
                itr='last',
                deterministic=False,
                eval_temp=1.0,
                use_temp=True,
                env_name=None,
                env_version=1,
                meta_learning_or_finetune=False):
    # handle which epoch to load from
    if itr == 'last':
        saves = [
            int(x[11:]) for x in os.listdir(fpath)
            if 'simple_save' in x and len(x) > 11
        ]
        itr = '%d' % max(saves) if len(saves) > 0 else ''
    else:
        itr = '%d' % itr

    # load the things!
    sess = tf.Session()
    model = restore_tf_graph(
        sess,
        osp.join(fpath, 'simple_save' + itr),
        meta_learning_or_finetune=meta_learning_or_finetune)

    # get the correct op for executing actions
    if deterministic and 'mu' in model.keys():
        # 'deterministic' is only a valid option for SAC policies
        print('Using deterministic action op.')
        action_op = model['mu']
    else:
        print('Using default action op.')
        action_op = model['pi']

    # make function for producing an action given a single state
    if not use_temp:
        get_action = lambda x: sess.run(action_op,
                                        feed_dict={model['x']: x[None, :]})[0]
    else:
        get_action = lambda x: sess.run(action_op,
                                        feed_dict={
                                            model['x']: x[None, :],
                                            model['temperature']: eval_temp
                                        })[0]

    if env_name is None:
        # try to load environment from save
        # (sometimes this will fail because the environment could not be pickled)
        try:
            state = joblib.load(osp.join(fpath, 'vars' + itr + '.pkl'))
            env = state['env']
        except:
            env = None
    else:
        # env = (lambda: gym.make(env_name))()
        if args.env_version in (1, 2):
            env = get_custom_env_fn(env_name, env_version)()
        if args.env_version >= 3:
            env = get_custom_env_fn(env_name,
                                    env_version,
                                    target_arcs=args.target_arcs,
                                    env_input=args.env_input,
                                    env_n_sample=5000)()

    return env, get_action
def td3(env_fn,
        expert=None,
        policy_path=None,
        actor_critic=core.mlp_actor_critic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=500,
        epochs=1000,
        replay_size=int(5e3),
        gamma=0.99,
        polyak=0.995,
        pi_lr=1e-4,
        q_lr=1e-4,
        batch_size=64,
        start_epochs=500,
        dagger_epochs=500,
        pretrain_epochs=50,
        dagger_noise=0.02,
        act_noise=0.02,
        target_noise=0.02,
        noise_clip=0.5,
        policy_delay=2,
        max_ep_len=500,
        logger_kwargs=dict(),
        save_freq=50,
        UPDATE_STEP=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Deterministically computes actions
                                           | from policy given states.
            ``q1``       (batch,)          | Gives one estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q2``       (batch,)          | Gives another estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q1_pi``    (batch,)          | Gives the composition of ``q1`` and 
                                           | ``pi`` for states in ``x_ph``: 
                                           | q1(x, pi(x)).
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to TD3.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target 
            networks. Target networks are updated towards main networks 
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow 
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually 
            close to 1.)

        pi_lr (float): Learning rate for policy.

        q_lr (float): Learning rate for Q-networks.

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        act_noise (float): Stddev for Gaussian exploration noise added to 
            policy at training time. (At test time, no noise is added.)

        target_noise (float): Stddev for smoothing noise added to target 
            policy.

        noise_clip (float): Limit for absolute value of target policy 
            smoothing noise.

        policy_delay (int): Policy will only be updated once every 
            policy_delay times for each update of the Q-networks.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())
    test_logger_kwargs = dict()
    test_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'],
                                                "test")
    test_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    test_logger = EpochLogger(**test_logger_kwargs)

    # test_logger_kwargs = dict()
    # test_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'], "test")
    # test_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    # test_logger = EpochLogger(**test_logger_kwargs)

    # pretrain_logger_kwargs = dict()
    # pretrain_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'], "pretrain")
    # pretrain_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    # pretrain_logger = EpochLogger(**pretrain_logger_kwargs)

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, do not assumes all dimensions share the same bound!
    act_limit = env.action_space.high / 2
    act_high_limit = env.action_space.high
    act_low_limit = env.action_space.low

    act_noise_limit = act_noise * act_limit
    sess = tf.Session()
    if policy_path is None:
        # Share information about action space with policy architecture
        ac_kwargs['action_space'] = env.action_space

        # Inputs to computation graph
        x_ph, a_ph, x2_ph, r_ph, d_ph = core.placeholders(
            obs_dim, act_dim, obs_dim, None, None)
        tfa_ph = core.placeholder(act_dim)

        # Main outputs from computation graph
        with tf.variable_scope('main'):
            pi, q1, q2, q1_pi = actor_critic(x_ph, a_ph, **ac_kwargs)

        # Target policy network
        with tf.variable_scope('target'):
            pi_targ, _, _, _ = actor_critic(x2_ph, a_ph, **ac_kwargs)

        # Target Q networks
        with tf.variable_scope('target', reuse=True):

            # Target policy smoothing, by adding clipped noise to target actions
            epsilon = tf.random_normal(tf.shape(pi_targ), stddev=target_noise)
            epsilon = tf.clip_by_value(epsilon, -noise_clip, noise_clip)
            a2 = pi_targ + epsilon
            a2 = tf.clip_by_value(a2, act_low_limit, act_high_limit)

            # Target Q-values, using action from target policy
            _, q1_targ, q2_targ, _ = actor_critic(x2_ph, a2, **ac_kwargs)

    else:
        # sess = tf.Session()
        model = restore_tf_graph(sess, osp.join(policy_path, 'simple_save'))
        x_ph, a_ph, x2_ph, r_ph, d_ph = model['x_ph'], model['a_ph'], model[
            'x2_ph'], model['r_ph'], model['d_ph']
        pi, q1, q2, q1_pi = model['pi'], model['q1'], model['q2'], model[
            'q1_pi']
        pi_targ, q1_targ, q2_targ = model['pi_targ'], model['q1_targ'], model[
            'q2_targ']
        tfa_ph = core.placeholder(act_dim)
        dagger_epochs = 0

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                 act_dim=act_dim,
                                 size=replay_size)
    dagger_replay_buffer = DaggerReplayBuffer(obs_dim=obs_dim,
                                              act_dim=act_dim,
                                              size=replay_size)
    # Count variables
    var_counts = tuple(
        core.count_vars(scope)
        for scope in ['main/pi', 'main/q1', 'main/q2', 'main'])
    print(
        '\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d, \t total: %d\n'
        % var_counts)

    if policy_path is None:
        # Bellman backup for Q functions, using Clipped Double-Q targets
        min_q_targ = tf.minimum(q1_targ, q2_targ)
        backup = tf.stop_gradient(r_ph + gamma * (1 - d_ph) * min_q_targ)

        # dagger loss
        dagger_pi_loss = tf.reduce_mean(tf.square(pi - tfa_ph))
        # TD3 losses
        pi_loss = -tf.reduce_mean(q1_pi)
        q1_loss = tf.reduce_mean((q1 - backup)**2)
        q2_loss = tf.reduce_mean((q2 - backup)**2)
        q_loss = tf.add(q1_loss, q2_loss)
        pi_loss = tf.identity(pi_loss, name="pi_loss")
        q1_loss = tf.identity(q1_loss, name="q1_loss")
        q2_loss = tf.identity(q2_loss, name="q2_loss")
        q_loss = tf.identity(q_loss, name="q_loss")

        # Separate train ops for pi, q
        dagger_pi_optimizer = tf.train.AdamOptimizer(learning_rate=pi_lr)
        pi_optimizer = tf.train.AdamOptimizer(learning_rate=pi_lr)
        q_optimizer = tf.train.AdamOptimizer(learning_rate=q_lr)
        train_dagger_pi_op = dagger_pi_optimizer.minimize(
            dagger_pi_loss,
            var_list=get_vars('main/pi'),
            name='train_dagger_pi_op')
        train_pi_op = pi_optimizer.minimize(pi_loss,
                                            var_list=get_vars('main/pi'),
                                            name='train_pi_op')
        train_q_op = q_optimizer.minimize(q_loss,
                                          var_list=get_vars('main/q'),
                                          name='train_q_op')

        # Polyak averaging for target variables
        target_update = tf.group([
            tf.assign(v_targ, polyak * v_targ + (1 - polyak) * v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

        # Initializing targets to match main variables
        target_init = tf.group([
            tf.assign(v_targ, v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])
        sess.run(tf.global_variables_initializer())
    else:
        graph = tf.get_default_graph()
        # opts = graph.get_operations()
        # print (opts)
        pi_loss = model['pi_loss']
        q1_loss = model['q1_loss']
        q2_loss = model['q2_loss']
        q_loss = model['q_loss']
        train_q_op = graph.get_operation_by_name('train_q_op')
        train_pi_op = graph.get_operation_by_name('train_pi_op')
        # target_update = graph.get_operation_by_name('target_update')
        # target_init = graph.get_operation_by_name('target_init')
        # Polyak averaging for target variables
        target_update = tf.group([
            tf.assign(v_targ, polyak * v_targ + (1 - polyak) * v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

        # Initializing targets to match main variables
        target_init = tf.group([
            tf.assign(v_targ, v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

    # sess = tf.Session()
    # sess.run(tf.global_variables_initializer())
    sess.run(target_init)

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x_ph': x_ph, 'a_ph': a_ph, 'x2_ph': x2_ph, 'r_ph': r_ph, 'd_ph': d_ph}, \
         outputs={'pi': pi, 'q1': q1, 'q2': q2, 'q1_pi': q1_pi, 'pi_targ': pi_targ, 'q1_targ': q1_targ, 'q2_targ': q2_targ, \
             'pi_loss': pi_loss, 'q1_loss': q1_loss, 'q2_loss': q2_loss, 'q_loss': q_loss})

    def get_action(o, noise_scale):
        a = sess.run(pi, feed_dict={x_ph: o.reshape(1, -1)})[0]
        # todo: add act_limit scale noise
        a += noise_scale * np.random.randn(act_dim)
        return np.clip(a, act_low_limit, act_high_limit)

    def choose_action(s, add_noise=False):
        s = s[np.newaxis, :]
        a = sess.run(pi, {x_ph: s})[0]
        if add_noise:
            noise = dagger_noise * act_high_limit * np.random.normal(
                size=a.shape)
            a = a + noise
        return np.clip(a, act_low_limit, act_high_limit)

    def test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, info = env.step(choose_action(np.array(o), 0))
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(
                        arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        # time.sleep(10)
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    start_time = time.time()
    env.unwrapped._set_test_mode(False)
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs
    test_num = 0

    total_env_t = 0
    print(colorize("begin dagger training", 'green', bold=True))
    # Main loop for dagger pretrain
    for epoch in range(1, dagger_epochs + 1, 1):
        obs, acs, rewards = [], [], []
        # number of timesteps
        for t in range(steps_per_epoch):
            # action = env.action_space.sample()
            # action = ppo.choose_action(np.array(observation))
            obs.append(o)
            ref_action = call_ref_controller(env, expert)
            if (epoch < pretrain_epochs):
                action = ref_action
            else:
                action = choose_action(np.array(o), True)

            o2, r, d, info = env.step(action)
            ep_ret += r
            ep_len += 1
            total_env_t += 1

            acs.append(ref_action)
            rewards.append(r)
            # Store experience to replay buffer
            replay_buffer.store(o, action, r, o2, d)

            o = o2

            if (t == steps_per_epoch - 1):
                # print ("reached the end")
                d = True

            if d:
                # collected data to replaybuffer
                max_step = len(np.array(rewards))
                q = [
                    np.sum(
                        np.power(gamma, np.arange(max_step - t)) * rewards[t:])
                    for t in range(max_step)
                ]
                dagger_replay_buffer.stores(obs, acs, rewards, q)

                # update policy
                for _ in range(int(max_step / 5)):
                    batch = dagger_replay_buffer.sample_batch(batch_size)
                    feed_dict = {x_ph: batch['obs1'], tfa_ph: batch['acts']}
                    q_step_ops = [dagger_pi_loss, train_dagger_pi_op]
                    for j in range(UPDATE_STEP):
                        outs = sess.run(q_step_ops, feed_dict)
                    logger.store(LossPi=outs[0])

                # train q function
                for j in range(int(max_step / 5)):
                    batch = replay_buffer.sample_batch(batch_size)
                    feed_dict = {
                        x_ph: batch['obs1'],
                        x2_ph: batch['obs2'],
                        a_ph: batch['acts'],
                        r_ph: batch['rews'],
                        d_ph: batch['done']
                    }
                    q_step_ops = [q_loss, q1, q2, train_q_op]
                    # for _ in range(UPDATE_STEP):
                    outs = sess.run(q_step_ops, feed_dict)
                    logger.store(LossQ=outs[0], Q1Vals=outs[1], Q2Vals=outs[2])

                    if j % policy_delay == 0:
                        # Delayed target update
                        outs = sess.run([target_update], feed_dict)
                        # logger.store(LossPi=outs[0])

                # logger.store(LossQ=1000000, Q1Vals=1000000, Q2Vals=1000000)
                logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
                break

        # End of epoch wrap-up
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == dagger_epochs):
            # Save model
            logger.save_state({}, None)

            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)

            # Log info about epoch
            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

    sess.run(target_init)
    print(colorize("begin td3 training", 'green', bold=True))
    # Main loop: collect experience in env and update/log each epoch
    # total_env_t = 0
    for epoch in range(1, epochs + 1, 1):

        # End of epoch wrap-up
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == epochs):

            # Save model
            logger.save_state({}, None)

            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)

            # Log info about epoch
            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()
        """
        Until start_steps have elapsed, randomly sample actions
        from a uniform distribution for better exploration. Afterwards, 
        use the learned policy (with some noise, via act_noise). 
        """
        # o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        for t in range(steps_per_epoch):
            if epoch > start_epochs:
                a = get_action(np.array(o), act_noise_limit)
            else:
                a = env.action_space.sample()
                # ref_action = call_ref_controller(env, expert)

            # Step the env
            o2, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1
            total_env_t += 1

            # Ignore the "done" signal if it comes from hitting the time
            # horizon (that is, when it's an artificial terminal signal
            # that isn't based on the agent's state)
            # d = False if ep_len==max_ep_len else d

            # Store experience to replay buffer
            replay_buffer.store(o, a, r, o2, d)

            # Super critical, easy to overlook step: make sure to update
            # most recent observation!
            o = o2

            if (t == steps_per_epoch - 1):
                # print ("reached the end")
                d = True

            if d:
                """
                Perform all TD3 updates at the end of the trajectory
                (in accordance with source code of TD3 published by
                original authors).
                """
                for j in range(ep_len):
                    batch = replay_buffer.sample_batch(batch_size)
                    feed_dict = {
                        x_ph: batch['obs1'],
                        x2_ph: batch['obs2'],
                        a_ph: batch['acts'],
                        r_ph: batch['rews'],
                        d_ph: batch['done']
                    }
                    q_step_ops = [q_loss, q1, q2, train_q_op]
                    # for _ in range(UPDATE_STEP):
                    outs = sess.run(q_step_ops, feed_dict)
                    logger.store(LossQ=outs[0], Q1Vals=outs[1], Q2Vals=outs[2])

                    if j % policy_delay == 0:
                        # Delayed policy update
                        outs = sess.run([pi_loss, train_pi_op, target_update],
                                        feed_dict)
                        logger.store(LossPi=outs[0])

                logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
                break
Ejemplo n.º 16
0
def mars(workload_file,
         model_path,
         ac_kwargs=dict(),
         seed=0,
         traj_per_epoch=4000,
         epochs=50,
         gamma=0.99,
         clip_ratio=0.2,
         pi_lr=3e-4,
         vf_lr=1e-3,
         train_pi_iters=80,
         train_v_iters=80,
         lam=0.97,
         max_ep_len=1000,
         target_kl=0.01,
         logger_kwargs=dict(),
         save_freq=10,
         pre_trained=0,
         trained_model=None,
         attn=False,
         shuffle=False,
         backfil=False,
         skip=False,
         score_type=0,
         batch_job_slice=0):
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())
    tf.set_random_seed(seed)
    np.random.seed(seed)
    env = HPC_Environment(shuffle=shuffle,
                          backfil=backfil,
                          skip=skip,
                          job_score_type=score_type,
                          batch_job_slice=batch_job_slice,
                          build_sjf=False)
    env.seed(seed)
    env.my_init(workload_file=workload_file, sched_file=model_path)
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape
    ac_kwargs['action_space'] = env.action_space
    ac_kwargs['attn'] = attn
    buf = MARSBuffer(obs_dim, act_dim, traj_per_epoch * TASK_SEQUENCE_SIZE,
                     gamma, lam)

    if pre_trained:
        sess = tf.Session()
        model = restore_tf_graph(sess, trained_model)
        logger.log('loading model')
        var_counts = tuple(count_vars(scope) for scope in ['pi', 'v'])
        logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' %
                   var_counts)
        x_ph = model['x']
        a_ph = model['a']
        mask_ph = model['mask']
        adv_ph = model['adv']
        ret_ph = model['ret']
        logp_old_ph = model['logp_old_ph']
        pi = model['pi']
        v = model['v']
        out = model['out']
        logp = model['logp']
        logp_pi = model['logp_pi']
        pi_loss = model['pi_loss']
        v_loss = model['v_loss']
        approx_ent = model['approx_ent']
        approx_kl = model['approx_kl']
        clipfrac = model['clipfrac']
        clipped = model['clipped']
        train_pi = tf.get_collection("train_pi")[0]
        train_v = tf.get_collection("train_v")[0]
        all_phs = [x_ph, a_ph, mask_ph, adv_ph, ret_ph, logp_old_ph]
        get_action_ops = [pi, v, logp_pi, out]

    else:
        if (buf < 512):
            x_ph, a_ph = placeholders_from_spaces(env.observation_space,
                                                  env.action_space)
            mask_ph = placeholder(MAX_QUEUE_SIZE)
            adv_ph, ret_ph, logp_old_ph = placeholders(None, None, None)
            pi, logp, logp_pi, v, out = actor_critic(x_ph, a_ph, mask_ph,
                                                     **ac_kwargs)
            all_phs = [x_ph, a_ph, mask_ph, adv_ph, ret_ph, logp_old_ph]
            get_action_ops = [pi, v, logp_pi, out]
            var_counts = tuple(count_vars(scope) for scope in ['pi', 'v'])
            logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' %
                       var_counts)
            ratio = tf.exp(logp - logp_old_ph)
            min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                               (1 - clip_ratio) * adv_ph)
            pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
            v_loss = tf.reduce_mean((ret_ph - v)**2)
            approx_kl = tf.reduce_mean(logp_old_ph - logp)
            approx_ent = tf.reduce_mean(-logp)
            clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio <
                                    (1 - clip_ratio))
            clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))
            train_pi = tf.train.AdamOptimizer(
                learning_rate=pi_lr).minimize(pi_loss)
            train_v = tf.train.AdamOptimizer(
                learning_rate=vf_lr).minimize(v_loss)
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
            tf.add_to_collection("train_pi", train_pi)
            tf.add_to_collection("train_v", train_v)
        else:
            x_ph, a_ph = placeholders_from_spaces(env.observation_space,
                                                  env.action_space)
            mask_ph = placeholder(MAX_QUEUE_SIZE)
            adv_ph, ret_ph, logp_old_ph = placeholders(None, None, None)
            pi, logp, logp_pi, v, out = actor_critic(x_ph, a_ph, mask_ph,
                                                     **ac_kwargs)
            all_phs = [x_ph, a_ph, mask_ph, adv_ph, ret_ph, logp_old_ph]
            get_action_ops = [pi, v, logp_pi, out]
            var_counts = tuple(count_vars(scope) for scope in ['pi', 'v'])
            logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' %
                       var_counts)
            ratio = tf.exp(logp - logp_old_ph)
            min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                               (1 - clip_ratio) * adv_ph)
            pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
            v_loss = tf.reduce_mean((ret_ph - v)**2)
            approx_kl = tf.reduce_mean(logp_old_ph - logp)
            approx_ent = tf.reduce_mean(-logp)
            clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio <
                                    (1 - clip_ratio))
            clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))
            train_pi = tf.train.AdamOptimizer(
                learning_rate=pi_lr).minimize(pi_loss)
            train_v = tf.train.AdamOptimizer(
                learning_rate=vf_lr).minimize(v_loss)
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
            tf.add_to_collection("train_pi", train_pi)
            tf.add_to_collection("train_v", train_v)
    logger.setup_tf_saver(sess,
                          inputs={
                              'x': x_ph,
                              'a': a_ph,
                              'adv': adv_ph,
                              'mask': mask_ph,
                              'ret': ret_ph,
                              'logp_old_ph': logp_old_ph
                          },
                          outputs={
                              'pi': pi,
                              'v': v,
                              'out': out,
                              'pi_loss': pi_loss,
                              'logp': logp,
                              'logp_pi': logp_pi,
                              'v_loss': v_loss,
                              'approx_ent': approx_ent,
                              'approx_kl': approx_kl,
                              'clipped': clipped,
                              'clipfrac': clipfrac
                          })

    def update():
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent],
                                          feed_dict=inputs)
        for i in range(train_pi_iters):
            _, kl = sess.run([train_pi, approx_kl], feed_dict=inputs)
            kl = mpi_avg(kl)
            if kl > 1.5 * target_kl:
                logger.log('Max reached at step %d ' % i)
                break
        logger.store(StopIter=i)
        for _ in range(train_v_iters):
            sess.run(train_v, feed_dict=inputs)
        pi_l_new, v_l_new, kl, cf = sess.run(
            [pi_loss, v_loss, approx_kl, clipfrac], feed_dict=inputs)
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    [o, co], r, d, ep_ret, ep_len, show_ret, sjf, f1 = env.reset(
    ), 0, False, 0, 0, 0, 0, 0
    start_time = time.time()
    num_total = 0
    for epoch in range(epochs):
        t = 0
        while True:
            lst = []
            for i in range(0, MAX_QUEUE_SIZE * TASK_FEATURES, TASK_FEATURES):
                if all(o[i:i + TASK_FEATURES] == [0] + [1] *
                       (TASK_FEATURES - 2) + [0]):
                    lst.append(0)
                elif all(o[i:i + TASK_FEATURES] == [1] * TASK_FEATURES):
                    lst.append(0)
                else:
                    lst.append(1)

            a, v_t, logp_t, output = sess.run(get_action_ops,
                                              feed_dict={
                                                  x_ph:
                                                  o.reshape(1, -1),
                                                  mask_ph:
                                                  np.array(lst).reshape(1, -1)
                                              })

            num_total += 1
            buf.store(o, None, a, np.array(lst), r, v_t, logp_t)
            logger.store(VVals=v_t)
            o, r, d, r2, sjf_t, f1_t = env.step(a[0])
            ep_ret += r
            ep_len += 1
            show_ret += r2
            sjf += sjf_t
            f1 += f1_t
            if d:
                t += 1
                buf.finish_path(r)
                logger.store(EpRet=ep_ret,
                             EpLen=ep_len,
                             ShowRet=show_ret,
                             SJF=sjf,
                             F1=f1)
                [o, co], r, d, ep_ret, ep_len, show_ret, sjf, f1 = env.reset(
                ), 0, False, 0, 0, 0, 0, 0
                if t >= traj_per_epoch:
                    break
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, None)
        update()
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', with_min_and_max=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts',
                           (epoch + 1) * traj_per_epoch * TASK_SEQUENCE_SIZE)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('ShowRet', average_only=True)
        logger.log_tabular('SJF', average_only=True)
        logger.log_tabular('F1', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
def ppo(env_fn,
        expert=None,
        policy_path=None,
        actor_critic=core.mlp_actor_critic_m,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=5000,
        epochs=10000,
        dagger_epochs=500,
        pretrain_epochs=50,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=1e-4,
        dagger_noise=0.01,
        batch_size=64,
        replay_size=int(5e3),
        vf_lr=1e-4,
        train_pi_iters=80,
        train_v_iters=80,
        lam=0.999,
        max_ep_len=500,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10,
        test_freq=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Samples actions from policy given 
                                           | states.
            ``logp``     (batch,)          | Gives log probability, according to
                                           | the policy, of taking actions ``a_ph``
                                           | in states ``x_ph``.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``.
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. (Critical: make sure 
                                           | to flatten this!)
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to PPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while 
            still profiting (improving the objective function)? The new policy 
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.)

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_pi_iters (int): Maximum number of gradient descent steps to take 
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)

        train_v_iters (int): Number of gradient descent steps to take on 
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used 
            for early stopping. (Usually small, 0.01 or 0.05.)

        policy_path (str): path of pretrained policy model
            train from scratch if None

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())
    test_logger_kwargs = dict()
    test_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'],
                                                "test")
    test_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    test_logger = EpochLogger(**test_logger_kwargs)
    test_logger.save_config(locals())

    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space
    act_high_limit = env.action_space.high
    act_low_limit = env.action_space.low

    sess = tf.Session()
    if policy_path is None:
        # Inputs to computation graph
        x_ph, a_ph = core.placeholders_from_spaces(env.observation_space,
                                                   env.action_space)
        adv_ph, ret_ph, logp_old_ph = core.placeholders(None, None, None)
        tfa_ph = core.placeholder(act_dim)

        # Main outputs from computation graph
        mu, pi, logp, logp_pi, v = actor_critic(x_ph, a_ph, **ac_kwargs)
        sess.run(tf.global_variables_initializer())

    else:
        # load pretrained model
        # sess, x_ph, a_ph, mu, pi, logp, logp_pi, v = load_policy(policy_path, itr='last', deterministic=False, act_high=env.action_space.high)
        # # get_action_2 = lambda x : sess.run(mu, feed_dict={x_ph: x[None,:]})[0]
        # adv_ph, ret_ph, logp_old_ph = core.placeholders(None, None, None)
        model = restore_tf_graph(sess, osp.join(policy_path, 'simple_save'))
        x_ph, a_ph, adv_ph, ret_ph, logp_old_ph = model['x_ph'], model[
            'a_ph'], model['adv_ph'], model['ret_ph'], model['logp_old_ph']
        mu, pi, logp, logp_pi, v = model['mu'], model['pi'], model[
            'logp'], model['logp_pi'], model['v']
        # tfa_ph = core.placeholder(act_dim)
        tfa_ph = model['tfa_ph']

    # Need all placeholders in *this* order later (to zip with data from buffer)
    all_phs = [x_ph, a_ph, adv_ph, ret_ph, logp_old_ph]

    # Every step, get: action, value, and logprob
    get_action_ops = [pi, v, logp_pi]

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    print("---------------", local_steps_per_epoch)
    buf = PPOBuffer(obs_dim, act_dim, steps_per_epoch, gamma, lam)
    # print(obs_dim)
    # print(act_dim)
    dagger_replay_buffer = DaggerReplayBuffer(obs_dim=obs_dim[0],
                                              act_dim=act_dim[0],
                                              size=replay_size)
    # Count variables
    var_counts = tuple(core.count_vars(scope) for scope in ['pi', 'v'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # PPO objectives
    if policy_path is None:
        ratio = tf.exp(logp - logp_old_ph)  # pi(a|s) / pi_old(a|s)
        min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                           (1 - clip_ratio) * adv_ph)
        pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
        v_loss = tf.reduce_mean((ret_ph - v)**2)
        dagger_pi_loss = tf.reduce_mean(tf.square(mu - tfa_ph))

        # Info (useful to watch during learning)
        approx_kl = tf.reduce_mean(
            logp_old_ph -
            logp)  # a sample estimate for KL-divergence, easy to compute
        approx_ent = tf.reduce_mean(
            -logp)  # a sample estimate for entropy, also easy to compute
        clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio <
                                (1 - clip_ratio))
        clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))

        # Optimizers
        dagger_pi_optimizer = tf.train.AdamOptimizer(learning_rate=pi_lr)
        optimizer_pi = tf.train.AdamOptimizer(learning_rate=pi_lr)
        optimizer_v = tf.train.AdamOptimizer(learning_rate=vf_lr)
        train_dagger_pi_op = dagger_pi_optimizer.minimize(
            dagger_pi_loss, name='train_dagger_pi_op')
        train_pi = optimizer_pi.minimize(pi_loss, name='train_pi_op')
        train_v = optimizer_v.minimize(v_loss, name='train_v_op')

        sess.run(tf.variables_initializer(optimizer_pi.variables()))
        sess.run(tf.variables_initializer(optimizer_v.variables()))
        sess.run(tf.variables_initializer(dagger_pi_optimizer.variables()))
    else:
        graph = tf.get_default_graph()
        dagger_pi_loss = model['dagger_pi_loss']
        pi_loss = model['pi_loss']
        v_loss = model['v_loss']
        approx_ent = model['approx_ent']
        approx_kl = model['approx_kl']
        clipfrac = model['clipfrac']

        train_dagger_pi_op = graph.get_operation_by_name('train_dagger_pi_op')
        train_pi = graph.get_operation_by_name('train_pi_op')
        train_v = graph.get_operation_by_name('train_v_op')
    # sess = tf.Session()
    # sess.run(tf.global_variables_initializer())

    # Sync params across processes
    # sess.run(sync_all_params())

    tf.summary.FileWriter("log/", sess.graph)
    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x_ph': x_ph, 'a_ph': a_ph, 'tfa_ph': tfa_ph, 'adv_ph': adv_ph, 'ret_ph': ret_ph, 'logp_old_ph': logp_old_ph}, \
        outputs={'mu': mu, 'pi': pi, 'v': v, 'logp': logp, 'logp_pi': logp_pi, 'clipfrac': clipfrac, 'approx_kl': approx_kl, \
            'pi_loss': pi_loss, 'v_loss': v_loss, 'dagger_pi_loss': dagger_pi_loss, 'approx_ent': approx_ent})

    def update():
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent],
                                          feed_dict=inputs)

        # Training
        for i in range(train_pi_iters):
            _, kl = sess.run([train_pi, approx_kl], feed_dict=inputs)
            kl = mpi_avg(kl)
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
        logger.store(StopIter=i)
        for _ in range(train_v_iters):
            sess.run(train_v, feed_dict=inputs)

        # Log changes from update
        pi_l_new, v_l_new, kl, cf = sess.run(
            [pi_loss, v_loss, approx_kl, clipfrac], feed_dict=inputs)
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    def choose_action(s, add_noise=False):
        s = s[np.newaxis, :]
        a = sess.run(mu, {x_ph: s})[0]
        if add_noise:
            noise = dagger_noise * act_high_limit * np.random.normal(
                size=a.shape)
            a = a + noise
        return np.clip(a, act_low_limit, act_high_limit)

    def test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, info = env.step(choose_action(np.array(o), 0))
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(
                        arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        # time.sleep(10)
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    def ref_test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                a = call_ref_controller(env, expert)
                o, r, d, info = env.step(a)
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(
                        arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    ref_test_agent(test_num=-1)
    test_logger.log_tabular('epoch', -1)
    test_logger.log_tabular('TestEpRet', average_only=True)
    test_logger.log_tabular('TestEpLen', average_only=True)
    test_logger.log_tabular('arrive_des', average_only=True)
    test_logger.log_tabular('arrive_des_appro', average_only=True)
    test_logger.log_tabular('converge_dis', average_only=True)
    test_logger.log_tabular('out_of_range', average_only=True)
    test_logger.dump_tabular()

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    test_policy_epochs = 91
    episode_steps = 500
    total_env_t = 0
    test_num = 0
    print(colorize("begin dagger training", 'green', bold=True))
    for epoch in range(1, dagger_epochs + 1, 1):
        # test policy
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == epochs):
            # Save model
            logger.save_state({}, None)

            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)

            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('arrive_des_appro', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

        # train policy
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        env.unwrapped._set_test_mode(False)
        obs, acs, rewards = [], [], []
        for t in range(local_steps_per_epoch):
            a, v_t, logp_t = sess.run(
                get_action_ops, feed_dict={x_ph: np.array(o).reshape(1, -1)})
            # a = get_action_2(np.array(o))
            # save and log
            obs.append(o)
            ref_action = call_ref_controller(env, expert)
            if (epoch < pretrain_epochs):
                action = ref_action
            else:
                action = choose_action(np.array(o), True)

            buf.store(o, action, r, v_t, logp_t)
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(action)
            acs.append(ref_action)
            rewards.append(r)

            ep_ret += r
            ep_len += 1
            total_env_t += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = r if d else sess.run(
                    v, feed_dict={x_ph: np.array(o).reshape(1, -1)})
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # Perform dagger and partical PPO update!
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        # pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent], feed_dict=inputs)
        for _ in range(train_v_iters):
            sess.run(train_v, feed_dict=inputs)

        # Log changes from update
        max_step = len(np.array(rewards))
        dagger_replay_buffer.stores(obs, acs, rewards)
        for _ in range(int(local_steps_per_epoch / 10)):
            batch = dagger_replay_buffer.sample_batch(batch_size)
            feed_dict = {x_ph: batch['obs1'], tfa_ph: batch['acts']}
            q_step_ops = [dagger_pi_loss, train_dagger_pi_op]
            for j in range(10):
                outs = sess.run(q_step_ops, feed_dict)
            logger.store(LossPi=outs[0])

        c_v_loss = sess.run(v_loss, feed_dict=inputs)
        logger.store(LossV=c_v_loss,
                     KL=0,
                     Entropy=0,
                     ClipFrac=0,
                     DeltaLossPi=0,
                     DeltaLossV=0,
                     StopIter=0)

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()

    # Main loop: collect experience in env and update/log each epoch
    print(colorize("begin ppo training", 'green', bold=True))
    for epoch in range(1, epochs + 1, 1):
        # test policy
        if epoch > 0 and (epoch % save_freq == 0) or (epoch
                                                      == epochs) or epoch == 1:
            # Save model
            logger.save_state({}, None)

            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)

            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('arrive_des_appro', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

        # train policy
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        env.unwrapped._set_test_mode(False)
        for t in range(local_steps_per_epoch):
            a, v_t, logp_t = sess.run(
                get_action_ops, feed_dict={x_ph: np.array(o).reshape(1, -1)})
            # a = a[0]
            # a = get_action_2(np.array(o))
            # a = np.clip(a, act_low_limit, act_high_limit)
            # if epoch < pretrain_epochs:
            #     a = env.action_space.sample()
            # a = np.clip(a, act_low_limit, act_high_limit)
            # save and log
            buf.store(o, a, r, v_t, logp_t)
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(a[0])
            ep_ret += r
            ep_len += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = r if d else sess.run(
                    v, feed_dict={x_ph: np.array(o).reshape(1, -1)})
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # Perform PPO update!
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
def sac(env_fn,  expert=None, policy_path=None, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0, 
        steps_per_epoch=500, epochs=100000, replay_size=int(5e3), gamma=0.99, 
        dagger_noise=0.02, polyak=0.995, lr=1e-4, alpha=0.2, batch_size=64, dagger_epochs=200, pretrain_epochs=50,
        max_ep_len=500, logger_kwargs=dict(), save_freq=50, update_steps=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``mu``       (batch, act_dim)  | Computes mean actions from policy
                                           | given states.
            ``pi``       (batch, act_dim)  | Samples actions from policy given 
                                           | states.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``. Critical: must be differentiable
                                           | with respect to policy parameters all
                                           | the way through action sampling.
            ``q1``       (batch,)          | Gives one estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q2``       (batch,)          | Gives another estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q1_pi``    (batch,)          | Gives the composition of ``q1`` and 
                                           | ``pi`` for states in ``x_ph``: 
                                           | q1(x, pi(x)).
            ``q2_pi``    (batch,)          | Gives the composition of ``q2`` and 
                                           | ``pi`` for states in ``x_ph``: 
                                           | q2(x, pi(x)).
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. 
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target 
            networks. Target networks are updated towards main networks 
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow 
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually 
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to 
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())
    test_logger_kwargs = dict()
    test_logger_kwargs['output_dir'] = osp.join(logger_kwargs['output_dir'], "test")
    test_logger_kwargs['exp_name'] = logger_kwargs['exp_name']
    test_logger = EpochLogger(**test_logger_kwargs)

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    print(obs_dim)
    print(act_dim)
    
    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space
    act_high_limit = env.action_space.high
    act_low_limit = env.action_space.low

    sess = tf.Session()
    if policy_path is None:
        # Inputs to computation graph
        x_ph, a_ph, x2_ph, r_ph, d_ph = core.placeholders(obs_dim, act_dim, obs_dim, None, None)
        tfa_ph = core.placeholder(act_dim)
        # Main outputs from computation graph
        with tf.variable_scope('main'):
            mu, pi, logp_pi, q1, q2, q1_pi, q2_pi, v = actor_critic(x_ph, a_ph, **ac_kwargs)
        
        # Target value network
        with tf.variable_scope('target'):
            _, _, _, _, _, _, _, v_targ  = actor_critic(x2_ph, a_ph, **ac_kwargs)
        # sess.run(tf.global_variables_initializer())
    
    else:
        # load pretrained model
        model = restore_tf_graph(sess, osp.join(policy_path, 'simple_save'))
        x_ph, a_ph, x2_ph, r_ph, d_ph = model['x_ph'], model['a_ph'], model['x2_ph'], model['r_ph'], model['d_ph']
        mu, pi, logp_pi, q1, q2, q1_pi, q2_pi, v = model['mu'], model['pi'], model['logp_pi'], model['q1'], model['q2'], model['q1_pi'], model['q2_pi'], model['v']
        # tfa_ph = core.placeholder(act_dim)
        tfa_ph = model['tfa_ph']

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)
    dagger_replay_buffer = DaggerReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)
    # Count variables
    var_counts = tuple(core.count_vars(scope) for scope in 
                       ['main/pi', 'main/q1', 'main/q2', 'main/v', 'main'])
    print(('\nNumber of parameters: \t pi: %d, \t' + \
           'q1: %d, \t q2: %d, \t v: %d, \t total: %d\n')%var_counts)


    # print(obs_dim)
    # print(act_dim)

    # SAC objectives
    if policy_path is None:
        # Min Double-Q:
        min_q_pi = tf.minimum(q1_pi, q2_pi)

        # Targets for Q and V regression
        q_backup = tf.stop_gradient(r_ph + gamma*(1-d_ph)*v_targ)
        v_backup = tf.stop_gradient(min_q_pi - alpha * logp_pi)

        # Soft actor-critic losses
        dagger_pi_loss = tf.reduce_mean(tf.square(mu-tfa_ph))
        pi_loss = tf.reduce_mean(alpha * logp_pi - q1_pi)
        q1_loss = 0.5 * tf.reduce_mean((q_backup - q1)**2)
        q2_loss = 0.5 * tf.reduce_mean((q_backup - q2)**2)
        v_loss = 0.5 * tf.reduce_mean((v_backup - v)**2)
        value_loss = q1_loss + q2_loss + v_loss

        # Policy train op 
        # (has to be separate from value train op, because q1_pi appears in pi_loss)
        dagger_pi_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        train_dagger_pi_op = dagger_pi_optimizer.minimize(dagger_pi_loss, name='train_dagger_pi_op')

        pi_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        train_pi_op = pi_optimizer.minimize(pi_loss, var_list=get_vars('main/pi'), name='train_pi_op')
        # sess.run(tf.variables_initializer(pi_optimizer.variables()))

        # Value train op
        # (control dep of train_pi_op because sess.run otherwise evaluates in nondeterministic order)
        value_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        value_params = get_vars('main/q') + get_vars('main/v')
        with tf.control_dependencies([train_pi_op]):
            train_value_op = value_optimizer.minimize(value_loss, var_list=value_params, name='train_value_op')
            # sess.run(tf.variables_initializer(value_optimizer.variables()))

        # Polyak averaging for target variables
        # (control flow because sess.run otherwise evaluates in nondeterministic order)
        with tf.control_dependencies([train_value_op]):
            target_update = tf.group([tf.assign(v_targ, polyak*v_targ + (1-polyak)*v_main)
                                    for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])

        # All ops to call during one training step
        step_ops = [pi_loss, q1_loss, q2_loss, v_loss, q1, q2, v, logp_pi, 
                    train_pi_op, train_value_op, target_update]

        # Initializing targets to match main variables
        target_init = tf.group([tf.assign(v_targ, v_main)
                                for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])
        sess.run(tf.global_variables_initializer())
    else:
        graph = tf.get_default_graph()
        dagger_pi_loss = model['dagger_pi_loss']
        pi_loss = model['pi_loss']
        q1_loss = model['q1_loss']
        q2_loss = model['q2_loss']        
        v_loss = model['v_loss']

        train_dagger_pi_op = graph.get_operation_by_name('train_dagger_pi_op')
        train_value_op = graph.get_operation_by_name('train_value_op')
        train_pi_op = graph.get_operation_by_name('train_pi_op')
        
        # Polyak averaging for target variables
        # (control flow because sess.run otherwise evaluates in nondeterministic order)
        with tf.control_dependencies([train_value_op]):
            target_update = tf.group([tf.assign(v_targ, polyak*v_targ + (1-polyak)*v_main)
                                    for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])

        # All ops to call during one training step
        step_ops = [pi_loss, q1_loss, q2_loss, v_loss, q1, q2, v, logp_pi, 
                    train_pi_op, train_value_op, target_update]

        # Initializing targets to match main variables
        target_init = tf.group([tf.assign(v_targ, v_main)
                                for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])
    # sess = tf.Session()
    # sess.run(tf.global_variables_initializer())
    dagger_step_ops = [q1_loss, q2_loss, v_loss, q1, q2, v, logp_pi, train_value_op, target_update]
    tf.summary.FileWriter("log/", sess.graph)
    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x_ph': x_ph, 'a_ph': a_ph, 'tfa_ph': tfa_ph, 'x2_ph': x2_ph, 'r_ph': r_ph, 'd_ph': d_ph}, \
        outputs={'mu': mu, 'pi': pi, 'v': v, 'logp_pi': logp_pi, 'q1': q1, 'q2': q2, 'q1_pi': q1_pi, 'q2_pi': q2_pi, \
            'pi_loss': pi_loss, 'v_loss': v_loss, 'dagger_pi_loss': dagger_pi_loss, 'q1_loss': q1_loss, 'q2_loss': q2_loss})
    
    def get_action(o, deterministic=False):
        act_op = mu if deterministic else pi
        a = sess.run(act_op, feed_dict={x_ph: o.reshape(1,-1)})[0]
        return np.clip(a, act_low_limit, act_high_limit)

    def choose_action(s, add_noise=False):
        s = s[np.newaxis, :]
        a = sess.run(mu, {x_ph: s})[0]
        if add_noise:
            noise = dagger_noise * act_high_limit * np.random.normal(size=a.shape)
            a = a + noise
        return np.clip(a, act_low_limit, act_high_limit)

    def test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not(d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, info = env.step(choose_action(np.array(o), 0))
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        # time.sleep(10)
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    def ref_test_agent(n=81, test_num=1):
        n = env.unwrapped._set_test_mode(True)
        con_flag = False
        for j in range(n):
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            while not(d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                a  = call_ref_controller(env, expert)
                o, r, d, info = env.step(a)
                ep_ret += r
                ep_len += 1
                if d:
                    test_logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
                    test_logger.store(arrive_des=info['arrive_des'])
                    test_logger.store(arrive_des_appro=info['arrive_des_appro'])
                    if not info['out_of_range']:
                        test_logger.store(converge_dis=info['converge_dis'])
                        con_flag = True
                    test_logger.store(out_of_range=info['out_of_range'])
                    # print(info)
        # test_logger.dump_tabular()
        if not con_flag:
            test_logger.store(converge_dis=10000)
        env.unwrapped._set_test_mode(False)

    # ref_test_agent(test_num = -1)
    # test_logger.log_tabular('epoch', -1)
    # test_logger.log_tabular('TestEpRet', average_only=True)
    # test_logger.log_tabular('TestEpLen', average_only=True)
    # test_logger.log_tabular('arrive_des', average_only=True)
    # test_logger.log_tabular('arrive_des_appro', average_only=True)
    # test_logger.log_tabular('converge_dis', average_only=True)
    # test_logger.log_tabular('out_of_range', average_only=True)
    # test_logger.dump_tabular()



    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    episode_steps = 500
    total_env_t = 0
    test_num = 0
    print(colorize("begin dagger training", 'green', bold=True))
    for epoch in range(1, dagger_epochs + 1, 1):
        # test policy
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == epochs):
            # Save model
            logger.save_state({}, None)
            
            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)
            
            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            test_logger.log_tabular('arrive_des_appro', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True) 
            logger.log_tabular('Q2Vals', with_min_and_max=True) 
            logger.log_tabular('VVals', with_min_and_max=True) 
            logger.log_tabular('LogPi', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ1', average_only=True)
            logger.log_tabular('LossQ2', average_only=True)
            logger.log_tabular('LossV', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()

        # train policy
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        env.unwrapped._set_test_mode(False)
        obs, acs, rewards = [], [], []
        for t in range(steps_per_epoch):
            obs.append(o)
            ref_action = call_ref_controller(env, expert)
            if(epoch < pretrain_epochs):
                action = ref_action
            else:
                action = choose_action(np.array(o), True)
            
            o2, r, d, _ = env.step(action)
            o = o2
            acs.append(ref_action)
            rewards.append(r)

            if (t == steps_per_epoch-1):
                # print ("reached the end")
                d = True

            # Store experience to replay buffer
            replay_buffer.store(o, action, r, o2, d)

            ep_ret += r
            ep_len += 1
            total_env_t += 1

            if d:
                # Perform partical sac update!
                for j in range(ep_len):
                    batch = replay_buffer.sample_batch(batch_size)
                    feed_dict = {x_ph: batch['obs1'],
                                x2_ph: batch['obs2'],
                                a_ph: batch['acts'],
                                r_ph: batch['rews'],
                                d_ph: batch['done'],
                                }
                    outs = sess.run(dagger_step_ops, feed_dict)
                    logger.store(LossQ1=outs[0], LossQ2=outs[1],
                                LossV=outs[2], Q1Vals=outs[3], Q2Vals=outs[4],
                                VVals=outs[5], LogPi=outs[6])

                # Perform dagger policy update
                dagger_replay_buffer.stores(obs, acs, rewards)
                for _ in range(int(ep_len/5)):
                    batch = dagger_replay_buffer.sample_batch(batch_size)
                    feed_dict = {x_ph: batch['obs1'], tfa_ph: batch['acts']}
                    q_step_ops = [dagger_pi_loss, train_dagger_pi_op]
                    for j in range(10):
                        outs = sess.run(q_step_ops, feed_dict)
                    logger.store(LossPi = outs[0])

                logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
                break

    # Main loop: collect experience in env and update/log each epoch
    print(colorize("begin sac training", 'green', bold=True))
    for epoch in range(1, epochs + 1, 1):
        # test policy
        if epoch > 0 and (epoch % save_freq == 0) or (epoch == epochs):
            # Save model
            logger.save_state({}, None)
            
            # Test the performance of the deterministic version of the agent.
            test_num += 1
            test_agent(test_num=test_num)
            
            test_logger.log_tabular('epoch', epoch)
            test_logger.log_tabular('TestEpRet', average_only=True)
            test_logger.log_tabular('TestEpLen', average_only=True)
            test_logger.log_tabular('arrive_des', average_only=True)
            # test_logger.log_tabular('arrive_des_appro', average_only=True)
            test_logger.log_tabular('converge_dis', average_only=True)
            test_logger.log_tabular('out_of_range', average_only=True)
            test_logger.dump_tabular()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('VVals', with_min_and_max=True)
            logger.log_tabular('TotalEnvInteracts', (epoch+1)*steps_per_epoch)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossV', average_only=True)
            # logger.log_tabular('DeltaLossPi', average_only=True)
            # logger.log_tabular('DeltaLossV', average_only=True)
            # logger.log_tabular('Entropy', average_only=True)
            # logger.log_tabular('KL', average_only=True)
            # logger.log_tabular('ClipFrac', average_only=True)
            # logger.log_tabular('StopIter', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()

        # train policy
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        env.unwrapped._set_test_mode(False)
        for t in range(steps_per_epoch):
            a = get_action(np.array(o))

            o2, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1
            if (t == steps_per_epoch-1):
                # print ("reached the end")
                d = True

            replay_buffer.store(o, a, r, o2, d)
            o = o2
            if d:
                """
                Perform all SAC updates at the end of the trajectory.
                This is a slight difference from the SAC specified in the
                original paper.
                """
                for j in range(ep_len):
                    batch = replay_buffer.sample_batch(batch_size)
                    feed_dict = {x_ph: batch['obs1'],
                                x2_ph: batch['obs2'],
                                a_ph: batch['acts'],
                                r_ph: batch['rews'],
                                d_ph: batch['done'],
                                }
                    outs = sess.run(step_ops, feed_dict)
                    logger.store(LossPi=outs[0], LossQ1=outs[1], LossQ2=outs[2],
                                LossV=outs[3], Q1Vals=outs[4], Q2Vals=outs[5],
                                VVals=outs[6], LogPi=outs[7])

                logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
Ejemplo n.º 19
0
def ppo(env_fn,
        actor_critic=core.mlp_actor_critic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        episodes_per_epoch=None,
        epochs=50,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=80,
        train_v_iters=80,
        lam=0.97,
        max_ep_len=1000,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10,
        custom_h=None,
        eval_episodes=50,
        do_checkpoint_eval=False,
        env_name=None,
        eval_temp=1.0,
        train_starting_temp=1.0,
        env_version=None,
        env_input=None,
        target_arcs=None,
        early_stop_epochs=None,
        save_all_eval=False,
        meta_learning=False,
        finetune=False,
        finetune_model_path=None):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Samples actions from policy given 
                                           | states.
            ``logp``     (batch,)          | Gives log probability, according to
                                           | the policy, of taking actions ``a_ph``
                                           | in states ``x_ph``.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``.
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. (Critical: make sure 
                                           | to flatten this!)
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to PPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while 
            still profiting (improving the objective function)? The new policy 
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.)

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_pi_iters (int): Maximum number of gradient descent steps to take 
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)

        train_v_iters (int): Number of gradient descent steps to take on 
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used 
            for early stopping. (Usually small, 0.01 or 0.05.)

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """
    # create logger for training
    logger = EpochLogger(meta_learning_or_finetune=(finetune or meta_learning),
                         **logger_kwargs)
    logger.save_config(locals())

    # create logger for evaluation to keep track of evaluation values at each checkpoint (or save frequency)
    # using eval_progress.txt. It is different from the logger_eval used inside one evaluation epoch.
    logger_eval_progress = EpochLogger(output_fname='progress_eval.txt',
                                       **logger_kwargs)

    # create logger for evaluation and save best performance, best structure, and best model in simple_save999999
    logger_eval = EpochLogger(**dict(
        exp_name=logger_kwargs['exp_name'],
        output_dir=os.path.join(logger.output_dir, "simple_save999999")))

    # create logger for tensorboard
    tb_logdir = "{}/tb_logs/".format(logger.output_dir)
    tb_logger = Logger(log_dir=tb_logdir)

    seed += 10000 * proc_id()
    tf.set_random_seed(seed)
    np.random.seed(seed)
    logger.log('set tf and np random seed = {}'.format(seed))

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    if custom_h is not None:
        hidden_layers_str_list = custom_h.split('-')
        hidden_layers_int_list = [int(h) for h in hidden_layers_str_list]
        ac_kwargs['hidden_sizes'] = hidden_layers_int_list

    # create a tf session with GPU memory usage option to be allow_growth so that one program will not use up the
    # whole GPU memory
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    # log tf graph
    tf.summary.FileWriter(tb_logdir, sess.graph)

    if not finetune:
        # Inputs to computation graph
        x_ph, a_ph = core.placeholders_from_spaces(env.observation_space,
                                                   env.action_space)
        adv_ph, ret_ph, logp_old_ph = core.placeholders(None, None, None)

        temperature_ph = tf.placeholder(tf.float32, shape=(), name="init")

        # Main outputs from computation graph
        pi, logp, logp_pi, v = actor_critic(x_ph, a_ph, temperature_ph,
                                            **ac_kwargs)

        # PPO objectives
        ratio = tf.exp(logp - logp_old_ph)  # pi(a|s) / pi_old(a|s)
        min_adv = tf.where(adv_ph > 0, (1 + clip_ratio) * adv_ph,
                           (1 - clip_ratio) * adv_ph)
        pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
        v_loss = tf.reduce_mean((ret_ph - v)**2)

        # Info (useful to watch during learning)
        approx_kl = tf.reduce_mean(
            logp_old_ph -
            logp)  # a sample estimate for KL-divergence, easy to compute
        approx_ent = tf.reduce_mean(
            -logp)  # a sample estimate for entropy, also easy to compute
        clipped = tf.logical_or(ratio > (1 + clip_ratio), ratio <
                                (1 - clip_ratio))
        clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))

        # Optimizers
        train_pi = tf.compat.v1.train.AdamOptimizer(
            learning_rate=pi_lr).minimize(pi_loss, name='train_pi')
        train_v = tf.compat.v1.train.AdamOptimizer(
            learning_rate=vf_lr).minimize(v_loss, name='train_v')

        sess.run(tf.global_variables_initializer())

    else:  # do finetuning -- load model from meta_model_path
        assert finetune_model_path is not None, "Please specify the path to the meta learnt model using --finetune_model_path"
        if 'simple_save' in finetune_model_path:
            model = restore_tf_graph(sess,
                                     fpath=finetune_model_path,
                                     meta_learning_or_finetune=finetune)
        else:
            model = restore_tf_graph(sess,
                                     fpath=finetune_model_path +
                                     '/simple_save999999',
                                     meta_learning_or_finetune=finetune)

        # get placeholders
        x_ph, a_ph, adv_ph = model['x'], model['a'], model['adv']
        ret_ph, logp_old_ph, temperature_ph = model['ret'], model[
            'logp_old'], model['temperature']

        # get model output
        pi, logp, logp_pi, v = model['pi'], model['logp'], model[
            'logp_pi'], model['v']
        pi_loss, v_loss = model['pi_loss'], model['v_loss']
        approx_kl, approx_ent, clipfrac = model['approx_kl'], model[
            'approx_ent'], model['clipfrac']

        # get Optimizers
        train_pi = model['train_pi']
        train_v = model['train_v']

    # Need all placeholders in *this* order later (to zip with data from buffer)
    all_phs = [x_ph, a_ph, adv_ph, ret_ph, logp_old_ph, temperature_ph]

    # Every step, get: action, value, and logprob
    get_action_ops = [pi, v, logp_pi]

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # Count variables
    var_counts = tuple(core.count_vars(scope) for scope in ['pi', 'v'])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # # log tf graph
    # tf.summary.FileWriter(tb_logdir, sess.graph)

    # Sync params across processes
    sess.run(sync_all_params())

    # Setup model saving
    logger.setup_tf_saver(sess,
                          inputs={
                              'x': x_ph,
                              'a': a_ph,
                              'adv': adv_ph,
                              'ret': ret_ph,
                              'logp_old': logp_old_ph,
                              'temperature': temperature_ph
                          },
                          outputs={
                              'pi': pi,
                              'v': v,
                              'logp': logp,
                              'logp_pi': logp_pi,
                              'pi_loss': pi_loss,
                              'v_loss': v_loss,
                              'approx_kl': approx_kl,
                              'approx_ent': approx_ent,
                              'clipfrac': clipfrac
                          })

    def update():
        inputs = {k: v for k, v in zip(all_phs, buf.get())}
        pi_l_old, v_l_old, ent = sess.run([pi_loss, v_loss, approx_ent],
                                          feed_dict=inputs)

        # Training
        for i in range(train_pi_iters):
            _, kl = sess.run([train_pi, approx_kl], feed_dict=inputs)
            kl = mpi_avg(kl)
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
        logger.store(StopIter=i)
        for _ in range(train_v_iters):
            sess.run(train_v, feed_dict=inputs)

        # Log changes from update
        pi_l_new, v_l_new, kl, cf = sess.run(
            [pi_loss, v_loss, approx_kl, clipfrac], feed_dict=inputs)
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    o, r, d, ep_ret, ep_len, ep_dummy_action_count, ep_len_normalized = env.reset(
    ), 0, False, 0, 0, 0, []

    # initialize variables for keeping track of BEST eval performance
    best_eval_AverageEpRet = -0.05  # a negative value so that best model is saved at least once.
    best_eval_StdEpRet = 1.0e30

    # below are used for early-stop. We early stop if
    # 1) a best model has been saved, and,
    # 2) 50 epochs have passed without a new save
    saved = False
    early_stop_count_started = False
    episode_count_after_saved = 0

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        current_temp = _get_current_temperature(epoch, epochs,
                                                train_starting_temp)
        for t in range(local_steps_per_epoch):
            a, v_t, logp_t = sess.run(get_action_ops,
                                      feed_dict={
                                          x_ph: o.reshape(1, -1),
                                          temperature_ph: current_temp
                                      })

            # save and log
            buf.store(o, a, r, v_t, logp_t, current_temp)
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(a[0])
            ep_ret += r
            ep_len += 1

            if env_version >= 4:
                ep_len_normalized.append(ep_len / env.allowed_steps)
                if env.action_is_dummy:  # a is dummy action
                    ep_dummy_action_count += 1

            terminal = d or (ep_len == max_ep_len)

            if terminal or (t == local_steps_per_epoch - 1):
                if not terminal:
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = r if d else sess.run(v,
                                                feed_dict={
                                                    x_ph: o.reshape(1, -1),
                                                    temperature_ph:
                                                    current_temp
                                                })
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                    if env_version >= 4:
                        logger.store(EpDummyCount=ep_dummy_action_count)
                        logger.store(EpTotalArcs=env.adjacency_matrix.sum())

                        assert len(ep_len_normalized) > 0
                        ep_len_normalized = np.asarray(
                            ep_len_normalized, dtype=np.float32).mean()
                        logger.store(EpDummyStepsNormalized=ep_len_normalized)

                o, r, d, ep_ret, ep_len, ep_dummy_action_count, ep_len_normalized = env.reset(
                ), 0, False, 0, 0, 0, []

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):

            if meta_learning:
                # Save a new model every save_freq and at the last epoch. Do not overwrite the previous save.
                logger.save_state({'env_name': env_name}, epoch)
            else:
                # Save a new model every save_freq and at the last epoch. Only keep one copy - the current model
                logger.save_state({'env_name': env_name})

            # Evaluate and save best model
            if do_checkpoint_eval and epoch > 0:
                # below is a hack. best model related stuff is saved at itr 999999, therefore, simple_save999999.
                # Doing this way, I can use test_policy and plot directly to test the best models.
                # saved best models includes:
                # 1) a copy of the env_name
                # 2) the best rl model with parameters
                # 3) a pickle file "best_eval_performance_n_structure" storing best_performance, best_structure and epoch
                # note that 1) and 2) are spinningup defaults, and 3) is a custom save
                best_eval_AverageEpRet, best_eval_StdEpRet, saved = eval_and_save_best_model(
                    best_eval_AverageEpRet,
                    best_eval_StdEpRet,
                    # a new best logger is created and passed in so that the new logger can leverage the directory
                    # structure without messing up the logger in the training loop
                    # eval_logger=EpochLogger(**dict(
                    #     exp_name=logger_kwargs['exp_name'],
                    #     output_dir=os.path.join(logger.output_dir, "simple_save999999"))),
                    eval_logger=logger_eval,
                    train_logger=logger,
                    eval_progress_logger=logger_eval_progress,
                    tb_logger=tb_logger,
                    epoch=epoch,
                    # the env_name is passed in so that to create an env when and where it is needed. This is to
                    # logx.save_state() error where an env pointer cannot be pickled
                    env_name="F{}x{}T{}_SP{}_v{}".format(
                        env.n_plant, env.n_product, env.target_arcs,
                        env.n_sample, env_version)
                    if env_version >= 3 else env_name,
                    env_version=env_version,
                    env_input=env_input,
                    render=
                    False,  # change this to True if you want to visualize how arcs are added during evaluation
                    target_arcs=env.input_target_arcs,
                    get_action=lambda x: sess.run(pi,
                                                  feed_dict={
                                                      x_ph: x[None, :],
                                                      temperature_ph: eval_temp
                                                  })[0],
                    # number of samples to draw when simulate demand
                    n_sample=5000,
                    num_episodes=eval_episodes,
                    seed=seed,
                    save_all_eval=save_all_eval)

        # Perform PPO update!
        update()

        # # # Log into tensorboard
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="EpRet",
                      with_min_and_max=True)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="EpLen",
                      with_min_and_max=False)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="VVals",
                      with_min_and_max=True)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="LossPi",
                      with_min_and_max=False)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="LossV",
                      with_min_and_max=False)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="DeltaLossPi",
                      with_min_and_max=False)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="DeltaLossV",
                      with_min_and_max=False)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="Entropy",
                      with_min_and_max=False)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="KL",
                      with_min_and_max=False)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="ClipFrac",
                      with_min_and_max=False)
        log_key_to_tb(tb_logger,
                      logger,
                      epoch,
                      key="StopIter",
                      with_min_and_max=False)
        tb_logger.log_scalar(tag="TotalEnvInteracts",
                             value=(epoch + 1) * steps_per_epoch,
                             step=epoch)
        tb_logger.log_scalar(tag="Time",
                             value=time.time() - start_time,
                             step=epoch)
        tb_logger.log_scalar(tag="epoch_temp", value=current_temp, step=epoch)
        if env_version >= 4:
            log_key_to_tb(tb_logger,
                          logger,
                          epoch,
                          key="EpDummyCount",
                          with_min_and_max=False)
            log_key_to_tb(tb_logger,
                          logger,
                          epoch,
                          key="EpTotalArcs",
                          with_min_and_max=False)

            if 'EpDummyStepsNormalized' in logger.epoch_dict.keys():
                if len(logger.epoch_dict['EpDummyStepsNormalized']) > 0:
                    log_key_to_tb(tb_logger,
                                  logger,
                                  epoch,
                                  key="EpDummyStepsNormalized",
                                  with_min_and_max=False)

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.log_tabular('EpochTemp', current_temp)
        if env_version >= 4:
            logger.log_tabular('EpDummyCount', with_min_and_max=True)
            if 'EpDummyStepsNormalized' in logger.epoch_dict.keys():
                if len(logger.epoch_dict['EpDummyStepsNormalized']) > 0:
                    logger.log_tabular('EpDummyStepsNormalized',
                                       average_only=True)
            logger.log_tabular('EpTotalArcs', average_only=True)

        logger.dump_tabular()

        if early_stop_epochs > 0:
            # check for early stop
            if saved:
                # start to count the episodes elapsed after a "saved" event
                early_stop_count_started = True

                # reset the count to 0
                episode_count_after_saved = 0

            else:
                # check whether we should count this episode, i.e., whether early_stop_count_started == True
                if early_stop_count_started:
                    episode_count_after_saved += 1
                    if episode_count_after_saved > early_stop_epochs:
                        logger.log('Early Stopped at epoch {}.'.format(epoch),
                                   color='cyan')
                        break
def load_policy(fpath,
                itr='last',
                deterministic=False,
                act_high=1,
                hidden_sizes=(64, 64),
                activation=tf.tanh):

    # handle which epoch to load from
    if itr == 'last':
        saves = [
            int(x[11:]) for x in os.listdir(fpath)
            if 'simple_save' in x and len(x) > 11
        ]
        itr = '%d' % max(saves) if len(saves) > 0 else ''
    else:
        itr = '%d' % itr

    # load the things!
    sess = tf.Session()
    print("itr:", itr)
    model = restore_tf_graph(sess, osp.join(fpath, 'simple_save' + itr))

    if deterministic and 'mu' in model.keys():
        print('Using deterministic action op.')
        with tf.variable_scope("pi", reuse=True):
            mu = model['mu']
    else:
        print('Using default action op.')
        with tf.variable_scope("pi", reuse=True):
            mu = model['pi']

    x = model['x']
    a = model['a']

    vf_mlp = lambda x: tf.squeeze(
        core.mlp(x,
                 list(hidden_sizes) + [1], activation, None), axis=1)
    with tf.variable_scope('q1'):
        q1 = vf_mlp(tf.concat([x, a], axis=-1))
    with tf.variable_scope('q1', reuse=True):
        q1_pi = vf_mlp(tf.concat([x, pi], axis=-1))
    with tf.variable_scope('q2'):
        q2 = vf_mlp(tf.concat([x, a], axis=-1))
    with tf.variable_scope('q2', reuse=True):
        q2_pi = vf_mlp(tf.concat([x, pi], axis=-1))
    sess.run(tf.global_variables_initializer())

    LOG_STD_MAX = 2
    LOG_STD_MIN = -20
    act_dim = a.shape.as_list()[-1]
    with tf.variable_scope("pi", reuse=True):
        # log_std = tf.constant(0.01*act_high, dtype=tf.float32, shape=(act_dim,))
        net = core.mlp(x, list(hidden_sizes), activation, activation)
        log_std = tf.layers.dense(net, act_dim, activation=tf.tanh)
    log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
    # log_std = tf.get_variable(name='log_std', initializer=math.log(0.01*act_high[0])*np.ones(act_dim, dtype=np.float32))
    std = tf.exp(log_std)
    with tf.variable_scope("pi", reuse=True):
        pi = mu + tf.random_normal(tf.shape(mu)) * std
        logp_pi = core.gaussian_likelihood(pi, mu, log_std)

    if 'v' in model.keys():
        print("value function already in model")
        with tf.variable_scope('v'):
            v = model['v']
    else:
        with tf.variable_scope('v'):
            v = vf_mlp(x)

    # get_action = lambda x : sess.run(mu, feed_dict={model['x']: x[None,:]})[0]
    sess.run(tf.initialize_variables([log_std]))

    return sess, model['x'], model[
        'a'], mu, pi, logp_pi, q1, q2, q1_pi, q2_pi, v