def test_clustering(): """Test learning of clusters for joint types.""" first_means = np.asarray([(10, 70), (58, 94), (66, 58), (95, 62)]) second_means = np.asarray([(88, 12), (56, 15), (25, 21), (24, 89)]) fake_locations = np.concatenate([ generate_fake_locations(100, first_means), generate_fake_locations(100, second_means), ], axis=0) np.random.shuffle(fake_locations) fake_pairs = [(0, 1), (1, 2), (2, 3)] fake_joints = Joints(fake_locations, fake_pairs) # Make two clusters for each relationship type. Yes, passing in zeros as # your scale is stupid, and poor testing practice. centers = from_dataset(fake_joints, 2, np.zeros(len(fake_locations)), 1) assert centers.ndim == 3 # Three joints, two clusters per joint, two coordinates (i.e. x, y) per # cluster assert centers.shape == (3, 2, 2) for idx, pair in enumerate(fake_pairs): first_idx, second_idx = pair first_mean = first_means[second_idx] - first_means[first_idx] second_mean = second_means[second_idx] - second_means[first_idx] found_means = centers[idx] first_dists = np.linalg.norm(found_means - first_mean, axis=1) second_dists = np.linalg.norm(found_means - second_mean, axis=1) # Make sure that each of our specified means are within Euclidean # distance 1 of at least one found cluster first_within = first_dists < 1 assert first_within.any() second_within = second_dists < 1 assert second_within.any()
def test_clustering(): """Test learning of clusters for joint types.""" first_means = np.asarray([ (10, 70), (58, 94), (66, 58), (95, 62) ]) second_means = np.asarray([ (88, 12), (56, 15), (25, 21), (24, 89) ]) fake_locations = np.concatenate([ generate_fake_locations(100, first_means), generate_fake_locations(100, second_means), ], axis=0) np.random.shuffle(fake_locations) fake_pairs = [ (0, 1), (1, 2), (2, 3) ] fake_joints = Joints(fake_locations, fake_pairs) # Make two clusters for each relationship type. Yes, passing in zeros as # your scale is stupid, and poor testing practice. centers = from_dataset(fake_joints, 2, np.zeros(len(fake_locations)), 1) assert centers.ndim == 3 # Three joints, two clusters per joint, two coordinates (i.e. x, y) per # cluster assert centers.shape == (3, 2, 2) for idx, pair in enumerate(fake_pairs): first_idx, second_idx = pair first_mean = first_means[second_idx] - first_means[first_idx] second_mean = second_means[second_idx] - second_means[first_idx] found_means = centers[idx] first_dists = np.linalg.norm(found_means - first_mean, axis=1) second_dists = np.linalg.norm(found_means - second_mean, axis=1) # Make sure that each of our specified means are within Euclidean # distance 1 of at least one found cluster first_within = first_dists < 1 assert first_within.any() second_within = second_dists < 1 assert second_within.any()
def do_pairwise_clustering(train_set, validate_set, test_set, cache_dir, check_cache=True): centroids_path = path.join(cache_dir, 'centroids.npy') if check_cache and path.exists(centroids_path): logging.info("Loading centroids from cache") centroids = np.load(centroids_path) else: if check_cache: msg = "Saved centroids not found" else: msg = "Ignoring saved centroids, if any" logging.info(msg + '; deriving centroids') centroids = from_dataset( train_set.joints, K, train_set.scales, train_set.template_size ) logging.info('Caching centroids') np.save(centroids_path, centroids) check_cache = False labels_path = path.join(cache_dir, 'labels.pickle') if check_cache and path.exists(labels_path): with open(labels_path) as fp: train_labels, validate_labels, test_labels = pickle.load(fp) else: if check_cache: msg = "Cached labels not found" else: msg = "Ignoring cached labels, if any" logging.info(msg + '; calculating labels') logging.info("Labelling training set") train_labels = TrainingLabels(train_set, centroids) logging.info("Labelling validation set") validate_labels = TrainingLabels(validate_set, centroids) logging.info("Labelling test set") test_labels = TrainingLabels(test_set, centroids) logging.info("Pickling labels") with open(labels_path, 'w') as fp: pickle.dump((train_labels, validate_labels, test_labels), fp) check_cache = False return check_cache, centroids, train_labels, validate_labels, test_labels