Esempio n. 1
0
class DQN(Off_Policy):
    '''
    Deep Q-learning Network, DQN, [2013](https://arxiv.org/pdf/1312.5602.pdf), [2015](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
    DQN + LSTM, https://arxiv.org/abs/1507.06527
    '''
    def __init__(self,
                 envspec,
                 lr: float = 5.0e-4,
                 eps_init: float = 1,
                 eps_mid: float = 0.2,
                 eps_final: float = 0.01,
                 init2mid_annealing_step: int = 1000,
                 assign_interval: int = 1000,
                 network_settings: List[int] = [32, 32],
                 **kwargs):
        assert not envspec.is_continuous, 'dqn only support discrete action space'
        super().__init__(envspec=envspec, **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval

        def _create_net(name, representation_net=None):
            return ValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
                value_net_kwargs=dict(output_shape=self.a_dim,
                                      network_settings=network_settings))

        self.q_net = _create_net('dqn_q_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.q_target_net = _create_net('dqn_q_target_net',
                                        self._representation_target_net)

        update_target_net_weights(self.q_target_net.weights,
                                  self.q_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self._worker_params_dict.update(self.q_net._policy_models)

        self._all_params_dict.update(self.q_net._all_models)
        self._all_params_dict.update(optimizer=self.optimizer)
        self._model_post_process()

    def choose_action(self,
                      s: np.ndarray,
                      visual_s: np.ndarray,
                      evaluation: bool = False) -> np.ndarray:
        if np.random.uniform() < self.expl_expt_mng.get_esp(
                self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            q_values, cell_state = self.q_net(s,
                                              visual_s,
                                              cell_state=cell_state)
        return tf.argmax(q_values, axis=1), cell_state

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.q_target_net.weights,
                                      self.q_net.weights)

    def learn(self, **kwargs) -> NoReturn:
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([['LEARNING_RATE/lr',
                           self.lr(self.train_step)]]),
                    'train_data_list':
                    ['s', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done']
                })

    @tf.function(experimental_relax_shapes=True)
    def _train(self, memories, isw, cell_state):
        s, visual_s, a, r, s_, visual_s_, done = memories
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                q, _ = self.q_net(s, visual_s, cell_state=cell_state)
                q_next, _ = self.q_target_net(s_,
                                              visual_s_,
                                              cell_state=cell_state)
                q_eval = tf.reduce_sum(tf.multiply(q, a),
                                       axis=1,
                                       keepdims=True)
                q_target = tf.stop_gradient(
                    r + self.gamma *
                    (1 - done) * tf.reduce_max(q_next, axis=1, keepdims=True))
                td_error = q_eval - q_target
                q_loss = tf.reduce_mean(tf.square(td_error) * isw)
            grads = tape.gradient(q_loss, self.q_net.trainable_variables)
            self.optimizer.apply_gradients(
                zip(grads, self.q_net.trainable_variables))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/loss', q_loss],
                 ['Statistics/q_max',
                  tf.reduce_max(q_eval)],
                 ['Statistics/q_min',
                  tf.reduce_min(q_eval)],
                 ['Statistics/q_mean',
                  tf.reduce_mean(q_eval)]])

    @tf.function(experimental_relax_shapes=True)
    def _cal_td(self, memories, cell_state):
        s, visual_s, a, r, s_, visual_s_, done = memories
        with tf.device(self.device):
            q = self.q_net(s, visual_s, cell_state=cell_state)
            q_next = self.q_target_net(s_, visual_s_, cell_state=cell_state)
            q_eval = tf.reduce_sum(tf.multiply(q, a), axis=1, keepdims=True)
            q_target = tf.stop_gradient(
                r + self.gamma *
                (1 - done) * tf.reduce_max(q_next, axis=1, keepdims=True))
            td_error = q_eval - q_target
        return td_error

    def apex_learn(self, train_step, data, priorities):
        self.train_step = train_step
        return self._apex_learn(function_dict={
            'summary_dict':
            dict([['LEARNING_RATE/lr',
                   self.lr(self.train_step)]])
        },
                                data=data,
                                priorities=priorities)

    def apex_cal_td(self, data):
        return self._apex_cal_td(data=data)
Esempio n. 2
0
class MAXSQN(make_off_policy_class(mode='share')):
    '''
    https://github.com/createamind/DRL/blob/master/spinup/algos/maxsqn/maxsqn.py
    '''
    def __init__(self,
                 s_dim,
                 visual_sources,
                 visual_resolution,
                 a_dim,
                 is_continuous,
                 alpha=0.2,
                 beta=0.1,
                 ployak=0.995,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 use_epsilon=False,
                 q_lr=5.0e-4,
                 alpha_lr=5.0e-4,
                 auto_adaption=True,
                 hidden_units=[32, 32],
                 **kwargs):
        assert not is_continuous, 'maxsqn only support discrete action space'
        super().__init__(s_dim=s_dim,
                         visual_sources=visual_sources,
                         visual_resolution=visual_resolution,
                         a_dim=a_dim,
                         is_continuous=is_continuous,
                         **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.use_epsilon = use_epsilon
        self.ployak = ployak
        self.log_alpha = alpha if not auto_adaption else tf.Variable(
            initial_value=0.0,
            name='log_alpha',
            dtype=tf.float32,
            trainable=True)
        self.auto_adaption = auto_adaption
        self.target_entropy = beta * np.log(self.a_dim)

        def _q_net():
            return Critic(self.feat_dim, self.a_dim, hidden_units)

        self.critic_net = DoubleQ(_q_net)
        self.critic_target_net = DoubleQ(_q_net)
        self.critic_tv = self.critic_net.trainable_variables + self.other_tv
        update_target_net_weights(self.critic_target_net.weights,
                                  self.critic_net.weights)
        self.q_lr, self.alpha_lr = map(self.init_lr, [q_lr, alpha_lr])
        self.optimizer_critic, self.optimizer_alpha = map(
            self.init_optimizer, [self.q_lr, self.alpha_lr])

        self.model_recorder(
            dict(critic_net=self.critic_net,
                 optimizer_critic=self.optimizer_critic,
                 optimizer_alpha=self.optimizer_alpha))

    def show_logo(self):
        self.logger.info('''
   xx     xx                                      xxxxxx         xxxxxx       xxxx   xx   
   xxx   xxx                                     xxx xxx        xxxx xxx      xxxx   xx   
   xxx   xxx        xxxxx          x   xx        xx             xx    xx      xxxxx  xx   
   xxxx  xxx       xxxxxx          xx xxx        xxxxxx         xx    xxx     xx xxx xx   
   xxxx xx x        x  xxx         xxxxx          xxxxxx       xx      xx     xx  xxxxx   
   xxxx xx x        xxxxxx          xxx               xxx      xxx  x xxx     xx   xxxx   
   xx xxx  x       xxx  xx          xxx          xx    xx       xx xxxxx      xx   xxxx   
   xx xxx  x       xx  xxx         xxxxx        xxxxxxxxx       xxx xxxx      xx    xxx   
   xx xxx  x       xxxxxxxx       xxx xxx        xxxxxxx         xxxxxxx      xx     xx   
                                                                  xxxxxxx                       
        ''')

    @property
    def alpha(self):
        return tf.exp(self.log_alpha)

    def choose_action(self, s, visual_s, evaluation=False):
        if self.use_epsilon and np.random.uniform(
        ) < self.expl_expt_mng.get_esp(self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            mu, pi, self.cell_state = self._get_action(s, visual_s,
                                                       self.cell_state)
            a = pi.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            feat, cell_state = self.get_feature(s,
                                                visual_s,
                                                cell_state=cell_state,
                                                record_cs=True)
            q = self.critic_net.Q1(feat)
            cate_dist = tfp.distributions.Categorical(logits=q / self.alpha)
            pi = cate_dist.sample()
        return tf.argmax(q, axis=1), pi, cell_state

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'train_function':
                    self.train,
                    'update_function':
                    lambda: update_target_net_weights(
                        self.critic_target_net.weights, self.critic_net.
                        weights, self.ployak),
                    'summary_dict':
                    dict([['LEARNING_RATE/q_lr',
                           self.q_lr(self.train_step)],
                          [
                              'LEARNING_RATE/alpha_lr',
                              self.alpha_lr(self.train_step)
                          ]])
                })

    @tf.function(experimental_relax_shapes=True)
    def train(self, memories, isw, crsty_loss, cell_state):
        ss, vvss, a, r, done = memories
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                feat, feat_ = self.get_feature(ss,
                                               vvss,
                                               cell_state=cell_state,
                                               s_and_s_=True)
                q1, q2 = self.critic_net(feat)
                q1_eval = tf.reduce_sum(tf.multiply(q1, a),
                                        axis=1,
                                        keepdims=True)
                q2_eval = tf.reduce_sum(tf.multiply(q2, a),
                                        axis=1,
                                        keepdims=True)

                q1_target, q2_target = self.critic_target_net(feat_)
                q1_target_max = tf.reduce_max(q1_target, axis=1, keepdims=True)
                q1_target_log_probs = tf.nn.log_softmax(q1_target / self.alpha,
                                                        axis=1) + 1e-8
                q1_target_entropy = -tf.reduce_mean(
                    tf.reduce_sum(
                        tf.exp(q1_target_log_probs) * q1_target_log_probs,
                        axis=1,
                        keepdims=True))

                q2_target_max = tf.reduce_max(q2_target, axis=1, keepdims=True)
                # q2_target_log_probs = tf.nn.log_softmax(q2_target, axis=1)
                # q2_target_log_max = tf.reduce_max(q2_target_log_probs, axis=1, keepdims=True)

                q_target = tf.minimum(
                    q1_target_max,
                    q2_target_max) + self.alpha * q1_target_entropy
                dc_r = tf.stop_gradient(r + self.gamma * q_target * (1 - done))
                td_error1 = q1_eval - dc_r
                td_error2 = q2_eval - dc_r
                q1_loss = tf.reduce_mean(tf.square(td_error1) * isw)
                q2_loss = tf.reduce_mean(tf.square(td_error2) * isw)
                loss = 0.5 * (q1_loss + q2_loss) + crsty_loss
            loss_grads = tape.gradient(loss, self.critic_tv)
            self.optimizer_critic.apply_gradients(
                zip(loss_grads, self.critic_tv))
            if self.auto_adaption:
                with tf.GradientTape() as tape:
                    q1 = self.critic_net.Q1(feat)
                    q1_log_probs = tf.nn.log_softmax(q1 / self.alpha,
                                                     axis=1) + 1e-8
                    q1_entropy = -tf.reduce_mean(
                        tf.reduce_sum(tf.exp(q1_log_probs) * q1_log_probs,
                                      axis=1,
                                      keepdims=True))
                    alpha_loss = -tf.reduce_mean(
                        self.alpha *
                        tf.stop_gradient(self.target_entropy - q1_entropy))
                alpha_grad = tape.gradient(alpha_loss, self.log_alpha)
                self.optimizer_alpha.apply_gradients([(alpha_grad,
                                                       self.log_alpha)])
            self.global_step.assign_add(1)
            summaries = dict(
                [['LOSS/loss', loss], ['Statistics/log_alpha', self.log_alpha],
                 ['Statistics/alpha', self.alpha],
                 ['Statistics/q1_entropy', q1_entropy],
                 ['Statistics/q_min',
                  tf.reduce_mean(tf.minimum(q1, q2))],
                 ['Statistics/q_mean', tf.reduce_mean(q1)],
                 ['Statistics/q_max',
                  tf.reduce_mean(tf.maximum(q1, q2))]])
            if self.auto_adaption:
                summaries.update({'LOSS/alpha_loss': alpha_loss})
            return (td_error1 + td_error2) / 2, summaries
Esempio n. 3
0
class BootstrappedDQN(make_off_policy_class(mode='share')):
    '''
    Deep Exploration via Bootstrapped DQN, http://arxiv.org/abs/1602.04621
    '''
    def __init__(self,
                 s_dim,
                 visual_sources,
                 visual_resolution,
                 a_dim,
                 is_continuous,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 head_num=4,
                 hidden_units=[32, 32],
                 **kwargs):
        assert not is_continuous, 'Bootstrapped DQN only support discrete action space'
        super().__init__(s_dim=s_dim,
                         visual_sources=visual_sources,
                         visual_resolution=visual_resolution,
                         a_dim=a_dim,
                         is_continuous=is_continuous,
                         **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval
        self.head_num = head_num
        self._probs = [1. / head_num for _ in range(head_num)]
        self.now_head = 0

        def _q_net():
            return NetWork(self.feat_dim, self.a_dim, self.head_num,
                           hidden_units)

        self.q_net = _q_net()
        self.q_target_net = _q_net()
        self.critic_tv = self.q_net.trainable_variables + self.other_tv
        update_target_net_weights(self.q_target_net.weights,
                                  self.q_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self.model_recorder(dict(model=self.q_net, optimizer=self.optimizer))

    def show_logo(self):
        self.logger.info('''
    xxxxxxx                      xxxxxxxx         xxxxxx      xxxx   xxxx  
     xx xxxx                      xxxxxxxx       xxx xxxx       xxx    x   
     xx  xxx                      xx    xxx     xxx   xxxx      xxxx   x   
     xx  xxx                      xx    xxx     xxx    xxx      xxxxx  x   
     xxxxxx      xxx xxxx xxx     xx     xx     xx     xxx      x xxxx x   
     xx xxxx     xxx xxxx xxx     xx     xx     xxx    xxx      x  xxxxx   
     xx  xxx     xxx  xx  xxx     xx    xxx     xxx    xxx      x   xxxx   
     xx   xx                      xx   xxxx     xxx   xxx       x    xxx   
     xx xxxx                      xxxxxxxx       xxxxxxxx      xxx    xx   
    xxxxxxxx                     xxxxxxx          xxxxx                    
                                                    xxxx                   
                                                      xxx 
        ''')

    def reset(self):
        super().reset()
        self.now_head = np.random.randint(self.head_num)

    def choose_action(self, s, visual_s, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(
                self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            q, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            q = q.numpy()
            a = np.argmax(q[self.now_head],
                          axis=1)  # [H, B, A] => [B, A] => [B, ]
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            feat, cell_state = self.get_feature(s,
                                                visual_s,
                                                cell_state=cell_state,
                                                record_cs=True)
            q_values = self.q_net(feat)  # [H, B, A]
        return q_values, cell_state

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        def _update():
            if self.global_step % self.assign_interval == 0:
                update_target_net_weights(self.q_target_net.weights,
                                          self.q_net.weights)

        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'train_function':
                    self.train,
                    'update_function':
                    _update,
                    'summary_dict':
                    dict([['LEARNING_RATE/lr',
                           self.lr(self.train_step)]])
                })

    @tf.function(experimental_relax_shapes=True)
    def train(self, memories, isw, crsty_loss, cell_state):
        ss, vvss, a, r, done = memories
        batch_size = tf.shape(a)[0]
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                feat, feat_ = self.get_feature(ss,
                                               vvss,
                                               cell_state=cell_state,
                                               s_and_s_=True)
                q = self.q_net(feat)  # [H, B, A]
                q_next = self.q_target_net(feat_)  # [H, B, A]
                q_eval = tf.reduce_sum(
                    tf.multiply(q, a), axis=-1,
                    keepdims=True)  # [H, B, A] * [B, A] => [H, B, 1]
                q_target = tf.stop_gradient(
                    r + self.gamma *
                    (1 - done) * tf.reduce_max(q_next, axis=-1, keepdims=True))
                td_error = q_eval - q_target  # [H, B, 1]
                td_error = tf.reduce_sum(td_error, axis=-1)  # [H, B]

                mask_dist = tfp.distributions.Bernoulli(probs=self._probs)
                mask = tf.transpose(mask_dist.sample(batch_size),
                                    [1, 0])  # [H, B]
                q_loss = tf.reduce_mean(tf.square(td_error) * isw) + crsty_loss
            grads = tape.gradient(q_loss, self.critic_tv)
            self.optimizer.apply_gradients(zip(grads, self.critic_tv))
            self.global_step.assign_add(1)
            return tf.reduce_mean(td_error, axis=0), dict([  # [H, B] =>
                ['LOSS/loss', q_loss],
                ['Statistics/q_max', tf.reduce_max(q_eval)],
                ['Statistics/q_min', tf.reduce_min(q_eval)],
                ['Statistics/q_mean',
                 tf.reduce_mean(q_eval)]
            ])
Esempio n. 4
0
class OC(Off_Policy):
    '''
    The Option-Critic Architecture. http://arxiv.org/abs/1609.05140
    '''
    def __init__(self,
                 envspec,
                 q_lr=5.0e-3,
                 intra_option_lr=5.0e-4,
                 termination_lr=5.0e-4,
                 use_eps_greedy=False,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 boltzmann_temperature=1.0,
                 options_num=4,
                 ent_coff=0.01,
                 double_q=False,
                 use_baseline=True,
                 terminal_mask=True,
                 termination_regularizer=0.01,
                 assign_interval=1000,
                 network_settings={
                     'q': [32, 32],
                     'intra_option': [32, 32],
                     'termination': [32, 32]
                 },
                 **kwargs):
        super().__init__(envspec=envspec, **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval
        self.options_num = options_num
        self.termination_regularizer = termination_regularizer
        self.ent_coff = ent_coff
        self.use_baseline = use_baseline
        self.terminal_mask = terminal_mask
        self.double_q = double_q
        self.boltzmann_temperature = boltzmann_temperature
        self.use_eps_greedy = use_eps_greedy

        def _create_net(name, representation_net=None):
            return ValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
                value_net_kwargs=dict(output_shape=self.options_num,
                                      network_settings=network_settings['q']))

        self.q_net = _create_net('q_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.q_target_net = _create_net('q_target_net',
                                        self._representation_target_net)

        self.intra_option_net = ValueNetwork(
            name='intra_option_net',
            value_net_type=OutputNetworkType.OC_INTRA_OPTION,
            value_net_kwargs=dict(
                vector_dim=self._representation_net.h_dim,
                output_shape=self.a_dim,
                options_num=self.options_num,
                network_settings=network_settings['intra_option']))
        self.termination_net = ValueNetwork(
            name='termination_net',
            value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
            value_net_kwargs=dict(
                vector_dim=self._representation_net.h_dim,
                output_shape=self.options_num,
                network_settings=network_settings['termination'],
                out_activation='sigmoid'))

        self.actor_tv = self.intra_option_net.trainable_variables
        if self.is_continuous:
            self.log_std = tf.Variable(initial_value=-0.5 * np.ones(
                (self.options_num, self.a_dim), dtype=np.float32),
                                       trainable=True)  # [P, A]
            self.actor_tv += [self.log_std]
        update_target_net_weights(self.q_target_net.weights,
                                  self.q_net.weights)

        self.q_lr, self.intra_option_lr, self.termination_lr = map(
            self.init_lr, [q_lr, intra_option_lr, termination_lr])
        self.q_optimizer = self.init_optimizer(self.q_lr, clipvalue=5.)
        self.intra_option_optimizer = self.init_optimizer(self.intra_option_lr,
                                                          clipvalue=5.)
        self.termination_optimizer = self.init_optimizer(self.termination_lr,
                                                         clipvalue=5.)

        self._worker_params_dict.update(self.q_net._policy_models)
        self._worker_params_dict.update(self.intra_option_net._policy_models)
        self._worker_params_dict.update(self.termination_net._policy_models)

        self._all_params_dict.update(self.q_net._all_models)
        self._all_params_dict.update(self.intra_option_net._all_models)
        self._all_params_dict.update(self.termination_net._all_models)
        self._all_params_dict.update(
            q_optimizer=self.q_optimizer,
            intra_option_optimizer=self.intra_option_optimizer,
            termination_optimizer=self.termination_optimizer)
        self._model_post_process()

    def _generate_random_options(self):
        return tf.constant(np.random.randint(0, self.options_num,
                                             self.n_agents),
                           dtype=tf.int32)

    def choose_action(self, s, visual_s, evaluation=False):
        if not hasattr(self, 'options'):
            self.options = self._generate_random_options()
        self.last_options = self.options

        a, self.options, self.cell_state = self._get_action(
            s, visual_s, self.cell_state, self.options)
        if self.use_eps_greedy:
            if np.random.uniform() < self.expl_expt_mng.get_esp(
                    self.train_step, evaluation=evaluation):  # epsilon greedy
                self.options = self._generate_random_options()
        a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state, options):
        with tf.device(self.device):
            feat, cell_state = self._representation_net(s,
                                                        visual_s,
                                                        cell_state=cell_state)
            q = self.q_net.value_net(feat)  # [B, P]
            pi = self.intra_option_net.value_net(feat)  # [B, P, A]
            beta = self.termination_net.value_net(feat)  # [B, P]
            options_onehot = tf.one_hot(options,
                                        self.options_num,
                                        dtype=tf.float32)  # [B, P]
            options_onehot_expanded = tf.expand_dims(options_onehot,
                                                     axis=-1)  # [B, P, 1]
            pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1)  # [B, A]
            if self.is_continuous:
                log_std = tf.gather(self.log_std, options)
                mu = tf.math.tanh(pi)
                a, _ = gaussian_clip_rsample(mu, log_std)
            else:
                pi = pi / self.boltzmann_temperature
                dist = tfp.distributions.Categorical(
                    logits=tf.nn.log_softmax(pi))  # [B, ]
                a = dist.sample()
            max_options = tf.cast(tf.argmax(q, axis=-1),
                                  dtype=tf.int32)  # [B, P] => [B, ]
            if self.use_eps_greedy:
                new_options = max_options
            else:
                beta_probs = tf.reduce_sum(beta * options_onehot,
                                           axis=1)  # [B, P] => [B,]
                beta_dist = tfp.distributions.Bernoulli(probs=beta_probs)
                new_options = tf.where(beta_dist.sample() < 1, options,
                                       max_options)
        return a, new_options, cell_state

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.q_target_net.weights,
                                      self.q_net.weights)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'sample_data_list': [
                        's', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done',
                        'last_options', 'options'
                    ],
                    'train_data_list': [
                        's', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done',
                        'last_options', 'options'
                    ],
                    'summary_dict':
                    dict([['LEARNING_RATE/q_lr',
                           self.q_lr(self.train_step)],
                          [
                              'LEARNING_RATE/intra_option_lr',
                              self.intra_option_lr(self.train_step)
                          ],
                          [
                              'LEARNING_RATE/termination_lr',
                              self.termination_lr(self.train_step)
                          ], ['Statistics/option', self.options[0]]])
                })

    @tf.function(experimental_relax_shapes=True)
    def _train(self, memories, isw, cell_state):
        s, visual_s, a, r, s_, visual_s_, done, last_options, options = memories
        last_options = tf.cast(last_options, tf.int32)
        options = tf.cast(options, tf.int32)
        with tf.device(self.device):
            with tf.GradientTape(persistent=True) as tape:
                feat, _ = self._representation_net(s,
                                                   visual_s,
                                                   cell_state=cell_state)
                feat_, _ = self._representation_target_net(
                    s_, visual_s_, cell_state=cell_state)
                q = self.q_net.value_net(feat)  # [B, P]
                pi = self.intra_option_net.value_net(feat)  # [B, P, A]
                beta = self.termination_net.value_net(feat)  # [B, P]
                q_next = self.q_target_net.value_net(
                    feat_)  # [B, P], [B, P, A], [B, P]
                beta_next = self.termination_net.value_net(feat_)  # [B, P]
                options_onehot = tf.one_hot(options,
                                            self.options_num,
                                            dtype=tf.float32)  # [B,] => [B, P]

                q_s = qu_eval = tf.reduce_sum(q * options_onehot,
                                              axis=-1,
                                              keepdims=True)  # [B, 1]
                beta_s_ = tf.reduce_sum(beta_next * options_onehot,
                                        axis=-1,
                                        keepdims=True)  # [B, 1]
                q_s_ = tf.reduce_sum(q_next * options_onehot,
                                     axis=-1,
                                     keepdims=True)  # [B, 1]
                # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L94
                if self.double_q:
                    q_ = self.q_net.value_net(
                        feat)  # [B, P], [B, P, A], [B, P]
                    max_a_idx = tf.one_hot(
                        tf.argmax(q_, axis=-1),
                        self.options_num,
                        dtype=tf.float32)  # [B, P] => [B, ] => [B, P]
                    q_s_max = tf.reduce_sum(q_next * max_a_idx,
                                            axis=-1,
                                            keepdims=True)  # [B, 1]
                else:
                    q_s_max = tf.reduce_max(q_next, axis=-1,
                                            keepdims=True)  # [B, 1]
                u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max  # [B, 1]
                qu_target = tf.stop_gradient(r + self.gamma *
                                             (1 - done) * u_target)
                td_error = qu_target - qu_eval  # gradient : q
                q_loss = tf.reduce_mean(tf.square(td_error) *
                                        isw)  # [B, 1] => 1

                # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L130
                if self.use_baseline:
                    adv = tf.stop_gradient(qu_target - qu_eval)
                else:
                    adv = tf.stop_gradient(qu_target)
                options_onehot_expanded = tf.expand_dims(
                    options_onehot, axis=-1)  # [B, P] => [B, P, 1]
                pi = tf.reduce_sum(pi * options_onehot_expanded,
                                   axis=1)  # [B, P, A] => [B, A]
                if self.is_continuous:
                    log_std = tf.gather(self.log_std, options)
                    mu = tf.math.tanh(pi)
                    log_p = gaussian_likelihood_sum(a, mu, log_std)
                    entropy = gaussian_entropy(log_std)
                else:
                    pi = pi / self.boltzmann_temperature
                    log_pi = tf.nn.log_softmax(pi, axis=-1)  # [B, A]
                    entropy = -tf.reduce_sum(tf.exp(log_pi) * log_pi,
                                             axis=1,
                                             keepdims=True)  # [B, 1]
                    log_p = tf.reduce_sum(a * log_pi, axis=-1,
                                          keepdims=True)  # [B, 1]
                pi_loss = tf.reduce_mean(
                    -(log_p * adv + self.ent_coff * entropy)
                )  # [B, 1] * [B, 1] => [B, 1] => 1

                last_options_onehot = tf.one_hot(
                    last_options, self.options_num,
                    dtype=tf.float32)  # [B,] => [B, P]
                beta_s = tf.reduce_sum(beta * last_options_onehot,
                                       axis=-1,
                                       keepdims=True)  # [B, 1]
                if self.use_eps_greedy:
                    v_s = tf.reduce_max(
                        q, axis=-1,
                        keepdims=True) - self.termination_regularizer  # [B, 1]
                else:
                    v_s = (1 - beta_s) * q_s + beta_s * tf.reduce_max(
                        q, axis=-1, keepdims=True)  # [B, 1]
                    # v_s = tf.reduce_mean(q, axis=-1, keepdims=True)   # [B, 1]
                beta_loss = beta_s * tf.stop_gradient(q_s - v_s)  # [B, 1]
                # https://github.com/lweitkamp/option-critic-pytorch/blob/0c57da7686f8903ed2d8dded3fae832ee9defd1a/option_critic.py#L238
                if self.terminal_mask:
                    beta_loss *= (1 - done)
                beta_loss = tf.reduce_mean(beta_loss)  # [B, 1] => 1

            q_grads = tape.gradient(q_loss, self.q_net.trainable_variables)
            intra_option_grads = tape.gradient(pi_loss, self.actor_tv)
            termination_grads = tape.gradient(
                beta_loss, self.termination_net.trainable_variables)
            self.q_optimizer.apply_gradients(
                zip(q_grads, self.q_net.trainable_variables))
            self.intra_option_optimizer.apply_gradients(
                zip(intra_option_grads, self.actor_tv))
            self.termination_optimizer.apply_gradients(
                zip(termination_grads,
                    self.termination_net.trainable_variables))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/q_loss', tf.reduce_mean(q_loss)],
                 ['LOSS/pi_loss', tf.reduce_mean(pi_loss)],
                 ['LOSS/beta_loss',
                  tf.reduce_mean(beta_loss)],
                 ['Statistics/q_option_max',
                  tf.reduce_max(q_s)],
                 ['Statistics/q_option_min',
                  tf.reduce_min(q_s)],
                 ['Statistics/q_option_mean',
                  tf.reduce_mean(q_s)]])

    def store_data(self, s, visual_s, a, r, s_, visual_s_, done):
        """
        for off-policy training, use this function to store <s, a, r, s_, done> into ReplayBuffer.
        """
        assert isinstance(a,
                          np.ndarray), "store need action type is np.ndarray"
        assert isinstance(r,
                          np.ndarray), "store need reward type is np.ndarray"
        assert isinstance(done,
                          np.ndarray), "store need done type is np.ndarray"
        self._running_average(s)
        self.data.add(
            s,
            visual_s,
            a,
            r[:, np.newaxis],  # 升维
            s_,
            visual_s_,
            done[:, np.newaxis],  # 升维
            self.last_options,
            self.options)

    def no_op_store(self, s, visual_s, a, r, s_, visual_s_, done):
        pass
Esempio n. 5
0
class QRDQN(make_off_policy_class(mode='share')):
    '''
    Quantile Regression DQN
    Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/abs/1710.10044
    No double, no dueling, no noisy net.
    '''
    def __init__(self,
                 s_dim,
                 visual_sources,
                 visual_resolution,
                 a_dim,
                 is_continuous,
                 nums=20,
                 huber_delta=1.,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 hidden_units=[128, 128],
                 **kwargs):
        assert not is_continuous, 'qrdqn only support discrete action space'
        assert nums > 0
        super().__init__(s_dim=s_dim,
                         visual_sources=visual_sources,
                         visual_resolution=visual_resolution,
                         a_dim=a_dim,
                         is_continuous=is_continuous,
                         **kwargs)
        self.nums = nums
        self.huber_delta = huber_delta
        self.quantiles = tf.reshape(
            tf.constant((2 * np.arange(self.nums) + 1) / (2.0 * self.nums),
                        dtype=tf.float32), [-1, self.nums])  # [1, N]
        self.batch_quantiles = tf.tile(self.quantiles,
                                       [self.a_dim, 1])  # [1, N] => [A, N]
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval

        def _net():
            return NetWork(self.feat_dim, self.a_dim, self.nums, hidden_units)

        self.q_dist_net = _net()
        self.q_target_dist_net = _net()
        self.critic_tv = self.q_dist_net.trainable_variables + self.other_tv
        update_target_net_weights(self.q_target_dist_net.weights,
                                  self.q_dist_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)
        self.model_recorder(
            dict(model=self.q_dist_net, optimizer=self.optimizer))

    def show_logo(self):
        self.logger.info('''
     xxxxxx       xxxxxxx        xxxxxxxx         xxxxxx      xxxx   xxxx  
    xxx xxxx       xxxxxxx        xxxxxxxx       xxx xxxx       xxx    x   
   xxx   xxxx      xx  xxx        xx    xxx     xxx   xxxx      xxxx   x   
   xxx    xxx      xx  xxx        xx    xxx     xxx    xxx      xxxxx  x   
   xx     xxx      xxxxxx         xx     xx     xx     xxx      x xxxx x   
   xxx    xxx      xxxxxx         xx     xx     xxx    xxx      x  xxxxx   
   xxx    xxx      xx xxxx        xx    xxx     xxx    xxx      x   xxxx   
   xxx   xxx       xx  xxx        xx   xxxx     xxx   xxx       x    xxx   
    xxxxxxxx      xxxxx xxxx      xxxxxxxx       xxxxxxxx      xxx    xx   
     xxxxx        xxxxx xxxx     xxxxxxx          xxxxx                    
       xxxx                                         xxxx                   
         xxx                                          xxx     
        ''')

    def choose_action(self, s, visual_s, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(
                self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            feat, cell_state = self.get_feature(s,
                                                visual_s,
                                                cell_state=cell_state,
                                                record_cs=True)
            q = self.get_q(feat)  # [B, A]
        return tf.argmax(q, axis=-1), cell_state  # [B, 1]

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        def _update():
            if self.global_step % self.assign_interval == 0:
                update_target_net_weights(self.q_target_dist_net.weights,
                                          self.q_dist_net.weights)

        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'train_function':
                    self.train,
                    'update_function':
                    _update,
                    'summary_dict':
                    dict([['LEARNING_RATE/lr',
                           self.lr(self.train_step)]])
                })

    @tf.function(experimental_relax_shapes=True)
    def train(self, memories, isw, crsty_loss, cell_state):
        ss, vvss, a, r, done = memories
        batch_size = tf.shape(a)[0]
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                feat, feat_ = self.get_feature(ss,
                                               vvss,
                                               cell_state=cell_state,
                                               s_and_s_=True)
                indexs = tf.reshape(tf.range(batch_size), [-1, 1])  # [B, 1]
                q_dist = self.q_dist_net(feat)  # [B, A, N]
                q_dist = tf.transpose(
                    tf.reduce_sum(tf.transpose(q_dist, [2, 0, 1]) * a,
                                  axis=-1), [1, 0])  # [B, N]
                target_q_dist = self.q_target_dist_net(feat_)  # [B, A, N]
                target_q = tf.reduce_sum(self.batch_quantiles * target_q_dist,
                                         axis=-1)  # [B, A, N] => [B, A]
                a_ = tf.reshape(
                    tf.cast(tf.argmax(target_q, axis=-1), dtype=tf.int32),
                    [-1, 1])  # [B, 1]
                target_q_dist = tf.gather_nd(target_q_dist,
                                             tf.concat([indexs, a_],
                                                       axis=-1))  # [B, N]
                target = tf.tile(r, tf.constant([1, self.nums])) \
                    + self.gamma * tf.multiply(self.quantiles,   # [1, N]
                                               (1.0 - tf.tile(done, tf.constant([1, self.nums]))))  # [B, N], [1, N]* [B, N] = [B, N]
                q_eval = tf.reduce_sum(q_dist * self.quantiles,
                                       axis=-1)  # [B, 1]
                q_target = tf.reduce_sum(target * self.quantiles,
                                         axis=-1)  # [B, 1]
                td_error = q_eval - q_target  # [B, 1]

                quantile_error = tf.expand_dims(
                    q_dist, axis=-1) - tf.expand_dims(
                        target, axis=1)  # [B, N, 1] - [B, 1, N] => [B, N, N]
                huber = huber_loss(quantile_error,
                                   delta=self.huber_delta)  # [B, N, N]
                huber_abs = tf.abs(
                    self.quantiles -
                    tf.where(quantile_error < 0, tf.ones_like(quantile_error),
                             tf.zeros_like(quantile_error))
                )  # [1, N] - [B, N, N] => [B, N, N]
                loss = tf.reduce_mean(huber_abs * huber,
                                      axis=-1)  # [B, N, N] => [B, N]
                loss = tf.reduce_sum(loss, axis=-1)  # [B, N] => [B, ]
                loss = tf.reduce_mean(loss * isw) + crsty_loss  # [B, ] => 1
            grads = tape.gradient(loss, self.critic_tv)
            self.optimizer.apply_gradients(zip(grads, self.critic_tv))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/loss', loss],
                 ['Statistics/q_max',
                  tf.reduce_max(q_eval)],
                 ['Statistics/q_min',
                  tf.reduce_min(q_eval)],
                 ['Statistics/q_mean',
                  tf.reduce_mean(q_eval)]])

    @tf.function(experimental_relax_shapes=True)
    def get_q(self, feat):
        with tf.device(self.device):
            return tf.reduce_sum(self.batch_quantiles * self.q_dist_net(feat),
                                 axis=-1)  # [B, A, N] => [B, A]
Esempio n. 6
0
class C51(Off_Policy):
    '''
    Category 51, https://arxiv.org/abs/1707.06887
    No double, no dueling, no noisy net.
    '''
    def __init__(self,
                 envspec,
                 v_min=-10,
                 v_max=10,
                 atoms=51,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 network_settings=[128, 128],
                 **kwargs):
        assert not envspec.is_continuous, 'c51 only support discrete action space'
        super().__init__(envspec=envspec, **kwargs)
        self.v_min = v_min
        self.v_max = v_max
        self.atoms = atoms
        self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1)
        self.z = tf.reshape(
            tf.constant(
                [self.v_min + i * self.delta_z for i in range(self.atoms)],
                dtype=tf.float32), [-1, self.atoms])  # [1, N]
        self.zb = tf.tile(self.z, tf.constant([self.a_dim, 1]))  # [A, N]
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval

        def _create_net(name, representation_net=None):
            return ValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.C51_DISTRIBUTIONAL,
                value_net_kwargs=dict(action_dim=self.a_dim,
                                      atoms=self.atoms,
                                      network_settings=network_settings))

        self.q_dist_net = _create_net('q_dist_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.q_target_dist_net = _create_net('q_target_dist_net',
                                             self._representation_target_net)
        update_target_net_weights(self.q_target_dist_net.weights,
                                  self.q_dist_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self._worker_params_dict.update(self.q_dist_net._policy_models)

        self._all_params_dict.update(self.q_dist_net._all_models)
        self._all_params_dict.update(optimizer=self.optimizer)
        self._model_post_process()

    def choose_action(self, obs, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(
                self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(obs, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, obs, cell_state):
        with tf.device(self.device):
            feat, cell_state = self.q_dist_net(obs, cell_state=cell_state)
            q = tf.reduce_sum(self.zb * feat, axis=-1)  # [B, A, N] => [B, A]
        return tf.argmax(q, axis=-1), cell_state  # [B, 1]

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.q_target_dist_net.weights,
                                      self.q_dist_net.weights)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([['LEARNING_RATE/lr',
                           self.lr(self.train_step)]])
                })

    @tf.function
    def _train(self, BATCH, isw, cell_state):
        batch_size = tf.shape(BATCH.action)[0]
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                indexes = tf.reshape(tf.range(batch_size), [-1, 1])  # [B, 1]
                q_dist, _ = self.q_dist_net(BATCH.obs,
                                            cell_state=cell_state)  # [B, A, N]
                q_dist = tf.transpose(
                    tf.reduce_sum(tf.transpose(q_dist, [2, 0, 1]) *
                                  BATCH.action,
                                  axis=-1), [1, 0])  # [B, N]
                q_eval = tf.reduce_sum(q_dist * self.z, axis=-1)
                target_q_dist, _ = self.q_target_dist_net(
                    BATCH.obs_, cell_state=cell_state)  # [B, A, N]
                target_q = tf.reduce_sum(self.zb * target_q_dist,
                                         axis=-1)  # [B, A, N] => [B, A]
                a_ = tf.reshape(
                    tf.cast(tf.argmax(target_q, axis=-1), dtype=tf.int32),
                    [-1, 1])  # [B, 1]
                target_q_dist = tf.gather_nd(target_q_dist,
                                             tf.concat([indexes, a_],
                                                       axis=-1))  # [B, N]
                target = tf.tile(BATCH.reward, tf.constant([1, self.atoms])) \
                    + self.gamma * tf.multiply(self.z,   # [1, N]
                                               (1.0 - tf.tile(BATCH.done, tf.constant([1, self.atoms]))))  # [B, N], [1, N]* [B, N] = [B, N]
                target = tf.clip_by_value(target, self.v_min,
                                          self.v_max)  # [B, N]
                b = (target - self.v_min) / self.delta_z  # [B, N]
                u, l = tf.math.ceil(b), tf.math.floor(b)  # [B, N]
                u_id, l_id = tf.cast(u, tf.int32), tf.cast(l,
                                                           tf.int32)  # [B, N]
                u_minus_b, b_minus_l = u - b, b - l  # [B, N]
                index_help = tf.tile(indexes,
                                     tf.constant([1, self.atoms]))  # [B, N]
                index_help = tf.expand_dims(index_help, -1)  # [B, N, 1]
                u_id = tf.concat(
                    [index_help, tf.expand_dims(u_id, -1)],
                    axis=-1)  # [B, N, 2]
                l_id = tf.concat(
                    [index_help, tf.expand_dims(l_id, -1)],
                    axis=-1)  # [B, N, 2]
                _cross_entropy = tf.stop_gradient(target_q_dist * u_minus_b) * tf.math.log(tf.gather_nd(q_dist, l_id)) \
                    + tf.stop_gradient(target_q_dist * b_minus_l) * tf.math.log(tf.gather_nd(q_dist, u_id))  # [B, N]
                # tf.debugging.check_numerics(_cross_entropy, '_cross_entropy')
                cross_entropy = -tf.reduce_sum(_cross_entropy, axis=-1)  # [B,]
                # tf.debugging.check_numerics(cross_entropy, 'cross_entropy')
                loss = tf.reduce_mean(cross_entropy * isw)
                td_error = cross_entropy
            grads = tape.gradient(loss, self.q_dist_net.trainable_variables)
            self.optimizer.apply_gradients(
                zip(grads, self.q_dist_net.trainable_variables))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/loss', loss],
                 ['Statistics/q_max',
                  tf.reduce_max(q_eval)],
                 ['Statistics/q_min',
                  tf.reduce_min(q_eval)],
                 ['Statistics/q_mean',
                  tf.reduce_mean(q_eval)]])
Esempio n. 7
0
class MAXSQN(Off_Policy):
    '''
    https://github.com/createamind/DRL/blob/master/spinup/algos/maxsqn/maxsqn.py
    '''
    def __init__(self,
                 envspec,
                 alpha=0.2,
                 beta=0.1,
                 ployak=0.995,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 use_epsilon=False,
                 q_lr=5.0e-4,
                 alpha_lr=5.0e-4,
                 auto_adaption=True,
                 network_settings=[32, 32],
                 **kwargs):
        assert not envspec.is_continuous, 'maxsqn only support discrete action space'
        super().__init__(envspec=envspec, **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.use_epsilon = use_epsilon
        self.ployak = ployak
        self.log_alpha = alpha if not auto_adaption else tf.Variable(
            initial_value=0.0,
            name='log_alpha',
            dtype=tf.float32,
            trainable=True)
        self.auto_adaption = auto_adaption
        self.target_entropy = beta * np.log(self.a_dim)

        def _create_net(name, representation_net=None):
            return DoubleValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
                value_net_kwargs=dict(output_shape=self.a_dim,
                                      network_settings=network_settings))

        self.critic_net = _create_net('critic_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.critic_target_net = _create_net('critic_target_net',
                                             self._representation_target_net)

        update_target_net_weights(self.critic_target_net.weights,
                                  self.critic_net.weights)
        self.q_lr, self.alpha_lr = map(self.init_lr, [q_lr, alpha_lr])
        self.optimizer_critic, self.optimizer_alpha = map(
            self.init_optimizer, [self.q_lr, self.alpha_lr])

        self._worker_params_dict.update(self.critic_net._policy_models)

        self._all_params_dict.update(self.critic_net._all_models)
        self._all_params_dict.update(optimizer_critic=self.optimizer_critic,
                                     optimizer_alpha=self.optimizer_alpha)
        self._model_post_process()

    @property
    def alpha(self):
        return tf.exp(self.log_alpha)

    def choose_action(self, s, visual_s, evaluation=False):
        if self.use_epsilon and np.random.uniform(
        ) < self.expl_expt_mng.get_esp(self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            mu, pi, self.cell_state = self._get_action(s, visual_s,
                                                       self.cell_state)
            a = pi.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            q, _, cell_state = self.critic_net(s,
                                               visual_s,
                                               cell_state=cell_state)
            cate_dist = tfp.distributions.Categorical(
                logits=tf.nn.log_softmax(q / self.alpha))
            pi = cate_dist.sample()
        return tf.argmax(q, axis=1), pi, cell_state

    def _target_params_update(self):
        update_target_net_weights(self.critic_target_net.weights,
                                  self.critic_net.weights, self.ployak)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([['LEARNING_RATE/q_lr',
                           self.q_lr(self.train_step)],
                          [
                              'LEARNING_RATE/alpha_lr',
                              self.alpha_lr(self.train_step)
                          ]]),
                    'train_data_list':
                    ['s', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done']
                })

    @tf.function(experimental_relax_shapes=True)
    def _train(self, memories, isw, cell_state):
        s, visual_s, a, r, s_, visual_s_, done = memories
        with tf.device(self.device):
            with tf.GradientTape(persistent=True) as tape:
                q1, q2, _ = self.critic_net(s, visual_s, cell_state=cell_state)
                q1_eval = tf.reduce_sum(tf.multiply(q1, a),
                                        axis=1,
                                        keepdims=True)
                q2_eval = tf.reduce_sum(tf.multiply(q2, a),
                                        axis=1,
                                        keepdims=True)

                q1_target, q2_target, _ = self.critic_target_net(
                    s_, visual_s_, cell_state=cell_state)
                q1_target_max = tf.reduce_max(q1_target, axis=1, keepdims=True)
                q1_target_log_probs = tf.nn.log_softmax(q1_target /
                                                        (self.alpha + 1e-8),
                                                        axis=1)
                q1_target_entropy = -tf.reduce_mean(
                    tf.reduce_sum(
                        tf.exp(q1_target_log_probs) * q1_target_log_probs,
                        axis=1,
                        keepdims=True))

                q2_target_max = tf.reduce_max(q2_target, axis=1, keepdims=True)
                # q2_target_log_probs = tf.nn.log_softmax(q2_target, axis=1)
                # q2_target_log_max = tf.reduce_max(q2_target_log_probs, axis=1, keepdims=True)

                q_target = tf.minimum(
                    q1_target_max,
                    q2_target_max) + self.alpha * q1_target_entropy
                dc_r = tf.stop_gradient(r + self.gamma * q_target * (1 - done))
                td_error1 = q1_eval - dc_r
                td_error2 = q2_eval - dc_r
                q1_loss = tf.reduce_mean(tf.square(td_error1) * isw)
                q2_loss = tf.reduce_mean(tf.square(td_error2) * isw)
                loss = 0.5 * (q1_loss + q2_loss)
                if self.auto_adaption:
                    q1_log_probs = tf.nn.log_softmax(q1 / (self.alpha + 1e-8),
                                                     axis=1)
                    q1_entropy = -tf.reduce_mean(
                        tf.reduce_sum(tf.exp(q1_log_probs) * q1_log_probs,
                                      axis=1,
                                      keepdims=True))
                    alpha_loss = -tf.reduce_mean(
                        self.alpha *
                        tf.stop_gradient(self.target_entropy - q1_entropy))
            loss_grads = tape.gradient(loss,
                                       self.critic_net.trainable_variables)
            self.optimizer_critic.apply_gradients(
                zip(loss_grads, self.critic_net.trainable_variables))
            if self.auto_adaption:
                alpha_grad = tape.gradient(alpha_loss, self.log_alpha)
                self.optimizer_alpha.apply_gradients([(alpha_grad,
                                                       self.log_alpha)])
            self.global_step.assign_add(1)
            summaries = dict(
                [['LOSS/loss', loss], ['Statistics/log_alpha', self.log_alpha],
                 ['Statistics/alpha', self.alpha],
                 ['Statistics/q1_entropy', q1_entropy],
                 ['Statistics/q_min',
                  tf.reduce_mean(tf.minimum(q1, q2))],
                 ['Statistics/q_mean', tf.reduce_mean(q1)],
                 ['Statistics/q_max',
                  tf.reduce_mean(tf.maximum(q1, q2))]])
            if self.auto_adaption:
                summaries.update({'LOSS/alpha_loss': alpha_loss})
            return (td_error1 + td_error2) / 2, summaries
Esempio n. 8
0
class QS:
    '''
    Q-learning/Sarsa/Expected Sarsa.
    '''
    def __init__(self,
                 envspec,
                 mode='q',
                 lr=0.2,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 **kwargs):
        assert not hasattr(s_dim, '__len__')
        assert not envspec.is_continuous
        self.mode = mode
        self.s_dim = s_dim
        self.a_dim = a_dim
        self.gamma = float(kwargs.get('gamma', 0.999))
        self.max_train_step = int(kwargs.get('max_train_step', 1000))
        self.step = 0
        self.train_step = 0
        self.n_agents = int(kwargs.get('n_agents', 0))
        if self.n_agents <= 0:
            raise ValueError('agents num must larger than zero.')
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.table = np.zeros(shape=(self.s_dim, self.a_dim))
        self.lr = lr
        self.next_a = np.zeros(self.n_agents, dtype=np.int32)
        self.mask = []
        ion()

    def one_hot2int(self, x):
        idx = [np.where(np.asarray(i))[0][0] for i in x]
        return idx

    def partial_reset(self, done):
        self.mask = np.where(done)[0]

    def choose_action(self, s, visual_s=None, evaluation=False):
        s = self.one_hot2int(s)
        if self.mode == 'q':
            return self._get_action(s, evaluation)
        elif self.mode == 'sarsa' or self.mode == 'expected_sarsa':
            a = self._get_action(s, evaluation)
            self.next_a[self.mask] = a[self.mask]
            return self.next_a

    def _get_action(self, s, evaluation=False, _max=False):
        a = np.array([np.argmax(self.table[i, :]) for i in s])
        if not _max:
            if np.random.uniform() < self.expl_expt_mng.get_esp(
                    self.train_step, evaluation=evaluation):
                a = np.random.randint(0, self.a_dim, self.n_agents)
        return a

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

    def store_data(self, s, visual_s, a, r, s_, visual_s_, done):
        self.step += 1
        s = self.one_hot2int(s)
        s_ = self.one_hot2int(s_)
        if self.mode == 'q':
            a_ = self._get_action(s_, _max=True)
            value = self.table[s_, a_]
        else:
            self.next_a = self._get_action(s_)
            if self.mode == 'expected_sarsa':
                value = np.mean(self.table[s_, :], axis=-1)
            else:
                value = self.table[s_, self.next_a]
        self.table[s, a] = (1 - self.lr) * self.table[s, a] + self.lr * (
            r + self.gamma * (1 - done) * value)
        if self.step % 1000 == 0:
            plot_heatmap(self.s_dim, self.a_dim, self.table)

    def close(self):
        ioff()

    def no_op_store(self, s, visual_s, a, r, s_, visual_s_, done):
        pass

    def __getattr__(self, x):
        # print(x)
        return lambda *args, **kwargs: 0
Esempio n. 9
0
class DQN(make_off_policy_class(mode='share')):
    '''
    Deep Q-learning Network, DQN, [2013](https://arxiv.org/pdf/1312.5602.pdf), [2015](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
    DQN + LSTM, https://arxiv.org/abs/1507.06527
    '''
    def __init__(self,
                 s_dim: Union[int, np.ndarray],
                 visual_sources: Union[int, np.ndarray],
                 visual_resolution: Union[List, np.ndarray],
                 a_dim: Union[int, np.ndarray],
                 is_continuous: Union[bool, np.ndarray],
                 lr: float = 5.0e-4,
                 eps_init: float = 1,
                 eps_mid: float = 0.2,
                 eps_final: float = 0.01,
                 init2mid_annealing_step: int = 1000,
                 assign_interval: int = 1000,
                 hidden_units: List[int] = [32, 32],
                 **kwargs):
        assert not is_continuous, 'dqn only support discrete action space'
        super().__init__(s_dim=s_dim,
                         visual_sources=visual_sources,
                         visual_resolution=visual_resolution,
                         a_dim=a_dim,
                         is_continuous=is_continuous,
                         **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval

        def _q_net():
            return NetWork(self.feat_dim, self.a_dim, hidden_units)

        self.q_net = _q_net()
        self.q_target_net = _q_net()
        self.critic_tv = self.q_net.trainable_variables + self.other_tv
        update_target_net_weights(self.q_target_net.weights,
                                  self.q_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self.model_recorder(dict(model=self.q_net, optimizer=self.optimizer))

    def show_logo(self) -> NoReturn:
        self.logger.info('''
       xxxxxxxx         xxxxxx      xxxx   xxxx  
        xxxxxxxx       xxx xxxx       xxx    x   
        xx    xxx     xxx   xxxx      xxxx   x   
        xx    xxx     xxx    xxx      xxxxx  x   
        xx     xx     xx     xxx      x xxxx x   
        xx     xx     xxx    xxx      x  xxxxx   
        xx    xxx     xxx    xxx      x   xxxx   
        xx   xxxx     xxx   xxx       x    xxx   
        xxxxxxxx       xxxxxxxx      xxx    xx   
       xxxxxxx          xxxxx                    
                          xxxx                   
                            xxx
        ''')

    def choose_action(self,
                      s: np.ndarray,
                      visual_s: np.ndarray,
                      evaluation: bool = False) -> np.ndarray:
        if np.random.uniform() < self.expl_expt_mng.get_esp(
                self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            feat, cell_state = self.get_feature(s,
                                                visual_s,
                                                cell_state=cell_state,
                                                record_cs=True)
            q_values = self.q_net(feat)
        return tf.argmax(q_values, axis=1), cell_state

    def learn(self, **kwargs) -> NoReturn:
        self.train_step = kwargs.get('train_step')

        def _update():
            if self.global_step % self.assign_interval == 0:
                update_target_net_weights(self.q_target_net.weights,
                                          self.q_net.weights)

        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'train_function':
                    self.train,
                    'update_function':
                    _update,
                    'summary_dict':
                    dict([['LEARNING_RATE/lr',
                           self.lr(self.train_step)]])
                })

    @tf.function(experimental_relax_shapes=True)
    def train(self, memories, isw, crsty_loss, cell_state):
        ss, vvss, a, r, done = memories
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                feat, feat_ = self.get_feature(ss,
                                               vvss,
                                               cell_state=cell_state,
                                               s_and_s_=True)
                q = self.q_net(feat)
                q_next = self.q_target_net(feat_)
                q_eval = tf.reduce_sum(tf.multiply(q, a),
                                       axis=1,
                                       keepdims=True)
                q_target = tf.stop_gradient(
                    r + self.gamma *
                    (1 - done) * tf.reduce_max(q_next, axis=1, keepdims=True))
                td_error = q_eval - q_target
                q_loss = tf.reduce_mean(tf.square(td_error) * isw) + crsty_loss
            grads = tape.gradient(q_loss, self.critic_tv)
            self.optimizer.apply_gradients(zip(grads, self.critic_tv))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/loss', q_loss],
                 ['Statistics/q_max',
                  tf.reduce_max(q_eval)],
                 ['Statistics/q_min',
                  tf.reduce_min(q_eval)],
                 ['Statistics/q_mean',
                  tf.reduce_mean(q_eval)]])
Esempio n. 10
0
class DDDQN(Off_Policy):
    '''
    Dueling Double DQN, https://arxiv.org/abs/1511.06581
    '''
    def __init__(self,
                 envspec,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=2,
                 network_settings={
                     'share': [128],
                     'v': [128],
                     'adv': [128]
                 },
                 **kwargs):
        assert not envspec.is_continuous, 'dueling double dqn only support discrete action space'
        super().__init__(envspec=envspec, **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval

        def _create_net(name, representation_net):
            return ValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.CRITIC_DUELING,
                value_net_kwargs=dict(output_shape=self.a_dim,
                                      network_settings=network_settings))

        self.dueling_net = _create_net('dueling_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.dueling_target_net = _create_net('dueling_target_net',
                                              self._representation_target_net)
        update_target_net_weights(self.dueling_target_net.weights,
                                  self.dueling_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self._worker_params_dict.update(self.dueling_net._policy_models)

        self._all_params_dict.update(self.dueling_net._all_models)
        self._all_params_dict.update(optimizer=self.optimizer)
        self._model_post_process()

    def choose_action(self, obs, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(
                self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(obs, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, obs, cell_state):
        with tf.device(self.device):
            q_values, cell_state = self.dueling_net(obs, cell_state=cell_state)
        return tf.argmax(q_values, axis=-1), cell_state

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.dueling_target_net.weights,
                                      self.dueling_net.weights)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([['LEARNING_RATE/lr',
                           self.lr(self.train_step)]]),
                    'use_stack':
                    True
                })

    @tf.function
    def _train(self, BATCH, isw, cell_state):
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                (feat,
                 feat_), _ = self._representation_net(BATCH.obs,
                                                      cell_state=cell_state,
                                                      need_split=True)
                q_target, _ = self.dueling_target_net(BATCH.obs_,
                                                      cell_state=cell_state)
                q = self.dueling_net.value_net(feat)
                q_eval = tf.reduce_sum(tf.multiply(q, BATCH.action),
                                       axis=1,
                                       keepdims=True)
                next_q = self.dueling_net.value_net(feat_)
                next_max_action = tf.argmax(next_q,
                                            axis=1,
                                            name='next_action_int')
                next_max_action_one_hot = tf.one_hot(
                    tf.squeeze(next_max_action),
                    self.a_dim,
                    1.,
                    0.,
                    dtype=tf.float32)
                next_max_action_one_hot = tf.cast(next_max_action_one_hot,
                                                  tf.float32)

                q_target_next_max = tf.reduce_sum(tf.multiply(
                    q_target, next_max_action_one_hot),
                                                  axis=1,
                                                  keepdims=True)
                q_target = tf.stop_gradient(BATCH.reward + self.gamma *
                                            (1 - BATCH.done) *
                                            q_target_next_max)
                td_error = q_target - q_eval
                q_loss = tf.reduce_mean(tf.square(td_error) * isw)
            grads = tape.gradient(q_loss, self.dueling_net.trainable_variables)
            self.optimizer.apply_gradients(
                zip(grads, self.dueling_net.trainable_variables))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/loss', q_loss],
                 ['Statistics/q_max',
                  tf.reduce_max(q_eval)],
                 ['Statistics/q_min',
                  tf.reduce_min(q_eval)],
                 ['Statistics/q_mean',
                  tf.reduce_mean(q_eval)]])
Esempio n. 11
0
def test_exploration_exploitation_class():
    my_expl = ExplorationExploitationClass(eps_init=1, eps_mid=0.2, eps_final=0.01, eps_eval=0,
                                           init2mid_annealing_step=50, start_step=0, max_step=100)
    assert my_expl.get_esp(0) == 1
    assert my_expl.get_esp(0, evaluation=True) == 0
    assert my_expl.get_esp(80, evaluation=True) == 0
    assert my_expl.get_esp(2) < 1
    assert my_expl.get_esp(50) == 0.2
    assert my_expl.get_esp(51) < 0.2
    assert my_expl.get_esp(100) >= 0.01

    my_expl = ExplorationExploitationClass(eps_init=0.2, eps_mid=0.1, eps_final=0, eps_eval=0,
                                           init2mid_annealing_step=1000, start_step=0, max_step=10000)
    assert my_expl.get_esp(0) == 0.2
    assert my_expl.get_esp(0, evaluation=True) == 0
    assert my_expl.get_esp(500, evaluation=True) == 0
    assert my_expl.get_esp(500) < 0.2
    assert my_expl.get_esp(1000) == 0.1
    assert my_expl.get_esp(2000) < 0.1
    assert my_expl.get_esp(9000) > 0
Esempio n. 12
0
class DDDQN(make_off_policy_class(mode='share')):
    '''
    Dueling Double DQN, https://arxiv.org/abs/1511.06581
    '''

    def __init__(self,
                 s_dim,
                 visual_sources,
                 visual_resolution,
                 a_dim,
                 is_continuous,

                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=2,
                 hidden_units={
                     'share': [128],
                     'v': [128],
                     'adv': [128]
                 },
                 **kwargs):
        assert not is_continuous, 'dueling double dqn only support discrete action space'
        super().__init__(
            s_dim=s_dim,
            visual_sources=visual_sources,
            visual_resolution=visual_resolution,
            a_dim=a_dim,
            is_continuous=is_continuous,
            **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self.max_train_step)
        self.assign_interval = assign_interval

        def _net(): return NetWork(self.feat_dim, self.a_dim, hidden_units)

        self.dueling_net = _net()
        self.dueling_target_net = _net()
        self.critic_tv = self.dueling_net.trainable_variables + self.other_tv
        update_target_net_weights(self.dueling_target_net.weights, self.dueling_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self.model_recorder(dict(
            model=self.dueling_net,
            optimizer=self.optimizer
        ))

    def show_logo(self):
        self.logger.info('''
   xxxxxxxx       xxxxxxxx       xxxxxxxx         xxxxxx      xxxx   xxxx  
    xxxxxxxx       xxxxxxxx       xxxxxxxx       xxx xxxx       xxx    x   
    xx    xxx      xx    xxx      xx    xxx     xxx   xxxx      xxxx   x   
    xx    xxx      xx    xxx      xx    xxx     xxx    xxx      xxxxx  x   
    xx     xx      xx     xx      xx     xx     xx     xxx      x xxxx x   
    xx     xx      xx     xx      xx     xx     xxx    xxx      x  xxxxx   
    xx    xxx      xx    xxx      xx    xxx     xxx    xxx      x   xxxx   
    xx   xxxx      xx   xxxx      xx   xxxx     xxx   xxx       x    xxx   
    xxxxxxxx       xxxxxxxx       xxxxxxxx       xxxxxxxx      xxx    xx   
   xxxxxxx        xxxxxxx        xxxxxxx          xxxxx                    
                                                    xxxx                   
                                                      xxx    
        ''')

    def choose_action(self, s, visual_s, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True)
            q = self.dueling_net(feat)
        return tf.argmax(q, axis=-1), cell_state

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        def _update():
            if self.global_step % self.assign_interval == 0:
                update_target_net_weights(self.dueling_target_net.weights, self.dueling_net.weights)
        for i in range(self.train_times_per_step):
            self._learn(function_dict={
                'train_function': self.train,
                'update_function': _update,
                'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.train_step)]])
            })

    @tf.function(experimental_relax_shapes=True)
    def train(self, memories, isw, crsty_loss, cell_state):
        ss, vvss, a, r, done = memories
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True)
                q = self.dueling_net(feat)
                q_eval = tf.reduce_sum(tf.multiply(q, a), axis=1, keepdims=True)
                next_q = self.dueling_net(feat_)
                next_max_action = tf.argmax(next_q, axis=1, name='next_action_int')
                next_max_action_one_hot = tf.one_hot(tf.squeeze(next_max_action), self.a_dim, 1., 0., dtype=tf.float32)
                next_max_action_one_hot = tf.cast(next_max_action_one_hot, tf.float32)
                q_target = self.dueling_target_net(feat_)

                q_target_next_max = tf.reduce_sum(
                    tf.multiply(q_target, next_max_action_one_hot),
                    axis=1, keepdims=True)
                q_target = tf.stop_gradient(r + self.gamma * (1 - done) * q_target_next_max)
                td_error = q_eval - q_target
                q_loss = tf.reduce_mean(tf.square(td_error) * isw) + crsty_loss
            grads = tape.gradient(q_loss, self.critic_tv)
            self.optimizer.apply_gradients(
                zip(grads, self.critic_tv)
            )
            self.global_step.assign_add(1)
            return td_error, dict([
                ['LOSS/loss', q_loss],
                ['Statistics/q_max', tf.reduce_max(q_eval)],
                ['Statistics/q_min', tf.reduce_min(q_eval)],
                ['Statistics/q_mean', tf.reduce_mean(q_eval)]
            ])
Esempio n. 13
0
class BootstrappedDQN(Off_Policy):
    '''
    Deep Exploration via Bootstrapped DQN, http://arxiv.org/abs/1602.04621
    '''

    def __init__(self,
                 envspec,

                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 head_num=4,
                 network_settings=[32, 32],
                 **kwargs):
        assert not envspec.is_continuous, 'Bootstrapped DQN only support discrete action space'
        super().__init__(envspec=envspec, **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self.max_train_step)
        self.assign_interval = assign_interval
        self.head_num = head_num
        self._probs = [1. / head_num for _ in range(head_num)]
        self.now_head = 0

        def _create_net(name, representation_net=None): return ValueNetwork(
            name=name,
            representation_net=representation_net,
            value_net_type=OutputNetworkType.CRITIC_QVALUE_BOOTSTRAP,
            value_net_kwargs=dict(output_shape=self.a_dim, head_num=self.head_num, network_settings=network_settings)
        )

        self.q_net = _create_net('q_net', self._representation_net)
        self._representation_target_net = self._create_representation_net('_representation_target_net')
        self.q_target_net = _create_net('q_target_net', self._representation_target_net)
        update_target_net_weights(self.q_target_net.weights, self.q_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self._worker_params_dict.update(self.q_net._policy_models)

        self._all_params_dict.update(self.q_net._all_models)
        self._all_params_dict.update(optimizer=self.optimizer)
        self._model_post_process()

    def reset(self):
        super().reset()
        self.now_head = np.random.randint(self.head_num)

    def choose_action(self, s, visual_s, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            q, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            q = q.numpy()
            a = np.argmax(q[self.now_head], axis=1)  # [H, B, A] => [B, A] => [B, ]
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            q_values, cell_state = self.q_net(s, visual_s, cell_state=cell_state)  # [H, B, A]
        return q_values, cell_state

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.q_target_net.weights, self.q_net.weights)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(function_dict={
                'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.train_step)]]),
                'train_data_list': ['s', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done']
            })

    @tf.function(experimental_relax_shapes=True)
    def _train(self, memories, isw, cell_state):
        s, visual_s, a, r, s_, visual_s_, done = memories
        batch_size = tf.shape(a)[0]
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                q, _ = self.q_net(s, visual_s, cell_state=cell_state)    # [H, B, A]
                q_next, _ = self.q_target_net(s_, visual_s_, cell_state=cell_state)   # [H, B, A]
                q_eval = tf.reduce_sum(tf.multiply(q, a), axis=-1, keepdims=True)    # [H, B, A] * [B, A] => [H, B, 1]
                q_target = tf.stop_gradient(r + self.gamma * (1 - done) * tf.reduce_max(q_next, axis=-1, keepdims=True))
                td_error = q_eval - q_target    # [H, B, 1]
                td_error = tf.reduce_sum(td_error, axis=-1)  # [H, B]

                mask_dist = tfp.distributions.Bernoulli(probs=self._probs)
                mask = tf.transpose(mask_dist.sample(batch_size), [1, 0])   # [H, B]
                q_loss = tf.reduce_mean(tf.square(td_error) * isw)
            grads = tape.gradient(q_loss, self.q_net.trainable_variables)
            self.optimizer.apply_gradients(
                zip(grads, self.q_net.trainable_variables)
            )
            self.global_step.assign_add(1)
            return tf.reduce_mean(td_error, axis=0), dict([  # [H, B] =>
                ['LOSS/loss', q_loss],
                ['Statistics/q_max', tf.reduce_max(q_eval)],
                ['Statistics/q_min', tf.reduce_min(q_eval)],
                ['Statistics/q_mean', tf.reduce_mean(q_eval)]
            ])
Esempio n. 14
0
class QRDQN(Off_Policy):
    '''
    Quantile Regression DQN
    Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/abs/1710.10044
    No double, no dueling, no noisy net.
    '''
    def __init__(self,
                 envspec,
                 nums=20,
                 huber_delta=1.,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 network_settings=[128, 128],
                 **kwargs):
        assert not envspec.is_continuous, 'qrdqn only support discrete action space'
        assert nums > 0
        super().__init__(envspec=envspec, **kwargs)
        self.nums = nums
        self.huber_delta = huber_delta
        self.quantiles = tf.reshape(
            tf.constant((2 * np.arange(self.nums) + 1) / (2.0 * self.nums),
                        dtype=tf.float32), [-1, self.nums])  # [1, N]
        self.batch_quantiles = tf.tile(self.quantiles,
                                       [self.a_dim, 1])  # [1, N] => [A, N]
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval

        def _create_net(name, representation_net=None):
            return ValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.QRDQN_DISTRIBUTIONAL,
                value_net_kwargs=dict(action_dim=self.a_dim,
                                      nums=self.nums,
                                      network_settings=network_settings))

        self.q_dist_net = _create_net('q_dist_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.q_target_dist_net = _create_net('q_target_dist_net',
                                             self._representation_target_net)
        update_target_net_weights(self.q_target_dist_net.weights,
                                  self.q_dist_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self._worker_params_dict.update(self.q_dist_net._policy_models)

        self._all_params_dict.update(self.q_dist_net._all_models)
        self._all_params_dict.update(optimizer=self.optimizer)
        self._model_post_process()

    def choose_action(self, s, visual_s, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(
                self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            q_values, cell_state = self.q_dist_net(s,
                                                   visual_s,
                                                   cell_state=cell_state)
            q = tf.reduce_sum(self.batch_quantiles * q_values,
                              axis=-1)  # [B, A, N] => [B, A]
        return tf.argmax(q, axis=-1), cell_state  # [B, 1]

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.q_target_dist_net.weights,
                                      self.q_dist_net.weights)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([['LEARNING_RATE/lr',
                           self.lr(self.train_step)]]),
                    'train_data_list':
                    ['s', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done']
                })

    @tf.function(experimental_relax_shapes=True)
    def _train(self, memories, isw, cell_state):
        s, visual_s, a, r, s_, visual_s_, done = memories
        batch_size = tf.shape(a)[0]
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                indexs = tf.reshape(tf.range(batch_size), [-1, 1])  # [B, 1]
                q_dist, _ = self.q_dist_net(s, visual_s,
                                            cell_state=cell_state)  # [B, A, N]
                q_dist = tf.transpose(
                    tf.reduce_sum(tf.transpose(q_dist, [2, 0, 1]) * a,
                                  axis=-1), [1, 0])  # [B, N]
                target_q_dist, _ = self.q_target_dist_net(
                    s_, visual_s_, cell_state=cell_state)  # [B, A, N]
                target_q = tf.reduce_sum(self.batch_quantiles * target_q_dist,
                                         axis=-1)  # [B, A, N] => [B, A]
                a_ = tf.reshape(
                    tf.cast(tf.argmax(target_q, axis=-1), dtype=tf.int32),
                    [-1, 1])  # [B, 1]
                target_q_dist = tf.gather_nd(target_q_dist,
                                             tf.concat([indexs, a_],
                                                       axis=-1))  # [B, N]
                target = tf.tile(r, tf.constant([1, self.nums])) \
                    + self.gamma * tf.multiply(self.quantiles,   # [1, N]
                                               (1.0 - tf.tile(done, tf.constant([1, self.nums]))))  # [B, N], [1, N]* [B, N] = [B, N]
                q_eval = tf.reduce_sum(q_dist * self.quantiles,
                                       axis=-1)  # [B, 1]
                q_target = tf.reduce_sum(target * self.quantiles,
                                         axis=-1)  # [B, 1]
                td_error = q_eval - q_target  # [B, 1]

                quantile_error = tf.expand_dims(
                    q_dist, axis=-1) - tf.expand_dims(
                        target, axis=1)  # [B, N, 1] - [B, 1, N] => [B, N, N]
                huber = huber_loss(quantile_error,
                                   delta=self.huber_delta)  # [B, N, N]
                huber_abs = tf.abs(
                    self.quantiles -
                    tf.where(quantile_error < 0, tf.ones_like(quantile_error),
                             tf.zeros_like(quantile_error))
                )  # [1, N] - [B, N, N] => [B, N, N]
                loss = tf.reduce_mean(huber_abs * huber,
                                      axis=-1)  # [B, N, N] => [B, N]
                loss = tf.reduce_sum(loss, axis=-1)  # [B, N] => [B, ]
                loss = tf.reduce_mean(loss * isw)  # [B, ] => 1
            grads = tape.gradient(loss, self.q_dist_net.trainable_variables)
            self.optimizer.apply_gradients(
                zip(grads, self.q_dist_net.trainable_variables))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/loss', loss],
                 ['Statistics/q_max',
                  tf.reduce_max(q_eval)],
                 ['Statistics/q_min',
                  tf.reduce_min(q_eval)],
                 ['Statistics/q_mean',
                  tf.reduce_mean(q_eval)]])
Esempio n. 15
0
class RAINBOW(Off_Policy):
    '''
    Rainbow DQN:    https://arxiv.org/abs/1710.02298
        1. Double
        2. Dueling
        3. PrioritizedExperienceReplay
        4. N-Step
        5. Distributional
        6. Noisy Net
    '''
    def __init__(self,
                 envspec,
                 v_min=-10,
                 v_max=10,
                 atoms=51,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=2,
                 network_settings={
                     'share': [128],
                     'v': [128],
                     'adv': [128]
                 },
                 **kwargs):
        assert not envspec.is_continuous, 'rainbow only support discrete action space'
        super().__init__(envspec=envspec, **kwargs)
        self.v_min = v_min
        self.v_max = v_max
        self.atoms = atoms
        self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1)
        self.z = tf.reshape(
            tf.constant(
                [self.v_min + i * self.delta_z for i in range(self.atoms)],
                dtype=tf.float32), [-1, self.atoms])  # [1, N]
        self.zb = tf.tile(self.z, tf.constant([self.a_dim, 1]))  # [A, N]
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)
        self.assign_interval = assign_interval

        def _create_net(name, representation_net=None):
            return ValueNetwork(
                name=name,
                representation_net=representation_net,
                value_net_type=OutputNetworkType.RAINBOW_DUELING,
                value_net_kwargs=dict(action_dim=self.a_dim,
                                      atoms=self.atoms,
                                      network_settings=network_settings))

        self.rainbow_net = _create_net('rainbow_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.rainbow_target_net = _create_net('rainbow_target_net',
                                              self._representation_target_net)
        update_target_net_weights(self.rainbow_target_net.weights,
                                  self.rainbow_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self._worker_params_dict.update(self.rainbow_net._policy_models)

        self._all_params_dict.update(self.rainbow_net._all_models)
        self._all_params_dict.update(optimizer=self.optimizer)
        self._model_post_process()

    def choose_action(self, s, visual_s, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(
                self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            q_values, cell_state = self.rainbow_net(s,
                                                    visual_s,
                                                    cell_state=cell_state)
            q = tf.reduce_sum(self.zb * q_values,
                              axis=-1)  # [B, A, N] => [B, A]
        return tf.argmax(q, axis=-1), cell_state  # [B, 1]

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.rainbow_target_net.weights,
                                      self.rainbow_net.weights)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([['LEARNING_RATE/lr',
                           self.lr(self.train_step)]]),
                    'train_data_list':
                    ['ss', 'vvss', 'a', 'r', 'done', 's_', 'visual_s_']
                })

    @tf.function(experimental_relax_shapes=True)
    def _train(self, memories, isw, cell_state):
        ss, vvss, a, r, done, s_, visual_s_ = memories
        batch_size = tf.shape(a)[0]
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                (feat,
                 feat_), _ = self._representation_net(ss,
                                                      vvss,
                                                      cell_state=cell_state,
                                                      need_split=True)
                indexs = tf.reshape(tf.range(batch_size), [-1, 1])  # [B, 1]
                q_dist = self.rainbow_net.value_net(feat)  # [B, A, N]
                q_dist = tf.transpose(
                    tf.reduce_sum(tf.transpose(q_dist, [2, 0, 1]) * a,
                                  axis=-1), [1, 0])  # [B, N]
                q_eval = tf.reduce_sum(q_dist * self.z, axis=-1)
                target_q = self.rainbow_net.value_net(feat_)
                target_q = tf.reduce_sum(self.zb * target_q,
                                         axis=-1)  # [B, A, N] => [B, A]
                a_ = tf.reshape(
                    tf.cast(tf.argmax(target_q, axis=-1), dtype=tf.int32),
                    [-1, 1])  # [B, 1]

                target_q_dist, _ = self.rainbow_target_net(
                    s_, visual_s_, cell_state=cell_state)  # [B, A, N]
                target_q_dist = tf.gather_nd(target_q_dist,
                                             tf.concat([indexs, a_],
                                                       axis=-1))  # [B, N]
                target = tf.tile(r, tf.constant([1, self.atoms])) \
                    + self.gamma * tf.multiply(self.z,   # [1, N]
                                               (1.0 - tf.tile(done, tf.constant([1, self.atoms]))))  # [B, N], [1, N]* [B, N] = [B, N]
                target = tf.clip_by_value(target, self.v_min,
                                          self.v_max)  # [B, N]
                b = (target - self.v_min) / self.delta_z  # [B, N]
                u, l = tf.math.ceil(b), tf.math.floor(b)  # [B, N]
                u_id, l_id = tf.cast(u, tf.int32), tf.cast(l,
                                                           tf.int32)  # [B, N]
                u_minus_b, b_minus_l = u - b, b - l  # [B, N]
                index_help = tf.tile(indexs,
                                     tf.constant([1, self.atoms]))  # [B, N]
                index_help = tf.expand_dims(index_help, -1)  # [B, N, 1]
                u_id = tf.concat(
                    [index_help, tf.expand_dims(u_id, -1)],
                    axis=-1)  # [B, N, 2]
                l_id = tf.concat(
                    [index_help, tf.expand_dims(l_id, -1)],
                    axis=-1)  # [B, N, 2]
                _cross_entropy = tf.stop_gradient(target_q_dist * u_minus_b) * tf.math.log(tf.gather_nd(q_dist, l_id)) \
                    + tf.stop_gradient(target_q_dist * b_minus_l) * tf.math.log(tf.gather_nd(q_dist, u_id))  # [B, N]
                cross_entropy = -tf.reduce_sum(_cross_entropy, axis=-1)  # [B,]
                loss = tf.reduce_mean(cross_entropy * isw)
                td_error = cross_entropy
            grads = tape.gradient(loss, self.rainbow_net.trainable_variables)
            self.optimizer.apply_gradients(
                zip(grads, self.rainbow_net.trainable_variables))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/loss', loss],
                 ['Statistics/q_max',
                  tf.reduce_max(q_eval)],
                 ['Statistics/q_min',
                  tf.reduce_min(q_eval)],
                 ['Statistics/q_mean',
                  tf.reduce_mean(q_eval)]])
Esempio n. 16
0
class C51(make_off_policy_class(mode='share')):
    '''
    Category 51, https://arxiv.org/abs/1707.06887
    No double, no dueling, no noisy net.
    '''

    def __init__(self,
                 s_dim,
                 visual_sources,
                 visual_resolution,
                 a_dim,
                 is_continuous,

                 v_min=-10,
                 v_max=10,
                 atoms=51,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 hidden_units=[128, 128],
                 **kwargs):
        assert not is_continuous, 'c51 only support discrete action space'
        super().__init__(
            s_dim=s_dim,
            visual_sources=visual_sources,
            visual_resolution=visual_resolution,
            a_dim=a_dim,
            is_continuous=is_continuous,
            **kwargs)
        self.v_min = v_min
        self.v_max = v_max
        self.atoms = atoms
        self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1)
        self.z = tf.reshape(tf.constant([self.v_min + i * self.delta_z for i in range(self.atoms)], dtype=tf.float32), [-1, self.atoms])  # [1, N]
        self.zb = tf.tile(self.z, tf.constant([self.a_dim, 1]))  # [A, N]
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self.max_train_step)
        self.assign_interval = assign_interval

        def _net(): return NetWork(self.feat_dim, self.a_dim, self.atoms, hidden_units)

        self.q_dist_net = _net()
        self.q_target_dist_net = _net()
        self.critic_tv = self.q_dist_net.trainable_variables + self.other_tv
        update_target_net_weights(self.q_target_dist_net.weights, self.q_dist_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self.model_recorder(dict(
            model=self.q_dist_net,
            optimizer=self.optimizer
        ))

    def show_logo(self):
        self.logger.info('''
     xxxxxxx         xxxxx          xxx      
    xxxx xxx         xxxx          xxxx      
   xxxx    x        xxxx             xx      
   xxx     x        xxxxx            xx      
   xxx                xxx            xx      
   xxx                 xxx           xx      
   xxx                  xx           xx      
    xxx    x        xx xx            xx      
    xxxxxxxx        xxxxx           xxxx     
      xxxxx          x              xxxx    
        ''')

    def choose_action(self, s, visual_s, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        with tf.device(self.device):
            feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True)
            q = self.get_q(feat)  # [B, A]
        return tf.argmax(q, axis=-1), cell_state  # [B, 1]

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        def _update():
            if self.global_step % self.assign_interval == 0:
                update_target_net_weights(self.q_target_dist_net.weights, self.q_dist_net.weights)
        for i in range(self.train_times_per_step):
            self._learn(function_dict={
                'train_function': self.train,
                'update_function': _update,
                'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.train_step)]])
            })

    @tf.function(experimental_relax_shapes=True)
    def train(self, memories, isw, crsty_loss, cell_state):
        ss, vvss, a, r, done = memories
        batch_size = tf.shape(a)[0]
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True)
                indexs = tf.reshape(tf.range(batch_size), [-1, 1])  # [B, 1]
                q_dist = self.q_dist_net(feat)  # [B, A, N]
                q_dist = tf.transpose(tf.reduce_sum(tf.transpose(q_dist, [2, 0, 1]) * a, axis=-1), [1, 0])  # [B, N]
                q_eval = tf.reduce_sum(q_dist * self.z, axis=-1)
                target_q_dist = self.q_target_dist_net(feat_)  # [B, A, N]
                target_q = tf.reduce_sum(self.zb * target_q_dist, axis=-1)  # [B, A, N] => [B, A]
                a_ = tf.reshape(tf.cast(tf.argmax(target_q, axis=-1), dtype=tf.int32), [-1, 1])  # [B, 1]
                target_q_dist = tf.gather_nd(target_q_dist, tf.concat([indexs, a_], axis=-1))   # [B, N]
                target = tf.tile(r, tf.constant([1, self.atoms])) \
                    + self.gamma * tf.multiply(self.z,   # [1, N]
                                               (1.0 - tf.tile(done, tf.constant([1, self.atoms]))))  # [B, N], [1, N]* [B, N] = [B, N]
                target = tf.clip_by_value(target, self.v_min, self.v_max)  # [B, N]
                b = (target - self.v_min) / self.delta_z  # [B, N]
                u, l = tf.math.ceil(b), tf.math.floor(b)  # [B, N]
                u_id, l_id = tf.cast(u, tf.int32), tf.cast(l, tf.int32)  # [B, N]
                u_minus_b, b_minus_l = u - b, b - l  # [B, N]
                index_help = tf.tile(indexs, tf.constant([1, self.atoms]))  # [B, N]
                index_help = tf.expand_dims(index_help, -1)  # [B, N, 1]
                u_id = tf.concat([index_help, tf.expand_dims(u_id, -1)], axis=-1)    # [B, N, 2]
                l_id = tf.concat([index_help, tf.expand_dims(l_id, -1)], axis=-1)    # [B, N, 2]
                _cross_entropy = tf.stop_gradient(target_q_dist * u_minus_b) * tf.math.log(tf.gather_nd(q_dist, l_id)) \
                    + tf.stop_gradient(target_q_dist * b_minus_l) * tf.math.log(tf.gather_nd(q_dist, u_id))  # [B, N]
                # tf.debugging.check_numerics(_cross_entropy, '_cross_entropy')
                cross_entropy = -tf.reduce_sum(_cross_entropy, axis=-1)  # [B,]
                # tf.debugging.check_numerics(cross_entropy, 'cross_entropy')
                loss = tf.reduce_mean(cross_entropy * isw) + crsty_loss
                td_error = cross_entropy
            grads = tape.gradient(loss, self.critic_tv)
            self.optimizer.apply_gradients(
                zip(grads, self.critic_tv)
            )
            self.global_step.assign_add(1)
            return td_error, dict([
                ['LOSS/loss', loss],
                ['Statistics/q_max', tf.reduce_max(q_eval)],
                ['Statistics/q_min', tf.reduce_min(q_eval)],
                ['Statistics/q_mean', tf.reduce_mean(q_eval)]
            ])

    @tf.function(experimental_relax_shapes=True)
    def get_q(self, feat):
        with tf.device(self.device):
            return tf.reduce_sum(self.zb * self.q_dist_net(feat), axis=-1)  # [B, A, N] => [B, A]
Esempio n. 17
0
class IQN(Off_Policy):
    '''
    Implicit Quantile Networks, https://arxiv.org/abs/1806.06923
    Double DQN
    '''
    def __init__(self,
                 envspec,
                 online_quantiles=8,
                 target_quantiles=8,
                 select_quantiles=32,
                 quantiles_idx=64,
                 huber_delta=1.,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=2,
                 network_settings={
                     'q_net': [128, 64],
                     'quantile': [128, 64],
                     'tile': [64]
                 },
                 **kwargs):
        assert not envspec.is_continuous, 'iqn only support discrete action space'
        super().__init__(envspec=envspec, **kwargs)
        self.pi = tf.constant(np.pi)
        self.online_quantiles = online_quantiles
        self.target_quantiles = target_quantiles
        self.select_quantiles = select_quantiles
        self.quantiles_idx = quantiles_idx
        self.huber_delta = huber_delta
        self.assign_interval = assign_interval
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self.max_train_step)

        def _create_net(name, representation_net=None):
            return ValueNetwork(name=name,
                                representation_net=representation_net,
                                value_net_type=OutputNetworkType.IQN_NET,
                                value_net_kwargs=dict(
                                    action_dim=self.a_dim,
                                    quantiles_idx=self.quantiles_idx,
                                    network_settings=network_settings))

        self.q_net = _create_net('q_net', self._representation_net)
        self._representation_target_net = self._create_representation_net(
            '_representation_target_net')
        self.q_target_net = _create_net('q_target_net',
                                        self._representation_target_net)
        update_target_net_weights(self.q_target_net.weights,
                                  self.q_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self._worker_params_dict.update(self.q_net._policy_models)

        self._all_params_dict.update(self.q_net._all_models)
        self._all_params_dict.update(optimizer=self.optimizer)
        self._model_post_process()

    def choose_action(self, obs, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(
                self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(obs, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, obs, cell_state):
        batch_size = tf.shape(s)[0]
        with tf.device(self.device):
            _, select_quantiles_tiled = self._generate_quantiles(  # [N*B, 64]
                batch_size=batch_size,
                quantiles_num=self.select_quantiles,
                quantiles_idx=self.quantiles_idx)
            # [B, A]
            (_, q_values), cell_state = self.q_net(
                obs,
                select_quantiles_tiled,
                quantiles_num=self.select_quantiles,
                cell_state=cell_state)
        return tf.argmax(q_values, axis=-1), cell_state  # [B,]

    @tf.function
    def _generate_quantiles(self, batch_size, quantiles_num, quantiles_idx):
        with tf.device(self.device):
            _quantiles = tf.random.uniform([batch_size * quantiles_num, 1],
                                           minval=0,
                                           maxval=1)  # [N*B, 1]
            _quantiles_tiled = tf.tile(
                _quantiles, [1, quantiles_idx])  # [N*B, 1] => [N*B, 64]
            _quantiles_tiled = tf.cast(
                tf.range(quantiles_idx), tf.float32
            ) * self.pi * _quantiles_tiled  # pi * i * tau [N*B, 64] * [64, ] => [N*B, 64]
            _quantiles_tiled = tf.cos(_quantiles_tiled)  # [N*B, 64]
            _quantiles = tf.reshape(
                _quantiles,
                [batch_size, quantiles_num, 1])  # [N*B, 1] => [B, N, 1]
            return _quantiles, _quantiles_tiled

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.q_target_net.weights,
                                      self.q_net.weights)

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(
                function_dict={
                    'summary_dict':
                    dict([['LEARNING_RATE/lr',
                           self.lr(self.train_step)]]),
                    'use_stack':
                    True
                })

    @tf.function
    def _train(self, BATCH, isw, cell_state):
        batch_size = tf.shape(BATCH.action)[0]
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                (feat,
                 feat_), _ = self._representation_net(BATCH.obs,
                                                      cell_state=cell_state,
                                                      need_split=True)
                quantiles, quantiles_tiled = self._generate_quantiles(  # [B, N, 1], [N*B, 64]
                    batch_size=batch_size,
                    quantiles_num=self.online_quantiles,
                    quantiles_idx=self.quantiles_idx)
                quantiles_value, q = self.q_net.value_net(
                    feat, quantiles_tiled,
                    quantiles_num=self.online_quantiles)  # [N, B, A], [B, A]
                _a = tf.reshape(
                    tf.tile(BATCH.action, [self.online_quantiles, 1]),
                    [self.online_quantiles, -1, self.a_dim
                     ])  # [B, A] => [N*B, A] => [N, B, A]
                quantiles_value = tf.reduce_sum(
                    quantiles_value * _a, axis=-1,
                    keepdims=True)  # [N, B, A] => [N, B, 1]
                q_eval = tf.reduce_sum(q * BATCH.action,
                                       axis=-1,
                                       keepdims=True)  # [B, A] => [B, 1]

                _, select_quantiles_tiled = self._generate_quantiles(  # [N*B, 64]
                    batch_size=batch_size,
                    quantiles_num=self.select_quantiles,
                    quantiles_idx=self.quantiles_idx)
                _, q_values = self.q_net.value_net(
                    feat_,
                    select_quantiles_tiled,
                    quantiles_num=self.select_quantiles)  # [B, A]
                next_max_action = tf.argmax(q_values, axis=-1)  # [B,]
                next_max_action = tf.one_hot(tf.squeeze(next_max_action),
                                             self.a_dim,
                                             1.,
                                             0.,
                                             dtype=tf.float32)  # [B, A]
                _next_max_action = tf.reshape(
                    tf.tile(next_max_action, [self.target_quantiles, 1]),
                    [self.target_quantiles, -1, self.a_dim
                     ])  # [B, A] => [N'*B, A] => [N', B, A]
                _, target_quantiles_tiled = self._generate_quantiles(  # [N'*B, 64]
                    batch_size=batch_size,
                    quantiles_num=self.target_quantiles,
                    quantiles_idx=self.quantiles_idx)

                (target_quantiles_value, target_q), _ = self.q_target_net(
                    BATCH.obs_,
                    target_quantiles_tiled,
                    quantiles_num=self.target_quantiles,
                    cell_state=cell_state)  # [N', B, A], [B, A]
                target_quantiles_value = tf.reduce_sum(
                    target_quantiles_value * _next_max_action,
                    axis=-1,
                    keepdims=True)  # [N', B, A] => [N', B, 1]
                target_q = tf.reduce_sum(target_q * BATCH.action,
                                         axis=-1,
                                         keepdims=True)  # [B, A] => [B, 1]
                q_target = tf.stop_gradient(
                    BATCH.reward + self.gamma *
                    (1 - BATCH.done) * target_q)  # [B, 1]
                td_error = q_target - q_eval  # [B, 1]

                _r = tf.reshape(
                    tf.tile(BATCH.reward, [self.target_quantiles, 1]),
                    [self.target_quantiles, -1, 1
                     ])  # [B, 1] => [N'*B, 1] => [N', B, 1]
                _done = tf.reshape(
                    tf.tile(BATCH.done, [self.target_quantiles, 1]),
                    [self.target_quantiles, -1, 1
                     ])  # [B, 1] => [N'*B, 1] => [N', B, 1]

                quantiles_value_target = tf.stop_gradient(
                    _r + self.gamma *
                    (1 - _done) * target_quantiles_value)  # [N', B, 1]
                quantiles_value_target = tf.transpose(quantiles_value_target,
                                                      [1, 2, 0])  # [B, 1, N']
                quantiles_value_online = tf.transpose(quantiles_value,
                                                      [1, 0, 2])  # [B, N, 1]
                quantile_error = quantiles_value_online - quantiles_value_target  # [B, N, 1] - [B, 1, N'] => [B, N, N']
                huber = huber_loss(quantile_error,
                                   delta=self.huber_delta)  # [B, N, N']
                huber_abs = tf.abs(
                    quantiles -
                    tf.where(quantile_error < 0, tf.ones_like(quantile_error),
                             tf.zeros_like(quantile_error))
                )  # [B, N, 1] - [B, N, N'] => [B, N, N']
                loss = tf.reduce_mean(huber_abs * huber,
                                      axis=-1)  # [B, N, N'] => [B, N]
                loss = tf.reduce_sum(loss, axis=-1)  # [B, N] => [B, ]
                loss = tf.reduce_mean(loss * isw)  # [B, ] => 1
            grads = tape.gradient(loss, self.q_net.trainable_variables)
            self.optimizer.apply_gradients(
                zip(grads, self.q_net.trainable_variables))
            self.global_step.assign_add(1)
            return td_error, dict(
                [['LOSS/loss', loss],
                 ['Statistics/q_max',
                  tf.reduce_max(q_eval)],
                 ['Statistics/q_min',
                  tf.reduce_min(q_eval)],
                 ['Statistics/q_mean',
                  tf.reduce_mean(q_eval)]])
Esempio n. 18
0
class AveragedDQN(Off_Policy):
    '''
    Averaged-DQN, http://arxiv.org/abs/1611.01929
    '''

    def __init__(self,
                 envspec,

                 target_k: int = 4,
                 lr: float = 5.0e-4,
                 eps_init: float = 1,
                 eps_mid: float = 0.2,
                 eps_final: float = 0.01,
                 init2mid_annealing_step: int = 1000,
                 assign_interval: int = 1000,
                 network_settings: List[int] = [32, 32],
                 **kwargs):
        assert not envspec.is_continuous, 'dqn only support discrete action space'
        super().__init__(envspec=envspec, **kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self.max_train_step)
        self.assign_interval = assign_interval
        self.target_k = target_k
        assert self.target_k > 0, "assert self.target_k > 0"
        self.target_nets = []
        self.current_target_idx = 0

        def _create_net(name, representation_net=None): return ValueNetwork(
            name=name,
            representation_net=representation_net,
            value_net_type=OutputNetworkType.CRITIC_QVALUE_ALL,
            value_net_kwargs=dict(output_shape=self.a_dim, network_settings=network_settings)
        )
        self.q_net = _create_net('dqn_q_net', self._representation_net)

        for i in range(self.target_k):
            target_q_net = _create_net(
                'dqn_q_target_net' + str(i),
                self._create_representation_net('_representation_target_net' + str(i))
            )
            update_target_net_weights(target_q_net.weights, self.q_net.weights)
            self.target_nets.append(target_q_net)

        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self._worker_params_dict.update(self.q_net._policy_models)

        self._all_params_dict.update(self.q_net._all_models)
        self._all_params_dict.update(optimizer=self.optimizer)
        self._model_post_process()

    def choose_action(self, obs, evaluation: bool = False) -> np.ndarray:
        if np.random.uniform() < self.expl_expt_mng.get_esp(self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(obs, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, obs, cell_state):
        with tf.device(self.device):
            q_values, cell_state = self.q_net(obs, cell_state=cell_state)
            for i in range(1, self.target_k):
                target_q_values, _ = self.target_nets[i](obs, cell_state=cell_state)
                q_values += target_q_values
        return tf.argmax(q_values, axis=1), cell_state  # 不取平均也可以

    def _target_params_update(self):
        if self.global_step % self.assign_interval == 0:
            update_target_net_weights(self.target_nets[self.current_target_idx].weights, self.q_net.weights)
            self.current_target_idx = (self.current_target_idx + 1) % self.target_k

    def learn(self, **kwargs) -> NoReturn:
        self.train_step = kwargs.get('train_step')
        for i in range(self.train_times_per_step):
            self._learn(function_dict={
                'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.train_step)]])
            })

    @tf.function
    def _train(self, BATCH, isw, cell_state):
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                q, _ = self.q_net(BATCH.obs, cell_state=cell_state)
                q_next, _ = self.target_nets[0](BATCH.obs_, cell_state=cell_state)
                for i in range(1, self.target_k):
                    target_q_values, _ = self.target_nets[i](BATCH.obs, cell_state=cell_state)
                    q_next += target_q_values
                q_next /= self.target_k
                q_eval = tf.reduce_sum(tf.multiply(q, BATCH.action), axis=1, keepdims=True)
                q_target = tf.stop_gradient(BATCH.reward + self.gamma * (1 - BATCH.done) * tf.reduce_max(q_next, axis=1, keepdims=True))
                td_error = q_target - q_eval
                q_loss = tf.reduce_mean(tf.square(td_error) * isw)
            grads = tape.gradient(q_loss, self.q_net.trainable_variables)
            self.optimizer.apply_gradients(
                zip(grads, self.q_net.trainable_variables)
            )
            self.global_step.assign_add(1)
            return td_error, dict([
                ['LOSS/loss', q_loss],
                ['Statistics/q_max', tf.reduce_max(q_eval)],
                ['Statistics/q_min', tf.reduce_min(q_eval)],
                ['Statistics/q_mean', tf.reduce_mean(q_eval)]
            ])
Esempio n. 19
0
class IQN(make_off_policy_class(mode='share')):
    '''
    Implicit Quantile Networks, https://arxiv.org/abs/1806.06923
    Double DQN
    '''

    def __init__(self,
                 s_dim,
                 visual_sources,
                 visual_resolution,
                 a_dim,
                 is_continuous,

                 online_quantiles=8,
                 target_quantiles=8,
                 select_quantiles=32,
                 quantiles_idx=64,
                 huber_delta=1.,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=2,
                 hidden_units={
                     'q_net': [128, 64],
                     'quantile': [128, 64],
                     'tile': [64]
                 },
                 **kwargs):
        assert not is_continuous, 'iqn only support discrete action space'
        super().__init__(
            s_dim=s_dim,
            visual_sources=visual_sources,
            visual_resolution=visual_resolution,
            a_dim=a_dim,
            is_continuous=is_continuous,
            **kwargs)
        self.pi = tf.constant(np.pi)
        self.online_quantiles = online_quantiles
        self.target_quantiles = target_quantiles
        self.select_quantiles = select_quantiles
        self.quantiles_idx = quantiles_idx
        self.huber_delta = huber_delta
        self.assign_interval = assign_interval
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self.max_train_step)

        def _net(): return NetWork(self.feat_dim, self.a_dim, self.quantiles_idx, hidden_units)

        self.q_net = _net()
        self.q_target_net = _net()
        self.critic_tv = self.q_net.trainable_variables + self.other_tv
        update_target_net_weights(self.q_target_net.weights, self.q_net.weights)
        self.lr = self.init_lr(lr)
        self.optimizer = self.init_optimizer(self.lr)

        self.model_recorder(dict(
            model=self.q_net,
            optimizer=self.optimizer
        ))

    def show_logo(self):
        self.logger.info('''
    xxxxxxxx       xxxxxxx       xxx    xxx  
    xxxxxxxx      xxxxxxxxx      xxxx   xxx  
      xxx         xxxx  xxxx     xxxxx  xxx  
      xxx         xxx    xxx     xxxxx  xxx  
      xxx        xxxx    xxx     xxxxxx xxx  
      xxx        xxxx    xxx     xxxxxxxxxx  
      xxx        xxxx    xxx     xxx xxxxxx  
      xxx         xxxx  xxxx     xxx xxxxxx  
    xxxxxxxx      xxxxxxxxx      xxx  xxxxx  
    xxxxxxxx       xxxxxxx       xxx   xxxx  
                      xxxx                   
                       xxxx                  
                        xxxx                          
        ''')

    def choose_action(self, s, visual_s, evaluation=False):
        if np.random.uniform() < self.expl_expt_mng.get_esp(self.train_step, evaluation=evaluation):
            a = np.random.randint(0, self.a_dim, self.n_agents)
        else:
            a, self.cell_state = self._get_action(s, visual_s, self.cell_state)
            a = a.numpy()
        return a

    @tf.function
    def _get_action(self, s, visual_s, cell_state):
        batch_size = tf.shape(s)[0]
        with tf.device(self.device):
            feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True)
            _, select_quantiles_tiled = self._generate_quantiles(   # [N*B, 64]
                batch_size=batch_size,
                quantiles_num=self.select_quantiles,
                quantiles_idx=self.quantiles_idx
            )
            _, q_values = self.q_net(feat, select_quantiles_tiled, quantiles_num=self.select_quantiles)  # [B, A]
        return tf.argmax(q_values, axis=-1), cell_state  # [B,]

    @tf.function
    def _generate_quantiles(self, batch_size, quantiles_num, quantiles_idx):
        with tf.device(self.device):
            _quantiles = tf.random.uniform([batch_size * quantiles_num, 1], minval=0, maxval=1)  # [N*B, 1]
            _quantiles_tiled = tf.tile(_quantiles, [1, quantiles_idx])  # [N*B, 1] => [N*B, 64]
            _quantiles_tiled = tf.cast(tf.range(quantiles_idx), tf.float32) * self.pi * _quantiles_tiled  # pi * i * tau [N*B, 64] * [64, ] => [N*B, 64]
            _quantiles_tiled = tf.cos(_quantiles_tiled)   # [N*B, 64]
            _quantiles = tf.reshape(_quantiles, [batch_size, quantiles_num, 1])    # [N*B, 1] => [B, N, 1]
            return _quantiles, _quantiles_tiled

    def learn(self, **kwargs):
        self.train_step = kwargs.get('train_step')

        def _update():
            if self.global_step % self.assign_interval == 0:
                update_target_net_weights(self.q_target_net.weights, self.q_net.weights)
        for i in range(self.train_times_per_step):
            self._learn(function_dict={
                'train_function': self.train,
                'update_function': _update,
                'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.train_step)]])
            })

    @tf.function(experimental_relax_shapes=True)
    def train(self, memories, isw, crsty_loss, cell_state):
        ss, vvss, a, r, done = memories
        batch_size = tf.shape(a)[0]
        with tf.device(self.device):
            with tf.GradientTape() as tape:
                feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True)
                quantiles, quantiles_tiled = self._generate_quantiles(   # [B, N, 1], [N*B, 64]
                    batch_size=batch_size,
                    quantiles_num=self.online_quantiles,
                    quantiles_idx=self.quantiles_idx
                )
                quantiles_value, q = self.q_net(feat, quantiles_tiled, quantiles_num=self.online_quantiles)    # [N, B, A], [B, A]
                _a = tf.reshape(tf.tile(a, [self.online_quantiles, 1]), [self.online_quantiles, -1, self.a_dim])  # [B, A] => [N*B, A] => [N, B, A]
                quantiles_value = tf.reduce_sum(quantiles_value * _a, axis=-1, keepdims=True)   # [N, B, A] => [N, B, 1]
                q_eval = tf.reduce_sum(q * a, axis=-1, keepdims=True)  # [B, A] => [B, 1]

                _, select_quantiles_tiled = self._generate_quantiles(   # [N*B, 64]
                    batch_size=batch_size,
                    quantiles_num=self.select_quantiles,
                    quantiles_idx=self.quantiles_idx
                )
                _, q_values = self.q_net(feat_, select_quantiles_tiled, quantiles_num=self.select_quantiles)  # [B, A]
                next_max_action = tf.argmax(q_values, axis=-1)   # [B,]
                next_max_action = tf.one_hot(tf.squeeze(next_max_action), self.a_dim, 1., 0., dtype=tf.float32)  # [B, A]
                _next_max_action = tf.reshape(tf.tile(next_max_action, [self.target_quantiles, 1]), [self.target_quantiles, -1, self.a_dim])  # [B, A] => [N'*B, A] => [N', B, A]
                _, target_quantiles_tiled = self._generate_quantiles(   # [N'*B, 64]
                    batch_size=batch_size,
                    quantiles_num=self.target_quantiles,
                    quantiles_idx=self.quantiles_idx
                )

                target_quantiles_value, target_q = self.q_target_net(feat_, target_quantiles_tiled, quantiles_num=self.target_quantiles)  # [N', B, A], [B, A]
                target_quantiles_value = tf.reduce_sum(target_quantiles_value * _next_max_action, axis=-1, keepdims=True)   # [N', B, A] => [N', B, 1]
                target_q = tf.reduce_sum(target_q * a, axis=-1, keepdims=True)  # [B, A] => [B, 1]
                q_target = tf.stop_gradient(r + self.gamma * (1 - done) * target_q)   # [B, 1]
                td_error = q_eval - q_target    # [B, 1]

                _r = tf.reshape(tf.tile(r, [self.target_quantiles, 1]), [self.target_quantiles, -1, 1])  # [B, 1] => [N'*B, 1] => [N', B, 1]
                _done = tf.reshape(tf.tile(done, [self.target_quantiles, 1]), [self.target_quantiles, -1, 1])    # [B, 1] => [N'*B, 1] => [N', B, 1]

                quantiles_value_target = tf.stop_gradient(_r + self.gamma * (1 - _done) * target_quantiles_value)   # [N', B, 1]
                quantiles_value_target = tf.transpose(quantiles_value_target, [1, 2, 0])    # [B, 1, N']
                quantiles_value_online = tf.transpose(quantiles_value, [1, 0, 2])   # [B, N, 1]
                quantile_error = quantiles_value_online - quantiles_value_target    # [B, N, 1] - [B, 1, N'] => [B, N, N']
                huber = huber_loss(quantile_error, delta=self.huber_delta)  # [B, N, N']
                huber_abs = tf.abs(quantiles - tf.where(quantile_error < 0, tf.ones_like(quantile_error), tf.zeros_like(quantile_error)))   # [B, N, 1] - [B, N, N'] => [B, N, N']
                loss = tf.reduce_mean(huber_abs * huber, axis=-1)  # [B, N, N'] => [B, N]
                loss = tf.reduce_sum(loss, axis=-1)  # [B, N] => [B, ]
                loss = tf.reduce_mean(loss * isw) + crsty_loss  # [B, ] => 1
            grads = tape.gradient(loss, self.critic_tv)
            self.optimizer.apply_gradients(
                zip(grads, self.critic_tv)
            )
            self.global_step.assign_add(1)
            return td_error, dict([
                ['LOSS/loss', loss],
                ['Statistics/q_max', tf.reduce_max(q_eval)],
                ['Statistics/q_min', tf.reduce_min(q_eval)],
                ['Statistics/q_mean', tf.reduce_mean(q_eval)]
            ])