Esempio n. 1
0
 def policy(self, trajectory, temperature=1):
   """Chooses an action to play after a trajectory."""
   tr_slice = trajectory[-self._max_slice_length:]
   trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np)
   # Add batch dimension to trajectory_np and run the model.
   obs = trajectory_np.observations[None, ...]
   values = self._run_value_model(obs, use_eval_model=False)
   # We insisit that values and observations have the shape
   # (batch, length, ...), where the length is the number of subsequent
   # observations on a given trajectory
   assert values.shape[:1] == obs.shape[:1]
   # We select the last element in the batch and the value
   # related to the last (current) observation
   values = values[0, -1, :]
   # temperature == 0 is used in another place in order to trigger eval
   if np.random.random_sample() < self._exploration_rate(self._epoch) and \
       temperature == 1:
     sample = np.array(self.task.action_space.sample())
   else:
     # this is our way of doing the argmax
     sample = jnp.argmax(values)
   result = (sample, values)
   if fastmath.backend_name() == 'jax':
     result = fastmath.nested_map(lambda x: x.copy(), result)
   return result
Esempio n. 2
0
 def sample(self, inputs, temperature=1.0):
   # No need for LogSoftmax with sampling - softmax normalization is
   # subtracting a constant from every logit, and sampling is taking
   # a max over logits plus noise, so invariant to adding a constant.
   if temperature == 0.0:
     return jnp.argmax(self._unflatten_inputs(inputs), axis=-1)
   return tl.logsoftmax_sample(self._unflatten_inputs(inputs), temperature)
Esempio n. 3
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     position_is_padding = jnp.equal(weights, 0)
     position_is_accurate = jnp.logical_or(jnp.equal(predictions, targets),
                                           position_is_padding)
     sequence_is_accurate = jnp.all(position_is_accurate, axis=-1)
     return jnp.average(sequence_is_accurate)
Esempio n. 4
0
 def f(model_output, targets):  # pylint: disable=invalid-name
   beta2 = beta ** 2
   predictions = jnp.argmax(model_output, axis=-1)
   n_categories = model_output.shape[-1]
   f_scores = jnp.empty(0)
   for k in range(initial_category_index, n_categories):
     _, _, _, precision, recall = _precision_recall(predictions, targets, k)
     f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2))
   return jnp.mean(f_scores)
Esempio n. 5
0
    def f(model_output, targets):  # pylint: disable=invalid-name
        def non_nan(x):  # pylint: disable=invalid-name
            return jnp.where(jnp.isnan(x), 0., x)

        beta2 = beta**2
        predictions = jnp.argmax(model_output, axis=-1)
        n_categories = model_output.shape[-1]
        f_scores = jnp.empty(0)
        for k in range(initial_category_index, n_categories):
            n_correct = sum((predictions == k) & (targets == k))
            precision = non_nan(n_correct / sum(predictions == k))
            recall = non_nan(n_correct / sum(targets == k))
            f_score = non_nan((beta2 + 1) * (precision * recall) /
                              ((beta2 * precision) + recall))
            f_scores = jnp.append(f_scores, f_score)
        return jnp.mean(f_scores)
Esempio n. 6
0
 def f(model_output):  # pylint: disable=invalid-name
     return jnp.argmax(model_output, axis=axis)
Esempio n. 7
0
 def f(model_output, targets):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     n_total = predictions.size
     n_correct = jnp.sum(jnp.equal(predictions, targets))
     return n_correct / n_total
Esempio n. 8
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     ones_and_zeros = jnp.equal(predictions, targets)
     return jnp.sum(ones_and_zeros * weights) / jnp.sum(weights)
Esempio n. 9
0
def gumbel_sample(log_probs, temperature=1.0):
    """Gumbel sampling from a categorical distribution."""
    u = numpy.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
    g = -np.log(-np.log(u))
    return np.argmax(log_probs + g * temperature, axis=-1)
Esempio n. 10
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, m2, mb, w1, w2, b2 = self.weights
        if self._mode != 'predict':
            w1 = jnp.reshape(w1.T, (-1, self._d_ff))
            w2 = jnp.reshape(w2, (self._d_ff, -1))
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: should we add bias and/or put relu after the low-rank m1 dot?
        mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb
        mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2])
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (quant_prob of the batches) use the soft-mask instead
            # of the quantized mask to improve training stability (see paper above).
            select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0)
            quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])

        if self._mode == 'train':
            # In training, run full matmul to get benefits from the above tricks.
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2
        elif self._mode == 'predict':
            # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1))
            # w2 = jnp.reshape(w2, (self._d1, self._d2, -1))
            # This implementation mimicks inference. It's not efficient for large
            # size of joint_batch, but at inference that will be 1 most of the time.
            # Shapes:
            # quant_mask is [joint_batch, self._d1]
            # w1 is [d_model, self._d1, self._d2]
            # we'll index w1 with advanced numpy indexing, first range over
            # self._d1 times the batch size, second range being quant_mask
            batch_size = quant_mask.shape[0]
            idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)
            # flatten indices and select from w1
            idx1 = jnp.reshape(idx1, [-1])
            idx2 = jnp.reshape(quant_mask, [-1])
            w = w1[idx1,
                   idx2, :]  # now we have per-element weights with batch dim
            w = jnp.reshape(w, [batch_size, self._d1, -1])
            mid = jnp.einsum('ai,aji->aj', x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [self._d1, self._d2, d_model]
            v = w2[idx1, idx2, :]
            v = jnp.reshape(v, [batch_size, self._d1, -1])
            res = jnp.einsum('ai,aij->aj', relu, v) + b2
        else:
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Esempio n. 11
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, w1, w2, b2 = self.weights
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: check if we need bias and/or put relu after the m1 dot?
        mask_logits = jnp.dot(x, m1)
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        # TODO(lukaszkaiser, chowdhery): Extract this block and share
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        selected_experts = jnp.argmax(log_mask + g * self._temperature,
                                      axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (50% of the batches) use the soft-mask instead of
            # the quantized mask to improve training stability (see the paper above).
            # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?
            select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)
            quant_mask = jnp.where(select > 0.0, quant_mask, mask)
        else:
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
        quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1])
        quant_mask_shape = quant_mask.shape
        batch_size = quant_mask.shape[0]

        if self._mode == 'predict' and batch_size == 1:
            # This implementation mimicks inference for batch_size 1.
            start_idx = selected_experts[0] * self._n_elements_in_block
            # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]
            w = fastmath.dynamic_slice(
                w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block])
            mid = jnp.dot(x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]
            v = fastmath.dynamic_slice(
                w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]])
            v = jnp.reshape(v, [self._n_elements_in_block, -1])
            res = jnp.dot(relu, v) + b2
        else:
            expanded_mask = jnp.broadcast_to(
                quant_mask, (quant_mask_shape[0], quant_mask.shape[1],
                             self._n_elements_in_block))
            expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))
            mid = jnp.dot(x, w1) * expanded_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Esempio n. 12
0
 def f(model_output, target_category):  # pylint: disable=invalid-name
     predicted_category = jnp.argmax(model_output, axis=axis)
     # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment.
     # shapes.assert_same_shape(predicted_category, target_category)
     return jnp.equal(predicted_category,
                      target_category).astype(jnp.float32)
Esempio n. 13
0
 def f(model_output):  # pylint: disable=invalid-name
     predicted_category = jnp.argmax(model_output, axis=axis)
     return predicted_category