示例#1
0
	#EDIT PRM STAR PARAM
	args = parser.parse_args()
	args.log_dir = './output/max_nodes-' + args.map_type + str(args.max_nodes) + "-obs-thres" + str(args.obstacle_threshold) +\
				   "-k_nearest-" + \
				   str(args.k_nearest) + "-connection_radius-" + str(args.connection_radius) + "-date-" + \
				   datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + '/'

	if not os.path.exists(args.log_dir):
		os.makedirs(args.log_dir)
	tic = time.time()
	# load map
	map_data, resolution = load_hilbert_map(map_type=args.map_type)
	#resolution = 0.2
	#with open("freiburg_ground_map_q_resolution_final.pickle", 'rb') as tf:
	#	map_data = pickle.load(tf)
	map_array = convert_map_dict_to_array(map_data, resolution)
	map_data["yq"] = 1.0 * (map_data["yq"] > args.obstacle_threshold)
	# get samples from hilbert maps
	sample_list = hilbert_samples(map_data.copy(), args.exp_factor, args.obstacle_threshold, num_samples=args.number_of_samples)
	# take unique samples
	sample_list = [list(t) for t in set(tuple(element) for element in sample_list)]
	# truncated based on max nodes
	sample_list = sample_list[:args.max_nodes]
	# find k nearest neighbor
	nbrs = NearestNeighbors(n_neighbors=args.k_nearest, algorithm='ball_tree').fit(sample_list)
	distances, indices = nbrs.kneighbors(sample_list)
	# create gragh
	prm_graph = nx.Graph()
	# add graph nodes
	for indx, s in enumerate(sample_list):
		prm_graph.add_node(indx, pos=(s[0], s[1]))
示例#2
0
def get_top_n_persistence_node_location(n,
                                        map_type,
                                        obs_threshold,
                                        location_type="death",
                                        feature_type=0):
    """
    :param feature_type: 0 for connected components, 1 for loops
    :param n: top number of persistence
    :param map_type: intel or drive
    :param location_type: string representing birth or death
    :return: returns the birth or death persistence node
    """
    if location_type == "death":
        location_type_index = 1
    elif location_type == "birth":
        location_type_index = 0
    else:
        raise ValueError("Invalid location type")

    map_data, resolution = load_hilbert_map(map_type=map_type)
    map_array = convert_map_dict_to_array(map_data, resolution)

    fc = FreudenthalComplex(map_array)
    st = fc.init_freudenthal_2d()
    print_complex_attributes(st)

    if st.make_filtration_non_decreasing():
        print("modified filtration value")
    st.initialize_filtration()
    if len(st.persistence()) <= 10:
        for i in st.persistence():
            print(i)

    first_persistence = st.persistence_intervals_in_dimension(feature_type)
    if feature_type == 1:
        remove_indices = []
        for i in range(len(first_persistence)):
            if first_persistence[i][0] > obs_threshold:
                remove_indices.append(i)
        first_persistence = np.delete(first_persistence, remove_indices, 0)
    if feature_type == 0:
        remove_indices = []
        for i in range(len(first_persistence)):
            if first_persistence[i][1] > obs_threshold:
                remove_indices.append(i)
        first_persistence = np.delete(first_persistence, remove_indices, 0)
        # remove feature ending after 0.4
    life_span = first_persistence[:, 1] - first_persistence[:, 0]
    winner_index = life_span.argsort()[-n:][::-1]
    print("len winner index ", len(winner_index))
    #print(life_span)
    winner_persistence = first_persistence[winner_index]
    print(winner_persistence, "winner_persistence")
    top_persistence_node = []
    for indx, intensity in enumerate(map_data['yq']):
        for j in range(n):
            p = winner_persistence[j]
            # if np.isclose(intensity, p[1]):
            #     top_persistence_node.append(map_data["Xq"][indx])
            if np.isclose(intensity,
                          p[location_type_index],
                          rtol=1e-10,
                          atol=1e-13):
                top_persistence_node.append(map_data["Xq"][indx])
                print(j, intensity)
    #return winner_persistence, life_span[winner_index]
    return top_persistence_node, life_span[winner_index]