def loss_function( self, inputs: Dict[str, tf.Tensor], outputs: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: sequence_mask = rk.utils.convert_sequence_length_to_sequence_mask( inputs['primary'], inputs['protein_length']) valid_mask = tf.cast(inputs['valid_mask'], tf.float32) sequence_mask = tf.cast(sequence_mask, tf.float32) * valid_mask angle_loss = self.compute_angle_loss(inputs['phi'], inputs['psi'], outputs['phi_pred'], outputs['psi_pred'], valid_mask) rsa_loss = tf.losses.mean_squared_error(inputs['rsa'], outputs['rsa_pred'], sequence_mask) disorder_loss = tf.losses.sigmoid_cross_entropy( inputs['disorder'], outputs['disorder_pred'], sequence_mask) interface_loss = tf.losses.sigmoid_cross_entropy( inputs['interface'], outputs['interface_pred'], sequence_mask) ss3_loss, ss3_acc = classification_loss_and_accuracy( inputs['ss3'], outputs['ss3_pred'], sequence_mask) ss8_loss, ss8_acc = classification_loss_and_accuracy( inputs['ss8'], outputs['ss8_pred'], sequence_mask) loss = angle_loss + rsa_loss + disorder_loss + interface_loss + ss3_loss + ss8_loss metrics = {'SS3ACC': ss3_acc, 'SS8ACC': ss8_acc} return loss, metrics
def loss_function( self, inputs: Dict[str, tf.Tensor], outputs: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: labels = inputs[self._label_name] logits = outputs[self._output_name] if self._mask_name != 'sequence_mask': mask = outputs[self._mask_name] else: mask = rk.utils.convert_sequence_length_to_sequence_mask( labels, inputs['protein_length']) loss, accuracy = classification_loss_and_accuracy(labels, logits, mask) ece = tf.exp(loss) probs = tf.nn.softmax(logits) logp = tf.nn.log_softmax(logits) perplexity = tf.exp(-tf.reduce_sum(probs * logp, -1)) weights = tf.ones_like(perplexity) * tf.cast(mask, perplexity.dtype) perplexity = tf.reduce_sum( perplexity * weights) / (tf.reduce_sum(weights) + 1e-10) metrics = { self.key_metric: accuracy, 'ECE': ece, 'Perplexity': perplexity } return loss, metrics
def loss_function(self, inputs: Dict[str, tf.Tensor], outputs: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: label = inputs[self._label] prediction = outputs[self._output_name] loss, accuracy = classification_loss_and_accuracy(label, prediction) metrics = {self.key_metric: accuracy} return loss, metrics
def loss_function( self, inputs: Dict[str, tf.Tensor], outputs: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: labels = inputs[self._label_name] logits = outputs[self._output_name] if self._mask_name != 'sequence_mask': mask = outputs[self._mask_name] else: mask = rk.utils.convert_sequence_length_to_sequence_mask( labels, inputs['protein_length']) loss, accuracy = classification_loss_and_accuracy(labels, logits, mask) metrics = {self.key_metric: accuracy} return loss, metrics