def _unweighted_loss_and_weights(self, logits, processed_labels, features): """Computes loss spec.""" time_to_event, censored = processed_labels time_to_event.shape.assert_is_compatible_with(censored.shape) with tf.control_dependencies([ tf.assert_positive(time_to_event), ]): model = self._survival_model( params=logits, labels=processed_labels, event_index=self._event_index, model_hparams=self._model_hparams) tf.logging.info(model.params()) log_pdf_value = model.log_pdf(time_to_event) log_survival_value = model.log_survival_func(time_to_event) batch_loss = survival_util.negative_log_likelihood_loss( censored=censored, log_pdf_value=log_pdf_value, log_survival_value=log_survival_value) # batch_loss has shape [batch_size,1] tf.logging.info(batch_loss) scalar_loss = math_ops.reduce_mean(batch_loss, axis=-1, keepdims=True) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) return scalar_loss, weights
def _unweighted_loss_and_weights(self, logits, labels, features): """Computes unweighted loss and weights.""" if self._loss_fn: unweighted_loss = base_head.call_loss_fn( loss_fn=self._loss_fn, labels=labels, logits=logits, features=features, expected_loss_dim=1) else: unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) return unweighted_loss, weights
def _unweighted_loss_and_weights(self, logits, labels, features): """Computes loss spec.""" if self._loss_fn: unweighted_loss = base_head.call_loss_fn( loss_fn=self._loss_fn, labels=labels, logits=logits, features=features, expected_loss_dim=self._logits_dimension) else: unweighted_loss = losses.mean_squared_error( labels=labels, predictions=logits, reduction=losses.Reduction.NONE) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits, allow_per_logit_weights=True) return unweighted_loss, weights
def _unweighted_loss_and_weights(self, logits, label_ids, features): """Computes loss spec.""" if self._loss_fn: unweighted_loss = base_head.call_loss_fn( loss_fn=self._loss_fn, labels=label_ids, logits=logits, features=features, expected_loss_dim=1) else: unweighted_loss = tf.compat.v1.losses.sparse_softmax_cross_entropy( labels=label_ids, logits=logits, reduction=tf.compat.v1.losses.Reduction.NONE) # Restore the squeezed dim, so unweighted_loss matches the weights shape. unweighted_loss = tf.compat.v1.expand_dims(unweighted_loss, axis=-1) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) return unweighted_loss, weights
def _unweighted_loss_and_weights(self, logits, processed_labels, features): """Computes loss spec.""" if self._loss_fn: unweighted_loss = base_head.call_loss_fn( loss_fn=self._loss_fn, labels=processed_labels, logits=logits, features=features, expected_loss_dim=1) else: unweighted_loss = tf.compat.v1.losses.sigmoid_cross_entropy( multi_class_labels=processed_labels, logits=logits, reduction=tf.compat.v1.losses.Reduction.NONE) # Averages loss over classes. unweighted_loss = tf.math.reduce_mean( unweighted_loss, axis=-1, keepdims=True) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) return unweighted_loss, weights
def _unweighted_loss_and_weights(self, logits, processed_labels, features, mode): """Computes loss spec.""" if self._model_hparams.model_name in STATE_SPACE_MODELS: _, loss = logits return loss, 1 if self._model_hparams.has_mask: mask_obs_tensor = tf.concat([ tf.expand_dims(features[obs + '_mask'], 2) for obs in self._model_hparams.observation_codes ], axis=2) else: mask_obs_tensor = tf.ones_like(processed_labels) if self._model_hparams.has_mask: true_len_hr = features['true_length_hr'] seqlen = tf.reduce_max(true_len_hr) else: batch_size = logits.get_shape().as_list()[0] true_len_hr = tf.fill([batch_size], self._model_hparams.context_window_size) seqlen = self._model_hparams.context_window_size if self._model_hparams.model_name == 'lstm_ds': seqlen = seqlen - 1 # processed_labels is not trimmed with true len with # [batch, context_window_size -1, num_obs] # logit is trimmed with seqlen with shape [batch, seqlen, num_obs] # batch_time_value_loss shape [batch, seqlen] # seqlen is #values, if seqlen = 5, there are 5 values for each feature. # trimmed_processed_labels shape [batch, seqlen, num_obs] trimmed_processed_labels = tf.slice(processed_labels, [0, 0, 0], [-1, seqlen, -1]) trimmed_mask_obs_tensor = tf.slice(mask_obs_tensor, [0, 0, 0], [-1, seqlen, -1]) # Compute L2 value loss based on true sequence len/windown size and obs val. batch_time_feature_loss = tf.multiply( tf.square(trimmed_processed_labels - logits), trimmed_mask_obs_tensor) batch_time_value_loss = tf.reduce_mean(batch_time_feature_loss, axis=2) # self._model_hparams.last_obs_len is the num of most recent observations # used for computing loss. # batch_value_loss shape [batch, 1]. last_obs_len = self._model_hparams.last_obs_len assert last_obs_len < self._model_hparams.context_window_size # zero out the loss outside [seqlen-last_obs_len, seqlen]. true_len_mask = tf.sequence_mask(true_len_hr, seqlen) if last_obs_len == -1: # all obs up to true len are included. selection_mask = tf.cast(true_len_mask, tf.float32) else: last_obs_len_mask = tf.sequence_mask(true_len_hr - last_obs_len, seqlen) selection_mask = tf.cast( tf.logical_xor(true_len_mask, last_obs_len_mask), tf.float32) trimmed_batch_time_value_loss = tf.multiply(batch_time_value_loss, selection_mask) batch_value_loss = tf.div_no_nan( tf.reduce_sum(trimmed_batch_time_value_loss, axis=1), tf.reduce_sum(selection_mask, axis=1)) scalar_loss = math_ops.reduce_mean(batch_value_loss, axis=-1, keepdims=True) weights = base_head.get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) return scalar_loss, weights