def distance_matrix(self, recreate: bool = False, model: str = "ted", **kwargs: Any) -> np.ndarray: """ Compute the distance matrix between each pair of reaction trees All key-word arguments are passed along to the `route_distance_calculator` function from the `route_distances` package. When `model` is "lstm", a key-word argument `model_path` needs to be given when `model` is "ted", two optional key-word arguments `timeout` and `content` can be given. :param recreate: if False, use a cached one if available :param model: the type of model to use "ted" or "lstm" :return: the square distance matrix """ if model == "lstm" and not kwargs.get("model_path"): raise KeyError( "Need to provide 'model_path' argument when using LSTM model for computing distances" ) content = kwargs.get("content", "both") cache_key = kwargs.get("model_path", "") if model == "lstm" else content if self._distance_matrix.get(cache_key) is not None and not recreate: return self._distance_matrix[cache_key] calculator = route_distances_calculator(model, **kwargs) distances = calculator(self.dicts) self._distance_matrix[cache_key] = distances return distances
def main() -> None: """ Entry-point for CLI tool """ args = _get_args() tqdm.pandas() data = _merge_inputs(args.files) if args.only_clustering: calculator = None elif args.model == "ted": calculator = route_distances_calculator("ted", content="both") else: calculator = route_distances_calculator( "lstm", model_path=args.model, fp_size=args.fp_size, lstm_size=args.lstm_size, ) if not args.only_clustering: dist_data = data.progress_apply(_calc_distances, axis=1, calculator=calculator) data = data.assign( distance_matrix=dist_data.distance_matrix, distances_time=dist_data.distances_time, ) if args.nclusters is not None: cluster_data = data.progress_apply( _do_clustering, axis=1, nclusters=args.nclusters, min_density=args.min_density, ) data = data.assign( cluster_labels=cluster_data.cluster_labels, cluster_time=cluster_data.cluster_time, ) with warnings.catch_warnings(): # This wil suppress a PerformanceWarning warnings.simplefilter("ignore") data.to_hdf(args.output, "table")
def distance_to(self, other: "ReactionTree", content: str = "both") -> float: """ Calculate the distance to another reaction tree This is a tree edit distance, with unit cost to insert and deleted nodes, and the Jaccard distance for substituting nodes :param other: the reaction tree to compare to :param content: determine what part of the tree to include in the calculation :return: the distance between the routes """ calculator = route_distances_calculator("ted", content=content) distances = calculator([self.to_dict(), other.to_dict()]) return distances[0, 1]