def predict_counts(self, logits, outputs=None):  # pylint: disable=unused-argument
    """Make count predictions from logits and counts.

    Args:
      logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis].
      outputs: LabeledTensor with dtype=float32 and axes [batch, output_axis].
        Unused by the base class but in the signature for the benefit of
        subclasses that use counts from previous rounds to help predict future
        rounds. It is the responsibility of the implementation using `outputs`
        to ensure that this method respects the casual structure of the
        experiment.

    Returns:
      preds: LabeledTensor with dtype=float32 and axes [batch, target_axis].
    """
    # TODO(shoyer): consider using tf.nn.softplus instead of abs here
    weights = abs(self.affinity_weights) * self.selection_signs
    if self.additional_output_axis:
      affinity_logits = lt.rename_axis(
          lt.select(logits, {'target': list(self.affinity_axis.labels)}),
          'target', 'affinity')
    else:
      affinity_logits = lt.rename_axis(logits, 'target', 'affinity')
    preds = lt.matmul(affinity_logits, weights) + self.bias
    return preds
 def predict_counts(self, logits, outputs):
   """See method on base class."""
   preds = super(LatentAffinityWithCrossDeps, self).predict_counts(logits,
                                                                   outputs)
   interact_weights = abs(self.logit_by_prev_count) * self.selection_signs
   # We're calling _normed_prev_round_counts a second time here with the same
   # arguments, but that's actually OK because TensorFlow automatically
   # consolidates these calls.
   if self.additional_output_axis:
     affinity_logits = lt.rename_axis(
         lt.select(logits, {'target': list(self.affinity_axis.labels)}),
         'target', 'affinity')
   else:
     affinity_logits = lt.rename_axis(logits, 'target', 'affinity')
   preds += (lt.matmul(affinity_logits, interact_weights) *
             self._normed_prev_round_counts(outputs))
   return preds
  def predict_affinity(self, logits):
    """See method on base class."""

    if not self.affinity_target_lt:
      raise Error(
          'No affinity_target_map has been designated. This FullyObserved '
          'layer cannot calculate the affinity. The FullyObserved layer '
          'must be initialized with an affinity_target_map to be capable '
          'of calculating affinity.')

    # then do matrix multiple to turn (target) X (target by protein)
    # to a vector of length protein. For proteins with multiple targets, the
    # multiplication takes the sum of the values.
    if self.additional_output_axis:
      count_logits = lt.select(logits,
                               {'target': list(self.target_axis.labels)})
    else:
      count_logits = logits
    output_per_affinity = lt.matmul(count_logits, self.affinity_target_lt)

    return output_per_affinity