def assignMembership(points, centroids):
	k = len(centroids)
	membership = []
	for point in points:
		minDistance = float("inf")
		for c in xrange(k):
			d = PS.distance(point, centroids[c])
			if (d < minDistance):
				minDistance = d
				minCluster = c
		#format (clusterNum, (x,y))
		membership.append((minCluster, point))
	return membership
def kMeans(k, e, i, o):
	comm = MPI.COMM_WORLD
	rank = comm.Get_rank()
	size = comm.Get_size()

	#rank = 0 divides the points and scatters to each node
	if rank == 0:
		f = open(i, "r")
		points = []
		for line in f.readlines():
			pointStr = line.strip().split(' ')
			point = (float(pointStr[0]), float(pointStr[1]))
			points.append(point)
		f.close()
		chunkSize = len(points)/size
		chunk = [points[i*chunkSize: min((i+1)*chunkSize, len(points))] for i in range(size)]
		#pick centroids and then broadcast
		newCen = getInitialCentroids(points, k)
	else:
		newCen = None
		chunk = None
	# every other node gets the broadcasted newCen from root
	newCen = comm.bcast(newCen, root=0)
	# every node gets its chunk of data
	chunk = comm.scatter(chunk, root=0)
	oldCen = [(0,0)]*k
	membership = [-1]*len(chunk)
	allMembership = []
	while(PS.diffCentroids(oldCen, newCen) > e):
		oldCen = newCen[:]
		# each node then computes membership based on the current centroids
		membership = assignMembership(chunk, newCen)
		allMembers = comm.gather(membership, root=0)
		#flatten the list of lists
		if (rank == 0):
			allMembership = []
			#flatten the list of lists
			for member in allMembers:
				# based on newmembership calculate new centroids 
				allMembership.extend(member)
			newCen = updateCentroids(allMembership, k)
		#broadcast newCentroids to every node
		newCen = comm.bcast(newCen, root=0)

	#write output file
	if (rank == 0):
		fo = open(o, "w+")
		for c in newCen:
			fo.write("%f %f\n" % (c[0], c[1]))
		fo.close()
	return 42