示例#1
0
def _fast_matrix_shift(x, funnel_factor=1, is_upsampling=False):
    """Fast matrix shift."""

    if funnel_factor == 1 and not is_upsampling:
        shift = 1
        batch_size, n_head = x.shape[0], x.shape[1]
        queries_len, keys_len = x.shape[2], x.shape[3]
        zero_pad = jnp.zeros((batch_size, n_head, queries_len, shift))
        x = jnp.concatenate([zero_pad, x], axis=3)
        x = x.reshape(batch_size, n_head, keys_len + shift, queries_len)
        x = x[:, :, shift:, :]
        return x

    if is_upsampling:
        k = funnel_factor
        shift = 1
    else:
        k = 1
        shift = funnel_factor

    bsz, n_head = x.shape[0], x.shape[1]
    qlen, klen = x.shape[2], (x.shape[3] + 1) // 2

    zero_pad = jnp.zeros((bsz, n_head, qlen, shift))
    x = jnp.concatenate([zero_pad, x], axis=3)
    x = x.reshape(bsz, n_head, 2 * klen - 1 + shift, qlen)
    x = x[:, :, shift:, :]
    x = x.reshape(bsz, n_head, qlen, klen * 2 - 1)
    x = x[:, :, :, shift - 1:shift - 1 + klen:k]
    return x
示例#2
0
文件: sm3.py 项目: stephenjfox/trax
 def init(self, w):
   momentum = []
   if self._has_momentum:
     momentum = jnp.zeros_like(w)
   v1s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape]
   v2s = []
   if self._graft:
     v2s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape]
   return (momentum, v1s, v2s)
示例#3
0
def _fast_inference_init_state(input_signature, buffer_length):
    """Returns an initial state for causal attention layer fast inference."""
    def zeros_for(batch_size, shape_dtype):
        shape, dtype = shape_dtype.as_tuple()
        d_feature = shape[-1]
        return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype)

    batch_size = input_signature[0].shape[0]
    k = zeros_for(batch_size, input_signature[1])
    v = zeros_for(batch_size, input_signature[2])
    mask = jnp.zeros((batch_size, 1, buffer_length))
    seq_indices = jnp.zeros((batch_size, ), dtype=jnp.int32)
    return (k, v, mask, seq_indices)
示例#4
0
 def init(self, weights):
     shape = weights.shape
     slots = []
     if self._factored and len(shape) >= 2:
         v_row = jnp.zeros(shape[:-1], dtype=jnp.float32)
         v_col = jnp.zeros(shape[:-2] + shape[-1:], dtype=jnp.float32)
         slots.extend([v_row, v_col])
     else:
         v = jnp.zeros_like(weights)
         slots.append(v)
     if self._do_momentum:
         m = jnp.zeros_like(weights)
         slots.append(m)
     return slots
示例#5
0
 def init_weights_and_state(self, input_signature):
   d_feature = input_signature.shape[-1]
   if self._transform == 'diag':
     # Initialize it to a small value because JAX has a bug in softplus.
     scale_isoftplus = jnp.zeros((d_feature,), dtype=jnp.float32) + 1e-4
     weights = scale_isoftplus
   elif self._transform == 'any':
     ortho = trax.layers.initializers.OrthogonalInitializer()
     weights = ortho((d_feature, d_feature), self.rng)
   else:
     weights = layer_base.EMPTY_WEIGHTS
   if self._mode == 'predict':
     batch_size = input_signature.shape[0]
     self.state = jnp.zeros((batch_size,), dtype=jnp.int32), self.rng
   self.weights = weights
示例#6
0
def ResidualZero(*layers, shortcut=None):
    """Wraps a series of layers with a ReZero-style residual connection.

  Instead of computing `(shortcut) + (output of layers)`, like in classical
  Residual connection, ResidualZero computes
  `(shortcut) + alpha * (output of layers)`, where `alpha` is a learnable scalar
  initialized with zero.

  Args:
    *layers: One or more layers, to be applied in series.
    shortcut: If None (the usual case), the Residual layer computes the
        element-wise sum of the stack-top input with the output of the layer
        series. If specified, the `shortcut` layer applies to a copy of the
        inputs and (elementwise) adds its output to the output from the main
        layer series.

  Returns:
      A layer representing a residual connection paired with a layer series.
  """
    layers = _ensure_flat(layers)
    layer = layers[0] if len(layers) == 1 else tl.Serial(layers)
    # TODO(jaszczur): perhaps change inner Serial to Branch?
    return tl.Serial(
        tl.Branch(
            shortcut,
            tl.Serial(
                layer,
                tl.Weights(
                    lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)),
                tl.Multiply())),
        tl.Add(),  # pylint: disable=no-value-for-parameter
    )
示例#7
0
 def init_weights_and_state(self, input_signature):
   if self._mode == 'predict':
     shape, dtype = input_signature.as_tuple()
     batch_size, _, d_feature = shape
     cache = jnp.zeros((batch_size, 2 * self._total_kv_pooling, d_feature),
                       dtype=dtype)
     self.state = cache, jnp.array(0)
示例#8
0
def prepare_attention_input(encoder_activations, decoder_activations, inputs):
    """
    function will prepare K, Q, V and M for attention layer.
    Args:
        encoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the input encoder
        decoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the pre-attention decoder
        inputs fastnp.array(batch_size, padded_input_length): padded input tokens
    Returns:
        queries, keys, values and mask for attention.
    """
    # set the keys and values to the encoder activations
    keys = encoder_activations  # (32, 64, 1024)
    values = encoder_activations

    # set the queries to the decoder activations
    queries = decoder_activations

    # generate the mask to distinguish real tokens from padding
    mask = inputs != 0  # --> (32, 64)

    # add axes to the mask for attention heads and decoder length.
    mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))  # (32, 1, 1, 64)
    # broadcast so mask shape is [batch size, attention heads, decoder-len, encoder-len].
    mask = mask + fastnp.zeros((1, 1, decoder_activations.shape[1], 1))  # (32, 1, 64, 64)

    return queries, keys, values, mask
示例#9
0
 def f(x):  # pylint: disable=invalid-name
   if len(x.shape) != 3:
     raise ValueError(f'Layer input should be a rank 3 tensor representing'
                      f' (batch_size, sequence_length, feature_depth); '
                      f'instead got shape {x.shape}.')
   return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]),
                    dtype=jnp.float32)
示例#10
0
  def init_weights_and_state(self, input_signature):
    """Randomly initializes the positional encoding vectors.

    Args:
      input_signature: :py:class:`ShapeDtype` instance characterizing the input
          this layer should compute on.
    """
    d_feature = input_signature.shape[-1]
    if self._d_feature is not None:
      d_feature = self._d_feature
    pe = np.zeros((self._max_len, d_feature), dtype=np.float32)
    position = np.arange(0, self._max_len)[:, np.newaxis]
    div_term = np.exp(
        np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)  # [self._max_len, d_feature]
    if self._use_bfloat16:
      pe = pe.astype(jnp.bfloat16)
    w = jnp.array(pe)  # Trainable parameters, initialized above.
    if self._d_feature is not None:
      ff = init.GlorotUniformInitializer()(
          (d_feature, input_signature.shape[-1]), self.rng)
      self.weights = w, ff
    else:
      self.weights = w
    if self._mode == 'predict':
      self.state = jnp.zeros((), dtype=jnp.int32)
示例#11
0
def NoUpsampling(shorten_factor, d_model, *args, **kwargs):
    del d_model, args, kwargs

    return core.Fn(
        'ReturnZero',
        lambda x: jnp.zeros(  # pylint: disable=g-long-lambda
            (x.shape[0], x.shape[1] * shorten_factor, x.shape[2]),
            dtype=x.dtype))
示例#12
0
    def init_weights_and_state(self, input_signature):
        """Randomly initializes the positional encoding vectors.

    Args:
      input_signature: `ShapeDtype` instance characterizing the input this
          layer should compute on.
    """
        if self._mode == 'predict':
            self.state = jnp.zeros((), dtype=jnp.int32)
示例#13
0
文件: rnn.py 项目: yangliuy/trax
 def F(encoder_activations, decoder_activations, input_tokens):
     keys = values = encoder_activations
     queries = decoder_activations
     # Mask is 1 where inputs are not padding (0) and 0 where they are padding.
     mask = (input_tokens != 0)
     # We need to add axes to the mask for attention heads and decoder length.
     mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
     # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len].
     mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1))
     return queries, keys, values, mask
示例#14
0
def _fast_matrix_shift(x):
  # Implements necessary shift for relative positional attention calculations.
  shift = 1
  batch_size, n_head = x.shape[0], x.shape[1]
  queries_len, keys_len = x.shape[2], x.shape[3]
  zero_pad = jnp.zeros((batch_size, n_head, queries_len, shift))
  x = jnp.concatenate([zero_pad, x], axis=3)
  x = x.reshape(batch_size, n_head, keys_len + shift, queries_len)
  x = x[:, :, shift:, :]
  return x
示例#15
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            if self._sample_all_discrete_actions:
                # Since we want to sample all actions, start by creating their list.
                act = np.arange(self._vocab_size)
                # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it.
                # Add extra dimenstions so it's the same dimensionality as dist_inputs.
                act = jnp.reshape(act,
                                  [-1] + [1] * (len(dist_inputs.shape) - 1))
                # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs.
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            if self._sample_all_discrete_actions:
                actions = act + jnp.zeros(dist_inputs.shape[:-1],
                                          dtype=jnp.int32)
                actions = jnp.swapaxes(actions, 0, 1)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            if not self._sample_all_discrete_actions:
                actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = fastmath.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
示例#16
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
                px = self.weights[:, :symbol_size, :]
            else:
                rng1, rng2 = fastmath.random.split(self.rng, 2)
                start = fastmath.random.randint(rng1, (), 0,
                                                self._max_offset_to_add)
                start_from_zero = fastmath.random.uniform(
                    rng2, (), jnp.float32, 0, 1)
                start = jnp.where(start_from_zero < self._start_from_zero_prob,
                                  jnp.zeros((), dtype=jnp.int32), start)
                px = fastmath.dynamic_slice_in_dim(self.weights,
                                                   start,
                                                   symbol_size,
                                                   axis=1)
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        fastmath.dynamic_slice_in_dim(self.weights[0],
                                                      state[i],
                                                      inputs.shape[1],
                                                      axis=0))
                self.state = state + inputs.shape[1]
                res = inputs + jnp.stack(emb, 0)
                return res
示例#17
0
 def f(decoder_input, mask):
     if len(decoder_input.shape) != 3:
         raise ValueError(
             f'Decoder input to EncoderDecoderMask must be a rank 3 tensor with '
             f'shape (batch_size, decoder_sequence_length, d_model); instead got '
             f'shape {decoder_input.shape}.')
     batch_size = mask.shape[0]
     encoder_sequence_length = mask.shape[-1]
     decoder_sequence_length = decoder_input.shape[1]
     mask = mask.reshape((batch_size, 1, 1, encoder_sequence_length))
     return mask + jnp.zeros((1, 1, decoder_sequence_length, 1))
示例#18
0
 def init_weights_and_state(self, input_signature):
   """Helper to initialize batch norm weights and state."""
   axis = self._axis
   axis = (axis,) if jnp.isscalar(axis) else axis
   input_shape = input_signature.shape
   shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
   # TODO(jonni): Should beta and gamma match the dtype in the input signature?
   beta = jnp.zeros(shape, dtype='float32') if self._center else ()
   gamma = jnp.ones(shape, dtype='float32') if self._scale else ()
   def get_stats_axis(i, d):
     if i in axis:
       return 1
     else:
       return d
   stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape))
   running_mean = jnp.zeros(stats_shape, dtype=jnp.float32)
   running_var = jnp.ones(stats_shape, dtype=jnp.float32)
   n_batches = jnp.zeros((), dtype=jnp.int64)
   self.weights = (beta, gamma)
   self.state = (running_mean, running_var, n_batches)
示例#19
0
    def test_custom_initializer_shape(self):
        layer = tl.Weights(
            lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32), (2, 2))
        layer.init(())
        y = layer(())
        self.assertEqual(y.tolist(), [[0., 0.], [0., 0.]])

        layer = tl.Weights(init.RandomNormalInitializer(), (2, 2))
        layer.init(())
        y = layer(())
        self.assertEqual(y.shape, (2, 2))
        self.assertNotEqual(y.tolist(), [[0., 0.], [0., 0.]])
示例#20
0
文件: attention.py 项目: google/trax
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        weights = self.weights
        if self._d_feature is not None:
            weights, ff = weights
            weights = jnp.dot(weights[:inputs.shape[1], :], ff)
        if len(weights.shape
               ) < 3:  # old checkpoints have 1 in first dim already
            weights = weights[None, :, :]  # [1, self._max_len, d_feature]
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
                px = weights[:, :symbol_size, :]
            else:
                rng1, rng2 = fastmath.random.split(self.rng, 2)
                start = fastmath.random.randint(rng1, (), 0,
                                                self._max_offset_to_add)
                start_from_zero = fastmath.random.uniform(
                    rng2, (), jnp.float32, 0, 1)
                start = jnp.where(start_from_zero < self._start_from_zero_prob,
                                  jnp.zeros((), dtype=jnp.int32), start)
                px = fastmath.dynamic_slice_in_dim(weights,
                                                   start,
                                                   symbol_size,
                                                   axis=1)
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer stores the index of the current
            # position and increments it on each call.
            emb = fastmath.dynamic_slice_in_dim(weights,
                                                self.state,
                                                inputs.shape[1],
                                                axis=1)
            self.state += inputs.shape[1]
            return inputs + emb
示例#21
0
    def favor(query, key, value):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        prefix_sum_tensor_shape = (key.shape[0], key.shape[-1],
                                   value.shape[-1])
        t_slice_shape = (key.shape[0], key.shape[-1])
        init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape)
        init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape)

        w = favor_numerator(init_prefix_sum_value_numerator, precision,
                            jnp.moveaxis(query_prime, 1, 0),
                            jnp.moveaxis(key_prime, 1, 0),
                            jnp.moveaxis(value, 1, 0))
        r = favor_denominator(init_prefix_sum_value_denominator, precision,
                              jnp.moveaxis(query_prime, 1, 0),
                              jnp.moveaxis(key_prime, 1, 0))
        w = jnp.moveaxis(w, 0, 1)
        r = jnp.moveaxis(r, 0, 1)
        r = jnp.reciprocal(r)
        r = jnp.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention
示例#22
0
    def test_weights_and_state_signature(self):
        class MyLayer(tl.Layer):
            def init_weights_and_state(self, input_signature):
                self.weights = jnp.zeros((2, 3))
                self.state = jnp.ones(input_signature.shape)

            def forward(self, inputs):
                return self.weights + self.state

        layer = MyLayer()
        w, s = layer.weights_and_state_signature(jnp.zeros((3, 4)))
        self.assertEqual(w.shape, (2, 3))
        self.assertEqual(s.shape, (3, 4))
示例#23
0
    def init_weights_and_state(self, input_signature):
        # Usually (B, W, H, C)
        shape = input_signature.shape
        num_channels = shape[-1]

        gamma = jnp.ones((num_channels, ), dtype=jnp.float32)
        beta = jnp.zeros((num_channels, ), dtype=jnp.float32)

        epsilon_l = base.EMPTY_WEIGHTS
        if self._learn_epsilon:
            epsilon_l = (self._init_learnt_epsilon, )

        self.weights = gamma, beta, epsilon_l
示例#24
0
 def init_weights_and_state(self, input_signature):
   d_feature = input_signature.shape[-1]
   pe = np.zeros((self._max_len, d_feature), dtype=np.float32)
   position = np.arange(0, self._max_len)[:, np.newaxis]
   div_term = np.exp(
       np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
   pe[:, 0::2] = np.sin(position * div_term)
   pe[:, 1::2] = np.cos(position * div_term)
   pe = pe[np.newaxis, :, :]  # [1, self._max_len, d_feature]
   self.weights = jnp.array(pe)  # Trainable parameters, initialized above.
   if self._mode == 'predict':
     batch_size = input_signature.shape[0]
     self.state = jnp.zeros((batch_size,), dtype=jnp.int32)
示例#25
0
 def init_weights_and_state(self, input_signature):
     d_feature = input_signature.shape[-1]
     assert d_feature % self._n_digits == 0
     d_weight = d_feature // self._n_digits
     rng1, rng2 = fastmath.random.split(self.rng, 2)
     base_weights = [[
         self._initializer((1, d_weight), rng)
         for rng in fastmath.random.split(rng1, self._n_digits)
     ] for _ in self._bases]
     # Special vector to mark the starting position.
     start_vec = self._initializer((1, 1, d_feature), rng2)
     self.weights = (base_weights, start_vec)
     if self._mode == 'predict':
         self.state = jnp.zeros((), dtype=jnp.int32)
示例#26
0
def prepare_attention_input(encoder_activations, decoder_activations, inputs):
    keys = encoder_activations
    values = encoder_activations
    queries = decoder_activations

    mask = (inputs != 0
            )  # generate the mask to distinguish real tokens from padding
    mask = fastnp.reshape(
        mask,
        (mask.shape[0], 1, 1, mask.shape[1]
         ))  # add axes to the mask for attention heads and decoder length.
    mask = mask + fastnp.zeros(
        (1, 1, decoder_activations.shape[1], 1)
    )  # broadcast so mask shape is [batch size, attention heads, decoder-len, encoder-len].

    return queries, keys, values, mask
示例#27
0
文件: attention.py 项目: google/trax
def _fast_inference_init_state(input_signature,
                               buffer_length,
                               predict_mask=None):
    """Returns an initial state for causal attention layer fast inference."""
    def zeros_for(batch_size, shape_dtype):
        shape, dtype = shape_dtype.as_tuple()
        d_feature = shape[-1]
        return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype)

    batch_size = input_signature[0].shape[0]
    k = zeros_for(batch_size, input_signature[1])
    v = zeros_for(batch_size, input_signature[2])
    if predict_mask is not None:
        mask_for_predict = jnp.zeros((buffer_length, )) != 0
        return (mask_for_predict, k, v, jnp.array(0))
    else:
        return (k, v, jnp.array(0))
示例#28
0
 def policy_inputs(self, trajectory, values):
     """Create inputs to policy model from a TrajectoryNp and values."""
     # How much TD to use is determined by the added policy slice length,
     # as the policy batches need to be this much longer to calculate TD.
     advantages = self._advantage_estimator(
         rewards=trajectory.rewards,
         returns=trajectory.returns,
         values=values,
         dones=trajectory.dones,
         gamma=self._task.gamma,
         n_extra_steps=self._added_policy_slice_length,
     )
     # Observations should be the same length as advantages - so if we are
     # using n_extra_steps, we need to trim the length to match.
     obs = trajectory.observations[:, :advantages.shape[1]]
     act = trajectory.actions[:, :advantages.shape[1]]
     mask = trajectory.mask[:, :advantages.
                            shape[1]]  # Mask to zero-out padding.
     if trajectory.dist_inputs is not None:
         dist_inputs = trajectory.dist_inputs[:, :advantages.shape[1]]
     else:
         dist_inputs = jnp.zeros(advantages.shape +
                                 (self._policy_dist.n_inputs, ))
     # Shape checks to help debugging.
     if len(advantages.shape) != 2:
         raise ValueError('Advantages are expected to have shape ' +
                          '[batch_size, length], got: %s' %
                          str(advantages.shape))
     if act.shape[0:2] != advantages.shape:
         raise ValueError(
             'First 2 dimensions of actions should be the same as in '
             'advantages, %s != %s' % (act.shape[0:2], advantages.shape))
     if obs.shape[0:2] != advantages.shape:
         raise ValueError(
             'First 2 dimensions of observations should be the same '
             'as in advantages, %s != %s' %
             (obs.shape[0:2], advantages.shape))
     if dist_inputs.shape[:2] != advantages.shape:
         raise ValueError(
             'First 2 dimensions of dist_inputs should be the same '
             'as in advantages, %s != %s' %
             (dist_inputs.shape[:2], advantages.shape))
     if mask.shape != advantages.shape:
         raise ValueError('Mask and advantages shapes should be the same'
                          ', %s != %s' % (mask.shape, advantages.shape))
     return (obs, act, advantages, dist_inputs, mask)
示例#29
0
  def init_weights_and_state(self, input_signature):
    """Randomly initializes the positional encoding vectors.

    Args:
      input_signature: `ShapeDtype` instance characterizing the input this
          layer should compute on.
    """
    d_feature = input_signature.shape[-1]
    pe = np.zeros((self._max_len, d_feature), dtype=np.float32)
    position = np.arange(0, self._max_len)[:, np.newaxis]
    div_term = np.exp(
        np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    pe = pe[np.newaxis, :, :]  # [1, self._max_len, d_feature]
    self.weights = jnp.array(pe)  # Trainable parameters, initialized above.
    if self._mode == 'predict':
      batch_size = input_signature.shape[0]
      self.state = jnp.zeros((batch_size,), dtype=jnp.int32)
示例#30
0
    def test_forward(self):
        layer = tl.PureLayer(lambda x: 2 * x)

        # Use Layer.__call__.
        in_0 = np.array([1, 2])
        out_0 = layer(in_0, weights=jnp.zeros((2, 3)))
        self.assertEqual(out_0.tolist(), [2, 4])
        self.assertEmpty(layer.weights)

        # Use PureLayer.forward.
        in_1 = np.array([3, 4])
        out_1 = layer.forward(in_1)
        self.assertEqual(out_1.tolist(), [6, 8])

        # Use Layer.pure_fn
        in_2 = np.array([5, 6])
        out_2, _ = layer.pure_fn(in_2, tl.EMPTY_WEIGHTS, tl.EMPTY_WEIGHTS,
                                 None)
        self.assertEqual(out_2.tolist(), [10, 12])