def transform(counts):
   if log_transform:
     counts = lt.log(1.0 + counts)
   selection_dict = {'target': list(counts.axes['target'].labels)}
   aligned_means = lt.select(means, selection_dict)
   aligned_stddevs = lt.select(stddevs, selection_dict)
   return (counts - aligned_means) / aligned_stddevs
def compute_experiment_statistics(
        experiment_proto,
        input_paths,
        proto_w_stats_path,
        preprocess_mode=data.PREPROCESS_SKIP_ALL_ZERO_COUNTS,
        max_size=None,
        logdir=None,
        save_stats=False):
    """Calculate the mean and standard deviation of counts from input files.

  These statistics are used for normalization. If any statistic is missing or
  save_stats=True, compute the statistics. Save the statitics to
  proto_w_stats_path if save_stats=True.

  Args:
    experiment_proto: selection_pb2.Experiment describing the experiment.
    input_paths: list of strings giving paths to sstables of input examples.
    proto_w_stats_path: string path to the validation proto file with stats
    preprocess_mode: optional preprocess mode defined in the `data` module.
    max_size: optional number of examples to examine to compute statistics. By
      default, examines the entire dataset.
    logdir: optional path to a directory in which to log events.
    save_stats: optional boolean indicating whether to update all the statistics
      and save to proto_w_stats_path.

  Returns:
    selection_pb2.Experiment with computed statistics.
  """
    experiment_proto = copy.deepcopy(experiment_proto)

    has_all_statistics = True

    all_reads = {}
    for round_proto in experiment_proto.rounds.values():
        for reads in [round_proto.positive_reads, round_proto.negative_reads]:
            if reads.name:
                all_reads[reads.name] = reads
                if not reads.HasField('statistics'):
                    has_all_statistics = False

    all_ao = {}
    for ao_proto in experiment_proto.additional_output:
        if ao_proto.name:
            all_ao[ao_proto.name] = ao_proto
            if not ao_proto.HasField('statistics'):
                has_all_statistics = False

    if not has_all_statistics or save_stats:
        with tf.Graph().as_default():
            logger.info('Setting up graph for statistics')
            # we only care about outputs, which don't rely on training hyper
            # parameters
            hps = tf.HParams(preprocess_mode=preprocess_mode,
                             kmer_k_max=0,
                             ratio_random_dna=0.0,
                             total_reads_defining_positive=0,
                             additional_output=','.join([
                                 x.name
                                 for x in experiment_proto.additional_output
                             ]))
            _, outputs = data.input_pipeline(input_paths,
                                             experiment_proto,
                                             final_mbsz=100000,
                                             hps=hps,
                                             num_epochs=1,
                                             num_threads=1)
            size_op = tf.shape(outputs)[list(
                outputs.axes.keys()).index('batch')]

            all_update_ops = []
            all_value_ops = {}
            for name in all_reads:
                counts = lt.select(outputs, {'output': name})
                log_counts = lt.log(counts + 1.0)
                ops = {
                    'mean': contrib_metrics.streaming_mean(counts),
                    'std_dev': streaming_std(counts),
                    'mean_log_plus_one':
                    contrib_metrics.streaming_mean(log_counts),
                    'std_dev_log_plus_one': streaming_std(log_counts),
                }
                value_ops, update_ops = contrib_metrics.aggregate_metric_map(
                    ops)
                all_update_ops.extend(list(update_ops.values()))
                all_value_ops[name] = value_ops

            for name in all_ao:
                ao = lt.select(outputs, {'output': name})
                log_ao = lt.log(ao + 1.0)
                ops = {
                    'mean': contrib_metrics.streaming_mean(ao),
                    'std_dev': streaming_std(ao),
                    'mean_log_plus_one':
                    contrib_metrics.streaming_mean(log_ao),
                    'std_dev_log_plus_one': streaming_std(log_ao),
                }
                value_ops, update_ops = contrib_metrics.aggregate_metric_map(
                    ops)
                all_update_ops.extend(list(update_ops.values()))
                all_value_ops[name] = value_ops

            logger.info('Running statistics ops')
            sv = tf.train.Supervisor(logdir=logdir)
            with sv.managed_session() as sess:
                total = 0
                for results in run_until_exhausted(sv, sess,
                                                   [size_op] + all_update_ops):
                    total += results[0]
                    if max_size is not None and total >= max_size:
                        break
                all_statistics = {
                    k: sess.run(v)
                    for k, v in all_value_ops.items()
                }

            for reads_name, reads in all_reads.items():
                for name, value in all_statistics[reads_name].items():
                    setattr(reads.statistics, name, value.item())

            for ao_name, ao in all_ao.items():
                for name, value in all_statistics[ao_name].items():
                    setattr(ao.statistics, name, value.item())

            logger.info('Computed statistics: %r', all_statistics)

            if save_stats:
                logger.info('Save the proto with statistics to %s',
                            proto_w_stats_path)
                with open('/tmp/tmp.pbtxt', 'w') as f:
                    f.write(text_format.MessageToString(experiment_proto))
                gfile.Copy('/tmp/tmp.pbtxt',
                           proto_w_stats_path,
                           overwrite=True)
    else:
        logger.info('All the statistics exist. Nothing to compute')
    return experiment_proto
Exemple #3
0
def create_input_and_outputs(feature_tensors,
                             experiment_proto,
                             input_features=(SEQUENCE_ONE_HOT, ),
                             skip_all_zero_counts=True,
                             kmer_k_max=4,
                             additional_output=None):
    """Create inputs and outputs from parsed features.

  Args:
    feature_tensors: Dict[str, tf.Tensor] with parsed featured created by
      `build_features`.
    experiment_proto: selection_pb2.Experiment describing the experiment.
    input_features: optional sequence of feature constants defined in this
      module.
    skip_all_zero_counts: some sequences have no counts, e.g., because they were
      created artificially for validation purposes on the binding array. We want
      to skip these sequences for training.
    kmer_k_max: optional integer giving the maximum kmer length to use if
      SEQUENCE_KMER_COUNT is in `input_features`.
    additional_output: optional list of strings contains additional outputs.

  Returns:
    inputs: LabeledTensor with dtype=float32 and axes
      [batch_axis, input_position_axis, input_channel_axis], of one-hot-encoded
      rasterized sequences for input into machine learning models.
    outputs: LabeledTensor with dtype=float32 and axes [batch_axis, output_axis]
      denoting possible output tensors, including counts and binding array
      measurements.
  """

    sequence_length = experiment_proto.sequence_length
    count_names = selection.all_count_names(experiment_proto)
    array_names = selection.binding_array_names(experiment_proto)

    sequence_tensor = feature_tensors['sequence']
    batch_axis = sequence_tensor.axes['batch']
    position_axis = ('position', list(range(sequence_length)))

    inputs = {}

    if SEQUENCE_ONE_HOT in input_features:
        seq_indices = custom_ops.dna_sequence_to_indices(
            sequence_tensor, sequence_length)
        tensor = tf.one_hot(seq_indices, depth=4, dtype=tf.float32)
        channel_axis = ('channel', list(dna.DNA_BASES))
        axes = [batch_axis, position_axis, channel_axis]
        one_hots = lt.LabeledTensor(tensor, axes)
        inputs[SEQUENCE_ONE_HOT] = one_hots

    if SEQUENCE_KMER_COUNT in input_features:
        raw_counts = custom_ops.count_all_dna_kmers(sequence_tensor,
                                                    kmer_k_max)
        kmer_axis = lt.Axis('kmer', _kmer_labels(kmer_k_max))
        counts = lt.LabeledTensor(raw_counts, [batch_axis, kmer_axis])
        means, stds = _all_kmer_mean_and_std(kmer_k_max, sequence_length)
        mean_count = lt.constant(means, tf.float32, axes=[kmer_axis])
        std_count = lt.constant(stds, tf.float32, axes=[kmer_axis])
        inputs[SEQUENCE_KMER_COUNT] = (
            (lt.cast(counts, tf.float32) - mean_count) / std_count)

    if STRUCTURE_PARTITION_FUNCTION in input_features:
        with tf.name_scope('structure_partition_fn'):
            raw_pf_tensor = lt.expand_dims(
                feature_tensors['partition_function'],
                ['batch', 'partition_fn_axis'])
            inputs[STRUCTURE_PARTITION_FUNCTION] = lt.log(raw_pf_tensor)

    output_names = count_names + array_names
    outputs = [lt.cast(feature_tensors[k], tf.float32) for k in output_names]

    if additional_output and additional_output[0]:
        outputs += [
            lt.cast(feature_tensors[k], tf.float32) for k in additional_output
        ]
        output_names += additional_output
    outputs = lt.pack(outputs, ('output', output_names), axis_position=1)

    if skip_all_zero_counts:
        with tf.name_scope('counts_filtering'):
            counts = lt.select(outputs, {'output': count_names})
            keep = lt.reduce_any(lt.not_equal(counts, 0.0), 'output')
            inputs = {k: lt.boolean_mask(v, keep) for k, v in inputs.items()}
            outputs = lt.boolean_mask(outputs, keep)

    return inputs, outputs