Ejemplo n.º 1
0
Archivo: nn.py Proyecto: xlnwel/d2rl
    def call(self, qs, state):
        assert_rank(qs, 3)
        assert_rank(state, 3)
        B, seqlen = qs.shape[:2]
        tf.debugging.assert_shapes([
            [qs, (B, seqlen, self.n_agents)],
            [state, (B, seqlen, None)],
        ])
        qs = tf.reshape(qs, (-1, self.n_agents))
        state = tf.reshape(state, (-1, state.shape[-1]))

        w1 = tf.math.abs(self.w1(state))
        w1 = tf.reshape(w1, (-1, self.n_agents, self.hidden_dim))
        b = self.b(state)
        h = tf.nn.elu(tf.einsum('ba,bah->bh', qs, w1) + b)
        tf.debugging.assert_shapes([
            [b, (B * seqlen, self.hidden_dim)],
            [h, (B * seqlen, self.hidden_dim)],
        ])
        w2 = tf.math.abs(self.w2(state))
        w2 = tf.reshape(w2, (-1, self.hidden_dim, 1))
        v = self.v(state)
        y = tf.einsum('bh,bho->bo', h, w2) + v
        y = tf.reshape(y, [-1, seqlen])
        tf.debugging.assert_shapes([
            [v, (B * seqlen, 1)],
            [y, (B, seqlen)],
        ])

        return y
Ejemplo n.º 2
0
 def _process_additional_input(self, x, prev_action, prev_reward):
     results = []
     if prev_action is not None:
         prev_action = tf.reshape(prev_action, (-1, 1))
         prev_action = tf.one_hot(prev_action,
                                  self.actor.action_dim,
                                  dtype=x.dtype)
         results.append(prev_action)
     if prev_reward is not None:
         prev_reward = tf.reshape(prev_reward, (-1, 1, 1))
         results.append(prev_reward)
     assert_rank(results, 3)
     return results
Ejemplo n.º 3
0
def quantile_regression_loss(qtv, target, tau_hat, kappa=1., return_error=False):
    assert qtv.shape[-1] == 1, qtv.shape
    assert target.shape[-2] == 1, target.shape
    assert tau_hat.shape[-1] == 1, tau_hat.shape
    assert_rank([qtv, target, tau_hat])
    error = target - qtv           # [B, N, N']
    weight = tf.abs(tau_hat - tf.cast(error < 0, tf.float32))   # [B, N, N']
    huber = huber_loss(error, threshold=kappa)                  # [B, N, N']
    qr_loss = tf.reduce_sum(tf.reduce_mean(weight * huber, axis=-1), axis=-2) # [B]

    if return_error:
        return error, qr_loss
    return qr_loss
Ejemplo n.º 4
0
Archivo: gru.py Proyecto: xlnwel/d2rl
 def call(self, x, state, mask, additional_input=[]):
     xs = [x] + additional_input
     mask = tf.expand_dims(mask, axis=-1)
     assert_rank(xs + [mask], 3)
     if not self._state_mask:
         # mask out inputs
         for i, v in enumerate(xs):
             xs[i] *= tf.cast(mask, v.dtype)
     x = tf.concat(xs, axis=-1) if len(xs) > 1 else xs[0]
     if not mask.dtype.is_compatible_with(global_policy().compute_dtype):
         mask = tf.cast(mask, global_policy().compute_dtype)
     x = self._rnn((x, mask), initial_state=state)
     x, state = x[0], GRUState(x[1])
     return x, state
Ejemplo n.º 5
0
def retrace(reward, next_qs, next_action, next_pi, next_mu_a, discount, 
        lambda_=.95, ratio_clip=1, axis=0, tbo=False, regularization=None):
    """
    discount = gamma * (1-done). 
    axis specifies the time dimension
    """
    if isinstance(discount, (int, float)):
        discount = discount * tf.ones_like(reward)
    if next_action.dtype.is_integer:
        next_action = tf.one_hot(next_action, next_pi.shape[-1], dtype=next_pi.dtype)
    assert_rank_and_shape_compatibility([next_action, next_pi], reward.shape.ndims + 1)
    next_pi_a = tf.reduce_sum(next_pi * next_action, axis=-1)
    next_ratio = next_pi_a / next_mu_a
    if ratio_clip is not None:
        next_ratio = tf.minimum(next_ratio, ratio_clip)
    next_c = next_ratio * lambda_

    if tbo:
        next_qs = inverse_h(next_qs)
    next_v = tf.reduce_sum(next_qs * next_pi, axis=-1)
    if regularization is not None:
        next_v -= regularization
    next_q = tf.reduce_sum(next_qs * next_action, axis=-1)
    current = reward + discount * (next_v - next_c * next_q)

    # swap 'axis' with the 0-th dimension
    dims = list(range(reward.shape.ndims))
    dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
    if axis != 0:
        next_q = tf.transpose(next_q, dims)
        current = tf.transpose(current, dims)
        discount = tf.transpose(discount, dims)
        next_c = tf.transpose(next_c, dims)

    assert_rank([current, discount, next_c])
    target = static_scan(
        lambda acc, x: x[0] + x[1] * x[2] * acc,
        next_q[-1], (current, discount, next_c), 
        reverse=True)

    if axis != 0:
        target = tf.transpose(target, dims)

    if tbo:
        target = h(target)
        
    return target
Ejemplo n.º 6
0
Archivo: gru.py Proyecto: xlnwel/d2rl
    def call(self, x, states):
        x, mask = tf.nest.flatten(x)
        h = states[0]
        assert_rank([x, h, mask], 2)
        if mask is not None:
            h = h * mask
        
        # it sigfinicantly increases the running time when separate normalizations are applied to x and h
        x = self.x_ln(tf.matmul(tf.concat([x, h], -1), self.kernel))
        # x = self.x_ln(tf.matmul(x, self.kernel)) + self.h_ln(tf.matmul(h, self.recurrent_kernel))
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias)
        r, c, z = tf.split(x, 3, 1)
        r, z = self.recurrent_activation(r), self.recurrent_activation(z)
        c = self.activation(c)
        h = z * c + (1-z) * h

        return h, GRUState(h)
Ejemplo n.º 7
0
 def encode(self, x, state, mask, prev_action=None, prev_reward=None):
     if x.shape.ndims % 2 == 0:
         x = tf.expand_dims(x, 1)
     if mask.shape.ndims < 2:
         mask = tf.reshape(mask, (-1, 1))
     assert_rank(mask, 2)
     x = self.encoder(x)
     if hasattr(self, 'rnn'):
         additional_rnn_input = self._process_additional_input(
             x, prev_action, prev_reward)
         x, state = self.rnn(x,
                             state,
                             mask,
                             additional_input=additional_rnn_input)
     else:
         state = None
     if x.shape[1] == 1:
         x = tf.squeeze(x, 1)
     return x, state
Ejemplo n.º 8
0
    def call(self, x, states):
        x, mask = tf.nest.flatten(x)
        assert_rank([x, mask], 2)
        h, c = states
        if mask is not None:
            h = h * mask
            c = c * mask

        x = self.x_ln(tf.matmul(x, self.kernel)) + self.h_ln(
            tf.matmul(h, self.recurrent_kernel))
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias)
        i, f, c_, o = tf.split(x, 4, 1)
        i, f, o = self.recurrent_activation(i), self.recurrent_activation(
            f), self.recurrent_activation(o)
        c_ = self.activation(c_)
        c = f * c + i * c_
        h = o * self.activation(self.c_ln(c))

        return h, LSTMState(h, c)
Ejemplo n.º 9
0
Archivo: nn.py Proyecto: xlnwel/d2rl
    def encode(self, obs, state, online=True):
        encoder = self.q_encoder if online else self.target_q_encoder
        rnn = self.q_rnn if online else self.target_q_rnn

        if obs.shape.ndims % 2 != 0:
            obs = tf.expand_dims(obs, 1)
        assert_rank(obs, 4)

        x = encoder(obs)  # [B, S, A, F]
        seqlen, n_agents = x.shape[1:3]

        tf.debugging.assert_equal(n_agents, self.qmixer.n_agents)
        x = tf.transpose(x, [0, 2, 1, 3])  # [B, A, S, F]
        x = tf.reshape(x, [-1, *x.shape[2:]])  # [B * A, S, F]
        x = rnn(x, state)
        x, state = x[0], self.State(*x[1:])
        x = tf.reshape(x, (-1, n_agents, seqlen, x.shape[-1]))  # [B, A, S, F]
        x = tf.transpose(x, [0, 2, 1, 3])  # [B, S, A, F]

        if seqlen == 1:
            x = tf.squeeze(x, 1)

        return x, state
Ejemplo n.º 10
0
 def _process_additional_input(self, x, prev_action, prev_reward):
     results = []
     if prev_action is not None:
         if self.actor.is_action_discrete:
             if prev_action.shape.ndims < 2:
                 prev_action = tf.reshape(prev_action, (-1, 1))
             prev_action = tf.one_hot(prev_action,
                                      self.actor.action_dim,
                                      dtype=x.dtype)
         else:
             if prev_action.shape.ndims < 3:
                 prev_action = tf.reshape(prev_action,
                                          (-1, 1, self.actor.action_dim))
         assert_rank(prev_action, 3)
         results.append(prev_action)
     if prev_reward is not None:
         if prev_reward.shape.ndims < 2:
             prev_reward = tf.reshape(prev_reward, (-1, 1, 1))
         elif prev_reward.shape.ndims == 2:
             prev_reward = tf.expand_dims(prev_reward, -1)
         assert_rank(prev_reward, 3)
         results.append(prev_reward)
     assert_rank(results, 3)
     return results