class RAINBOW(Off_Policy): ''' Rainbow DQN: https://arxiv.org/abs/1710.02298 1. Double 2. Dueling 3. PrioritizedExperienceReplay 4. N-Step 5. Distributional 6. Noisy Net ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, v_min=-10, v_max=10, atoms=51, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, assign_interval=2, hidden_units={ 'share': [128], 'v': [128], 'adv': [128] }, **kwargs): assert not is_continuous, 'rainbow only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.v_min = v_min self.v_max = v_max self.atoms = atoms self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1) self.z = tf.reshape( tf.constant( [self.v_min + i * self.delta_z for i in range(self.atoms)], dtype=tf.float32), [-1, self.atoms]) # [1, N] self.zb = tf.tile(self.z, tf.constant([self.a_counts, 1])) # [A, N] self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.assign_interval = assign_interval self.visual_net = Nn.VisualNet('visual_net', self.visual_dim) self.rainbow_net = Nn.rainbow_dueling(self.s_dim, self.a_counts, self.atoms, 'rainbow_net', hidden_units, visual_net=self.visual_net) self.rainbow_target_net = Nn.rainbow_dueling( self.s_dim, self.a_counts, self.atoms, 'rainbow_target_net', hidden_units, visual_net=self.visual_net) self.update_target_net_weights(self.rainbow_target_net.weights, self.rainbow_net.weights) self.lr = tf.keras.optimizers.schedules.PolynomialDecay( lr, self.max_episode, 1e-10, power=1.0) self.optimizer = tf.keras.optimizers.Adam( learning_rate=self.lr(self.episode)) def show_logo(self): self.recorder.logger.info(''' x x xxxxxxx xxx xxx xxxxxxx xx xx xx xxx xx xx xx xxxx xx xxx xxxxx xxx xxxxxx xxxxxx xxxxxx xxxxxx xx xxxxxx xx xx xxx xxxxxx xxx xxx xxx xx xxx xx xx xxxxxx xxxx xx xx xx xx xx xx xx xx xx xx xx xxxx xxxxx xx xx xx xx xx xxx xx xxxxxxx xx xxx xx xx x xx xx xx xx xxx xx xxx xxxxxxx xxxxx xxxx xxxxxxx xxx xxx xxx xxxxx xxxxx xx xx xxxxx xxxx xxx xx xxx xxxxxxx xxx xx xx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp( self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, len(s)) else: a = self._get_action(s, visual_s).numpy() return sth.int2action_index(a, self.a_dim_or_list) @tf.function def _get_action(self, s, visual_s): s, visual_s = self.cast(s, visual_s) with tf.device(self.device): q = self.get_q(s, visual_s) # [B, A] return tf.argmax(q, axis=-1) # [B, 1] def learn(self, **kwargs): self.episode = kwargs['episode'] for i in range(kwargs['step']): if self.data.is_lg_batch_size: s, visual_s, a, r, s_, visual_s_, done = self.data.sample() if self.use_priority: self.IS_w = self.data.get_IS_w() td_error, summaries = self.train(s, visual_s, a, r, s_, visual_s_, done) if self.use_priority: td_error = np.squeeze(td_error.numpy()) self.data.update(td_error, self.episode) if self.global_step % self.assign_interval == 0: self.update_target_net_weights( self.rainbow_target_net.weights, self.rainbow_net.weights) summaries.update( dict([['LEARNING_RATE/lr', self.lr(self.episode)]])) self.write_training_summaries(self.global_step, summaries) @tf.function(experimental_relax_shapes=True) def train(self, s, visual_s, a, r, s_, visual_s_, done): s, visual_s, a, r, s_, visual_s_, done = self.cast( s, visual_s, a, r, s_, visual_s_, done) with tf.device(self.device): with tf.GradientTape() as tape: indexs = tf.reshape(tf.range(s.shape[0]), [-1, 1]) # [B, 1] q_dist = self.rainbow_net(s, visual_s) # [B, A, N] q_dist = tf.transpose( tf.reduce_sum(tf.transpose(q_dist, [2, 0, 1]) * a, axis=-1), [1, 0]) # [B, N] q_eval = tf.reduce_sum(q_dist * self.z, axis=-1) target_q = self.get_q(s_, visual_s_) # [B, A] a_ = tf.reshape( tf.cast(tf.argmax(target_q, axis=-1), dtype=tf.int32), [-1, 1]) # [B, 1] target_q_dist = self.rainbow_target_net(s_, visual_s_) # [B, A, N] target_q_dist = tf.gather_nd(target_q_dist, tf.concat([indexs, a_], axis=-1)) # [B, N] target = tf.tile(r, tf.constant([1, self.atoms])) \ + self.gamma * tf.multiply(self.z, # [1, N] (1.0 - tf.tile(done, tf.constant([1, self.atoms])))) # [B, N], [1, N]* [B, N] = [B, N] target = tf.clip_by_value(target, self.v_min, self.v_max) # [B, N] b = (target - self.v_min) / self.delta_z # [B, N] u, l = tf.math.ceil(b), tf.math.floor(b) # [B, N] u_id, l_id = tf.cast(u, tf.int32), tf.cast(l, tf.int32) # [B, N] u_minus_b, b_minus_l = u - b, b - l # [B, N] index_help = tf.tile(indexs, tf.constant([1, self.atoms])) # [B, N] index_help = tf.expand_dims(index_help, -1) # [B, N, 1] u_id = tf.concat( [index_help, tf.expand_dims(u_id, -1)], axis=-1) # [B, N, 2] l_id = tf.concat( [index_help, tf.expand_dims(l_id, -1)], axis=-1) # [B, N, 2] _cross_entropy = tf.stop_gradient(target_q_dist * u_minus_b) * tf.math.log(tf.gather_nd(q_dist, l_id)) \ + tf.stop_gradient(target_q_dist * b_minus_l) * tf.math.log(tf.gather_nd(q_dist, u_id)) # [B, N] cross_entropy = -tf.reduce_sum(_cross_entropy, axis=-1) # [B,] loss = tf.reduce_mean(cross_entropy * self.IS_w) td_error = cross_entropy grads = tape.gradient(loss, self.rainbow_net.tv) self.optimizer.apply_gradients(zip(grads, self.rainbow_net.tv)) self.global_step.assign_add(1) return td_error, dict( [['LOSS/loss', loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)]]) @tf.function(experimental_relax_shapes=True) def get_q(self, s, visual_s): with tf.device(self.device): return tf.reduce_sum(self.zb * self.rainbow_net(s, visual_s), axis=-1) # [B, A, N] => [B, A]
class IQN(make_off_policy_class(mode='share')): ''' Implicit Quantile Networks, https://arxiv.org/abs/1806.06923 Double DQN ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim, is_continuous, online_quantiles=8, target_quantiles=8, select_quantiles=32, quantiles_idx=64, huber_delta=1., lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, assign_interval=2, hidden_units={ 'q_net': [128, 64], 'quantile': [128, 64], 'tile': [64] }, **kwargs): assert not is_continuous, 'iqn only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim=a_dim, is_continuous=is_continuous, **kwargs) self.pi = tf.constant(np.pi) self.online_quantiles = online_quantiles self.target_quantiles = target_quantiles self.select_quantiles = select_quantiles self.quantiles_idx = quantiles_idx self.huber_delta = huber_delta self.assign_interval = assign_interval self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) _net = lambda: rls.iqn_net(self.feat_dim, self.a_dim, self. quantiles_idx, hidden_units) self.q_net = _net() self.q_target_net = _net() self.critic_tv = self.q_net.trainable_variables + self.other_tv self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) self.lr = self.init_lr(lr) self.optimizer = self.init_optimizer(self.lr) self.model_recorder(dict(model=self.q_net, optimizer=self.optimizer)) def show_logo(self): self.recorder.logger.info(''' xxxxxxxx xxxxxxx xxx xxx xxxxxxxx xxxxxxxxx xxxx xxx xxx xxxx xxxx xxxxx xxx xxx xxx xxx xxxxx xxx xxx xxxx xxx xxxxxx xxx xxx xxxx xxx xxxxxxxxxx xxx xxxx xxx xxx xxxxxx xxx xxxx xxxx xxx xxxxxx xxxxxxxx xxxxxxxxx xxx xxxxx xxxxxxxx xxxxxxx xxx xxxx xxxx xxxx xxxx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp( self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_dim, self.n_agents) else: a, self.cell_state = self._get_action(s, visual_s, self.cell_state) a = a.numpy() return a @tf.function def _get_action(self, s, visual_s, cell_state): batch_size = tf.shape(s)[0] with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True) _, select_quantiles_tiled = self._generate_quantiles( # [N*B, 64] batch_size=batch_size, quantiles_num=self.select_quantiles, quantiles_idx=self.quantiles_idx) _, q_values = self.q_net( feat, select_quantiles_tiled, quantiles_num=self.select_quantiles) # [B, A] return tf.argmax(q_values, axis=-1), cell_state # [B,] @tf.function def _generate_quantiles(self, batch_size, quantiles_num, quantiles_idx): with tf.device(self.device): _quantiles = tf.random.uniform([batch_size * quantiles_num, 1], minval=0, maxval=1) # [N*B, 1] _quantiles_tiled = tf.tile( _quantiles, [1, quantiles_idx]) # [N*B, 1] => [N*B, 64] _quantiles_tiled = tf.cast( tf.range(quantiles_idx), tf.float32 ) * self.pi * _quantiles_tiled # pi * i * tau [N*B, 64] * [64, ] => [N*B, 64] _quantiles_tiled = tf.cos(_quantiles_tiled) # [N*B, 64] _quantiles = tf.reshape( _quantiles, [batch_size, quantiles_num, 1]) # [N*B, 1] => [B, N, 1] return _quantiles, _quantiles_tiled def learn(self, **kwargs): self.episode = kwargs['episode'] def _update(): if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) for i in range(kwargs['step']): self._learn( function_dict={ 'train_function': self.train, 'update_function': _update, 'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.episode)]]) }) @tf.function(experimental_relax_shapes=True) def train(self, memories, isw, crsty_loss, cell_state): ss, vvss, a, r, done = memories batch_size = tf.shape(a)[0] with tf.device(self.device): with tf.GradientTape() as tape: feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True) quantiles, quantiles_tiled = self._generate_quantiles( # [B, N, 1], [N*B, 64] batch_size=batch_size, quantiles_num=self.online_quantiles, quantiles_idx=self.quantiles_idx) quantiles_value, q = self.q_net( feat, quantiles_tiled, quantiles_num=self.online_quantiles) # [N, B, A], [B, A] _a = tf.reshape(tf.tile(a, [self.online_quantiles, 1]), [self.online_quantiles, -1, self.a_dim ]) # [B, A] => [N*B, A] => [N, B, A] quantiles_value = tf.reduce_sum( quantiles_value * _a, axis=-1, keepdims=True) # [N, B, A] => [N, B, 1] q_eval = tf.reduce_sum(q * a, axis=-1, keepdims=True) # [B, A] => [B, 1] _, select_quantiles_tiled = self._generate_quantiles( # [N*B, 64] batch_size=batch_size, quantiles_num=self.select_quantiles, quantiles_idx=self.quantiles_idx) _, q_values = self.q_net( feat_, select_quantiles_tiled, quantiles_num=self.select_quantiles) # [B, A] next_max_action = tf.argmax(q_values, axis=-1) # [B,] next_max_action = tf.one_hot(tf.squeeze(next_max_action), self.a_dim, 1., 0., dtype=tf.float32) # [B, A] _next_max_action = tf.reshape( tf.tile(next_max_action, [self.target_quantiles, 1]), [self.target_quantiles, -1, self.a_dim ]) # [B, A] => [N'*B, A] => [N', B, A] _, target_quantiles_tiled = self._generate_quantiles( # [N'*B, 64] batch_size=batch_size, quantiles_num=self.target_quantiles, quantiles_idx=self.quantiles_idx) target_quantiles_value, target_q = self.q_target_net( feat_, target_quantiles_tiled, quantiles_num=self.target_quantiles) # [N', B, A], [B, A] target_quantiles_value = tf.reduce_sum( target_quantiles_value * _next_max_action, axis=-1, keepdims=True) # [N', B, A] => [N', B, 1] target_q = tf.reduce_sum(target_q * a, axis=-1, keepdims=True) # [B, A] => [B, 1] q_target = tf.stop_gradient(r + self.gamma * (1 - done) * target_q) # [B, 1] td_error = q_eval - q_target # [B, 1] _r = tf.reshape(tf.tile(r, [self.target_quantiles, 1]), [self.target_quantiles, -1, 1 ]) # [B, 1] => [N'*B, 1] => [N', B, 1] _done = tf.reshape(tf.tile(done, [self.target_quantiles, 1]), [self.target_quantiles, -1, 1 ]) # [B, 1] => [N'*B, 1] => [N', B, 1] quantiles_value_target = tf.stop_gradient( _r + self.gamma * (1 - _done) * target_quantiles_value) # [N', B, 1] quantiles_value_target = tf.transpose(quantiles_value_target, [1, 2, 0]) # [B, 1, N'] quantiles_value_online = tf.transpose(quantiles_value, [1, 0, 2]) # [B, N, 1] quantile_error = quantiles_value_online - quantiles_value_target # [B, N, 1] - [B, 1, N'] => [B, N, N'] huber = huber_loss(quantile_error, delta=self.huber_delta) # [B, N, N'] huber_abs = tf.abs( quantiles - tf.where(quantile_error < 0, tf.ones_like(quantile_error), tf.zeros_like(quantile_error)) ) # [B, N, 1] - [B, N, N'] => [B, N, N'] loss = tf.reduce_mean(huber_abs * huber, axis=-1) # [B, N, N'] => [B, N] loss = tf.reduce_sum(loss, axis=-1) # [B, N] => [B, ] loss = tf.reduce_mean(loss * isw) + crsty_loss # [B, ] => 1 grads = tape.gradient(loss, self.critic_tv) self.optimizer.apply_gradients(zip(grads, self.critic_tv)) self.global_step.assign_add(1) return td_error, dict( [['LOSS/loss', loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)]])
class OC(Off_Policy): ''' The Option-Critic Architecture. http://arxiv.org/abs/1609.05140 ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, q_lr=5.0e-3, intra_option_lr=5.0e-4, termination_lr=5.0e-4, use_eps_greedy=False, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, boltzmann_temperature=1.0, options_num=4, ent_coff=0.01, double_q=False, use_baseline=True, terminal_mask=True, termination_regularizer=0.01, assign_interval=1000, hidden_units={ 'q': [32, 32], 'intra_option': [32, 32], 'termination': [32, 32] }, **kwargs): super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.assign_interval = assign_interval self.options_num = options_num self.termination_regularizer = termination_regularizer self.ent_coff = ent_coff self.use_baseline = use_baseline self.terminal_mask = terminal_mask self.double_q = double_q self.boltzmann_temperature = boltzmann_temperature self.use_eps_greedy = use_eps_greedy _q_net = lambda: Nn.critic_q_all(self.rnn_net.hdim, self.options_num, hidden_units['q']) self.q_net = _q_net() self.q_target_net = _q_net() self.intra_option_net = Nn.oc_intra_option( self.rnn_net.hdim, self.a_counts, self.options_num, hidden_units['intra_option']) self.termination_net = Nn.critic_q_all(self.rnn_net.hdim, self.options_num, hidden_units['termination'], 'sigmoid') self.critic_tv = self.q_net.trainable_variables + self.other_tv self.actor_tv = self.intra_option_net.trainable_variables if self.is_continuous: self.log_std = tf.Variable(initial_value=-0.5 * np.ones( (self.options_num, self.a_counts), dtype=np.float32), trainable=True) # [P, A] self.actor_tv += [self.log_std] self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) self.q_lr, self.intra_option_lr, self.termination_lr = map( self.init_lr, [q_lr, intra_option_lr, termination_lr]) self.q_optimizer = self.init_optimizer(self.q_lr, clipvalue=5.) self.intra_option_optimizer = self.init_optimizer(self.intra_option_lr, clipvalue=5.) self.termination_optimizer = self.init_optimizer(self.termination_lr, clipvalue=5.) self.model_recorder( dict(q_net=self.q_net, intra_option_net=self.intra_option_net, termination_net=self.termination_net, q_optimizer=self.q_optimizer, intra_option_optimizer=self.intra_option_optimizer, termination_optimizer=self.termination_optimizer)) def show_logo(self): self.recorder.logger.info(''' xxxxxx xxxxxxx xxx xxxx xxxx xxx xxx xxx xxxx x xx xxx xxx x xx xxx xxx xx xxx xxx xx xxx xxx xxx xxx xxx x xxxxxxxx xxxxxxxx xxxxx xxxxx ''') def _generate_random_options(self): return tf.constant(np.random.randint(0, self.options_num, self.n_agents), dtype=tf.int32) def choose_action(self, s, visual_s, evaluation=False): if not hasattr(self, 'options'): self.options = self._generate_random_options() self.last_options = self.options a, self.options, self.cell_state = self._get_action( s, visual_s, self.cell_state, self.options) if self.use_eps_greedy: if np.random.uniform() < self.expl_expt_mng.get_esp( self.episode, evaluation=evaluation): # epsilon greedy self.options = self._generate_random_options() a = a.numpy() return a @tf.function def _get_action(self, s, visual_s, cell_state, options): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True, train=False) q = self.q_net(feat) # [B, P] pi = self.intra_option_net(feat) # [B, P, A] beta = self.termination_net(feat) # [B, P] options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B, P] options_onehot_expanded = tf.expand_dims(options_onehot, axis=-1) # [B, P, 1] pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1) # [B, A] if self.is_continuous: log_std = tf.gather(self.log_std, options) mu = tf.math.tanh(pi) a, _ = gaussian_clip_rsample(mu, log_std) else: pi = pi / self.boltzmann_temperature dist = tfp.distributions.Categorical(logits=pi) # [B, ] a = dist.sample() max_options = tf.cast(tf.argmax(q, axis=-1), dtype=tf.int32) # [B, P] => [B, ] if self.use_eps_greedy: new_options = max_options else: beta_probs = tf.reduce_sum(beta * options_onehot, axis=1) # [B, P] => [B,] beta_dist = tfp.distributions.Bernoulli(probs=beta_probs) new_options = tf.where(beta_dist.sample() < 1, options, max_options) return a, new_options, cell_state def learn(self, **kwargs): self.episode = kwargs['episode'] def _update(): if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) for i in range(kwargs['step']): self._learn( function_dict={ 'train_function': self.train, 'update_function': _update, 'sample_data_list': [ 's', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done', 'last_options', 'options' ], 'train_data_list': [ 'ss', 'vvss', 'a', 'r', 'done', 'last_options', 'options' ], 'summary_dict': dict([['LEARNING_RATE/q_lr', self.q_lr(self.episode)], [ 'LEARNING_RATE/intra_option_lr', self.intra_option_lr(self.episode) ], [ 'LEARNING_RATE/termination_lr', self.termination_lr(self.episode) ], ['Statistics/option', self.options[0]]]) }) @tf.function(experimental_relax_shapes=True) def train(self, memories, isw, crsty_loss, cell_state): ss, vvss, a, r, done, last_options, options = memories last_options = tf.cast(last_options, tf.int32) options = tf.cast(options, tf.int32) with tf.device(self.device): with tf.GradientTape(persistent=True) as tape: feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True) q = self.q_net(feat) # [B, P] pi = self.intra_option_net(feat) # [B, P, A] beta = self.termination_net(feat) # [B, P] q_next = self.q_target_net(feat_) # [B, P], [B, P, A], [B, P] beta_next = self.termination_net(feat_) # [B, P] options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B,] => [B, P] q_s = qu_eval = tf.reduce_sum(q * options_onehot, axis=-1, keepdims=True) # [B, 1] beta_s_ = tf.reduce_sum(beta_next * options_onehot, axis=-1, keepdims=True) # [B, 1] q_s_ = tf.reduce_sum(q_next * options_onehot, axis=-1, keepdims=True) # [B, 1] # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L94 if self.double_q: q_ = self.q_net(feat) # [B, P], [B, P, A], [B, P] max_a_idx = tf.one_hot( tf.argmax(q_, axis=-1), self.options_num, dtype=tf.float32) # [B, P] => [B, ] => [B, P] q_s_max = tf.reduce_sum(q_next * max_a_idx, axis=-1, keepdims=True) # [B, 1] else: q_s_max = tf.reduce_max(q_next, axis=-1, keepdims=True) # [B, 1] u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max # [B, 1] qu_target = tf.stop_gradient(r + self.gamma * (1 - done) * u_target) td_error = qu_target - qu_eval # gradient : q q_loss = tf.reduce_mean( tf.square(td_error) * isw) + crsty_loss # [B, 1] => 1 # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L130 if self.use_baseline: adv = tf.stop_gradient(qu_target - qu_eval) else: adv = tf.stop_gradient(qu_target) options_onehot_expanded = tf.expand_dims( options_onehot, axis=-1) # [B, P] => [B, P, 1] pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1) # [B, P, A] => [B, A] if self.is_continuous: log_std = tf.gather(self.log_std, options) mu = tf.math.tanh(pi) log_p = gaussian_likelihood_sum(a, mu, log_std) entropy = gaussian_entropy(log_std) else: pi = pi / self.boltzmann_temperature log_pi = tf.nn.log_softmax(pi, axis=-1) # [B, A] entropy = -tf.reduce_sum(tf.exp(log_pi) * log_pi, axis=1, keepdims=True) # [B, 1] log_p = tf.reduce_sum(a * log_pi, axis=-1, keepdims=True) # [B, 1] pi_loss = tf.reduce_mean( -(log_p * adv + self.ent_coff * entropy) ) # [B, 1] * [B, 1] => [B, 1] => 1 last_options_onehot = tf.one_hot( last_options, self.options_num, dtype=tf.float32) # [B,] => [B, P] beta_s = tf.reduce_sum(beta * last_options_onehot, axis=-1, keepdims=True) # [B, 1] if self.use_eps_greedy: v_s = tf.reduce_max( q, axis=-1, keepdims=True) - self.termination_regularizer # [B, 1] else: v_s = (1 - beta_s) * q_s + beta_s * tf.reduce_max( q, axis=-1, keepdims=True) # [B, 1] # v_s = tf.reduce_mean(q, axis=-1, keepdims=True) # [B, 1] beta_loss = beta_s * tf.stop_gradient(q_s - v_s) # [B, 1] # https://github.com/lweitkamp/option-critic-pytorch/blob/0c57da7686f8903ed2d8dded3fae832ee9defd1a/option_critic.py#L238 if self.terminal_mask: beta_loss *= (1 - done) beta_loss = tf.reduce_mean(beta_loss) # [B, 1] => 1 q_grads = tape.gradient(q_loss, self.critic_tv) intra_option_grads = tape.gradient(pi_loss, self.actor_tv) termination_grads = tape.gradient( beta_loss, self.termination_net.trainable_variables) self.q_optimizer.apply_gradients(zip(q_grads, self.critic_tv)) self.intra_option_optimizer.apply_gradients( zip(intra_option_grads, self.actor_tv)) self.termination_optimizer.apply_gradients( zip(termination_grads, self.termination_net.trainable_variables)) self.global_step.assign_add(1) return td_error, dict( [['LOSS/q_loss', tf.reduce_mean(q_loss)], ['LOSS/pi_loss', tf.reduce_mean(pi_loss)], ['LOSS/beta_loss', tf.reduce_mean(beta_loss)], ['Statistics/q_option_max', tf.reduce_max(q_s)], ['Statistics/q_option_min', tf.reduce_min(q_s)], ['Statistics/q_option_mean', tf.reduce_mean(q_s)]]) def store_data(self, s, visual_s, a, r, s_, visual_s_, done): """ for off-policy training, use this function to store <s, a, r, s_, done> into ReplayBuffer. """ assert isinstance(a, np.ndarray), "store need action type is np.ndarray" assert isinstance(r, np.ndarray), "store need reward type is np.ndarray" assert isinstance(done, np.ndarray), "store need done type is np.ndarray" self.data.add( s, visual_s, a, r[:, np.newaxis], # 升维 s_, visual_s_, done[:, np.newaxis], # 升维 self.last_options, self.options) def no_op_store(self, s, visual_s, a, r, s_, visual_s_, done): pass
class MAXSQN(Off_Policy): ''' https://github.com/createamind/DRL/blob/master/spinup/algos/maxsqn/maxsqn.py ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, alpha=0.2, beta=0.1, ployak=0.995, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, use_epsilon=False, q_lr=5.0e-4, alpha_lr=5.0e-4, auto_adaption=True, hidden_units=[32, 32], **kwargs): assert not is_continuous, 'maxsqn only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.use_epsilon = use_epsilon self.ployak = ployak self.log_alpha = alpha if not auto_adaption else tf.Variable( initial_value=0.0, name='log_alpha', dtype=tf.float32, trainable=True) self.auto_adaption = auto_adaption self.target_alpha = beta * np.log(self.a_counts) _q_net = lambda: Nn.critic_q_all(self.rnn_net.hdim, self.a_counts, hidden_units) self.q1_net = _q_net() self.q1_target_net = _q_net() self.q2_net = _q_net() self.q2_target_net = _q_net() self.critic_tv = self.q1_net.trainable_variables + self.q2_net.trainable_variables + self.other_tv self.update_target_net_weights( self.q1_target_net.weights + self.q2_target_net.weights, self.q1_net.weights + self.q2_net.weights) self.q_lr, self.alpha_lr = map(self.init_lr, [q_lr, alpha_lr]) self.optimizer_critic, self.optimizer_alpha = map( self.init_optimizer, [self.q_lr, self.alpha_lr]) self.model_recorder( dict(q1_net=self.q1_net, q2_net=self.q2_net, optimizer_critic=self.optimizer_critic, optimizer_alpha=self.optimizer_alpha)) def show_logo(self): self.recorder.logger.info(''' xx xx xxxxxx xxxxxx xxxx xx xxx xxx xxx xxx xxxx xxx xxxx xx xxx xxx xxxxx x xx xx xx xx xxxxx xx xxxx xxx xxxxxx xx xxx xxxxxx xx xxx xx xxx xx xxxx xx x x xxx xxxxx xxxxxx xx xx xx xxxxx xxxx xx x xxxxxx xxx xxx xxx x xxx xx xxxx xx xxx x xxx xx xxx xx xx xx xxxxx xx xxxx xx xxx x xx xxx xxxxx xxxxxxxxx xxx xxxx xx xxx xx xxx x xxxxxxxx xxx xxx xxxxxxx xxxxxxx xx xx xxxxxxx ''') def choose_action(self, s, visual_s, evaluation=False): if self.use_epsilon and np.random.uniform( ) < self.expl_expt_mng.get_esp(self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, self.n_agents) else: mu, pi, self.cell_state = self._get_action(s, visual_s, self.cell_state) a = pi.numpy() return a @tf.function def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True, train=False) q = self.q1_net(feat) cate_dist = tfp.distributions.Categorical(logits=q / tf.exp(self.log_alpha)) pi = cate_dist.sample() return tf.argmax(q, axis=1), pi, cell_state def learn(self, **kwargs): self.episode = kwargs['episode'] for i in range(kwargs['step']): self._learn( function_dict={ 'train_function': self.train, 'update_function': lambda: self.update_target_net_weights( self.q1_target_net.weights + self.q2_target_net. weights, self.q1_net.weights + self.q2_net.weights, self.ployak), 'summary_dict': dict([[ 'LEARNING_RATE/q_lr', self.q_lr(self.episode) ], ['LEARNING_RATE/alpha_lr', self.alpha_lr(self.episode)]]) }) @tf.function(experimental_relax_shapes=True) def train(self, memories, isw, crsty_loss, cell_state): ss, vvss, a, r, done = memories with tf.device(self.device): with tf.GradientTape() as tape: feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True) q1 = self.q1_net(feat) q1_eval = tf.reduce_sum(tf.multiply(q1, a), axis=1, keepdims=True) q2 = self.q2_net(feat) q2_eval = tf.reduce_sum(tf.multiply(q2, a), axis=1, keepdims=True) q1_target = self.q1_target_net(feat_) q1_target_max = tf.reduce_max(q1_target, axis=1, keepdims=True) q1_target_log_probs = tf.nn.log_softmax( q1_target / tf.exp(self.log_alpha), axis=1) + 1e-8 q1_target_log_max = tf.reduce_max(q1_target_log_probs, axis=1, keepdims=True) q1_target_entropy = -tf.reduce_mean( tf.reduce_sum( tf.exp(q1_target_log_probs) * q1_target_log_probs, axis=1, keepdims=True)) q2_target = self.q2_target_net(feat_) q2_target_max = tf.reduce_max(q2_target, axis=1, keepdims=True) # q2_target_log_probs = tf.nn.log_softmax(q2_target, axis=1) # q2_target_log_max = tf.reduce_max(q2_target_log_probs, axis=1, keepdims=True) q_target = tf.minimum( q1_target_max, q2_target_max) + tf.exp(self.log_alpha) * q1_target_entropy dc_r = tf.stop_gradient(r + self.gamma * q_target * (1 - done)) td_error1 = q1_eval - dc_r td_error2 = q2_eval - dc_r q1_loss = tf.reduce_mean(tf.square(td_error1) * isw) q2_loss = tf.reduce_mean(tf.square(td_error2) * isw) loss = 0.5 * (q1_loss + q2_loss) + crsty_loss loss_grads = tape.gradient(loss, self.critic_tv) self.optimizer_critic.apply_gradients( zip(loss_grads, self.critic_tv)) if self.auto_adaption: with tf.GradientTape() as tape: q1 = self.q1_net(feat) q1_log_probs = tf.nn.log_softmax( q1_target / tf.exp(self.log_alpha), axis=1) + 1e-8 q1_log_max = tf.reduce_max(q1_log_probs, axis=1, keepdims=True) q1_entropy = -tf.reduce_mean( tf.reduce_sum(tf.exp(q1_log_probs) * q1_log_probs, axis=1, keepdims=True)) alpha_loss = -tf.reduce_mean( self.log_alpha * tf.stop_gradient(self.target_alpha - q1_entropy)) alpha_grad = tape.gradient(alpha_loss, self.log_alpha) self.optimizer_alpha.apply_gradients([(alpha_grad, self.log_alpha)]) self.global_step.assign_add(1) summaries = dict( [['LOSS/loss', loss], ['Statistics/log_alpha', self.log_alpha], ['Statistics/alpha', tf.exp(self.log_alpha)], ['Statistics/q1_entropy', q1_entropy], ['Statistics/q_min', tf.reduce_mean(tf.minimum(q1, q2))], ['Statistics/q_mean', tf.reduce_mean(q1)], ['Statistics/q_max', tf.reduce_mean(tf.maximum(q1, q2))]]) if self.auto_adaption: summaries.update({'LOSS/alpha_loss': alpha_loss}) return td_error1 + td_error2 / 2, summaries
class QRDQN(make_off_policy_class(mode='share')): ''' Quantile Regression DQN Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/abs/1710.10044 No double, no dueling, no noisy net. ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim, is_continuous, nums=20, huber_delta=1., lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, hidden_units=[128, 128], **kwargs): assert not is_continuous, 'qrdqn only support discrete action space' assert nums > 0 super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim=a_dim, is_continuous=is_continuous, **kwargs) self.nums = nums self.huber_delta = huber_delta self.quantiles = tf.reshape( tf.constant((2 * np.arange(self.nums) + 1) / (2.0 * self.nums), dtype=tf.float32), [-1, self.nums]) # [1, N] self.batch_quantiles = tf.tile(self.quantiles, [self.a_dim, 1]) # [1, N] => [A, N] self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self.max_train_step) self.assign_interval = assign_interval def _net(): return rls.qrdqn_distributional(self.feat_dim, self.a_dim, self.nums, hidden_units) self.q_dist_net = _net() self.q_target_dist_net = _net() self.critic_tv = self.q_dist_net.trainable_variables + self.other_tv self.update_target_net_weights(self.q_target_dist_net.weights, self.q_dist_net.weights) self.lr = self.init_lr(lr) self.optimizer = self.init_optimizer(self.lr) self.model_recorder( dict(model=self.q_dist_net, optimizer=self.optimizer)) def show_logo(self): self.recorder.logger.info(''' xxxxxx xxxxxxx xxxxxxxx xxxxxx xxxx xxxx xxx xxxx xxxxxxx xxxxxxxx xxx xxxx xxx x xxx xxxx xx xxx xx xxx xxx xxxx xxxx x xxx xxx xx xxx xx xxx xxx xxx xxxxx x xx xxx xxxxxx xx xx xx xxx x xxxx x xxx xxx xxxxxx xx xx xxx xxx x xxxxx xxx xxx xx xxxx xx xxx xxx xxx x xxxx xxx xxx xx xxx xx xxxx xxx xxx x xxx xxxxxxxx xxxxx xxxx xxxxxxxx xxxxxxxx xxx xx xxxxx xxxxx xxxx xxxxxxx xxxxx xxxx xxxx xxx xxx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp( self.train_step, evaluation=evaluation): a = np.random.randint(0, self.a_dim, self.n_agents) else: a, self.cell_state = self._get_action(s, visual_s, self.cell_state) a = a.numpy() return a @tf.function def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True) q = self.get_q(feat) # [B, A] return tf.argmax(q, axis=-1), cell_state # [B, 1] def learn(self, **kwargs): self.train_step = kwargs.get('train_step') def _update(): if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_dist_net.weights, self.q_dist_net.weights) for i in range(kwargs['step']): self._learn( function_dict={ 'train_function': self.train, 'update_function': _update, 'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.train_step)]]) }) @tf.function(experimental_relax_shapes=True) def train(self, memories, isw, crsty_loss, cell_state): ss, vvss, a, r, done = memories batch_size = tf.shape(a)[0] with tf.device(self.device): with tf.GradientTape() as tape: feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True) indexs = tf.reshape(tf.range(batch_size), [-1, 1]) # [B, 1] q_dist = self.q_dist_net(feat) # [B, A, N] q_dist = tf.transpose( tf.reduce_sum(tf.transpose(q_dist, [2, 0, 1]) * a, axis=-1), [1, 0]) # [B, N] target_q_dist = self.q_target_dist_net(feat_) # [B, A, N] target_q = tf.reduce_sum(self.batch_quantiles * target_q_dist, axis=-1) # [B, A, N] => [B, A] a_ = tf.reshape( tf.cast(tf.argmax(target_q, axis=-1), dtype=tf.int32), [-1, 1]) # [B, 1] target_q_dist = tf.gather_nd(target_q_dist, tf.concat([indexs, a_], axis=-1)) # [B, N] target = tf.tile(r, tf.constant([1, self.nums])) \ + self.gamma * tf.multiply(self.quantiles, # [1, N] (1.0 - tf.tile(done, tf.constant([1, self.nums])))) # [B, N], [1, N]* [B, N] = [B, N] q_eval = tf.reduce_sum(q_dist * self.quantiles, axis=-1) # [B, 1] q_target = tf.reduce_sum(target * self.quantiles, axis=-1) # [B, 1] td_error = q_eval - q_target # [B, 1] quantile_error = tf.expand_dims( q_dist, axis=-1) - tf.expand_dims( target, axis=1) # [B, N, 1] - [B, 1, N] => [B, N, N] huber = huber_loss(quantile_error, delta=self.huber_delta) # [B, N, N] huber_abs = tf.abs( self.quantiles - tf.where(quantile_error < 0, tf.ones_like(quantile_error), tf.zeros_like(quantile_error)) ) # [1, N] - [B, N, N] => [B, N, N] loss = tf.reduce_mean(huber_abs * huber, axis=-1) # [B, N, N] => [B, N] loss = tf.reduce_sum(loss, axis=-1) # [B, N] => [B, ] loss = tf.reduce_mean(loss * isw) + crsty_loss # [B, ] => 1 grads = tape.gradient(loss, self.critic_tv) self.optimizer.apply_gradients(zip(grads, self.critic_tv)) self.global_step.assign_add(1) return td_error, dict( [['LOSS/loss', loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)]]) @tf.function(experimental_relax_shapes=True) def get_q(self, feat): with tf.device(self.device): return tf.reduce_sum(self.batch_quantiles * self.q_dist_net(feat), axis=-1) # [B, A, N] => [B, A]
class DDDQN(Off_Policy): ''' Dueling Double DQN, https://arxiv.org/abs/1511.06581 ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, assign_interval=2, hidden_units={ 'share': [128], 'v': [128], 'adv': [128] }, **kwargs): assert not is_continuous, 'dueling double dqn only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.assign_interval = assign_interval self.visual_net = Nn.VisualNet('visual_net', self.visual_dim) self.dueling_net = Nn.critic_dueling(self.s_dim, self.a_counts, 'dueling_net', hidden_units, visual_net=self.visual_net) self.dueling_target_net = Nn.critic_dueling(self.s_dim, self.a_counts, 'dueling_target_net', hidden_units, visual_net=self.visual_net) self.update_target_net_weights(self.dueling_target_net.weights, self.dueling_net.weights) self.lr = tf.keras.optimizers.schedules.PolynomialDecay( lr, self.max_episode, 1e-10, power=1.0) self.optimizer = tf.keras.optimizers.Adam( learning_rate=self.lr(self.episode)) def show_logo(self): self.recorder.logger.info(''' xxxxxxxx xxxxxxxx xxxxxxxx xxxxxx xxxx xxxx xxxxxxxx xxxxxxxx xxxxxxxx xxx xxxx xxx x xx xxx xx xxx xx xxx xxx xxxx xxxx x xx xxx xx xxx xx xxx xxx xxx xxxxx x xx xx xx xx xx xx xx xxx x xxxx x xx xx xx xx xx xx xxx xxx x xxxxx xx xxx xx xxx xx xxx xxx xxx x xxxx xx xxxx xx xxxx xx xxxx xxx xxx x xxx xxxxxxxx xxxxxxxx xxxxxxxx xxxxxxxx xxx xx xxxxxxx xxxxxxx xxxxxxx xxxxx xxxx xxx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp( self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, len(s)) else: a = self._get_action(s, visual_s).numpy() return sth.int2action_index(a, self.a_dim_or_list) @tf.function def _get_action(self, s, visual_s): s, visual_s = self.cast(s, visual_s) with tf.device(self.device): q = self.dueling_net(s, visual_s) return tf.argmax(q, axis=-1) def learn(self, **kwargs): self.episode = kwargs['episode'] for i in range(kwargs['step']): if self.data.is_lg_batch_size: s, visual_s, a, r, s_, visual_s_, done = self.data.sample() if self.use_priority: self.IS_w = self.data.get_IS_w() td_error, summaries = self.train(s, visual_s, a, r, s_, visual_s_, done) if self.use_priority: td_error = np.squeeze(td_error.numpy()) self.data.update(td_error, self.episode) if self.global_step % self.assign_interval == 0: self.update_target_net_weights( self.dueling_target_net.weights, self.dueling_net.weights) summaries.update( dict([['LEARNING_RATE/lr', self.lr(self.episode)]])) self.write_training_summaries(self.global_step, summaries) @tf.function(experimental_relax_shapes=True) def train(self, s, visual_s, a, r, s_, visual_s_, done): s, visual_s, a, r, s_, visual_s_, done = self.cast( s, visual_s, a, r, s_, visual_s_, done) with tf.device(self.device): with tf.GradientTape() as tape: q = self.dueling_net(s, visual_s) q_eval = tf.reduce_sum(tf.multiply(q, a), axis=1, keepdims=True) next_q = self.dueling_net(s_, visual_s_) next_max_action = tf.argmax(next_q, axis=1, name='next_action_int') next_max_action_one_hot = tf.one_hot( tf.squeeze(next_max_action), self.a_counts, 1., 0., dtype=tf.float32) next_max_action_one_hot = tf.cast(next_max_action_one_hot, tf.float32) q_target = self.dueling_target_net(s_, visual_s_) q_target_next_max = tf.reduce_sum(tf.multiply( q_target, next_max_action_one_hot), axis=1, keepdims=True) q_target = tf.stop_gradient(r + self.gamma * (1 - done) * q_target_next_max) td_error = q_eval - q_target q_loss = tf.reduce_mean(tf.square(td_error) * self.IS_w) grads = tape.gradient(q_loss, self.dueling_net.tv) self.optimizer.apply_gradients(zip(grads, self.dueling_net.tv)) self.global_step.assign_add(1) return td_error, dict( [['LOSS/loss', q_loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)]])
class DQN_GCN(Off_Policy): ''' Deep Q-learning Network, DQN, [2013](https://arxiv.org/pdf/1312.5602.pdf), [2015](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, assign_interval=1000, hidden_units=[32, 32], **kwargs): assert not is_continuous, 'dqn only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.assign_interval = assign_interval self.visual_net = Nn.VisualNet('visual_net', self.visual_dim) self.q_net = Nn.critic_q_all_gcn(self.s_dim, self.a_counts, 'q_net', hidden_units, visual_net=self.visual_net) self.q_target_net = Nn.critic_q_all_gcn(self.s_dim, self.a_counts, 'q_target_net', hidden_units, visual_net=self.visual_net) self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) self.lr = tf.keras.optimizers.schedules.PolynomialDecay( lr, self.max_episode, 1e-10, power=1.0) self.optimizer = tf.keras.optimizers.Adam( learning_rate=self.lr(self.episode)) def show_logo(self): self.recorder.logger.info(''' xxxxxxxx xxxxxx xxxx xxxx xxxxxxxx xxx xxxx xxx x xx xxx xxx xxxx xxxx x xx xxx xxx xxx xxxxx x xx xx xx xxx x xxxx x xx xx xxx xxx x xxxxx xx xxx xxx xxx x xxxx xx xxxx xxx xxx x xxx xxxxxxxx xxxxxxxx xxx xx xxxxxxx xxxxx xxxx xxxxxxxxxxxxxx ''') def choose_action(self, adj, x, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp( self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, len(adj)) else: a = self._get_action(adj, x, visual_s).numpy() return sth.int2action_index(a, self.a_dim_or_list) @tf.function def _get_action(self, adj, x, visual_s): adj, x, visual_s = self.cast(adj, x, visual_s) with tf.device(self.device): q_values = self.q_net(adj, x, visual_s) return tf.argmax(q_values, axis=1) def learn(self, **kwargs): self.episode = kwargs['episode'] for i in range(kwargs['step']): if self.data.is_lg_batch_size: adj, x, visual_s, a, r, adj_, x_, visual_s_, done = self.data.sample( ) if self.use_priority: self.IS_w = self.data.get_IS_w() td_error, summaries = self.train(adj, x, visual_s, a, r, adj_, x_, visual_s_, done) if self.use_priority: td_error = np.squeeze(td_error.numpy()) self.data.update(td_error, self.episode) if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) summaries.update( dict([['LEARNING_RATE/lr', self.lr(self.episode)]])) self.write_training_summaries(self.global_step, summaries) @tf.function(experimental_relax_shapes=True) def train(self, adj, x, visual_s, a, r, adj_, x_, visual_s_, done): adj, x, visual_s, a, r, adj_, x_, visual_s_, done = self.cast( adj, x, visual_s, a, r, adj_, x_, visual_s_, done) with tf.device(self.device): with tf.GradientTape() as tape: q = self.q_net(adj, x, visual_s) q_next = self.q_target_net(adj_, x_, visual_s_) q_eval = tf.reduce_sum(tf.multiply(q, a), axis=1, keepdims=True) q_target = tf.stop_gradient( r + self.gamma * (1 - done) * tf.reduce_max(q_next, axis=1, keepdims=True)) td_error = q_eval - q_target q_loss = tf.reduce_mean(tf.square(td_error) * self.IS_w) grads = tape.gradient(q_loss, self.q_net.tv) self.optimizer.apply_gradients(zip(grads, self.q_net.tv)) self.global_step.assign_add(1) return td_error, dict( [['LOSS/loss', q_loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)]])
def test_exploration_exploitation_class(): my_expl = ExplorationExploitationClass(eps_init=1, eps_mid=0.2, eps_final=0.01, eps_eval=0, init2mid_annealing_step=50, start_step=0, max_step=100) assert my_expl.get_esp(0) == 1 assert my_expl.get_esp(0, evaluation=True) == 0 assert my_expl.get_esp(80, evaluation=True) == 0 assert my_expl.get_esp(2) < 1 assert my_expl.get_esp(50) == 0.2 assert my_expl.get_esp(51) < 0.2 assert my_expl.get_esp(100) >= 0.01 my_expl = ExplorationExploitationClass(eps_init=0.2, eps_mid=0.1, eps_final=0, eps_eval=0, init2mid_annealing_step=1000, start_step=0, max_step=10000) assert my_expl.get_esp(0) == 0.2 assert my_expl.get_esp(0, evaluation=True) == 0 assert my_expl.get_esp(500, evaluation=True) == 0 assert my_expl.get_esp(500) < 0.2 assert my_expl.get_esp(1000) == 0.1 assert my_expl.get_esp(2000) < 0.1 assert my_expl.get_esp(9000) > 0
class QRDQN(Off_Policy): ''' Quantile Regression DQN Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/abs/1710.10044 No double, no dueling, no noisy net. ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, nums=20, huber_delta=1., lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, assign_interval=1000, hidden_units=[128, 128], **kwargs): assert not is_continuous, 'qrdqn only support discrete action space' assert nums > 0 super().__init__( s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.nums = nums self.huber_delta = huber_delta self.quantiles = tf.reshape(tf.constant((2 * np.arange(self.nums) + 1) / (2.0 * self.nums), dtype=tf.float32), [-1, self.nums]) # [1, N] self.batch_quantiles = tf.tile(self.quantiles, [self.a_counts, 1]) # [1, N] => [A, N] self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.assign_interval = assign_interval self.visual_net = Nn.VisualNet('visual_net', self.visual_dim) self.q_dist_net = Nn.qrdqn_distributional(self.s_dim, self.a_counts, self.nums, 'q_dist_net', hidden_units, visual_net=self.visual_net) self.q_target_dist_net = Nn.qrdqn_distributional(self.s_dim, self.a_counts, self.nums, 'q_target_dist_net', hidden_units, visual_net=self.visual_net) self.update_target_net_weights(self.q_target_dist_net.weights, self.q_dist_net.weights) self.lr = tf.keras.optimizers.schedules.PolynomialDecay(lr, self.max_episode, 1e-10, power=1.0) self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr(self.episode)) def show_logo(self): self.recorder.logger.info(''' xxxxxx xxxxxxx xxxxxxxx xxxxxx xxxx xxxx xxx xxxx xxxxxxx xxxxxxxx xxx xxxx xxx x xxx xxxx xx xxx xx xxx xxx xxxx xxxx x xxx xxx xx xxx xx xxx xxx xxx xxxxx x xx xxx xxxxxx xx xx xx xxx x xxxx x xxx xxx xxxxxx xx xx xxx xxx x xxxxx xxx xxx xx xxxx xx xxx xxx xxx x xxxx xxx xxx xx xxx xx xxxx xxx xxx x xxx xxxxxxxx xxxxx xxxx xxxxxxxx xxxxxxxx xxx xx xxxxx xxxxx xxxx xxxxxxx xxxxx xxxx xxxx xxx xxx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp(self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, len(s)) else: a = self._get_action(s, visual_s).numpy() return sth.int2action_index(a, self.a_dim_or_list) @tf.function def _get_action(self, s, visual_s): s, visual_s = self.cast(s, visual_s) with tf.device(self.device): q = self.get_q(s, visual_s) # [B, A] return tf.argmax(q, axis=-1) # [B, 1] def learn(self, **kwargs): self.episode = kwargs['episode'] for i in range(kwargs['step']): if self.data.is_lg_batch_size: s, visual_s, a, r, s_, visual_s_, done = self.data.sample() if self.use_priority: self.IS_w = self.data.get_IS_w() td_error, summaries = self.train(s, visual_s, a, r, s_, visual_s_, done) if self.use_priority: td_error = np.squeeze(td_error.numpy()) self.data.update(td_error, self.episode) if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_dist_net.weights, self.q_dist_net.weights) summaries.update(dict([['LEARNING_RATE/lr', self.lr(self.episode)]])) self.write_training_summaries(self.global_step, summaries) @tf.function(experimental_relax_shapes=True) def train(self, s, visual_s, a, r, s_, visual_s_, done): s, visual_s, a, r, s_, visual_s_, done = self.cast(s, visual_s, a, r, s_, visual_s_, done) with tf.device(self.device): with tf.GradientTape() as tape: indexs = tf.reshape(tf.range(s.shape[0]), [-1, 1]) # [B, 1] q_dist = self.q_dist_net(s, visual_s) # [B, A, N] q_dist = tf.transpose(tf.reduce_sum(tf.transpose(q_dist, [2, 0, 1]) * a, axis=-1), [1, 0]) # [B, N] target_q_dist = self.q_target_dist_net(s_, visual_s_) # [B, A, N] target_q = tf.reduce_sum(self.batch_quantiles * target_q_dist, axis=-1) # [B, A, N] => [B, A] a_ = tf.reshape(tf.cast(tf.argmax(target_q, axis=-1), dtype=tf.int32), [-1, 1]) # [B, 1] target_q_dist = tf.gather_nd(target_q_dist, tf.concat([indexs, a_], axis=-1)) # [B, N] target = tf.tile(r, tf.constant([1, self.nums])) \ + self.gamma * tf.multiply(self.quantiles, # [1, N] (1.0 - tf.tile(done, tf.constant([1, self.nums])))) # [B, N], [1, N]* [B, N] = [B, N] q_eval = tf.reduce_sum(q_dist * self.quantiles, axis=-1) # [B, 1] q_target = tf.reduce_sum(target * self.quantiles, axis=-1) # [B, 1] td_error = q_eval - q_target # [B, 1] quantile_error = tf.expand_dims(q_dist, axis=-1) - tf.expand_dims(target, axis=1) # [B, N, 1] - [B, 1, N] => [B, N, N] huber = huber_loss(quantile_error, delta=self.huber_delta) # [B, N, N] huber_abs = tf.abs(self.quantiles - tf.where(quantile_error < 0, tf.ones_like(quantile_error), tf.zeros_like(quantile_error))) # [1, N] - [B, N, N] => [B, N, N] loss = tf.reduce_mean(huber_abs * huber, axis=-1) # [B, N, N] => [B, N] loss = tf.reduce_sum(loss, axis=-1) # [B, N] => [B, ] loss = tf.reduce_mean(loss * self.IS_w) # [B, ] => 1 grads = tape.gradient(loss, self.q_dist_net.tv) self.optimizer.apply_gradients( zip(grads, self.q_dist_net.tv) ) self.global_step.assign_add(1) return td_error, dict([ ['LOSS/loss', loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)] ]) @tf.function(experimental_relax_shapes=True) def get_q(self, s, visual_s): with tf.device(self.device): return tf.reduce_sum(self.batch_quantiles * self.q_dist_net(s, visual_s), axis=-1) # [B, A, N] => [B, A]
class DDDQN(make_off_policy_class(mode='share')): ''' Dueling Double DQN, https://arxiv.org/abs/1511.06581 ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim, is_continuous, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=2, hidden_units={ 'share': [128], 'v': [128], 'adv': [128] }, **kwargs): assert not is_continuous, 'dueling double dqn only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim=a_dim, is_continuous=is_continuous, **kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self.max_train_step) self.assign_interval = assign_interval def _net(): return rls.critic_dueling(self.feat_dim, self.a_dim, hidden_units) self.dueling_net = _net() self.dueling_target_net = _net() self.critic_tv = self.dueling_net.trainable_variables + self.other_tv self.update_target_net_weights(self.dueling_target_net.weights, self.dueling_net.weights) self.lr = self.init_lr(lr) self.optimizer = self.init_optimizer(self.lr) self.model_recorder( dict(model=self.dueling_net, optimizer=self.optimizer)) def show_logo(self): self.recorder.logger.info(''' xxxxxxxx xxxxxxxx xxxxxxxx xxxxxx xxxx xxxx xxxxxxxx xxxxxxxx xxxxxxxx xxx xxxx xxx x xx xxx xx xxx xx xxx xxx xxxx xxxx x xx xxx xx xxx xx xxx xxx xxx xxxxx x xx xx xx xx xx xx xx xxx x xxxx x xx xx xx xx xx xx xxx xxx x xxxxx xx xxx xx xxx xx xxx xxx xxx x xxxx xx xxxx xx xxxx xx xxxx xxx xxx x xxx xxxxxxxx xxxxxxxx xxxxxxxx xxxxxxxx xxx xx xxxxxxx xxxxxxx xxxxxxx xxxxx xxxx xxx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp( self.train_step, evaluation=evaluation): a = np.random.randint(0, self.a_dim, self.n_agents) else: a, self.cell_state = self._get_action(s, visual_s, self.cell_state) a = a.numpy() return a @tf.function def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True) q = self.dueling_net(feat) return tf.argmax(q, axis=-1), cell_state def learn(self, **kwargs): self.train_step = kwargs.get('train_step') def _update(): if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.dueling_target_net.weights, self.dueling_net.weights) for i in range(kwargs['step']): self._learn( function_dict={ 'train_function': self.train, 'update_function': _update, 'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.train_step)]]) }) @tf.function(experimental_relax_shapes=True) def train(self, memories, isw, crsty_loss, cell_state): ss, vvss, a, r, done = memories with tf.device(self.device): with tf.GradientTape() as tape: feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True) q = self.dueling_net(feat) q_eval = tf.reduce_sum(tf.multiply(q, a), axis=1, keepdims=True) next_q = self.dueling_net(feat_) next_max_action = tf.argmax(next_q, axis=1, name='next_action_int') next_max_action_one_hot = tf.one_hot( tf.squeeze(next_max_action), self.a_dim, 1., 0., dtype=tf.float32) next_max_action_one_hot = tf.cast(next_max_action_one_hot, tf.float32) q_target = self.dueling_target_net(feat_) q_target_next_max = tf.reduce_sum(tf.multiply( q_target, next_max_action_one_hot), axis=1, keepdims=True) q_target = tf.stop_gradient(r + self.gamma * (1 - done) * q_target_next_max) td_error = q_eval - q_target q_loss = tf.reduce_mean(tf.square(td_error) * isw) + crsty_loss grads = tape.gradient(q_loss, self.critic_tv) self.optimizer.apply_gradients(zip(grads, self.critic_tv)) self.global_step.assign_add(1) return td_error, dict( [['LOSS/loss', q_loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)]])
class BootstrappedDQN(make_off_policy_class(mode='share')): ''' Deep Exploration via Bootstrapped DQN, http://arxiv.org/abs/1602.04621 ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim, is_continuous, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, assign_interval=1000, head_num=4, hidden_units=[32, 32], **kwargs): assert not is_continuous, 'Bootstrapped DQN only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim=a_dim, is_continuous=is_continuous, **kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.assign_interval = assign_interval self.head_num = head_num self._probs = [1. / head_num for _ in range(head_num)] self.now_head = 0 _q_net = lambda: rls.critic_q_bootstrap(self.feat_dim, self.a_dim, self .head_num, hidden_units) self.q_net = _q_net() self.q_target_net = _q_net() self.critic_tv = self.q_net.trainable_variables + self.other_tv self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) self.lr = self.init_lr(lr) self.optimizer = self.init_optimizer(self.lr) self.model_recorder(dict(model=self.q_net, optimizer=self.optimizer)) def show_logo(self): self.recorder.logger.info(''' xxxxxxx xxxxxxxx xxxxxx xxxx xxxx xx xxxx xxxxxxxx xxx xxxx xxx x xx xxx xx xxx xxx xxxx xxxx x xx xxx xx xxx xxx xxx xxxxx x xxxxxx xxx xxxx xxx xx xx xx xxx x xxxx x xx xxxx xxx xxxx xxx xx xx xxx xxx x xxxxx xx xxx xxx xx xxx xx xxx xxx xxx x xxxx xx xx xx xxxx xxx xxx x xxx xx xxxx xxxxxxxx xxxxxxxx xxx xx xxxxxxxx xxxxxxx xxxxx xxxx xxx ''') def reset(self): super().reset() self.now_head = np.random.randint(self.head_num) def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp( self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_dim, self.n_agents) else: q, self.cell_state = self._get_action(s, visual_s, self.cell_state) q = q.numpy() a = np.argmax(q[self.now_head], axis=1) # [H, B, A] => [B, A] => [B, ] return a @tf.function def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True) q_values = self.q_net(feat) # [H, B, A] return q_values, cell_state def learn(self, **kwargs): self.episode = kwargs['episode'] def _update(): if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) for i in range(kwargs['step']): self._learn( function_dict={ 'train_function': self.train, 'update_function': _update, 'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.episode)]]) }) @tf.function(experimental_relax_shapes=True) def train(self, memories, isw, crsty_loss, cell_state): ss, vvss, a, r, done = memories batch_size = tf.shape(a)[0] with tf.device(self.device): with tf.GradientTape() as tape: feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True) q = self.q_net(feat) # [H, B, A] q_next = self.q_target_net(feat_) # [H, B, A] q_eval = tf.reduce_sum( tf.multiply(q, a), axis=-1, keepdims=True) # [H, B, A] * [B, A] => [H, B, 1] q_target = tf.stop_gradient( r + self.gamma * (1 - done) * tf.reduce_max(q_next, axis=-1, keepdims=True)) td_error = q_eval - q_target # [H, B, 1] td_error = tf.reduce_sum(td_error, axis=-1) # [H, B] mask_dist = tfp.distributions.Bernoulli(probs=self._probs) mask = tf.transpose(mask_dist.sample(batch_size), [1, 0]) # [H, B] q_loss = tf.reduce_mean(tf.square(td_error) * isw) + crsty_loss grads = tape.gradient(q_loss, self.critic_tv) self.optimizer.apply_gradients(zip(grads, self.critic_tv)) self.global_step.assign_add(1) return tf.reduce_mean(td_error, axis=0), dict([ # [H, B] => ['LOSS/loss', q_loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)] ])
class DQN(make_off_policy_class(mode='share')): ''' Deep Q-learning Network, DQN, [2013](https://arxiv.org/pdf/1312.5602.pdf), [2015](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) DQN + LSTM, https://arxiv.org/abs/1507.06527 ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim, is_continuous, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, hidden_units=[32, 32], **kwargs): assert not is_continuous, 'dqn only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim=a_dim, is_continuous=is_continuous, **kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self.max_train_step) self.assign_interval = assign_interval def _q_net(): return rls.critic_q_all(self.feat_dim, self.a_dim, hidden_units) self.q_net = _q_net() self.q_target_net = _q_net() self.critic_tv = self.q_net.trainable_variables + self.other_tv self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) self.lr = self.init_lr(lr) self.optimizer = self.init_optimizer(self.lr) self.model_recorder(dict(model=self.q_net, optimizer=self.optimizer)) def show_logo(self): self.recorder.logger.info(''' xxxxxxxx xxxxxx xxxx xxxx xxxxxxxx xxx xxxx xxx x xx xxx xxx xxxx xxxx x xx xxx xxx xxx xxxxx x xx xx xx xxx x xxxx x xx xx xxx xxx x xxxxx xx xxx xxx xxx x xxxx xx xxxx xxx xxx x xxx xxxxxxxx xxxxxxxx xxx xx xxxxxxx xxxxx xxxx xxx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp( self.train_step, evaluation=evaluation): a = np.random.randint(0, self.a_dim, self.n_agents) else: a, self.cell_state = self._get_action(s, visual_s, self.cell_state) a = a.numpy() return a @tf.function def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True) q_values = self.q_net(feat) return tf.argmax(q_values, axis=1), cell_state def learn(self, **kwargs): self.train_step = kwargs.get('train_step') def _update(): if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) for i in range(kwargs['step']): self._learn( function_dict={ 'train_function': self.train, 'update_function': _update, 'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.train_step)]]) }) @tf.function(experimental_relax_shapes=True) def train(self, memories, isw, crsty_loss, cell_state): ss, vvss, a, r, done = memories with tf.device(self.device): with tf.GradientTape() as tape: feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True) q = self.q_net(feat) q_next = self.q_target_net(feat_) q_eval = tf.reduce_sum(tf.multiply(q, a), axis=1, keepdims=True) q_target = tf.stop_gradient( r + self.gamma * (1 - done) * tf.reduce_max(q_next, axis=1, keepdims=True)) td_error = q_eval - q_target q_loss = tf.reduce_mean(tf.square(td_error) * isw) + crsty_loss grads = tape.gradient(q_loss, self.critic_tv) self.optimizer.apply_gradients(zip(grads, self.critic_tv)) self.global_step.assign_add(1) return td_error, dict( [['LOSS/loss', q_loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)]])
class DRQN(Off_Policy): ''' DQN + LSTM, https://arxiv.org/abs/1507.06527 ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, assign_interval=1000, hidden_units={ 'lstm': 8, 'dense': [32] }, **kwargs): assert not is_continuous, 'drqn only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.assign_interval = assign_interval self.visual_net = Nn.VisualNet('visual_net', self.visual_dim) self.q_net = Nn.drqn_critic_q_all(self.s_dim, self.a_counts, 'q_net', hidden_units, visual_net=self.visual_net) self.q_target_net = Nn.drqn_critic_q_all(self.s_dim, self.a_counts, 'q_target_net', hidden_units, visual_net=self.visual_net) self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) self.lr = tf.keras.optimizers.schedules.PolynomialDecay( lr, self.max_episode, 1e-10, power=1.0) self.optimizer = tf.keras.optimizers.Adam( learning_rate=self.lr(self.episode)) self.cell_state = None self.buffer_type = 'EpisodeER' def show_logo(self): self.recorder.logger.info(''' xxxxxxxx xxxxxxx xxxxxx xxxx xxxx xxxxxxxx xxxxxxx xxx xxxx xxx x xx xxx xx xxx xxx xxxx xxxx x xx xxx xx xxx xxx xxx xxxxx x xx xx xxxxxx xx xxx x xxxx x xx xx xxxxxx xxx xxx x xxxxx xx xxx xx xxxx xxx xxx x xxxx xx xxxx xx xxx xxx xxx x xxx xxxxxxxx xxxxx xxxx xxxxxxxx xxx xx xxxxxxx xxxxx xxxx xxxxx xxxx xxx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp( self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, len(s)) else: a = self._get_action(s, visual_s).numpy() return sth.int2action_index(a, self.a_dim_or_list) @tf.function def _get_action(self, s, visual_s): s, visual_s = self.cast(s, visual_s) s = tf.expand_dims(s, axis=1) visual_s = tf.expand_dims(visual_s, axis=1) with tf.device(self.device): q_values, self.cell_state = self.q_net(s, visual_s, self.cell_state) return tf.argmax(q_values, axis=-1) def learn(self, **kwargs): self.episode = kwargs['episode'] def _update(): if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) for i in range(kwargs['step']): self._learn( function_dict={ 'train_function': self.train, 'update_function': _update, 'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.episode)]]) }) @tf.function(experimental_relax_shapes=True) def train(self, s, visual_s, a, r, s_, visual_s_, done): pad = lambda x: tf.keras.preprocessing.sequence.pad_sequences( x, padding='post', dtype='float32', value=0.) s, visual_s, a, r, s_, visual_s_ = map( pad, [s, visual_s, a, r, s_, visual_s_]) done = tf.keras.preprocessing.sequence.pad_sequences(done, padding='post', dtype='float32', value=1.) a, r, done = map(lambda x: tf.reshape(x, (-1, x.shape[-1])), [a, r, done]) # [B, T, N] => [B*T, N] with tf.device(self.device): with tf.GradientTape() as tape: q, _ = self.q_net(s, visual_s) q_next, _ = self.q_target_net(s_, visual_s_) q_eval = tf.reduce_sum(tf.multiply(q, a), axis=-1, keepdims=True) q_target = tf.stop_gradient( r + self.gamma * (1 - done) * tf.reduce_max(q_next, axis=-1, keepdims=True)) td_error = q_eval - q_target q_loss = tf.reduce_mean(tf.square(td_error) * self.IS_w) grads = tape.gradient(q_loss, self.q_net.tv) self.optimizer.apply_gradients(zip(grads, self.q_net.tv)) self.global_step.assign_add(1) return td_error, dict( [['LOSS/loss', q_loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)]])
class MAXSQN(Off_Policy): ''' https://github.com/createamind/DRL/blob/master/spinup/algos/maxsqn/maxsqn.py ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, alpha=0.2, beta=0.1, ployak=0.995, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, use_epsilon=False, q_lr=5.0e-4, alpha_lr=5.0e-4, auto_adaption=True, hidden_units=[32, 32], **kwargs): assert not is_continuous, 'maxsqn only support discrete action space' super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.use_epsilon = use_epsilon self.ployak = ployak self.log_alpha = alpha if not auto_adaption else tf.Variable( initial_value=0.0, name='log_alpha', dtype=tf.float32, trainable=True) self.auto_adaption = auto_adaption self.target_alpha = beta * np.log(self.a_counts) self.visual_net = Nn.VisualNet('visual_net', self.visual_dim) self.q1_net = Nn.critic_q_all(self.s_dim, self.a_counts, 'q1_net', hidden_units, visual_net=self.visual_net) self.q1_target_net = Nn.critic_q_all(self.s_dim, self.a_counts, 'q1_target_net', hidden_units, visual_net=self.visual_net) self.q2_net = Nn.critic_q_all(self.s_dim, self.a_counts, 'q2_net', hidden_units, visual_net=self.visual_net) self.q2_target_net = Nn.critic_q_all(self.s_dim, self.a_counts, 'q2_target_net', hidden_units, visual_net=self.visual_net) self.update_target_net_weights( self.q1_target_net.weights + self.q2_target_net.weights, self.q1_net.weights + self.q2_net.weights) self.q_lr = tf.keras.optimizers.schedules.PolynomialDecay( q_lr, self.max_episode, 1e-10, power=1.0) self.alpha_lr = tf.keras.optimizers.schedules.PolynomialDecay( alpha_lr, self.max_episode, 1e-10, power=1.0) self.optimizer_critic = tf.keras.optimizers.Adam( learning_rate=self.q_lr(self.episode)) self.optimizer_alpha = tf.keras.optimizers.Adam( learning_rate=self.alpha_lr(self.episode)) def show_logo(self): self.recorder.logger.info(''' xx xx xxxxxx xxxxxx xxxx xx xxx xxx xxx xxx xxxx xxx xxxx xx xxx xxx xxxxx x xx xx xx xx xxxxx xx xxxx xxx xxxxxx xx xxx xxxxxx xx xxx xx xxx xx xxxx xx x x xxx xxxxx xxxxxx xx xx xx xxxxx xxxx xx x xxxxxx xxx xxx xxx x xxx xx xxxx xx xxx x xxx xx xxx xx xx xx xxxxx xx xxxx xx xxx x xx xxx xxxxx xxxxxxxxx xxx xxxx xx xxx xx xxx x xxxxxxxx xxx xxx xxxxxxx xxxxxxx xx xx xxxxxxx ''') def choose_action(self, s, visual_s, evaluation=False): if self.use_epsilon and np.random.uniform( ) < self.expl_expt_mng.get_esp(self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, len(s)) else: a = self._get_action(s, visual_s)[-1].numpy() return sth.int2action_index(a, self.a_dim_or_list) @tf.function def _get_action(self, s, visual_s): s, visual_s = self.cast(s, visual_s) with tf.device(self.device): q = self.q1_net(s, visual_s) cate_dist = tfp.distributions.Categorical(logits=q / tf.exp(self.log_alpha)) pi = cate_dist.sample() return tf.argmax(q, axis=1), pi def learn(self, **kwargs): self.episode = kwargs['episode'] for i in range(kwargs['step']): if self.data.is_lg_batch_size: s, visual_s, a, r, s_, visual_s_, done = self.data.sample() if self.use_priority: self.IS_w = self.data.get_IS_w() td_error, summaries = self.train(s, visual_s, a, r, s_, visual_s_, done) if self.use_priority: td_error = np.squeeze(td_error.numpy()) self.data.update(td_error, self.episode) self.update_target_net_weights( self.q1_target_net.weights + self.q2_target_net.weights, self.q1_net.weights + self.q2_net.weights, self.ployak) summaries.update( dict([['LEARNING_RATE/q_lr', self.q_lr(self.episode)], [ 'LEARNING_RATE/alpha_lr', self.alpha_lr(self.episode) ]])) self.write_training_summaries(self.global_step, summaries) @tf.function(experimental_relax_shapes=True) def train(self, s, visual_s, a, r, s_, visual_s_, done): s, visual_s, a, r, s_, visual_s_, done = self.cast( s, visual_s, a, r, s_, visual_s_, done) with tf.device(self.device): with tf.GradientTape() as tape: q1 = self.q1_net(s, visual_s) q1_eval = tf.reduce_sum(tf.multiply(q1, a), axis=1, keepdims=True) q2 = self.q2_net(s, visual_s) q2_eval = tf.reduce_sum(tf.multiply(q2, a), axis=1, keepdims=True) q1_target = self.q1_target_net(s_, visual_s_) q1_target_max = tf.reduce_max(q1_target, axis=1, keepdims=True) q1_target_log_probs = tf.nn.log_softmax( q1_target / tf.exp(self.log_alpha), axis=1) + 1e-8 q1_target_log_max = tf.reduce_max(q1_target_log_probs, axis=1, keepdims=True) q1_target_entropy = -tf.reduce_mean( tf.reduce_sum( tf.exp(q1_target_log_probs) * q1_target_log_probs, axis=1, keepdims=True)) q2_target = self.q2_target_net(s_, visual_s_) q2_target_max = tf.reduce_max(q2_target, axis=1, keepdims=True) # q2_target_log_probs = tf.nn.log_softmax(q2_target, axis=1) # q2_target_log_max = tf.reduce_max(q2_target_log_probs, axis=1, keepdims=True) q_target = tf.minimum( q1_target_max, q2_target_max) + tf.exp(self.log_alpha) * q1_target_entropy dc_r = tf.stop_gradient(r + self.gamma * q_target * (1 - done)) td_error1 = q1_eval - dc_r td_error2 = q2_eval - dc_r q1_loss = tf.reduce_mean(tf.square(td_error1) * self.IS_w) q2_loss = tf.reduce_mean(tf.square(td_error2) * self.IS_w) loss = 0.5 * (q1_loss + q2_loss) loss_grads = tape.gradient(loss, self.q1_net.tv + self.q2_net.tv) self.optimizer_critic.apply_gradients( zip(loss_grads, self.q1_net.tv + self.q2_net.tv)) if self.auto_adaption: with tf.GradientTape() as tape: q1 = self.q1_net(s, visual_s) q1_log_probs = tf.nn.log_softmax( q1_target / tf.exp(self.log_alpha), axis=1) + 1e-8 q1_log_max = tf.reduce_max(q1_log_probs, axis=1, keepdims=True) q1_entropy = -tf.reduce_mean( tf.reduce_sum(tf.exp(q1_log_probs) * q1_log_probs, axis=1, keepdims=True)) alpha_loss = -tf.reduce_mean( self.log_alpha * tf.stop_gradient(self.target_alpha - q1_entropy)) alpha_grads = tape.gradient(alpha_loss, [self.log_alpha]) self.optimizer_alpha.apply_gradients( zip(alpha_grads, [self.log_alpha])) self.global_step.assign_add(1) summaries = dict( [['LOSS/loss', loss], ['Statistics/log_alpha', self.log_alpha], ['Statistics/alpha', tf.exp(self.log_alpha)], ['Statistics/q1_entropy', q1_entropy], ['Statistics/q_min', tf.reduce_mean(tf.minimum(q1, q2))], ['Statistics/q_mean', tf.reduce_mean(q1)], ['Statistics/q_max', tf.reduce_mean(tf.maximum(q1, q2))]]) if self.auto_adaption: summaries.update({'LOSS/alpha_loss': alpha_loss}) return td_error1 + td_error2 / 2, summaries
class QS: ''' Q-learning/Sarsa/Expected Sarsa. ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, mode='q', lr=0.2, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, **kwargs): assert not hasattr(s_dim, '__len__') assert not is_continuous self.mode = mode self.s_dim = s_dim self.a_dim_or_list = a_dim_or_list self.gamma = float(kwargs.get('gamma', 0.999)) self.max_episode = int(kwargs.get('max_episode', 1000)) self.step = 0 self.a_counts = int(np.asarray(a_dim_or_list).prod()) self.episode = 0 # episode of now self.n_agents = int(kwargs.get('n_agents', 0)) if self.n_agents <= 0: raise ValueError('agents num must larger than zero.') self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.table = np.zeros(shape=(self.s_dim, self.a_counts)) self.lr = lr self.next_a = np.zeros(self.n_agents, dtype=np.int32) self.mask = [] ion() def one_hot2int(self, x): idx = [np.where(np.asarray(i))[0][0] for i in x] return idx def partial_reset(self, done): self.mask = np.where(done)[0] def choose_action(self, s, visual_s=None, evaluation=False): s = self.one_hot2int(s) if self.mode == 'q': return self._get_action(s, evaluation) elif self.mode == 'sarsa' or self.mode == 'expected_sarsa': a = self._get_action(s, evaluation) self.next_a[self.mask] = a[self.mask] return self.next_a def _get_action(self, s, evaluation=False, _max=False): a = np.array([np.argmax(self.table[i, :]) for i in s]) if not _max: if np.random.uniform() < self.expl_expt_mng.get_esp( self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, self.n_agents) return a def learn(self, **kwargs): self.episode = kwargs['episode'] def store_data(self, s, visual_s, a, r, s_, visual_s_, done): self.step += 1 s = self.one_hot2int(s) s_ = self.one_hot2int(s_) if self.mode == 'q': a_ = self._get_action(s_, _max=True) value = self.table[s_, a_] else: self.next_a = self._get_action(s_) if self.mode == 'expected_sarsa': value = np.mean(self.table[s_, :], axis=-1) else: value = self.table[s_, self.next_a] self.table[s, a] = (1 - self.lr) * self.table[s, a] + self.lr * ( r + self.gamma * (1 - done) * value) if self.step % 1000 == 0: plot_heatmap(self.s_dim, self.a_counts, self.table) def close(self): ioff() def no_op_store(self, s, visual_s, a, r, s_, visual_s_, done): pass def __getattr__(self, x): # print(x) return lambda *args, **kwargs: 0
class C51(Off_Policy): ''' Category 51, https://arxiv.org/abs/1707.06887 No double, no dueling, no noisy net. ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, v_min=-10, v_max=10, atoms=51, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, assign_interval=1000, hidden_units=[128, 128], **kwargs): assert not is_continuous, 'c51 only support discrete action space' super().__init__( s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.v_min = v_min self.v_max = v_max self.atoms = atoms self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1) self.z = tf.reshape(tf.constant([self.v_min + i * self.delta_z for i in range(self.atoms)], dtype=tf.float32), [-1, self.atoms]) # [1, N] self.zb = tf.tile(self.z, tf.constant([self.a_counts, 1])) # [A, N] self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.assign_interval = assign_interval _net = lambda: Nn.c51_distributional(self.rnn_net.hdim, self.a_counts, self.atoms, hidden_units) self.q_dist_net = _net() self.q_target_dist_net = _net() self.critic_tv = self.q_dist_net.trainable_variables + self.other_tv self.update_target_net_weights(self.q_target_dist_net.weights, self.q_dist_net.weights) self.lr = self.init_lr(lr) self.optimizer = self.init_optimizer(self.lr) self.model_recorder(dict( model=self.q_dist_net, optimizer=self.optimizer )) def show_logo(self): self.recorder.logger.info(''' xxxxxxx xxxxx xxx xxxx xxx xxxx xxxx xxxx x xxxx xx xxx x xxxxx xx xxx xxx xx xxx xxx xx xxx xx xx xxx x xx xx xx xxxxxxxx xxxxx xxxx xxxxx x xxxx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp(self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, self.n_agents) else: a, self.cell_state = self._get_action(s, visual_s, self.cell_state) a = a.numpy() return a @tf.function def _get_action(self, s, visual_s, cell_state): with tf.device(self.device): feat, cell_state = self.get_feature(s, visual_s, cell_state=cell_state, record_cs=True, train=False) q = self.get_q(feat) # [B, A] return tf.argmax(q, axis=-1), cell_state # [B, 1] def learn(self, **kwargs): self.episode = kwargs['episode'] def _update(): if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_dist_net.weights, self.q_dist_net.weights) for i in range(kwargs['step']): self._learn(function_dict={ 'train_function': self.train, 'update_function': _update, 'summary_dict': dict([['LEARNING_RATE/lr', self.lr(self.episode)]]) }) @tf.function(experimental_relax_shapes=True) def train(self, memories, isw, crsty_loss, cell_state): ss, vvss, a, r, done = memories batch_size = tf.shape(a)[0] with tf.device(self.device): with tf.GradientTape() as tape: feat, feat_ = self.get_feature(ss, vvss, cell_state=cell_state, s_and_s_=True) indexs = tf.reshape(tf.range(batch_size), [-1, 1]) # [B, 1] q_dist = self.q_dist_net(feat) # [B, A, N] q_dist = tf.transpose(tf.reduce_sum(tf.transpose(q_dist, [2, 0, 1]) * a, axis=-1), [1, 0]) # [B, N] q_eval = tf.reduce_sum(q_dist * self.z, axis=-1) target_q_dist = self.q_target_dist_net(feat_) # [B, A, N] target_q = tf.reduce_sum(self.zb * target_q_dist, axis=-1) # [B, A, N] => [B, A] a_ = tf.reshape(tf.cast(tf.argmax(target_q, axis=-1), dtype=tf.int32), [-1, 1]) # [B, 1] target_q_dist = tf.gather_nd(target_q_dist, tf.concat([indexs, a_], axis=-1)) # [B, N] target = tf.tile(r, tf.constant([1, self.atoms])) \ + self.gamma * tf.multiply(self.z, # [1, N] (1.0 - tf.tile(done, tf.constant([1, self.atoms])))) # [B, N], [1, N]* [B, N] = [B, N] target = tf.clip_by_value(target, self.v_min, self.v_max) # [B, N] b = (target - self.v_min) / self.delta_z # [B, N] u, l = tf.math.ceil(b), tf.math.floor(b) # [B, N] u_id, l_id = tf.cast(u, tf.int32), tf.cast(l, tf.int32) # [B, N] u_minus_b, b_minus_l = u - b, b - l # [B, N] index_help = tf.tile(indexs, tf.constant([1, self.atoms])) # [B, N] index_help = tf.expand_dims(index_help, -1) # [B, N, 1] u_id = tf.concat([index_help, tf.expand_dims(u_id, -1)], axis=-1) # [B, N, 2] l_id = tf.concat([index_help, tf.expand_dims(l_id, -1)], axis=-1) # [B, N, 2] _cross_entropy = tf.stop_gradient(target_q_dist * u_minus_b) * tf.math.log(tf.gather_nd(q_dist, l_id)) \ + tf.stop_gradient(target_q_dist * b_minus_l) * tf.math.log(tf.gather_nd(q_dist, u_id)) # [B, N] # tf.debugging.check_numerics(_cross_entropy, '_cross_entropy') cross_entropy = -tf.reduce_sum(_cross_entropy, axis=-1) # [B,] # tf.debugging.check_numerics(cross_entropy, 'cross_entropy') loss = tf.reduce_mean(cross_entropy * isw) + crsty_loss td_error = cross_entropy grads = tape.gradient(loss, self.critic_tv) self.optimizer.apply_gradients( zip(grads, self.critic_tv) ) self.global_step.assign_add(1) return td_error, dict([ ['LOSS/loss', loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)] ]) @tf.function(experimental_relax_shapes=True) def get_q(self, feat): with tf.device(self.device): return tf.reduce_sum(self.zb * self.q_dist_net(feat), axis=-1) # [B, A, N] => [B, A]
class IQN(Off_Policy): ''' Implicit Quantile Networks Double DQN ''' def __init__(self, s_dim, visual_sources, visual_resolution, a_dim_or_list, is_continuous, online_quantiles=8, target_quantiles=8, select_quantiles=32, quantiles_idx=64, huber_delta=1., lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_episode=100, assign_interval=2, hidden_units={ 'q_net': [128, 64], 'quantile': [128, 64], 'tile': [64] }, **kwargs): assert not is_continuous, 'iqn only support discrete action space' super().__init__( s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim_or_list=a_dim_or_list, is_continuous=is_continuous, **kwargs) self.pi = tf.constant(np.pi) self.online_quantiles = online_quantiles self.target_quantiles = target_quantiles self.select_quantiles = select_quantiles self.quantiles_idx = quantiles_idx self.huber_delta = huber_delta self.assign_interval = assign_interval self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_episode=init2mid_annealing_episode, max_episode=self.max_episode) self.visual_net = Nn.VisualNet('visual_net', self.visual_dim) self.q_net = Nn.iqn_net(self.s_dim, self.a_counts, self.quantiles_idx, 'q_net', hidden_units, visual_net=self.visual_net) self.q_target_net = Nn.iqn_net(self.s_dim, self.a_counts, self.quantiles_idx, 'q_target_net', hidden_units, visual_net=self.visual_net) self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) self.lr = tf.keras.optimizers.schedules.PolynomialDecay(lr, self.max_episode, 1e-10, power=1.0) self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr(self.episode)) self.recorder.logger.info(''' xxxxxxxx xxxxxxx xxx xxx xxxxxxxx xxxxxxxxx xxxx xxx xxx xxxx xxxx xxxxx xxx xxx xxx xxx xxxxx xxx xxx xxxx xxx xxxxxx xxx xxx xxxx xxx xxxxxxxxxx xxx xxxx xxx xxx xxxxxx xxx xxxx xxxx xxx xxxxxx xxxxxxxx xxxxxxxxx xxx xxxxx xxxxxxxx xxxxxxx xxx xxxx xxxx xxxx xxxx ''') def choose_action(self, s, visual_s, evaluation=False): if np.random.uniform() < self.expl_expt_mng.get_esp(self.episode, evaluation=evaluation): a = np.random.randint(0, self.a_counts, len(s)) else: a = self._get_action(s, visual_s).numpy() return sth.int2action_index(a, self.a_dim_or_list) @tf.function def _get_action(self, s, visual_s): s, visual_s = self.cast(s, visual_s) with tf.device(self.device): _, select_quantiles_tiled = self._generate_quantiles( # [N*B, 64] batch_size=s.shape[0], quantiles_num=self.select_quantiles, quantiles_idx=self.quantiles_idx ) _, q_values = self.q_net(s, visual_s, select_quantiles_tiled, quantiles_num=self.select_quantiles) # [B, A] return tf.argmax(q_values, axis=-1) # [B,] @tf.function def _generate_quantiles(self, batch_size, quantiles_num, quantiles_idx): with tf.device(self.device): _quantiles = tf.random.uniform([batch_size * quantiles_num, 1], minval=0, maxval=1) # [N*B, 1] _quantiles_tiled = tf.tile(_quantiles, [1, quantiles_idx]) # [N*B, 1] => [N*B, 64] _quantiles_tiled = tf.cast(tf.range(quantiles_idx), tf.float32) * self.pi * _quantiles_tiled # pi * i * tau [N*B, 64] * [64, ] => [N*B, 64] _quantiles_tiled = tf.cos(_quantiles_tiled) # [N*B, 64] _quantiles = tf.reshape(_quantiles, [batch_size, quantiles_num, 1]) # [N*B, 1] => [B, N, 1] return _quantiles, _quantiles_tiled def learn(self, **kwargs): self.episode = kwargs['episode'] for i in range(kwargs['step']): if self.data.is_lg_batch_size: s, visual_s, a, r, s_, visual_s_, done = self.data.sample() if self.use_priority: self.IS_w = self.data.get_IS_w() td_error, summaries = self.train(s, visual_s, a, r, s_, visual_s_, done) if self.use_priority: td_error = np.squeeze(td_error.numpy()) self.data.update(td_error, self.episode) if self.global_step % self.assign_interval == 0: self.update_target_net_weights(self.q_target_net.weights, self.q_net.weights) summaries.update(dict([ ['LEARNING_RATE/lr', self.lr(self.episode)] ])) self.write_training_summaries(self.global_step, summaries) @tf.function(experimental_relax_shapes=True) def train(self, s, visual_s, a, r, s_, visual_s_, done): s, visual_s, a, r, s_, visual_s_, done = self.cast(s, visual_s, a, r, s_, visual_s_, done) with tf.device(self.device): with tf.GradientTape() as tape: quantiles, quantiles_tiled = self._generate_quantiles( # [B, N, 1], [N*B, 64] batch_size=s.shape[0], quantiles_num=self.online_quantiles, quantiles_idx=self.quantiles_idx ) quantiles_value, q = self.q_net(s, visual_s, quantiles_tiled, quantiles_num=self.online_quantiles) # [N, B, A], [B, A] _a = tf.reshape(tf.tile(a, [self.online_quantiles, 1]), [self.online_quantiles, -1, self.a_counts]) # [B, A] => [N*B, A] => [N, B, A] quantiles_value = tf.reduce_sum(quantiles_value * _a, axis=-1, keepdims=True) # [N, B, A] => [N, B, 1] q_eval = tf.reduce_sum(q * a, axis=-1, keepdims=True) # [B, A] => [B, 1] next_max_action = self._get_action(s_, visual_s_) # [B,] next_max_action = tf.one_hot(tf.squeeze(next_max_action), self.a_counts, 1., 0., dtype=tf.float32) # [B, A] _next_max_action = tf.reshape(tf.tile(next_max_action, [self.target_quantiles, 1]), [self.target_quantiles, -1, self.a_counts]) # [B, A] => [N'*B, A] => [N', B, A] _, target_quantiles_tiled = self._generate_quantiles( # [N'*B, 64] batch_size=s_.shape[0], quantiles_num=self.target_quantiles, quantiles_idx=self.quantiles_idx ) target_quantiles_value, target_q = self.q_target_net(s_, visual_s_, target_quantiles_tiled, quantiles_num=self.target_quantiles) # [N', B, A], [B, A] target_quantiles_value = tf.reduce_sum(target_quantiles_value * _next_max_action, axis=-1, keepdims=True) # [N', B, A] => [N', B, 1] target_q = tf.reduce_sum(target_q * a, axis=-1, keepdims=True) # [B, A] => [B, 1] q_target = tf.stop_gradient(r + self.gamma * (1 - done) * target_q) # [B, 1] td_error = q_eval - q_target # [B, 1] _r = tf.reshape(tf.tile(r, [self.target_quantiles, 1]), [self.target_quantiles, -1, 1]) # [B, 1] => [N'*B, 1] => [N', B, 1] _done = tf.reshape(tf.tile(done, [self.target_quantiles, 1]), [self.target_quantiles, -1, 1]) # [B, 1] => [N'*B, 1] => [N', B, 1] quantiles_value_target = tf.stop_gradient(_r + self.gamma * (1 - _done) * target_quantiles_value) # [N', B, 1] quantiles_value_target = tf.transpose(quantiles_value_target, [1, 2, 0]) # [B, 1, N'] quantiles_value_online = tf.transpose(quantiles_value, [1, 0, 2]) # [B, N, 1] quantile_error = quantiles_value_online - quantiles_value_target # [B, N, 1] - [B, 1, N'] => [B, N, N'] huber = huber_loss(quantile_error, delta=self.huber_delta) # [B, N, N'] huber_abs = tf.abs(quantiles - tf.where(quantile_error < 0, tf.ones_like(quantile_error), tf.zeros_like(quantile_error))) # [B, N, 1] - [B, N, N'] => [B, N, N'] loss = tf.reduce_mean(huber_abs * huber, axis=-1) # [B, N, N'] => [B, N] loss = tf.reduce_sum(loss, axis=-1) # [B, N] => [B, ] loss = tf.reduce_mean(loss * self.IS_w) # [B, ] => 1 grads = tape.gradient(loss, self.q_net.tv) self.optimizer.apply_gradients( zip(grads, self.q_net.tv) ) self.global_step.assign_add(1) return td_error, dict([ ['LOSS/loss', loss], ['Statistics/q_max', tf.reduce_max(q_eval)], ['Statistics/q_min', tf.reduce_min(q_eval)], ['Statistics/q_mean', tf.reduce_mean(q_eval)] ])