def predict_outputs(self, logits, outputs=None):
    """Predict a score that should correlate with each output.

    Args:
      logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis].
      outputs: optional LabeledTensor with dtype=float32 and axes [batch,
        output_axis]. Note that different output layers may not be directly
        comparable if they make sure of `outputs` from prior rounds of selection
        in predictions.

    Returns:
      LabeledTensor with dtype=float32 and axes [batch, output_axis] giving
      predictions for each count and binding array.
    """
    predicted_counts = lt.rename_axis(
        self.predict_counts(logits, outputs), 'target', 'output')

    if self.binding_arrays_map:
      predicted_affinity = self.predict_affinity(logits)
      predicted_binding_arrays = lt.pack([
          lt.select(predicted_affinity, {'affinity': target})
          for target in self.binding_arrays_map.values()
      ], ('output', list(self.binding_arrays_map.keys())),
                                         axis_position=1)
      preds = lt.concat([predicted_counts, predicted_binding_arrays], 'output')
    else:
      preds = predicted_counts

    if self.additional_output_axis:
      predicted_additional_output = lt.rename_axis(
          self.predict_additional_output(logits), 'target', 'output')
      preds = lt.concat([preds, predicted_additional_output], 'output')
    return preds
Esempio n. 2
0
def _stack_inputs_by_rank(inputs):
    """Create 2D and 3D input tensors from a dictionary of inputs.

  3D inputs are stacked together for use in (optional) convolutional layers.
  2D inputs are only used in fully-connected layers.

  Args:
    inputs: Dict[str, lt.LabeledTensor] providing input features. All features
      must be 2D or 3D labeled tensors with a 'batch' axis as their first
      dimension. 3D tensors must have 'position' as their second axis. The last
      axis of all tensors is allowed to vary, because raw input features may
      have different names for labels that are more meaningful than generic
      "features" or "channels".

  Returns:
    Tuple[Optional[lt.LabeledTensor], Optional[lt.LabeledTensor]], where the
    first labeled tensor, if present, has axes ['batch', 'feature'] and the
    second labeled tensor, if present, has axes ['batch', 'position',
    'channel'].

  Raises:
    ValueError: if the result tensors do not have the same batch axis.
  """
    inputs_2d = []
    inputs_3d = []
    for key in sorted(inputs):
        # outputs should be fixed across randomized dict iteration order
        tensor = inputs[key]
        if len(tensor.axes) == 2:
            tensor = lt.rename_axis(tensor,
                                    list(tensor.axes.keys())[-1], 'feature')
            inputs_2d.append(tensor)
        elif len(tensor.axes) == 3:
            assert list(tensor.axes.values())[1].name == 'position'
            tensor = lt.rename_axis(tensor,
                                    list(tensor.axes.keys())[-1], 'channel')
            inputs_3d.append(tensor)
        else:
            raise AssertionError('unexpected rank')

    combined_2d = lt.concat(inputs_2d, 'feature') if inputs_2d else None
    combined_3d = lt.concat(inputs_3d, 'channel') if inputs_3d else None
    if combined_2d is not None and combined_3d is not None:
        if list(combined_2d.axes.values())[0] != list(
                combined_2d.axes.values())[0]:
            raise ValueError('mismatched batch axis')
    return combined_2d, combined_3d
 def loss_per_example_and_target(self, logits, outputs, include_array=True):
   """See method on base class."""
   targets = _targets_from_outputs(outputs, self.logit_axis)
   loss = self.loss.per_example_and_target(logits, targets)
   if bool(set(self.binding_arrays_map.keys()) &
           set(outputs.axes['output'].labels)) and include_array:
     affinity_loss = self.affinity_loss_per_example_and_target(logits, outputs)
     return lt.concat([loss, affinity_loss], 'target')
   else:
     return loss
 def loss_per_example_and_target(self, logits, outputs, include_array=True):
   """See method on base class."""
   with tf.name_scope('predictions'):
     if self.additional_output_axis:
       affinity_logits = lt.select(logits,
                                   {'target': list(self.affinity_axis.labels)})
       ao_logits = lt.select(logits,
                             {'target':
                              list(self.additional_output_axis.labels)})
       count_preds = self.predict_counts(affinity_logits, outputs)
       preds = lt.concat([count_preds, ao_logits], 'target')
     else:
       preds = self.predict_counts(logits, outputs)
   targets = _targets_from_outputs(outputs, self.all_target_axis)
   loss = self.loss.per_example_and_target(preds, targets)
   if bool(set(self.binding_arrays_map.keys()) &
           set(outputs.axes['output'].labels)) and include_array:
     affinity_loss = self.affinity_loss_per_example_and_target(logits, outputs)
     return lt.concat([loss, affinity_loss], 'target')
   else:
     return loss
  def average_loss_per_target(self, logits, outputs, include_array=True):
    """Calculate averaged over examples.

    This is the loss to use for training. If affinity loss is calculated and
    "include_array" is set to True, the count loss for the novel sequences
    included in the microarray and the affinity loss for the sequences not
    included in the microarray are excluded from the average loss calculation.
    Otherwise, return the average count loss over all samples.

    Args:
      logits: LabeledTensor with dtype=float32 and axes [batch, logit_axis].
      outputs: LabeledTensor with dtype=float32 and axes [batch, output_axis].
      include_array: Optional boolean variable indicating whether to also
                     compute affinity loss against binding array data.

    Returns:
      LabeledTensor with type=float32 with axes [output_axis].
    """
    # should be independent of mini-batch size
    loss_matrix = self.loss_per_example_and_target(logits,
                                                   outputs,
                                                   include_array)

    if bool(set(self.binding_arrays_map.keys()) &
            set(outputs.axes['output'].labels)) and include_array:
      count_loss = lt.select(loss_matrix,
                             {'target': list(self.target_axis.labels)})
      # Only the count loss for the samples with at least one non-zero
      # count output will be kept.
      loss_matrix_keep_idx = lt.reduce_any(lt.not_equal(
          lt.select(outputs, {'output': list(self.target_axis.labels)})
          , 0.0), 'output')
      loss_matrix_keep = lt.boolean_mask(count_loss, loss_matrix_keep_idx)
      reduce_loss_matrix = utils.reduce_nanmean(loss_matrix_keep, 'batch')

      affinity_loss = lt.select(
          loss_matrix, {'target': list(self.binding_arrays_map.keys())})
      # Only the affinity loss for the samples with at least one non-zero
      # affinity output wil be kept.
      affinity_loss_keep_idx = lt.reduce_any(
          lt.not_equal(
              lt.select(outputs,
                        {'output': list(self.binding_arrays_map.keys())}), 0.0),
          'output')
      affity_loss_keep = lt.boolean_mask(affinity_loss, affinity_loss_keep_idx)
      reduce_affity_loss = utils.reduce_nanmean(affity_loss_keep, 'batch')
      # Count loss and affinity loss are concatenated
      avg_loss = lt.concat([reduce_loss_matrix, reduce_affity_loss], 'target')

      # Only the additional output loss for the samples with at least one
      # non-zero output value wil be kept.
      if self.additional_output_axis:
        ao_labels = list(self.additional_output_axis.labels)
        af_loss = lt.select(loss_matrix, {'target': ao_labels})
        af_loss_keep_idx = lt.reduce_any(
            lt.not_equal(lt.select(outputs, {'output': ao_labels}), 0.0),
            'output')
        af_loss_keep = lt.boolean_mask(af_loss, af_loss_keep_idx)
        reduce_af_loss = utils.reduce_nanmean(af_loss_keep, 'batch')
        avg_loss = lt.concat([avg_loss, reduce_af_loss], 'target')

    else:
      avg_loss = utils.reduce_nanmean(loss_matrix, 'batch')

    return avg_loss
Esempio n. 6
0
def preprocess(strs,
               experiment_proto,
               input_features=(SEQUENCE_ONE_HOT, ),
               mode=PREPROCESS_SKIP_ALL_ZERO_COUNTS,
               kmer_k_max=4,
               ratio_random_dna=1,
               total_reads_defining_positive=0,
               additional_output=None):
    """Build a small TF graph to preprocess a minibatch of tf.Example protos.

  Args:
    strs: LabeledTensor holding a minibatch of serialized tf.Example protos
    experiment_proto: selection_pb2.Experiment describing the experiment.
    input_features: optional sequence of feature constants defined in this
      module.
    mode: optional preprocess mode defined in this module.
    kmer_k_max: optional integer giving the maximum kmer length to use if
      SEQUENCE_KMER_COUNT is in `input_features`.
    ratio_random_dna: optional ratio of random sequences to inject if mode ==
      PREPROCESS_INJECT_RANDOM_SEQUENCES
    total_reads_defining_positive: optional integer indicating the sum of all
      read counts required to be seen to classify the tensor as a "positive"
      example when balancing input classes.
    additional_output: optional list of strings contains additional outputs.

  Returns:
    inputs: LabeledTensor with dtype=float32 and axes
      [batch_axis, input_position_axis, input_channel_axis], of one-hot-encoded
      rasterized sequences for input into machine learning models.
    outputs: LabeledTensor with dtype=float32 and axes [batch_axis, output_axis]
      denoting possible output tensors, including counts and binding array
      measurements.
  """
    with tf.name_scope('preprocess'):
        features = build_features(experiment_proto)
        parsed_feature_tensors = lt.parse_example(strs, features)
        count_names = selection.all_count_names(experiment_proto)

        if mode == PREPROCESS_SKIP_ALL_ZERO_COUNTS:
            skip_all_zero_counts = True
            feature_tensors = parsed_feature_tensors

        elif mode == PREPROCESS_ALL_COUNTS:
            skip_all_zero_counts = False
            feature_tensors = parsed_feature_tensors

        elif mode == PREPROCESS_INJECT_RANDOM_SEQUENCES:
            skip_all_zero_counts = False

            # replace zero counts with NaN in real data
            for count_name in count_names:
                count = parsed_feature_tensors[count_name]
                parsed_feature_tensors[count_name] = lt.LabeledTensor(
                    tf.where(count != 0, tf.cast(count, tf.float32),
                             tf.fill(tf.shape(count), np.float32(np.nan))),
                    count.axes)

            # only random sequences will have a count of zero
            input_batch_size = tf.shape(strs.tensor)[list(
                strs.axes.keys()).index('batch')]
            n_randoms = tf.cast(
                tf.cast(input_batch_size, tf.float32) * ratio_random_dna,
                tf.int32)
            random_feature_tensors = random_dna_features(
                experiment_proto, n_randoms)
            for count_name in count_names:
                random_feature_tensors[count_name] = lt.cast(
                    random_feature_tensors[count_name], tf.float32)

            feature_tensors = {
                k: lt.concat(
                    [random_feature_tensors[k], parsed_feature_tensors[k]],
                    'batch')
                for k in features
            }

            # shuffle random and non-random inputs because preprocess batches get
            # split across many mini-batches for training
            batch_size = tf.shape(feature_tensors['sequence'].tensor)[0]
            order = tf.random_shuffle(tf.range(batch_size, dtype=tf.int32))
            order.set_shape(feature_tensors['sequence'].tensor.get_shape())
            feature_tensors = {
                k: lt.LabeledTensor(tf.gather(v.tensor, order), v.axes)
                for k, v in feature_tensors.items()
            }

        else:
            raise ValueError('unknown mode: %r' % mode)  # pylint: disable=g-doc-exception

        feature_tensors = upsample_positives(
            feature_tensors,
            count_names,
            total_reads_defining_positive=total_reads_defining_positive,
            min_fraction_positive=0.1)

        inputs, outputs = create_input_and_outputs(
            feature_tensors,
            experiment_proto,
            input_features=input_features,
            kmer_k_max=kmer_k_max,
            skip_all_zero_counts=skip_all_zero_counts,
            additional_output=additional_output)

        return inputs, outputs