Exemplo n.º 1
0
 def init_phi(self, 
         pf_batch_size=32,
         pf_weight_decay=0.,
         pf_update_method='adam',
         **kwargs):
     
     self.pf_batch_size = pf_batch_size
     self.pf_weight_decay = 0.
     self.pf_update_method = \
         FirstOrderOptimizer(
             update_method=pf_update_method,
             learning_rate=self.pf_learning_rate
         )
Exemplo n.º 2
0
    def __init__(self,
                 optimizer=None,
                 optimizer_args=None,
                 step_size=0.01,
                 use_maml=True,
                 **kwargs):
        assert optimizer is not None  # only for use with MAML TRPO

        self.optimizer = optimizer
        self.offPolicy_optimizer = FirstOrderOptimizer(max_epochs=1)
        self.step_size = step_size
        self.use_maml = use_maml
        self.kl_constrain_step = -1  # needs to be 0 or -1 (original pol params, or new pol params)
        super(MAMLNPO, self).__init__(**kwargs)
Exemplo n.º 3
0
 def __init__(self,
              env,
              policy_list,
              baseline_list,
              optimizer=None,
              optimizer_args=None,
              use_maml=True,
              **kwargs):
     Serializable.quick_init(self, locals())
     if optimizer is None:
         default_args = dict(
             batch_size=None,
             max_epochs=1,
         )
         if optimizer_args is None:
             optimizer_args = default_args
         else:
             optimizer_args = dict(default_args, **optimizer_args)
         optimizer_list = [
             FirstOrderOptimizer(**optimizer_args)
             for n in range(len(policy_list))
         ]
     self.optimizer_list = optimizer_list
     self.opt_info = None
     super(BMAMLREPTILE, self).__init__(env=env,
                                        policy_list=policy_list,
                                        baseline_list=baseline_list,
                                        **kwargs)
Exemplo n.º 4
0
    def __init__(
            self,
            env,
            policy,
            baseline,
            env_path,
            env_num,
            env_keep_itr,
            optimizer=None,
            optimizer_args=None,
            transfer=True,
            record_env=True
            **kwargs):
        Serializable.quick_init(self, locals())
        if optimizer is None:
            default_args = dict(
                batch_size=None,
                max_epochs=1,
            )
            if optimizer_args is None:
                optimizer_args = default_args
            else:
                optimizer_args = dict(default_args, **optimizer_args)
            optimizer = FirstOrderOptimizer(**optimizer_args)
        self.optimizer = optimizer
        self.opt_info = None

        self.transfer = transfer
        self.record_env = record_env
        self.env_path = env_path
        self.env_num = env_num
        self.env_keep_itr = env_keep_itr
        super(VPG_t, self).__init__(env=env, policy=policy, baseline=baseline, sampler_cls=QMDPSampler,sampler_args=dict(), **kwargs)
Exemplo n.º 5
0
 def __init__(self,
              env,
              policy,
              baseline,
              optimizer=None,
              optimizer_args=None,
              use_maml=True,
              **kwargs):
     Serializable.quick_init(self, locals())
     if optimizer is None:
         default_args = dict(
             batch_size=None,
             max_epochs=1,
         )
         if optimizer_args is None:
             optimizer_args = default_args
         else:
             optimizer_args = dict(default_args, **optimizer_args)
         optimizer = FirstOrderOptimizer(**optimizer_args)
     self.optimizer = optimizer
     self.opt_info = None
     self.use_maml = use_maml
     super(MAMLVPG, self).__init__(env=env,
                                   policy=policy,
                                   baseline=baseline,
                                   use_maml=use_maml,
                                   **kwargs)
Exemplo n.º 6
0
    def __init__(
            self,
            env,
            policy,
            baseline,
            optimizer=None,
            optimizer_args=None,
            **kwargs):
        Serializable.quick_init(self, locals())
        if optimizer is None:
            default_args = dict(
                batch_size=None,
                max_epochs=1,
            )
            if optimizer_args is None:
                optimizer_args = default_args
            else:
                optimizer_args = dict(default_args, **optimizer_args)
            optimizer = FirstOrderOptimizer(**optimizer_args)
        self.optimizer = optimizer
        self.opt_info = None
        if "extra_input" in kwargs.keys():
            self.extra_input = kwargs["extra_input"]
        else:
            self.extra_input = ""
        if "extra_input_dim" in kwargs.keys():
            self.extra_input_dim = kwargs["extra_input_dim"]
        else:
            self.extra_input_dim = 0

        super(VPG, self).__init__(env=env, policy=policy, baseline=baseline, **kwargs)
Exemplo n.º 7
0
 def __init__(
         self,
         optimizer=None,
         optimizer_args=None,
         step_size=0.01,
         use_maml=True,
      
         **kwargs):
     assert optimizer is not None  # only for use with MAML TRPO
     if optimizer is None:
         if optimizer_args is None:
             optimizer_args = dict()
         optimizer = PenaltyLbfgsOptimizer(**optimizer_args)
     if not use_maml:
         default_args = dict(
             batch_size=None,
             max_epochs=1,
         
         )
         optimizer = FirstOrderOptimizer(**default_args)
     self.optimizer = optimizer
     self.step_size = step_size
     self.use_maml = use_maml
     self.kl_constrain_step = -1  # needs to be 0 or -1 (original pol params, or new pol params)
    
     super(MAMLNPO, self).__init__(**kwargs)
Exemplo n.º 8
0
def get_baseline(env, value_function, num_slices):
    if (value_function == 'zero'):
        baseline = ZeroBaseline(env.spec)
    else:
        value_network = get_value_network(env)

        if (value_function == 'conj'):
            baseline_optimizer = ConjugateGradientOptimizer(
                subsample_factor=1.0, num_slices=num_slices)
        elif (value_function == 'adam'):
            baseline_optimizer = FirstOrderOptimizer(
                max_epochs=3,
                batch_size=512,
                num_slices=num_slices,
                ignore_last=True,
                #verbose=True
            )
        else:
            logger.log("Inappropirate value function")
            exit(0)

        baseline = DeterministicMLPBaseline(env.spec,
                                            num_slices=num_slices,
                                            regressor_args=dict(
                                                network=value_network,
                                                optimizer=baseline_optimizer,
                                                normalize_inputs=False))

    return baseline
Exemplo n.º 9
0
    def init_critic(self,
            min_pool_size=10000,
            replay_pool_size=1000000,
            replacement_prob=1.0,
            qf_batch_size=32,
            qf_weight_decay=0.,
            qf_update_method='adam',
            qf_learning_rate=1e-3,
            qf_use_target=True,
            qf_mc_ratio = 0,
            qf_residual_phi = 0,
            soft_target_tau=0.001,
            **kwargs):
        self.soft_target_tau = soft_target_tau
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.qf_batch_size = qf_batch_size
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.qf_use_target = qf_use_target

        self.qf_mc_ratio = qf_mc_ratio
        if self.qf_mc_ratio > 0:
            self.mc_y_averages = []

        self.qf_residual_phi = qf_residual_phi
        if self.qf_residual_phi > 0:
            self.residual_y_averages = []
            self.qf_residual_loss_averages = []

        self.qf_loss_averages = []
        self.q_averages = []
        self.y_averages = []
Exemplo n.º 10
0
 def __init__(self, env, policy_or_policies, baseline_or_baselines, optimizer=None,
              optimizer_args=None, **kwargs):
     Serializable.quick_init(self, locals())
     if optimizer is None:
         default_args = dict(
             batch_size=None,
             max_epochs=1,)
         if optimizer_args is None:
             optimizer_args = default_args
         else:
             optimizer_args = dict(default_args, **optimizer_args)
         optimizer = FirstOrderOptimizer(**optimizer_args)
     self.optimizer = optimizer
     self.opt_info = None
     super(MAVPG, self).__init__(env=env, policy_or_policies=policy_or_policies,
                                 baseline_or_baselines=baseline_or_baselines, **kwargs)
Exemplo n.º 11
0
def opt_vpg(env,
            baseline,
            policy,
            learning_rate=1e-5,
            batch_size=4000,
            **kwargs):
    # no idea what batch size, learning rate, etc. should be
    optimiser = FirstOrderOptimizer(
        tf_optimizer_cls=tf.train.AdamOptimizer,
        tf_optimizer_args=dict(learning_rate=learning_rate),
        # batch_size actually gets passed to BatchPolopt (parent of VPG)
        # instead of TF optimiser (makes sense, I guess)
        batch_size=None,
        max_epochs=1)
    return VPG(env=env,
               policy=policy,
               baseline=baseline,
               n_itr=int(1e9),
               optimizer=optimiser,
               batch_size=batch_size,
               **kwargs)
Exemplo n.º 12
0
class Poleval():
    """
    Base class defining methods for policy evaluation.
    """
    def init_critic(self,
            min_pool_size=10000,
            replay_pool_size=1000000,
            replacement_prob=1.0,
            qf_batch_size=32,
            qf_weight_decay=0.,
            qf_update_method='adam',
            qf_learning_rate=1e-3,
            qf_use_target=True,
            qf_mc_ratio = 0,
            qf_residual_phi = 0,
            soft_target_tau=0.001,
            **kwargs):
        self.soft_target_tau = soft_target_tau
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.qf_batch_size = qf_batch_size
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.qf_use_target = qf_use_target

        self.qf_mc_ratio = qf_mc_ratio
        if self.qf_mc_ratio > 0:
            self.mc_y_averages = []

        self.qf_residual_phi = qf_residual_phi
        if self.qf_residual_phi > 0:
            self.residual_y_averages = []
            self.qf_residual_loss_averages = []

        self.qf_loss_averages = []
        self.q_averages = []
        self.y_averages = []

    def log_critic_training(self):
        if self.qf is None: return
        if len(self.q_averages) == 0: return
        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)
        average_q_loss = np.mean(self.qf_loss_averages)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        if self.qf_mc_ratio > 0:
            all_mc_ys = np.concatenate(self.mc_y_averages)
            logger.record_tabular('AverageMcY', np.mean(all_mc_ys))
            logger.record_tabular('AverageAbsMcY', np.mean(np.abs(all_mc_ys)))
            self.mc_y_averages = []
        if self.qf_residual_phi > 0:
            all_residual_ys = np.concatenate(self.residual_y_averages)
            average_q_residual_loss = np.mean(self.qf_residual_loss_averages)
            logger.record_tabular('AverageResQLoss', average_q_residual_loss)
            logger.record_tabular('AverageResY', np.mean(all_residual_ys))
            logger.record_tabular('AverageAbsResY', np.mean(np.abs(all_residual_ys)))
            logger.record_tabular('AverageAbsResQYDiff',
                              np.mean(np.abs(all_qs - all_residual_ys)))
            self.residual_y_averages = []
            self.qf_residual_loss_averages = []
        self.qf_loss_averages = []
        self.q_averages = []
        self.y_averages = []

    def init_opt_critic(self):
        if self.qf is None: return

        if self.qf_use_target:
            logger.log("[init_opt] using target qf.")
            target_qf = Serializable.clone(self.qf, name="target_qf")
        else:
            logger.log("[init_opt] no target qf.")
            target_qf = self.qf

        obs = self.qf.env_spec.observation_space.new_tensor_variable(
            'qf_obs',
            extra_dims=1,
        )
        action = self.qf.env_spec.action_space.new_tensor_variable(
            'qf_action',
            extra_dims=1,
        )
        yvar = tf.placeholder(dtype=tf.float32, shape=[None], name='ys')

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        qval = self.qf.get_qval_sym(obs, action)

        qf_loss = tf.reduce_mean(tf.square(yvar - qval))
        qf_input_list = [yvar, obs, action]
        qf_output_list = [qf_loss, qval]

        # set up residual gradient method
        if self.qf_residual_phi > 0:
            next_obs = self.qf.env_spec.observation_space.new_tensor_variable(
                'qf_next_obs',
                extra_dims=1,
            )
            rvar = tf.placeholder(dtype=tf.float32, shape=[None], name='rs')
            terminals = tf.placeholder(dtype=tf.float32, shape=[None], name='terminals')
            discount = tf.placeholder(dtype=tf.float32, shape=(), name='discount')
            qf_loss *= (1. - self.qf_residual_phi)
            next_qval = self.qf.get_e_qval_sym(next_obs, self.policy)
            residual_ys = rvar + (1.-terminals)*discount*next_qval
            qf_residual_loss = tf.reduce_mean(tf.square(residual_ys-qval))
            qf_loss += self.qf_residual_phi * qf_residual_loss
            qf_input_list += [next_obs, rvar, terminals, discount]
            qf_output_list += [qf_residual_loss, residual_ys]

        # set up monte carlo Q fitting method
        if self.qf_mc_ratio > 0:
            mc_obs = self.qf.env_spec.observation_space.new_tensor_variable(
                'qf_mc_obs',
                extra_dims=1,
            )
            mc_action = self.qf.env_spec.action_space.new_tensor_variable(
                'qf_mc_action',
                extra_dims=1,
            )
            mc_yvar = tf.placeholder(dtype=tf.float32, shape=[None], name='mc_ys')
            mc_qval = self.qf.get_qval_sym(mc_obs, mc_action)
            qf_mc_loss = tf.reduce_mean(tf.square(mc_yvar - mc_qval))
            qf_loss = (1.-self.qf_mc_ratio)*qf_loss + self.qf_mc_ratio*qf_mc_loss
            qf_input_list += [mc_yvar, mc_obs, mc_action]

        qf_reg_loss = qf_loss + qf_weight_decay_term
        self.qf_update_method.update_opt(
            loss=qf_reg_loss, target=self.qf, inputs=qf_input_list)
        qf_output_list += [self.qf_update_method._train_op]

        f_train_qf = tensor_utils.compile_function(
            inputs=qf_input_list,
            outputs=qf_output_list,
        )

        self.opt_info_critic = dict(
            f_train_qf=f_train_qf,
            target_qf=target_qf,
        )

    def do_critic_training(self, itr, batch, samples_data=None):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch,
            "observations", "actions", "rewards", "next_observations",
            "terminals"
        )

        target_qf = self.opt_info_critic["target_qf"]
        if self.qf_dqn:
            next_qvals = target_qf.get_max_qval(next_obs)
        else:
            target_policy = self.opt_info["target_policy"]
            next_qvals = target_qf.get_e_qval(next_obs, target_policy)

        ys = rewards + (1. - terminals) * self.discount * next_qvals
        inputs = (ys, obs, actions)

        if self.qf_residual_phi:
            inputs += (next_obs, rewards, terminals, self.discount)
        if self.qf_mc_ratio > 0:
            mc_inputs = ext.extract(
                samples_data,
                "qvalues", "observations", "actions"
            )
            inputs += mc_inputs
            self.mc_y_averages.append(mc_inputs[0])

        qf_outputs = self.opt_info_critic['f_train_qf'](*inputs)
        qf_loss = qf_outputs.pop(0)
        qval = qf_outputs.pop(0)
        if self.qf_residual_phi:
            qf_residual_loss = qf_outputs.pop(0)
            residual_ys = qf_outputs.pop(0)
            self.qf_residual_loss_averages.append(qf_residual_loss)
            self.residual_y_averages.append(residual_ys)

        if self.qf_use_target:
            target_qf.set_param_values(
                target_qf.get_param_values() * (1.0 - self.soft_target_tau) +
                self.qf.get_param_values() * self.soft_target_tau)

        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)
Exemplo n.º 13
0
class DDPG(RLAlgorithm):
    """
    Deep Deterministic Policy Gradient.
    """
    def __init__(self,
                 env,
                 policy,
                 oracle_policy,
                 qf,
                 gate_qf,
                 agent_strategy,
                 batch_size=32,
                 n_epochs=200,
                 epoch_length=1000,
                 min_pool_size=10000,
                 replay_pool_size=1000000,
                 replacement_prob=1.0,
                 discount=0.99,
                 max_path_length=250,
                 qf_weight_decay=0.,
                 qf_update_method='adam',
                 qf_learning_rate=1e-3,
                 policy_weight_decay=0,
                 policy_update_method='adam',
                 policy_learning_rate=1e-3,
                 policy_updates_ratio=1.0,
                 eval_samples=10000,
                 soft_target=True,
                 soft_target_tau=0.001,
                 n_updates_per_sample=1,
                 scale_reward=1.0,
                 include_horizon_terminal_transitions=False,
                 plot=False,
                 pause_for_plot=False):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q function.
        :param qf_update_method: Online optimization method for training Q function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the policy.
        :param policy_update_method: Online optimization method for training the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the policy.
        :param soft_target_tau: Interpolation parameter for doing the soft target update.
        :param n_updates_per_sample: Number of Q function and policy updates per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when training
        :param include_horizon_terminal_transitions: whether to include transitions with terminal=True because the
        horizon was reached. This might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting.
        :return:
        """
        self.env = env
        self.policy = policy
        self.oracle_policy = oracle_policy
        self.qf = qf
        self.discrete_qf = gate_qf
        self.gate_qf = gate_qf
        self.agent_strategy = agent_strategy
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay

        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )

        self.gate_qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )

        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay


        self.policy_update_method = \
            FirstOrderOptimizer(
                update_method=policy_update_method,
                learning_rate=policy_learning_rate,
            )

        self.policy_gate_update_method = \
            FirstOrderOptimizer(
                update_method=policy_update_method,
                learning_rate=policy_learning_rate,
            )


        self.gating_func_update_method = \
            FirstOrderOptimizer(
                update_method='adam',
                learning_rate=policy_learning_rate,
            )

        self.policy_learning_rate = policy_learning_rate
        self.policy_updates_ratio = policy_updates_ratio
        self.eval_samples = eval_samples
        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0

        self.scale_reward = scale_reward

        self.train_policy_itr = 0
        self.train_gate_policy_itr = 0

        self.opt_info = None

    def start_worker(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

    @overrides
    def train(self, e, environment_name, penalty):
        with tf.Session() as sess:

            self.initialize_uninitialized(sess)

            # This seems like a rather sequential method
            pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.env.observation_space.flat_dim,
                action_dim=self.env.action_space.flat_dim,
                replacement_prob=self.replacement_prob,
            )

            binary_pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.env.observation_space.flat_dim,
                action_dim=2,
                replacement_prob=self.replacement_prob,
            )

            self.start_worker()
            self.init_opt()

            num_experiment = e

            self.initialize_uninitialized(sess)
            itr = 0
            path_length = 0
            path_return = 0
            terminal = False
            initial = False

            ### assigning query cost here
            query_cost = 0.9

            observation = self.env.reset()

            with tf.variable_scope("sample_policy"):
                sample_policy = Serializable.clone(self.policy)

            with tf.variable_scope("sample_target_gate_qf"):
                target_gate_qf = Serializable.clone(self.gate_qf)

            oracle_policy = self.oracle_policy

            oracle_interaction = 0
            agent_interaction = 0
            agent_interaction_per_episode = np.zeros(shape=(self.n_epochs))
            oracle_interaction_per_episode = np.zeros(shape=(self.n_epochs))

            for epoch in range(self.n_epochs):
                logger.push_prefix('epoch #%d | ' % epoch)
                logger.log("Training started")
                train_qf_itr, train_policy_itr = 0, 0

                for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                    # Execute policy
                    if terminal:  # or path_length > self.max_path_length:
                        # Note that if the last time step ends an episode, the very
                        # last state and observation will be ignored and not added
                        # to the replay pool
                        observation = self.env.reset()
                        self.agent_strategy.reset()
                        sample_policy.reset()
                        self.es_path_returns.append(path_return)
                        path_length = 0
                        path_return = 0
                        initial = True
                    else:
                        initial = False

                    ## softmax binary output here from Beta(s)
                    agent_action, binary_action = self.agent_strategy.get_action_with_binary(
                        itr, observation, policy=sample_policy)  # qf=qf)

                    sigma = np.round(binary_action)
                    oracle_action = self.get_oracle_action(
                        itr, observation, policy=oracle_policy)

                    action = sigma[0] * agent_action + sigma[1] * oracle_action

                    next_observation, reward, terminal, _ = self.env.step(
                        action)

                    ## sigma[1] for oracle interaction
                    if sigma[1] == 1.0:
                        oracle_interaction += 1
                        if penalty == True:
                            reward = reward - query_cost

                    ## for no oracle interaction
                    elif sigma[0] == 1.0:
                        agent_interaction += 1

                    path_length += 1
                    path_return += reward
                    """
                    CHECK THIS - To do here
                    Discrete binary actions to be added to the replay buffer
                    Not the binary action probabilities
                    """
                    binary_action = sigma

                    if not terminal and path_length >= self.max_path_length:
                        terminal = True
                        if self.include_horizon_terminal_transitions:
                            pool.add_sample(observation, action,
                                            reward * self.scale_reward,
                                            terminal, initial)
                            binary_pool.add_sample(observation, binary_action,
                                                   reward * self.scale_reward,
                                                   terminal, initial)

                    else:
                        pool.add_sample(observation, action,
                                        reward * self.scale_reward, terminal,
                                        initial)
                        binary_pool.add_sample(observation, binary_action,
                                               reward * self.scale_reward,
                                               terminal, initial)

                    observation = next_observation

                    if pool.size >= self.min_pool_size:
                        for update_itr in range(self.n_updates_per_sample):
                            # Train policy
                            # batches from pool containing continuous actions and discrete actions
                            batch = pool.random_batch(self.batch_size)
                            binary_batch = binary_pool.random_batch(
                                self.batch_size)

                            itrs = self.do_training(itr, batch, binary_batch)
                            train_qf_itr += itrs[0]
                            train_policy_itr += itrs[1]
                        sample_policy.set_param_values(
                            self.policy.get_param_values())

                    itr += 1

                agent_interaction_per_episode[epoch] = agent_interaction
                oracle_interaction_per_episode[epoch] = oracle_interaction
                np.save(
                    '/Users/Riashat/Documents/PhD_Research/RLLAB/rllab/learning_active_learning/learning_ask_help/DDPG/Oracle_Interactions/oracle_interactons_'
                    + str(environment_name) + '_' + 'exp_' +
                    str(num_experiment) + '.npy',
                    oracle_interaction_per_episode)
                np.save(
                    '/Users/Riashat/Documents/PhD_Research/RLLAB/rllab/learning_active_learning/learning_ask_help/DDPG/Oracle_Interactions/agent_interactions_'
                    + str(environment_name) + '_' + 'exp_' +
                    str(num_experiment) + '.npy',
                    agent_interaction_per_episode)
                # np.save('/home/ml/rislam4/Documents/RLLAB/rllab/Active_Imitation_Learning/Imitation_Learning_RL/learning_ask_help/DDPG/Oracle_Interactions/oracle_interactons_'  + str(environment_name) +  '_' + 'exp_' + str(num_experiment) + '.npy', oracle_interaction_per_episode)
                # np.save('/home/ml/rislam4/Documents/RLLAB/rllab/Active_Imitation_Learning/Imitation_Learning_RL/learning_ask_help/DDPG/Oracle_Interactions/agent_interactions_'  + str(environment_name) +  '_' + 'exp_' + str(num_experiment) + '.npy', agent_interaction_per_episode)

                logger.record_tabular('Oracle Interactions',
                                      oracle_interaction)
                logger.record_tabular('Agent Interactions', agent_interaction)

                logger.log("Training finished")
                logger.log("Trained qf %d steps, policy %d steps" %
                           (train_qf_itr, train_policy_itr))
                # logger.log("Pool sizes agent (%d) oracle (%d)" %(agent_only_pool.size, oracle_only_pool.size))

                if pool.size >= self.min_pool_size:
                    self.evaluate(epoch, pool)
                    params = self.get_epoch_snapshot(epoch)
                    logger.save_itr_params(epoch, params)
                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
            self.env.terminate()
            self.policy.terminate()

    def init_opt(self):

        with tf.variable_scope("target_policy"):
            target_policy = Serializable.clone(self.policy)

        oracle_policy = self.oracle_policy

        with tf.variable_scope("target_qf"):
            target_qf = Serializable.clone(self.qf)

        with tf.variable_scope("target_gate_qf"):
            target_gate_qf = Serializable.clone(self.gate_qf)

        obs = self.obs = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1,
        )
        action = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )

        discrete_action = tensor_utils.new_tensor(
            'discrete_action',
            ndim=2,
            dtype=tf.float32,
        )

        yvar = tensor_utils.new_tensor(
            'ys',
            ndim=1,
            dtype=tf.float32,
        )

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        policy_weight_decay_term = 0.5 * self.policy_weight_decay * \
                                   sum([tf.reduce_sum(tf.square(param))
                                        for param in self.policy.get_params(regularizable=True)])

        policy_qval_novice = self.qf.get_qval_sym(
            obs, self.policy.get_novice_policy_sym(obs), deterministic=True)

        policy_qval_gate = self.discrete_qf.get_qval_sym(
            obs,
            self.policy.get_action_binary_gate_sym(obs),
            deterministic=True)

        qval = self.qf.get_qval_sym(obs, action)
        qf_loss = tf.reduce_mean(tf.square(yvar - qval))
        qf_reg_loss = qf_loss + qf_weight_decay_term

        discrete_qval = self.gate_qf.get_qval_sym(obs, discrete_action)
        discrete_qf_loss = tf.reduce_mean(tf.square(yvar - discrete_qval))
        discrete_qf_reg_loss = discrete_qf_loss + qf_weight_decay_term

        qf_input_list = [yvar, obs, action]
        discrete_qf_input_list = [yvar, obs, discrete_action]

        policy_input_list = [obs]
        policy_gate_input_list = [obs]

        gating_network = self.policy.get_action_binary_gate_sym(obs)

        policy_surr = -tf.reduce_mean(policy_qval_novice)
        policy_reg_surr = policy_surr + policy_weight_decay_term

        policy_gate_surr = -tf.reduce_mean(
            policy_qval_gate) + policy_weight_decay_term
        policy_reg_gate_surr = policy_gate_surr + policy_weight_decay_term

        self.qf_update_method.update_opt(loss=qf_reg_loss,
                                         target=self.qf,
                                         inputs=qf_input_list)

        self.gate_qf_update_method.update_opt(loss=discrete_qf_reg_loss,
                                              target=self.gate_qf,
                                              inputs=discrete_qf_input_list)

        self.policy_update_method.update_opt(loss=policy_reg_surr,
                                             target=self.policy,
                                             inputs=policy_input_list)

        self.policy_gate_update_method.update_opt(
            loss=policy_reg_gate_surr,
            target=self.policy,
            inputs=policy_gate_input_list)

        f_train_qf = tensor_utils.compile_function(
            inputs=qf_input_list,
            outputs=[qf_loss, qval, self.qf_update_method._train_op],
        )

        f_train_discrete_qf = tensor_utils.compile_function(
            inputs=discrete_qf_input_list,
            outputs=[
                discrete_qf_loss, discrete_qval,
                self.gate_qf_update_method._train_op
            ],
        )

        f_train_policy = tensor_utils.compile_function(
            inputs=policy_input_list,
            outputs=[policy_surr, self.policy_update_method._train_op],
        )

        f_train_policy_gate = tensor_utils.compile_function(
            inputs=policy_gate_input_list,
            outputs=[
                policy_gate_surr, self.policy_gate_update_method._train_op,
                gating_network
            ],
        )

        self.opt_info = dict(
            f_train_qf=f_train_qf,
            f_train_discrete_qf=f_train_discrete_qf,
            f_train_policy=f_train_policy,
            f_train_policy_gate=f_train_policy_gate,
            target_qf=target_qf,
            target_gate_qf=target_gate_qf,
            target_policy=target_policy,
            oracle_policy=oracle_policy,
        )

    def do_training(self, itr, batch, binary_batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch, "observations", "actions", "rewards", "next_observations",
            "terminals")

        binary_obs, binary_actions, binary_rewards, binary_next_obs, binary_terminals = ext.extract(
            binary_batch, "observations", "actions", "rewards",
            "next_observations", "terminals")

        target_qf = self.opt_info["target_qf"]
        target_gate_qf = self.opt_info["target_gate_qf"]
        target_policy = self.opt_info["target_policy"]

        ## training critic for pi(s)
        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)
        ys = rewards + (1. -
                        terminals) * self.discount * next_qvals.reshape(-1)

        ## training the critic
        f_train_qf = self.opt_info["f_train_qf"]
        qf_loss, qval, _ = f_train_qf(ys, obs, actions)
        target_qf.set_param_values(target_qf.get_param_values() *
                                   (1.0 - self.soft_target_tau) +
                                   self.qf.get_param_values() *
                                   self.soft_target_tau)

        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)

        ## for training the actor for pi(s)
        self.train_policy_itr += self.policy_updates_ratio
        train_policy_itr = 0

        while self.train_policy_itr > 0:
            f_train_policy = self.opt_info["f_train_policy"]
            policy_surr, _ = f_train_policy(obs)
            target_policy.set_param_values(target_policy.get_param_values() *
                                           (1.0 - self.soft_target_tau) +
                                           self.policy.get_param_values() *
                                           self.soft_target_tau)
            self.policy_surr_averages.append(policy_surr)
            self.train_policy_itr -= 1
            train_policy_itr += 1
        """
        Training the gate function with Q-learning here
        """
        # next_binary_actions, _ = target_policy.get_binary_actions(next_obs)

        next_max_qvals = target_gate_qf.get_max_qval(next_obs)
        ys_discrete_qf = binary_rewards + (
            1. - terminals) * self.discount * next_max_qvals.reshape(-1)

        f_train_discrete_qf = self.opt_info["f_train_discrete_qf"]
        qf_loss, qval, _ = f_train_discrete_qf(ys_discrete_qf, binary_obs,
                                               binary_actions)

        ## for training the actor with Q-learning critic
        self.train_gate_policy_itr += self.policy_updates_ratio
        train_gate_policy_itr = 0

        while self.train_gate_policy_itr > 0:
            f_train_policy_gate = self.opt_info["f_train_policy_gate"]
            policy_surr, _, gating_outputs = f_train_policy_gate(obs)
            # target_policy.set_param_values(
            #     target_policy.get_param_values() * (1.0 - self.soft_target_tau) +
            #     self.policy.get_param_values() * self.soft_target_tau)
            self.policy_surr_averages.append(policy_surr)
            self.train_gate_policy_itr -= 1
            train_gate_policy_itr += 1

        return 1, train_policy_itr  # number of itrs qf, policy are trained

    #evaluation of the learnt policy
    def evaluate(self, epoch, pool):
        logger.log("Collecting samples for evaluation")
        paths = parallel_sampler.sample_paths(
            policy_params=self.policy.get_param_values(),
            max_samples=self.eval_samples,
            max_path_length=self.max_path_length,
        )

        average_discounted_return = np.mean([
            special.discount_return(path["rewards"], self.discount)
            for path in paths
        ])

        returns = [sum(path["rewards"]) for path in paths]

        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)

        average_q_loss = np.mean(self.qf_loss_averages)
        average_policy_surr = np.mean(self.policy_surr_averages)
        average_action = np.mean(
            np.square(np.concatenate([path["actions"] for path in paths])))

        policy_reg_param_norm = np.linalg.norm(
            self.policy.get_param_values(regularizable=True))
        qfun_reg_param_norm = np.linalg.norm(
            self.qf.get_param_values(regularizable=True))

        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Iteration', epoch)
        logger.record_tabular('AverageReturn', np.mean(returns))
        logger.record_tabular('StdReturn', np.std(returns))
        logger.record_tabular('MaxReturn', np.max(returns))
        logger.record_tabular('MinReturn', np.min(returns))
        if len(self.es_path_returns) > 0:
            logger.record_tabular('AverageEsReturn',
                                  np.mean(self.es_path_returns))
            logger.record_tabular('StdEsReturn', np.std(self.es_path_returns))
            logger.record_tabular('MaxEsReturn', np.max(self.es_path_returns))
            logger.record_tabular('MinEsReturn', np.min(self.es_path_returns))
        logger.record_tabular('AverageDiscountedReturn',
                              average_discounted_return)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AveragePolicySurr', average_policy_surr)
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        logger.record_tabular('AverageAction', average_action)

        logger.record_tabular('PolicyRegParamNorm', policy_reg_param_norm)
        logger.record_tabular('QFunRegParamNorm', qfun_reg_param_norm)

        self.env.log_diagnostics(paths)
        self.policy.log_diagnostics(paths)

        self.qf_loss_averages = []
        self.policy_surr_averages = []

        self.q_averages = []
        self.y_averages = []
        self.es_path_returns = []

    def update_plot(self):
        if self.plot:
            plotter.update_plot(self.policy, self.max_path_length)

    def get_epoch_snapshot(self, epoch):
        return dict(
            env=self.env,
            epoch=epoch,
            qf=self.qf,
            policy=self.policy,
            target_qf=self.opt_info["target_qf"],
            target_policy=self.opt_info["target_policy"],
            es=self.agent_strategy,
        )

    def get_oracle_action(self, t, observation, policy, **kwargs):
        action, _ = policy.get_action(observation)
        # ou_state = self.evolve_state()
        return np.clip(action, self.env.action_space.low,
                       self.env.action_space.high)

    ### to reinitialise variables in TF graph
    ### which has not been initailised so far
    def initialize_uninitialized(self, sess):
        global_vars = tf.global_variables()
        is_not_initialized = sess.run(
            [tf.is_variable_initialized(var) for var in global_vars])
        not_initialized_vars = [
            v for (v, f) in zip(global_vars, is_not_initialized) if not f
        ]

        print([str(i.name) for i in not_initialized_vars])  # only for testing
        if len(not_initialized_vars):
            sess.run(tf.variables_initializer(not_initialized_vars))
Exemplo n.º 14
0
class BatchPolopt(RLAlgorithm):
    """
    Base class for batch sampling-based policy optimization methods.
    This includes various policy gradient methods like vpg, npg, ppo, trpo, etc.
    """
    def __init__(
            self,
            env,
            policy,
            baseline,
            scope=None,
            n_itr=500,
            start_itr=0,
            batch_size=5000,
            max_path_length=500,
            discount=0.99,
            gae_lambda=1,
            plot=False,
            pause_for_plot=False,
            center_adv=True,
            positive_adv=False,
            store_paths=False,
            whole_paths=True,
            fixed_horizon=False,
            sampler_cls=None,
            sampler_args=None,
            force_batch_sampler=False,
            # qprop params
            qf=None,
            min_pool_size=10000,
            replay_pool_size=1000000,
            replacement_prob=1.0,
            qf_updates_ratio=1,
            qprop_use_mean_action=True,
            qprop_min_itr=0,
            qprop_batch_size=None,
            qprop_use_advantage=True,
            qprop_use_qf_baseline=False,
            qprop_eta_option='ones',
            qf_weight_decay=0.,
            qf_update_method='adam',
            qf_learning_rate=1e-3,
            qf_batch_size=32,
            qf_baseline=None,
            soft_target=True,
            soft_target_tau=0.001,
            scale_reward=1.0,
            **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :type policy: Policy
        :param baseline: Baseline
        :param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms
        simultaneously, each using different environments and policies
        :param n_itr: Number of iterations.
        :param start_itr: Starting iteration.
        :param batch_size: Number of samples per iteration.
        :param max_path_length: Maximum length of a single rollout.
        :param discount: Discount.
        :param gae_lambda: Lambda used for generalized advantage estimation.
        :param plot: Plot evaluation run after each iteration.
        :param pause_for_plot: Whether to pause before contiuing when plotting.
        :param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1.
        :param positive_adv: Whether to shift the advantages so that they are always positive. When used in
        conjunction with center_adv the advantages will be standardized before shifting.
        :param store_paths: Whether to save all paths data to the snapshot.
        :param qf: q function for q-prop.
        :return:
        """
        self.env = env
        self.policy = policy
        self.baseline = baseline
        self.scope = scope
        self.n_itr = n_itr
        self.start_itr = start_itr
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.center_adv = center_adv
        self.positive_adv = positive_adv
        self.store_paths = store_paths
        self.whole_paths = whole_paths
        self.fixed_horizon = fixed_horizon
        self.qf = qf
        if self.qf is not None:
            self.qprop = True
            self.qprop_optimizer = Serializable.clone(self.optimizer)
            self.min_pool_size = min_pool_size
            self.replay_pool_size = replay_pool_size
            self.replacement_prob = replacement_prob
            self.qf_updates_ratio = qf_updates_ratio
            self.qprop_use_mean_action = qprop_use_mean_action
            self.qprop_min_itr = qprop_min_itr
            self.qprop_use_qf_baseline = qprop_use_qf_baseline
            self.qf_weight_decay = qf_weight_decay
            self.qf_update_method = \
                FirstOrderOptimizer(
                    update_method=qf_update_method,
                    learning_rate=qf_learning_rate,
                )
            self.qf_learning_rate = qf_learning_rate
            self.qf_batch_size = qf_batch_size
            self.qf_baseline = qf_baseline
            if qprop_batch_size is None:
                self.qprop_batch_size = self.batch_size
            else:
                self.qprop_batch_size = qprop_batch_size
            self.qprop_use_advantage = qprop_use_advantage
            self.qprop_eta_option = qprop_eta_option
            self.soft_target_tau = soft_target_tau
            self.scale_reward = scale_reward

            self.qf_loss_averages = []
            self.q_averages = []
            self.y_averages = []
            if self.start_itr >= self.qprop_min_itr:
                self.batch_size = self.qprop_batch_size
                if self.qprop_use_qf_baseline:
                    self.baseline = self.qf_baseline
                self.qprop_enable = True
            else:
                self.qprop_enable = False
        else:
            self.qprop = False
        if sampler_cls is None:
            if self.policy.vectorized and not force_batch_sampler:
                sampler_cls = VectorizedSampler
            else:
                sampler_cls = BatchSampler
        if sampler_args is None:
            sampler_args = dict()

        self.sampler = sampler_cls(self, **sampler_args)

    def start_worker(self):
        self.sampler.start_worker()
        if self.plot:
            plotter.init_plot(self.env, self.policy)

    def shutdown_worker(self):
        self.sampler.shutdown_worker()

    def obtain_samples(self, itr):
        return self.sampler.obtain_samples(itr)

    def process_samples(self, itr, paths):
        return self.sampler.process_samples(itr, paths)

    def train(self):
        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            if self.qprop:
                pool = SimpleReplayPool(
                    max_pool_size=self.replay_pool_size,
                    observation_dim=self.env.observation_space.flat_dim,
                    action_dim=self.env.action_space.flat_dim,
                    replacement_prob=self.replacement_prob,
                )
            self.start_worker()
            self.init_opt()
            # This initializes the optimizer parameters
            sess.run(tf.initialize_all_variables())
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    if self.qprop and not self.qprop_enable and \
                            itr >= self.qprop_min_itr:
                        logger.log(
                            "Restarting workers with batch size %d->%d..." %
                            (self.batch_size, self.qprop_batch_size))
                        self.shutdown_worker()
                        self.batch_size = self.qprop_batch_size
                        self.start_worker()
                        if self.qprop_use_qf_baseline:
                            self.baseline = self.qf_baseline
                        self.qprop_enable = True
                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)
                    if self.qprop:
                        logger.log("Adding samples to replay pool...")
                        self.add_pool(itr, paths, pool)
                        logger.log("Optimizing critic before policy...")
                        self.optimize_critic(itr, pool)
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, samples_data)
                    params = self.get_itr_snapshot(itr,
                                                   samples_data)  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)
                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")
        self.shutdown_worker()

    def log_diagnostics(self, paths):
        self.env.log_diagnostics(paths)
        self.policy.log_diagnostics(paths)
        self.baseline.log_diagnostics(paths)

    def init_opt(self):
        """
        Initialize the optimization procedure. If using tensorflow, this may
        include declaring all the variables and compiling functions
        """
        raise NotImplementedError

    def init_opt_critic(self, vars_info, qbaseline_info):
        assert (not self.policy.recurrent)

        # Compute Taylor expansion Q function
        delta = vars_info["action_var"] - qbaseline_info["action_mu"]
        control_variate = tf.reduce_sum(delta * qbaseline_info["qprime"], 1)
        if not self.qprop_use_advantage:
            control_variate += qbaseline_info["qvalue"]
            logger.log("Qprop, using Q-value over A-value")
        f_control_variate = tensor_utils.compile_function(
            inputs=[vars_info["obs_var"], vars_info["action_var"]],
            outputs=[control_variate, qbaseline_info["qprime"]],
        )

        target_qf = Serializable.clone(self.qf, name="target_qf")

        # y need to be computed first
        obs = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1,
        )

        # The yi values are computed separately as above and then passed to
        # the training functions below
        action = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )
        yvar = tf.placeholder(dtype=tf.float32, shape=[None], name='ys')

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        qval = self.qf.get_qval_sym(obs, action)

        qf_loss = tf.reduce_mean(tf.square(yvar - qval))
        qf_reg_loss = qf_loss + qf_weight_decay_term

        qf_input_list = [yvar, obs, action]

        self.qf_update_method.update_opt(loss=qf_reg_loss,
                                         target=self.qf,
                                         inputs=qf_input_list)

        f_train_qf = tensor_utils.compile_function(
            inputs=qf_input_list,
            outputs=[qf_loss, qval, self.qf_update_method._train_op],
        )

        self.opt_info_critic = dict(
            f_train_qf=f_train_qf,
            target_qf=target_qf,
            f_control_variate=f_control_variate,
        )

    def get_control_variate(self, observations, actions):
        control_variate, qprime = self.opt_info_critic["f_control_variate"](
            observations, actions)
        return control_variate

    def get_itr_snapshot(self, itr, samples_data):
        """
        Returns all the data that should be saved in the snapshot for this
        iteration.
        """
        raise NotImplementedError

    def optimize_policy(self, itr, samples_data):
        raise NotImplementedError

    def add_pool(self, itr, paths, pool):
        # Add samples to replay pool
        path_lens = []
        for path in paths:
            path_len = path["observations"].shape[0]
            for i in range(path_len):
                observation = path["observations"][i]
                action = path["actions"][i]
                reward = path["rewards"][i]
                terminal = path["terminals"][i]
                initial = i == 0
                pool.add_sample(observation, action,
                                reward * self.scale_reward, terminal, initial)
            path_lens.append(path_len)
        path_lens = np.array(path_lens)
        logger.log(
            "PathsInfo epsN=%d, meanL=%.2f, maxL=%d, minL=%d" %
            (len(paths), path_lens.mean(), path_lens.max(), path_lens.min()))
        logger.log("Put %d transitions to replay, size=%d" %
                   (path_lens.sum(), pool.size))

    def optimize_critic(self, itr, pool):
        # Train the critic
        if pool.size >= self.min_pool_size:
            #qf_itrs = float(self.batch_size)/self.qf_batch_size*self.qf_updates_ratio
            qf_itrs = float(self.batch_size) * self.qf_updates_ratio
            qf_itrs = int(np.ceil(qf_itrs))
            logger.log("Fitting critic for %d iterations, batch size=%d" %
                       (qf_itrs, self.qf_batch_size))
            for i in range(qf_itrs):
                # Train policy
                batch = pool.random_batch(self.qf_batch_size)
                self.do_training(itr, batch)

    def do_training(self, itr, batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch, "observations", "actions", "rewards", "next_observations",
            "terminals")

        # compute the on-policy y values
        target_qf = self.opt_info_critic["target_qf"]

        next_actions, next_actions_dict = self.policy.get_actions(next_obs)
        if self.qprop_use_mean_action:
            next_actions = next_actions_dict["mean"]
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. - terminals) * self.discount * next_qvals

        f_train_qf = self.opt_info_critic["f_train_qf"]

        qf_loss, qval, _ = f_train_qf(ys, obs, actions)

        target_qf.set_param_values(target_qf.get_param_values() *
                                   (1.0 - self.soft_target_tau) +
                                   self.qf.get_param_values() *
                                   self.soft_target_tau)

        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)

    def update_plot(self):
        if self.plot:
            plotter.update_plot(self.policy, self.max_path_length)
Exemplo n.º 15
0
class Poleval():
    """
    Base class defining methods for policy evaluation.
    """
    def init_critic(self,
            min_pool_size=10000,
            replay_pool_size=1000000,
            replacement_prob=1.0,
            qf_batch_size=32,
            qf_weight_decay=0.,
            qf_update_method='adam',
            qf_learning_rate=1e-3,
            qf_use_target=True,
            qf_mc_ratio = 0,
            qf_residual_phi = 0,
            soft_target_tau=0.001,
            **kwargs):
        self.soft_target_tau = soft_target_tau
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.qf_batch_size = qf_batch_size
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.qf_use_target = qf_use_target

        self.qf_mc_ratio = qf_mc_ratio
        if self.qf_mc_ratio > 0:
            self.mc_y_averages = []

        self.qf_residual_phi = qf_residual_phi
        if self.qf_residual_phi > 0:
            self.residual_y_averages = []
            self.qf_residual_loss_averages = []

        self.qf_loss_averages = []
        self.q_averages = []
        self.y_averages = []
    

    def init_phi(self, 
            pf_batch_size=32,
            pf_weight_decay=0.,
            pf_update_method='adam',
            **kwargs):
        
        self.pf_batch_size = pf_batch_size
        self.pf_weight_decay = 0.
        self.pf_update_method = \
            FirstOrderOptimizer(
                update_method=pf_update_method,
                learning_rate=self.pf_learning_rate
            )

    def log_critic_training(self):
        if self.qf is None: return
        if len(self.q_averages) == 0: return
        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)
        average_q_loss = np.mean(self.qf_loss_averages)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        if self.qf_mc_ratio > 0:
            all_mc_ys = np.concatenate(self.mc_y_averages)
            logger.record_tabular('AverageMcY', np.mean(all_mc_ys))
            logger.record_tabular('AverageAbsMcY', np.mean(np.abs(all_mc_ys)))
            self.mc_y_averages = []
        if self.qf_residual_phi > 0:
            all_residual_ys = np.concatenate(self.residual_y_averages)
            average_q_residual_loss = np.mean(self.qf_residual_loss_averages)
            logger.record_tabular('AverageResQLoss', average_q_residual_loss)
            logger.record_tabular('AverageResY', np.mean(all_residual_ys))
            logger.record_tabular('AverageAbsResY', np.mean(np.abs(all_residual_ys)))
            logger.record_tabular('AverageAbsResQYDiff',
                              np.mean(np.abs(all_qs - all_residual_ys)))
            self.residual_y_averages = []
            self.qf_residual_loss_averages = []
        self.qf_loss_averages = []
        self.q_averages = []
        self.y_averages = []
    
    #TODO: add log of phi training
    def log_phi_training(self):
        if self.pf is None: return
        if len(self.pf_loss_averages) == 0: return
        
        average_phi_loss = np.mean(self.pf_loss_averages)
        logger.record_tabular('AveragePhiLoss', average_phi_loss)

        self.pf_loss_averages = []
        

    def init_opt_critic(self):
        if self.qf is None: logger.log("qf is None"); return

        if self.qf_use_target:
            logger.log("[init_opt] using target qf.")
            target_qf = Serializable.clone(self.qf, name="target_qf")
        else:
            logger.log("[init_opt] no target qf.")
            target_qf = self.qf

        obs = self.qf.env_spec.observation_space.new_tensor_variable(
            'qf_obs',
            extra_dims=1,
        )
        action = self.qf.env_spec.action_space.new_tensor_variable(
            'qf_action',
            extra_dims=1,
        )
        yvar = tf.placeholder(dtype=tf.float32, shape=[None], name='ys')

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        qval = self.qf.get_qval_sym(obs, action)

        qf_loss = tf.reduce_mean(tf.square(yvar - qval))
        qf_input_list = [yvar, obs, action]
        qf_output_list = [qf_loss, qval]

        # set up residual gradient method
        if self.qf_residual_phi > 0:
            next_obs = self.qf.env_spec.observation_space.new_tensor_variable(
                'qf_next_obs',
                extra_dims=1,
            )
            rvar = tf.placeholder(dtype=tf.float32, shape=[None], name='rs')
            terminals = tf.placeholder(dtype=tf.float32, shape=[None], name='terminals')
            discount = tf.placeholder(dtype=tf.float32, shape=(), name='discount')
            qf_loss *= (1. - self.qf_residual_phi)
            next_qval = self.qf.get_e_qval_sym(next_obs, self.policy)
            residual_ys = rvar + (1.-terminals)*discount*next_qval
            qf_residual_loss = tf.reduce_mean(tf.square(residual_ys-qval))
            qf_loss += self.qf_residual_phi * qf_residual_loss
            qf_input_list += [next_obs, rvar, terminals, discount]
            qf_output_list += [qf_residual_loss, residual_ys]

        # set up monte carlo Q fitting method
        if self.qf_mc_ratio > 0:
            mc_obs = self.qf.env_spec.observation_space.new_tensor_variable(
                'qf_mc_obs',
                extra_dims=1,
            )
            mc_action = self.qf.env_spec.action_space.new_tensor_variable(
                'qf_mc_action',
                extra_dims=1,
            )
            mc_yvar = tf.placeholder(dtype=tf.float32, shape=[None], name='mc_ys')
            mc_qval = self.qf.get_qval_sym(mc_obs, mc_action)
            qf_mc_loss = tf.reduce_mean(tf.square(mc_yvar - mc_qval))
            qf_loss = (1.-self.qf_mc_ratio)*qf_loss + self.qf_mc_ratio*qf_mc_loss
            qf_input_list += [mc_yvar, mc_obs, mc_action]

        qf_reg_loss = qf_loss + qf_weight_decay_term
        self.qf_update_method.update_opt(
            loss=qf_reg_loss, target=self.qf, inputs=qf_input_list)
        qf_output_list += [self.qf_update_method._train_op]

        f_train_qf = tensor_utils.compile_function(
            inputs=qf_input_list,
            outputs=qf_output_list,
        )

        self.opt_info_critic = dict(
            f_train_qf=f_train_qf,
            target_qf=target_qf,
        )
    
    def init_opt_phi(self):
        if self.pf is None: return
        is_recurrent = int(self.policy.recurrent)
        obs = self.pf.env_spec.observation_space.new_tensor_variable(
            'phi_obs',
            extra_dims=1,
        )
        action = self.pf.env_spec.action_space.new_tensor_variable(
            'phi_action',
            extra_dims=1,
        )

        origin_advantages = tensor_utils.new_tensor(
            name="origin_adv_var",
            ndim = 1+ is_recurrent,
            dtype=tf.float32,
        )
        
        eta_var = tensor_utils.new_tensor(
            name="eta_var",
            ndim=1 + is_recurrent,
            dtype=tf.float32
        )


        pf_weight_decay_term = .5 * self.pf_weight_decay * \
                                sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.pf.get_params(regularizable=True)])
        
        self.pf_loss_averages = []
        
        # TODO: Add multiple choice for Variance Reduction

        if self.use_gradient_vr:
            logger.log("using gradient as variance reduction")
            phi_mse = self.pf.get_gradient_cv_sym(obs, action, 
                        origin_advantages, eta_var, self.policy)
            phi_mse = phi_mse["mu_mse"] + phi_mse["var_mse"]
        else:
            logger.log("using reward as variance reduction")
            phi_mse = self.pf.get_adv_cv_sym(obs, action, 
                        origin_advantages, eta_var, self.policy)
        
        pf_loss = tf.reduce_mean(phi_mse)
         
        pf_input_list = [obs, action, origin_advantages, eta_var]
        pf_output_list = [pf_loss]
        pf_reg_loss = pf_loss + pf_weight_decay_term
        self.pf_update_method.update_opt(loss=pf_reg_loss, 
                        target=self.pf, inputs=pf_input_list)
        
        # parameters of phi
        params = self.pf.get_params(trainable=True)
        
        for param in params:
            logger.log("parameter of phi %s, shape=%s"%(param.name, param.shape))

        pf_output_list += [self.pf_update_method._train_op]

        f_train_pf = tensor_utils.compile_function(
            inputs=pf_input_list,
            outputs=pf_output_list)
        
        self.opt_train_phi = dict(
            f_train_pf=f_train_pf
        )

        # optimal eta calculation
        
        opt_eta = self.pf.get_opt_eta_sym(obs, action, 
                origin_advantages, self.policy)
        
        self.opt_eta_phi = tensor_utils.compile_function(
            inputs=[obs, action, origin_advantages],
            outputs=opt_eta)


    
    def do_critic_training(self, itr, batch, samples_data=None):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch,
            "observations", "actions", "rewards", "next_observations",
            "terminals"
        )

        target_qf = self.opt_info_critic["target_qf"]
        if self.qf_dqn:
            next_qvals = target_qf.get_max_qval(next_obs)
        else:
            target_policy = self.opt_info["target_policy"]
            next_qvals = target_qf.get_e_qval(next_obs, target_policy)

        ys = rewards + (1. - terminals) * self.discount * next_qvals
        inputs = (ys, obs, actions)

        if self.qf_residual_phi:
            inputs += (next_obs, rewards, terminals, self.discount)
        if self.qf_mc_ratio > 0:
            mc_inputs = ext.extract(
                samples_data,
                "qvalues", "observations", "actions"
            )
            inputs += mc_inputs
            self.mc_y_averages.append(mc_inputs[0])

        qf_outputs = self.opt_info_critic['f_train_qf'](*inputs)
        qf_loss = qf_outputs.pop(0)
        qval = qf_outputs.pop(0)
        if self.qf_residual_phi:
            qf_residual_loss = qf_outputs.pop(0)
            residual_ys = qf_outputs.pop(0)
            self.qf_residual_loss_averages.append(qf_residual_loss)
            self.residual_y_averages.append(residual_ys)

        if self.qf_use_target:
            target_qf.set_param_values(
                target_qf.get_param_values() * (1.0 - self.soft_target_tau) +
                self.qf.get_param_values() * self.soft_target_tau)

        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)
    
    def do_phi_training(self, itr, indices=None, samples_data=None):
        
        batch_samples = samples_data
        '''
        dict(
            observations=samples_data["observations"][indices],
            actions=samples_data["actions"][indices],
            origin_advantages=samples_data["origin_advantages"][indices],)
        '''
        inputs = ext.extract(
            batch_samples, 
            "observations", 
            "actions", 
            "origin_advantages", 
            "etas",)

        # the following code is useless
        # FIXME: write a better version of this
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]

        inputs += tuple(state_info_list)

        #TODO: add recurrent
        if self.policy.recurrent:
            inputs += (samples_data["valid"], )
        
        pf_outputs = self.opt_train_phi['f_train_pf'](*inputs)
        pf_loss = pf_outputs.pop(0)
        self.pf_loss_averages.append(pf_loss)
Exemplo n.º 16
0
    def init_opt(self):

                
        if self.pf is not None:
            optimizer = [FirstOrderOptimizer(**self.optimizer_args) for i in range(2)]
        else:
            optimizer = FirstOrderOptimizer(**self.optimizer_args)
        
        self.optimizer = optimizer          

        is_recurrent = int(self.policy.recurrent)

        obs_var = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1 + is_recurrent,
        )
        action_var = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1 + is_recurrent,
        )
        advantage_var = tensor_utils.new_tensor(
            name='advantage',
            ndim=1 + is_recurrent,
            dtype=tf.float32,
        )

        advantage_bar = None
        if self.phi:
            advantage_bar = tensor_utils.new_tensor(
            name='advantage_bar',
            ndim=1 + is_recurrent,
            dtype=tf.float32,
        )

        dist = self.policy.distribution

        old_dist_info_vars = {
            k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k)
            for k, shape in dist.dist_info_specs
            }
        old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys]

        state_info_vars = {
            k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k)
            for k, shape in self.policy.state_info_specs
            }
        state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys]

        if is_recurrent:
            valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid")
        else:
            valid_var = None

        dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
        logli = dist.log_likelihood_sym(action_var, dist_info_vars)
        kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)

        # formulate as a minimization problem
        # The gradient of the surrogate objective is the policy gradient
        if is_recurrent:
            surr_obj = - tf.reduce_sum(logli * advantage_var * valid_var) / tf.reduce_sum(valid_var)
            mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var)
            max_kl = tf.reduce_max(kl * valid_var)
        else:
            surr_obj = - tf.reduce_mean(logli * advantage_var)
            mean_kl = tf.reduce_mean(kl)
            max_kl = tf.reduce_max(kl)

        input_list = [obs_var, action_var, advantage_var] + state_info_vars_list
        if is_recurrent:
            input_list.append(valid_var)

        vars_info = {
            "mean_kl": mean_kl,
            "input_list": input_list,
            "obs_var": obs_var,
            "action_var": action_var,
            "advantage_var": advantage_var,
            "advantage_bar": advantage_bar,
            "surr_loss": surr_obj,
            "dist_info_vars": dist_info_vars,
            "lr": logli,
        }

        if self.qprop:
            eta_var = tensor_utils.new_tensor(
                'eta',
                ndim=1 + is_recurrent,
                dtype=tf.float32,
            )
            qvalue = self.qf.get_e_qval_sym(vars_info["obs_var"], self.policy, deterministic=True)
            qprop_surr_loss = - tf.reduce_mean(vars_info["lr"] *
                vars_info["advantage_var"]) - tf.reduce_mean(
                qvalue * eta_var)
            input_list += [eta_var]
            self.optimizer.update_opt(
                loss=qprop_surr_loss,
                target=self.policy,
                inputs=input_list,
            )
            # calculate covariance between \hat A and \ba A
            control_variate = self.qf.get_cv_sym(obs_var,
                    action_var, self.policy)
            f_control_variate = tensor_utils.compile_function(
                inputs=[obs_var, action_var],
                outputs=control_variate,
            )
            self.opt_info_qprop = dict(
                f_control_variate=f_control_variate,
            )
        elif self.phi:
            # Here we use two unified versions...
            # First we get gradient w.r.t \theta of \mu
            # Then we get gradient w.r.t \theta of \sigma
            # TODO: fix gradwrtmu parameters, find a more elegant method..
            # change to adaptive eta
            eta_var = tensor_utils.new_tensor(
                'eta',
                ndim=1 + is_recurrent,
                dtype=tf.float32,
            )

            phival, mean_var = self.pf._get_e_phival_sym(vars_info["obs_var"], 
                    self.policy, gradwrtmu=True, deterministic=True)
            
            scv_surr_mu_loss = - tf.reduce_mean(vars_info['lr'] * 
                vars_info["advantage_var"]) - tf.reduce_mean(eta_var * phival)
            
            self.optimizer[0].update_opt(
                loss = scv_surr_mu_loss,
                target=self.policy.mean_network,
                inputs=input_list)

            # gradient w.r.t \theta of \sigma
            # using stein gradient variance reduction methods
            self.pf.phi_sigma_full = False
            if self.pf.phi_sigma_full:
                logger.log("Use full stein sigma variance reduction methods")
                # using standard stein variance reduction w.r.t \sigma
                grad_info, dist_info = self.policy.get_grad_info_sym(vars_info['obs_var'],
                                    vars_info['action_var'])
                
                # FIXME: Check if it is a list or something
                phi_primes = tf.gradients(phival, mean_var)[0]
                

                var_gradient = grad_info["logpi_dvar"] * tf.expand_dims(vars_info["advantage_var"], axis=1) - \
                                tf.expand_dims(vars_info['advantage_bar'], axis=1) * grad_info["logpi_dvar"] + \
                                2. * grad_info['logpi_dmu'] * phi_primes
                

                var_loss = - tf.reduce_mean(tf.reduce_sum(tf.stop_gradient(
                            var_gradient) * tf.exp(2.*dist_info["log_std"]), axis=1))

                self.optimizer[1].update_opt(
                    loss=var_loss,
                    target=self.policy.std_network,
                    inputs=input_list + [vars_info['advantage_bar']]
                )

            # the same as Q-prop for sigma updates
            else:
                logger.log("Did not use full stein variance reduction for sigma")
                surr_sigma_loss = - tf.reduce_mean(vars_info['lr'] * 
                vars_info["advantage_var"])

                self.optimizer[1].update_opt(
                    loss=surr_sigma_loss,
                    target=self.policy.std_network,
                    inputs=input_list
                )

            stein_phi = self.pf.get_phi_bar_sym(obs_var, 
                        action_var, self.policy)
            f_stein_phi = tensor_utils.compile_function(
                inputs=[obs_var, action_var],
                outputs=stein_phi,
            )
                
            self.opt_info_phi = dict(
                f_stein_phi=f_stein_phi, 
            )
                
        
        else:
            self.optimizer.update_opt(loss=surr_obj, target=self.policy, inputs=input_list)

        f_kl = tensor_utils.compile_function(
            inputs=input_list + old_dist_info_vars_list,
            outputs=[mean_kl, max_kl],
        )
        self.opt_info = dict(
            f_kl=f_kl,
            target_policy=self.policy,
        )


        self.init_opt_critic()
        # init optimization for phi training
        self.init_opt_phi()

        logger.log("Parameters...")
        for param in self.policy.mean_network.get_params(trainable=True):
            logger.log('mean with name=%s, shape=%s'% (param.name, param.shape))
        
        for param in self.policy.std_network.get_params(trainable=True):
            logger.log('std with name=%s, shape=%s'% (param.name, str(param.shape)))
Exemplo n.º 17
0
    def __init__(self,
                 env,
                 policy,
                 qf,
                 es,
                 batch_size=32,
                 n_epochs=200,
                 epoch_length=1000,
                 min_pool_size=10000,
                 replay_pool_size=1000000,
                 replacement_prob=1.0,
                 discount=0.99,
                 max_path_length=250,
                 qf_weight_decay=0.,
                 qf_update_method='adam',
                 qf_learning_rate=1e-3,
                 policy_weight_decay=0,
                 policy_update_method='adam',
                 policy_learning_rate=1e-3,
                 policy_updates_ratio=1.0,
                 eval_samples=10000,
                 soft_target=True,
                 soft_target_tau=0.001,
                 n_updates_per_sample=1,
                 scale_reward=1.0,
                 include_horizon_terminal_transitions=False,
                 plot=False,
                 pause_for_plot=False,
                 **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q function.
        :param qf_update_method: Online optimization method for training Q function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the policy.
        :param policy_update_method: Online optimization method for training the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the policy.
        :param soft_target_tau: Interpolation parameter for doing the soft target update.
        :param n_updates_per_sample: Number of Q function and policy updates per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when training
        :param include_horizon_terminal_transitions: whether to include transitions with terminal=True because the
        horizon was reached. This might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting.
        :return:
        """
        self.env = env
        self.policy = policy
        self.qf = qf
        self.es = es
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay
        self.policy_update_method = \
            FirstOrderOptimizer(
                update_method=policy_update_method,
                learning_rate=policy_learning_rate,
            )
        self.policy_learning_rate = policy_learning_rate
        self.policy_updates_ratio = policy_updates_ratio
        self.eval_samples = eval_samples
        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0

        self.scale_reward = scale_reward

        self.train_policy_itr = 0

        self.opt_info = None
Exemplo n.º 18
0
    def init_opt(self):

        ###############################
        #
        # Variable Definitions
        #
        ###############################

        all_task_dist_info_vars = []
        all_obs_vars = []

        for i, policy in enumerate(self.local_policies):

            task_obs_var = self.env_partitions[
                i].observation_space.new_tensor_variable('obs%d' % i,
                                                         extra_dims=1)
            task_dist_info_vars = []

            for j, other_policy in enumerate(self.local_policies):

                state_info_vars = dict()  # Not handling recurrent policies
                dist_info_vars = other_policy.dist_info_sym(
                    task_obs_var, state_info_vars)
                task_dist_info_vars.append(dist_info_vars)

            all_obs_vars.append(task_obs_var)
            all_task_dist_info_vars.append(task_dist_info_vars)

        obs_var = self.env.observation_space.new_tensor_variable('obs',
                                                                 extra_dims=1)
        action_var = self.env.action_space.new_tensor_variable('action',
                                                               extra_dims=1)
        advantage_var = tensor_utils.new_tensor('advantage',
                                                ndim=1,
                                                dtype=tf.float32)

        old_dist_info_vars = {
            k: tf.placeholder(tf.float32,
                              shape=[None] + list(shape),
                              name='old_%s' % k)
            for k, shape in self.policy.distribution.dist_info_specs
        }

        old_dist_info_vars_list = [
            old_dist_info_vars[k]
            for k in self.policy.distribution.dist_info_keys
        ]

        input_list = [obs_var, action_var, advantage_var
                      ] + old_dist_info_vars_list + all_obs_vars

        ###############################
        #
        # Local Policy Optimization
        #
        ###############################

        self.optimizers = []
        self.metrics = []

        for n, policy in enumerate(self.local_policies):

            state_info_vars = dict()
            dist_info_vars = policy.dist_info_sym(obs_var, state_info_vars)
            dist = policy.distribution

            kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
            lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars,
                                           dist_info_vars)
            surr_loss = -tf.reduce_mean(lr * advantage_var)

            if self.constrain_together:
                additional_loss = Metrics.kl_on_others(
                    n, dist, all_task_dist_info_vars)
            else:
                additional_loss = tf.constant(0.0)

            local_loss = surr_loss + self.penalty * additional_loss

            kl_metric = tensor_utils.compile_function(inputs=input_list,
                                                      outputs=additional_loss,
                                                      log_name="KLPenalty%d" %
                                                      n)
            self.metrics.append(kl_metric)

            mean_kl_constraint = tf.reduce_mean(kl)

            optimizer = self.optimizer_class(**self.optimizer_args)
            optimizer.update_opt(
                loss=local_loss,
                target=policy,
                leq_constraint=(mean_kl_constraint, self.step_size),
                inputs=input_list,
                constraint_name="mean_kl_%d" % n,
            )
            self.optimizers.append(optimizer)

        ###############################
        #
        # Global Policy Optimization
        #
        ###############################

        # Behaviour Cloning Loss

        state_info_vars = dict()
        center_dist_info_vars = self.policy.dist_info_sym(
            obs_var, state_info_vars)
        behaviour_cloning_loss = tf.losses.mean_squared_error(
            action_var, center_dist_info_vars['mean'])
        self.center_optimizer = FirstOrderOptimizer(max_epochs=1,
                                                    verbose=True,
                                                    batch_size=1000)
        self.center_optimizer.update_opt(behaviour_cloning_loss, self.policy,
                                         [obs_var, action_var])

        # TRPO Loss

        kl = dist.kl_sym(old_dist_info_vars, center_dist_info_vars)
        lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars,
                                       center_dist_info_vars)
        center_trpo_loss = -tf.reduce_mean(lr * advantage_var)
        mean_kl_constraint = tf.reduce_mean(kl)

        optimizer = self.optimizer_class(**self.optimizer_args)
        optimizer.update_opt(
            loss=center_trpo_loss,
            target=self.policy,
            leq_constraint=(mean_kl_constraint, self.step_size),
            inputs=[obs_var, action_var, advantage_var] +
            old_dist_info_vars_list,
            constraint_name="mean_kl_center",
        )

        self.center_trpo_optimizer = optimizer

        # Reset Local Policies to Global Policy

        assignment_operations = []

        for policy in self.local_policies:
            for param_local, param_center in zip(
                    policy.get_params_internal(),
                    self.policy.get_params_internal()):
                if 'std' not in param_local.name:
                    assignment_operations.append(
                        tf.assign(param_local, param_center))

        self.reset_to_center = tf.group(*assignment_operations)

        return dict()
Exemplo n.º 19
0
class NPO(BatchPolopt):
    """
    Natural Policy Optimization.
    """
    def __init__(self,
                 optimizer_class=None,
                 optimizer_args=None,
                 step_size=0.01,
                 penalty=0.0,
                 **kwargs):

        self.optimizer_class = default(optimizer_class, PenaltyLbfgsOptimizer)
        self.optimizer_args = default(optimizer_args, dict())

        self.penalty = penalty
        self.constrain_together = penalty > 0

        self.step_size = step_size

        self.metrics = []
        super(NPO, self).__init__(**kwargs)

    @overrides
    def init_opt(self):

        ###############################
        #
        # Variable Definitions
        #
        ###############################

        all_task_dist_info_vars = []
        all_obs_vars = []

        for i, policy in enumerate(self.local_policies):

            task_obs_var = self.env_partitions[
                i].observation_space.new_tensor_variable('obs%d' % i,
                                                         extra_dims=1)
            task_dist_info_vars = []

            for j, other_policy in enumerate(self.local_policies):

                state_info_vars = dict()  # Not handling recurrent policies
                dist_info_vars = other_policy.dist_info_sym(
                    task_obs_var, state_info_vars)
                task_dist_info_vars.append(dist_info_vars)

            all_obs_vars.append(task_obs_var)
            all_task_dist_info_vars.append(task_dist_info_vars)

        obs_var = self.env.observation_space.new_tensor_variable('obs',
                                                                 extra_dims=1)
        action_var = self.env.action_space.new_tensor_variable('action',
                                                               extra_dims=1)
        advantage_var = tensor_utils.new_tensor('advantage',
                                                ndim=1,
                                                dtype=tf.float32)

        old_dist_info_vars = {
            k: tf.placeholder(tf.float32,
                              shape=[None] + list(shape),
                              name='old_%s' % k)
            for k, shape in self.policy.distribution.dist_info_specs
        }

        old_dist_info_vars_list = [
            old_dist_info_vars[k]
            for k in self.policy.distribution.dist_info_keys
        ]

        input_list = [obs_var, action_var, advantage_var
                      ] + old_dist_info_vars_list + all_obs_vars

        ###############################
        #
        # Local Policy Optimization
        #
        ###############################

        self.optimizers = []
        self.metrics = []

        for n, policy in enumerate(self.local_policies):

            state_info_vars = dict()
            dist_info_vars = policy.dist_info_sym(obs_var, state_info_vars)
            dist = policy.distribution

            kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
            lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars,
                                           dist_info_vars)
            surr_loss = -tf.reduce_mean(lr * advantage_var)

            if self.constrain_together:
                additional_loss = Metrics.kl_on_others(
                    n, dist, all_task_dist_info_vars)
            else:
                additional_loss = tf.constant(0.0)

            local_loss = surr_loss + self.penalty * additional_loss

            kl_metric = tensor_utils.compile_function(inputs=input_list,
                                                      outputs=additional_loss,
                                                      log_name="KLPenalty%d" %
                                                      n)
            self.metrics.append(kl_metric)

            mean_kl_constraint = tf.reduce_mean(kl)

            optimizer = self.optimizer_class(**self.optimizer_args)
            optimizer.update_opt(
                loss=local_loss,
                target=policy,
                leq_constraint=(mean_kl_constraint, self.step_size),
                inputs=input_list,
                constraint_name="mean_kl_%d" % n,
            )
            self.optimizers.append(optimizer)

        ###############################
        #
        # Global Policy Optimization
        #
        ###############################

        # Behaviour Cloning Loss

        state_info_vars = dict()
        center_dist_info_vars = self.policy.dist_info_sym(
            obs_var, state_info_vars)
        behaviour_cloning_loss = tf.losses.mean_squared_error(
            action_var, center_dist_info_vars['mean'])
        self.center_optimizer = FirstOrderOptimizer(max_epochs=1,
                                                    verbose=True,
                                                    batch_size=1000)
        self.center_optimizer.update_opt(behaviour_cloning_loss, self.policy,
                                         [obs_var, action_var])

        # TRPO Loss

        kl = dist.kl_sym(old_dist_info_vars, center_dist_info_vars)
        lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars,
                                       center_dist_info_vars)
        center_trpo_loss = -tf.reduce_mean(lr * advantage_var)
        mean_kl_constraint = tf.reduce_mean(kl)

        optimizer = self.optimizer_class(**self.optimizer_args)
        optimizer.update_opt(
            loss=center_trpo_loss,
            target=self.policy,
            leq_constraint=(mean_kl_constraint, self.step_size),
            inputs=[obs_var, action_var, advantage_var] +
            old_dist_info_vars_list,
            constraint_name="mean_kl_center",
        )

        self.center_trpo_optimizer = optimizer

        # Reset Local Policies to Global Policy

        assignment_operations = []

        for policy in self.local_policies:
            for param_local, param_center in zip(
                    policy.get_params_internal(),
                    self.policy.get_params_internal()):
                if 'std' not in param_local.name:
                    assignment_operations.append(
                        tf.assign(param_local, param_center))

        self.reset_to_center = tf.group(*assignment_operations)

        return dict()

    def optimize_local_policies(self, itr, all_samples_data):

        dist_info_keys = self.policy.distribution.dist_info_keys
        for n, optimizer in enumerate(self.optimizers):

            obs_act_adv_values = tuple(
                ext.extract(all_samples_data[n], "observations", "actions",
                            "advantages"))
            dist_info_list = tuple([
                all_samples_data[n]["agent_infos"][k] for k in dist_info_keys
            ])
            all_task_obs_values = tuple([
                samples_data["observations"]
                for samples_data in all_samples_data
            ])

            all_input_values = obs_act_adv_values + dist_info_list + all_task_obs_values
            optimizer.optimize(all_input_values)

            kl_penalty = sliced_fun(self.metrics[n], 1)(all_input_values)
            logger.record_tabular('KLPenalty%d' % n, kl_penalty)

    def optimize_global_policy(self, itr, all_samples_data):

        all_observations = np.concatenate([
            samples_data['observations'] for samples_data in all_samples_data
        ])
        all_actions = np.concatenate([
            samples_data['agent_infos']['mean']
            for samples_data in all_samples_data
        ])

        num_itrs = 1 if itr % self.distillation_period != 0 else 30

        for _ in range(num_itrs):
            self.center_optimizer.optimize([all_observations, all_actions])

        paths = self.global_sampler.obtain_samples(itr)
        samples_data = self.global_sampler.process_samples(itr, paths)

        obs_values = tuple(
            ext.extract(samples_data, "observations", "actions", "advantages"))
        dist_info_list = [
            samples_data["agent_infos"][k]
            for k in self.policy.distribution.dist_info_keys
        ]

        all_input_values = obs_values + tuple(dist_info_list)

        self.center_trpo_optimizer.optimize(all_input_values)
        self.env.log_diagnostics(paths)

    @overrides
    def optimize_policy(self, itr, all_samples_data):

        self.optimize_local_policies(itr, all_samples_data)
        self.optimize_global_policy(itr, all_samples_data)

        if itr % self.distillation_period == 0:
            sess = tf.get_default_session()
            sess.run(self.reset_to_center)
            logger.log('Reset Local Policies to Global Policies')

        return dict()
Exemplo n.º 20
0
class DDPG(RLAlgorithm):
    """
    Deep Deterministic Policy Gradient.
    """
    def __init__(self,
                 env,
                 policy,
                 qf,
                 es,
                 batch_size=32,
                 n_epochs=200,
                 epoch_length=1000,
                 min_pool_size=10000,
                 replay_pool_size=1000000,
                 replacement_prob=1.0,
                 discount=0.99,
                 max_path_length=250,
                 qf_weight_decay=0.,
                 qf_update_method='adam',
                 qf_learning_rate=1e-3,
                 policy_weight_decay=0,
                 policy_update_method='adam',
                 policy_learning_rate=1e-3,
                 policy_updates_ratio=1.0,
                 eval_samples=10000,
                 soft_target=True,
                 soft_target_tau=0.001,
                 n_updates_per_sample=1,
                 scale_reward=1.0,
                 include_horizon_terminal_transitions=False,
                 plot=False,
                 pause_for_plot=False,
                 **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q function.
        :param qf_update_method: Online optimization method for training Q function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the policy.
        :param policy_update_method: Online optimization method for training the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the policy.
        :param soft_target_tau: Interpolation parameter for doing the soft target update.
        :param n_updates_per_sample: Number of Q function and policy updates per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when training
        :param include_horizon_terminal_transitions: whether to include transitions with terminal=True because the
        horizon was reached. This might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting.
        :return:
        """
        self.env = env
        self.on_policy_env = Serializable.clone(env)
        self.policy = policy
        self.qf = qf
        self.es = es
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay
        self.policy_update_method = \
            FirstOrderOptimizer(
                update_method=policy_update_method,
                learning_rate=policy_learning_rate,
            )
        self.policy_learning_rate = policy_learning_rate
        self.policy_updates_ratio = policy_updates_ratio
        self.eval_samples = eval_samples
        self.train_step = tf.placeholder(tf.float32,
                                         shape=(),
                                         name="train_step")
        self.global_train_step = 0.0

        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0
        self.random_dist = Bernoulli(None, [.5])
        self.sigma_type = kwargs.get('sigma_type', 'gated')

        self.scale_reward = scale_reward

        self.train_policy_itr = 0

        self.opt_info = None

    def start_worker(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

    @overrides
    def train(self):
        gc_dump_time = time.time()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # This seems like a rather sequential method
            pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.env.observation_space.flat_dim,
                action_dim=self.env.action_space.flat_dim,
                replacement_prob=self.replacement_prob,
            )

            on_policy_pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.on_policy_env.observation_space.flat_dim,
                action_dim=self.on_policy_env.action_space.flat_dim,
            )

            self.start_worker()

            self.init_opt()
            # This initializes the optimizer parameters
            sess.run(tf.global_variables_initializer())
            itr = 0
            path_length = 0
            path_return = 0
            terminal = False
            initial = False
            observation = self.env.reset()
            on_policy_terminal = False
            on_policy_initial = False
            on_policy_path_length = 0
            on_policy_path_return = 0
            on_policy_observation = self.on_policy_env.reset()

            #with tf.variable_scope("sample_policy"):
            #with suppress_params_loading():
            #sample_policy = pickle.loads(pickle.dumps(self.policy))
            with tf.variable_scope("sample_policy"):
                sample_policy = Serializable.clone(self.policy)

            for epoch in range(self.n_epochs):
                logger.push_prefix('epoch #%d | ' % epoch)
                logger.log("Training started")
                train_qf_itr, train_policy_itr = 0, 0
                for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                    # Execute policy
                    if terminal:  # or path_length > self.max_path_length:
                        # Note that if the last time step ends an episode, the very
                        # last state and observation will be ignored and not added
                        # to the replay pool
                        observation = self.env.reset()
                        self.es.reset()
                        sample_policy.reset()
                        self.es_path_returns.append(path_return)
                        path_length = 0
                        path_return = 0
                        initial = True
                    else:
                        initial = False

                    if on_policy_terminal:  # or path_length > self.max_path_length:
                        # Note that if the last time step ends an episode, the very
                        # last state and observation will be ignored and not added
                        # to the replay pool
                        observation = self.on_policy_env.reset()
                        sample_policy.reset()
                        on_policy_path_length = 0
                        on_policy_path_return = 0
                        on_policy_initial = True
                    else:
                        on_policy_initial = False

                    action = self.es.get_action(itr,
                                                observation,
                                                policy=sample_policy)  # qf=qf)
                    on_policy_action = self.get_action_on_policy(
                        self.on_policy_env,
                        on_policy_observation,
                        policy=sample_policy)

                    next_observation, reward, terminal, _ = self.env.step(
                        action)
                    on_policy_next_observation, on_policy_reward, on_policy_terminal, _ = self.on_policy_env.step(
                        on_policy_action)

                    path_length += 1
                    path_return += reward
                    on_policy_path_length += 1
                    on_policy_path_return += reward

                    if not terminal and path_length >= self.max_path_length:
                        terminal = True
                        # only include the terminal transition in this case if the flag was set
                        if self.include_horizon_terminal_transitions:
                            pool.add_sample(observation, action,
                                            reward * self.scale_reward,
                                            terminal, initial)
                    else:
                        pool.add_sample(observation, action,
                                        reward * self.scale_reward, terminal,
                                        initial)

                    if not on_policy_terminal and on_policy_path_length >= self.max_path_length:
                        on_policy_terminal = True
                        # only include the terminal transition in this case if the flag was set
                        if self.include_horizon_terminal_transitions:
                            on_policy_pool.add_sample(
                                on_policy_observation, on_policy_action,
                                on_policy_reward * self.scale_reward,
                                on_policy_terminal, on_policy_initial)
                    else:
                        on_policy_pool.add_sample(
                            on_policy_observation, on_policy_action,
                            on_policy_reward * self.scale_reward,
                            on_policy_terminal, on_policy_initial)

                    on_policy_observation = on_policy_next_observation
                    observation = next_observation

                    if pool.size >= self.min_pool_size:
                        self.global_train_step += 1
                        for update_itr in range(self.n_updates_per_sample):
                            # Train policy
                            batch = pool.random_batch(self.batch_size)
                            on_policy_batch = on_policy_pool.random_batch(
                                self.batch_size)
                            itrs = self.do_training(itr, on_policy_batch,
                                                    batch)
                            train_qf_itr += itrs[0]
                            train_policy_itr += itrs[1]
                        sample_policy.set_param_values(
                            self.policy.get_param_values())

                    itr += 1
                    if time.time() - gc_dump_time > 100:
                        gc.collect()
                        gc_dump_time = time.time()

                logger.log("Training finished")
                logger.log("Trained qf %d steps, policy %d steps" %
                           (train_qf_itr, train_policy_itr))
                if pool.size >= self.min_pool_size:
                    self.evaluate(epoch)
                    params = self.get_epoch_snapshot(epoch)
                    logger.save_itr_params(epoch, params)

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
            self.env.terminate()
            self.policy.terminate()

    def get_action_on_policy(self, env_spec, observation, policy, **kwargs):
        action, _ = policy.get_action(observation)
        action_space = env_spec.action_space
        return np.clip(action, action_space.low, action_space.high)

    def init_opt(self):

        # First, create "target" policy and Q functions
        with tf.variable_scope("target_policy"):
            target_policy = Serializable.clone(self.policy)
        with tf.variable_scope("target_qf"):
            target_qf = Serializable.clone(self.qf)

        # y need to be computed first
        obs = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1,
        )

        # The yi values are computed separately as above and then passed to
        # the training functions below
        action = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )

        yvar = tensor_utils.new_tensor(
            'ys',
            ndim=1,
            dtype=tf.float32,
        )

        obs_offpolicy = self.env.observation_space.new_tensor_variable(
            'obs_offpolicy',
            extra_dims=1,
        )

        action_offpolicy = self.env.action_space.new_tensor_variable(
            'action_offpolicy',
            extra_dims=1,
        )

        yvar = tensor_utils.new_tensor(
            'ys',
            ndim=1,
            dtype=tf.float32,
        )

        yvar_offpolicy = tensor_utils.new_tensor(
            'ys_offpolicy',
            ndim=1,
            dtype=tf.float32,
        )

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        qval = self.qf.get_qval_sym(obs, action)
        qval_off = self.qf.get_qval_sym(obs_offpolicy, action_offpolicy)

        qf_loss = tf.reduce_mean(tf.square(yvar - qval))
        qf_loss_off = tf.reduce_mean(tf.square(yvar_offpolicy - qval_off))

        # TODO: penalize dramatic changes in gating_func
        # if PENALIZE_GATING_DISTRIBUTION_DIVERGENCE:


        policy_weight_decay_term = 0.5 * self.policy_weight_decay * \
                                   sum([tf.reduce_sum(tf.square(param))
                                        for param in self.policy.get_params(regularizable=True)])
        policy_qval = self.qf.get_qval_sym(obs,
                                           self.policy.get_action_sym(obs),
                                           deterministic=True)

        policy_qval_off = self.qf.get_qval_sym(
            obs_offpolicy,
            self.policy.get_action_sym(obs_offpolicy),
            deterministic=True)

        policy_surr = -tf.reduce_mean(policy_qval)
        policy_surr_off = -tf.reduce_mean(policy_qval_off)

        if self.sigma_type == 'unified-gated' or self.sigma_type == 'unified-gated-decaying':
            print("Using Gated Sigma!")

            input_to_gates = tf.concat([obs, obs_offpolicy], axis=1)

            assert input_to_gates.get_shape().as_list()[-1] == obs.get_shape(
            ).as_list()[-1] + obs_offpolicy.get_shape().as_list()[-1]

            # TODO: right now this is a soft-gate, should make a hard-gate (options vs mixtures)
            gating_func = MLP(
                name="sigma_gate",
                output_dim=1,
                hidden_sizes=(64, 64),
                hidden_nonlinearity=tf.nn.relu,
                output_nonlinearity=tf.nn.sigmoid,
                input_var=input_to_gates,
                input_shape=tuple(
                    input_to_gates.get_shape().as_list()[1:])).output
        elif self.sigma_type == 'unified':
            # sample a bernoulli random variable
            print("Using Bernoulli sigma!")
            gating_func = tf.cast(self.random_dist.sample(qf_loss.get_shape()),
                                  tf.float32)
        elif self.sigma_type == 'unified-decaying':
            print("Using decaying sigma!")
            gating_func = tf.train.exponential_decay(1.0,
                                                     self.train_step,
                                                     20,
                                                     0.96,
                                                     staircase=True)
        else:
            raise Exception("sigma type not supported")

        qf_inputs_list = [
            yvar, obs, action, yvar_offpolicy, obs_offpolicy, action_offpolicy,
            self.train_step
        ]
        qf_reg_loss = qf_loss * (1.0 - gating_func) + qf_loss_off * (
            gating_func) + qf_weight_decay_term

        policy_input_list = [obs, obs_offpolicy, self.train_step]
        policy_reg_surr = policy_surr * (
            1.0 - gating_func) + policy_surr_off * (
                gating_func) + policy_weight_decay_term

        if self.sigma_type == 'unified-gated-decaying':
            print("Adding a decaying factor to gated sigma!")
            decaying_factor = tf.train.exponential_decay(.5,
                                                         self.train_step,
                                                         20,
                                                         0.96,
                                                         staircase=True)
            penalty = decaying_factor * tf.nn.l2_loss(gating_func)
            qf_reg_loss += penalty
            policy_reg_surr += penalty

        self.qf_update_method.update_opt(qf_reg_loss,
                                         target=self.qf,
                                         inputs=qf_inputs_list)

        self.policy_update_method.update_opt(policy_reg_surr,
                                             target=self.policy,
                                             inputs=policy_input_list)

        f_train_qf = tensor_utils.compile_function(
            inputs=qf_inputs_list,
            outputs=[qf_loss, qval, self.qf_update_method._train_op],
        )

        f_train_policy = tensor_utils.compile_function(
            inputs=policy_input_list,
            outputs=[policy_surr, self.policy_update_method._train_op],
        )

        self.opt_info = dict(
            f_train_qf=f_train_qf,
            f_train_policy=f_train_policy,
            target_qf=target_qf,
            target_policy=target_policy,
        )

    def do_training(self, itr, batch, offpolicy_batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch, "observations", "actions", "rewards", "next_observations",
            "terminals")

        obs_off, actions_off, rewards_off, next_obs_off, terminals_off = ext.extract(
            offpolicy_batch, "observations", "actions", "rewards",
            "next_observations", "terminals")

        # compute the on-policy y values
        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. -
                        terminals) * self.discount * next_qvals.reshape(-1)

        next_actions_off, _ = target_policy.get_actions(next_obs_off)
        next_qvals_off = target_qf.get_qval(next_obs_off, next_actions_off)

        ys_off = rewards + (
            1. - terminals_off) * self.discount * next_qvals_off.reshape(-1)

        f_train_qf = self.opt_info["f_train_qf"]
        f_train_policy = self.opt_info["f_train_policy"]

        qf_loss, qval, _ = f_train_qf(ys, obs, actions, ys_off, obs_off,
                                      actions_off, self.global_train_step)

        target_qf.set_param_values(target_qf.get_param_values() *
                                   (1.0 - self.soft_target_tau) +
                                   self.qf.get_param_values() *
                                   self.soft_target_tau)
        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)  #TODO: also add ys_off

        self.train_policy_itr += self.policy_updates_ratio
        train_policy_itr = 0
        while self.train_policy_itr > 0:
            f_train_policy = self.opt_info["f_train_policy"]
            policy_surr, _ = f_train_policy(obs, obs_off,
                                            self.global_train_step)
            target_policy.set_param_values(target_policy.get_param_values() *
                                           (1.0 - self.soft_target_tau) +
                                           self.policy.get_param_values() *
                                           self.soft_target_tau)
            self.policy_surr_averages.append(policy_surr)
            self.train_policy_itr -= 1
            train_policy_itr += 1

        return 1, train_policy_itr  # number of itrs qf, policy are trained

    def evaluate(self, epoch):
        paths = parallel_sampler.sample_paths(
            policy_params=self.policy.get_param_values(),
            max_samples=self.eval_samples,
            max_path_length=self.max_path_length,
        )

        average_discounted_return = np.mean([
            special.discount_return(path["rewards"], self.discount)
            for path in paths
        ])

        returns = [sum(path["rewards"]) for path in paths]

        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)

        average_q_loss = np.mean(self.qf_loss_averages)
        average_policy_surr = np.mean(self.policy_surr_averages)
        average_action = np.mean(
            np.square(np.concatenate([path["actions"] for path in paths])))

        policy_reg_param_norm = np.linalg.norm(
            self.policy.get_param_values(regularizable=True))
        qfun_reg_param_norm = np.linalg.norm(
            self.qf.get_param_values(regularizable=True))

        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Iteration', epoch)
        logger.record_tabular('AverageReturn', np.mean(returns))
        logger.record_tabular('StdReturn', np.std(returns))
        logger.record_tabular('MaxReturn', np.max(returns))
        logger.record_tabular('MinReturn', np.min(returns))
        if len(self.es_path_returns) > 0:
            logger.record_tabular('AverageEsReturn',
                                  np.mean(self.es_path_returns))
            logger.record_tabular('StdEsReturn', np.std(self.es_path_returns))
            logger.record_tabular('MaxEsReturn', np.max(self.es_path_returns))
            logger.record_tabular('MinEsReturn', np.min(self.es_path_returns))
        logger.record_tabular('AverageDiscountedReturn',
                              average_discounted_return)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AveragePolicySurr', average_policy_surr)
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        logger.record_tabular('AverageAction', average_action)

        logger.record_tabular('PolicyRegParamNorm', policy_reg_param_norm)
        logger.record_tabular('QFunRegParamNorm', qfun_reg_param_norm)

        self.env.log_diagnostics(paths)
        self.policy.log_diagnostics(paths)

        self.qf_loss_averages = []
        self.policy_surr_averages = []

        self.q_averages = []
        self.y_averages = []
        self.es_path_returns = []

    def update_plot(self):
        if self.plot:
            plotter.update_plot(self.policy, self.max_path_length)

    def get_epoch_snapshot(self, epoch):
        return dict(
            env=self.env,
            epoch=epoch,
            qf=self.qf,
            policy=self.policy,
            target_qf=self.opt_info["target_qf"],
            target_policy=self.opt_info["target_policy"],
            es=self.es,
        )
Exemplo n.º 21
0
    def __init__(
            self,
            env,
            policy,
            baseline,
            scope=None,
            n_itr=500,
            start_itr=0,
            batch_size=5000,
            max_path_length=500,
            discount=0.99,
            gae_lambda=1,
            plot=False,
            pause_for_plot=False,
            center_adv=True,
            positive_adv=False,
            store_paths=False,
            whole_paths=True,
            fixed_horizon=False,
            sampler_cls=None,
            sampler_args=None,
            force_batch_sampler=False,
            # qprop params
            qf=None,
            min_pool_size=10000,
            replay_pool_size=1000000,
            replacement_prob=1.0,
            qf_updates_ratio=1,
            qprop_use_mean_action=True,
            qprop_min_itr=0,
            qprop_batch_size=None,
            qprop_use_advantage=True,
            qprop_use_qf_baseline=False,
            qprop_eta_option='ones',
            qf_weight_decay=0.,
            qf_update_method='adam',
            qf_learning_rate=1e-3,
            qf_batch_size=32,
            qf_baseline=None,
            soft_target=True,
            soft_target_tau=0.001,
            scale_reward=1.0,
            **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :type policy: Policy
        :param baseline: Baseline
        :param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms
        simultaneously, each using different environments and policies
        :param n_itr: Number of iterations.
        :param start_itr: Starting iteration.
        :param batch_size: Number of samples per iteration.
        :param max_path_length: Maximum length of a single rollout.
        :param discount: Discount.
        :param gae_lambda: Lambda used for generalized advantage estimation.
        :param plot: Plot evaluation run after each iteration.
        :param pause_for_plot: Whether to pause before contiuing when plotting.
        :param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1.
        :param positive_adv: Whether to shift the advantages so that they are always positive. When used in
        conjunction with center_adv the advantages will be standardized before shifting.
        :param store_paths: Whether to save all paths data to the snapshot.
        :param qf: q function for q-prop.
        :return:
        """
        self.env = env
        self.policy = policy
        self.baseline = baseline
        self.scope = scope
        self.n_itr = n_itr
        self.start_itr = start_itr
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.center_adv = center_adv
        self.positive_adv = positive_adv
        self.store_paths = store_paths
        self.whole_paths = whole_paths
        self.fixed_horizon = fixed_horizon
        self.qf = qf
        if self.qf is not None:
            self.qprop = True
            self.qprop_optimizer = Serializable.clone(self.optimizer)
            self.min_pool_size = min_pool_size
            self.replay_pool_size = replay_pool_size
            self.replacement_prob = replacement_prob
            self.qf_updates_ratio = qf_updates_ratio
            self.qprop_use_mean_action = qprop_use_mean_action
            self.qprop_min_itr = qprop_min_itr
            self.qprop_use_qf_baseline = qprop_use_qf_baseline
            self.qf_weight_decay = qf_weight_decay
            self.qf_update_method = \
                FirstOrderOptimizer(
                    update_method=qf_update_method,
                    learning_rate=qf_learning_rate,
                )
            self.qf_learning_rate = qf_learning_rate
            self.qf_batch_size = qf_batch_size
            self.qf_baseline = qf_baseline
            if qprop_batch_size is None:
                self.qprop_batch_size = self.batch_size
            else:
                self.qprop_batch_size = qprop_batch_size
            self.qprop_use_advantage = qprop_use_advantage
            self.qprop_eta_option = qprop_eta_option
            self.soft_target_tau = soft_target_tau
            self.scale_reward = scale_reward

            self.qf_loss_averages = []
            self.q_averages = []
            self.y_averages = []
            if self.start_itr >= self.qprop_min_itr:
                self.batch_size = self.qprop_batch_size
                if self.qprop_use_qf_baseline:
                    self.baseline = self.qf_baseline
                self.qprop_enable = True
            else:
                self.qprop_enable = False
        else:
            self.qprop = False
        if sampler_cls is None:
            if self.policy.vectorized and not force_batch_sampler:
                sampler_cls = VectorizedSampler
            else:
                sampler_cls = BatchSampler
        if sampler_args is None:
            sampler_args = dict()

        self.sampler = sampler_cls(self, **sampler_args)
Exemplo n.º 22
0
class MAMLNPO(BatchMAMLPolopt):
    """
    Natural Policy Optimization.
    """
    def __init__(self,
                 optimizer=None,
                 optimizer_args=None,
                 step_size=0.01,
                 use_maml=True,
                 **kwargs):
        assert optimizer is not None  # only for use with MAML TRPO

        self.optimizer = optimizer
        self.offPolicy_optimizer = FirstOrderOptimizer(max_epochs=1)
        self.step_size = step_size
        self.use_maml = use_maml
        self.kl_constrain_step = -1  # needs to be 0 or -1 (original pol params, or new pol params)
        super(MAMLNPO, self).__init__(**kwargs)

    def make_vars(self, stepnum='0'):
        # lists over the meta_batch_size
        obs_vars, action_vars, adv_vars, imp_vars = [], [], [], []
        for i in range(self.meta_batch_size):
            obs_vars.append(
                self.env.observation_space.new_tensor_variable(
                    'obs' + stepnum + '_' + str(i),
                    extra_dims=1,
                ))
            action_vars.append(
                self.env.action_space.new_tensor_variable(
                    'action' + stepnum + '_' + str(i),
                    extra_dims=1,
                ))
            adv_vars.append(
                tensor_utils.new_tensor(
                    name='advantage' + stepnum + '_' + str(i),
                    ndim=1,
                    dtype=tf.float32,
                ))

            imp_vars.append(
                tensor_utils.new_tensor(
                    name='imp_ratios' + stepnum + '_' + str(i),
                    ndim=1,
                    dtype=tf.float32,
                ))

        return obs_vars, action_vars, adv_vars, imp_vars

    @overrides
    def init_opt(self):
        is_recurrent = int(self.policy.recurrent)
        assert not is_recurrent  # not supported

        dist = self.policy.distribution

        old_dist_info_vars, old_dist_info_vars_list = [], []
        for i in range(self.meta_batch_size):
            old_dist_info_vars.append({
                k: tf.placeholder(tf.float32,
                                  shape=[None] + list(shape),
                                  name='old_%s_%s' % (i, k))
                for k, shape in dist.dist_info_specs
            })
            old_dist_info_vars_list += [
                old_dist_info_vars[i][k] for k in dist.dist_info_keys
            ]

        state_info_vars, state_info_vars_list = {}, []

        all_surr_objs, input_list = [], []
        new_params = None
        for j in range(self.num_grad_updates):
            obs_vars, action_vars, adv_vars, _ = self.make_vars(str(j))
            surr_objs = []

            cur_params = new_params
            new_params = [
            ]  # if there are several grad_updates the new_params are overwritten
            kls = []

            for i in range(self.meta_batch_size):
                if j == 0:
                    dist_info_vars, params = self.policy.dist_info_sym(
                        obs_vars[i],
                        state_info_vars,
                        all_params=self.policy.all_params)
                    if self.kl_constrain_step == 0:
                        kl = dist.kl_sym(old_dist_info_vars[i], dist_info_vars)
                        kls.append(kl)
                else:
                    dist_info_vars, params = self.policy.updated_dist_info_sym(
                        i,
                        all_surr_objs[-1][i],
                        obs_vars[i],
                        params_dict=cur_params[i])

                new_params.append(params)
                logli = dist.log_likelihood_sym(action_vars[i], dist_info_vars)

                # formulate as a minimization problem
                # The gradient of the surrogate objective is the policy gradient
                surr_objs.append(-tf.reduce_mean(logli * adv_vars[i]))

            input_list += obs_vars + action_vars + adv_vars + state_info_vars_list
            if j == 0:
                # For computing the fast update for sampling
                self.policy.set_init_surr_obj(input_list, surr_objs)
                init_input_list = input_list

            all_surr_objs.append(surr_objs)

        obs_vars, action_vars, adv_vars, _ = self.make_vars('test')
        surr_objs = []
        for i in range(self.meta_batch_size):
            dist_info_vars, _ = self.policy.updated_dist_info_sym(
                i,
                all_surr_objs[-1][i],
                obs_vars[i],
                params_dict=new_params[i])

            if self.kl_constrain_step == -1:  # if we only care about the kl of the last step, the last item in kls will be the overall
                kl = dist.kl_sym(old_dist_info_vars[i], dist_info_vars)
                kls.append(kl)
            lr = dist.likelihood_ratio_sym(action_vars[i],
                                           old_dist_info_vars[i],
                                           dist_info_vars)
            surr_objs.append(-tf.reduce_mean(lr * adv_vars[i]))

        if self.use_maml:
            surr_obj = tf.reduce_mean(tf.stack(
                surr_objs, 0))  # mean over meta_batch_size (the diff tasks)
            input_list += obs_vars + action_vars + adv_vars + old_dist_info_vars_list
        else:
            surr_obj = tf.reduce_mean(
                tf.stack(all_surr_objs[0],
                         0))  # if not meta, just use the first surr_obj
            input_list = init_input_list

        if self.use_maml:
            mean_kl = tf.reduce_mean(
                tf.concat(kls, 0)
            )  ##CF shouldn't this have the option of self.kl_constrain_step == -1?
            max_kl = tf.reduce_max(tf.concat(kls, 0))

            self.optimizer.update_opt(loss=surr_obj,
                                      target=self.policy,
                                      leq_constraint=(mean_kl, self.step_size),
                                      inputs=input_list,
                                      constraint_name="mean_kl")
        else:
            self.optimizer.update_opt(
                loss=surr_obj,
                target=self.policy,
                inputs=input_list,
            )
        return dict()

    @overrides
    def init_opt_offPolicy(self):

        is_recurrent = int(self.policy.recurrent)
        assert not is_recurrent  # not supported
        dist = self.policy.distribution
        state_info_vars, state_info_vars_list = {}, []
        all_surr_objs, input_list = [], []
        new_params = None

        for j in range(self.num_grad_updates):
            obs_vars, action_vars, adv_vars, imp_vars = self.make_vars(str(j))
            surr_objs = []

            cur_params = new_params
            new_params = [
            ]  # if there are several grad_updates the new_params are overwritten

            for i in range(self.meta_batch_size):
                if j == 0:
                    dist_info_vars, params = self.policy.dist_info_sym(
                        obs_vars[i],
                        state_info_vars,
                        all_params=self.policy.all_params)

                else:
                    dist_info_vars, params = self.policy.updated_dist_info_sym(
                        i,
                        all_surr_objs[-1][i],
                        obs_vars[i],
                        params_dict=cur_params[i])

                new_params.append(params)
                logli = dist.log_likelihood_sym(action_vars[i], dist_info_vars)

                # formulate as a minimization problem
                # The gradient of the surrogate objective is the policy gradient
                surr_objs.append(-tf.reduce_mean(logli * imp_vars[i] *
                                                 adv_vars[i]))

            input_list += obs_vars + action_vars + adv_vars + imp_vars
            all_surr_objs.append(surr_objs)

        obs_vars, action_vars, _, _ = self.make_vars('test')
        surr_objs = []
        for i in range(self.meta_batch_size):

            dist_info_vars, _ = self.policy.updated_dist_info_sym(
                i,
                all_surr_objs[-1][i],
                obs_vars[i],
                params_dict=new_params[i])
            logli = dist.log_likelihood_sym(action_vars[i], dist_info_vars)
            surr_objs.append(-tf.reduce_mean(logli))

        surr_obj = tf.reduce_mean(tf.stack(
            surr_objs, 0))  # mean over meta_batch_size (the diff tasks)
        input_list += obs_vars + action_vars

        self.offPolicy_optimizer.update_opt(
            loss=surr_obj,
            target=self.policy,
            inputs=input_list,
        )

    def offPolicy_optimization_step(self, samples_data, expert_data):

        input_list = []
        #for step in range(len(all_samples_data)):  # these are the gradient steps
        obs_list, action_list, adv_list , imp_list , expert_obs_list , expert_action_list = [], [], [] , [] , [], []
        for i in range(self.meta_batch_size):

            inputs = ext.extract(samples_data[i], "observations", "actions",
                                 "advantages", 'traj_imp_weights')
            obs_list.append(inputs[0])
            action_list.append(inputs[1])
            adv_list.append(inputs[2])
            imp_list.append(inputs[3])

            expert_inputs = ext.extract(expert_data[i], "observations",
                                        "actions")
            expert_obs_list.append(expert_inputs[0])
            expert_action_list.append(expert_inputs[1])

        input_list += obs_list + action_list + adv_list + imp_list + expert_obs_list + expert_action_list

        self.offPolicy_optimizer.optimize(input_list)

    @overrides
    def optimize_policy(self, itr, all_samples_data):
        assert len(
            all_samples_data
        ) == self.num_grad_updates + 1  # we collected the rollouts to compute the grads and then the test!

        if not self.use_maml:
            all_samples_data = [all_samples_data[0]]

        input_list = []
        for step in range(
                len(all_samples_data)):  # these are the gradient steps
            obs_list, action_list, adv_list = [], [], []
            for i in range(self.meta_batch_size):

                inputs = ext.extract(all_samples_data[step][i], "observations",
                                     "actions", "advantages")
                obs_list.append(inputs[0])
                action_list.append(inputs[1])
                adv_list.append(inputs[2])
            input_list += obs_list + action_list + adv_list  # [ [obs_0], [act_0], [adv_0], [obs_1], ... ]

            if step == 0:  ##CF not used?
                init_inputs = input_list

        if self.use_maml:
            dist_info_list = []
            for i in range(self.meta_batch_size):
                agent_infos = all_samples_data[
                    self.kl_constrain_step][i]['agent_infos']
                dist_info_list += [
                    agent_infos[k]
                    for k in self.policy.distribution.dist_info_keys
                ]
            input_list += tuple(dist_info_list)
            logger.log("Computing KL before")
            mean_kl_before = self.optimizer.constraint_val(input_list)

        logger.log("Computing loss before")
        loss_before = self.optimizer.loss(input_list)
        logger.log("Optimizing")
        self.optimizer.optimize(input_list)
        logger.log("Computing loss after")
        loss_after = self.optimizer.loss(input_list)
        if self.use_maml:
            logger.log("Computing KL after")
            mean_kl = self.optimizer.constraint_val(input_list)
            logger.record_tabular('MeanKLBefore',
                                  mean_kl_before)  # this now won't be 0!
            logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict()

    @overrides
    def get_itr_snapshot(self, itr, samples_data):
        return dict(
            itr=itr,
            policy=self.policy,
            baseline=self.baseline,
            env=self.env,
        )
Exemplo n.º 23
0
class DDPG(RLAlgorithm):
    """
    Deep Deterministic Policy Gradient.
    """
    def __init__(self,
                 env,
                 policy,
                 oracle_policy,
                 qf,
                 agent_strategy,
                 batch_size=32,
                 n_epochs=200,
                 epoch_length=1000,
                 min_pool_size=10000,
                 replay_pool_size=1000000,
                 replacement_prob=1.0,
                 discount=0.99,
                 max_path_length=250,
                 qf_weight_decay=0.,
                 qf_update_method='adam',
                 qf_learning_rate=1e-3,
                 policy_weight_decay=0,
                 policy_update_method='adam',
                 policy_learning_rate=1e-3,
                 policy_updates_ratio=1.0,
                 eval_samples=10000,
                 soft_target=True,
                 soft_target_tau=0.001,
                 n_updates_per_sample=1,
                 scale_reward=1.0,
                 include_horizon_terminal_transitions=False,
                 plot=False,
                 pause_for_plot=False):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q function.
        :param qf_update_method: Online optimization method for training Q function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the policy.
        :param policy_update_method: Online optimization method for training the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the policy.
        :param soft_target_tau: Interpolation parameter for doing the soft target update.
        :param n_updates_per_sample: Number of Q function and policy updates per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when training
        :param include_horizon_terminal_transitions: whether to include transitions with terminal=True because the
        horizon was reached. This might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting.
        :return:
        """
        self.env = env
        self.policy = policy
        self.oracle_policy = oracle_policy
        self.qf = qf
        self.agent_strategy = agent_strategy
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay

        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay


        self.policy_update_method = \
            FirstOrderOptimizer(
                update_method=policy_update_method,
                learning_rate=policy_learning_rate,
            )

        self.gating_func_update_method = \
            FirstOrderOptimizer(
                update_method='adam',
                learning_rate=policy_learning_rate,
            )

        self.policy_learning_rate = policy_learning_rate
        self.policy_updates_ratio = policy_updates_ratio
        self.eval_samples = eval_samples
        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0

        self.scale_reward = scale_reward

        self.train_policy_itr = 0

        self.opt_info = None

    def start_worker(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

    @overrides
    def train(self):
        with tf.Session() as sess:
            # sess.run(tf.global_variables_initializer())
            # only initialise the uninitialised ones
            self.initialize_uninitialized(sess)

            pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.env.observation_space.flat_dim,
                action_dim=self.env.action_space.flat_dim,
                replacement_prob=self.replacement_prob,
            )

            self.start_worker()
            self.init_opt()

            # This initializes the optimizer parameters
            self.initialize_uninitialized(sess)
            itr = 0
            path_length = 0
            path_return = 0
            terminal = False
            initial = False
            observation = self.env.reset()

            with tf.variable_scope("sample_policy"):
                sample_policy = Serializable.clone(self.policy)

            oracle_policy = self.oracle_policy

            for epoch in range(self.n_epochs):
                logger.push_prefix('epoch #%d | ' % epoch)
                logger.log("Training started")
                train_qf_itr, train_policy_itr = 0, 0

                for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                    # Execute policy
                    if terminal:  # or path_length > self.max_path_length:
                        # Note that if the last time step ends an episode, the very
                        # last state and observation will be ignored and not added
                        # to the replay pool
                        observation = self.env.reset()
                        self.agent_strategy.reset()
                        sample_policy.reset()
                        self.es_path_returns.append(path_return)
                        path_length = 0
                        path_return = 0
                        initial = True
                    else:
                        initial = False

                    ### both continuous actions
                    ### binary_action is continuous here
                    ### it will be approximated as a discrete action with the regularizers
                    ### taken from conditional computation (Bengio)
                    agent_action, binary_action = self.agent_strategy.get_action_with_binary(
                        itr, observation, policy=sample_policy)  # qf=qf)
                    sigma = np.round(binary_action)
                    oracle_action = self.get_oracle_action(
                        itr, observation, policy=oracle_policy)

                    action = sigma[0] * agent_action + sigma[1] * oracle_action
                    next_observation, reward, terminal, _ = self.env.step(
                        action)
                    path_length += 1
                    path_return += reward

                    ### including both the agent and oracle samples in the same replay buffer
                    if not terminal and path_length >= self.max_path_length:
                        terminal = True
                        # only include the terminal transition in this case if the flag was set
                        if self.include_horizon_terminal_transitions:
                            pool.add_sample(observation, action,
                                            reward * self.scale_reward,
                                            terminal, initial)

                    else:
                        pool.add_sample(observation, action,
                                        reward * self.scale_reward, terminal,
                                        initial)

                    observation = next_observation

                    if pool.size >= self.min_pool_size:

                        for update_itr in range(self.n_updates_per_sample):
                            # Train policy
                            batch = pool.random_batch(self.batch_size)
                            itrs = self.do_training(itr, batch)
                            train_qf_itr += itrs[0]
                            train_policy_itr += itrs[1]
                        sample_policy.set_param_values(
                            self.policy.get_param_values())

                    itr += 1

                logger.log("Training finished")
                logger.log("Trained qf %d steps, policy %d steps" %
                           (train_qf_itr, train_policy_itr))
                # logger.log("Pool sizes agent (%d) oracle (%d)" %(agent_only_pool.size, oracle_only_pool.size))

                if pool.size >= self.min_pool_size:
                    self.evaluate(epoch, pool)
                    params = self.get_epoch_snapshot(epoch)
                    logger.save_itr_params(epoch, params)
                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
            self.env.terminate()
            self.policy.terminate()

    def init_opt(self, lambda_s=100, lambda_v=10, tau=.5):

        with tf.variable_scope("target_policy"):
            target_policy = Serializable.clone(self.policy)

        oracle_policy = self.oracle_policy

        with tf.variable_scope("target_qf"):
            target_qf = Serializable.clone(self.qf)

        obs = self.obs = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1,
        )

        action = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )

        yvar = tensor_utils.new_tensor(
            'ys',
            ndim=1,
            dtype=tf.float32,
        )

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        qval = self.qf.get_qval_sym(obs, action)

        qf_loss = tf.reduce_mean(tf.square(yvar - qval))
        qf_reg_loss = qf_loss + qf_weight_decay_term

        policy_weight_decay_term = 0.5 * self.policy_weight_decay * \
                                   sum([tf.reduce_sum(tf.square(param))
                                        for param in self.policy.get_params(regularizable=True)])

        qf_input_list = [yvar, obs, action]
        policy_input_list = [obs]

        obs_oracle = self.env.observation_space.new_tensor_variable(
            'obs_oracle',
            extra_dims=1,
        )

        action_oracle = self.env.action_space.new_tensor_variable(
            'action_oracle',
            extra_dims=1,
        )

        yvar_oracle = tensor_utils.new_tensor(
            'ys_oracle',
            ndim=1,
            dtype=tf.float32,
        )

        qval_oracle = self.qf.get_qval_sym(obs_oracle, action_oracle)
        qf_loss_oracle = tf.reduce_mean(tf.square(yvar_oracle - qval_oracle))
        qf_reg_loss_oracle = qf_loss_oracle + qf_weight_decay_term

        policy_qval_novice = self.qf.get_qval_sym(
            obs, self.policy.get_novice_policy_sym(obs), deterministic=True)

        gating_network = self.policy.get_action_binary_gate_sym(obs)

        policy_qval_oracle = self.qf.get_qval_sym(
            obs, self.policy.get_action_oracle_sym(obs), deterministic=True)

        combined_losses = tf.concat([
            tf.reshape(policy_qval_novice, [-1, 1]),
            tf.reshape(policy_qval_oracle, [-1, 1])
        ],
                                    axis=1)

        combined_loss = -tf.reduce_mean(tf.reshape(
            tf.reduce_mean(combined_losses * gating_network, axis=1), [-1, 1]),
                                        axis=0)

        lambda_s_loss = tf.constant(0.0)

        if lambda_s > 0.0:
            lambda_s_loss = lambda_s * (tf.reduce_mean(
                (tf.reduce_mean(gating_network, axis=0) - tau)**
                2) + tf.reduce_mean(
                    (tf.reduce_mean(gating_network, axis=1) - tau)**2))

        lambda_v_loss = tf.constant(0.0)

        if lambda_v > 0.0:
            mean0, var0 = tf.nn.moments(gating_network, axes=[0])
            mean, var1 = tf.nn.moments(gating_network, axes=[1])
            lambda_v_loss = -lambda_v * (tf.reduce_mean(var0) +
                                         tf.reduce_mean(var1))

        combined_losses = tf.concat([
            tf.reshape(policy_qval_novice, [-1, 1]),
            tf.reshape(policy_qval_oracle, [-1, 1])
        ],
                                    axis=1)
        combined_loss = -tf.reduce_mean(tf.reshape(
            tf.reduce_mean(combined_losses * gating_network, axis=1), [-1, 1]),
                                        axis=0)
        lambda_s_loss = tf.constant(0.0)

        if lambda_s > 0.0:
            lambda_s_loss = lambda_s * (tf.reduce_mean(
                (tf.reduce_mean(gating_network, axis=0) - tau)**
                2) + tf.reduce_mean(
                    (tf.reduce_mean(gating_network, axis=1) - tau)**2))

        lambda_v_loss = tf.constant(0.0)

        if lambda_v > 0.0:
            mean0, var0 = tf.nn.moments(gating_network, axes=[0])
            mean, var1 = tf.nn.moments(gating_network, axes=[1])
            lambda_v_loss = -lambda_v * (tf.reduce_mean(var0) +
                                         tf.reduce_mean(var1))

        policy_surr = combined_loss
        policy_reg_surr = combined_loss + policy_weight_decay_term + lambda_s_loss + lambda_v_loss
        gf_input_list = [obs_oracle, action_oracle, yvar_oracle
                         ] + qf_input_list

        self.qf_update_method.update_opt(loss=qf_reg_loss,
                                         target=self.qf,
                                         inputs=qf_input_list)

        self.policy_update_method.update_opt(loss=policy_reg_surr,
                                             target=self.policy,
                                             inputs=policy_input_list)

        f_train_qf = tensor_utils.compile_function(
            inputs=qf_input_list,
            outputs=[qf_loss, qval, self.qf_update_method._train_op],
        )

        f_train_policy = tensor_utils.compile_function(
            inputs=policy_input_list,
            outputs=[
                policy_surr, self.policy_update_method._train_op,
                gating_network
            ],
        )

        self.opt_info = dict(
            f_train_qf=f_train_qf,
            f_train_policy=f_train_policy,
            target_qf=target_qf,
            target_policy=target_policy,
            oracle_policy=oracle_policy,
        )

    def do_training(self, itr, batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch, "observations", "actions", "rewards", "next_observations",
            "terminals")

        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]
        oracle_policy = self.opt_info["oracle_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)
        ys = rewards + (1. -
                        terminals) * self.discount * next_qvals.reshape(-1)

        f_train_qf = self.opt_info["f_train_qf"]

        qf_loss, qval, _ = f_train_qf(ys, obs, actions)

        target_qf.set_param_values(target_qf.get_param_values() *
                                   (1.0 - self.soft_target_tau) +
                                   self.qf.get_param_values() *
                                   self.soft_target_tau)

        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)

        self.train_policy_itr += self.policy_updates_ratio
        train_policy_itr = 0

        while self.train_policy_itr > 0:
            f_train_policy = self.opt_info["f_train_policy"]
            policy_surr, _, gating_outputs = f_train_policy(obs)

            target_policy.set_param_values(target_policy.get_param_values() *
                                           (1.0 - self.soft_target_tau) +
                                           self.policy.get_param_values() *
                                           self.soft_target_tau)
            self.policy_surr_averages.append(policy_surr)

            self.train_policy_itr -= 1
            train_policy_itr += 1

        return 1, train_policy_itr  # number of itrs qf, policy are trained

    def evaluate(self, epoch, pool):
        logger.log("Collecting samples for evaluation")
        paths = parallel_sampler.sample_paths(
            policy_params=self.policy.get_param_values(),
            max_samples=self.eval_samples,
            max_path_length=self.max_path_length,
        )

        average_discounted_return = np.mean([
            special.discount_return(path["rewards"], self.discount)
            for path in paths
        ])

        returns = [sum(path["rewards"]) for path in paths]

        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)

        average_q_loss = np.mean(self.qf_loss_averages)
        average_policy_surr = np.mean(self.policy_surr_averages)
        average_action = np.mean(
            np.square(np.concatenate([path["actions"] for path in paths])))

        policy_reg_param_norm = np.linalg.norm(
            self.policy.get_param_values(regularizable=True))
        qfun_reg_param_norm = np.linalg.norm(
            self.qf.get_param_values(regularizable=True))

        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Iteration', epoch)
        logger.record_tabular('AverageReturn', np.mean(returns))
        logger.record_tabular('StdReturn', np.std(returns))
        logger.record_tabular('MaxReturn', np.max(returns))
        logger.record_tabular('MinReturn', np.min(returns))
        if len(self.es_path_returns) > 0:
            logger.record_tabular('AverageEsReturn',
                                  np.mean(self.es_path_returns))
            logger.record_tabular('StdEsReturn', np.std(self.es_path_returns))
            logger.record_tabular('MaxEsReturn', np.max(self.es_path_returns))
            logger.record_tabular('MinEsReturn', np.min(self.es_path_returns))
        logger.record_tabular('AverageDiscountedReturn',
                              average_discounted_return)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AveragePolicySurr', average_policy_surr)
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        logger.record_tabular('AverageAction', average_action)

        logger.record_tabular('PolicyRegParamNorm', policy_reg_param_norm)
        logger.record_tabular('QFunRegParamNorm', qfun_reg_param_norm)

        self.env.log_diagnostics(paths)
        self.policy.log_diagnostics(paths)

        self.qf_loss_averages = []
        self.policy_surr_averages = []

        self.q_averages = []
        self.y_averages = []
        self.es_path_returns = []

    def update_plot(self):
        if self.plot:
            plotter.update_plot(self.policy, self.max_path_length)

    def get_epoch_snapshot(self, epoch):
        return dict(
            env=self.env,
            epoch=epoch,
            qf=self.qf,
            policy=self.policy,
            target_qf=self.opt_info["target_qf"],
            target_policy=self.opt_info["target_policy"],
            es=self.agent_strategy,
        )

    def get_oracle_action(self, t, observation, policy, **kwargs):
        action, _ = policy.get_action(observation)
        return np.clip(action, self.env.action_space.low,
                       self.env.action_space.high)

    ### to reinitialise variables in TF graph
    ### which has not been initailised so far
    def initialize_uninitialized(self, sess):
        global_vars = tf.global_variables()
        is_not_initialized = sess.run(
            [tf.is_variable_initialized(var) for var in global_vars])
        not_initialized_vars = [
            v for (v, f) in zip(global_vars, is_not_initialized) if not f
        ]
        print([str(i.name) for i in not_initialized_vars])  # only for testing
        if len(not_initialized_vars):
            sess.run(tf.variables_initializer(not_initialized_vars))
Exemplo n.º 24
0
class DDPG(RLAlgorithm):
    """
    Deep Deterministic Policy Gradient.
    """
    def __init__(self,
                 env,
                 policy,
                 qf,
                 es,
                 batch_size=32,
                 n_epochs=200,
                 epoch_length=1000,
                 min_pool_size=10000,
                 replay_pool_size=1000000,
                 replacement_prob=1.0,
                 discount=0.99,
                 max_path_length=250,
                 qf_weight_decay=0.,
                 qf_update_method='adam',
                 qf_learning_rate=1e-3,
                 policy_weight_decay=0,
                 policy_update_method='adam',
                 policy_learning_rate=1e-3,
                 policy_updates_ratio=1.0,
                 eval_samples=10000,
                 soft_target=True,
                 soft_target_tau=0.001,
                 n_updates_per_sample=1,
                 scale_reward=1.0,
                 include_horizon_terminal_transitions=False,
                 plot=False,
                 pause_for_plot=False,
                 **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q function.
        :param qf_update_method: Online optimization method for training Q function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the policy.
        :param policy_update_method: Online optimization method for training the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the policy.
        :param soft_target_tau: Interpolation parameter for doing the soft target update.
        :param n_updates_per_sample: Number of Q function and policy updates per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when training
        :param include_horizon_terminal_transitions: whether to include transitions with terminal=True because the
        horizon was reached. This might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting.
        :return:
        """
        self.env = env
        self.policy = policy
        self.qf = qf
        self.es = es
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay
        self.policy_update_method = \
            FirstOrderOptimizer(
                update_method=policy_update_method,
                learning_rate=policy_learning_rate,
            )
        self.policy_learning_rate = policy_learning_rate
        self.policy_updates_ratio = policy_updates_ratio
        self.eval_samples = eval_samples
        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0

        self.scale_reward = scale_reward

        self.train_policy_itr = 0

        self.opt_info = None

    def start_worker(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

    @overrides
    def train(self):
        gc_dump_time = time.time()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # This seems like a rather sequential method
            pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.env.observation_space.flat_dim,
                action_dim=self.env.action_space.flat_dim,
                replacement_prob=self.replacement_prob,
            )
            self.start_worker()

            self.init_opt()
            # This initializes the optimizer parameters
            sess.run(tf.global_variables_initializer())
            itr = 0
            path_length = 0
            path_return = 0
            terminal = False
            initial = False
            observation = self.env.reset()

            #with tf.variable_scope("sample_policy"):
            #with suppress_params_loading():
            #sample_policy = pickle.loads(pickle.dumps(self.policy))
            with tf.variable_scope("sample_policy"):
                sample_policy = Serializable.clone(self.policy)

            for epoch in range(self.n_epochs):
                logger.push_prefix('epoch #%d | ' % epoch)
                logger.log("Training started")
                train_qf_itr, train_policy_itr = 0, 0
                for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                    # Execute policy
                    if terminal:  # or path_length > self.max_path_length:
                        # Note that if the last time step ends an episode, the very
                        # last state and observation will be ignored and not added
                        # to the replay pool
                        observation = self.env.reset()
                        self.es.reset()
                        sample_policy.reset()
                        self.es_path_returns.append(path_return)
                        path_length = 0
                        path_return = 0
                        initial = True
                    else:
                        initial = False

                    actions = []

                    for i in range(100):
                        action, _ = sample_policy.get_action_with_dropout(
                            observation)
                        actions.append(action)

                    tiled_observations = [observation] * len(actions)

                    all_qvals = []

                    for i in range(100):
                        q_vals = self.qf.get_qval_dropout(
                            np.vstack(tiled_observations), np.vstack(actions))
                        all_qvals.append(q_vals)

                    action_max = np.argmax(np.vstack(all_qvals)) % len(actions)

                    next_observation, reward, terminal, _ = self.env.step(
                        actions[action_max])
                    path_length += 1
                    path_return += reward

                    if not terminal and path_length >= self.max_path_length:
                        terminal = True
                        # only include the terminal transition in this case if the flag was set
                        if self.include_horizon_terminal_transitions:
                            pool.add_sample(observation, action,
                                            reward * self.scale_reward,
                                            terminal, initial)
                    else:
                        pool.add_sample(observation, action,
                                        reward * self.scale_reward, terminal,
                                        initial)

                    observation = next_observation

                    if pool.size >= self.min_pool_size:
                        for update_itr in range(self.n_updates_per_sample):
                            # Train policy
                            batch = pool.random_batch(self.batch_size)
                            itrs = self.do_training(itr, batch)
                            train_qf_itr += itrs[0]
                            train_policy_itr += itrs[1]
                        sample_policy.set_param_values(
                            self.policy.get_param_values())

                    itr += 1
                    if time.time() - gc_dump_time > 100:
                        gc.collect()
                        gc_dump_time = time.time()

                logger.log("Training finished")
                logger.log("Trained qf %d steps, policy %d steps" %
                           (train_qf_itr, train_policy_itr))
                if pool.size >= self.min_pool_size:
                    self.evaluate(epoch, pool)
                    params = self.get_epoch_snapshot(epoch)
                    logger.save_itr_params(epoch, params)
                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
            self.env.terminate()
            self.policy.terminate()

    def init_opt(self):

        # First, create "target" policy and Q functions
        with tf.variable_scope("target_policy"):
            target_policy = Serializable.clone(self.policy)
        with tf.variable_scope("target_qf"):
            target_qf = Serializable.clone(self.qf)

        # y need to be computed first
        obs = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1,
        )

        # The yi values are computed separately as above and then passed to
        # the training functions below
        action = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )

        yvar = tensor_utils.new_tensor(
            'ys',
            ndim=1,
            dtype=tf.float32,
        )

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        qval = self.qf.get_qval_sym(obs, action)

        qf_loss = tf.reduce_mean(tf.square(yvar - qval))
        qf_reg_loss = qf_loss + qf_weight_decay_term

        policy_weight_decay_term = 0.5 * self.policy_weight_decay * \
                                   sum([tf.reduce_sum(tf.square(param))
                                        for param in self.policy.get_params(regularizable=True)])
        policy_qval = self.qf.get_qval_sym(obs,
                                           self.policy.get_action_sym(obs),
                                           deterministic=True)
        policy_surr = -tf.reduce_mean(policy_qval)

        policy_reg_surr = policy_surr + policy_weight_decay_term

        qf_input_list = [yvar, obs, action]
        policy_input_list = [obs]

        self.qf_update_method.update_opt(loss=qf_reg_loss,
                                         target=self.qf,
                                         inputs=qf_input_list)
        self.policy_update_method.update_opt(loss=policy_reg_surr,
                                             target=self.policy,
                                             inputs=policy_input_list)

        f_train_qf = tensor_utils.compile_function(
            inputs=qf_input_list,
            outputs=[qf_loss, qval, self.qf_update_method._train_op],
        )

        f_train_policy = tensor_utils.compile_function(
            inputs=policy_input_list,
            outputs=[policy_surr, self.policy_update_method._train_op],
        )

        self.opt_info = dict(
            f_train_qf=f_train_qf,
            f_train_policy=f_train_policy,
            target_qf=target_qf,
            target_policy=target_policy,
        )

    def do_training(self, itr, batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch, "observations", "actions", "rewards", "next_observations",
            "terminals")

        # compute the on-policy y values
        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. -
                        terminals) * self.discount * next_qvals.reshape(-1)

        f_train_qf = self.opt_info["f_train_qf"]
        qf_loss, qval, _ = f_train_qf(ys, obs, actions)
        target_qf.set_param_values(target_qf.get_param_values() *
                                   (1.0 - self.soft_target_tau) +
                                   self.qf.get_param_values() *
                                   self.soft_target_tau)
        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)

        self.train_policy_itr += self.policy_updates_ratio
        train_policy_itr = 0
        while self.train_policy_itr > 0:
            f_train_policy = self.opt_info["f_train_policy"]
            policy_surr, _ = f_train_policy(obs)
            target_policy.set_param_values(target_policy.get_param_values() *
                                           (1.0 - self.soft_target_tau) +
                                           self.policy.get_param_values() *
                                           self.soft_target_tau)
            self.policy_surr_averages.append(policy_surr)
            self.train_policy_itr -= 1
            train_policy_itr += 1
        return 1, train_policy_itr  # number of itrs qf, policy are trained

    def evaluate(self, epoch, pool):
        logger.log("Collecting samples for evaluation")
        paths = parallel_sampler.sample_paths(
            policy_params=self.policy.get_param_values(),
            max_samples=self.eval_samples,
            max_path_length=self.max_path_length,
        )
        self.env.reset()

        average_discounted_return = np.mean([
            special.discount_return(path["rewards"], self.discount)
            for path in paths
        ])

        returns = [sum(path["rewards"]) for path in paths]

        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)

        average_q_loss = np.mean(self.qf_loss_averages)
        average_policy_surr = np.mean(self.policy_surr_averages)
        average_action = np.mean(
            np.square(np.concatenate([path["actions"] for path in paths])))

        policy_reg_param_norm = np.linalg.norm(
            self.policy.get_param_values(regularizable=True))
        qfun_reg_param_norm = np.linalg.norm(
            self.qf.get_param_values(regularizable=True))

        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Iteration', epoch)
        logger.record_tabular('AverageReturn', np.mean(returns))
        logger.record_tabular('StdReturn', np.std(returns))
        logger.record_tabular('MaxReturn', np.max(returns))
        logger.record_tabular('MinReturn', np.min(returns))
        if len(self.es_path_returns) > 0:
            logger.record_tabular('AverageEsReturn',
                                  np.mean(self.es_path_returns))
            logger.record_tabular('StdEsReturn', np.std(self.es_path_returns))
            logger.record_tabular('MaxEsReturn', np.max(self.es_path_returns))
            logger.record_tabular('MinEsReturn', np.min(self.es_path_returns))
        logger.record_tabular('AverageDiscountedReturn',
                              average_discounted_return)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AveragePolicySurr', average_policy_surr)
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        logger.record_tabular('AverageAction', average_action)

        logger.record_tabular('PolicyRegParamNorm', policy_reg_param_norm)
        logger.record_tabular('QFunRegParamNorm', qfun_reg_param_norm)

        self.env.log_diagnostics(paths)
        self.policy.log_diagnostics(paths)

        self.qf_loss_averages = []
        self.policy_surr_averages = []

        self.q_averages = []
        self.y_averages = []
        self.es_path_returns = []

    def update_plot(self):
        if self.plot:
            plotter.update_plot(self.policy, self.max_path_length)

    def get_epoch_snapshot(self, epoch):
        return dict(
            env=self.env,
            epoch=epoch,
            qf=self.qf,
            policy=self.policy,
            target_qf=self.opt_info["target_qf"],
            target_policy=self.opt_info["target_policy"],
            es=self.es,
        )
Exemplo n.º 25
0
def main(_):

    env = TfEnv(
        AtariEnv(args.env,
                 force_reset=True,
                 record_video=False,
                 record_log=False,
                 resize_size=args.resize_size,
                 atari_noop=args.atari_noop,
                 atari_eplife=args.atari_eplife,
                 atari_firereset=args.atari_firereset))

    policy_network = ConvNetwork(
        name='prob_network',
        input_shape=env.observation_space.shape,
        output_dim=env.action_space.n,
        # number of channels/filters for each conv layer
        conv_filters=(16, 32),
        # filter size
        conv_filter_sizes=(8, 4),
        conv_strides=(4, 2),
        conv_pads=('VALID', 'VALID'),
        hidden_sizes=(256, ),
        hidden_nonlinearity=tf.nn.relu,
        output_nonlinearity=tf.nn.softmax,
        batch_normalization=False)
    policy = CategoricalMLPPolicy(name='policy',
                                  env_spec=env.spec,
                                  prob_network=policy_network)

    if (args.value_function == 'zero'):
        baseline = ZeroBaseline(env.spec)
    else:
        value_network = get_value_network(env)
        baseline_batch_size = args.batch_size * 10

        if (args.value_function == 'conj'):
            baseline_optimizer = ConjugateGradientOptimizer(
                subsample_factor=1.0, num_slices=args.num_slices)
        elif (args.value_function == 'adam'):
            baseline_optimizer = FirstOrderOptimizer(
                max_epochs=3,
                batch_size=512,
                num_slices=args.num_slices,
                verbose=True)
        else:
            logger.log("Inappropirate value function")
            exit(0)
        '''
      baseline = GaussianMLPBaseline(
          env.spec,
          num_slices=args.num_slices,
          regressor_args=dict(
              step_size=0.01,
              mean_network=value_network,
              optimizer=baseline_optimizer,
              subsample_factor=1.0,
              batchsize=baseline_batch_size,
              use_trust_region=False
          )
      )
      '''
        baseline = DeterministicMLPBaseline(env.spec,
                                            num_slices=args.num_slices,
                                            regressor_args=dict(
                                                network=value_network,
                                                optimizer=baseline_optimizer,
                                                normalize_inputs=False))

    algo = TRPO(env=env,
                policy=policy,
                baseline=baseline,
                batch_size=args.batch_size,
                max_path_length=4500,
                n_itr=args.n_itr,
                discount=args.discount_factor,
                step_size=args.step_size,
                clip_reward=(not args.reward_no_scale),
                optimizer_args={
                    "subsample_factor": 1.0,
                    "num_slices": args.num_slices
                }
                #       plot=True
                )

    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=args.n_cpu,
                            inter_op_parallelism_threads=args.n_cpu)
    config.gpu_options.allow_growth = True  # pylint: disable=E1101
    sess = tf.Session(config=config)
    sess.__enter__()
    algo.train(sess)
Exemplo n.º 26
0
    def __init__(
            self,
            env,
            qf,
            es,
            policy=None,
            policy_batch_size=32,
            n_epochs=200,
            epoch_length=1000,
            discount=0.99,
            max_path_length=250,
            policy_weight_decay=0,
            policy_update_method='adam',
            policy_learning_rate=1e-3,
            policy_step_size=0.01,
            policy_optimizer_args=dict(),
            policy_updates_ratio=1.0,
            policy_use_target=True,
            policy_sample_last=False,
            eval_samples=10000,
            updates_ratio=1.0, # #updates/#samples
            scale_reward=1.0,
            include_horizon_terminal_transitions=False,
            save_freq=0,
            save_format='pickle',
            restore_auto=True,
            **kwargs):

        self.env = env
        self.policy = policy
        if self.policy is None: self.qf_dqn = True
        else: self.qf_dqn = False
        self.qf = qf
        self.es = es
        if self.es is None: self.es = ExplorationStrategy()
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.discount = discount
        self.max_path_length = max_path_length

        self.init_critic(**kwargs)

        if not self.qf_dqn:
            self.policy_weight_decay = policy_weight_decay
            if policy_update_method == 'adam':
                self.policy_update_method = \
                    FirstOrderOptimizer(
                        update_method=policy_update_method,
                        learning_rate=policy_learning_rate,
                        **policy_optimizer_args,
                    )
                self.policy_learning_rate = policy_learning_rate
            elif policy_update_method == 'cg':
                self.policy_update_method = \
                    ConjugateGradientOptimizer(
                        **policy_optimizer_args,
                    )
                self.policy_step_size = policy_step_size
            self.policy_optimizer_args = policy_optimizer_args
            self.policy_updates_ratio = policy_updates_ratio
            self.policy_use_target = policy_use_target
            self.policy_batch_size = policy_batch_size
            self.policy_sample_last = policy_sample_last
            self.policy_surr_averages = []
            self.exec_policy = self.policy
        else:
            self.policy_batch_size = 0
            self.exec_policy = self.qf

        self.eval_samples = eval_samples
        self.updates_ratio = updates_ratio
        self.include_horizon_terminal_transitions = include_horizon_terminal_transitions

        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0

        self.scale_reward = scale_reward

        self.train_policy_itr = 0

        self.save_freq = save_freq
        self.save_format = save_format
        self.restore_auto = restore_auto
class Persistence_Length_Exploration(RLAlgorithm):
    """
    PolyRL Exploration Strategy
    """
    def __init__(self,
                 env,
                 policy,
                 qf,
                 L_p=0.08,
                 b_step_size=0.0004,
                 sigma=0.1,
                 max_exploratory_steps=20,
                 batch_size=32,
                 n_epochs=200,
                 epoch_length=1000,
                 min_pool_size=10000,
                 replay_pool_size=1000000,
                 replacement_prob=1.0,
                 discount=0.99,
                 max_path_length=250,
                 qf_weight_decay=0.,
                 qf_update_method='adam',
                 qf_learning_rate=1e-3,
                 policy_weight_decay=0,
                 policy_update_method='adam',
                 policy_learning_rate=1e-3,
                 policy_updates_ratio=1.0,
                 eval_samples=10000,
                 soft_target=True,
                 soft_target_tau=0.001,
                 n_updates_per_sample=1,
                 scale_reward=1.0,
                 include_horizon_terminal_transitions=False,
                 plot=False,
                 pause_for_plot=False):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q function.
        :param qf_update_method: Online optimization method for training Q function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the policy.
        :param policy_update_method: Online optimization method for training the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the policy.
        :param soft_target_tau: Interpolation parameter for doing the soft target update.
        :param n_updates_per_sample: Number of Q function and policy updates per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when training
        :param include_horizon_terminal_transitions: whether to include transitions with terminal=True because the
        horizon was reached. This might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting.
        :return:
        """
        self.env = env
        self.policy = policy
        self.qf = qf
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.replacement_prob = replacement_prob
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            FirstOrderOptimizer(
                update_method=qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay
        self.policy_update_method = \
            FirstOrderOptimizer(
                update_method=policy_update_method,
                learning_rate=policy_learning_rate,
            )
        self.policy_learning_rate = policy_learning_rate
        self.policy_updates_ratio = policy_updates_ratio
        self.eval_samples = eval_samples
        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0
        """
        PolyRL Hyperparameters
        """
        self.b_step_size = b_step_size
        self.L_p = L_p
        self.sigma = sigma
        self.max_exploratory_steps = max_exploratory_steps

        self.scale_reward = scale_reward
        self.train_policy_itr = 0

        self.opt_info = None

    def start_worker(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

    @overrides
    def lp_exploration(self):
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # This seems like a rather sequential method
            pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.env.observation_space.flat_dim,
                action_dim=self.env.action_space.flat_dim,
                replacement_prob=self.replacement_prob,
            )
            self.start_worker()

            with tf.variable_scope("sample_policy", reuse=True):
                sample_policy = Serializable.clone(self.policy)

            self.init_opt()

            # This initializes the optimizer parameters
            sess.run(tf.global_variables_initializer())

            itr = 0
            path_length = 0
            path_return = 0
            terminal = False
            initial = False
            observation = self.env.reset()

            self.initial_action = self.env.action_space.sample()

            chain_actions = np.array([self.initial_action])
            chain_states = np.array([observation])

            action_trajectory_chain = 0
            state_trajectory_chain = 0

            end_traj_action = 0
            end_traj_state = 0

            H_vector = np.random.uniform(
                low=self.env.action_space.low,
                high=self.env.action_space.high,
                size=(self.env.action_space.shape[0], ))
            H = (self.b_step_size / LA.norm(H_vector)) * H_vector

            all_H = np.array([H])
            all_theta = np.array([])

            last_action_chosen = self.initial_action

            for epoch in range(self.max_exploratory_steps):

                print("LP Exploration Episode", epoch)
                print("Replay Buffer Sample Size", pool.size)

                # logger.push_prefix('epoch #%d | ' % epoch)
                # logger.log("Training started")
                train_qf_itr, train_policy_itr = 0, 0

                if epoch == 0:
                    next_action = last_action_chosen + H
                else:
                    next_action = last_action_chosen

                for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):

                    if self.env.action_space.shape[0] == 6:
                        one_vector = self.one_vector_6D()

                    elif self.env.action_space.shape[0] == 21:
                        one_vector = self.one_vector_21D()

                    elif self.env.action_space.shape[0] == 3:
                        one_vector = self.one_vector_3D()

                    elif self.env.action_space.shape[0] == 10:
                        one_vector = self.one_vector_10D()

                    theta_mean = np.arccos(
                        np.exp(np.true_divide(-self.b_step_size,
                                              self.L_p))) * one_vector
                    sigma_iden = self.sigma**2 * np.identity(
                        self.env.action_space.shape[0] - 1)

                    eta = np.random.multivariate_normal(theta_mean, sigma_iden)
                    eta = np.concatenate((np.array([0]), eta), axis=0)
                    """
                    Map H_t to Spherical coordinate
                    """
                    if self.env.action_space.shape[0] == 3:
                        H_conversion = self.cart2pol_3D(H)
                    elif self.env.action_space.shape[0] == 6:
                        H_conversion = self.cart2pol_6D(H)
                    elif self.env.action_space.shape[0] == 10:
                        H_conversion = self.cart2pol_10D(H)
                    elif self.env.action_space.shape[0] == 21:
                        H_conversion = self.cart2pol_21D(H)

                    H = H_conversion + eta
                    """
                    Map H_t to Cartesian coordinate
                    """
                    if self.env.action_space.shape[0] == 3:
                        H_conversion = self.pol2cart_3D(H)
                    elif self.env.action_space.shape[0] == 6:
                        H_conversion = self.pol2cart_6D(H)
                    elif self.env.action_space.shape[0] == 10:
                        H_conversion = self.pol2cart_10D(H)
                    elif self.env.action_space.shape[0] == 21:
                        H_conversion = self.cart2pol_21D(H)

                    H = H_conversion

                    phi_t = next_action
                    phi_t_1 = phi_t + H

                    chosen_action = np.array([phi_t_1])
                    chain_actions = np.append(chain_actions,
                                              chosen_action,
                                              axis=0)

                    chosen_action = chosen_action[0, :]

                    # Execute policy
                    if terminal:  # or path_length > self.max_path_length:
                        # Note that if the last time step ends an episode, the very
                        # last state and observation will be ignored and not added
                        # to the replay pool
                        observation = self.env.reset()
                        # self.es.reset()
                        sample_policy.reset()
                        self.es_path_returns.append(path_return)
                        path_length = 0
                        path_return = 0
                        initial = True
                    else:
                        initial = False

                    chosen_state, reward, terminal, _ = self.env.step(
                        chosen_action)

                    chain_states = np.append(chain_states,
                                             np.array([chosen_state]),
                                             axis=0)

                    action = chosen_action
                    state = chosen_state
                    end_traj_state = chosen_state
                    end_traj_action = chosen_action

                    #updates to be used in next iteration
                    H = phi_t_1 - phi_t
                    all_H = np.append(all_H, np.array([H]), axis=0)
                    next_action = phi_t_1

                    path_length += 1
                    path_return += reward

                    if not terminal and path_length >= self.max_path_length:
                        terminal = True

                        #originally, it was only line above
                        #added these below
                        terminal_state = chosen_state
                        last_action_chosen = self.env.action_space.sample()
                        H_vector = np.random.uniform(
                            low=self.env.action_space.low,
                            high=self.env.action_space.high,
                            size=(self.env.action_space.shape[0], ))
                        H = (self.b_step_size / LA.norm(H_vector)) * H_vector
                        next_action = last_action_chosen + H

                        # path_length = 0
                        # path_return = 0
                        # state = self.env.reset()
                        # sample_policy.reset()

                        # only include the terminal transition in this case if the flag was set
                        if self.include_horizon_terminal_transitions:
                            pool.add_sample(observation, action,
                                            reward * self.scale_reward,
                                            terminal, initial)
                    else:
                        pool.add_sample(observation, action,
                                        reward * self.scale_reward, terminal,
                                        initial)

                    observation = state

                    if pool.size >= self.min_pool_size:
                        for update_itr in range(self.n_updates_per_sample):
                            # Train policy
                            batch = pool.random_batch(self.batch_size)
                            itrs = self.do_training(itr, batch)

                            train_qf_itr += itrs[0]
                            train_policy_itr += itrs[1]
                        sample_policy.set_param_values(
                            self.policy.get_param_values())

                    itr += 1

                last_action_chosen = action
                #last_action_chosen = last_action_chosen[0, :]

            action_trajectory_chain = chain_actions
            state_trajectory_chain = chain_states
            end_trajectory_action = end_traj_action
            end_trajectory_state = end_traj_state

            self.env.terminate()
            self.policy.terminate()

            return self.qf, self.policy, action_trajectory_chain, state_trajectory_chain, end_trajectory_action, end_trajectory_state

    def init_opt(self):

        # First, create "target" policy and Q functions
        with tf.variable_scope("target_policy", reuse=True):
            target_policy = Serializable.clone(self.policy)
        with tf.variable_scope("target_qf", reuse=True):
            target_qf = Serializable.clone(self.qf)

        # target_policy = pickle.loads(pickle.dumps(self.policy))
        # target_qf = pickle.loads(pickle.dumps(self.qf))

        # y need to be computed first
        obs = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1,
        )

        # The yi values are computed separately as above and then passed to
        # the training functions below
        action = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )

        yvar = tensor_utils.new_tensor(
            'ys',
            ndim=1,
            dtype=tf.float32,
        )

        qf_weight_decay_term = 0.5 * self.qf_weight_decay * \
                               sum([tf.reduce_sum(tf.square(param)) for param in
                                    self.qf.get_params(regularizable=True)])

        qval = self.qf.get_qval_sym(obs, action)

        qf_loss = tf.reduce_mean(tf.square(yvar - qval))
        qf_reg_loss = qf_loss + qf_weight_decay_term

        policy_weight_decay_term = 0.5 * self.policy_weight_decay * \
                                   sum([tf.reduce_sum(tf.square(param))
                                        for param in self.policy.get_params(regularizable=True)])
        policy_qval = self.qf.get_qval_sym(obs,
                                           self.policy.get_action_sym(obs),
                                           deterministic=True)
        policy_surr = -tf.reduce_mean(policy_qval)

        policy_reg_surr = policy_surr + policy_weight_decay_term

        qf_input_list = [yvar, obs, action]
        policy_input_list = [obs]

        self.qf_update_method.update_opt(loss=qf_reg_loss,
                                         target=self.qf,
                                         inputs=qf_input_list)
        self.policy_update_method.update_opt(loss=policy_reg_surr,
                                             target=self.policy,
                                             inputs=policy_input_list)

        f_train_qf = tensor_utils.compile_function(
            inputs=qf_input_list,
            outputs=[qf_loss, qval, self.qf_update_method._train_op],
        )

        f_train_policy = tensor_utils.compile_function(
            inputs=policy_input_list,
            outputs=[policy_surr, self.policy_update_method._train_op],
        )

        self.opt_info = dict(
            f_train_qf=f_train_qf,
            f_train_policy=f_train_policy,
            target_qf=target_qf,
            target_policy=target_policy,
        )

    def do_training(self, itr, batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch, "observations", "actions", "rewards", "next_observations",
            "terminals")

        # compute the on-policy y values
        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. -
                        terminals) * self.discount * next_qvals.reshape(-1)

        f_train_qf = self.opt_info["f_train_qf"]
        qf_loss, qval, _ = f_train_qf(ys, obs, actions)
        target_qf.set_param_values(target_qf.get_param_values() *
                                   (1.0 - self.soft_target_tau) +
                                   self.qf.get_param_values() *
                                   self.soft_target_tau)
        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)

        self.train_policy_itr += self.policy_updates_ratio
        train_policy_itr = 0

        while self.train_policy_itr > 0:
            f_train_policy = self.opt_info["f_train_policy"]
            policy_surr, _ = f_train_policy(obs)
            target_policy.set_param_values(target_policy.get_param_values() *
                                           (1.0 - self.soft_target_tau) +
                                           self.policy.get_param_values() *
                                           self.soft_target_tau)
            self.policy_surr_averages.append(policy_surr)
            self.train_policy_itr -= 1
            train_policy_itr += 1

        return 1, train_policy_itr

    def evaluate(self, epoch, pool):
        logger.log("Collecting samples for evaluation")
        paths = parallel_sampler.sample_paths(
            policy_params=self.policy.get_param_values(),
            max_samples=self.eval_samples,
            max_path_length=self.max_path_length,
        )

        average_discounted_return = np.mean([
            special.discount_return(path["rewards"], self.discount)
            for path in paths
        ])

        returns = [sum(path["rewards"]) for path in paths]

        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)

        average_q_loss = np.mean(self.qf_loss_averages)
        average_policy_surr = np.mean(self.policy_surr_averages)
        average_action = np.mean(
            np.square(np.concatenate([path["actions"] for path in paths])))

        policy_reg_param_norm = np.linalg.norm(
            self.policy.get_param_values(regularizable=True))
        qfun_reg_param_norm = np.linalg.norm(
            self.qf.get_param_values(regularizable=True))

        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Iteration', epoch)
        logger.record_tabular('AverageReturn', np.mean(returns))
        logger.record_tabular('StdReturn', np.std(returns))
        logger.record_tabular('MaxReturn', np.max(returns))
        logger.record_tabular('MinReturn', np.min(returns))
        if len(self.es_path_returns) > 0:
            logger.record_tabular('AverageEsReturn',
                                  np.mean(self.es_path_returns))
            logger.record_tabular('StdEsReturn', np.std(self.es_path_returns))
            logger.record_tabular('MaxEsReturn', np.max(self.es_path_returns))
            logger.record_tabular('MinEsReturn', np.min(self.es_path_returns))
        logger.record_tabular('AverageDiscountedReturn',
                              average_discounted_return)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AveragePolicySurr', average_policy_surr)
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        logger.record_tabular('AverageAction', average_action)

        logger.record_tabular('PolicyRegParamNorm', policy_reg_param_norm)
        logger.record_tabular('QFunRegParamNorm', qfun_reg_param_norm)

        self.env.log_diagnostics(paths)
        self.policy.log_diagnostics(paths)

        self.qf_loss_averages = []
        self.policy_surr_averages = []

        self.q_averages = []
        self.y_averages = []
        self.es_path_returns = []

    def update_plot(self):
        if self.plot:
            plotter.update_plot(self.policy, self.max_path_length)

    def get_epoch_snapshot(self, epoch):
        return dict(
            env=self.env,
            epoch=epoch,
            qf=self.qf,
            policy=self.policy,
            target_qf=self.opt_info["target_qf"],
            target_policy=self.opt_info["target_policy"],
            es=self.es,
        )

    def one_vector_3D(self):
        one_vec_x1 = random.randint(-1, 0)
        if one_vec_x1 == 0:
            one_vec_x1 = 1

        one_vec_x2 = random.randint(-1, 0)
        if one_vec_x2 == 0:
            one_vec_x2 = 1

        one_vector = np.array([one_vec_x1, one_vec_x2])

        return one_vector

    def one_vector_6D(self):
        one_vec_x1 = random.randint(-1, 0)
        if one_vec_x1 == 0:
            one_vec_x1 = 1

        one_vec_x2 = random.randint(-1, 0)
        if one_vec_x2 == 0:
            one_vec_x2 = 1

        one_vec_x3 = random.randint(-1, 0)
        if one_vec_x3 == 0:
            one_vec_x3 = 1

        one_vec_x4 = random.randint(-1, 0)
        if one_vec_x4 == 0:
            one_vec_x4 = 1

        one_vec_x5 = random.randint(-1, 0)
        if one_vec_x5 == 0:
            one_vec_x5 = 1

        one_vector = np.array(
            [one_vec_x1, one_vec_x2, one_vec_x3, one_vec_x4, one_vec_x5])

        return one_vector

    def one_vector_10D(self):
        one_vec_x1 = random.randint(-1, 0)
        if one_vec_x1 == 0:
            one_vec_x1 = 1

        one_vec_x2 = random.randint(-1, 0)
        if one_vec_x2 == 0:
            one_vec_x2 = 1

        one_vec_x3 = random.randint(-1, 0)
        if one_vec_x3 == 0:
            one_vec_x3 = 1

        one_vec_x4 = random.randint(-1, 0)
        if one_vec_x4 == 0:
            one_vec_x4 = 1

        one_vec_x5 = random.randint(-1, 0)
        if one_vec_x5 == 0:
            one_vec_x5 = 1

        one_vec_x6 = random.randint(-1, 0)
        if one_vec_x6 == 0:
            one_vec_x6 = 1

        one_vec_x7 = random.randint(-1, 0)
        if one_vec_x7 == 0:
            one_vec_x7 = 1

        one_vec_x8 = random.randint(-1, 0)
        if one_vec_x8 == 0:
            one_vec_x8 = 1

        one_vec_x9 = random.randint(-1, 0)
        if one_vec_x9 == 0:
            one_vec_x9 = 1

        # one_vec_x10 = random.randint(-1, 0)
        # if one_vec_x10 == 0:
        #     one_vec_x10 = 1

        one_vector = np.array([
            one_vec_x1, one_vec_x2, one_vec_x3, one_vec_x4, one_vec_x5,
            one_vec_x6, one_vec_x7, one_vec_x8, one_vec_x9
        ])

        return one_vector

    def one_vector_21D(self):
        one_vec_x1 = random.randint(-1, 0)
        if one_vec_x1 == 0:
            one_vec_x1 = 1

        one_vec_x2 = random.randint(-1, 0)
        if one_vec_x2 == 0:
            one_vec_x2 = 1

        one_vec_x3 = random.randint(-1, 0)
        if one_vec_x3 == 0:
            one_vec_x3 = 1

        one_vec_x4 = random.randint(-1, 0)
        if one_vec_x4 == 0:
            one_vec_x4 = 1

        one_vec_x5 = random.randint(-1, 0)
        if one_vec_x5 == 0:
            one_vec_x5 = 1

        one_vec_x6 = random.randint(-1, 0)
        if one_vec_x6 == 0:
            one_vec_x6 = 1

        one_vec_x7 = random.randint(-1, 0)
        if one_vec_x7 == 0:
            one_vec_x7 = 1

        one_vec_x8 = random.randint(-1, 0)
        if one_vec_x8 == 0:
            one_vec_x8 = 1

        one_vec_x9 = random.randint(-1, 0)
        if one_vec_x9 == 0:
            one_vec_x9 = 1

        one_vec_x10 = random.randint(-1, 0)
        if one_vec_x10 == 0:
            one_vec_x10 = 1

        one_vec_x11 = random.randint(-1, 0)
        if one_vec_x11 == 0:
            one_vec_x12 = 1

        one_vec_x12 = random.randint(-1, 0)
        if one_vec_x12 == 0:
            one_vec_x12 = 1

        one_vec_x13 = random.randint(-1, 0)
        if one_vec_x13 == 0:
            one_vec_x13 = 1

        one_vec_x14 = random.randint(-1, 0)
        if one_vec_x14 == 0:
            one_vec_x14 = 1

        one_vec_x15 = random.randint(-1, 0)
        if one_vec_x15 == 0:
            one_vec_x15 = 1

        one_vec_x16 = random.randint(-1, 0)
        if one_vec_x16 == 0:
            one_vec_x16 = 1

        one_vec_x17 = random.randint(-1, 0)
        if one_vec_x17 == 0:
            one_vec_x17 = 1

        one_vec_x18 = random.randint(-1, 0)
        if one_vec_x18 == 0:
            one_vec_x18 = 1

        one_vec_x19 = random.randint(-1, 0)
        if one_vec_x19 == 0:
            one_vec_x19 = 1

        one_vec_x20 = random.randint(-1, 0)
        if one_vec_x20 == 0:
            one_vec_x20 = 1

        one_vector = np.array([
            one_vec_x1, one_vec_x2, one_vec_x3, one_vec_x4, one_vec_x5,
            one_vec_x6, one_vec_x7, one_vec_x8, one_vec_x9, one_vec_x10,
            one_vec_x11, one_vec_x12, one_vec_x13, one_vec_x14, one_vec_x15,
            one_vec_x16, one_vec_x17, one_vec_x18, one_vec_x19, one_vec_x20
        ])

        return one_vector

    def cart2pol_6D(self, cartesian):

        x_1 = cartesian[0]
        x_2 = cartesian[1]
        x_3 = cartesian[2]
        x_4 = cartesian[3]
        x_5 = cartesian[4]
        x_6 = cartesian[5]

        modulus = x_1**2 + x_2**2 + x_3**2 + x_4**2 + x_5**2 + x_6**2

        radius = np.sqrt(modulus)
        phi_1 = np.arccos(x_1 / radius)
        phi_2 = np.arccos(x_2 / radius)
        phi_3 = np.arccos(x_3 / radius)

        phi_4 = np.arccos(x_4 / (np.sqrt(x_4**2 + x_5**2 + x_6**2)))

        if x_6 >= 0:
            phi_5 = np.arccos(x_5 / (np.sqrt(x_5**2 + x_6**2)))
        else:
            phi_5 = (2 * np.pi) - np.arccos(x_5 / (np.sqrt(x_5**2 + x_6**2)))

        spherical = np.array([radius, phi_1, phi_2, phi_3, phi_4, phi_5])

        return spherical

    def cart2pol_3D(self, cartesian):

        x_1 = cartesian[0]
        x_2 = cartesian[1]
        x_3 = cartesian[2]

        modulus = x_1**2 + x_2**2 + x_3**2

        radius = np.sqrt(modulus)
        phi_1 = np.arccos(x_1 / radius)
        phi_2 = np.arccos(x_2 / radius)

        spherical = np.array([radius, phi_1, phi_2])

        return spherical

    def cart2pol_10D(self, cartesian):

        x_1 = cartesian[0]
        x_2 = cartesian[1]
        x_3 = cartesian[2]
        x_4 = cartesian[3]
        x_5 = cartesian[4]
        x_6 = cartesian[5]
        x_7 = cartesian[6]
        x_8 = cartesian[7]
        x_9 = cartesian[8]
        x_10 = cartesian[9]

        modulus = x_1**2 + x_2**2 + x_3**2 + x_4**2 + x_5**2 + x_6**2 + x_7**2 + x_8**2 + x_9**2 + x_10**2

        radius = np.sqrt(modulus)
        phi_1 = np.arccos(x_1 / radius)
        phi_2 = np.arccos(x_2 / radius)
        phi_3 = np.arccos(x_3 / radius)
        phi_4 = np.arccos(x_4 / radius)
        phi_5 = np.arccos(x_5 / radius)
        phi_6 = np.arccos(x_6 / radius)
        phi_7 = np.arccos(x_7 / radius)

        phi_8 = np.arccos(x_8 / (np.sqrt(x_10**2 + x_9**2 + x_8**2)))

        if x_10 >= 0:
            phi_9 = np.arccos(x_9 / (np.sqrt(x_10**2 + x_9**2)))
        else:
            phi_9 = (2 * np.pi) - np.arccos(x_9 / (np.sqrt(x_10**2 + x_9**2)))

        spherical = np.array([
            radius, phi_1, phi_2, phi_3, phi_4, phi_5, phi_6, phi_7, phi_8,
            phi_9
        ])

        return spherical

    def cart2pol_21D(self, cartesian):

        x_1 = cartesian[0]
        x_2 = cartesian[1]
        x_3 = cartesian[2]
        x_4 = cartesian[3]
        x_5 = cartesian[4]
        x_6 = cartesian[5]
        x_7 = cartesian[6]
        x_8 = cartesian[7]
        x_9 = cartesian[8]
        x_10 = cartesian[9]
        x_11 = cartesian[10]
        x_12 = cartesian[11]
        x_13 = cartesian[12]
        x_14 = cartesian[13]
        x_15 = cartesian[14]
        x_16 = cartesian[15]
        x_17 = cartesian[16]
        x_18 = cartesian[17]
        x_19 = cartesian[18]
        x_20 = cartesian[19]
        x_21 = cartesian[20]

        modulus = x_1**2 + x_2**2 + x_3**2 + x_4**2 + x_5**2 + x_6**2 + x_7**2 + x_8**2 + x_9**2 + x_10**2 + x_11**2 + x_12**2 + x_13**2 + x_14**2 + x_15**2 + x_16**2 + x_17**2 + x_18**2 + x_19**2 + x_20**2 + x_21**2

        radius = np.sqrt(modulus)
        phi_1 = np.arccos(x_1 / radius)
        phi_2 = np.arccos(x_2 / radius)
        phi_3 = np.arccos(x_3 / radius)
        phi_4 = np.arccos(x_4 / radius)
        phi_5 = np.arccos(x_5 / radius)
        phi_6 = np.arccos(x_6 / radius)
        phi_7 = np.arccos(x_7 / radius)
        phi_8 = np.arccos(x_8 / radius)
        phi_9 = np.arccos(x_9 / radius)
        phi_10 = np.arccos(x_10 / radius)
        phi_11 = np.arccos(x_11 / radius)
        phi_12 = np.arccos(x_12 / radius)
        phi_13 = np.arccos(x_13 / radius)
        phi_14 = np.arccos(x_14 / radius)
        phi_15 = np.arccos(x_15 / radius)
        phi_16 = np.arccos(x_16 / radius)
        phi_17 = np.arccos(x_17 / radius)
        phi_18 = np.arccos(x_18 / radius)

        phi_19 = np.arccos(x_19 / (np.sqrt(x_21**2 + x_20**2 + x_19**2)))

        if x_21 >= 0:
            phi_20 = np.arccos(x_20 / (np.sqrt(x_21**2 + x_20**2)))
        else:
            phi_20 = (2 * np.pi) - np.arccos(x_20 /
                                             (np.sqrt(x_21**2 + x_20**2)))

        spherical = np.array([
            radius, phi_1, phi_2, phi_3, phi_4, phi_5, phi_6, phi_7, phi_8,
            phi_9, phi_10, phi_11, phi_12, phi_13, phi_14, phi_15, phi_16,
            phi_17, phi_18, phi_19, phi_20
        ])

        return spherical

    def pol2cart_6D(self, polar):

        radius = polar[0]
        phi_1 = polar[1]
        phi_2 = polar[2]
        phi_3 = polar[3]
        phi_4 = polar[4]
        phi_5 = polar[5]

        x_1 = radius * np.cos(phi_1)
        x_2 = radius * np.sin(phi_1) * np.cos(phi_2)
        x_3 = radius * np.sin(phi_1) * np.sin(phi_2) * np.cos(phi_3)
        x_4 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.cos(
            phi_4)
        x_5 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.cos(phi_5)
        x_6 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5)

        cartesian = np.array([x_1, x_2, x_3, x_4, x_5, x_6])

        return cartesian

    def pol2cart_3D(self, polar):

        radius = polar[0]
        phi_1 = polar[1]
        phi_2 = polar[2]

        x_1 = radius * np.cos(phi_1)
        x_2 = radius * np.sin(phi_1) * np.cos(phi_2)
        x_3 = radius * np.sin(phi_1) * np.sin(phi_2)

        cartesian = np.array([x_1, x_2, x_3])

        return cartesian

    def pol2cart_10D(self, polar):

        radius = polar[0]
        phi_1 = polar[1]
        phi_2 = polar[2]
        phi_3 = polar[3]
        phi_4 = polar[4]
        phi_5 = polar[5]
        phi_6 = polar[6]
        phi_7 = polar[7]
        phi_8 = polar[8]
        phi_9 = polar[9]

        x_1 = radius * np.cos(phi_1)
        x_2 = radius * np.sin(phi_1) * np.cos(phi_2)
        x_3 = radius * np.sin(phi_1) * np.sin(phi_2) * np.cos(phi_3)
        x_4 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.cos(
            phi_4)
        x_5 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.cos(phi_5)
        x_6 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.cos(phi_6)
        x_7 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.cos(phi_7)
        x_8 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.cos(
                phi_8)
        x_9 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.cos(phi_9)
        x_10 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9)

        cartesian = np.array(
            [x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10])

        return cartesian

    def pol2cart_21D(self, polar):

        radius = polar[0]
        phi_1 = polar[1]
        phi_2 = polar[2]
        phi_3 = polar[3]
        phi_4 = polar[4]
        phi_5 = polar[5]
        phi_6 = polar[6]
        phi_7 = polar[7]
        phi_8 = polar[8]
        phi_9 = polar[9]
        phi_10 = polar[10]
        phi_11 = polar[11]
        phi_12 = polar[12]
        phi_13 = polar[13]
        phi_14 = polar[14]
        phi_15 = polar[15]
        phi_16 = polar[16]
        phi_17 = polar[17]
        phi_18 = polar[18]
        phi_19 = polar[19]
        phi_20 = polar[20]

        x_1 = radius * np.cos(phi_1)
        x_2 = radius * np.sin(phi_1) * np.cos(phi_2)
        x_3 = radius * np.sin(phi_1) * np.sin(phi_2) * np.cos(phi_3)
        x_4 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.cos(
            phi_4)
        x_5 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.cos(phi_5)
        x_6 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.cos(phi_6)
        x_7 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.cos(phi_7)
        x_8 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.cos(
                phi_8)
        x_9 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.cos(phi_9)
        x_10 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.cos(phi_10)
        x_11 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.cos(phi_11)
        x_12 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.cos(phi_12)
        x_13 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.sin(phi_12) * np.cos(phi_13)
        x_14 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.sin(phi_12) * np.sin(phi_13) * np.cos(phi_14)
        x_15 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.sin(phi_12) * np.sin(phi_13) * np.sin(
                        phi_14) * np.cos(phi_15)
        x_16 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.sin(phi_12) * np.sin(phi_13) * np.sin(
                        phi_14) * np.sin(phi_15) * np.cos(phi_16)
        x_17 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.sin(phi_12) * np.sin(phi_13) * np.sin(
                        phi_14) * np.sin(phi_15) * np.sin(phi_16) * np.cos(
                            phi_17)
        x_18 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.sin(phi_12) * np.sin(phi_13) * np.sin(
                        phi_14) * np.sin(phi_15) * np.sin(phi_16) * np.sin(
                            phi_17) * np.cos(phi_18)
        x_19 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.sin(phi_12) * np.sin(phi_13) * np.sin(
                        phi_14) * np.sin(phi_15) * np.sin(phi_16) * np.sin(
                            phi_17) * np.sin(phi_18) * np.cos(phi_19)
        x_20 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.sin(phi_12) * np.sin(phi_13) * np.sin(
                        phi_14) * np.sin(phi_15) * np.sin(phi_16) * np.sin(
                            phi_17) * np.sin(phi_18) * np.sin(phi_19) * np.cos(
                                phi_20)
        x_21 = radius * np.sin(phi_1) * np.sin(phi_2) * np.sin(phi_3) * np.sin(
            phi_4) * np.sin(phi_5) * np.sin(phi_6) * np.sin(phi_7) * np.sin(
                phi_8) * np.sin(phi_9) * np.sin(phi_10) * np.sin(
                    phi_11) * np.sin(phi_12) * np.sin(phi_13) * np.sin(
                        phi_14) * np.sin(phi_15) * np.sin(phi_16) * np.sin(
                            phi_17) * np.sin(phi_18) * np.sin(phi_19) * np.sin(
                                phi_20)

        cartesian = np.array([
            x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10, x_11, x_12,
            x_13, x_14, x_15, x_16, x_17, x_18, x_19, x_20, x_21
        ])

        return cartesian