示例#1
0
    def compute_metrics(
        masked_lm_logits: jnp.ndarray,
        next_sentence_logits: jnp.ndarray,
        masked_lm_labels: jnp.ndarray,
        masked_lm_weights: jnp.ndarray,
        next_sentence_labels: jnp.ndarray,
    ):
        """Computes the pre-training loss and its components."""
        masked_lm_logits = nn.log_softmax(masked_lm_logits)
        masked_lm_labels = onehot(masked_lm_labels.reshape((-1, )),
                                  masked_lm_logits.shape[-1])
        masked_lm_weights = masked_lm_weights.reshape((-1, ))
        masked_lm_loss = -jnp.sum(
            jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) *
            masked_lm_weights) / jnp.sum(masked_lm_weights)

        next_sentence_logits = nn.log_softmax(next_sentence_logits)
        next_sentence_labels = next_sentence_labels.reshape((-1, ))
        next_sentence_loss = -jnp.mean(
            jnp.sum(
                onehot(next_sentence_labels, next_sentence_logits.shape[-1]) *
                next_sentence_logits,
                axis=-1,
            ))
        return {
            "loss": masked_lm_loss + next_sentence_loss,
            "masked_lm_loss": masked_lm_loss,
            "next_sentence_loss": next_sentence_loss,
        }
示例#2
0
  def __call__(self, x):
    """Define the convolutional network architecture.

    Architecture originates from "Human-level control through deep reinforcement
    learning.", Nature 518, no. 7540 (2015): 529-533.
    Note that this is different than the one from  "Playing atari with deep
    reinforcement learning." arxiv.org/abs/1312.5602 (2013)

    Network is used to both estimate policy (logits) and expected state value;
    in other words, hidden layers' params are shared between policy and value
    networks, see e.g.:
    github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py
    """
    dtype = jnp.float32
    x = x.astype(dtype) / 255.
    x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1',
                dtype=dtype)(x)
    x = nn.relu(x)
    x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2',
                dtype=dtype)(x)
    x = nn.relu(x)
    x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3',
                dtype=dtype)(x)
    x = nn.relu(x)
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=512, name='hidden', dtype=dtype)(x)
    x = nn.relu(x)
    logits = nn.Dense(features=self.num_outputs, name='logits', dtype=dtype)(x)
    policy_log_probabilities = nn.log_softmax(logits)
    value = nn.Dense(features=1, name='value', dtype=dtype)(x)
    return policy_log_probabilities, value
示例#3
0
def compute_per_pos_loss(logits,
                         targets,
                         weights=None,
                         label_smoothing=0.0):
  """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   weights: None or array of shape [batch, length].
   label_smoothing: label smoothing constant, used to determine the on and
     off values.

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
                     (str(logits.shape), str(targets.shape)))
  vocab_size = logits.shape[-1]
  confidence = 1.0 - label_smoothing
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
  normalizing_constant = -(
      confidence * jnp.log(confidence) + (vocab_size - 1) *
      low_confidence * jnp.log(low_confidence + 1e-20))
  soft_targets = common_utils.onehot(
      targets, vocab_size, on_value=confidence, off_value=low_confidence)

  loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
  loss = loss - normalizing_constant

  if weights is not None:
    loss = loss * weights

  return loss
示例#4
0
def weighted_unnormalized_cross_entropy(logits, targets, weights=None):
  """Compute weighted cross entropy and entropy for log probs and targets.

  This computes sum_(x,y) ce(x, y) for a single, potentially padded minibatch.
  If the minibatch is padded (that is it contains null examples) it is assumed
  that weights is a binary mask where 0 indicates that the example is null.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: one hot vector of shape [batch, ..., num_classes].
   weights: None or array of shape [batch x ...] (rank of one_hot_targets -1).

  Returns:
    Cross entropy loss computed per example, shape [batch, ...].
  """
  if logits.ndim != targets.ndim:
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
                     (str(logits.shape), str(targets.shape)))

  loss = -jnp.sum(targets * nn.log_softmax(logits), axis=-1)
  if weights is not None:
    if weights.ndim != targets.ndim - 1:
      raise ValueError('Incorrect shapes. Got shape %s weights and %s targets' %
                       (str(weights.shape), str(targets.shape)))
    loss = loss * weights

  return loss
def compute_weighted_cross_entropy(logits, targets, weights=None):
    """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: `[batch, length, num_classes]` float array.
   targets: categorical targets `[batch, length]` int array.
   weights: None or array of shape [batch, length, 1]

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    if logits.shape[1] != targets.shape[1]:  # Truncate logits.
        logits = logits[:, :targets.shape[1]]

    onehot_targets = common_utils.onehot(targets, logits.shape[-1])
    loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
    normalizing_factor = jnp.prod(jnp.asarray(targets.shape))
    if weights is not None:
        loss = loss * weights
        normalizing_factor = weights.sum()

    return loss.sum(), normalizing_factor
示例#6
0
def cross_entropy(logits, targets, axis=-1):
    logprobs = nn.log_softmax(logits, axis=axis)
    nll = np.take_along_axis(logprobs,
                             np.expand_dims(targets, axis=axis),
                             axis=axis)
    ce = -np.mean(nll)
    return ce
示例#7
0
    def __call__(
        self,
        input_ids: jnp.ndarray,
        input_mask: jnp.ndarray,
        type_ids: jnp.ndarray,
        labels: jnp.ndarray = None,
        *,
        deterministic: bool = False,
    ):
        """Applies BERT for sequence classification."""
        bert = BertModel(config=self.config, name="bert")
        _, pooled_output = bert(input_ids,
                                input_mask,
                                type_ids,
                                deterministic=deterministic)
        pooled_output = nn.Dropout(rate=self.config.hidden_dropout_prob,
                                   deterministic=deterministic)(pooled_output)
        logits = layers.OutputProjection(
            n_out=self.n_classes,
            kernel_init=get_kernel_init(self.config),
            name="classification",
        )(pooled_output)

        if labels is None:
            return logits
        elif logits.shape[-1] == 1:
            # Regression task
            loss = jnp.mean((logits[..., 0] - labels)**2)
            return {"loss": loss}
        else:
            # Classification task
            logits = nn.log_softmax(logits)
            loss = -jnp.mean(
                jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1))
            return {"loss": loss}
示例#8
0
    def __call__(self, x, train: bool = True):
        conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
        norm = partial(nn.BatchNorm,
                       use_running_average=not train,
                       momentum=0.9,
                       epsilon=1e-5,
                       dtype=self.dtype)

        x = conv(self.num_filters, (7, 7), (2, 2),
                 padding=[(3, 3), (3, 3)],
                 name='conv_init')(x)
        x = norm(name='bn_init')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        for i, block_size in enumerate(self.stage_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = self.block_cls(self.num_filters * 2**i,
                                   strides=strides,
                                   conv=conv,
                                   norm=norm,
                                   act=self.act)(x)
        x = jnp.mean(x, axis=(1, 2))
        x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
        x = jnp.asarray(x, self.dtype)
        x = nn.log_softmax(x)
        return x
示例#9
0
 def __call__(self, x):
     x = OGDense(first, name="first")(x)
     for n in rest:
         x = OGDense(n)(x)
     x = nn.Dense(10, name="last")(x)
     x = nn.log_softmax(x)
     return x
示例#10
0
 def __call__(self, x):
     x = nn.Conv(features=8, kernel_size=(3, 3), strides=(1, 1))(x)
     x = activation(x)
     x = jnp.reshape(x, (x.shape[0], -1))
     x = nn.Dense(10)(x)
     x = nn.log_softmax(x)
     return x
示例#11
0
 def __call__(self, x):
     if self.nhidden > 0:
         x = nn.Dense(self.nhidden)(x)
         x = nn.relu(x)
     x = nn.Dense(self.nclasses)(x)
     x = nn.log_softmax(x)
     return x
示例#12
0
    def __call__(self, x):
        # Helper macro.
        R_ = lambda hidden_: ResidualUnit(hidden_features=hidden_,
                                          norm=self.norm,
                                          training=self.training,
                                          activation=nn.gelu)
        # First filter to make features.
        h = nn.Conv(features=self.hidden * self.alpha,
                    use_bias=False,
                    kernel_size=(3, 3),
                    kernel_init=INITS[self.kernel_init])(x)
        h = NORMS[self.norm](use_running_average=not self.training)(h)
        h = nn.gelu(h)
        # 2 stages of continuous segments:
        h = ResidualStitch(hidden_features=self.hidden * self.alpha,
                           output_features=self.hidden * self.alpha,
                           strides=(1, 1),
                           norm=self.norm,
                           training=self.training,
                           activation=nn.gelu)(h)
        h = StatefulContinuousBlock(R=R_(self.hidden * self.alpha),
                                    scheme=self.scheme,
                                    n_step=self.n_step,
                                    n_basis=self.n_basis,
                                    basis=self.basis,
                                    training=self.training)(h)

        # Pool and linearly classify:
        h = NORMS[self.norm](use_running_average=not self.training)(h)
        h = nn.gelu(h)
        h = nn.avg_pool(h, window_shape=(8, 8), strides=(8, 8))
        h = h.reshape((h.shape[0], -1))
        h = nn.Dense(features=self.n_classes)(h)
        return nn.log_softmax(h)  # no softmax
示例#13
0
 def __call__(self, x):
   x = jnp.reshape(x, (-1, 28 * 28))
   x = nn.Dense(1024)(x)
   x = activation(x)
   x = nn.Dense(10)(x)
   x = nn.log_softmax(x)
   return x
  def loss_fn(params):
    """loss function used for training."""
    logits = models.Transformer(config).apply(
        {"params": params},
        inputs,
        targets,
        inputs_positions=inputs_positions,
        targets_positions=targets_positions,
        inputs_segmentation=inputs_segmentation,
        targets_segmentation=targets_segmentation,
        rngs={"dropout": dropout_rng})

    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) +
        (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
    soft_targets = common_utils.onehot(
        targets, vocab_size, on_value=confidence, off_value=low_confidence)

    loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
    loss = loss - normalizing_constant

    loss = loss * weights
    normalizing_factor = weights.sum()

    mean_loss = loss.sum() / normalizing_factor
    return mean_loss, logits
示例#15
0
    def __call__(self, input_ids, type_ids, labels=None, deterministic=False):
        """Applies model for sequence classification.

    Args:
      input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into
        different types.
      labels: True labels associated with inputs. Generally only required for
        training. Shape depends on task type:
        * Classification: <int>[BATCH_SIZE],
        * Regression: <float>[BATCH_SIZE].
      deterministic: Whether to apply dropout to input.

    Returns:
      * If labels supplied (training mode): Model loss and metrics.
      * If no labels supplied (prediction / evaluation mode): Logits with shape
        <float>[BATCH_SIZE, n_classes].
    """
        encoder_output = EncoderModel(self.config, name="encoder")(
            input_ids, type_ids, deterministic=deterministic)

        # All other classification and regression tasks use the pooled output.
        output = encoder_output.pooled_output
        # TODO(jamesleethorp): For WiC, the original SuperGLUE paper
        #  (https://arxiv.org/abs/1905.00537) concatenates the "CLS" and "word"
        #  output representations. We only use the pooled output.

        logits = layers.OutputProjection(n_out=self.n_classes,
                                         kernel_init=default_kernel_init,
                                         name="classification")(output)

        if labels is None:
            # Code path used during evaluation or prediction; metrics can be computed
            # from logits by the caller.
            return logits

        # Code path used during training.
        if (self.config.dataset_name == "glue/stsb" or  # Regression task
                self.config.dataset_name == "super_glue/copa"
                or  # "Regression" task
                self.config.dataset_name
                == "super_glue/record"):  # "Regression" task
            # Logits have shape: [BATCH_SIZE, 1].
            per_example_loss = jnp.sum((logits[Ellipsis, 0] - labels)**2,
                                       axis=-1)
            batch_loss = jnp.mean(per_example_loss)
            return ClassificationStats(batch_loss=batch_loss,
                                       num_labels=labels.size)

        else:  # Classification task
            # Logits have shape: [BATCH_SIZE, self.n_classes].
            logits = nn.log_softmax(logits, axis=-1)
            per_example_loss = -jnp.sum(
                onehot(labels, logits.shape[-1]) * logits, axis=-1)
            batch_loss = jnp.mean(per_example_loss)
            correct_predictions = jnp.sum(logits.argmax(-1) == labels)
            return ClassificationStats(batch_loss=batch_loss,
                                       num_labels=labels.size,
                                       correct_predictions=correct_predictions)
示例#16
0
    def __call__(
        self,
        inputs,
    ):
        """Applies ResNet model. Number of residual blocks inferred from hparams."""
        num_classes = self.num_classes
        hparams = self.hparams
        num_filters = self.num_filters
        dtype = self.dtype

        x = aqt_flax_layers.ConvAqt(
            features=num_filters,
            kernel_size=(7, 7),
            strides=(2, 2),
            padding=[(3, 3), (3, 3)],
            use_bias=False,
            dtype=dtype,
            name='init_conv',
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.conv_init,
        )(inputs)
        x = nn.BatchNorm(use_running_average=not self.train,
                         momentum=0.9,
                         epsilon=1e-5,
                         dtype=dtype,
                         name='init_bn')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        filter_multiplier = hparams.filter_multiplier
        for i, block_hparams in enumerate(hparams.residual_blocks):
            proj = block_hparams.conv_proj
            # For projection layers (unless it is the first layer), strides = (2, 2)
            if i > 0 and proj is not None:
                filter_multiplier *= 2
                strides = (2, 2)
            else:
                strides = (1, 1)
            x = ResidualBlock(filters=int(num_filters * filter_multiplier),
                              hparams=block_hparams,
                              quant_context=self.quant_context,
                              strides=strides,
                              train=self.train,
                              dtype=dtype)(x)
        x = jnp.mean(x, axis=(1, 2))

        x = aqt_flax_layers.DenseAqt(
            features=num_classes,
            dtype=dtype,
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.dense_layer,
        )(x, padding_mask=None)

        x = jnp.asarray(x, dtype)
        output = nn.log_softmax(x)
        return output
示例#17
0
 def __call__(self, x):
     x = nn.Dense(1024, bias_init=nn.initializers.normal(stddev=1.0))(x)
     x = nn.relu(x)
     x = nn.Dense(1024, bias_init=nn.initializers.normal(stddev=1.0))(x)
     x = nn.relu(x)
     x = nn.Dense(10)(x)
     x = nn.log_softmax(x)
     return x
示例#18
0
  def __call__(
      self,
      input_ids,
      input_mask,
      type_ids,
      labels = None,
      deterministic = False
  ):
    """Applies model for sequence classification.

    Args:
      input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs.
      input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual
        inputs from padding. Only used by BERT.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into
        different types.
      labels: True labels associated with inputs. Generally only required for
        training. Shape depends on task type:
        * Classification: <int>[BATCH_SIZE]
        * Regression: <float>[BATCH_SIZE]
      deterministic: Whether or not to apply dropout to input.

    Returns:
      * If labels supplied (training mode): Model loss and metrics.
      * If no labels supplied (prediction / evaluation mode): Logits of shape
        <float>[BATCH_SIZE, n_classes].
    """
    _, pooled_output = EncoderModel(
        self.config, name="encoder")(
            input_ids, input_mask, type_ids, deterministic=deterministic)

    logits = layers.OutputProjection(
        n_out=self.n_classes,
        kernel_init=default_kernel_init,
        name="classification")(
            pooled_output)

    if labels is None:
      # Code path used during evaluation or prediction; metrics can be computed
      # from logits by the caller.
      return logits

    # Code path used during training.
    if self.config.dataset_name == "glue/stsb":  # Regression task
      loss = jnp.mean((logits[Ellipsis, 0] - labels)**2)
      return {"loss": loss, "num_labels": labels.size}
    else:  # Classification task
      logits = nn.log_softmax(logits)
      loss = -jnp.mean(
          jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1))
      correct_predictions = jnp.sum(logits.argmax(-1) == labels)
      return {
          "loss": loss,
          "correct_predictions": correct_predictions,
          "num_labels": labels.size
      }
示例#19
0
 def __call__(self, x: spec.Tensor, train: bool):
     del train
     input_size = 28 * 28
     num_hidden = 128
     num_classes = 10
     x = x.reshape((x.shape[0], input_size))  # Flatten.
     x = nn.Dense(features=num_hidden, use_bias=True)(x)
     x = nn.sigmoid(x)
     x = nn.Dense(features=num_classes, use_bias=True)(x)
     x = nn.log_softmax(x)
     return x
示例#20
0
 def loss(params):
     x_in = shift_right(x)
     logits = model.apply(params,
                          x_in,
                          rngs={
                              "permute": permute_key,
                              "dropout": dropout_key
                          })
     log_prob = nn.log_softmax(logits)
     x_onehot = onehot(x, num_classes=10)
     nll = -jnp.sum(x_onehot * log_prob, axis=-1)
     return jnp.mean(nll)
示例#21
0
文件: train.py 项目: ykumards/flax
 def __call__(self, x):
     x = nn.Conv(features=32, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.Conv(features=64, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(features=256)(x)
     x = nn.relu(x)
     x = nn.Dense(features=10)(x)
     x = nn.log_softmax(x)
     return x
示例#22
0
    def __call__(self, inputs, train: bool = True):
        """Passes the input through the network.
        Arguments:
            inputs:     [batch_size, height, width, channels]
            train:      bool
        Returns:
            output:     [batch_size, config.num_classes]
        """
        cfg = self.config
        conv = partial(nn.Conv,
                       use_bias=False,
                       dtype=cfg.dtype,
                       precision=cfg.precision,
                       kernel_init=cfg.kernel_init)
        norm = partial(nn.BatchNorm,
                       use_running_average=not train,
                       momentum=cfg.bn_momentum,
                       epsilon=cfg.bn_epsilon,
                       dtype=cfg.dtype)

        y = conv(cfg.initial_filters,
                 kernel_size=(7, 7),
                 strides=(2, 2),
                 padding=[(3, 3), (3, 3)])(inputs)
        y = norm()(y)
        y = cfg.activation_fn(y)
        y = nn.max_pool(y, (3, 3), strides=(2, 2), padding='SAME')
        for i, block_size in enumerate(cfg.stage_sizes[:-1]):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                y = BottleneckResNetBlock(filters=cfg.initial_filters * 2**i,
                                          strides=strides,
                                          config=cfg,
                                          conv=conv,
                                          norm=norm)(y)
        for j in range(cfg.stage_sizes[-1]):
            strides = (2, 2) if j == 0 and cfg.stride_one is False else (1, 1)
            y = BoTBlock(filters=cfg.initial_filters * 2**(i + 1),
                         strides=strides,
                         config=cfg,
                         conv=conv,
                         norm=norm)(y)
        y = jnp.mean(y, axis=(1, 2))
        y = nn.Dense(cfg.num_classes,
                     dtype=cfg.dtype,
                     kernel_init=cfg.kernel_init,
                     bias_init=cfg.bias_init)(y)
        y = jnp.asarray(y, dtype=cfg.dtype)
        y = nn.log_softmax(y)
        return y
    def __call__(self, x_dict: _InputBatch) -> _LogitBatch:
        # Each feature is of shape f32[B, 1]
        x_tuple = tuple(x_dict[feature] for feature in _FEATURE_KEYS_XF)
        x_array = jnp.concatenate(x_tuple, axis=-1)  # shape: f32[B, 4]
        assert x_array.ndim == 2
        assert x_array.shape[1] == 4
        x = x_array

        x = nn.Dense(features=8)(x)
        x = nn.relu(x)
        x = nn.Dense(features=8)(x)
        x = nn.relu(x)
        x = nn.Dense(features=3)(x)
        x = nn.log_softmax(x, axis=-1)
        return x
示例#24
0
 def __call__(self, x, with_classifier=True):
   x = nn.Conv(features=32, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   x = nn.Conv(features=64, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   # TODO: replace np.prod(x.shape[1:]) with -1 once we fix shape_polymorphism
   x = x.reshape((x.shape[0], np.prod(x.shape[1:])))  # flatten
   x = nn.Dense(features=256)(x)
   x = nn.relu(x)
   if not with_classifier:
     return x
   x = nn.Dense(features=10)(x)
   x = nn.log_softmax(x)
   return x
示例#25
0
def cross_entropy(logits, targets, weights = None, label_smoothing = 0.0):
    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
            confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
    ).astype(logits.dtype)
    soft_targets = common_utils.onehot(targets,vocab_size)
    loss = -jnp.sum(soft_targets*nn.log_softmax(logits),axis=-1,dtype=logits.dtype)
    loss -= normalizing_constant
    if weights is not None:
        loss *= weights
        normalizing_factor = weights.sum()
    else:
        normalizing_factor = np.prod(targets.shape,dtype=logits.dtype)
    return loss.sum(),normalizing_factor
示例#26
0
    def _compute_weighted_cross_entropy(logits, targets, weights=None):
        """Computes weighted cross entropy and entropy for log probs and targets.

    Args:
     logits: <float>[NUM_EXAMPLES, NUM_CLASSES] predicted logits.
     targets: <int>[NUM_EXAMPLES] true labels.
     weights: <float>[NUM_EXAMPLES] relative weights for labels.

    Returns:
      Loss and normalizing factor for input batch.
    """
        onehot_targets = onehot(targets, logits.shape[-1])
        loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
        normalizing_factor = onehot_targets.sum()
        if weights is not None:
            loss = loss * weights
            normalizing_factor = weights.sum()
        return loss.sum(), normalizing_factor
示例#27
0
 def __call__(self, x):
     alpha = self.alpha
     n_step = self.n_step
     h = nn.Conv(features=alpha,
                 kernel_size=(3, 3),
                 use_bias=False,
                 kernel_init=INITS[self.kernel_init])(x)
     h = NORMS[self.norm](use_running_average=not self.training)(h)
     h = nn.relu(h)
     for i in range(n_step):
         h += ResidualUnit(hidden_features=alpha,
                           norm=self.norm,
                           kernel_init=self.kernel_init,
                           training=self.training)(h)
     h = ResidualStitch(hidden_features=alpha,
                        output_features=2 * alpha,
                        strides=(2, 2),
                        norm=self.norm,
                        kernel_init=self.kernel_init,
                        training=self.training)(h)
     for i in range(n_step):
         h += ResidualUnit(hidden_features=2 * alpha,
                           norm=self.norm,
                           kernel_init=self.kernel_init,
                           training=self.training)(h)
     h = ResidualStitch(hidden_features=2 * alpha,
                        output_features=4 * alpha,
                        strides=(2, 2),
                        norm=self.norm,
                        kernel_init=self.kernel_init,
                        training=self.training)(h)
     for i in range(n_step):
         h += ResidualUnit(hidden_features=4 * alpha,
                           norm=self.norm,
                           kernel_init=self.kernel_init,
                           training=self.training)(h)
     h = NORMS[self.norm](use_running_average=not self.training)(h)
     # h = nn.pooling.avg_pool(h, (h.shape[-3], h.shape[-2]))
     # h = h.reshape(h.shape[0], -1)
     h = jnp.mean(h, axis=(1, 2))
     h = nn.Dense(features=self.n_classes)(h)
     return nn.log_softmax(h)  # no softmax
示例#28
0
  def __call__(self, x):
    """Define the model architecture.

    Network is used to both estimate policy (logits) and expected state value;
    in other words, hidden layers' params are shared between policy and value
    networks.

    Args:
      x: input of shape N, H, W(1)

    Returns:
      policy_log_probabilities: logits
      value: state value
    """
    x = x.astype(self.dtype)
    x = nn.Conv(
        features=self.chan1,
        kernel_size=[3],
        strides=1,
        name='conv1',
        dtype=self.dtype)(
            x)
    x = nn.relu(x)
    x = nn.Conv(
        features=self.chan2,
        kernel_size=[3],
        strides=1,
        name='conv2',
        dtype=self.dtype)(
            x)
    x = nn.relu(x)

    x = x.reshape((x.shape[0], -1))  # flatten
    logits = nn.Dense(
        features=self.num_actions, name='logits', dtype=self.dtype)(
            x)
    policy_log_probabilities = nn.log_softmax(logits)
    value = nn.Dense(features=1, name='value', dtype=self.dtype)(x)
    return policy_log_probabilities, value
示例#29
0
    def compute_weighted_cross_entropy(self,
                                       logits,
                                       targets,
                                       weights,
                                       label_smoothing=0.0):
        """Compute weighted cross entropy and entropy for log probs and targets.

    Args:
     logits: [batch, length, num_classes] float array.
     targets: categorical targets [batch, length] int array.
     weights: array of shape [batch, length].
     label_smoothing: label smoothing constant, used to determine the on and off
       values.

    Returns:
      Tuple of loss for every example and batch normalizing factor.
    """
        if logits.ndim != targets.ndim + 1:
            raise ValueError(
                f'Incorrect shapes. Got shape {str(logits.shape)} logits '
                f'and {str(targets.shape)} targets')
        confidence = 1.0 - label_smoothing
        low_confidence = (1.0 - confidence) / (self._vocab_size - 1)
        normalizing_constant = -(confidence * jnp.log(confidence) +
                                 ((self._vocab_size - 1) * low_confidence *
                                  jnp.log(low_confidence + 1e-20)))
        soft_targets = common_utils.onehot(targets,
                                           self._vocab_size,
                                           on_value=confidence,
                                           off_value=low_confidence)

        loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
        loss = loss - normalizing_constant

        if weights is not None:
            loss = loss * weights

        return loss
示例#30
0
def cross_entropy(logits, targets):
  """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
                     (str(logits.shape), str(targets.shape)))
  vocab_size = logits.shape[-1]
  onehot_targets = common_utils.onehot(targets, vocab_size)

  loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)

  d = np.prod(targets.shape[1:])

  loss = util_fns.sum_except_batch(loss) / d / np.log(2)

  return loss