예제 #1
0
    def loss_function(
        self, inputs: Dict[str, tf.Tensor],
        outputs: Dict[str,
                      tf.Tensor]) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
        sequence_mask = rk.utils.convert_sequence_length_to_sequence_mask(
            inputs['primary'], inputs['protein_length'])
        valid_mask = tf.cast(inputs['valid_mask'], tf.float32)

        sequence_mask = tf.cast(sequence_mask, tf.float32) * valid_mask

        angle_loss = self.compute_angle_loss(inputs['phi'], inputs['psi'],
                                             outputs['phi_pred'],
                                             outputs['psi_pred'], valid_mask)

        rsa_loss = tf.losses.mean_squared_error(inputs['rsa'],
                                                outputs['rsa_pred'],
                                                sequence_mask)
        disorder_loss = tf.losses.sigmoid_cross_entropy(
            inputs['disorder'], outputs['disorder_pred'], sequence_mask)
        interface_loss = tf.losses.sigmoid_cross_entropy(
            inputs['interface'], outputs['interface_pred'], sequence_mask)
        ss3_loss, ss3_acc = classification_loss_and_accuracy(
            inputs['ss3'], outputs['ss3_pred'], sequence_mask)
        ss8_loss, ss8_acc = classification_loss_and_accuracy(
            inputs['ss8'], outputs['ss8_pred'], sequence_mask)

        loss = angle_loss + rsa_loss + disorder_loss + interface_loss + ss3_loss + ss8_loss
        metrics = {'SS3ACC': ss3_acc, 'SS8ACC': ss8_acc}

        return loss, metrics
    def loss_function(
        self, inputs: Dict[str, tf.Tensor],
        outputs: Dict[str,
                      tf.Tensor]) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
        labels = inputs[self._label_name]
        logits = outputs[self._output_name]
        if self._mask_name != 'sequence_mask':
            mask = outputs[self._mask_name]
        else:
            mask = rk.utils.convert_sequence_length_to_sequence_mask(
                labels, inputs['protein_length'])
        loss, accuracy = classification_loss_and_accuracy(labels, logits, mask)

        ece = tf.exp(loss)
        probs = tf.nn.softmax(logits)
        logp = tf.nn.log_softmax(logits)
        perplexity = tf.exp(-tf.reduce_sum(probs * logp, -1))
        weights = tf.ones_like(perplexity) * tf.cast(mask, perplexity.dtype)
        perplexity = tf.reduce_sum(
            perplexity * weights) / (tf.reduce_sum(weights) + 1e-10)

        metrics = {
            self.key_metric: accuracy,
            'ECE': ece,
            'Perplexity': perplexity
        }
        return loss, metrics
예제 #3
0
    def loss_function(self,
                      inputs: Dict[str, tf.Tensor],
                      outputs: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
        label = inputs[self._label]
        prediction = outputs[self._output_name]

        loss, accuracy = classification_loss_and_accuracy(label, prediction)

        metrics = {self.key_metric: accuracy}

        return loss, metrics
예제 #4
0
파일: Task.py 프로젝트: nitrogenase/TAPE
 def loss_function(
     self, inputs: Dict[str, tf.Tensor],
     outputs: Dict[str,
                   tf.Tensor]) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
     labels = inputs[self._label_name]
     logits = outputs[self._output_name]
     if self._mask_name != 'sequence_mask':
         mask = outputs[self._mask_name]
     else:
         mask = rk.utils.convert_sequence_length_to_sequence_mask(
             labels, inputs['protein_length'])
     loss, accuracy = classification_loss_and_accuracy(labels, logits, mask)
     metrics = {self.key_metric: accuracy}
     return loss, metrics