def call(self, qs, state): assert_rank(qs, 3) assert_rank(state, 3) B, seqlen = qs.shape[:2] tf.debugging.assert_shapes([ [qs, (B, seqlen, self.n_agents)], [state, (B, seqlen, None)], ]) qs = tf.reshape(qs, (-1, self.n_agents)) state = tf.reshape(state, (-1, state.shape[-1])) w1 = tf.math.abs(self.w1(state)) w1 = tf.reshape(w1, (-1, self.n_agents, self.hidden_dim)) b = self.b(state) h = tf.nn.elu(tf.einsum('ba,bah->bh', qs, w1) + b) tf.debugging.assert_shapes([ [b, (B * seqlen, self.hidden_dim)], [h, (B * seqlen, self.hidden_dim)], ]) w2 = tf.math.abs(self.w2(state)) w2 = tf.reshape(w2, (-1, self.hidden_dim, 1)) v = self.v(state) y = tf.einsum('bh,bho->bo', h, w2) + v y = tf.reshape(y, [-1, seqlen]) tf.debugging.assert_shapes([ [v, (B * seqlen, 1)], [y, (B, seqlen)], ]) return y
def _process_additional_input(self, x, prev_action, prev_reward): results = [] if prev_action is not None: prev_action = tf.reshape(prev_action, (-1, 1)) prev_action = tf.one_hot(prev_action, self.actor.action_dim, dtype=x.dtype) results.append(prev_action) if prev_reward is not None: prev_reward = tf.reshape(prev_reward, (-1, 1, 1)) results.append(prev_reward) assert_rank(results, 3) return results
def quantile_regression_loss(qtv, target, tau_hat, kappa=1., return_error=False): assert qtv.shape[-1] == 1, qtv.shape assert target.shape[-2] == 1, target.shape assert tau_hat.shape[-1] == 1, tau_hat.shape assert_rank([qtv, target, tau_hat]) error = target - qtv # [B, N, N'] weight = tf.abs(tau_hat - tf.cast(error < 0, tf.float32)) # [B, N, N'] huber = huber_loss(error, threshold=kappa) # [B, N, N'] qr_loss = tf.reduce_sum(tf.reduce_mean(weight * huber, axis=-1), axis=-2) # [B] if return_error: return error, qr_loss return qr_loss
def call(self, x, state, mask, additional_input=[]): xs = [x] + additional_input mask = tf.expand_dims(mask, axis=-1) assert_rank(xs + [mask], 3) if not self._state_mask: # mask out inputs for i, v in enumerate(xs): xs[i] *= tf.cast(mask, v.dtype) x = tf.concat(xs, axis=-1) if len(xs) > 1 else xs[0] if not mask.dtype.is_compatible_with(global_policy().compute_dtype): mask = tf.cast(mask, global_policy().compute_dtype) x = self._rnn((x, mask), initial_state=state) x, state = x[0], GRUState(x[1]) return x, state
def retrace(reward, next_qs, next_action, next_pi, next_mu_a, discount, lambda_=.95, ratio_clip=1, axis=0, tbo=False, regularization=None): """ discount = gamma * (1-done). axis specifies the time dimension """ if isinstance(discount, (int, float)): discount = discount * tf.ones_like(reward) if next_action.dtype.is_integer: next_action = tf.one_hot(next_action, next_pi.shape[-1], dtype=next_pi.dtype) assert_rank_and_shape_compatibility([next_action, next_pi], reward.shape.ndims + 1) next_pi_a = tf.reduce_sum(next_pi * next_action, axis=-1) next_ratio = next_pi_a / next_mu_a if ratio_clip is not None: next_ratio = tf.minimum(next_ratio, ratio_clip) next_c = next_ratio * lambda_ if tbo: next_qs = inverse_h(next_qs) next_v = tf.reduce_sum(next_qs * next_pi, axis=-1) if regularization is not None: next_v -= regularization next_q = tf.reduce_sum(next_qs * next_action, axis=-1) current = reward + discount * (next_v - next_c * next_q) # swap 'axis' with the 0-th dimension dims = list(range(reward.shape.ndims)) dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] if axis != 0: next_q = tf.transpose(next_q, dims) current = tf.transpose(current, dims) discount = tf.transpose(discount, dims) next_c = tf.transpose(next_c, dims) assert_rank([current, discount, next_c]) target = static_scan( lambda acc, x: x[0] + x[1] * x[2] * acc, next_q[-1], (current, discount, next_c), reverse=True) if axis != 0: target = tf.transpose(target, dims) if tbo: target = h(target) return target
def call(self, x, states): x, mask = tf.nest.flatten(x) h = states[0] assert_rank([x, h, mask], 2) if mask is not None: h = h * mask # it sigfinicantly increases the running time when separate normalizations are applied to x and h x = self.x_ln(tf.matmul(tf.concat([x, h], -1), self.kernel)) # x = self.x_ln(tf.matmul(x, self.kernel)) + self.h_ln(tf.matmul(h, self.recurrent_kernel)) if self.use_bias: x = tf.nn.bias_add(x, self.bias) r, c, z = tf.split(x, 3, 1) r, z = self.recurrent_activation(r), self.recurrent_activation(z) c = self.activation(c) h = z * c + (1-z) * h return h, GRUState(h)
def encode(self, x, state, mask, prev_action=None, prev_reward=None): if x.shape.ndims % 2 == 0: x = tf.expand_dims(x, 1) if mask.shape.ndims < 2: mask = tf.reshape(mask, (-1, 1)) assert_rank(mask, 2) x = self.encoder(x) if hasattr(self, 'rnn'): additional_rnn_input = self._process_additional_input( x, prev_action, prev_reward) x, state = self.rnn(x, state, mask, additional_input=additional_rnn_input) else: state = None if x.shape[1] == 1: x = tf.squeeze(x, 1) return x, state
def call(self, x, states): x, mask = tf.nest.flatten(x) assert_rank([x, mask], 2) h, c = states if mask is not None: h = h * mask c = c * mask x = self.x_ln(tf.matmul(x, self.kernel)) + self.h_ln( tf.matmul(h, self.recurrent_kernel)) if self.use_bias: x = tf.nn.bias_add(x, self.bias) i, f, c_, o = tf.split(x, 4, 1) i, f, o = self.recurrent_activation(i), self.recurrent_activation( f), self.recurrent_activation(o) c_ = self.activation(c_) c = f * c + i * c_ h = o * self.activation(self.c_ln(c)) return h, LSTMState(h, c)
def encode(self, obs, state, online=True): encoder = self.q_encoder if online else self.target_q_encoder rnn = self.q_rnn if online else self.target_q_rnn if obs.shape.ndims % 2 != 0: obs = tf.expand_dims(obs, 1) assert_rank(obs, 4) x = encoder(obs) # [B, S, A, F] seqlen, n_agents = x.shape[1:3] tf.debugging.assert_equal(n_agents, self.qmixer.n_agents) x = tf.transpose(x, [0, 2, 1, 3]) # [B, A, S, F] x = tf.reshape(x, [-1, *x.shape[2:]]) # [B * A, S, F] x = rnn(x, state) x, state = x[0], self.State(*x[1:]) x = tf.reshape(x, (-1, n_agents, seqlen, x.shape[-1])) # [B, A, S, F] x = tf.transpose(x, [0, 2, 1, 3]) # [B, S, A, F] if seqlen == 1: x = tf.squeeze(x, 1) return x, state
def _process_additional_input(self, x, prev_action, prev_reward): results = [] if prev_action is not None: if self.actor.is_action_discrete: if prev_action.shape.ndims < 2: prev_action = tf.reshape(prev_action, (-1, 1)) prev_action = tf.one_hot(prev_action, self.actor.action_dim, dtype=x.dtype) else: if prev_action.shape.ndims < 3: prev_action = tf.reshape(prev_action, (-1, 1, self.actor.action_dim)) assert_rank(prev_action, 3) results.append(prev_action) if prev_reward is not None: if prev_reward.shape.ndims < 2: prev_reward = tf.reshape(prev_reward, (-1, 1, 1)) elif prev_reward.shape.ndims == 2: prev_reward = tf.expand_dims(prev_reward, -1) assert_rank(prev_reward, 3) results.append(prev_reward) assert_rank(results, 3) return results