def ensemble_one_distance_histogram(pickle_files, weights):
  """Average the given pickle_files and dump."""
  dicts = []
  sequence = None
  max_dim = None
  for picklefile in pickle_files:
    if not tf.io.gfile.exists(picklefile):
      logging.warning('missing %s', picklefile)
      break
    logging.info('loading pickle file %s', picklefile)
    distance_histogram_dict = parsers.parse_distance_histogram_dict(picklefile)
    if sequence is None:
      sequence = distance_histogram_dict['sequence']
    else:
      assert sequence == distance_histogram_dict['sequence'], '%s vs %s' % (
          sequence, distance_histogram_dict['sequence'])
    dicts.append(distance_histogram_dict)
    assert dicts[-1]['probs'].shape[0] == dicts[-1]['probs'].shape[1], (
        '%d vs %d' % (dicts[-1]['probs'].shape[0], dicts[-1]['probs'].shape[1]))
    assert (dicts[0]['probs'].shape[0:2] == dicts[-1]['probs'].shape[0:2]
           ), ('%d vs %d' % (dicts[0]['probs'].shape, dicts[-1]['probs'].shape))
    if max_dim is None or max_dim < dicts[-1]['probs'].shape[2]:
      max_dim = dicts[-1]['probs'].shape[2]
  if len(dicts) != len(pickle_files):
    logging.warning('length mismatch\n%s\nVS\n%s', dicts, pickle_files)
    return sequence, None
  ensemble_hist = (
      sum(w * c['probs'] for w, c in zip(weights, dicts)) / sum(weights))
  new_dict = dict(dicts[0])
  new_dict['probs'] = ensemble_hist
  return sequence, new_dict
Ejemplo n.º 2
0
def paste_distance_histograms(input_dir, output_dir, weights, crop_sizes,
                              crop_step):
    """Paste together distograms for given domains of given targets and write.

  Domains distance histograms are 'pasted', meaning they are substituted
  directly into the contact map. The order is determined by the order in the
  domain definition file.

  Args:
    input_dir: String, path to directory containing chain and domain-level
      distogram files.
    output_dir: String, path to directory to write out chain-level distrogram
      files.
    weights: A dictionary with weights.
    crop_sizes: The crop sizes.
    crop_step: The step size for cropping.

  Raises:
    ValueError: if histogram parameters don't match.
  """
    tf.io.gfile.makedirs(output_dir)

    targets = tf.io.gfile.glob(os.path.join(input_dir, "*.pickle"))
    targets = [os.path.splitext(os.path.basename(t))[0] for t in targets]
    targets = set([t.split("-")[0] for t in targets])
    logging.info("Pasting distance histograms for %d targets", len(targets))

    for target in sorted(targets):
        logging.info("%s as chain", target)

        chain_pickle_path = os.path.join(input_dir, "%s.pickle" % target)
        distance_histogram_dict = parsers.parse_distance_histogram_dict(
            chain_pickle_path)

        combined_cmap = np.array(distance_histogram_dict["probs"])
        # Make the counter map 1-deep but still rank 3.
        counter_map = np.ones_like(combined_cmap[:, :, 0:1])

        sequence = distance_histogram_dict["sequence"]

        target_domains = generate_domains(target=target,
                                          sequence=sequence,
                                          crop_sizes=crop_sizes,
                                          crop_step=crop_step)

        # Paste in each domain.
        for domain in sorted(target_domains, key=lambda x: x["name"]):
            if domain["name"] == target:
                logging.info("Skipping %s as domain", target)
                continue

            if "," in domain["description"]:
                logging.info("Skipping multisegment domain %s", domain["name"])
                continue

            crop_start, crop_end = domain["description"]

            domain_pickle_path = os.path.join(input_dir,
                                              "%s.pickle" % domain["name"])

            weight = weights.get(domain["name"], 1e9)

            logging.info("Pasting %s: %d-%d. weight: %f", domain_pickle_path,
                         crop_start, crop_end, weight)

            domain_distance_histogram_dict = parsers.parse_distance_histogram_dict(
                domain_pickle_path)
            for field in ["num_bins", "min_range", "max_range"]:
                if domain_distance_histogram_dict[
                        field] != distance_histogram_dict[field]:
                    raise ValueError("Field {} does not match {} {}".format(
                        field, domain_distance_histogram_dict[field],
                        distance_histogram_dict[field]))
            weight_matrix_size = crop_end - crop_start + 1
            weight_matrix = np.ones((weight_matrix_size, weight_matrix_size),
                                    dtype=np.float32) * weight
            combined_cmap[crop_start - 1:crop_end,
                          crop_start - 1:crop_end, :] += (
                              domain_distance_histogram_dict["probs"] *
                              np.expand_dims(weight_matrix, 2))
            counter_map[crop_start - 1:crop_end, crop_start - 1:crop_end,
                        0] += weight_matrix

        # Broadcast across the histogram bins.
        combined_cmap /= counter_map

        # Write out full-chain cmap for folding.
        output_chain_pickle_path = os.path.join(output_dir,
                                                "{}.pickle".format(target))

        logging.info("Writing to %s", output_chain_pickle_path)

        distance_histogram_dict["probs"] = combined_cmap
        distance_histogram_dict["target"] = target

        # Save the distogram pickle file.
        distogram_io.save_distance_histogram_from_dict(
            output_chain_pickle_path, distance_histogram_dict)

        # Compute the contact map and save it as an RR file.
        contact_probs = distogram_io.contact_map_from_distogram(
            distance_histogram_dict)
        rr_path = os.path.join(output_dir, "%s.rr" % target)
        distogram_io.save_rr_file(filename=rr_path,
                                  probs=contact_probs,
                                  domain=target,
                                  sequence=distance_histogram_dict["sequence"])