Ejemplo n.º 1
0
  def predict_counts(self, logits, outputs=None):  # pylint: disable=unused-argument
    """Make count predictions from logits and counts.

    Args:
      logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis].
      outputs: LabeledTensor with dtype=float32 and axes [batch, output_axis].
        Unused by the base class but in the signature for the benefit of
        subclasses that use counts from previous rounds to help predict future
        rounds. It is the responsibility of the implementation using `outputs`
        to ensure that this method respects the casual structure of the
        experiment.

    Returns:
      preds: LabeledTensor with dtype=float32 and axes [batch, target_axis].
    """
    # TODO(shoyer): consider using tf.nn.softplus instead of abs here
    weights = abs(self.affinity_weights) * self.selection_signs
    if self.additional_output_axis:
      affinity_logits = lt.rename_axis(
          lt.select(logits, {'target': list(self.affinity_axis.labels)}),
          'target', 'affinity')
    else:
      affinity_logits = lt.rename_axis(logits, 'target', 'affinity')
    preds = lt.matmul(affinity_logits, weights) + self.bias
    return preds
Ejemplo n.º 2
0
  def predict_outputs(self, logits, outputs=None):
    """Predict a score that should correlate with each output.

    Args:
      logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis].
      outputs: optional LabeledTensor with dtype=float32 and axes [batch,
        output_axis]. Note that different output layers may not be directly
        comparable if they make sure of `outputs` from prior rounds of selection
        in predictions.

    Returns:
      LabeledTensor with dtype=float32 and axes [batch, output_axis] giving
      predictions for each count and binding array.
    """
    predicted_counts = lt.rename_axis(
        self.predict_counts(logits, outputs), 'target', 'output')

    if self.binding_arrays_map:
      predicted_affinity = self.predict_affinity(logits)
      predicted_binding_arrays = lt.pack([
          lt.select(predicted_affinity, {'affinity': target})
          for target in self.binding_arrays_map.values()
      ], ('output', list(self.binding_arrays_map.keys())),
                                         axis_position=1)
      preds = lt.concat([predicted_counts, predicted_binding_arrays], 'output')
    else:
      preds = predicted_counts

    if self.additional_output_axis:
      predicted_additional_output = lt.rename_axis(
          self.predict_additional_output(logits), 'target', 'output')
      preds = lt.concat([preds, predicted_additional_output], 'output')
    return preds
Ejemplo n.º 3
0
  def affinity_loss_per_example_and_target(self, logits, outputs):
    """Calculate loss per example on predicting affinity.

    This calls "predict_affinity" which assumably has been implemented in the
    current output layer to predict affinity, and calculates the loss against
    the array output.

    Args:
      logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis].
      outputs: LabeledTensor with dtype=float32 and axes [batch, output_axis].
        These outputs should include everything from the preprocessing, whether
        it is used in the loss or not.

    Returns:
      LabeledTensor with dtype=float32 and axes [batch, target_axis] giving
      loss for each target.
    """
    affinity_pred = _affinities_to_binding_arrays(self.binding_arrays_map,
                                                  self.predict_affinity(logits))

    affinity_pred = lt.rename_axis(affinity_pred, 'output', 'target')
    array_output = lt.rename_axis(
        lt.select(outputs, {'output': list(self.binding_arrays_map.keys())}),
        'output', 'target')

    return self.loss.per_example_and_target_array(affinity_pred, array_output)
Ejemplo n.º 4
0
 def predict_affinity(self, logits):
   """See method on base class."""
   if self.additional_output_axis:
     return lt.rename_axis(
         lt.select(logits, {'target': list(self.affinity_axis.labels)}),
         'target', 'affinity')
   else:
     return lt.rename_axis(logits, 'target', 'affinity')
Ejemplo n.º 5
0
def _stack_inputs_by_rank(inputs):
    """Create 2D and 3D input tensors from a dictionary of inputs.

  3D inputs are stacked together for use in (optional) convolutional layers.
  2D inputs are only used in fully-connected layers.

  Args:
    inputs: Dict[str, lt.LabeledTensor] providing input features. All features
      must be 2D or 3D labeled tensors with a 'batch' axis as their first
      dimension. 3D tensors must have 'position' as their second axis. The last
      axis of all tensors is allowed to vary, because raw input features may
      have different names for labels that are more meaningful than generic
      "features" or "channels".

  Returns:
    Tuple[Optional[lt.LabeledTensor], Optional[lt.LabeledTensor]], where the
    first labeled tensor, if present, has axes ['batch', 'feature'] and the
    second labeled tensor, if present, has axes ['batch', 'position',
    'channel'].

  Raises:
    ValueError: if the result tensors do not have the same batch axis.
  """
    inputs_2d = []
    inputs_3d = []
    for key in sorted(inputs):
        # outputs should be fixed across randomized dict iteration order
        tensor = inputs[key]
        if len(tensor.axes) == 2:
            tensor = lt.rename_axis(tensor,
                                    list(tensor.axes.keys())[-1], 'feature')
            inputs_2d.append(tensor)
        elif len(tensor.axes) == 3:
            assert list(tensor.axes.values())[1].name == 'position'
            tensor = lt.rename_axis(tensor,
                                    list(tensor.axes.keys())[-1], 'channel')
            inputs_3d.append(tensor)
        else:
            raise AssertionError('unexpected rank')

    combined_2d = lt.concat(inputs_2d, 'feature') if inputs_2d else None
    combined_3d = lt.concat(inputs_3d, 'channel') if inputs_3d else None
    if combined_2d is not None and combined_3d is not None:
        if list(combined_2d.axes.values())[0] != list(
                combined_2d.axes.values())[0]:
            raise ValueError('mismatched batch axis')
    return combined_2d, combined_3d
Ejemplo n.º 6
0
 def predict_counts(self, logits, outputs):
   """See method on base class."""
   preds = super(LatentAffinityWithCrossDeps, self).predict_counts(logits,
                                                                   outputs)
   interact_weights = abs(self.logit_by_prev_count) * self.selection_signs
   # We're calling _normed_prev_round_counts a second time here with the same
   # arguments, but that's actually OK because TensorFlow automatically
   # consolidates these calls.
   if self.additional_output_axis:
     affinity_logits = lt.rename_axis(
         lt.select(logits, {'target': list(self.affinity_axis.labels)}),
         'target', 'affinity')
   else:
     affinity_logits = lt.rename_axis(logits, 'target', 'affinity')
   preds += (lt.matmul(affinity_logits, interact_weights) *
             self._normed_prev_round_counts(outputs))
   return preds