Beispiel #1
0
    def findSamples(self, category, binImg, samples, outImg):
        contours = self.findContourAndBound(binImg.copy(), bounded=True,
                                            upperbounded=True,
                                            bound=self.minContourArea,
                                            upperbound=self.maxContourArea)
        contours = sorted(contours, key=cv2.contourArea, reverse=True)
        for contour in contours:
            # Find the center of each contour
            rect = cv2.minAreaRect(contour)
            centroid = rect[0]

            # Find the orientation of each contour
            points = np.int32(cv2.cv.BoxPoints(rect))
            edge1 = points[1] - points[0]
            edge2 = points[2] - points[1]

            if cv2.norm(edge1) > cv2.norm(edge2):
                angle = math.degrees(math.atan2(edge1[1], edge1[0]))
            else:
                angle = math.degrees(math.atan2(edge2[1], edge2[0]))

            if 90 < abs(Utils.normAngle(self.comms.curHeading) -
                        Utils.normAngle(angle)) < 270:
                angle = Utils.invertAngle(angle)

            samples.append({'centroid': centroid, 'angle': angle,
                            'area': cv2.contourArea(contour),
                            'category': category})

            if self.debugMode:
                Vision.drawRect(outImg, points)
Beispiel #2
0
    def print_table(ligs):
        for lig_num, lig in enumerate(ligs):
            for element_num, element in enumerate(lig):
                ligs[lig_num][element_num] = ActionSearch._format_.get(element,
                                                                       element)

        Utils.print_table(ligs)
Beispiel #3
0
    def execute(self, userdata):
        if self.comms.isKilled or self.comms.isAborted:
            self.comms.abortMission()
            return 'aborted'

        if not self.comms.retVal or \
           len(self.comms.retVal['foundLines']) == 0:
               return 'lost'

        lines = self.comms.retVal['foundLines']
        if len(lines) == 1 or self.comms.expectedLanes == 1:
            self.angleSampler.newSample(lines[0]['angle'])
        elif len(lines) >= 2:
            if self.comms.chosenLane == self.comms.LEFT:
                self.angleSampler.newSample(lines[0]['angle'])
            elif self.comms.chosenLane == self.comms.RIGHT:
                self.angleSampler.newSample(lines[1]['angle'])
            else:
                rospy.loginfo("Something goes wrong with chosenLane")

        variance = self.angleSampler.getVariance()
        rospy.loginfo("Variance: {}".format(variance))
        if (variance < 5.0):
            dAngle = Utils.toHeadingSpace(self.angleSampler.getMedian())
            adjustHeading = Utils.normAngle(self.comms.curHeading + dAngle)

            self.comms.sendMovement(h=adjustHeading, blocking=True)
            self.comms.adjustHeading = adjustHeading
            return 'aligned'
        else:
            rospy.sleep(rospy.Duration(0.05))
            return 'aligning'
Beispiel #4
0
 def create_share(self, obj):
     share = Utils.create_object()
     share.title = obj.logo_title
     share.description = obj.short_description
     share.url = Utils.get_current_url(self.request)
     encoded_url = urlquote_plus(share.url)
     title = obj.logo_title
     encoded_title = urlquote_plus(title)
     encoded_detail = urlquote_plus(obj.short_description)
     url_detail = obj.short_description + '\n\n' + share.url
     encoded_url_detail = urlquote_plus(url_detail)
     share.image_url = Utils.get_url(
         self.request,
         PosterService.poster_image_url(obj)
     )
     encoded_image_url = urlquote_plus(share.image_url)
     # email shouldn't encode space
     share.email = 'subject=%s&body=%s' % (
         urlquote(title, ''), urlquote(url_detail, '')
     )
     #
     share.fb = 'u=%s' % encoded_url
     #
     share.twitter = 'text=%s' % encoded_url_detail
     #
     share.google_plus = 'url=%s' % encoded_url
     #
     share.linkedin = 'url=%s&title=%s&summary=%s' % (
         encoded_url, encoded_title, encoded_detail
     )
     #
     share.pinterest = 'url=%s&media=%s&description=%s' % (
         encoded_url, encoded_image_url, encoded_detail
     )
     return share
Beispiel #5
0
    def execute(self, userdata):
        if self.comms.isAborted or self.comms.isKilled:
            self.comms.abortMission()
            return 'aborted'

        if not self.comms.retVal or \
           len(self.comms.retVal['matches']) == 0:
            return 'lost'

        self.comms.sendMovement(d=self.comms.aligningDepth,
                                blocking=True)
        try:
            # Align with the bins
            dAngle = Utils.toHeadingSpace(self.comms.nearest)
            adjustAngle = Utils.normAngle(dAngle + self.comms.curHeading)
            self.comms.adjustHeading = adjustAngle
            self.comms.visionFilter.visionMode = BinsVision.BINSMODE
            self.comms.sendMovement(h=adjustAngle,
                                    d=self.comms.aligningDepth,
                                    blocking=True)
            #self.comms.sendMovement(h=adjustAngle,
            #                        d=self.comms.sinkingDepth,
            #                        blocking=True)
            return 'aligned'
        except Exception as e:
            rospy.logerr(str(e))
            adjustAngle = self.comms.curHeading
            self.comms.adjustHeading = adjustAngle
            self.comms.sendMovement(h=adjustAngle, blocking=True)
            #self.comms.sendMovement(d=self.comms.sinkingDepth,
            #                        blocking=True)
            return 'aligned'
Beispiel #6
0
 def camCallback(self, rosImg):
     outImg = self.visionFilter.gotFrame(Utils.rosimg2cv(rosImg))
     if self.canPublish and outImg is not None:
         try:
             self.outPub.publish(Utils.cv2rosimg(outImg))
         except Exception, e:
             pass
Beispiel #7
0
    def run(self):
        images = self.client.images()

        # Parse images information
        images_enhanced = []
        for img in images:
            for repotag in img["RepoTags"]:
                registry, repository = self.parse_repository(
                    ":".join(repotag.split(":")[:-1]))
                images_enhanced.append({"IMAGE ID": img["Id"][:10],
                                        "CREATED": img["Created"],
                                        "VIRTUAL SIZE": img["VirtualSize"],
                                        "TAG": repotag.split(":")[-1],
                                        "REPOSITORY": repository,
                                        "REGISTRY": registry,
                                        })

        # Sort images (with facilities for sort key)
        sort_by = self.args.sort_by
        for column in self._FIELDS_:
            if column.startswith(sort_by.upper()):
                sort_by = column
                break
        images = sorted(images_enhanced, key=lambda x: x.get(sort_by))

        # Print images information
        for img in images:
            img["VIRTUAL SIZE"] = ActionImages.printable_size(
                img["VIRTUAL SIZE"])
            img["CREATED"] = ActionImages.printable_date(img["CREATED"])

        Utils.print_table([self._FIELDS_] + [[img[k]
                                              for k in self._FIELDS_] for img in images])
Beispiel #8
0
    def _evaluate_bandwidth_availability(self, check_1, check_2):
        state = ''
        changed = False

        wma_1 = Utils.compute_wma(check_1)
        logging.debug('BANDWIDTH wma_1: %s' % (wma_1))

        if check_2 is not None:
            wma_2 = Utils.compute_wma(check_2)
            logging.debug('BANDWIDTH wma_2: %s' % (wma_2))

            wma_diff = wma_2 - wma_1
            wma_diff_abs = abs(wma_diff)
            variation = round(float(wma_diff_abs/wma_1*100), 1)
            logging.debug('BANDWIDTH variation: %s' % (variation))

            if variation >= float(self._cfg['bandwidth_avail_factor']):
                variation_dim = 'accretion'
                if wma_diff > 0.0:
                    variation_dim = 'degradation'

                state = 'BANDWIDTH availability --> %s%% %s   <b style="color: red;">(!!! NEW !!!)</b>' % (variation, variation_dim)
                changed = True
            else:
                state = 'BANDWIDTH availability --> %s' % (wma_2)
        else:
            state = 'BANDWIDTH availability --> %s' % (wma_1)

        logging.debug('BANDWIDTH check_final: %s' % (state))
        return changed, state
Beispiel #9
0
 def __init__(self, config, phishtank, openphish, cleanmx):
     logging.debug("Instantiating the '%s' class" % (self.__class__.__name__))
     self._cfg = config
     self._phishtank = phishtank
     self._openphish = openphish
     self._cleanmx = cleanmx
     Utils.remove_dir_content(self._cfg['dir_out'])
Beispiel #10
0
    def _clean_status_container(self, status):
        targets = []
        for container in self.client.containers(all=True):
            if container["Status"].startswith(status):
                # Sanitize
                if container["Names"] is None:
                    container["Names"] = ["NO_NAME"]
                targets.append(container)

        if len(targets) == 0:
            print "No containers %s found." % (status.lower())
            return

        # Display available elements
        print "%d containers %s founds." % (len(targets), status.lower())
        ligs = [["NAME", "IMAGE", "COMMAND"]]
        ligs += [[",".join(c["Names"]).replace("/", ""), c["Image"], c["Command"]]
                 for c in targets]
        Utils.print_table(ligs)

        if Utils.ask("Remove some of them", default="N"):
            for container in targets:
                if Utils.ask(
                        "\tRemove %s" % container["Names"][0].replace("/", ""),
                        default="N"):
                    # Force is false to avoid bad moves
                    print "\t-> Removing %s..." % container["Id"][:10]
                    self.client.remove_container(container["Id"], v=False,
                                                 link=False,
                                                 force=False)
Beispiel #11
0
    def run(self):
        print("Running...")

        os.system('ls -i -t ' + self._cfg['dir_archive'] +'/* | cut -d\' \' -f2 | tail -n+1 | xargs rm -f')

        data_old = JSONAdapter.load(self._cfg['dir_in'], self._cfg['serial_file'])

        if data_old is not None:
            data_file_name_new = data_old[0] + '_' + self._cfg['serial_file']
            Utils.rename_file(self._cfg['dir_in'], self._cfg['dir_archive'], \
                              self._cfg['serial_file'], data_file_name_new)
        else:
            data_old = []

        if Utils.is_internet_up() is True:
            urls = self._cfg['urls_to_check']
            data_new = []
            data_new.insert(0, Utils.timestamp_to_string(time.time()))
            thread_id = 0
            threads = []
            display = Display(visible=0, size=(1024, 768))
            display.start()
            for url in urls:
                thread_id += 1
                alivechecker_thread = AliveChecker(thread_id, self._cfg, url)
                threads.append(alivechecker_thread)
                alivechecker_thread.start()

            # Waiting for all threads to complete
            for thread in threads:
                thread.join()
            display.stop()

            for thread in threads:
                data_new.append(thread.data)
                logging.debug('%s\n' % (thread.log))
                thread.browser.quit()

            if len(data_new) > 0:
                JSONAdapter.save(data_new, self._cfg['dir_in'], self._cfg['serial_file'])

                data_all = []
                if len(data_old) > 0:
                    data_all.append(data_old)
                data_all.append(data_new)
                JSONAdapter.save(data_all, self._cfg['dir_out'], self._cfg['serial_file'])

                state = self._evaluator.run(data_all)
                logging.debug('Final state: %s' % (state))

                if self._emailnotifiers and state != '':
                    EmailNotifiers.notify(state)
            else:
                logging.debug('Empty data')
        else:
            logging.error('Internet is definitely down!')
            sys.exit(2)

        print("Done...")
Beispiel #12
0
 def sonarImageCallback(self, rosImg):
     # outImg = self.visionFilter.gotSonarFrame(Utils.rosimg2cv(rosImg))
     outImg = self.sonarFilter.gotSonarFrame(Utils.rosimg2cv(rosImg))
     if self.canPublish and outImg is not None:
         try:
             self.sonarPub.publish(Utils.cv2rosimg(outImg))
         except Exception, e:
             pass
Beispiel #13
0
 def camCallback(self, rosImg):
     try:
         if self.processingCount == self.processingRate:
             self.retVal, outImg = self.visionFilter.gotFrame(Utils.rosimg2cv(rosImg))
             if self.canPublish and outImg is not None:
                 self.outPub.publish(Utils.cv2rosimg(outImg))
             self.processingCount = 0
         self.processingCount += 1
     except Exception as e:
         print e
Beispiel #14
0
 def camCallback(self, img):
     rospy.loginfo("Solo")
     img = Utils.rosimg2cv(img) 
     red_img = Img(img, conn)
     if(1000 < red_img.area < 1500):
         red_img.drawBounding(red_img.mask_bgr)
         red_img.drawCentroid(red_img.mask_bgr)
     drawCenter(red_img.mask_bgr)
     self.img_pub.publish(Utils.cv2rosimg(red_img.mask_bgr))
     self.img_pub2.publish(Utils.cv2rosimg(red_img.enhanced_bgr))
Beispiel #15
0
 def sonarCallback(self, rosimg):
     rospy.loginfo("Inside sonar")
     cvImg = Utils.rosimg2cv(rosimg)
     gray = cv2.cvtColor(cvImg, cv2.COLOR_BGR2GRAY)
     mask = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)[1]
     mask_bgr = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
     sobel_bgr = self.sobelLine(mask)
     #self.sonarCont(mask, mask_bgr)
     sonarPub = rospy.Publisher("/Vision/image_filter_sonar", Image)
     sonarPub.publish(Utils.cv2rosimg(sobel_bgr))
Beispiel #16
0
 def _dump_file_trails(self, apk_file):
     logging.debug("Dumping file_trails")
     path = apk_file.get_path()
     file = apk_file.get_filename()
     file_trails = dict()
     file_trails['file_name'] = file
     file_trails['file_md5_sum'] = Utils.compute_md5_file(path, file)
     file_trails['file_sha256_sum'] = Utils.compute_sha256_file(path, file)
     file_trails['file_dimension'] = Utils.get_size(path, file)
     return file_trails
Beispiel #17
0
    def execute(self, userdata):
        if self.comms.isKilled or self.comms.isAborted:
            self.comms.abortMission()
            return 'aborted'

        curCorner = self.comms.visionFilter.curCorner

        start = time.time()
        while not self.comms.retVal or \
              self.comms.retVal.get('foundLines', None) is None or \
              len(self.comms.retVal['foundLines']) == 0:
            if self.comms.isKilled or self.comms.isAborted:
                self.comms.abortMission()
                return 'aborted'
            if time.time() - start > self.timeout:
                if curCorner == 4: 
                    self.comms.failTask()
                    return 'lost'
                else:
                    self.comms.visionFilter.curCorner += 1
                    self.comms.detectingBox = True
                    return 'next_corner'
            rospy.sleep(rospy.Duration(0.1))

        # Calculate angle between box and lane
        if self.comms.visionFilter.curCorner == 0:
            boxCentroid = (self.centerX, self.centerY)
        else:
            boxCentroid = self.comms.visionFilter.corners[curCorner]
        laneCentroid = self.comms.retVal['foundLines'][0]['pos']
        boxLaneAngle = math.atan2(laneCentroid[1] - boxCentroid[1],
                                  laneCentroid[0] - boxCentroid[0])
        self.angleSampler.newSample(math.degrees(boxLaneAngle))

        variance = self.angleSampler.getVariance()
        rospy.loginfo("Variance: {}".format(variance))
        if (variance < 5.0):
            dAngle = Utils.toHeadingSpace(self.angleSampler.getMedian())
            adjustHeading = Utils.normAngle(self.comms.curHeading + dAngle)
            self.comms.inputHeading = adjustHeading
            rospy.loginfo("box-lane angle: {}".format(self.comms.inputHeading))
            self.comms.sendMovement(h=adjustHeading,
                                    d=self.comms.laneSearchDepth,
                                    blocking=True)
            self.comms.sendMovement(f=self.forward_dist, h=adjustHeading,
                                    d=self.comms.laneSearchDepth,
                                    blocking=True)
            self.comms.visionFilter.curCorner = 0
            return 'aligned'
        else:
            rospy.sleep(rospy.Duration(0.05))
            return 'aligning'
Beispiel #18
0
 def _extract(self, path_in, file, password):
     try:
         if Utils.is_zip(path_in, file):
             Utils.extract_zip(path_in, file, path_in, password)
         elif Utils.is_rar(path_in, file):
             Utils.extract_rar(path_in, file, path_in, password)
         elif Utils.is_tar(path_in, file):
             Utils.extract_tar(path_in, file, path_in)
     except OSError, e:
         logging.error(e)
         return False
Beispiel #19
0
    def save_ph_ws(ph_ws):
        logging.debug("XMLAdapter is storing phishing websites...")
        Utils.remove_dir_content(XMLAdapter.config['dir_out'])

        for cm_name, cm in ph_ws.items():
            if len(cm) > 0:
                try:
                    file_ph_ws_name = os.path.join(XMLAdapter.config['dir_out'], cm_name + '.xml')
                    file_ph_ws = open(file_ph_ws_name, 'w')

                    file_ph_ws.write(dicttoxml.dicttoxml(ph_ws))
                    file_ph_ws.close()
                except OSError, e:
                    logging.error("Error saving phishing websites in xml format: %s" % (e))
                    raise OSError
Beispiel #20
0
    def run(self):
        # Format inputs
        reg_from = self.get_registryaddr("from")
        reg_to = self.get_registryaddr("to")
        imgsrc = self.args.IMGSRC
        if ":" not in imgsrc:
            imgsrc += ":latest"
        imgdst = self.args.IMGDST if (self.args.IMGDST != '-') else imgsrc
        if ":" not in imgdst:
            imgdst += ":latest"
        isrc = reg_from + imgsrc
        idst = reg_to + imgdst

        # Confirm transfer
        if not Utils.ask("Transfer %s -> %s" % (isrc, idst)):
            return

        # Search for source image avaibility
        isrc_id = None
        for img in self.client.images():
            if isrc in img["RepoTags"]:
                isrc_id = img["Id"]
                print "'%s' is locally available (%s), use it" % (isrc, img["Id"][:10])

        # Source image is not available, pull it
        if isrc_id is None:
            if not Utils.ask("'%s' is not locally available, try to pull it" % isrc):
                return
            # Try to pull Image without creds
            res = self.client.pull(isrc, insecure_registry=True)
            if "error" in res:
                print "An error as occurred (DEBUG: %s)" % res
                return
            print "'%s' successfully pulled !" % isrc

            raise NotImplementedError("Get image id")

        # Tag the element
        idst, idst_tag = ":".join(idst.split(":")[:-1]), idst.split(":")[-1]
        self.client.tag(isrc_id, idst, tag=idst_tag, force=False)

        # Push the element, insecure mode
        print "Pushing..."
        for status in self.client.push(idst, tag=idst_tag, stream=True,
                                       insecure_registry=True):
            sys.stdout.write("\r" + json.loads(status)["status"])
            sys.stdout.flush()
        print "\nTransfer complete !"
Beispiel #21
0
 def turnRight(self):
     rospy.loginfo("Turning right...")
     # Turn to the right and look for another bin
     self.comms.sendMovement(h=Utils.normAngle(self.comms.adjustHeading+90),
                             d=self.comms.turnDepth,
                             blocking=True)
     self.comms.sendMovement(f=0.9, d=self.comms.turnDepth, blocking=True)
Beispiel #22
0
 def get_queryset(self):
     queryset = super(PosterFunViewSet, self).get_queryset()
     queryset = queryset.filter(
         poster=self.kwargs['poster_id'],
         ip_address=Utils.get_client_ip(self.request._request)
     )
     return queryset
Beispiel #23
0
    def get_likely_vehicles(self, nexttrip):
        prevtrips = []

        self._connect()
        self.cur.execute('SELECT date, vehicle FROM VEHICLES where trip=? ORDER BY date DESC', [nexttrip])
        date_vehicles = self.cur.fetchall()

        for d, v in date_vehicles:
            self.cur.execute('SELECT trip FROM VEHICLES where date=? and vehicle=? order by time asc', [d, v])
            ts = self.cur.fetchall()
            if len(ts)> 1:
                # find the trip just before our trip
                prevtrip = None
                for triplist in ts:
                    trip = triplist[0]
                    if trip == nexttrip:
                        # previous trip (if any) led to this one
                        if prevtrip:
                            prevtrips.append(prevtrip)
                    else:
                        prevtrip = trip

        ranked_prev_trips = Utils.ranklist(prevtrips)

        vehicles = []
        for trip, count in ranked_prev_trips:
            # get the latest vehicle for each trip
            self.cur.execute('SELECT vehicle FROM VEHICLES where trip=? ORDER BY date DESC', [trip])
            vehicle = self.cur.fetchone()
            vehicles.append((vehicle, count))

        return vehicles
Beispiel #24
0
    def execute(self, userdata):
        if self.comms.isAborted or self.comms.isKilled:
            self.comms.abortMission()
            return 'aborted'

        if not self.comms.retVal or \
           len(self.comms.retVal['samples']) == 0:
            self.comms.adjustDepth = self.comms.curDepth
            rospy.sleep(rospy.Duration(0.1))
            return 'approaching'

        curArea = self.comms.retVal['samples'][0]['area']
        rospy.loginfo("Area: {}".format(curArea))
        if curArea  > self.comms.grabbingArea or \
           self.comms.curDepth > self.comms.grabbingDepth:
            self.comms.adjustDepth = self.comms.curDepth
            return 'completed'

        samples = self.comms.retVal['samples']
        closest = min(samples,
                      key=lambda c:
                      Utils.distBetweenPoints(c['centroid'],
                                              (self.centerX, self.centerY)))
        dx = (closest['centroid'][0] - self.centerX) / self.width
        dy = (closest['centroid'][1] - self.centerY) / self.height

        self.comms.sendMovement(f=-self.ycoeff*dy, sm=self.xcoeff*dx,
                                d=self.comms.curDepth + 0.1,
                                h=self.comms.adjustHeading,
                                timeout=2,
                                blocking=False)
        return 'approaching'
Beispiel #25
0
 def sonarImageCallback(self, rosImg):
     outImg = self.gotSonarFrame(self.rosImgToCVMono(rosImg))
     if outImg is not None:
         try:
             self.sonarPub.publish(Utils.cv2rosimg(outImg))
         except Exception, e:
             pass
Beispiel #26
0
    def test_utils(self):

        testlist = [1, 2, 3, 3, 4, 5, 1, 1, 1, 2]
        rankedlist = Utils.ranklist(testlist)
        self.assertEqual(rankedlist,
            [(1, 4), (2, 2), (3, 2), (4, 1), (5, 1)])
        pass
Beispiel #27
0
    def __init__(self, peer_id, log=None):
        super(PeerState, self).__init__()
        self.id      = peer_id # site ID
        self.peers   = [] # known peers
        self.strokes = [] # currently drawn strokes
        self.prqs    = [] # past requests
        self.processed_ops = []

        self.session = -1

        # attached ui
        self.window = None

        self.lock = Lock()

        # site log file
        self.log = log

        if self.id >= 0:
            self.engine = OperationEngine(self.id,log)
        else:
            # This is so that I can draw locally if I never join a session
            self.engine = OperationEngine(0,log)

        self.queue  = Queue(log)

        # Join/leave handling
        self.ip = ''
        self.port = 0
        self.cs = None
        self.uid = Utils.generateID()
        self.ips = []
        self.ports = []
Beispiel #28
0
    def _download_openphish(self, path_in, user_agent, url_feed, username, password):
        logging.debug("Downloading OpenPhish feeds...")
        try:
            password_manager = urllib2.HTTPPasswordMgrWithDefaultRealm()
            password_manager.add_password(None, url_feed, username, password)

            auth = urllib2.HTTPBasicAuthHandler(password_manager)
            opener = urllib2.build_opener(auth)
            opener.addheaders = [('User-agent', user_agent)]
            urllib2.install_opener(opener)

            request = urllib2.Request(url_feed)
            handler = urllib2.urlopen(request)

            count = 0
            json_str = ''
            for d in handler.readlines():
                if count == 0:
                    json_str = '{'
                json_str += '\"phish_' + str(count) + "\" : " + d.strip() + ', \n'
                count = count + 1
            json_str = json_str[:-3] + '}'
            json_obj = json.loads(json_str)
            json_obj = Utils.decode_dict(json_obj)

            with open(os.path.join(path_in, 'openphish.txt'), 'w') as f:
                json.dump(json_obj, f, ensure_ascii=False)
        except Exception, e:
            logging.error("Error downloading OpenPhish feed '%s': %s" % (url_feed, e))
            raise Exception
Beispiel #29
0
 def __init__(self, path=[], width=0, color=[0,0,0,255], id='none'):
     self.path  = path
     self.width = width
     self.color = color
     if id == 'none':
         self.id = Utils.generateID()
     else:
         self.id    = id
Beispiel #30
0
    def _walk_dir(self, dir, path_out):
        try:
            for root, dirs, files in os.walk(dir):
                for file in files:
                    file.replace('$', '\$')

                    if Utils.is_zip(root, file) or \
                            Utils.is_rar(root, file) or \
                            Utils.is_tar(root, file):
                        self._extract_file(root, file)
                        Utils.remove_file(root, file)
                    else:
                        if Utils.is_apk(root, file):
                            Utils.rename_file(root, path_out, file)
                        else:
                            Utils.remove_file(root, file)
        except OSError, e:
            logging.error("Error walking dir '%s': %s" % (dir, e))
            raise OSError
Beispiel #31
0
 def get_navi_search_fav_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/li_layout2')
Beispiel #32
0
 def get_navi_nearby_name_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/tv_name')
Beispiel #33
0
 def get_navi_nearby_back_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/btn_back')
Beispiel #34
0
 def __get_navi_home_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/go_home')
 def extract_title(self, tree):
     title = tree.xpath(self.title_xpath)[0].strip()
     title = Utils.transform_coding(title.strip())
     return title
Beispiel #36
0
 def get_navi_search_addr_list_title(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/list_item_title')
Beispiel #37
0
 def get_navi_ready_to_dest_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/iv_dest')
Beispiel #38
0
 def __get_navi_active_confirmorcancel(self):
     return Utils().get_ele_by_resourceId(pkg_name +
                                          ':id/tv_active_confirmorcancel')
Beispiel #39
0
 def get_navi_current_road_name_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/ll_current_name')
Beispiel #40
0
 def get_navi_Compass_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/llCompass')
Beispiel #41
0
 def get_navi_satellite_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/iv_satellite')
Beispiel #42
0
 def get_navi_zoomSeekBar_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/zoomSeekBar')
Beispiel #43
0
 def get_navi_time_indicator_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/fl_tmc_indicator')
Beispiel #44
0
 def __get_navi_et_activenum_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/et_activenum')
Beispiel #45
0
 def get_navi_navipager_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/pager')
Beispiel #46
0
 def __get_navi_nearby_close_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/btn_close')
Beispiel #47
0
 def get_navi_search_listview_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/searchtip')
Beispiel #48
0
 def __get_navi_mapback_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/iv_mapback')
Beispiel #49
0
 def get_navi_search_city_list_item(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/list_item_tip')
Beispiel #50
0
 def __get_navi_search_result_all_ele(self):
     return Utils().get_ele_by_resourceId(pkg_name + ':id/ll_result_all')
Beispiel #51
0
class Ai2Thor():
    def __init__(self):
        self.visualize = False
        self.verbose = False
        self.save_imgs = True

        self.plot_loss = True
        # st()

        mapnames = []
        for i in [1, 201, 301, 401]:
            mapname = 'FloorPlan' + str(i)
            mapnames.append(mapname)

        # random.shuffle(mapnames)
        self.mapnames_train = mapnames
        self.num_episodes = len(self.mapnames_train)

        # get rest of the house in orders
        a = np.arange(2, 30)
        b = np.arange(202, 231)
        c = np.arange(302, 331)
        d = np.arange(402, 431)
        abcd = np.hstack((a, b, c, d))
        mapnames = []
        for i in range(a.shape[0]):
            mapname = 'FloorPlan' + str(a[i])
            mapnames.append(mapname)
            mapname = 'FloorPlan' + str(b[i])
            mapnames.append(mapname)
            mapname = 'FloorPlan' + str(c[i])
            mapnames.append(mapname)
            mapname = 'FloorPlan' + str(d[i])
            mapnames.append(mapname)

        self.mapnames_test = mapnames

        self.ignore_classes = []
        # classes to save
        # self.include_classes = [
        #     'ShowerDoor', 'Cabinet', 'CounterTop', 'Sink', 'Towel', 'HandTowel', 'TowelHolder', 'SoapBar',
        #     'ToiletPaper', 'ToiletPaperHanger', 'HandTowelHolder', 'SoapBottle', 'GarbageCan', 'Candle', 'ScrubBrush',
        #     'Plunger', 'SinkBasin', 'Cloth', 'SprayBottle', 'Toilet', 'Faucet', 'ShowerHead', 'Box', 'Bed', 'Book',
        #     'DeskLamp', 'BasketBall', 'Pen', 'Pillow', 'Pencil', 'CellPhone', 'KeyChain', 'Painting', 'CreditCard',
        #     'AlarmClock', 'CD', 'Laptop', 'Drawer', 'SideTable', 'Chair', 'Blinds', 'Desk', 'Curtains', 'Dresser',
        #     'Watch', 'Television', 'WateringCan', 'Newspaper', 'FloorLamp', 'RemoteControl', 'HousePlant', 'Statue',
        #     'Ottoman', 'ArmChair', 'Sofa', 'DogBed', 'BaseballBat', 'TennisRacket', 'VacuumCleaner', 'Mug', 'ShelvingUnit',
        #     'Shelf', 'StoveBurner', 'Apple', 'Lettuce', 'Bottle', 'Egg', 'Microwave', 'CoffeeMachine', 'Fork', 'Fridge',
        #     'WineBottle', 'Spatula', 'Bread', 'Tomato', 'Pan', 'Cup', 'Pot', 'SaltShaker', 'Potato', 'PepperShaker',
        #     'ButterKnife', 'StoveKnob', 'Toaster', 'DishSponge', 'Spoon', 'Plate', 'Knife', 'DiningTable', 'Bowl',
        #     'LaundryHamper', 'Vase', 'Stool', 'CoffeeTable', 'Poster', 'Bathtub', 'TissueBox', 'Footstool', 'BathtubBasin',
        #     'ShowerCurtain', 'TVStand', 'Boots', 'RoomDecor', 'PaperTowelRoll', 'Ladle', 'Kettle', 'Safe', 'GarbageBag', 'TeddyBear',
        #     'TableTopDecor', 'Dumbbell', 'Desktop', 'AluminumFoil', 'Window']

        # These are all classes shared between aithor and coco
        self.include_classes = [
            'Sink',
            'Toilet',
            'Bed',
            'Book',
            'CellPhone',
            'AlarmClock',
            'Laptop',
            'Chair',
            'Television',
            'RemoteControl',
            'HousePlant',
            'Ottoman',
            'ArmChair',
            'Sofa',
            'BaseballBat',
            'TennisRacket',
            'Mug',
            'Apple',
            'Bottle',
            'Microwave',
            'Fork',
            'Fridge',
            'WineBottle',
            'Cup',
            'ButterKnife',
            'Toaster',
            'Spoon',
            'Knife',
            'DiningTable',
            'Bowl',
            'Vase',
            'TeddyBear',
        ]

        self.maskrcnn_to_ithor = {
            81: 'Sink',
            70: 'Toilet',
            65: 'Bed',
            84: 'Book',
            77: 'CellPhone',
            85: 'AlarmClock',
            73: 'Laptop',
            62: 'Chair',
            72: 'Television',
            75: 'RemoteControl',
            64: 'HousePlant',
            62: 'Ottoman',
            62: 'ArmChair',
            63: 'Sofa',
            39: 'BaseballBat',
            43: 'TennisRacket',
            47: 'Mug',
            53: 'Apple',
            44: 'Bottle',
            78: 'Microwave',
            48: 'Fork',
            82: 'Fridge',
            44: 'WineBottle',
            47: 'Cup',
            49: 'ButterKnife',
            80: 'Toaster',
            50: 'Spoon',
            49: 'Knife',
            67: 'DiningTable',
            51: 'Bowl',
            86: 'Vase',
            88: 'TeddyBear',
        }

        self.ithor_to_maskrcnn = {
            'Sink': 81,
            'Toilet': 70,
            'Bed': 65,
            'Book': 84,
            'CellPhone': 77,
            'AlarmClock': 85,
            'Laptop': 73,
            'Chair': 62,
            'Television': 72,
            'RemoteControl': 75,
            'HousePlant': 64,
            'Ottoman': 62,
            'ArmChair': 62,
            'Sofa': 63,
            'BaseballBat': 39,
            'TennisRacket': 43,
            'Mug': 47,
            'Apple': 53,
            'Bottle': 44,
            'Microwave': 78,
            'Fork': 48,
            'Fridge': 82,
            'WineBottle': 44,
            'Cup': 47,
            'ButterKnife': 49,
            'Toaster': 80,
            'Spoon': 50,
            'Knife': 49,
            'DiningTable': 67,
            'Bowl': 51,
            'Vase': 86,
            'TeddyBear': 88,
        }

        self.maskrcnn_to_catname = {
            81: 'sink',
            67: 'dining table',
            65: 'bed',
            84: 'book',
            77: 'cell phone',
            70: 'toilet',
            85: 'clock',
            73: 'laptop',
            62: 'chair',
            72: 'tv',
            75: 'remote',
            64: 'potted plant',
            63: 'couch',
            39: 'baseball bat',
            43: 'tennis racket',
            47: 'cup',
            53: 'apple',
            44: 'bottle',
            78: 'microwave',
            48: 'fork',
            82: 'refrigerator',
            46: 'wine glass',
            49: 'knife',
            79: 'oven',
            80: 'toaster',
            50: 'spoon',
            67: 'dining table',
            51: 'bowl',
            86: 'vase',
            88: 'teddy bear',
        }

        self.obj_conf_dict = {
            'sink': [],
            'dining table': [],
            'bed': [],
            'book': [],
            'cell phone': [],
            'clock': [],
            'laptop': [],
            'chair': [],
            'tv': [],
            'remote': [],
            'potted plant': [],
            'couch': [],
            'baseball bat': [],
            'tennis racket': [],
            'cup': [],
            'apple': [],
            'bottle': [],
            'microwave': [],
            'fork': [],
            'refrigerator': [],
            'wine glass': [],
            'knife': [],
            'oven': [],
            'toaster': [],
            'spoon': [],
            'dining table': [],
            'bowl': [],
            'vase': [],
            'teddy bear': [],
        }

        self.data_store = {
            'sink': {},
            'dining table': {},
            'bed': {},
            'book': {},
            'cell phone': {},
            'clock': {},
            'laptop': {},
            'chair': {},
            'tv': {},
            'remote': {},
            'potted plant': {},
            'couch': {},
            'baseball bat': {},
            'tennis racket': {},
            'cup': {},
            'apple': {},
            'bottle': {},
            'microwave': {},
            'fork': {},
            'refrigerator': {},
            'wine glass': {},
            'knife': {},
            'oven': {},
            'toaster': {},
            'spoon': {},
            'dining table': {},
            'bowl': {},
            'vase': {},
            'teddy bear': {},
        }

        self.data_store_features = []
        self.feature_obj_ids = []
        self.first_time = True
        self.Softmax = nn.Softmax(dim=0)

        self.action_space = {
            0: "MoveLeft",
            1: "MoveRight",
            2: "MoveAhead",
            3: "MoveBack",
            4: "DoNothing"
        }
        self.num_actions = len(self.action_space)

        cfg_det = get_cfg()
        cfg_det.merge_from_file(
            model_zoo.get_config_file(
                "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg_det.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1  # set threshold for this model
        cfg_det.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
            "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        cfg_det.MODEL.DEVICE = 'cuda'
        self.cfg_det = cfg_det
        self.maskrcnn = DefaultPredictor(cfg_det)

        self.normalize = transforms.Compose([
            transforms.Resize(256, interpolation=PIL.Image.BILINEAR),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        # Initialize vgg
        vgg16 = torchvision.models.vgg16(pretrained=True).double().cuda()
        vgg16.eval()
        print(torch.nn.Sequential(*list(vgg16.features.children())))
        self.vgg_feat_extractor = torch.nn.Sequential(
            *list(vgg16.features.children())[:-2])
        print(self.vgg_feat_extractor)
        self.vgg_mean = torch.from_numpy(
            np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1))
        self.vgg_std = torch.from_numpy(
            np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1))

        self.conf_thresh_detect = 0.7  # for initially detecting a low confident object
        self.conf_thresh_init = 0.8  # for after turning head toward object threshold
        self.conf_thresh_end = 0.9  # if reach this then stop getting obs

        self.BATCH_SIZE = 50  # frames (not episodes) - this is approximate - it could be higher
        # self.percentile = 70
        self.max_iters = 100000
        self.max_frames = 10
        self.val_interval = 10  #10
        self.save_interval = 50

        # self.BATCH_SIZE = 2
        # self.percentile = 70
        # self.max_iters = 100000
        # self.max_frames = 2
        # self.val_interval = 1
        # self.save_interval = 1

        self.small_classes = []
        self.rot_interval = 5.0
        self.radius_max = 3.5  #3 #1.75
        self.radius_min = 1.0  #1.25
        self.num_flat_views = 3
        self.num_any_views = 7
        self.num_views = 25
        self.center_from_mask = False  # get object centroid from maskrcnn (True) or gt (False)

        self.obj_per_scene = 5

        mod = 'test00'

        # self.homepath = f'/home/nel/gsarch/aithor/data/test2'
        self.homepath = '/home/sirdome/katefgroup/gsarch/ithor/data/' + mod
        print(self.homepath)
        if not os.path.exists(self.homepath):
            os.mkdir(self.homepath)
        else:
            val = input("Delete homepath? [y/n]: ")
            if val == 'y':
                import shutil
                shutil.rmtree(self.homepath)
                os.mkdir(self.homepath)
            else:
                print("ENDING")
                assert (False)

        self.log_freq = 1
        self.log_dir = self.homepath + '/..' + '/log_cem/' + mod
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)
        MAX_QUEUE = 10  # flushes when this amount waiting
        self.writer = SummaryWriter(self.log_dir,
                                    max_queue=MAX_QUEUE,
                                    flush_secs=60)

        self.W = 256
        self.H = 256

        self.fov = 90

        self.utils = Utils(self.fov, self.W, self.H)
        self.K = self.utils.get_habitat_pix_T_camX(self.fov)
        self.camera_matrix = self.utils.get_camera_matrix(
            self.W, self.H, self.fov)

        self.controller = Controller(
            scene='FloorPlan30',  # will change 
            gridSize=0.25,
            width=self.W,
            height=self.H,
            fieldOfView=self.fov,
            renderObjectImage=True,
            renderDepthImage=True,
        )

        self.init_network()

        self.run_episodes()

    def init_network(self):

        input_shape = np.array([3, self.W, self.H])

        self.localpnet = LocalPNET(input_shape=input_shape,
                                   num_actions=self.num_actions).cuda()

        self.loss = nn.CrossEntropyLoss()

        self.optimizer = torch.optim.Adam(params=self.localpnet.parameters(),
                                          lr=0.00001)

    def batch_iteration(self, mapnames, BATCH_SIZE):

        batch = {
            "actions": [],
            "obs_all": [],
            "seg_ims": [],
            "conf_end_change": [],
            "conf_avg_change": [],
            "conf_median_change": []
        }
        iter_idx = 0
        total_loss = torch.tensor(0.0).cuda()
        num_obs = 0
        while True:

            mapname = np.random.choice(mapnames)

            # self.basepath = self.homepath + f"/{mapname}_{episode}"
            # print("BASEPATH: ", self.basepath)

            # # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            # if not os.path.exists(self.basepath):
            #     os.mkdir(self.basepath)

            self.controller.reset(scene=mapname)

            total_loss, obs, actions, seg_ims, confs = self.run(
                "train", total_loss)

            if obs is None:
                print("NO EPISODE LOSS.. SKIPPING BATCH INSTANCE")
                continue

            num_obs += len(actions)

            print("Total loss for train batch # ", iter_idx, " :", total_loss)

            confs = np.array(confs)
            conf_end_change = confs[-1] - confs[0]
            conf_avg_change = np.mean(np.diff(confs))
            conf_median_change = np.median(np.diff(confs))

            batch["actions"].append(actions)
            # These are only used for plotting
            batch["obs_all"].append(obs)
            batch["seg_ims"].append(seg_ims)
            batch["conf_end_change"].append(conf_end_change)
            batch["conf_avg_change"].append(conf_avg_change)
            batch["conf_median_change"].append(conf_median_change)

            iter_idx += 1

            # if len(batch["obs_all"]) == BATCH_SIZE:
            if num_obs >= BATCH_SIZE:
                print("NUM OBS IN BATCH=", num_obs)
                # batch["total_loss"] = total_loss
                print("Total loss for iter: ", total_loss)

                return total_loss, batch, num_obs
                # iter_idx = 0
                # total_loss = torch.tensor(0.0).cuda()
                # batch = {"actions": [], "obs_all": [], "seg_ims": [], "conf_end_change": [], "conf_avg_change": []}

    def run_episodes(self):
        self.ep_idx = 0
        # self.objects = []

        for episode in range(len(self.mapnames_train)):
            print("STARTING EPISODE ", episode)

            mapname = self.mapnames_train[episode]
            print("MAPNAME=", mapname)

            self.controller.reset(scene=mapname)

            # self.controller.start()

            self.basepath = self.homepath + f"/{mapname}_{episode}"
            print("BASEPATH: ", self.basepath)

            # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            if not os.path.exists(self.basepath):
                os.mkdir(self.basepath)

            self.run(mode="train")

            self.ep_idx += 1

        self.ep_idx = 1
        self.best_inner_prods = []
        self.pred_ids = []
        self.true_ids = []
        self.pred_catnames = []
        self.true_catnames = []
        self.pred_catnames_all = []
        self.true_catnames_all = []
        self.conf_mats = []
        # self.pred_catnames = []
        for episode in range(len(self.mapnames_test)):
            print("STARTING EPISODE ", episode)

            mapname = self.mapnames_test[episode]
            print("MAPNAME=", mapname)

            self.controller.reset(scene=mapname)

            # self.controller.start()

            self.basepath = self.homepath + f"/{mapname}_{episode}"
            print("BASEPATH: ", self.basepath)

            # self.basepath = f"/hdd/ayushj/habitat_data/{mapname}_{episode}"
            if not os.path.exists(self.basepath):
                os.mkdir(self.basepath)

            self.run(mode="test")

            if self.ep_idx % 4 == 0:
                self.best_inner_prods = np.array(self.best_inner_prods)
                self.pred_ids = np.array(self.pred_ids)
                self.true_ids = np.array(self.true_ids)
                # for i in range(len(self.best_inner_prods)):s

                correct_pred = self.best_inner_prods[self.pred_ids ==
                                                     self.true_ids]
                incorrect_pred = self.best_inner_prods[
                    self.pred_ids != self.true_ids]

                bins = 50
                plt.figure(1)
                plt.clf()
                plt.hist([correct_pred, incorrect_pred],
                         alpha=0.5,
                         histtype='stepfilled',
                         label=['correct', 'incorrect'],
                         bins=bins)
                plt.title(f'testhouse{self.ep_idx//4}')
                plt.xlabel('inner product of nearest neighbor')
                plt.ylabel('Counts')
                plt.legend()
                plt_name = self.homepath + f'/correct_incorrect_testhouse{self.ep_idx//4}.png'
                plt.savefig(plt_name)

                conf_mat = confusion_matrix(self.pred_catnames,
                                            self.true_catnames,
                                            labels=self.include_classes)
                self.conf_mats.append(conf_mat)

                plt.figure(1)
                plt.clf()
                df_cm = pd.DataFrame(conf_mat,
                                     index=[i for i in self.include_classes],
                                     columns=[i for i in self.include_classes])
                plt.figure(figsize=(10, 7))
                sn.heatmap(df_cm, annot=True)
                plt_name = self.homepath + f'/confusion_matrix_testhouse{self.ep_idx//4}.png'
                plt.savefig(plt_name)
                # plt.show()

                self.pred_catnames_all.extend(self.pred_catnames)
                self.true_catnames_all.extend(self.true_catnames)
                self.best_inner_prods = []
                self.pred_ids = []
                self.true_ids = []
                self.true_catnames = []
                self.pred_catnames = []
                self.true_catnames = []

                conf_mat = confusion_matrix(self.pred_catnames_all,
                                            self.true_catnames_all,
                                            labels=self.include_classes)
                plt.figure(1)
                plt.clf()
                df_cm = pd.DataFrame(conf_mat,
                                     index=[i for i in self.include_classes],
                                     columns=[i for i in self.include_classes])
                plt.figure(figsize=(10, 7))
                sn.heatmap(df_cm, annot=True)
                plt_name = self.homepath + f'/confusion_matrix_testhouses_all.png'
                plt.savefig(plt_name)

            self.ep_idx += 1

        self.controller.stop()
        time.sleep(1)

    def run2(self):
        event = self.controller.step('GetReachablePositions')
        for obj in event.metadata['objects']:
            if obj['objectType'] not in self.objects:
                self.objects.append(obj['objectType'])

    def get_detectron_conf_center_obj(self, im, obj_mask, frame=None):
        im = Image.fromarray(im, mode="RGB")
        im = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)

        outputs = self.maskrcnn(im)

        pred_masks = outputs['instances'].pred_masks
        pred_scores = outputs['instances'].scores
        pred_classes = outputs['instances'].pred_classes

        len_pad = 5

        W2_low = self.W // 2 - len_pad
        W2_high = self.W // 2 + len_pad
        H2_low = self.H // 2 - len_pad
        H2_high = self.H // 2 + len_pad

        if False:

            v = Visualizer(im[:, :, ::-1],
                           MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]),
                           scale=1.0)
            out = v.draw_instance_predictions(outputs['instances'].to("cpu"))
            seg_im = out.get_image()

            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + f'/seg_all{frame}.png'
            plt.savefig(plt_name)

            seg_im[W2_low:W2_high, H2_low:H2_high, :] = 0.0
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + f'/seg_all_mask{frame}.png'
            plt.savefig(plt_name)

        ind_obj = None
        # max_overlap = 0
        sum_obj_mask = np.sum(obj_mask)
        mask_sum_thresh = 7000
        for idx in range(pred_masks.shape[0]):
            pred_mask_cur = pred_masks[idx].detach().cpu().numpy()
            pred_masks_center = pred_mask_cur[W2_low:W2_high, H2_low:H2_high]
            sum_pred_mask_cur = np.sum(pred_mask_cur)
            # print(torch.sum(pred_masks_center))
            if np.sum(pred_masks_center) > 0:
                if np.abs(sum_pred_mask_cur - sum_obj_mask) < mask_sum_thresh:
                    ind_obj = idx
                    mask_sum_thresh = np.abs(sum_pred_mask_cur - sum_obj_mask)
                # max_overlap = torch.sum(pred_masks_center)
        if ind_obj is None:
            print("RETURNING NONE")
            return None, None, None, None

        v = Visualizer(im[:, :, ::-1],
                       MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]),
                       scale=1.0)
        out = v.draw_instance_predictions(
            outputs['instances'][ind_obj].to("cpu"))
        seg_im = out.get_image()

        if False:
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + f'/seg{frame}.png'
            plt.savefig(plt_name)

        # print("OBJ CLASS ID=", int(pred_classes[ind_obj].detach().cpu().numpy()))
        # pred_boxes = outputs['instances'].pred_boxes.tensor
        # pred_classes = outputs['instances'].pred_classes
        # pred_scores = outputs['instances'].scores
        obj_score = float(pred_scores[ind_obj].detach().cpu().numpy())
        obj_pred_classes = int(pred_classes[ind_obj].detach().cpu().numpy())
        obj_pred_mask = pred_masks[ind_obj].detach().cpu().numpy()

        return obj_score, obj_pred_classes, obj_pred_mask, seg_im

    def detect_object_centroid(self, im, event):

        im = Image.fromarray(im, mode="RGB")
        im = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)

        outputs = self.maskrcnn(im)

        v = Visualizer(im[:, :, ::-1],
                       MetadataCatalog.get(self.cfg_det.DATASETS.TRAIN[0]),
                       scale=1.2)
        out = v.draw_instance_predictions(outputs['instances'].to("cpu"))
        seg_im = out.get_image()

        if False:
            plt.figure(1)
            plt.clf()
            plt.imshow(seg_im)
            plt_name = self.homepath + '/seg_init.png'
            plt.savefig(plt_name)

        pred_masks = outputs['instances'].pred_masks
        pred_boxes = outputs['instances'].pred_boxes.tensor
        pred_classes = outputs['instances'].pred_classes
        pred_scores = outputs['instances'].scores

        obj_catids = []
        obj_scores = []
        obj_masks = []
        for segs in range(len(pred_masks)):
            if pred_scores[segs] <= self.conf_thresh_detect:
                obj_catids.append(pred_classes[segs].item())
                obj_scores.append(pred_scores[segs].item())
                obj_masks.append(pred_masks[segs])

        eulers_xyz_rad = np.radians(
            np.array([
                event.metadata['agent']['cameraHorizon'],
                event.metadata['agent']['rotation']['y'], 0.0
            ]))

        rx = eulers_xyz_rad[0]
        ry = eulers_xyz_rad[1]
        rz = eulers_xyz_rad[2]
        rotation_ = self.utils.eul2rotm(-rx, -ry, rz)

        translation_ = np.array(
            list(event.metadata['agent']['position'].values())) + np.array(
                [0.0, 0.675, 0.0])
        # need to invert since z is positive here by convention
        translation_[2] = -translation_[2]

        T_world_cam = np.eye(4)
        T_world_cam[0:3, 0:3] = rotation_
        T_world_cam[0:3, 3] = translation_

        if not obj_masks:
            return None, None
        elif self.center_from_mask:

            # want an object not on the edges of the image
            sum_interior = 0
            while sum_interior == 0:
                if len(obj_masks) == 0:
                    return None, None
                random_int = np.random.randint(low=0, high=len(obj_masks))
                obj_mask_focus = obj_masks.pop(random_int)
                print("OBJECT ID INIT=", obj_catids[random_int])
                sum_interior = torch.sum(obj_mask_focus[50:self.W - 50,
                                                        50:self.H - 50])

            depth = event.depth_frame

            xs, ys = np.meshgrid(np.linspace(-1 * 256 / 2., 1 * 256 / 2., 256),
                                 np.linspace(1 * 256 / 2., -1 * 256 / 2., 256))
            depth = depth.reshape(1, 256, 256)
            xs = xs.reshape(1, 256, 256)
            ys = ys.reshape(1, 256, 256)

            xys = np.vstack(
                (xs * depth, ys * depth, -depth, np.ones(depth.shape)))
            xys = xys.reshape(4, -1)
            xy_c0 = np.matmul(np.linalg.inv(self.K), xys)
            xyz = xy_c0.T[:, :3].reshape(256, 256, 3)
            xyz_obj_masked = xyz[obj_mask_focus]

            xyz_obj_masked = np.matmul(
                rotation_, xyz_obj_masked.T) + translation_.reshape(3, 1)
            xyz_obj_mid = np.mean(xyz_obj_masked, axis=1)

            xyz_obj_mid[2] = -xyz_obj_mid[2]
        else:

            # want an object not on the edges of the image
            sum_interior = 0
            while True:
                if len(obj_masks) == 0:
                    return None, None
                random_int = np.random.randint(low=0, high=len(obj_masks))
                obj_mask_focus = obj_masks.pop(random_int)
                # print("OBJECT ID INIT=", obj_catids[random_int])
                sum_interior = torch.sum(obj_mask_focus[50:self.W - 50,
                                                        50:self.H - 50])
                if sum_interior < 500:
                    continue  # exclude too small objects

                pixel_locs_obj = np.where(obj_mask_focus.cpu().numpy())
                x_mid = np.round(np.median(pixel_locs_obj[1]) / self.W, 4)
                y_mid = np.round(np.median(pixel_locs_obj[0]) / self.H, 4)

                if False:
                    plt.figure(1)
                    plt.clf()
                    plt.imshow(obj_mask_focus)
                    plt.plot(np.median(pixel_locs_obj[1]),
                             np.median(pixel_locs_obj[0]), 'x')
                    plt_name = self.homepath + '/seg_mask.png'
                    plt.savefig(plt_name)

                event = self.controller.step('TouchThenApplyForce',
                                             x=x_mid,
                                             y=y_mid,
                                             handDistance=1000000.0,
                                             direction=dict(x=0.0,
                                                            y=0.0,
                                                            z=0.0),
                                             moveMagnitude=0.0)
                obj_focus_id = event.metadata['actionReturn']['objectId']

                xyz_obj_mid = None
                for o in event.metadata['objects']:
                    if o['objectId'] == obj_focus_id:
                        if o['objectType'] not in self.include_classes_final:
                            continue
                        xyz_obj_mid = np.array(
                            list(o['axisAlignedBoundingBox']
                                 ['center'].values()))

                if xyz_obj_mid is not None:
                    break

        print("MIDPOINT=", xyz_obj_mid)
        return xyz_obj_mid, obj_mask_focus

    def run(self, mode=None, total_loss=None, summ_writer=None):

        event = self.controller.step('GetReachablePositions')
        if not event.metadata['reachablePositions']:
            # Different versions this is empty/full
            event = self.controller.step(action='MoveAhead')
        self.nav_pts = event.metadata['reachablePositions']
        self.nav_pts = np.array([list(d.values()) for d in self.nav_pts])
        # objects = np.random.choice(event.metadata['objects'], self.obj_per_scene, replace=False)
        objects = event.metadata['objects']
        objects_inds = np.arange(len(event.metadata['objects']))
        np.random.shuffle(objects_inds)

        # objects = np.random.shuffle(event.metadata['objects'])
        # for obj in event.metadata['objects']: #objects:
        #     print(obj['name'])
        # objects = objects[0]
        successes = 0
        # meta_obj_idx = 0
        num_obs = 0
        # while successes < self.obj_per_scene and meta_obj_idx <= len(event.metadata['objects']) - 1:
        for obj in objects:
            # if meta_obj_idx > len(event.metadata['objects']) - 1:
            #     print("OUT OF OBJECT... RETURNING")
            #     return total_loss, None, None, None, None

            # obj = objects[objects_inds[meta_obj_idx]]
            # meta_obj_idx += 1
            print("Center object is ", obj['objectType'])

            st()
            # if obj['name'] in ['Microwave_b200e0bc']:
            #     print(obj['name'])
            # else:
            #     continue
            # print(obj['name'])

            if obj['objectType'] not in self.include_classes:
                print("Continuing... Invalid Object")
                continue

            # Calculate distance to object center
            obj_center = np.array(
                list(obj['axisAlignedBoundingBox']['center'].values()))

            obj_center = np.expand_dims(obj_center, axis=0)
            distances = np.sqrt(np.sum((self.nav_pts - obj_center)**2, axis=1))

            # Get points with r_min < dist < r_max
            valid_pts = self.nav_pts[np.where(
                (distances > self.radius_min) * (distances < self.radius_max))]

            # Bin points based on angles [vertical_angle (10 deg/bin), horizontal_angle (10 deg/bin)]
            valid_pts_shift = valid_pts - obj_center

            dz = valid_pts_shift[:, 2]
            dx = valid_pts_shift[:, 0]
            dy = valid_pts_shift[:, 1]

            # Get yaw for binning
            valid_yaw = np.degrees(np.arctan2(dz, dx))

            if mode == "train":
                nbins = 10  #20
            else:
                nbins = 5
            bins = np.linspace(-180, 180, nbins + 1)
            bin_yaw = np.digitize(valid_yaw, bins)

            num_valid_bins = np.unique(bin_yaw).size

            if False:
                import matplotlib.cm as cm
                colors = iter(cm.rainbow(np.linspace(0, 1, nbins)))
                plt.figure(2)
                plt.clf()
                print(np.unique(bin_yaw))
                for bi in range(nbins):
                    cur_bi = np.where(bin_yaw == (bi + 1))
                    points = valid_pts[cur_bi]
                    x_sample = points[:, 0]
                    z_sample = points[:, 2]
                    plt.plot(z_sample, x_sample, 'o', color=next(colors))
                plt.plot(self.nav_pts[:, 2],
                         self.nav_pts[:, 0],
                         'x',
                         color='red')
                plt.plot(obj_center[:, 2],
                         obj_center[:, 0],
                         'x',
                         color='black')
                plt_name = '/home/nel/gsarch/aithor/data/valid.png'
                plt.savefig(plt_name)

            if num_valid_bins == 0:
                continue

            if mode == "train":
                spawns_per_bin = 3  #20
            else:
                spawns_per_bin = 1  #int(self.num_views / num_valid_bins) + 2
            # print(f'spawns_per_bin: {spawns_per_bin}')

            action = "do_nothing"
            episodes = []
            valid_pts_selected = []
            camXs_T_camX0_4x4 = []
            camX0_T_camXs_4x4 = []
            origin_T_camXs = []
            origin_T_camXs_t = []
            cnt = 0
            for b in range(nbins):

                # get all angle indices in the current bin range
                inds_bin_cur = np.where(
                    bin_yaw == (b + 1))  # bins start 1 so need +1
                inds_bin_cur = list(inds_bin_cur[0])
                if len(inds_bin_cur) == 0:
                    continue

                for s in range(spawns_per_bin):

                    observations = {}

                    if len(inds_bin_cur) == 0:
                        continue

                    rand_ind = np.random.randint(0, len(inds_bin_cur))
                    s_ind = inds_bin_cur.pop(rand_ind)

                    pos_s = valid_pts[s_ind]
                    valid_pts_selected.append(pos_s)

                    # add height from center of agent to camera
                    pos_s[1] = pos_s[1] + 0.675

                    turn_yaw, turn_pitch = self.utils.get_rotation_to_obj(
                        obj_center, pos_s)

                    event = self.controller.step('TeleportFull',
                                                 x=pos_s[0],
                                                 y=pos_s[1],
                                                 z=pos_s[2],
                                                 rotation=dict(x=0.0,
                                                               y=int(turn_yaw),
                                                               z=0.0),
                                                 horizon=int(turn_pitch))

                    rgb = event.frame

                    object_id = obj['objectId']

                    instance_detections2D = event.instance_detections2D

                    if object_id not in instance_detections2D:
                        print("NOT in instance detections 2D.. continuing")
                        continue
                    obj_instance_detection2D = instance_detections2D[
                        object_id]  # [start_x, start_y, end_x, end_y]

                    max_len = np.max(
                        np.array([
                            obj_instance_detection2D[2] -
                            obj_instance_detection2D[0],
                            obj_instance_detection2D[3] -
                            obj_instance_detection2D[1]
                        ]))
                    pad_len = max_len // 8

                    if pad_len == 0:
                        print("pad len 0.. continuing")
                        continue

                    x_center = (obj_instance_detection2D[3] +
                                obj_instance_detection2D[1]) // 2
                    x_low = x_center - max_len - pad_len
                    if x_low < 0:
                        x_low = 0
                    x_high = x_center + max_len + pad_len  #x_low + max_len + 2*pad_len
                    if x_high > self.W:
                        x_high = self.W

                    y_center = (obj_instance_detection2D[2] +
                                obj_instance_detection2D[0]) // 2
                    y_low = y_center - max_len - pad_len  #-pad_len
                    if y_low < 0:
                        y_low = 0
                    y_high = y_center + max_len + pad_len  #y_low + max_len + 2*pad_len
                    if y_high > self.H:
                        y_high = self.H

                    rgb_crop = rgb[x_low:x_high, y_low:y_high, :]

                    rgb_crop = Image.fromarray(rgb_crop)

                    normalize_cropped_rgb = self.normalize(rgb_crop).unsqueeze(
                        0).double().cuda()

                    obj_features = self.vgg_feat_extractor(
                        normalize_cropped_rgb).view((512, -1))

                    obj_features = obj_features.detach().cpu().numpy()

                    # pca = PCA(n_components=10)
                    # obj_features = pca.fit_transform(obj_features.T).flatten()

                    # obj_features = torch.from_numpy(obj_features).view(-1).cuda()
                    obj_features = obj_features.flatten()

                    if mode == "train":
                        if self.first_time:
                            self.first_time = False
                            self.data_store_features = obj_features
                            # self.data_store_features = self.data_store_features.cuda()
                            self.feature_obj_ids.append(
                                self.ithor_to_maskrcnn[obj['objectType']])
                        else:
                            # self.data_store_features = torch.vstack((self.data_store_features, obj_features))
                            self.data_store_features = np.vstack(
                                (self.data_store_features, obj_features))
                            self.feature_obj_ids.append(
                                self.ithor_to_maskrcnn[obj['objectType']])

                    elif mode == "test":

                        # obj_features = obj_features.unsqueeze(0)

                        # inner_prod = torch.abs(torch.mm(obj_features, self.data_store_features.T)).squeeze()

                        # inner_prod = inner_prod.detach().cpu().numpy()

                        # dist = np.squeeze(np.abs(np.matmul(obj_features, self.data_store_features.transpose())))

                        dist = np.linalg.norm(self.data_store_features -
                                              obj_features,
                                              axis=1)

                        k = 10

                        ind_knn = list(np.argsort(dist)[:k])

                        dist_knn = np.sort(dist)[:k]
                        dist_knn_norm = list(
                            self.Softmax(torch.from_numpy(-dist_knn)).numpy())

                        match_knn_id = [
                            self.feature_obj_ids[i] for i in ind_knn
                        ]

                        # for i in range(1, len(match_knn_id)):

                        # add softmax values from the same class (probably a really complex way of doing this)
                        idx = 0
                        dist_knn_norm_add = []
                        match_knn_id_add = []
                        while True:
                            if not match_knn_id:
                                break
                            match_knn_cur = match_knn_id.pop(0)
                            dist_knn_norm_cur = dist_knn_norm.pop(0)
                            match_knn_id_add.append(match_knn_cur)
                            idxs_ = []
                            for i in range(len(match_knn_id)):
                                if match_knn_id[i] == match_knn_cur:
                                    dist_knn_norm_cur += dist_knn_norm[i]
                                    # match_knn_id_.pop(i)
                                else:
                                    idxs_.append(i)
                            match_knn_id = [match_knn_id[idx] for idx in idxs_]
                            dist_knn_norm = [
                                dist_knn_norm[idx] for idx in idxs_
                            ]
                            dist_knn_norm_add.append(dist_knn_norm_cur)

                        dist_knn_norm_add = np.array(dist_knn_norm_add)

                        dist_knn_argmax = np.argmax(dist_knn_norm_add)

                        match_nn_id = match_knn_id_add[
                            dist_knn_argmax]  #self.feature_obj_ids[ind_nn]

                        match_nn_catname = self.maskrcnn_to_ithor[match_nn_id]

                        self.best_inner_prods.append(
                            dist_knn_norm_add[dist_knn_argmax])
                        self.pred_ids.append(match_nn_id)
                        # self.pred_catnames.append(match_nn_catname)
                        self.true_ids.append(
                            self.ithor_to_maskrcnn[obj['objectType']])
                        self.pred_catnames.append(match_nn_catname)
                        self.true_catnames.append(obj['objectType'])

                        print(match_nn_catname)

                        self.data_store_features = np.vstack(
                            (self.data_store_features, obj_features))
                        self.feature_obj_ids.append(
                            self.ithor_to_maskrcnn[obj['objectType']])

                    if False:
                        normalize_cropped_rgb = np.transpose(
                            normalize_cropped_rgb.squeeze(
                                0).detach().cpu().numpy(), (1, 2, 0))
                        plt.figure(1)
                        plt.clf()
                        plt.imshow(normalize_cropped_rgb)
                        # plt_name = self.homepath + '/seg_init.png'
                        plt.figure(2)
                        plt.clf()
                        plt.imshow(rgb)
                        plt.show()

                        plt.figure(3)
                        plt.clf()
                        plt.imshow(np.array(rgb_crop))
                        plt.show()
Beispiel #52
0
    def __init__(self):
        self.visualize = False
        self.verbose = False
        self.save_imgs = True

        self.plot_loss = True
        # st()

        mapnames = []
        for i in [1, 201, 301, 401]:
            mapname = 'FloorPlan' + str(i)
            mapnames.append(mapname)

        # random.shuffle(mapnames)
        self.mapnames_train = mapnames
        self.num_episodes = len(self.mapnames_train)

        # get rest of the house in orders
        a = np.arange(2, 30)
        b = np.arange(202, 231)
        c = np.arange(302, 331)
        d = np.arange(402, 431)
        abcd = np.hstack((a, b, c, d))
        mapnames = []
        for i in range(a.shape[0]):
            mapname = 'FloorPlan' + str(a[i])
            mapnames.append(mapname)
            mapname = 'FloorPlan' + str(b[i])
            mapnames.append(mapname)
            mapname = 'FloorPlan' + str(c[i])
            mapnames.append(mapname)
            mapname = 'FloorPlan' + str(d[i])
            mapnames.append(mapname)

        self.mapnames_test = mapnames

        self.ignore_classes = []
        # classes to save
        # self.include_classes = [
        #     'ShowerDoor', 'Cabinet', 'CounterTop', 'Sink', 'Towel', 'HandTowel', 'TowelHolder', 'SoapBar',
        #     'ToiletPaper', 'ToiletPaperHanger', 'HandTowelHolder', 'SoapBottle', 'GarbageCan', 'Candle', 'ScrubBrush',
        #     'Plunger', 'SinkBasin', 'Cloth', 'SprayBottle', 'Toilet', 'Faucet', 'ShowerHead', 'Box', 'Bed', 'Book',
        #     'DeskLamp', 'BasketBall', 'Pen', 'Pillow', 'Pencil', 'CellPhone', 'KeyChain', 'Painting', 'CreditCard',
        #     'AlarmClock', 'CD', 'Laptop', 'Drawer', 'SideTable', 'Chair', 'Blinds', 'Desk', 'Curtains', 'Dresser',
        #     'Watch', 'Television', 'WateringCan', 'Newspaper', 'FloorLamp', 'RemoteControl', 'HousePlant', 'Statue',
        #     'Ottoman', 'ArmChair', 'Sofa', 'DogBed', 'BaseballBat', 'TennisRacket', 'VacuumCleaner', 'Mug', 'ShelvingUnit',
        #     'Shelf', 'StoveBurner', 'Apple', 'Lettuce', 'Bottle', 'Egg', 'Microwave', 'CoffeeMachine', 'Fork', 'Fridge',
        #     'WineBottle', 'Spatula', 'Bread', 'Tomato', 'Pan', 'Cup', 'Pot', 'SaltShaker', 'Potato', 'PepperShaker',
        #     'ButterKnife', 'StoveKnob', 'Toaster', 'DishSponge', 'Spoon', 'Plate', 'Knife', 'DiningTable', 'Bowl',
        #     'LaundryHamper', 'Vase', 'Stool', 'CoffeeTable', 'Poster', 'Bathtub', 'TissueBox', 'Footstool', 'BathtubBasin',
        #     'ShowerCurtain', 'TVStand', 'Boots', 'RoomDecor', 'PaperTowelRoll', 'Ladle', 'Kettle', 'Safe', 'GarbageBag', 'TeddyBear',
        #     'TableTopDecor', 'Dumbbell', 'Desktop', 'AluminumFoil', 'Window']

        # These are all classes shared between aithor and coco
        self.include_classes = [
            'Sink',
            'Toilet',
            'Bed',
            'Book',
            'CellPhone',
            'AlarmClock',
            'Laptop',
            'Chair',
            'Television',
            'RemoteControl',
            'HousePlant',
            'Ottoman',
            'ArmChair',
            'Sofa',
            'BaseballBat',
            'TennisRacket',
            'Mug',
            'Apple',
            'Bottle',
            'Microwave',
            'Fork',
            'Fridge',
            'WineBottle',
            'Cup',
            'ButterKnife',
            'Toaster',
            'Spoon',
            'Knife',
            'DiningTable',
            'Bowl',
            'Vase',
            'TeddyBear',
        ]

        self.maskrcnn_to_ithor = {
            81: 'Sink',
            70: 'Toilet',
            65: 'Bed',
            84: 'Book',
            77: 'CellPhone',
            85: 'AlarmClock',
            73: 'Laptop',
            62: 'Chair',
            72: 'Television',
            75: 'RemoteControl',
            64: 'HousePlant',
            62: 'Ottoman',
            62: 'ArmChair',
            63: 'Sofa',
            39: 'BaseballBat',
            43: 'TennisRacket',
            47: 'Mug',
            53: 'Apple',
            44: 'Bottle',
            78: 'Microwave',
            48: 'Fork',
            82: 'Fridge',
            44: 'WineBottle',
            47: 'Cup',
            49: 'ButterKnife',
            80: 'Toaster',
            50: 'Spoon',
            49: 'Knife',
            67: 'DiningTable',
            51: 'Bowl',
            86: 'Vase',
            88: 'TeddyBear',
        }

        self.ithor_to_maskrcnn = {
            'Sink': 81,
            'Toilet': 70,
            'Bed': 65,
            'Book': 84,
            'CellPhone': 77,
            'AlarmClock': 85,
            'Laptop': 73,
            'Chair': 62,
            'Television': 72,
            'RemoteControl': 75,
            'HousePlant': 64,
            'Ottoman': 62,
            'ArmChair': 62,
            'Sofa': 63,
            'BaseballBat': 39,
            'TennisRacket': 43,
            'Mug': 47,
            'Apple': 53,
            'Bottle': 44,
            'Microwave': 78,
            'Fork': 48,
            'Fridge': 82,
            'WineBottle': 44,
            'Cup': 47,
            'ButterKnife': 49,
            'Toaster': 80,
            'Spoon': 50,
            'Knife': 49,
            'DiningTable': 67,
            'Bowl': 51,
            'Vase': 86,
            'TeddyBear': 88,
        }

        self.maskrcnn_to_catname = {
            81: 'sink',
            67: 'dining table',
            65: 'bed',
            84: 'book',
            77: 'cell phone',
            70: 'toilet',
            85: 'clock',
            73: 'laptop',
            62: 'chair',
            72: 'tv',
            75: 'remote',
            64: 'potted plant',
            63: 'couch',
            39: 'baseball bat',
            43: 'tennis racket',
            47: 'cup',
            53: 'apple',
            44: 'bottle',
            78: 'microwave',
            48: 'fork',
            82: 'refrigerator',
            46: 'wine glass',
            49: 'knife',
            79: 'oven',
            80: 'toaster',
            50: 'spoon',
            67: 'dining table',
            51: 'bowl',
            86: 'vase',
            88: 'teddy bear',
        }

        self.obj_conf_dict = {
            'sink': [],
            'dining table': [],
            'bed': [],
            'book': [],
            'cell phone': [],
            'clock': [],
            'laptop': [],
            'chair': [],
            'tv': [],
            'remote': [],
            'potted plant': [],
            'couch': [],
            'baseball bat': [],
            'tennis racket': [],
            'cup': [],
            'apple': [],
            'bottle': [],
            'microwave': [],
            'fork': [],
            'refrigerator': [],
            'wine glass': [],
            'knife': [],
            'oven': [],
            'toaster': [],
            'spoon': [],
            'dining table': [],
            'bowl': [],
            'vase': [],
            'teddy bear': [],
        }

        self.data_store = {
            'sink': {},
            'dining table': {},
            'bed': {},
            'book': {},
            'cell phone': {},
            'clock': {},
            'laptop': {},
            'chair': {},
            'tv': {},
            'remote': {},
            'potted plant': {},
            'couch': {},
            'baseball bat': {},
            'tennis racket': {},
            'cup': {},
            'apple': {},
            'bottle': {},
            'microwave': {},
            'fork': {},
            'refrigerator': {},
            'wine glass': {},
            'knife': {},
            'oven': {},
            'toaster': {},
            'spoon': {},
            'dining table': {},
            'bowl': {},
            'vase': {},
            'teddy bear': {},
        }

        self.data_store_features = []
        self.feature_obj_ids = []
        self.first_time = True
        self.Softmax = nn.Softmax(dim=0)

        self.action_space = {
            0: "MoveLeft",
            1: "MoveRight",
            2: "MoveAhead",
            3: "MoveBack",
            4: "DoNothing"
        }
        self.num_actions = len(self.action_space)

        cfg_det = get_cfg()
        cfg_det.merge_from_file(
            model_zoo.get_config_file(
                "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg_det.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1  # set threshold for this model
        cfg_det.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
            "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        cfg_det.MODEL.DEVICE = 'cuda'
        self.cfg_det = cfg_det
        self.maskrcnn = DefaultPredictor(cfg_det)

        self.normalize = transforms.Compose([
            transforms.Resize(256, interpolation=PIL.Image.BILINEAR),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        # Initialize vgg
        vgg16 = torchvision.models.vgg16(pretrained=True).double().cuda()
        vgg16.eval()
        print(torch.nn.Sequential(*list(vgg16.features.children())))
        self.vgg_feat_extractor = torch.nn.Sequential(
            *list(vgg16.features.children())[:-2])
        print(self.vgg_feat_extractor)
        self.vgg_mean = torch.from_numpy(
            np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1))
        self.vgg_std = torch.from_numpy(
            np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1))

        self.conf_thresh_detect = 0.7  # for initially detecting a low confident object
        self.conf_thresh_init = 0.8  # for after turning head toward object threshold
        self.conf_thresh_end = 0.9  # if reach this then stop getting obs

        self.BATCH_SIZE = 50  # frames (not episodes) - this is approximate - it could be higher
        # self.percentile = 70
        self.max_iters = 100000
        self.max_frames = 10
        self.val_interval = 10  #10
        self.save_interval = 50

        # self.BATCH_SIZE = 2
        # self.percentile = 70
        # self.max_iters = 100000
        # self.max_frames = 2
        # self.val_interval = 1
        # self.save_interval = 1

        self.small_classes = []
        self.rot_interval = 5.0
        self.radius_max = 3.5  #3 #1.75
        self.radius_min = 1.0  #1.25
        self.num_flat_views = 3
        self.num_any_views = 7
        self.num_views = 25
        self.center_from_mask = False  # get object centroid from maskrcnn (True) or gt (False)

        self.obj_per_scene = 5

        mod = 'test00'

        # self.homepath = f'/home/nel/gsarch/aithor/data/test2'
        self.homepath = '/home/sirdome/katefgroup/gsarch/ithor/data/' + mod
        print(self.homepath)
        if not os.path.exists(self.homepath):
            os.mkdir(self.homepath)
        else:
            val = input("Delete homepath? [y/n]: ")
            if val == 'y':
                import shutil
                shutil.rmtree(self.homepath)
                os.mkdir(self.homepath)
            else:
                print("ENDING")
                assert (False)

        self.log_freq = 1
        self.log_dir = self.homepath + '/..' + '/log_cem/' + mod
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)
        MAX_QUEUE = 10  # flushes when this amount waiting
        self.writer = SummaryWriter(self.log_dir,
                                    max_queue=MAX_QUEUE,
                                    flush_secs=60)

        self.W = 256
        self.H = 256

        self.fov = 90

        self.utils = Utils(self.fov, self.W, self.H)
        self.K = self.utils.get_habitat_pix_T_camX(self.fov)
        self.camera_matrix = self.utils.get_camera_matrix(
            self.W, self.H, self.fov)

        self.controller = Controller(
            scene='FloorPlan30',  # will change 
            gridSize=0.25,
            width=self.W,
            height=self.H,
            fieldOfView=self.fov,
            renderObjectImage=True,
            renderDepthImage=True,
        )

        self.init_network()

        self.run_episodes()