def getFlyContours(img):
   minFlyArea = flyTrackerSettings.minFlyArea     # 900
   minFlyAreaNorm = flyTrackerSettings.minFlyAreaNorm # 0.0045
   arenaCoords = flyTrackerSettings.arenaCoords
   nFlies = flyTrackerSettings.nFlies
   # maxFlyAreaNorm = 0.02
   if cv2to3.isCV2(): # opencv2
      contours, hierarchy = cv2.findContours(img,mode=cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) # RETR_EXTERNAL? use sure_bg?
   else: # opencv3+
      image, contours, hierarchy = cv2.findContours(img,mode=cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) # RETR_EXTERNAL? use sure_bg?

   area = [0]*len(contours)
   for idx,cnt in enumerate(contours):
      area[idx] = cv2.contourArea(cnt)

   arenaArea = abs((arenaCoords[1]-arenaCoords[3])*(arenaCoords[0]-arenaCoords[2]))
   index = 0
   for cntInd,cnt in enumerate(contours):
      # print(area[cntInd]/arenaArea)
      if area[cntInd]/arenaArea > minFlyAreaNorm:
         contours[index]=contours[cntInd]
         index = index + 1
   contours = contours[0:index]

   return contours
def clusterFlies2(foreGround,frameCount,threshOffset=0):
   # this is where all the major clustering happens!
   arenaCoords = flyTrackerSettings.arenaCoords
   term_crit = (cv2.TERM_CRITERIA_EPS, 30, 0.1)
   nFlies = flyTrackerSettings.nFlies
   oldCenters = flyTrackerSettings.oldCenters

   if flyTrackerSettings.debugImages:
      cv2.imshow('raw foreground', foreGround/255*10)

   # we should throw up a flag when we expect the clustering to not be so good:
   # for instance, when the flies are directly adjacent to each other

   success = 1
   printDB('new clustering.............................................' + str(frameCount))
   t=time.time()
   imgDB('foreground.jpg',foreGround*5)

   ##############################
   # Extract all the foreground things we need
   ##############################
   # black out everything outside of the arena circle that was selected...this should be computed once outside of
   # this function. no need to recompute every frame
   fg2 = np.array(foreGround,np.uint8)
   gray = cv2.cvtColor(foreGround,cv2.COLOR_BGR2GRAY)
   if flyTrackerSettings.circularMask:
      mask = np.zeros(gray.shape,np.uint8)
      cv2.circle(mask,(int(gray.shape[0]/2),int(gray.shape[1]/2)), flyTrackerSettings.radius,255,-1)
      gray[mask!=255] = 0

   grayOld = copy.copy(gray)
   intensity = np.sum(gray)/1000      # have to normalize this for total number of pixels...normally 840x840

   printDB('intensity level:' + str(intensity))
   imgDB('gray.jpg',gray)
   gray = np.array(gray, np.uint8)
   chosenThreshold = intensity/250 + 1 + threshOffset + flyTrackerSettings.bgOffset
   if flyTrackerSettings.threshType=='adaptive':
      thresh = cv2.adaptiveThreshold(gray,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY_INV,19,chosenThreshold)
   else:
      ret, thresh = cv2.threshold(gray,chosenThreshold,255,cv2.THRESH_BINARY)

   thresh1 = copy.copy(thresh)
   imgDB('thresh1.jpg',thresh)
   thresh = cv2.blur(thresh,(6,6))
   blurredThresh = cv2.blur(thresh,(2,2))
   imgDB('blur.jpg',blurredThresh)
   ret, thresh = cv2.threshold(thresh,1,255,cv2.THRESH_BINARY)
   imgDB('thresh.jpg',thresh)

   if (flyTrackerSettings.debugImages):
      cv2.imshow('foreground', thresh)

   printDB(round(1000*(time.time() - t)))    # ~130ms
   t = time.time()
   ##############################
   # Find the contours in the image and try to use only the fly-sized ones
   ##############################
   contours = getFlyContours(thresh)

   markers = np.zeros((grayOld.shape[0],grayOld.shape[1],1),np.int32)
   for cntInd,cnt in enumerate(contours):
      # print(np.mean(cnt,axis=0,dtype=int)[0])
      cv2.drawContours(markers,contours,cntInd,cntInd+1,cv2.FILLED)

   fgObjects = np.zeros(gray.shape,np.uint8)
   # print(fgObjects.shape)
   # print(markers.shape)
   # print(grayOld.shape)
   fgObjects[np.logical_and(markers[:,:,0] > 0, grayOld > 0)] = grayOld[np.logical_and(markers[:,:,0] > 0, grayOld > 0)]
   kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3))
   cv2.erode(fgObjects,kernel2,fgObjects)
   imgDB('markers.jpg',markers)
   imgDB('fgobj.jpg',fgObjects)

   printDB(round(1000*(time.time() - t)))    # ~18ms
   t = time.time()
   
   ##############################
   # Use k-means to cluster individual flies and denoise further
   ##############################

   pts = np.vstack(np.nonzero(fgObjects)).astype(np.float32).T
   # pts = np.vstack(np.nonzero(thresh)).astype(np.float32).T
   pts = pts[:,0:2]
   compactness_thresh = 10
   numK = flyTrackerSettings.nFlies-1;
   compactness = 1000000000000000;
   jumpMax = flyTrackerSettings.px2mm*4;
   compactnessMax = 2000;
   foundJump = False

   while (compactness/len(pts) > compactnessMax and numK < 10) or (maxdist > jumpMax and numK < 10 and np.sum(flyTrackerSettings.oldCenters) != 0 and not foundJump):
      numK = numK+1;
      # print('pts shape: ' + str(pts.shape));
      if cv2to3.isCV2():
         compactness, bestLabelsKM, centers = cv2.kmeans(pts, numK, term_crit, attempts=10, flags=cv2.KMEANS_PP_CENTERS)
      else: # for cv3 add 3rd argument=None
         compactness, bestLabelsKM, centers = cv2.kmeans(pts, numK, None, term_crit, attempts=10, flags=cv2.KMEANS_PP_CENTERS)
      
      printDB('numK! ' + str(numK) + ', compactness! ' + str(compactness/len(pts)))
      if (numK > nFlies):
         # we've found a bunch of noise... let's hope it's the stuff that's far from the previous flies!
         D = findDist(centers, oldCenters)
         for ii in range(numK-nFlies):
            ind = bestLabelsKM == np.nonzero(D.min(axis=1) == D.min(axis=1).max())[0][0]
            pts2 = np.vstack((pts[ind[:,0],0], pts[ind[:,0],1])).T.astype(np.uint)
            for (x,y) in pts2:
               fgObjects[x,y] = 0;

            pts = np.vstack(np.nonzero(fgObjects)).astype(np.float32).T
            pts = pts[:,0:2]
      
            if cv2to3.isCV2():
               compactness, bestLabelsKM, centers = cv2.kmeans(pts, numK-ii-1, term_crit, 10, cv2.KMEANS_PP_CENTERS)
            else:
               compactness, bestLabelsKM, centers = cv2.kmeans(pts, numK-ii-1, None, term_crit, 10, cv2.KMEANS_PP_CENTERS)

            D = findDist(centers, oldCenters)

         bestLabelsKM = np.reshape(bestLabelsKM,(len(bestLabelsKM),1))

      newCenters, newLabels, oldLabels = matchFlies(centers)

      bestLabels = np.zeros(bestLabelsKM.shape)
      for ii in range(len(newLabels)):
         bestLabels[bestLabelsKM == oldLabels[ii]] = newLabels[ii]

      bestLabelsKM = copy.copy(bestLabels)
      D = findDist(flyTrackerSettings.oldCenters, newCenters)
      maxdist = np.amax(D[np.eye(nFlies)==1])
      mindist = np.amax(D[np.eye(nFlies)==1])

      if (mindist > jumpMax):
         # replace this (or add to it) delta timestamp
         if np.sum(flyTrackerSettings.oldCenters) != 0:
            foundJump = True

      if (flyTrackerSettings.debugImages):
         tmp = np.zeros(fg2.shape)
         ind = bestLabelsKM==0;
         pts2 = np.vstack((pts[ind[:,0],0], pts[ind[:,0],1])).T.astype(np.uint)
         for (x,y) in pts2:
            tmp[x,y,0] = 255
         ind = bestLabelsKM==1
         pts2 = np.vstack((pts[ind[:,0],0], pts[ind[:,0],1])).T.astype(np.uint)
         for (x,y) in pts2:
            tmp[x,y,1] = 255
         ind = bestLabelsKM>=2
         pts2 = np.vstack((pts[ind[:,0],0], pts[ind[:,0],1])).T.astype(np.uint)
         for (x,y) in pts2:
            tmp[x,y,2] = 255

         imgDB('kMeans_' + str(numK) + '.jpg',tmp)

   ####### check how much these contours overlap with the larger contours
   ##########################################################################

   # contours,hierarchy = cv2.findContours(np.uint8(np.sum(tmp,axis=2)),cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE);
   # print('total number of contours: ' + str(len(contours)))
   # for ii in range(nFlies):
	  #  newFlyContours,hierarchy = cv2.findContours(np.uint8(tmp[:,:,ii]),cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE);
	  #  print('newFlyContours: ' + str(len(newFlyContours)))

	  #  tmpNew = copy.copy(tmp)
	  #  cv2.drawContours(tmpNew,newFlyContours,-1,(255,255,255),3)
	   
	  #  for idx,cnt in enumerate(contours):
	  #  		for flyIdx,flyCnt in enumerate(newFlyContours):
	  #  			# flyIntersect = np.intersect1d(cnt,flyCnt)
	  #  			flyIntersect = np.array([x for x in set(tuple(x) for x in cnt[:,0,:]) & set(tuple(x) for x in flyCnt[:,0,:])])
	  #  			if len(flyIntersect) > 10 and (len(cnt) > 100):
	  #  				print('primary contour ' + str(idx) + ' intersects at ' + str(len(flyIntersect)) + ' (' + str(1 - float(len(flyIntersect))/len(flyCnt)) + ')')
	  #  				cv2.drawContours(tmpNew,contours,idx,(255,255,255),int(4/(idx+1)))
	  #  				# cv2.drawContours(tmpNew,newFlyContours,flyIdx,(128,128,255),3)
	  #  				cv2.drawContours(tmpNew,newFlyContours,flyIdx,(int(255*float(flyIdx)/len(newFlyContours)),0,int(255*float(flyIdx)/len(newFlyContours))),3)

	  #  imgDB('kMeans_' + str(numK) + '_' + str(ii) + '.jpg',tmpNew)

   printDB(round(1000*(time.time() - t)))    # 225ms
   t = time.time()

   ## what we really want is to find situations where N flies are near each other and to erode the image until

   ##############################
   # Extract each fly body and fit ellipses
   ##############################

   bodyEllipses = np.zeros((nFlies, 5), np.float16);
   bodyEllipses2 = np.zeros((nFlies, 5), np.float16);
   allRedLines = np.zeros((nFlies, 4), np.float16);

   # extract the body shape and fit an ellipse
   tmpBd = np.zeros(fg2.shape)
   fly_colors = [[255,0,0],[0,255,0],[0,0,255],[128,128,0],[0,128,127],[128,0,127]]
   for ii in range(nFlies):

      ind = bestLabelsKM==ii;
      # print(np.sum(ind))
      labelPoints = np.vstack((pts[ind[:,0],0], pts[ind[:,0],1],)).T.astype(np.uint);
      grayPoints = grayOld[np.vsplit(labelPoints.T,2)].T;
      # print(labelPoints.shape)
      # print(grayPoints.shape)
      # print(nFlies)
      labelPoints = labelPoints[grayPoints[:,0]!=0,:]
      grayPoints = grayPoints[grayPoints[:,0] != 0];
      flyBodyDat = np.append(labelPoints,grayPoints,axis=1).astype(np.float32)

      numBodyK = 8;
      if cv2to3.isCV2():
         bodyCompactness, bestBodyLabels, bodyCenters = cv2.kmeans(flyBodyDat, numBodyK, term_crit, 10, cv2.KMEANS_PP_CENTERS)
      else:
         bodyCompactness, bestBodyLabels, bodyCenters = cv2.kmeans(flyBodyDat, numBodyK, None, term_crit, 10, cv2.KMEANS_PP_CENTERS)

      # Assume that the cluster with the lowest mean intensity is the 'shadow' and the 'wings'
      # Assume that the largest two clusters are just the body
#       print(bodyCenters)
      ind = bestBodyLabels < 1000000;
      numPts = 0;
      thresh = np.amax(bodyCenters[:,-1])-1
      lblPts = np.zeros((numBodyK,1))
      for jj in range(numBodyK):
         lblPts[jj] = np.sum(bestBodyLabels==jj)
      # while numPts < 30:

      while numPts < min(400,bestBodyLabels.shape[0]):
         thresh = thresh - 5;
         numPts = np.sum(lblPts[bodyCenters[:,-1] > thresh]);

      for jj in range(numBodyK):
#          if bodyCenters[jj,-1] > np.amax(bodyCenters[:,-1])/3:
         if bodyCenters[jj,-1] > thresh:
            ind[bestBodyLabels == jj] = True;
         else:
            ind[bestBodyLabels == jj] = False;

      ptsBody = np.vstack((labelPoints[ind[:,0],0], labelPoints[ind[:,0],1])).T.astype(np.uint);
      flyEllipse = cv2.fitEllipse(ptsBody[:,::-1])
      bodyEllipses[ii] = ellipseListToArray(flyEllipse)

      if (flyTrackerSettings.debugImages):
         for (x,y) in ptsBody:
            tmpBd[x,y,:] = fly_colors[ii]

      ptsTmp = np.vstack((labelPoints[:,0], labelPoints[:,1])).T;
      flyEllipse2 = cv2.fitEllipse(ptsTmp[:,::-1])
      bodyEllipses2[ii] = ellipseListToArray(flyEllipse2)

      # I'm pretty sure this is the reduced points??
      thisLine = cv2.fitLine(ptsBody, cv2.DIST_L2, 0, 0.01, 0.01)
      allRedLines[ii] = np.squeeze(thisLine)
      x1 = (thisLine[3]+thisLine[1]*50, thisLine[2]+thisLine[0]*50);
      x2 = (thisLine[3]-thisLine[1]*50, thisLine[2]-thisLine[0]*50);

      if (flyTrackerSettings.debugImages):
         cv2.ellipse(tmp,flyEllipse,(64,0,255),3)
         # print(flyEllipse)
         if (flyEllipse[1][1]/flyEllipse[1][0] > 1.5):
            cv2.circle(tmp,(int(flyEllipse[0][0]),int(flyEllipse[0][1])),5,(255,255,255),3)

         cv2.ellipse(tmp,flyEllipse2,(64,64,255),1)
         if (flyEllipse2[1][1]/flyEllipse2[1][0] > 1.5):
            cv2.circle(tmp,(int(flyEllipse2[0][0]),int(flyEllipse2[0][1])),3,(128,128,128),3)
         # print('fly ellipse difference: ' + str(np.abs(flyEllipse[2]-flyEllipse2[2])))

         cv2.line(tmp, x1, x2 , (128,128,128), 2)

      # Identify wings
      # frames 203-204 are where I'm dying the most
      a = np.tan(np.radians(flyEllipse[2]+90))
      b = flyEllipse[0][1] - a*flyEllipse[0][0]
      leftInd = labelPoints[:,0] > labelPoints[:,1]*a + b
      if (flyTrackerSettings.debugImages):
         for (x,y) in (labelPoints[leftInd,:]):
            tmp[x,y,:] = 127

   printDB(round(1000*(time.time() - t)))    # 400ms
   t = time.time()
   # import ipdb; ipdb.set_trace()
   if (flyTrackerSettings.debugImages):
      imgDB('tmpout_' + str(frameCount) + '.jpg', tmp)
      imgDB('tmpBody_' + str(frameCount) + '.jpg', tmpBd)
      cv2.imshow('clustered', tmpBd)
      cv2.waitKey(1)

   flyFlags = 0

   return newCenters, bodyEllipses2, bodyEllipses, allRedLines, chosenThreshold, flyFlags