Example #1
0
class IOC(Off_Policy):
    '''
    Learning Options with Interest Functions, https://www.aaai.org/ojs/index.php/AAAI/article/view/5114/4987 
    Options of Interest: Temporal Abstraction with Interest Functions, http://arxiv.org/abs/2001.00271
    '''
    def __init__(
            self,
            envspec,
            q_lr=5.0e-3,
            intra_option_lr=5.0e-4,
            termination_lr=5.0e-4,
            interest_lr=5.0e-4,
            boltzmann_temperature=1.0,
            options_num=4,
            ent_coff=0.01,
            double_q=False,
            use_baseline=True,
            terminal_mask=True,
            termination_regularizer=0.01,
            assign_interval=1000,
            network_settings={
                'q': [32, 32],
                'intra_option': [32, 32],
                'termination': [32, 32],
                'interest': [32, 32]
            },
            **kwargs):
        super().__init__(envspec=envspec, **kwargs)
        self.assign_interval = assign_interval
        self.options_num = options_num
        self.termination_regularizer = termination_regularizer
        self.ent_coff = ent_coff
        self.use_baseline = use_baseline
        self.terminal_mask = terminal_mask
        self.double_q = double_q
        self.boltzmann_temperature = boltzmann_temperature

        def _create_net(name, representation_net=None):
            return ValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
                value_net_kwargs=dict(output_shape=self.options_num,
                                      network_settings=network_settings['q']))

        self.q_net = _create_net('q_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.q_target_net = _create_net('q_target_net',
                                        self._representation_target_net)

        self.intra_option_net = ValueNetwork(
            name='intra_option_net',
            value_net_type=OutputNetworkType.OC_INTRA_OPTION,
            value_net_kwargs=dict(
                vector_dim=self._representation_net.h_dim,
                output_shape=self.a_dim,
                options_num=self.options_num,
                network_settings=network_settings['intra_option']))
        self.termination_net = ValueNetwork(
            name='termination_net',
            value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
            value_net_kwargs=dict(
                vector_dim=self._representation_net.h_dim,
                output_shape=self.options_num,
                network_settings=network_settings['termination'],
                out_activation='sigmoid'))
        self.interest_net = ValueNetwork(
            name='interest_net',
            value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
            value_net_kwargs=dict(
                vector_dim=self._representation_net.h_dim,
                output_shape=self.options_num,
                network_settings=network_settings['interest'],
                out_activation='sigmoid'))

        self.actor_tv = self.intra_option_net.trainable_variables
        if self.is_continuous:
            self.log_std = tf.Variable(initial_value=-0.5 * np.ones(
                (self.options_num, self.a_dim), dtype=np.float32),
                                       trainable=True)  # [P, A]
            self.actor_tv += [self.log_std]
        update_target_net_weights(self.q_target_net.weights,
                                  self.q_net.weights)

        self.q_lr, self.intra_option_lr, self.termination_lr, self.interest_lr = map(
            self.init_lr, [q_lr, intra_option_lr, termination_lr, interest_lr])
        self.q_optimizer = self.init_optimizer(self.q_lr, clipvalue=5.)
        self.intra_option_optimizer = self.init_optimizer(self.intra_option_lr,
                                                          clipvalue=5.)
        self.termination_optimizer = self.init_optimizer(self.termination_lr,
                                                         clipvalue=5.)
        self.interest_optimizer = self.init_optimizer(self.interest_lr,
                                                      clipvalue=5.)

        self._worker_params_dict.update(self.q_net._policy_models)
        self._worker_params_dict.update(self.intra_option_net._policy_models)
        self._worker_params_dict.update(self.interest_net._policy_models)

        self._all_params_dict.update(self.q_net._all_models)
        self._all_params_dict.update(self.intra_option_net._all_models)
        self._all_params_dict.update(self.interest_net._all_models)
        self._all_params_dict.update(self.termination_net._all_models)
        self._all_params_dict.update(
            q_optimizer=self.q_optimizer,
            intra_option_optimizer=self.intra_option_optimizer,
            termination_optimizer=self.termination_optimizer,
            interest_optimizer=self.interest_optimizer)
        self._model_post_process()

    def _generate_random_options(self):
        return tf.constant(np.random.randint(0, self.options_num,
                                             self.n_agents),
                           dtype=tf.int32)

    def choose_action(self, s, visual_s, evaluation=False):
        if not hasattr(self, 'options'):
            self.options = self._generate_random_options()
        self.last_options = self.options

        a, self.options, self.cell_state = self._get_action(
            s, visual_s, self.cell_state, self.options)
        a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state, options):
        with tf.device(self.device):
            feat, cell_state = self._representation_net(s,
                                                        visual_s,
                                                        cell_state=cell_state)
            q = self.q_net.value_net(feat)  # [B, P]
            pi = self.intra_option_net.value_net(feat)  # [B, P, A]
            options_onehot = tf.one_hot(options,
                                        self.options_num,
                                        dtype=tf.float32)  # [B, P]
            options_onehot_expanded = tf.expand_dims(options_onehot,
                                                     axis=-1)  # [B, P, 1]
            pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1)  # [B, A]
            if self.is_continuous:
                log_std = tf.gather(self.log_std, options)
                mu = tf.math.tanh(pi)
                a, _ = gaussian_clip_rsample(mu, log_std)
            else:
                pi = pi / self.boltzmann_temperature
                dist = tfp.distributions.Categorical(
                    logits=tf.nn.log_softmax(pi))  # [B, ]
                a = dist.sample()
            interests = self.interest_net.value_net(feat)  # [B, P]
            op_logits = interests * q  # [B, P] or tf.nn.softmax(q)
            new_options = tfp.distributions.Categorical(
                logits=tf.nn.log_softmax(op_logits)).sample()
        return a, new_options, cell_state

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.q_target_net.weights,
                                      self.q_net.weights)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'sample_data_list': [
                        's', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done',
                        'last_options', 'options'
                    ],
                    'train_data_list': [
                        's', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done',
                        'last_options', 'options'
                    ],
                    'summary_dict':
                    dict([['LEARNING_RATE/q_lr',
                           self.q_lr(self.train_step)],
                          [
                              'LEARNING_RATE/intra_option_lr',
                              self.intra_option_lr(self.train_step)
                          ],
                          [
                              'LEARNING_RATE/termination_lr',
                              self.termination_lr(self.train_step)
                          ], ['Statistics/option', self.options[0]]])
                })

    @tf.function(experimental_relax_shapes=True)
    def _train(self, memories, isw, cell_state):
        s, visual_s, a, r, s_, visual_s_, done, last_options, options = memories
        last_options = tf.cast(last_options, tf.int32)
        options = tf.cast(options, tf.int32)
        with tf.device(self.device):
            with tf.GradientTape(persistent=True) as tape:
                feat, _ = self._representation_net(s,
                                                   visual_s,
                                                   cell_state=cell_state)
                feat_, _ = self._representation_target_net(
                    s_, visual_s_, cell_state=cell_state)
                q = self.q_net.value_net(feat)  # [B, P]
                pi = self.intra_option_net.value_net(feat)  # [B, P, A]
                beta = self.termination_net.value_net(feat)  # [B, P]
                q_next = self.q_target_net.value_net(
                    feat_)  # [B, P], [B, P, A], [B, P]
                beta_next = self.termination_net.value_net(feat_)  # [B, P]
                interests = self.interest_net.value_net(feat)  # [B, P]
                options_onehot = tf.one_hot(options,
                                            self.options_num,
                                            dtype=tf.float32)  # [B,] => [B, P]

                q_s = qu_eval = tf.reduce_sum(q * options_onehot,
                                              axis=-1,
                                              keepdims=True)  # [B, 1]
                beta_s_ = tf.reduce_sum(beta_next * options_onehot,
                                        axis=-1,
                                        keepdims=True)  # [B, 1]
                q_s_ = tf.reduce_sum(q_next * options_onehot,
                                     axis=-1,
                                     keepdims=True)  # [B, 1]
                if self.double_q:
                    q_ = self.q_net.value_net(
                        feat)  # [B, P], [B, P, A], [B, P]
                    max_a_idx = tf.one_hot(
                        tf.argmax(q_, axis=-1),
                        self.options_num,
                        dtype=tf.float32)  # [B, P] => [B, ] => [B, P]
                    q_s_max = tf.reduce_sum(q_next * max_a_idx,
                                            axis=-1,
                                            keepdims=True)  # [B, 1]
                else:
                    q_s_max = tf.reduce_max(q_next, axis=-1,
                                            keepdims=True)  # [B, 1]
                u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max  # [B, 1]
                qu_target = tf.stop_gradient(r + self.gamma *
                                             (1 - done) * u_target)
                td_error = qu_target - qu_eval  # gradient : q
                q_loss = tf.reduce_mean(tf.square(td_error) *
                                        isw)  # [B, 1] => 1

                if self.use_baseline:
                    adv = tf.stop_gradient(qu_target - qu_eval)
                else:
                    adv = tf.stop_gradient(qu_target)
                options_onehot_expanded = tf.expand_dims(
                    options_onehot, axis=-1)  # [B, P] => [B, P, 1]
                pi = tf.reduce_sum(pi * options_onehot_expanded,
                                   axis=1)  # [B, P, A] => [B, A]
                if self.is_continuous:
                    log_std = tf.gather(self.log_std, options)
                    mu = tf.math.tanh(pi)
                    log_p = gaussian_likelihood_sum(a, mu, log_std)
                    entropy = gaussian_entropy(log_std)
                else:
                    pi = pi / self.boltzmann_temperature
                    log_pi = tf.nn.log_softmax(pi, axis=-1)  # [B, A]
                    entropy = -tf.reduce_sum(tf.exp(log_pi) * log_pi,
                                             axis=1,
                                             keepdims=True)  # [B, 1]
                    log_p = tf.reduce_sum(a * log_pi, axis=-1,
                                          keepdims=True)  # [B, 1]
                pi_loss = tf.reduce_mean(
                    -(log_p * adv + self.ent_coff * entropy)
                )  # [B, 1] * [B, 1] => [B, 1] => 1

                last_options_onehot = tf.one_hot(
                    last_options, self.options_num,
                    dtype=tf.float32)  # [B,] => [B, P]
                beta_s = tf.reduce_sum(beta * last_options_onehot,
                                       axis=-1,
                                       keepdims=True)  # [B, 1]

                pi_op = tf.nn.softmax(
                    interests *
                    tf.stop_gradient(q))  # [B, P] or tf.nn.softmax(q)
                interest_loss = -tf.reduce_mean(beta_s * tf.reduce_sum(
                    pi_op * options_onehot, axis=-1, keepdims=True) *
                                                q_s)  # [B, 1] => 1

                v_s = tf.reduce_sum(q * pi_op, axis=-1,
                                    keepdims=True)  # [B, P] * [B, P] => [B, 1]
                beta_loss = beta_s * tf.stop_gradient(q_s - v_s)  # [B, 1]
                if self.terminal_mask:
                    beta_loss *= (1 - done)
                beta_loss = tf.reduce_mean(beta_loss)  # [B, 1] => 1

            q_grads = tape.gradient(q_loss, self.q_net.trainable_variables)
            intra_option_grads = tape.gradient(pi_loss, self.actor_tv)
            termination_grads = tape.gradient(
                beta_loss, self.termination_net.trainable_variables)
            interest_grads = tape.gradient(
                interest_loss, self.interest_net.trainable_variables)
            self.q_optimizer.apply_gradients(
                zip(q_grads, self.q_net.trainable_variables))
            self.intra_option_optimizer.apply_gradients(
                zip(intra_option_grads, self.actor_tv))
            self.termination_optimizer.apply_gradients(
                zip(termination_grads,
                    self.termination_net.trainable_variables))
            self.interest_optimizer.apply_gradients(
                zip(interest_grads, self.interest_net.trainable_variables))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/q_loss', tf.reduce_mean(q_loss)],
                 ['LOSS/pi_loss', tf.reduce_mean(pi_loss)],
                 ['LOSS/beta_loss',
                  tf.reduce_mean(beta_loss)],
                 ['LOSS/interest_loss',
                  tf.reduce_mean(interest_loss)],
                 ['Statistics/q_option_max',
                  tf.reduce_max(q_s)],
                 ['Statistics/q_option_min',
                  tf.reduce_min(q_s)],
                 ['Statistics/q_option_mean',
                  tf.reduce_mean(q_s)]])

    def store_data(self, s, visual_s, a, r, s_, visual_s_, done):
        """
        for off-policy training, use this function to store <s, a, r, s_, done> into ReplayBuffer.
        """
        assert isinstance(a,
                          np.ndarray), "store need action type is np.ndarray"
        assert isinstance(r,
                          np.ndarray), "store need reward type is np.ndarray"
        assert isinstance(done,
                          np.ndarray), "store need done type is np.ndarray"
        self._running_average(s)
        self.data.add(
            s,
            visual_s,
            a,
            r[:, np.newaxis],  # 升维
            s_,
            visual_s_,
            done[:, np.newaxis],  # 升维
            self.last_options,
            self.options)

    def no_op_store(self, s, visual_s, a, r, s_, visual_s_, done):
        pass
Example #2
0
class CURL(Off_Policy):
    """
    CURL: Contrastive Unsupervised Representations for Reinforcement Learning, http://arxiv.org/abs/2004.04136
    """
    def __init__(
            self,
            envspec,
            alpha=0.2,
            annealing=True,
            last_alpha=0.01,
            ployak=0.995,
            discrete_tau=1.0,
            network_settings={
                'actor_continuous': {
                    'share': [128, 128],
                    'mu': [64],
                    'log_std': [64],
                    'soft_clip': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [64, 32],
                'q': [128, 128],
                'encoder': 128
            },
            auto_adaption=True,
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            alpha_lr=5.0e-4,
            curl_lr=5.0e-4,
            img_size=64,
            **kwargs):
        super().__init__(envspec=envspec, **kwargs)
        self.concat_vector_dim = self.obs_spec.total_vector_dim
        self.ployak = ployak
        self.discrete_tau = discrete_tau
        self.auto_adaption = auto_adaption
        self.annealing = annealing
        self.img_size = img_size
        self.img_dim = [img_size, img_size, self.obs_spec.visual_dims[0][-1]]
        self.vis_feat_size = network_settings['encoder']

        if self.auto_adaption:
            self.log_alpha = tf.Variable(initial_value=0.0,
                                         name='log_alpha',
                                         dtype=tf.float32,
                                         trainable=True)
        else:
            self.log_alpha = tf.Variable(initial_value=tf.math.log(alpha),
                                         name='log_alpha',
                                         dtype=tf.float32,
                                         trainable=False)
            if self.annealing:
                self.alpha_annealing = LinearAnnealing(alpha, last_alpha,
                                                       1.0e6)

        def _create_net(name):
            return DoubleValueNetwork(
                name=name,
                value_net_type=OutputNetworkType.CRITIC_QVALUE_ONE,
                value_net_kwargs=dict(vector_dim=self.concat_vector_dim +
                                      self.vis_feat_size,
                                      action_dim=self.a_dim,
                                      network_settings=network_settings['q']))

        self.critic_net = _create_net('critic_net')
        self.critic_target_net = _create_net('critic_target_net')

        if self.is_continuous:
            self.actor_net = ValueNetwork(
                name='actor_net',
                value_net_type=OutputNetworkType.ACTOR_CTS,
                value_net_kwargs=dict(
                    vector_dim=self.concat_vector_dim + self.vis_feat_size,
                    output_shape=self.a_dim,
                    network_settings=network_settings['actor_continuous']))
        else:
            self.actor_net = ValueNetwork(
                name='actor_net',
                value_net_type=OutputNetworkType.ACTOR_DCT,
                value_net_kwargs=dict(
                    vector_dim=self.concat_vector_dim + self.vis_feat_size,
                    output_shape=self.a_dim,
                    network_settings=network_settings['actor_discrete']))
            self.gumbel_dist = tfp.distributions.Gumbel(0, 1)

        # entropy = -log(1/|A|) = log |A|
        self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else
                                      np.log(self.a_dim))

        self.encoder = VisualEncoder(self.img_dim, self.vis_feat_size)
        self.encoder_target = VisualEncoder(self.img_dim, self.vis_feat_size)

        self.curl_w = tf.Variable(
            initial_value=tf.random.normal(shape=(self.vis_feat_size,
                                                  self.vis_feat_size)),
            name='curl_w',
            dtype=tf.float32,
            trainable=True)

        self.critic_tv = self.critic_net.trainable_variables + self.encoder.trainable_variables

        update_target_net_weights(
            self.critic_target_net.weights +
            self.encoder_target.trainable_variables,
            self.critic_net.weights + self.encoder.trainable_variables)
        self.actor_lr, self.critic_lr, self.alpha_lr, self.curl_lr = map(
            self.init_lr, [actor_lr, critic_lr, alpha_lr, curl_lr])
        self.optimizer_actor, self.optimizer_critic, self.optimizer_alpha, self.optimizer_curl = map(
            self.init_optimizer,
            [self.actor_lr, self.critic_lr, self.alpha_lr, self.curl_lr])

        self._worker_params_dict.update(self.actor_net._policy_models)
        self._worker_params_dict.update(encoder=self.encoder)

        self._all_params_dict.update(self.actor_net._all_models)
        self._all_params_dict.update(self.critic_net._all_models)
        self._all_params_dict.update(curl_w=self.curl_w,
                                     encoder=self.encoder,
                                     optimizer_actor=self.optimizer_actor,
                                     optimizer_critic=self.optimizer_critic,
                                     optimizer_alpha=self.optimizer_alpha,
                                     optimizer_curl=self.optimizer_curl)
        self._model_post_process()

    def choose_action(self, obs, evaluation=False):
        visual = center_crop_image(obs.first_visual()[:, 0], self.img_size)
        mu, pi = self._get_action(visual)
        a = mu.numpy() if evaluation else pi.numpy()
        return a

    @tf.function
    def _get_action(self, visual):
        with tf.device(self.device):
            feat = tf.concat([self.encoder(visual),
                              obs.flatten_vector()],
                             axis=-1)
            if self.is_continuous:
                mu, log_std = self.actor_net.value_net(feat)
                pi, _ = squash_rsample(mu, log_std)
                mu = tf.tanh(mu)  # squash mu
            else:
                logits = self.actor_net.value_net(feat)
                mu = tf.argmax(logits, axis=1)
                cate_dist = tfp.distributions.Categorical(
                    logits=tf.nn.log_softmax(logits))
                pi = cate_dist.sample()
            return mu, pi

    def _process_before_train(self, data: BatchExperiences):
        visual = np.transpose(data.obs.first_visual()[:, 0].numpy(),
                              (0, 3, 1, 2))
        visual_ = np.transpose(data.obs_.first_visual()[:, 0].numpy(),
                               (0, 3, 1, 2))
        pos = np.transpose(random_crop(visual, self.img_size), (0, 2, 3, 1))
        visual = np.transpose(random_crop(visual, self.img_size), (0, 2, 3, 1))
        visual_ = np.transpose(random_crop(visual_, self.img_size),
                               (0, 2, 3, 1))
        return self.data_convert([visual, visual_, pos])

    def _target_params_update(self):
        update_target_net_weights(
            self.critic_target_net.weights +
            self.encoder_target.trainable_variables,
            self.critic_net.weights + self.encoder.trainable_variables,
            self.ployak)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([[
                        'LEARNING_RATE/actor_lr',
                        self.actor_lr(self.train_step)
                    ],
                          [
                              'LEARNING_RATE/critic_lr',
                              self.critic_lr(self.train_step)
                          ],
                          [
                              'LEARNING_RATE/alpha_lr',
                              self.alpha_lr(self.train_step)
                          ]])
                })

    @property
    def alpha(self):
        return tf.exp(self.log_alpha)

    def _train(self, BATCH: BatchExperiences, isw, cell_state):
        visual, visual_, pos = self._process_before_train(BATCH)
        td_error, summaries = self.train(BATCH, isw, cell_state, visual,
                                         visual_, pos)
        if self.annealing and not self.auto_adaption:
            self.log_alpha.assign(
                tf.math.log(
                    tf.cast(self.alpha_annealing(self.global_step.numpy()),
                            tf.float32)))
        return td_error, summaries

    @tf.function
    def train(self, BATCH, isw, cell_state, visual, visual_, pos):
        with tf.device(self.device):
            with tf.GradientTape(persistent=True) as tape:
                vis_feat = self.encoder(visual)
                vis_feat_ = self.encoder(visual_)
                target_vis_feat_ = self.encoder_target(visual_)
                feat = tf.concat(
                    [vis_feat, BATCH.obs.flatten_vector()], axis=-1)
                feat_ = tf.concat(
                    [vis_feat_, BATCH.obs_.flatten_vector()], axis=-1)
                target_feat_ = tf.concat(
                    [target_vis_feat_,
                     BATCH.obs_.flatten_vector()], axis=-1)
                if self.is_continuous:
                    target_mu, target_log_std = self.actor_net.value_net(feat_)
                    target_pi, target_log_pi = squash_rsample(
                        target_mu, target_log_std)
                else:
                    target_logits = self.actor_net.value_net(feat_)
                    target_cate_dist = tfp.distributions.Categorical(
                        logits=tf.nn.log_softmax(target_logits))
                    target_pi = target_cate_dist.sample()
                    target_log_pi = target_cate_dist.log_prob(target_pi)
                    target_pi = tf.one_hot(target_pi,
                                           self.a_dim,
                                           dtype=tf.float32)
                q1, q2 = self.critic_net.value_net(feat, BATCH.action)
                q1_target, q2_target = self.critic_target_net.value_net(
                    feat_, target_pi)
                q_target = tf.minimum(q1_target, q2_target)
                dc_r = tf.stop_gradient(
                    BATCH.reward + self.gamma * (1 - BATCH.done) *
                    (q_target - self.alpha * target_log_pi))
                td_error1 = q1 - dc_r
                td_error2 = q2 - dc_r
                q1_loss = tf.reduce_mean(tf.square(td_error1) * isw)
                q2_loss = tf.reduce_mean(tf.square(td_error2) * isw)
                critic_loss = 0.5 * q1_loss + 0.5 * q2_loss

                z_a = vis_feat  # [B, N]
                z_out = self.encoder_target(pos)
                logits = tf.matmul(
                    z_a, tf.matmul(self.curl_w, tf.transpose(z_out, [1, 0])))
                logits -= tf.reduce_max(logits, axis=-1, keepdims=True)
                curl_loss = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        tf.range(self.batch_size), logits))
            critic_grads = tape.gradient(critic_loss, self.critic_tv)
            self.optimizer_critic.apply_gradients(
                zip(critic_grads, self.critic_tv))
            curl_grads = tape.gradient(curl_loss, [self.curl_w] +
                                       self.encoder.trainable_variables)
            self.optimizer_curl.apply_gradients(
                zip(curl_grads,
                    [self.curl_w] + self.encoder.trainable_variables))

            with tf.GradientTape() as tape:
                if self.is_continuous:
                    mu, log_std = self.actor_net.value_net(feat)
                    pi, log_pi = squash_rsample(mu, log_std)
                    entropy = gaussian_entropy(log_std)
                else:
                    logits = self.actor_net.value_net(feat)
                    logp_all = tf.nn.log_softmax(logits)
                    gumbel_noise = tf.cast(self.gumbel_dist.sample(
                        BATCH.action.shape),
                                           dtype=tf.float32)
                    _pi = tf.nn.softmax(
                        (logp_all + gumbel_noise) / self.discrete_tau)
                    _pi_true_one_hot = tf.one_hot(tf.argmax(_pi, axis=-1),
                                                  self.a_dim)
                    _pi_diff = tf.stop_gradient(_pi_true_one_hot - _pi)
                    pi = _pi_diff + _pi
                    log_pi = tf.reduce_sum(tf.multiply(logp_all, pi),
                                           axis=1,
                                           keepdims=True)
                    entropy = -tf.reduce_mean(
                        tf.reduce_sum(tf.exp(logp_all) * logp_all,
                                      axis=1,
                                      keepdims=True))
                q_s_pi = self.critic_net.get_min(feat, pi)
                actor_loss = -tf.reduce_mean(q_s_pi - self.alpha * log_pi)
            actor_grads = tape.gradient(actor_loss,
                                        self.actor_net.trainable_variables)
            self.optimizer_actor.apply_gradients(
                zip(actor_grads, self.actor_net.trainable_variables))

            if self.auto_adaption:
                with tf.GradientTape() as tape:
                    if self.is_continuous:
                        mu, log_std = self.actor_net.value_net(feat)
                        norm_dist = tfp.distributions.Normal(
                            loc=mu, scale=tf.exp(log_std))
                        log_pi = tf.reduce_sum(norm_dist.log_prob(
                            norm_dist.sample()),
                                               axis=-1,
                                               keep_dims=True)  # [B, 1]
                    else:
                        logits = self.actor_net.value_net(feat)
                        norm_dist = tfp.distributions.Categorical(
                            logits=tf.nn.log_softmax(logits))
                        log_pi = norm_dist.log_prob(cate_dist.sample())
                    alpha_loss = -tf.reduce_mean(
                        self.alpha *
                        tf.stop_gradient(log_pi + self.target_entropy))
                alpha_grad = tape.gradient(alpha_loss, self.log_alpha)
                self.optimizer_alpha.apply_gradients([(alpha_grad,
                                                       self.log_alpha)])
            self.global_step.assign_add(1)
            summaries = dict(
                [['LOSS/actor_loss', actor_loss], ['LOSS/q1_loss', q1_loss],
                 ['LOSS/q2_loss', q2_loss], ['LOSS/critic_loss', critic_loss],
                 ['LOSS/curl_loss', curl_loss],
                 ['Statistics/log_alpha', self.log_alpha],
                 ['Statistics/alpha', self.alpha],
                 ['Statistics/entropy', entropy],
                 ['Statistics/q_min',
                  tf.reduce_min(tf.minimum(q1, q2))],
                 ['Statistics/q_mean',
                  tf.reduce_mean(tf.minimum(q1, q2))],
                 ['Statistics/q_max',
                  tf.reduce_max(tf.maximum(q1, q2))]])
            if self.auto_adaption:
                summaries.update({'LOSS/alpha_loss': alpha_loss})
            return (td_error1 + td_error2) / 2., summaries
Example #3
0
class OC(Off_Policy):
    '''
    The Option-Critic Architecture. http://arxiv.org/abs/1609.05140
    '''
    def __init__(self,
                 envspec,
                 q_lr=5.0e-3,
                 intra_option_lr=5.0e-4,
                 termination_lr=5.0e-4,
                 use_eps_greedy=False,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 boltzmann_temperature=1.0,
                 options_num=4,
                 ent_coff=0.01,
                 double_q=False,
                 use_baseline=True,
                 terminal_mask=True,
                 termination_regularizer=0.01,
                 assign_interval=1000,
                 network_settings={
                     'q': [32, 32],
                     'intra_option': [32, 32],
                     'termination': [32, 32]
                 },
                 **kwargs):
        super().__init__(envspec=envspec, **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval
        self.options_num = options_num
        self.termination_regularizer = termination_regularizer
        self.ent_coff = ent_coff
        self.use_baseline = use_baseline
        self.terminal_mask = terminal_mask
        self.double_q = double_q
        self.boltzmann_temperature = boltzmann_temperature
        self.use_eps_greedy = use_eps_greedy

        def _create_net(name, representation_net=None):
            return ValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
                value_net_kwargs=dict(output_shape=self.options_num,
                                      network_settings=network_settings['q']))

        self.q_net = _create_net('q_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.q_target_net = _create_net('q_target_net',
                                        self._representation_target_net)

        self.intra_option_net = ValueNetwork(
            name='intra_option_net',
            value_net_type=OutputNetworkType.OC_INTRA_OPTION,
            value_net_kwargs=dict(
                vector_dim=self._representation_net.h_dim,
                output_shape=self.a_dim,
                options_num=self.options_num,
                network_settings=network_settings['intra_option']))
        self.termination_net = ValueNetwork(
            name='termination_net',
            value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
            value_net_kwargs=dict(
                vector_dim=self._representation_net.h_dim,
                output_shape=self.options_num,
                network_settings=network_settings['termination'],
                out_activation='sigmoid'))

        self.actor_tv = self.intra_option_net.trainable_variables
        if self.is_continuous:
            self.log_std = tf.Variable(initial_value=-0.5 * np.ones(
                (self.options_num, self.a_dim), dtype=np.float32),
                                       trainable=True)  # [P, A]
            self.actor_tv += [self.log_std]
        update_target_net_weights(self.q_target_net.weights,
                                  self.q_net.weights)

        self.q_lr, self.intra_option_lr, self.termination_lr = map(
            self.init_lr, [q_lr, intra_option_lr, termination_lr])
        self.q_optimizer = self.init_optimizer(self.q_lr, clipvalue=5.)
        self.intra_option_optimizer = self.init_optimizer(self.intra_option_lr,
                                                          clipvalue=5.)
        self.termination_optimizer = self.init_optimizer(self.termination_lr,
                                                         clipvalue=5.)

        self._worker_params_dict.update(self.q_net._policy_models)
        self._worker_params_dict.update(self.intra_option_net._policy_models)
        self._worker_params_dict.update(self.termination_net._policy_models)

        self._all_params_dict.update(self.q_net._all_models)
        self._all_params_dict.update(self.intra_option_net._all_models)
        self._all_params_dict.update(self.termination_net._all_models)
        self._all_params_dict.update(
            q_optimizer=self.q_optimizer,
            intra_option_optimizer=self.intra_option_optimizer,
            termination_optimizer=self.termination_optimizer)
        self._model_post_process()

    def _generate_random_options(self):
        return tf.constant(np.random.randint(0, self.options_num,
                                             self.n_agents),
                           dtype=tf.int32)

    def choose_action(self, s, visual_s, evaluation=False):
        if not hasattr(self, 'options'):
            self.options = self._generate_random_options()
        self.last_options = self.options

        a, self.options, self.cell_state = self._get_action(
            s, visual_s, self.cell_state, self.options)
        if self.use_eps_greedy:
            if np.random.uniform() < self.expl_expt_mng.get_esp(
                    self.train_step, evaluation=evaluation):  # epsilon greedy
                self.options = self._generate_random_options()
        a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state, options):
        with tf.device(self.device):
            feat, cell_state = self._representation_net(s,
                                                        visual_s,
                                                        cell_state=cell_state)
            q = self.q_net.value_net(feat)  # [B, P]
            pi = self.intra_option_net.value_net(feat)  # [B, P, A]
            beta = self.termination_net.value_net(feat)  # [B, P]
            options_onehot = tf.one_hot(options,
                                        self.options_num,
                                        dtype=tf.float32)  # [B, P]
            options_onehot_expanded = tf.expand_dims(options_onehot,
                                                     axis=-1)  # [B, P, 1]
            pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1)  # [B, A]
            if self.is_continuous:
                log_std = tf.gather(self.log_std, options)
                mu = tf.math.tanh(pi)
                a, _ = gaussian_clip_rsample(mu, log_std)
            else:
                pi = pi / self.boltzmann_temperature
                dist = tfp.distributions.Categorical(
                    logits=tf.nn.log_softmax(pi))  # [B, ]
                a = dist.sample()
            max_options = tf.cast(tf.argmax(q, axis=-1),
                                  dtype=tf.int32)  # [B, P] => [B, ]
            if self.use_eps_greedy:
                new_options = max_options
            else:
                beta_probs = tf.reduce_sum(beta * options_onehot,
                                           axis=1)  # [B, P] => [B,]
                beta_dist = tfp.distributions.Bernoulli(probs=beta_probs)
                new_options = tf.where(beta_dist.sample() < 1, options,
                                       max_options)
        return a, new_options, cell_state

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.q_target_net.weights,
                                      self.q_net.weights)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'sample_data_list': [
                        's', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done',
                        'last_options', 'options'
                    ],
                    'train_data_list': [
                        's', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done',
                        'last_options', 'options'
                    ],
                    'summary_dict':
                    dict([['LEARNING_RATE/q_lr',
                           self.q_lr(self.train_step)],
                          [
                              'LEARNING_RATE/intra_option_lr',
                              self.intra_option_lr(self.train_step)
                          ],
                          [
                              'LEARNING_RATE/termination_lr',
                              self.termination_lr(self.train_step)
                          ], ['Statistics/option', self.options[0]]])
                })

    @tf.function(experimental_relax_shapes=True)
    def _train(self, memories, isw, cell_state):
        s, visual_s, a, r, s_, visual_s_, done, last_options, options = memories
        last_options = tf.cast(last_options, tf.int32)
        options = tf.cast(options, tf.int32)
        with tf.device(self.device):
            with tf.GradientTape(persistent=True) as tape:
                feat, _ = self._representation_net(s,
                                                   visual_s,
                                                   cell_state=cell_state)
                feat_, _ = self._representation_target_net(
                    s_, visual_s_, cell_state=cell_state)
                q = self.q_net.value_net(feat)  # [B, P]
                pi = self.intra_option_net.value_net(feat)  # [B, P, A]
                beta = self.termination_net.value_net(feat)  # [B, P]
                q_next = self.q_target_net.value_net(
                    feat_)  # [B, P], [B, P, A], [B, P]
                beta_next = self.termination_net.value_net(feat_)  # [B, P]
                options_onehot = tf.one_hot(options,
                                            self.options_num,
                                            dtype=tf.float32)  # [B,] => [B, P]

                q_s = qu_eval = tf.reduce_sum(q * options_onehot,
                                              axis=-1,
                                              keepdims=True)  # [B, 1]
                beta_s_ = tf.reduce_sum(beta_next * options_onehot,
                                        axis=-1,
                                        keepdims=True)  # [B, 1]
                q_s_ = tf.reduce_sum(q_next * options_onehot,
                                     axis=-1,
                                     keepdims=True)  # [B, 1]
                # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L94
                if self.double_q:
                    q_ = self.q_net.value_net(
                        feat)  # [B, P], [B, P, A], [B, P]
                    max_a_idx = tf.one_hot(
                        tf.argmax(q_, axis=-1),
                        self.options_num,
                        dtype=tf.float32)  # [B, P] => [B, ] => [B, P]
                    q_s_max = tf.reduce_sum(q_next * max_a_idx,
                                            axis=-1,
                                            keepdims=True)  # [B, 1]
                else:
                    q_s_max = tf.reduce_max(q_next, axis=-1,
                                            keepdims=True)  # [B, 1]
                u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max  # [B, 1]
                qu_target = tf.stop_gradient(r + self.gamma *
                                             (1 - done) * u_target)
                td_error = qu_target - qu_eval  # gradient : q
                q_loss = tf.reduce_mean(tf.square(td_error) *
                                        isw)  # [B, 1] => 1

                # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L130
                if self.use_baseline:
                    adv = tf.stop_gradient(qu_target - qu_eval)
                else:
                    adv = tf.stop_gradient(qu_target)
                options_onehot_expanded = tf.expand_dims(
                    options_onehot, axis=-1)  # [B, P] => [B, P, 1]
                pi = tf.reduce_sum(pi * options_onehot_expanded,
                                   axis=1)  # [B, P, A] => [B, A]
                if self.is_continuous:
                    log_std = tf.gather(self.log_std, options)
                    mu = tf.math.tanh(pi)
                    log_p = gaussian_likelihood_sum(a, mu, log_std)
                    entropy = gaussian_entropy(log_std)
                else:
                    pi = pi / self.boltzmann_temperature
                    log_pi = tf.nn.log_softmax(pi, axis=-1)  # [B, A]
                    entropy = -tf.reduce_sum(tf.exp(log_pi) * log_pi,
                                             axis=1,
                                             keepdims=True)  # [B, 1]
                    log_p = tf.reduce_sum(a * log_pi, axis=-1,
                                          keepdims=True)  # [B, 1]
                pi_loss = tf.reduce_mean(
                    -(log_p * adv + self.ent_coff * entropy)
                )  # [B, 1] * [B, 1] => [B, 1] => 1

                last_options_onehot = tf.one_hot(
                    last_options, self.options_num,
                    dtype=tf.float32)  # [B,] => [B, P]
                beta_s = tf.reduce_sum(beta * last_options_onehot,
                                       axis=-1,
                                       keepdims=True)  # [B, 1]
                if self.use_eps_greedy:
                    v_s = tf.reduce_max(
                        q, axis=-1,
                        keepdims=True) - self.termination_regularizer  # [B, 1]
                else:
                    v_s = (1 - beta_s) * q_s + beta_s * tf.reduce_max(
                        q, axis=-1, keepdims=True)  # [B, 1]
                    # v_s = tf.reduce_mean(q, axis=-1, keepdims=True)   # [B, 1]
                beta_loss = beta_s * tf.stop_gradient(q_s - v_s)  # [B, 1]
                # https://github.com/lweitkamp/option-critic-pytorch/blob/0c57da7686f8903ed2d8dded3fae832ee9defd1a/option_critic.py#L238
                if self.terminal_mask:
                    beta_loss *= (1 - done)
                beta_loss = tf.reduce_mean(beta_loss)  # [B, 1] => 1

            q_grads = tape.gradient(q_loss, self.q_net.trainable_variables)
            intra_option_grads = tape.gradient(pi_loss, self.actor_tv)
            termination_grads = tape.gradient(
                beta_loss, self.termination_net.trainable_variables)
            self.q_optimizer.apply_gradients(
                zip(q_grads, self.q_net.trainable_variables))
            self.intra_option_optimizer.apply_gradients(
                zip(intra_option_grads, self.actor_tv))
            self.termination_optimizer.apply_gradients(
                zip(termination_grads,
                    self.termination_net.trainable_variables))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/q_loss', tf.reduce_mean(q_loss)],
                 ['LOSS/pi_loss', tf.reduce_mean(pi_loss)],
                 ['LOSS/beta_loss',
                  tf.reduce_mean(beta_loss)],
                 ['Statistics/q_option_max',
                  tf.reduce_max(q_s)],
                 ['Statistics/q_option_min',
                  tf.reduce_min(q_s)],
                 ['Statistics/q_option_mean',
                  tf.reduce_mean(q_s)]])

    def store_data(self, s, visual_s, a, r, s_, visual_s_, done):
        """
        for off-policy training, use this function to store <s, a, r, s_, done> into ReplayBuffer.
        """
        assert isinstance(a,
                          np.ndarray), "store need action type is np.ndarray"
        assert isinstance(r,
                          np.ndarray), "store need reward type is np.ndarray"
        assert isinstance(done,
                          np.ndarray), "store need done type is np.ndarray"
        self._running_average(s)
        self.data.add(
            s,
            visual_s,
            a,
            r[:, np.newaxis],  # 升维
            s_,
            visual_s_,
            done[:, np.newaxis],  # 升维
            self.last_options,
            self.options)

    def no_op_store(self, s, visual_s, a, r, s_, visual_s_, done):
        pass
Example #4
0
class TAC(Off_Policy):
    """Tsallis Actor Critic, TAC with V neural Network. https://arxiv.org/abs/1902.00137
    """
    def __init__(
            self,
            envspec,
            alpha=0.2,
            annealing=True,
            last_alpha=0.01,
            ployak=0.995,
            entropic_index=1.5,
            discrete_tau=1.0,
            log_std_bound=[-20, 2],
            network_settings={
                'actor_continuous': {
                    'share': [128, 128],
                    'mu': [64],
                    'log_std': [64]
                },
                'actor_discrete': [64, 32],
                'q': [128, 128]
            },
            auto_adaption=True,
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            alpha_lr=5.0e-4,
            **kwargs):
        super().__init__(envspec=envspec, **kwargs)
        self.ployak = ployak
        self.discrete_tau = discrete_tau
        self.entropic_index = 2 - entropic_index
        self.log_std_min, self.log_std_max = log_std_bound[:]
        self.auto_adaption = auto_adaption
        self.annealing = annealing

        if self.auto_adaption:
            self.log_alpha = tf.Variable(initial_value=0.0,
                                         name='log_alpha',
                                         dtype=tf.float32,
                                         trainable=True)
        else:
            self.log_alpha = tf.Variable(initial_value=tf.math.log(alpha),
                                         name='log_alpha',
                                         dtype=tf.float32,
                                         trainable=False)
            if self.annealing:
                self.alpha_annealing = LinearAnnealing(alpha, last_alpha, 1e6)

        def _create_net(name, representation_net=None):
            return DoubleValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.CRITIC_QVALUE_ONE,
                value_net_kwargs=dict(action_dim=self.a_dim,
                                      network_settings=network_settings['q']))

        self.critic_net = _create_net('critic_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.critic_target_net = _create_net('critic_target_net',
                                             self._representation_target_net)

        if self.is_continuous:
            self.actor_net = ValueNetwork(
                name='actor_net',
                value_net_type=OutputNetworkType.ACTOR_CTS,
                value_net_kwargs=dict(
                    vector_dim=self._representation_net.h_dim,
                    output_shape=self.a_dim,
                    network_settings=network_settings['actor_continuous']))
        else:
            self.actor_net = ValueNetwork(
                name='actor_net',
                value_net_type=OutputNetworkType.ACTOR_DCT,
                value_net_kwargs=dict(
                    vector_dim=self._representation_net.h_dim,
                    output_shape=self.a_dim,
                    network_settings=network_settings['actor_discrete']))
            self.gumbel_dist = tfp.distributions.Gumbel(0, 1)

        # entropy = -log(1/|A|) = log |A|
        self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else
                                      np.log(self.a_dim))

        update_target_net_weights(self.critic_target_net.weights,
                                  self.critic_net.weights)
        self.actor_lr, self.critic_lr, self.alpha_lr = map(
            self.init_lr, [actor_lr, critic_lr, alpha_lr])
        self.optimizer_actor, self.optimizer_critic, self.optimizer_alpha = map(
            self.init_optimizer,
            [self.actor_lr, self.critic_lr, self.alpha_lr])

        self._worker_params_dict.update(self._representation_net._all_models)
        self._worker_params_dict.update(self.actor_net._policy_models)

        self._all_params_dict.update(self.actor_net._all_models)
        self._all_params_dict.update(self.critic_net._all_models)
        self._all_params_dict.update(log_alpha=self.log_alpha,
                                     optimizer_actor=self.optimizer_actor,
                                     optimizer_critic=self.optimizer_critic,
                                     optimizer_alpha=self.optimizer_alpha)
        self._model_post_process()

    @property
    def alpha(self):
        return tf.exp(self.log_alpha)

    def choose_action(self, s, visual_s, evaluation=False):
        mu, pi, self.cell_state = self._get_action(s, visual_s,
                                                   self.cell_state)
        a = mu.numpy() if evaluation else pi.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            feat, cell_state = self._representation_net(s,
                                                        visual_s,
                                                        cell_state=cell_state)
            if self.is_continuous:
                mu, log_std = self.actor_net.value_net(feat)
                log_std = clip_nn_log_std(log_std, self.log_std_min,
                                          self.log_std_max)
                pi, _ = tsallis_squash_rsample(mu, log_std,
                                               self.entropic_index)
                mu = tf.tanh(mu)  # squash mu
            else:
                logits = self.actor_net.value_net(feat)
                mu = tf.argmax(logits, axis=1)
                cate_dist = tfp.distributions.Categorical(
                    logits=tf.nn.log_softmax(logits))
                pi = cate_dist.sample()
            return mu, pi, cell_state

    def _target_params_update(self):
        update_target_net_weights(self.critic_target_net.weights,
                                  self.critic_net.weights, self.ployak)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([[
                        'LEARNING_RATE/actor_lr',
                        self.actor_lr(self.train_step)
                    ],
                          [
                              'LEARNING_RATE/critic_lr',
                              self.critic_lr(self.train_step)
                          ],
                          [
                              'LEARNING_RATE/alpha_lr',
                              self.alpha_lr(self.train_step)
                          ]]),
                    'train_data_list':
                    ['ss', 'vvss', 'a', 'r', 'done', 's_', 'visual_s_']
                })

    def _train(self, memories, isw, cell_state):
        td_error, summaries = self.train(memories, isw, cell_state)
        if self.annealing and not self.auto_adaption:
            self.log_alpha.assign(
                tf.math.log(
                    tf.cast(self.alpha_annealing(self.global_step.numpy()),
                            tf.float32)))
        return td_error, summaries

    @tf.function(experimental_relax_shapes=True)
    def train(self, memories, isw, cell_state):
        ss, vvss, a, r, done, s_, visual_s_ = memories
        with tf.device(self.device):
            with tf.GradientTape(persistent=True) as tape:
                (feat,
                 feat_), _ = self._representation_net(ss,
                                                      vvss,
                                                      cell_state=cell_state,
                                                      need_split=True)
                if self.is_continuous:
                    mu, log_std = self.actor_net.value_net(feat)
                    log_std = clip_nn_log_std(log_std, self.log_std_min,
                                              self.log_std_max)
                    pi, log_pi = tsallis_squash_rsample(
                        mu, log_std, self.entropic_index)
                    entropy = gaussian_entropy(log_std)
                    target_mu, target_log_std = self.actor_net.value_net(feat_)
                    target_log_std = clip_nn_log_std(target_log_std,
                                                     self.log_std_min,
                                                     self.log_std_max)
                    target_pi, target_log_pi = tsallis_squash_rsample(
                        target_mu, target_log_std, self.entropic_index)
                else:
                    logits = self.actor_net.value_net(feat)
                    logp_all = tf.nn.log_softmax(logits)
                    gumbel_noise = tf.cast(self.gumbel_dist.sample(a.shape),
                                           dtype=tf.float32)
                    _pi = tf.nn.softmax(
                        (logp_all + gumbel_noise) / self.discrete_tau)
                    _pi_true_one_hot = tf.one_hot(tf.argmax(_pi, axis=-1),
                                                  self.a_dim)
                    _pi_diff = tf.stop_gradient(_pi_true_one_hot - _pi)
                    pi = _pi_diff + _pi
                    log_pi = tf.reduce_sum(tf.multiply(logp_all, pi),
                                           axis=1,
                                           keepdims=True)
                    entropy = -tf.reduce_mean(
                        tf.reduce_sum(tf.exp(logp_all) * logp_all,
                                      axis=1,
                                      keepdims=True))

                    target_logits = self.actor_net.value_net(feat_)
                    target_cate_dist = tfp.distributions.Categorical(
                        logits=tf.nn.log_softmax(target_logits))
                    target_pi = target_cate_dist.sample()
                    target_log_pi = target_cate_dist.log_prob(target_pi)
                    target_pi = tf.one_hot(target_pi,
                                           self.a_dim,
                                           dtype=tf.float32)
                q1, q2 = self.critic_net.get_value(feat, a)
                q_s_pi = self.critic_net.get_min(feat, pi)

                q1_target, q2_target, _ = self.critic_target_net(
                    s_, visual_s_, target_pi, cell_state=cell_state)
                q_target = tf.minimum(q1_target, q2_target)
                dc_r = tf.stop_gradient(
                    r + self.gamma * (1 - done) *
                    (q_target - self.alpha * target_log_pi))
                td_error1 = q1 - dc_r
                td_error2 = q2 - dc_r
                q1_loss = tf.reduce_mean(tf.square(td_error1) * isw)
                q2_loss = tf.reduce_mean(tf.square(td_error2) * isw)
                critic_loss = 0.5 * q1_loss + 0.5 * q2_loss
                actor_loss = -tf.reduce_mean(q_s_pi - self.alpha * log_pi)
                if self.auto_adaption:
                    alpha_loss = -tf.reduce_mean(
                        self.alpha *
                        tf.stop_gradient(log_pi + self.target_entropy))
            critic_grads = tape.gradient(critic_loss,
                                         self.critic_net.trainable_variables)
            self.optimizer_critic.apply_gradients(
                zip(critic_grads, self.critic_net.trainable_variables))
            actor_grads = tape.gradient(actor_loss,
                                        self.actor_net.trainable_variables)
            self.optimizer_actor.apply_gradients(
                zip(actor_grads, self.actor_net.trainable_variables))
            if self.auto_adaption:
                alpha_grad = tape.gradient(alpha_loss, self.log_alpha)
                self.optimizer_alpha.apply_gradients([(alpha_grad,
                                                       self.log_alpha)])
            self.global_step.assign_add(1)
            summaries = dict(
                [['LOSS/actor_loss', actor_loss], ['LOSS/q1_loss', q1_loss],
                 ['LOSS/q2_loss', q2_loss], ['LOSS/critic_loss', critic_loss],
                 ['Statistics/log_alpha', self.log_alpha],
                 ['Statistics/alpha', self.alpha],
                 ['Statistics/entropy', entropy],
                 ['Statistics/q_min',
                  tf.reduce_min(tf.minimum(q1, q2))],
                 ['Statistics/q_mean',
                  tf.reduce_mean(tf.minimum(q1, q2))],
                 ['Statistics/q_max',
                  tf.reduce_max(tf.maximum(q1, q2))]])
            if self.auto_adaption:
                summaries.update({'LOSS/alpha_loss': alpha_loss})
            return (td_error1 + td_error2) / 2, summaries
Example #5
0
class SAC_V(Off_Policy):
    """
        Soft Actor Critic with Value neural network. https://arxiv.org/abs/1812.05905
        Soft Actor-Critic for Discrete Action Settings. https://arxiv.org/abs/1910.07207
    """
    def __init__(
            self,
            envspec,
            alpha=0.2,
            annealing=True,
            last_alpha=0.01,
            ployak=0.995,
            use_gumbel=True,
            discrete_tau=1.0,
            network_settings={
                'actor_continuous': {
                    'share': [128, 128],
                    'mu': [64],
                    'log_std': [64],
                    'soft_clip': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [64, 32],
                'q': [128, 128],
                'v': [128, 128]
            },
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            alpha_lr=5.0e-4,
            auto_adaption=True,
            **kwargs):
        super().__init__(envspec=envspec, **kwargs)
        self.ployak = ployak
        self.use_gumbel = use_gumbel
        self.discrete_tau = discrete_tau
        self.auto_adaption = auto_adaption
        self.annealing = annealing

        if self.auto_adaption:
            self.log_alpha = tf.Variable(initial_value=0.0,
                                         name='log_alpha',
                                         dtype=tf.float32,
                                         trainable=True)
        else:
            self.log_alpha = tf.Variable(initial_value=tf.math.log(alpha),
                                         name='log_alpha',
                                         dtype=tf.float32,
                                         trainable=False)
            if self.annealing:
                self.alpha_annealing = LinearAnnealing(alpha, last_alpha, 1e6)

        def _create_net(name, representation_net=None):
            return ValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.CRITIC_VALUE,
                value_net_kwargs=dict(network_settings=network_settings['v']))

        self.v_net = _create_net('v_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.v_target_net = _create_net('v_target_net',
                                        self._representation_target_net)

        if self.is_continuous:
            self.actor_net = ValueNetwork(
                name='actor_net',
                value_net_type=OutputNetworkType.ACTOR_CTS,
                value_net_kwargs=dict(
                    vector_dim=self._representation_net.h_dim,
                    output_shape=self.a_dim,
                    network_settings=network_settings['actor_continuous']))
        else:
            self.actor_net = ValueNetwork(
                name='actor_net',
                value_net_type=OutputNetworkType.ACTOR_DCT,
                value_net_kwargs=dict(
                    vector_dim=self._representation_net.h_dim,
                    output_shape=self.a_dim,
                    network_settings=network_settings['actor_discrete']))
            if self.use_gumbel:
                self.gumbel_dist = tfp.distributions.Gumbel(0, 1)

        # entropy = -log(1/|A|) = log |A|
        self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else
                                      np.log(self.a_dim))

        if self.is_continuous or self.use_gumbel:
            self.q_net = DoubleValueNetwork(
                name='q_net',
                value_net_type=OutputNetworkType.CRITIC_QVALUE_ONE,
                value_net_kwargs=dict(
                    vector_dim=self._representation_net.h_dim,
                    action_dim=self.a_dim,
                    network_settings=network_settings['q']))
        else:
            self.q_net = DoubleValueNetwork(
                name='q_net',
                value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
                value_net_kwargs=dict(
                    vector_dim=self._representation_net.h_dim,
                    output_shape=self.a_dim,
                    network_settings=network_settings['q']))

        update_target_net_weights(self.v_target_net.weights,
                                  self.v_net.weights)
        self.actor_lr, self.critic_lr, self.alpha_lr = map(
            self.init_lr, [actor_lr, critic_lr, alpha_lr])
        self.optimizer_actor, self.optimizer_critic, self.optimizer_alpha = map(
            self.init_optimizer,
            [self.actor_lr, self.critic_lr, self.alpha_lr])

        self._worker_params_dict.update(self._representation_net._all_models)
        self._worker_params_dict.update(self.actor_net._policy_models)

        self._all_params_dict.update(self.actor_net._all_models)
        self._all_params_dict.update(self.v_net._all_models)
        self._all_params_dict.update(self.q_net._all_models)
        self._all_params_dict.update(log_alpha=self.log_alpha,
                                     optimizer_actor=self.optimizer_actor,
                                     optimizer_critic=self.optimizer_critic,
                                     optimizer_alpha=self.optimizer_alpha)
        self._model_post_process()

    @property
    def alpha(self):
        return tf.exp(self.log_alpha)

    def choose_action(self, obs, evaluation=False):
        mu, pi, self.cell_state = self._get_action(obs, self.cell_state)
        a = mu.numpy() if evaluation else pi.numpy()
        return a

    @tf.function
    def _get_action(self, obs, cell_state):
        with tf.device(self.device):
            feat, cell_state = self._representation_net(obs,
                                                        cell_state=cell_state)
            if self.is_continuous:
                mu, log_std = self.actor_net.value_net(feat)
                pi, _ = squash_rsample(mu, log_std)
                mu = tf.tanh(mu)  # squash mu
            else:
                logits = self.actor_net.value_net(feat)
                mu = tf.argmax(logits, axis=1)
                cate_dist = tfp.distributions.Categorical(
                    logits=tf.nn.log_softmax(logits))
                pi = cate_dist.sample()
            return mu, pi, cell_state

    def _target_params_update(self):
        update_target_net_weights(self.v_target_net.weights,
                                  self.v_net.weights, self.ployak)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([[
                        'LEARNING_RATE/actor_lr',
                        self.actor_lr(self.train_step)
                    ],
                          [
                              'LEARNING_RATE/critic_lr',
                              self.critic_lr(self.train_step)
                          ],
                          [
                              'LEARNING_RATE/alpha_lr',
                              self.alpha_lr(self.train_step)
                          ]]),
                })

    def _train(self, BATCH, isw, cell_state):
        if self.is_continuous or self.use_gumbel:
            td_error, summaries = self.train_continuous(BATCH, isw, cell_state)
        else:
            td_error, summaries = self.train_discrete(BATCH, isw, cell_state)
        if self.annealing and not self.auto_adaption:
            self.log_alpha.assign(
                tf.math.log(
                    tf.cast(self.alpha_annealing(self.global_step.numpy()),
                            tf.float32)))
        return td_error, summaries

    @tf.function
    def train_continuous(self, BATCH, isw, cell_state):
        with tf.device(self.device):
            with tf.GradientTape(persistent=True) as tape:
                feat, _ = self._representation_net(BATCH.obs,
                                                   cell_state=cell_state)
                v = self.v_net.value_net(feat)
                v_target, _ = self.v_target_net(BATCH.obs_,
                                                cell_state=cell_state)

                if self.is_continuous:
                    mu, log_std = self.actor_net.value_net(feat)
                    pi, log_pi = squash_rsample(mu, log_std)
                    entropy = gaussian_entropy(log_std)
                else:
                    logits = self.actor_net.value_net(feat)
                    logp_all = tf.nn.log_softmax(logits)
                    gumbel_noise = tf.cast(self.gumbel_dist.sample(
                        BATCH.action.shape),
                                           dtype=tf.float32)
                    _pi = tf.nn.softmax(
                        (logp_all + gumbel_noise) / self.discrete_tau)
                    _pi_true_one_hot = tf.one_hot(tf.argmax(_pi, axis=-1),
                                                  self.a_dim)
                    _pi_diff = tf.stop_gradient(_pi_true_one_hot - _pi)
                    pi = _pi_diff + _pi
                    log_pi = tf.reduce_sum(tf.multiply(logp_all, pi),
                                           axis=1,
                                           keepdims=True)
                    entropy = -tf.reduce_mean(
                        tf.reduce_sum(tf.exp(logp_all) * logp_all,
                                      axis=1,
                                      keepdims=True))
                q1, q2 = self.q_net.get_value(feat, BATCH.action)
                q1_pi, q2_pi = self.q_net.get_value(feat, pi)
                dc_r = tf.stop_gradient(BATCH.reward + self.gamma * v_target *
                                        (1 - BATCH.done))
                v_from_q_stop = tf.stop_gradient(
                    tf.minimum(q1_pi, q2_pi) - self.alpha * log_pi)
                td_v = v - v_from_q_stop
                td_error1 = q1 - dc_r
                td_error2 = q2 - dc_r
                q1_loss = tf.reduce_mean(tf.square(td_error1) * isw)
                q2_loss = tf.reduce_mean(tf.square(td_error2) * isw)
                v_loss_stop = tf.reduce_mean(tf.square(td_v) * isw)
                critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop
                actor_loss = -tf.reduce_mean(q1_pi - self.alpha * log_pi)
                if self.auto_adaption:
                    alpha_loss = -tf.reduce_mean(
                        self.alpha *
                        tf.stop_gradient(log_pi + self.target_entropy))
            actor_grads = tape.gradient(actor_loss,
                                        self.actor_net.trainable_variables)
            self.optimizer_actor.apply_gradients(
                zip(actor_grads, self.actor_net.trainable_variables))
            critic_grads = tape.gradient(
                critic_loss, self.q_net.trainable_variables +
                self.v_net.trainable_variables)
            self.optimizer_critic.apply_gradients(
                zip(
                    critic_grads, self.q_net.trainable_variables +
                    self.v_net.trainable_variables))
            if self.auto_adaption:
                alpha_grad = tape.gradient(alpha_loss, self.log_alpha)
                self.optimizer_alpha.apply_gradients([(alpha_grad,
                                                       self.log_alpha)])
            self.global_step.assign_add(1)
            summaries = dict(
                [['LOSS/actor_loss', actor_loss], ['LOSS/q1_loss', q1_loss],
                 ['LOSS/q2_loss', q2_loss], ['LOSS/v_loss', v_loss_stop],
                 ['LOSS/critic_loss', critic_loss],
                 ['Statistics/log_alpha', self.log_alpha],
                 ['Statistics/alpha', self.alpha],
                 ['Statistics/entropy', entropy],
                 ['Statistics/q_min',
                  tf.reduce_min(tf.minimum(q1, q2))],
                 ['Statistics/q_mean',
                  tf.reduce_mean(tf.minimum(q1, q2))],
                 ['Statistics/q_max',
                  tf.reduce_max(tf.maximum(q1, q2))],
                 ['Statistics/v_mean', tf.reduce_mean(v)]])
            if self.auto_adaption:
                summaries.update({'LOSS/alpha_loss': alpha_loss})
            return (td_error1 + td_error2) / 2, summaries

    @tf.function
    def train_discrete(self, BATCH, isw, cell_state):
        with tf.device(self.device):
            with tf.GradientTape(persistent=True) as tape:
                feat, _ = self._representation_net(BATCH.obs,
                                                   cell_state=cell_state)
                v = self.v_net.value_net(feat)  # [B, 1]
                v_target, _ = self.v_target_net(
                    BATCH.obs_, cell_state=cell_state)  # [B, 1]

                q1_all, q2_all = self.q_net.get_value(feat)  # [B, A]

                def q_function(x):
                    return tf.reduce_sum(x * BATCH.action,
                                         axis=-1,
                                         keepdims=True)  # [B, 1]

                q1 = q_function(q1_all)
                q2 = q_function(q2_all)
                logits = self.actor_net.value_net(feat)  # [B, A]
                logp_all = tf.nn.log_softmax(logits)  # [B, A]

                entropy = -tf.reduce_sum(tf.exp(logp_all) * logp_all,
                                         axis=1,
                                         keepdims=True)  # [B, 1]
                q_all = self.q_net.get_min(feat)  # [B, A]
                actor_loss = -tf.reduce_mean(
                    tf.reduce_sum((q_all - self.alpha * logp_all) *
                                  tf.exp(logp_all))  # [B, A] => [B,]
                )

                dc_r = tf.stop_gradient(BATCH.reward + self.gamma * v_target *
                                        (1 - BATCH.done))
                td_v = v - tf.stop_gradient(
                    tf.minimum(
                        tf.reduce_sum(tf.exp(logp_all) * q1_all, axis=-1),
                        tf.reduce_sum(tf.exp(logp_all) * q2_all, axis=-1)))
                td_error1 = q1 - dc_r
                td_error2 = q2 - dc_r
                q1_loss = tf.reduce_mean(tf.square(td_error1) * isw)
                q2_loss = tf.reduce_mean(tf.square(td_error2) * isw)
                v_loss_stop = tf.reduce_mean(tf.square(td_v) * isw)
                critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop

                if self.auto_adaption:
                    corr = tf.stop_gradient(self.target_entropy - entropy)
                    # corr = tf.stop_gradient(tf.reduce_sum((logp_all - self.a_dim) * tf.exp(logp_all), axis=-1))    #[B, A] => [B,]
                    alpha_loss = -tf.reduce_mean(self.alpha * corr)

            critic_grads = tape.gradient(
                critic_loss, self.q_net.trainable_variables +
                self.v_net.trainable_variables)
            self.optimizer_critic.apply_gradients(
                zip(
                    critic_grads, self.q_net.trainable_variables +
                    self.v_net.trainable_variables))
            actor_grads = tape.gradient(actor_loss,
                                        self.actor_net.trainable_variables)
            self.optimizer_actor.apply_gradients(
                zip(actor_grads, self.actor_net.trainable_variables))
            if self.auto_adaption:
                alpha_grad = tape.gradient(alpha_loss, self.log_alpha)
                self.optimizer_alpha.apply_gradients([(alpha_grad,
                                                       self.log_alpha)])
            self.global_step.assign_add(1)
            summaries = dict([['LOSS/actor_loss', actor_loss],
                              ['LOSS/q1_loss', q1_loss],
                              ['LOSS/q2_loss', q2_loss],
                              ['LOSS/v_loss', v_loss_stop],
                              ['LOSS/critic_loss', critic_loss],
                              ['Statistics/log_alpha', self.log_alpha],
                              ['Statistics/alpha', self.alpha],
                              ['Statistics/entropy',
                               tf.reduce_mean(entropy)],
                              ['Statistics/v_mean',
                               tf.reduce_mean(v)]])
            if self.auto_adaption:
                summaries.update({'LOSS/alpha_loss': alpha_loss})
            return (td_error1 + td_error2) / 2, summaries