def _unweighted_loss_and_weights(self, logits, processed_labels, features):
    """Computes loss spec."""
    time_to_event, censored = processed_labels

    time_to_event.shape.assert_is_compatible_with(censored.shape)
    with tf.control_dependencies([
        tf.assert_positive(time_to_event),
    ]):
      model = self._survival_model(
          params=logits,
          labels=processed_labels,
          event_index=self._event_index,
          model_hparams=self._model_hparams)

      tf.logging.info(model.params())
      log_pdf_value = model.log_pdf(time_to_event)
      log_survival_value = model.log_survival_func(time_to_event)
      batch_loss = survival_util.negative_log_likelihood_loss(
          censored=censored,
          log_pdf_value=log_pdf_value,
          log_survival_value=log_survival_value)
      # batch_loss has shape [batch_size,1]
      tf.logging.info(batch_loss)

      scalar_loss = math_ops.reduce_mean(batch_loss, axis=-1, keepdims=True)
      weights = base_head.get_weights_and_check_match_logits(
          features=features, weight_column=self._weight_column, logits=logits)
      return scalar_loss, weights
Ejemplo n.º 2
0
 def _unweighted_loss_and_weights(self, logits, labels, features):
   """Computes unweighted loss and weights."""
   if self._loss_fn:
     unweighted_loss = base_head.call_loss_fn(
         loss_fn=self._loss_fn, labels=labels, logits=logits,
         features=features, expected_loss_dim=1)
   else:
     unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
         labels=labels, logits=logits)
   weights = base_head.get_weights_and_check_match_logits(
       features=features, weight_column=self._weight_column, logits=logits)
   return unweighted_loss, weights
Ejemplo n.º 3
0
 def _unweighted_loss_and_weights(self, logits, labels, features):
   """Computes loss spec."""
   if self._loss_fn:
     unweighted_loss = base_head.call_loss_fn(
         loss_fn=self._loss_fn, labels=labels, logits=logits,
         features=features, expected_loss_dim=self._logits_dimension)
   else:
     unweighted_loss = losses.mean_squared_error(
         labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
   weights = base_head.get_weights_and_check_match_logits(
       features=features, weight_column=self._weight_column, logits=logits,
       allow_per_logit_weights=True)
   return unweighted_loss, weights
Ejemplo n.º 4
0
 def _unweighted_loss_and_weights(self, logits, label_ids, features):
   """Computes loss spec."""
   if self._loss_fn:
     unweighted_loss = base_head.call_loss_fn(
         loss_fn=self._loss_fn,
         labels=label_ids,
         logits=logits,
         features=features,
         expected_loss_dim=1)
   else:
     unweighted_loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(
         labels=label_ids,
         logits=logits,
         reduction=tf.compat.v1.losses.Reduction.NONE)
     # Restore the squeezed dim, so unweighted_loss matches the weights shape.
     unweighted_loss = tf.compat.v1.expand_dims(unweighted_loss, axis=-1)
   weights = base_head.get_weights_and_check_match_logits(
       features=features, weight_column=self._weight_column, logits=logits)
   return unweighted_loss, weights
Ejemplo n.º 5
0
 def _unweighted_loss_and_weights(self, logits, processed_labels, features):
   """Computes loss spec."""
   if self._loss_fn:
     unweighted_loss = base_head.call_loss_fn(
         loss_fn=self._loss_fn,
         labels=processed_labels,
         logits=logits,
         features=features,
         expected_loss_dim=1)
   else:
     unweighted_loss = tf.compat.v1.losses.sigmoid_cross_entropy(
         multi_class_labels=processed_labels,
         logits=logits,
         reduction=tf.compat.v1.losses.Reduction.NONE)
     # Averages loss over classes.
     unweighted_loss = tf.math.reduce_mean(
         unweighted_loss, axis=-1, keepdims=True)
   weights = base_head.get_weights_and_check_match_logits(
       features=features, weight_column=self._weight_column, logits=logits)
   return unweighted_loss, weights
    def _unweighted_loss_and_weights(self, logits, processed_labels, features,
                                     mode):
        """Computes loss spec."""
        if self._model_hparams.model_name in STATE_SPACE_MODELS:
            _, loss = logits
            return loss, 1

        if self._model_hparams.has_mask:
            mask_obs_tensor = tf.concat([
                tf.expand_dims(features[obs + '_mask'], 2)
                for obs in self._model_hparams.observation_codes
            ],
                                        axis=2)
        else:
            mask_obs_tensor = tf.ones_like(processed_labels)

        if self._model_hparams.has_mask:
            true_len_hr = features['true_length_hr']
            seqlen = tf.reduce_max(true_len_hr)
        else:
            batch_size = logits.get_shape().as_list()[0]
            true_len_hr = tf.fill([batch_size],
                                  self._model_hparams.context_window_size)
            seqlen = self._model_hparams.context_window_size
        if self._model_hparams.model_name == 'lstm_ds':
            seqlen = seqlen - 1

        # processed_labels is not trimmed with true len with
        # [batch, context_window_size -1, num_obs]
        # logit is trimmed with seqlen with shape [batch, seqlen, num_obs]
        # batch_time_value_loss shape [batch, seqlen]
        # seqlen is #values, if seqlen = 5, there are 5 values for each feature.
        # trimmed_processed_labels shape [batch, seqlen, num_obs]
        trimmed_processed_labels = tf.slice(processed_labels, [0, 0, 0],
                                            [-1, seqlen, -1])
        trimmed_mask_obs_tensor = tf.slice(mask_obs_tensor, [0, 0, 0],
                                           [-1, seqlen, -1])

        # Compute L2 value loss based on true sequence len/windown size and obs val.
        batch_time_feature_loss = tf.multiply(
            tf.square(trimmed_processed_labels - logits),
            trimmed_mask_obs_tensor)
        batch_time_value_loss = tf.reduce_mean(batch_time_feature_loss, axis=2)

        # self._model_hparams.last_obs_len is the num of most recent observations
        # used for computing loss.
        # batch_value_loss shape [batch, 1].
        last_obs_len = self._model_hparams.last_obs_len
        assert last_obs_len < self._model_hparams.context_window_size

        # zero out the loss outside [seqlen-last_obs_len, seqlen].
        true_len_mask = tf.sequence_mask(true_len_hr, seqlen)
        if last_obs_len == -1:
            # all obs up to true len are included.
            selection_mask = tf.cast(true_len_mask, tf.float32)
        else:
            last_obs_len_mask = tf.sequence_mask(true_len_hr - last_obs_len,
                                                 seqlen)
            selection_mask = tf.cast(
                tf.logical_xor(true_len_mask, last_obs_len_mask), tf.float32)

        trimmed_batch_time_value_loss = tf.multiply(batch_time_value_loss,
                                                    selection_mask)
        batch_value_loss = tf.div_no_nan(
            tf.reduce_sum(trimmed_batch_time_value_loss, axis=1),
            tf.reduce_sum(selection_mask, axis=1))

        scalar_loss = math_ops.reduce_mean(batch_value_loss,
                                           axis=-1,
                                           keepdims=True)

        weights = base_head.get_weights_and_check_match_logits(
            features=features,
            weight_column=self._weight_column,
            logits=logits)
        return scalar_loss, weights