def predict_counts(self, logits, outputs=None): # pylint: disable=unused-argument """Make count predictions from logits and counts. Args: logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis]. outputs: LabeledTensor with dtype=float32 and axes [batch, output_axis]. Unused by the base class but in the signature for the benefit of subclasses that use counts from previous rounds to help predict future rounds. It is the responsibility of the implementation using `outputs` to ensure that this method respects the casual structure of the experiment. Returns: preds: LabeledTensor with dtype=float32 and axes [batch, target_axis]. """ # TODO(shoyer): consider using tf.nn.softplus instead of abs here weights = abs(self.affinity_weights) * self.selection_signs if self.additional_output_axis: affinity_logits = lt.rename_axis( lt.select(logits, {'target': list(self.affinity_axis.labels)}), 'target', 'affinity') else: affinity_logits = lt.rename_axis(logits, 'target', 'affinity') preds = lt.matmul(affinity_logits, weights) + self.bias return preds
def predict_outputs(self, logits, outputs=None): """Predict a score that should correlate with each output. Args: logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis]. outputs: optional LabeledTensor with dtype=float32 and axes [batch, output_axis]. Note that different output layers may not be directly comparable if they make sure of `outputs` from prior rounds of selection in predictions. Returns: LabeledTensor with dtype=float32 and axes [batch, output_axis] giving predictions for each count and binding array. """ predicted_counts = lt.rename_axis( self.predict_counts(logits, outputs), 'target', 'output') if self.binding_arrays_map: predicted_affinity = self.predict_affinity(logits) predicted_binding_arrays = lt.pack([ lt.select(predicted_affinity, {'affinity': target}) for target in self.binding_arrays_map.values() ], ('output', list(self.binding_arrays_map.keys())), axis_position=1) preds = lt.concat([predicted_counts, predicted_binding_arrays], 'output') else: preds = predicted_counts if self.additional_output_axis: predicted_additional_output = lt.rename_axis( self.predict_additional_output(logits), 'target', 'output') preds = lt.concat([preds, predicted_additional_output], 'output') return preds
def affinity_loss_per_example_and_target(self, logits, outputs): """Calculate loss per example on predicting affinity. This calls "predict_affinity" which assumably has been implemented in the current output layer to predict affinity, and calculates the loss against the array output. Args: logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis]. outputs: LabeledTensor with dtype=float32 and axes [batch, output_axis]. These outputs should include everything from the preprocessing, whether it is used in the loss or not. Returns: LabeledTensor with dtype=float32 and axes [batch, target_axis] giving loss for each target. """ affinity_pred = _affinities_to_binding_arrays(self.binding_arrays_map, self.predict_affinity(logits)) affinity_pred = lt.rename_axis(affinity_pred, 'output', 'target') array_output = lt.rename_axis( lt.select(outputs, {'output': list(self.binding_arrays_map.keys())}), 'output', 'target') return self.loss.per_example_and_target_array(affinity_pred, array_output)
def predict_affinity(self, logits): """See method on base class.""" if self.additional_output_axis: return lt.rename_axis( lt.select(logits, {'target': list(self.affinity_axis.labels)}), 'target', 'affinity') else: return lt.rename_axis(logits, 'target', 'affinity')
def _stack_inputs_by_rank(inputs): """Create 2D and 3D input tensors from a dictionary of inputs. 3D inputs are stacked together for use in (optional) convolutional layers. 2D inputs are only used in fully-connected layers. Args: inputs: Dict[str, lt.LabeledTensor] providing input features. All features must be 2D or 3D labeled tensors with a 'batch' axis as their first dimension. 3D tensors must have 'position' as their second axis. The last axis of all tensors is allowed to vary, because raw input features may have different names for labels that are more meaningful than generic "features" or "channels". Returns: Tuple[Optional[lt.LabeledTensor], Optional[lt.LabeledTensor]], where the first labeled tensor, if present, has axes ['batch', 'feature'] and the second labeled tensor, if present, has axes ['batch', 'position', 'channel']. Raises: ValueError: if the result tensors do not have the same batch axis. """ inputs_2d = [] inputs_3d = [] for key in sorted(inputs): # outputs should be fixed across randomized dict iteration order tensor = inputs[key] if len(tensor.axes) == 2: tensor = lt.rename_axis(tensor, list(tensor.axes.keys())[-1], 'feature') inputs_2d.append(tensor) elif len(tensor.axes) == 3: assert list(tensor.axes.values())[1].name == 'position' tensor = lt.rename_axis(tensor, list(tensor.axes.keys())[-1], 'channel') inputs_3d.append(tensor) else: raise AssertionError('unexpected rank') combined_2d = lt.concat(inputs_2d, 'feature') if inputs_2d else None combined_3d = lt.concat(inputs_3d, 'channel') if inputs_3d else None if combined_2d is not None and combined_3d is not None: if list(combined_2d.axes.values())[0] != list( combined_2d.axes.values())[0]: raise ValueError('mismatched batch axis') return combined_2d, combined_3d
def predict_counts(self, logits, outputs): """See method on base class.""" preds = super(LatentAffinityWithCrossDeps, self).predict_counts(logits, outputs) interact_weights = abs(self.logit_by_prev_count) * self.selection_signs # We're calling _normed_prev_round_counts a second time here with the same # arguments, but that's actually OK because TensorFlow automatically # consolidates these calls. if self.additional_output_axis: affinity_logits = lt.rename_axis( lt.select(logits, {'target': list(self.affinity_axis.labels)}), 'target', 'affinity') else: affinity_logits = lt.rename_axis(logits, 'target', 'affinity') preds += (lt.matmul(affinity_logits, interact_weights) * self._normed_prev_round_counts(outputs)) return preds