def predict_counts(self, logits, outputs=None): # pylint: disable=unused-argument """Make count predictions from logits and counts. Args: logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis]. outputs: LabeledTensor with dtype=float32 and axes [batch, output_axis]. Unused by the base class but in the signature for the benefit of subclasses that use counts from previous rounds to help predict future rounds. It is the responsibility of the implementation using `outputs` to ensure that this method respects the casual structure of the experiment. Returns: preds: LabeledTensor with dtype=float32 and axes [batch, target_axis]. """ # TODO(shoyer): consider using tf.nn.softplus instead of abs here weights = abs(self.affinity_weights) * self.selection_signs if self.additional_output_axis: affinity_logits = lt.rename_axis( lt.select(logits, {'target': list(self.affinity_axis.labels)}), 'target', 'affinity') else: affinity_logits = lt.rename_axis(logits, 'target', 'affinity') preds = lt.matmul(affinity_logits, weights) + self.bias return preds
def predict_counts(self, logits, outputs): """See method on base class.""" preds = super(LatentAffinityWithCrossDeps, self).predict_counts(logits, outputs) interact_weights = abs(self.logit_by_prev_count) * self.selection_signs # We're calling _normed_prev_round_counts a second time here with the same # arguments, but that's actually OK because TensorFlow automatically # consolidates these calls. if self.additional_output_axis: affinity_logits = lt.rename_axis( lt.select(logits, {'target': list(self.affinity_axis.labels)}), 'target', 'affinity') else: affinity_logits = lt.rename_axis(logits, 'target', 'affinity') preds += (lt.matmul(affinity_logits, interact_weights) * self._normed_prev_round_counts(outputs)) return preds
def predict_affinity(self, logits): """See method on base class.""" if not self.affinity_target_lt: raise Error( 'No affinity_target_map has been designated. This FullyObserved ' 'layer cannot calculate the affinity. The FullyObserved layer ' 'must be initialized with an affinity_target_map to be capable ' 'of calculating affinity.') # then do matrix multiple to turn (target) X (target by protein) # to a vector of length protein. For proteins with multiple targets, the # multiplication takes the sum of the values. if self.additional_output_axis: count_logits = lt.select(logits, {'target': list(self.target_axis.labels)}) else: count_logits = logits output_per_affinity = lt.matmul(count_logits, self.affinity_target_lt) return output_per_affinity