def ensemble_one_distance_histogram(pickle_files, weights): """Average the given pickle_files and dump.""" dicts = [] sequence = None max_dim = None for picklefile in pickle_files: if not tf.io.gfile.exists(picklefile): logging.warning('missing %s', picklefile) break logging.info('loading pickle file %s', picklefile) distance_histogram_dict = parsers.parse_distance_histogram_dict(picklefile) if sequence is None: sequence = distance_histogram_dict['sequence'] else: assert sequence == distance_histogram_dict['sequence'], '%s vs %s' % ( sequence, distance_histogram_dict['sequence']) dicts.append(distance_histogram_dict) assert dicts[-1]['probs'].shape[0] == dicts[-1]['probs'].shape[1], ( '%d vs %d' % (dicts[-1]['probs'].shape[0], dicts[-1]['probs'].shape[1])) assert (dicts[0]['probs'].shape[0:2] == dicts[-1]['probs'].shape[0:2] ), ('%d vs %d' % (dicts[0]['probs'].shape, dicts[-1]['probs'].shape)) if max_dim is None or max_dim < dicts[-1]['probs'].shape[2]: max_dim = dicts[-1]['probs'].shape[2] if len(dicts) != len(pickle_files): logging.warning('length mismatch\n%s\nVS\n%s', dicts, pickle_files) return sequence, None ensemble_hist = ( sum(w * c['probs'] for w, c in zip(weights, dicts)) / sum(weights)) new_dict = dict(dicts[0]) new_dict['probs'] = ensemble_hist return sequence, new_dict
def paste_distance_histograms(input_dir, output_dir, weights, crop_sizes, crop_step): """Paste together distograms for given domains of given targets and write. Domains distance histograms are 'pasted', meaning they are substituted directly into the contact map. The order is determined by the order in the domain definition file. Args: input_dir: String, path to directory containing chain and domain-level distogram files. output_dir: String, path to directory to write out chain-level distrogram files. weights: A dictionary with weights. crop_sizes: The crop sizes. crop_step: The step size for cropping. Raises: ValueError: if histogram parameters don't match. """ tf.io.gfile.makedirs(output_dir) targets = tf.io.gfile.glob(os.path.join(input_dir, "*.pickle")) targets = [os.path.splitext(os.path.basename(t))[0] for t in targets] targets = set([t.split("-")[0] for t in targets]) logging.info("Pasting distance histograms for %d targets", len(targets)) for target in sorted(targets): logging.info("%s as chain", target) chain_pickle_path = os.path.join(input_dir, "%s.pickle" % target) distance_histogram_dict = parsers.parse_distance_histogram_dict( chain_pickle_path) combined_cmap = np.array(distance_histogram_dict["probs"]) # Make the counter map 1-deep but still rank 3. counter_map = np.ones_like(combined_cmap[:, :, 0:1]) sequence = distance_histogram_dict["sequence"] target_domains = generate_domains(target=target, sequence=sequence, crop_sizes=crop_sizes, crop_step=crop_step) # Paste in each domain. for domain in sorted(target_domains, key=lambda x: x["name"]): if domain["name"] == target: logging.info("Skipping %s as domain", target) continue if "," in domain["description"]: logging.info("Skipping multisegment domain %s", domain["name"]) continue crop_start, crop_end = domain["description"] domain_pickle_path = os.path.join(input_dir, "%s.pickle" % domain["name"]) weight = weights.get(domain["name"], 1e9) logging.info("Pasting %s: %d-%d. weight: %f", domain_pickle_path, crop_start, crop_end, weight) domain_distance_histogram_dict = parsers.parse_distance_histogram_dict( domain_pickle_path) for field in ["num_bins", "min_range", "max_range"]: if domain_distance_histogram_dict[ field] != distance_histogram_dict[field]: raise ValueError("Field {} does not match {} {}".format( field, domain_distance_histogram_dict[field], distance_histogram_dict[field])) weight_matrix_size = crop_end - crop_start + 1 weight_matrix = np.ones((weight_matrix_size, weight_matrix_size), dtype=np.float32) * weight combined_cmap[crop_start - 1:crop_end, crop_start - 1:crop_end, :] += ( domain_distance_histogram_dict["probs"] * np.expand_dims(weight_matrix, 2)) counter_map[crop_start - 1:crop_end, crop_start - 1:crop_end, 0] += weight_matrix # Broadcast across the histogram bins. combined_cmap /= counter_map # Write out full-chain cmap for folding. output_chain_pickle_path = os.path.join(output_dir, "{}.pickle".format(target)) logging.info("Writing to %s", output_chain_pickle_path) distance_histogram_dict["probs"] = combined_cmap distance_histogram_dict["target"] = target # Save the distogram pickle file. distogram_io.save_distance_histogram_from_dict( output_chain_pickle_path, distance_histogram_dict) # Compute the contact map and save it as an RR file. contact_probs = distogram_io.contact_map_from_distogram( distance_histogram_dict) rr_path = os.path.join(output_dir, "%s.rr" % target) distogram_io.save_rr_file(filename=rr_path, probs=contact_probs, domain=target, sequence=distance_histogram_dict["sequence"])