Esempio n. 1
0
def main():
    # PARSE ARGS
    data_filename = None
    header_filename = None
    k = None
    distance_metric = None

    parser = lab5.get_k_means_args()
    args = parser.parse_args()
    data_filename = args.csv_filename
    if args.header_filename:
        header_filename = args.header_filename
    elif args.infer_header:
        header_filename = get_header_filename(data_filename)
    k = args.k
    distance_metric = get_distance_metric(args)

    print("Data   Filename: %s" % data_filename)
    print("Header Filename: %s" % header_filename)
    print("k              : %d" % k)

    # READ DATA
    dataset = Dataset(data_filename, header_filename)

    # CALC K MEANS
    k_means = KMeans(distance_metric)
    centroids, clusters = k_means.disk_k_means(dataset, k)
    assert len(centroids) == len(clusters) and len(clusters) == k
    num = print_stats(dataset, k_means, clusters, centroids)
    assert num == dataset.size(), "num(%d) != D.size(%d)" % (num, dataset.size())
    print("Datapoints clustered: %d" % num)
    print("Datapoints total    : %d" % dataset.size())
    return 0
Esempio n. 2
0
def main():

   # PARSE ARGS
   data_filename = None
   header_filename = None
   threshold = None
   distance_metric = None
   cluster_distance = None

   parser = lab5.get_hierarchical_args()
   args = parser.parse_args()
   data_filename = args.csv_filename
   if args.header_filename:
      header_filename = args.header_filename
   elif args.infer_header:
      header_filename = get_header_filename(data_filename)
   threshold = args.threshold
   distance_metric = get_distance_metric(args)
   cluster_distance = get_cluster_distance(args)
   print('Data   Filename: %s' % data_filename)
   print('Header Filename: %s' % header_filename)
   print('Threshold      : %.3f' % threshold) if threshold else None

   # READ DATA
   dataset = Dataset(data_filename, header_filename)
   #for d in dataset:
   #   print(d)

   # CALC AGGLOMERATIVE
   agglomerative = Agglomerative(distance_metric, cluster_distance)

   dendrogram = agglomerative.agglomerative(dataset)
   all_clusters = get_all_clusters(dendrogram)
   if threshold:
      trimmed_tree, centroids, clusters = get_clusters(dendrogram, threshold)
      for c in clusters:
         print(c)
      print(len(clusters))
      print_tree(trimmed_tree)
      num_datapoints = print_stats(dataset, agglomerative, clusters, centroids)
      print('Datapoints clustered: %d' % num_datapoints)
      print('Datapoints total    : %d' % dataset.size())
   else:
      print_tree(dendrogram)

   return 0