class object_detector(): 
    def __init__(self, start): 	
        self.cap = cv2.VideoCapture(0)	
        self.start_time = start

        self.stored_flag = False
        self.trained_flag = False
        self.milstone_flag = False
        self.incremental_train_flag = False
        self.tracking_flag = False

        self.boxls = None
        self.count = 1
        self.new_count = 1
        self.path = "/home/intuitivecompting/Desktop/color/Smart-Projector/script/datasets/"
        if MODE == 'all':
            self.file = open(self.path + "read.txt", "w")
            self.milestone_file = open(self.path + "mileston_read.txt", "w")
        self.user_input = 0
        self.predict = None
        self.memory = cache(10)
        self.memory1 = cache(10)
        self.hand_memory = cache(10)

        self.node_sequence = []
        #-----------------------create GUI-----------------------#
        self.gui_img = np.zeros((130,640,3), np.uint8)
        cv2.circle(self.gui_img,(160,50),30,(255,0,0),-1)
        cv2.putText(self.gui_img,"start",(130,110),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(255,0,0))
        cv2.circle(self.gui_img,(320,50),30,(0,255,0),-1)
        cv2.putText(self.gui_img,"stop",(290,110),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,255,0))
        cv2.circle(self.gui_img,(480,50),30,(0,0,255),-1)
        cv2.putText(self.gui_img,"quit",(450,110),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
        cv2.namedWindow('gui_img')
        cv2.namedWindow('gui_img1')
        cv2.setMouseCallback('gui_img',self.gui_callback)
        cv2.setMouseCallback('gui_img1',self.gui_callback)
        #-----------------------Training sign--------------#
        self.training_surface = np.ones((610,640,3), np.uint8) * 255
        cv2.putText(self.training_surface,'Training...',(120,300),cv2.FONT_HERSHEY_SIMPLEX, 3.0,(255,192,203), 5)
        #----------------------new coming item id------------------#
        self.new_come_id = None
        self.old_come_id = None
        self.new_come_side = None
        self.old_come_side = None
        self.new_coming_lock = True
        self.once_lock = True
        #---------------------set some flag-------------------#
        self.storing = None
        self.quit = None
        self.once = True
        #---------------------set gui image----------------------#
        self.temp_surface = None
        #----------------------for easlier developing-----------------#
        if MODE == 'test':
            if not GPU:
                self.net = Net()
            else:
                self.net = Net().cuda()
            self.net.load_state_dict(torch.load(f=self.path + 'model'))
            self.user_input = 5
            self.stored_flag = True


    def update(self, save=True, train=False):
        
        self.boxls = []
        OK, origin = self.cap.read()
        if OK:
            rect = self.camrectify(origin)

            #-------------warp the image---------------------#
            warp = self.warp(rect)

            #-------------segment the object----------------#
            hsv = cv2.cvtColor(warp,cv2.COLOR_BGR2HSV)
            green_mask = cv2.inRange(hsv, Green_low, Green_high)
            # green_mask = cv2.inRange(hsv, np.array([45,90,29]), np.array([85,255,255]))
            hand_mask = cv2.inRange(hsv, Hand_low, Hand_high)
            hand_mask = cv2.dilate(hand_mask, kernel = np.ones((7,7),np.uint8))

            skin_mask = cv2.inRange(hsv, Skin_low, Skin_high)
            skin_mask = cv2.dilate(skin_mask, kernel = np.ones((7,7),np.uint8))

            
            
            thresh = 255 - green_mask
            thresh = cv2.subtract(thresh, hand_mask)
            thresh = cv2.subtract(thresh, skin_mask)
            thresh[477:, 50:610] = 0
            #thresh = cv2.dilate(thresh, kernel = np.ones((11,11),np.uint8))
            cv2.imshow('afg', thresh)
            draw_img1 = warp.copy()
            draw_img2 = warp.copy()
            draw_img3 = warp.copy()
            self.train_img = warp.copy()
            #-------------get the bounding box--------------
            self.get_bound(draw_img1, thresh, hand_mask, only=False, visualization=True)
            #--------------get bags of words and training-------#
            if MODE == 'all':
                #----------------------------storing image for each item---------#
                if not self.stored_flag:
                    self.temp_surface = np.vstack((draw_img1, self.gui_img))                    
                    self.stored_flag = self.store()
                    cv2.imshow('gui_img', self.temp_surface)
                #--------------------------training, just once------------------#
                if self.stored_flag and not self.trained_flag:  
                    cv2.destroyWindow('gui_img')
                    #cv2.imshow('training', self.training_surface)
                    self.trained_flag = self.train()
                #------------------------assembling and saving milstone---------#
                if self.trained_flag and not self.milstone_flag: 
                    self.test(draw_img2)
                    self.temp_surface = np.vstack((draw_img2, self.gui_img))
                    cv2.imshow('gui_img1', self.temp_surface)
                #-----------------------training saved milstone image---------#
                if self.milstone_flag and not self.incremental_train_flag:
                    cv2.destroyWindow('gui_img1')
                    self.incremental_train_flag = self.train(is_incremental=True)
                #-----------------------finalized tracking------------------#
                if self.incremental_train_flag and not self.tracking_flag:
                    self.test(draw_img3, is_tracking=True)
                    cv2.imshow('tracking', draw_img3)
            elif MODE == 'test':
                self.test(draw_img2)
                self.temp_surface = np.vstack((draw_img2, self.gui_img))
                cv2.imshow('gui_img', self.temp_surface)
                #cv2.imshow('track', draw_img2)
                #-----------------------training saved milstone image---------#
                if self.milstone_flag and not self.incremental_train_flag:
                    cv2.destroyWindow('gui_img')
                    self.incremental_train_flag = self.train(is_incremental=True)
                #-----------------------finalized tracking------------------#
                if self.incremental_train_flag and not self.tracking_flag:
                    self.test(draw_img3, is_tracking=True)
                    cv2.imshow('gui_img1', draw_img3)
            elif MODE == 'train':
                if not self.trained_flag:  
                    #cv2.destroyWindow('gui_img')
                    #cv2.imshow('training', self.training_surface)
                    self.trained_flag = self.train()
                #------------------------assembling and saving milstone---------#
                if self.trained_flag and not self.milstone_flag: 
                    self.test(draw_img2)
                    self.temp_surface = np.vstack((draw_img2, self.gui_img))
                    cv2.imshow('gui_img1', self.temp_surface)
                #-----------------------training saved milstone image---------#
                if self.milstone_flag and not self.incremental_train_flag:
                    cv2.destroyWindow('gui_img1')
                    self.incremental_train_flag = self.train(is_incremental=True)
                #-----------------------finalized tracking------------------#
                if self.incremental_train_flag and not self.tracking_flag:
                    self.test(draw_img3, is_tracking=True)
                    cv2.imshow('tracking', draw_img3)
        
    def get_bound(self, img, object_mask, hand_mask, only=True, visualization=True):
        (_,object_contours, object_hierarchy)=cv2.findContours(object_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
        (_,hand_contours, hand_hierarchy)=cv2.findContours(hand_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
        hand_m_ls = []
        object_m_ls = []
        if len(hand_contours) > 0:
            for i , contour in enumerate(hand_contours):
                area = cv2.contourArea(contour)
                if area>600 and area < 100000 and hand_hierarchy[0, i, 3] == -1:					
                    M = cv2.moments(contour)
                    cx = int(M['m10']/M['m00'])
                    cy = int(M['m01']/M['m00'])
                    hand_m_ls.append((cx, cy))
        if len(object_contours) > 0:
            for i , contour in enumerate(object_contours):
                area = cv2.contourArea(contour)
                if area>100 and area < 100000 and object_hierarchy[0, i, 3] == -1:					
                    M = cv2.moments(contour)
                    cx = int(M['m10']/M['m00'])
                    cy = int(M['m01']/M['m00'])
                    object_m_ls.append((cx, cy))
                    x,y,w,h = cv2.boundingRect(contour)
                    self.boxls.append([x, y, w, h])
        temp_i = []
        temp_j = []
        for (x3, y3) in hand_m_ls:
            for i in range(len(object_m_ls)):
                for j in range(i + 1, len(object_m_ls)):
                    x1, y1 = object_m_ls[i]
                    x2, y2 = object_m_ls[j]
                    d12 = distant((x1, y1), (x2, y2))
                    d13 = distant((x1, y1), (x3, y3))
                    d23 = distant((x2, y2), (x3, y3))
                    # dis = d13 * d23 / d12
                    # if dis < 60 and d12 < 140 and d13 < 100 and d23 < 100:
                    #     temp_i.append(i)
                    #     temp_j.append(j)
                    dis = self.get_k_dis((x1, y1), (x2, y2), (x3, y3))
                    if dis < 60 and d12 < 140 and d13 < 100 and d23 < 100:
                        temp_i.append(i)
                        temp_j.append(j)
                        # print(dis, d12, d13, d23)

        if len(temp_i) > 0 and len(temp_j) > 0 and len(self.boxls) >= 1:
            for (i, j) in zip(temp_i, temp_j):
                if self.boxls[i] != 0 and self.boxls[j] != 0:
                    x, y = np.min([self.boxls[i][0], self.boxls[j][0]]), np.min([self.boxls[i][1], self.boxls[j][1]])
                    x_max, y_max = np.max([self.boxls[i][0] + self.boxls[i][2], self.boxls[j][0] + self.boxls[j][2]]), np.max([self.boxls[i][1] + self.boxls[i][3], self.boxls[j][1] + self.boxls[j][3]])         
                    w, h = x_max - x, y_max - y
                    self.boxls[i] = 0
                    self.boxls[j] = [x, y, w, h]
            
            self.boxls = filter(lambda a: a != 0, self.boxls)   

            #---------------sorting the list according to the x coordinate of each item
        if len(self.boxls) > 0:
            boxls_arr = np.array(self.boxls)
            self.boxls = boxls_arr[boxls_arr[:, 0].argsort()].tolist()
        for i in range(len(self.boxls)): 
            if visualization: 
                ind = max(range(len(self.boxls)), key=lambda i:self.boxls[i][2]*self.boxls[i][3])
                x,y,w,h = self.boxls[ind]
                cv2.rectangle(img,(x,y),(x+w,y+h),(0,0,255),2)
                cv2.putText(img,str(self.user_input),(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))

     
    def gui_callback(self, event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDBLCLK and (self.temp_surface[y, x] == np.array([255, 0, 0])).all() and not self.storing:
            self.count = 1
            self.user_input += 1
            self.storing = True
            if self.user_input > 5:
                if self.once:
                    temp_node = node((self.new_come_id, self.old_come_id), (self.new_come_side, self.old_come_side),self.user_input)
                    self.once = False
                else:
                    temp_node = node((self.new_come_id, self.user_input - 1), (self.new_come_side, self.old_come_side), self.user_input)
                self.node_sequence.append(temp_node)
            print("start")
        if event == cv2.EVENT_LBUTTONDBLCLK and (self.temp_surface[y, x] == np.array([0, 255, 0])).all() and self.storing:
            self.storing = False
            self.new_coming_lock = True
            self.new_come_id = None
            self.old_come_id = None
            self.new_come_side = None
            self.old_come_side = None

            print("stop")
        if event == cv2.EVENT_LBUTTONDBLCLK and (self.temp_surface[y, x] == np.array([0, 0, 255])).all():
            self.storing = None
            self.quit = True
            print("quit")
            if self.stored_flag:
                x,y,w,h = self.boxls[0]
                sub_img = self.train_img[y:y+h, x:x+w, :]
                cv2.imwrite('test_imgs/saved' + str(self.user_input) + '.jpg', sub_img)
        # if event == cv2.EVENT_LBUTTONDBLCLK and (self.temp_surface[y, x] == np.array([255, 0, 255])).all():
        #     self.saving_milstone = True
        #     self.user_input += 1

    def store(self, is_milestone=False):
        # if is_milestone:
        #     file = self.milestone_file
        #     img_dir = os.path.join(self.path + "milestone_image", str(self.count) + ".jpg")
        #     self.createFolder(self.path + "milestone_image")
        # else:
        if is_milestone:
            self.file = open(self.path + "read.txt", "a")
            img_dir = os.path.join(self.path + "image", "milstone" + str(self.new_count) + ".jpg")
        else:
            img_dir = os.path.join(self.path + "image", str(self.new_count) + ".jpg")
        file = self.file
        self.createFolder(self.path + "image")
        if self.quit:
                file.close()
                print('finish output')               
                return True
        if len(self.boxls) > 0:
            if self.storing:
                cv2.putText(self.temp_surface,"recording",(450,50),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255), 2)
                frame = self.train_img
                ind = max(range(len(self.boxls)), key=lambda i:self.boxls[i][2]*self.boxls[i][3])
            #-------------capturing img for each of item--------------#
                x,y,w,h = self.boxls[ind]
                temp = frame[y:y+h, x:x+w, :]
                
                cv2.imwrite(img_dir, temp)         
                file.write(img_dir + " " + str(self.user_input) + "\n")
                if self.count % 100 == 0:
                    print('output imgs ' + str(self.count) + 'img' )
                self.count += 1
                self.new_count += 1 
                return False
            #-----------------get to the next item-----------    
        else:
            return False
        

    
        
    def train(self, is_incremental=False):
        if is_incremental:
            pickle.dump(node.pair_list ,open("node.p", "wb"))
        start_time = time.time()
        if not is_incremental:
            reader_train = self.reader(self.path, "read.txt")
            if not GPU:
                self.net = Net()
            else:
                self.net = Net().cuda()
        else:
            if not GPU:
                self.net = Net()
            else:
                self.net = Net().cuda()
            reader_train = self.reader(self.path, "read.txt")
            #self.net.load_state_dict(torch.load(f=self.path + 'model'))
        optimizer = optim.SGD(self.net.parameters(), lr=LR, momentum=MOMENTUM, nesterov=True)
        #optimizer = optim.Adam(self.net.parameters(), lr=LR, weight_decay=0.01)
        schedule = optim.lr_scheduler.StepLR(optimizer, step_size=STEP, gamma=GAMMA)
        trainset = CovnetDataset(reader=reader_train, transforms=transforms.Compose([transforms.Resize((200, 100)),
                                                                                            transforms.ToTensor()
                                                                                    ]))
        #trainset = CovnetDataset(reader=reader_train, transforms=transforms.Compose([transforms.Pad(30),
         #                                                                                     transforms.ToTensor()
          #                                                                            ]))
        trainloader = DataLoader(dataset=trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
#-----------------------------------training----------------------------------------------------------------        
        if True:
            loss_ls = []
            count = 0
            count_ls = []
            t = tqdm.trange(EPOTH, desc='Training')
            temp = 0
            for _ in t:  # loop over the dataset multiple times
                schedule.step()
                running_loss = 0.0
                i = 0
                for data in trainloader:

                    # get the inputs
                    inputs, labels = data
                    if GPU:
                        inputs, labels = inputs.cuda(), labels.cuda()
                    inputs, labels = Variable(inputs), Variable(labels.long())
                    # zero the parameter gradients
                    optimizer.zero_grad()
                    # forward + backward + optimize
                    outputs = self.net(inputs)
                    # print(outputs)
                    # print(labels.view(1, -1)[0])
                    loss = F.cross_entropy(outputs, labels.view(1, -1)[0])
                    loss.backward()
                    optimizer.step()
                    t.set_description('loss=%g' %(temp))

                    loss_ls.append(loss.item())
                    count += 1
                    count_ls.append(count)
                    
                    running_loss += loss.item()                    
                    if i % 10 == 9:   
                        temp = running_loss/10
                        running_loss = 0.0
                    i += 1
            plt.plot(count_ls, loss_ls)
            plt.show(block=False)
            print('Finished Training, using {} second'.format(int(time.time() - start_time)))
            
            self.quit = None
            
            if not is_incremental:
                self.user_input = 5
                torch.save(self.net.state_dict(), f=self.path + 'model')
            else:
                torch.save(self.net.state_dict(), f=self.path + 'milestone_model')
                # try:
                #     node_file = open(self.path + "node.txt", "w")
                #     for pair in node.pair_list: 
                #         node_file.write(str(pair[0][0]) + "" + str(pair[0][1]) + "" +str(pair[1][0]) + "" + str(pair[1][1]) + "\n")
                # except:
                #     print("fail to save")
            return True
#---------------------------------testing-----------------------------------------------
        
    def test(self, draw_img, is_tracking=False):
        self.predict = []
        net = self.net
        num_object = len(self.boxls)
        frame = self.train_img
        preprocess = transforms.Compose([transforms.Resize((200, 100)),
                                                    transforms.ToTensor()])
        #preprocess = transforms.Compose([transforms.Pad(30),
         #                                             transforms.ToTensor()])
        for i in range(num_object):
            x,y,w,h = self.boxls[i]
            temp = frame[y:y+h, x:x+w, :]
            temp = cv2.cvtColor(temp,cv2.COLOR_BGR2RGB)
            image = Image.fromarray(temp)
            img_tensor = preprocess(image)
            img_tensor.unsqueeze_(0)
            img_variable = Variable(img_tensor).cuda()
            if GPU:
                img_variable = Variable(img_tensor).cuda()
                out = np.argmax(net(img_variable).cpu().data.numpy()[0])
            else:
                img_variable = Variable(img_tensor)
                out = np.argmax(net(img_variable).data.numpy()[0])
            # if np.max(net(img_variable).cpu().data.numpy()[0]) > 0.9:
            #     out = np.argmax(net(img_variable).cpu().data.numpy()[0])
            # else:
            #     out = -1
            cv2.rectangle(draw_img,(x,y),(x+w,y+h),(0,0,255),2)
            cv2.putText(draw_img,str(out),(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
            self.predict.append(((x, y, w, h), out))
        if not is_tracking:
            if self.old_come_side is not None and self.new_come_side is None:
                cv2.putText(draw_img,"Point to next!",(220,50),cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0,0,255), 2)
            if self.new_come_side is not None and self.old_come_side is not None:
                cv2.putText(draw_img,"Start Assemble! Click Start when finish",(180,50),cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0,0,255), 2)
            lab, color, ind, coord = self.store_side(frame)
            if lab:
                self.get_pair(frame.copy(), num_object, lab, color, ind, coord)
            self.milstone_flag = self.store(is_milestone=True)
            
        # self.memory.append(self.predict)
        #print(len(self.memory.list))

    
    def store_side(self, frame):
        img = frame.copy()
        point, center = hand_tracking(img, self.hand_memory).get_result()
        if point and len(self.boxls) > 0:
            red_center = side_finder(img, color='red')
            blue_center = side_finder(img, color='blue')
            tape = red_center + blue_center
            length_ls = []
            for (x, y) in tape:
                length_ls.append((self.get_k_dis((point[0], point[1]), (center[0], center[1]), (x, y)), (x, y)))
            x,y = min(length_ls, key=lambda x: x[0])[1]
            cv2.circle(img, (x,y), 10, [255, 255, 0], -1)
            ind = test_insdie((x, y), self.boxls)

            # x,y,w,h = self.boxls[ind]
            # line_canvas = np.zeros((h, w))
            # cx, cy = center
            # x1, y1 = point
            # k = (y1-cy)/float(x1-cx)
            # cv2.line(line_canvas, point, (x1-50, y1-50*k), (255,0,0), 5)
            
            # frame_copy = frame.copy()
            # sub_img = frame_copy[y:y+h, x:x+w, :]
            # hsv = cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
            # object_mask = cv2.subtract(cv2.inRange(hsv, Green_low, Green_high),cv2.inRange(hsv, Hand_low, Hand_high))
            # object_mask = 255 - object_mask
            # (_,object_contours, object_hierarchy)=cv2.findContours(object_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
            # max_area = 0
            # cnt = None
            # for i , contour in enumerate(object_contours):
            #     area = cv2.contourArea(contour)
            #     if object_hierarchy[0, i, 3] == -1 and area > max_area:	
            #         max_area = area
            #         cnt = contour	
            # cnt_canvas = np.zeros((h, w))



            
            
            # cv2.imshow("point", img)
            # print(ind, self.predict)
            if ind is not None:
                color = None
                if (x, y) in red_center:
                    color = 'red'
                elif (x, y) in blue_center:
                    color = 'blue'
                return self.predict[ind][1], color, ind, (x, y)
            else:
                return None, None, None, None
        else:
            return None, None, None, None
        # 
            


    def get_pair(self,image, num_object, label, color, index, coord):
        '''
        pointing from left to right
        '''
        if self.once and self.once_lock and num_object == 2:
            if index == 0:
                self.memory.append(self.predict[0][1])
                if self.memory.full and self.new_come_id is None:
                    self.old_come_id = max(set(self.memory.list), key=self.memory.list.count)
                    # if self.new_come_id == label:
                    self.old_come_side = self.draw_point(image, coord, index)
        
                    # cv2.circle(image, coord, 5, (125, 125), 1)
                    # x,y,w,h = self.boxls[index]
                    # sub_img = image[y:y+h, x:x+w, :]
                    # cv2.imwrite('saved' + str(self.predict[0][1]) + '.jpg', sub_img)
                    # else:
                    #     self.memory.clear()                                
                
            if self.memory.full and index == 1:    
                self.memory1.append(self.predict[-1][1])
                if self.memory1.full:
                    self.new_come_id = max(set(self.memory1.list), key=self.memory1.list.count)
                    # if self.old_come_id == label:
                    self.new_come_side = self.draw_point(image, coord, index)
                    # cv2.circle(image, coord, 5, (125, 125), 1)
                    # x,y,w,h = self.boxls[index]
                    # sub_img = image[y:y+h, x:x+w, :]
                    # cv2.imwrite('saved' + str(self.predict[-1][1]) + '.jpg', sub_img)
                    # else:
                    #     self.memory1.clear()
                    
            if self.memory.full and self.memory1.full:
                self.once_lock = False
                self.memory.clear()
                self.memory1.clear()
                print("new_come_id:{}, old_come_id:{}".format(self.new_come_id, self.old_come_id))
                print("new_come_side:{}, old_come_side:{}".format(self.new_come_side, self.old_come_side))
        print(self.new_come_side, self.old_come_side)
        
        '''
        pointing from left to right
        '''
        if not self.once and num_object == 2 and self.new_coming_lock:
            if index == 0:
                self.memory.append(0)
                if self.memory.full:
                    self.old_come_side = self.draw_point(image, coord, index, is_milestone=True)
                    self.memory.clear()
                # cv2.circle(image, coord, 5, (125, 125), 1)
                # x,y,w,h = self.boxls[index]
                # sub_img = image[y:y+h, x:x+w, :]
                # cv2.imwrite('saved' + str(self.user_input + 1) + '.jpg', sub_img)
            elif index == 1 and self.new_come_id is None:               
                self.memory1.append(self.predict[-1][1])
                if self.memory1.full:
                    self.new_come_id = max(set(self.memory1.list), key=self.memory1.list.count)                    
                    self.new_come_side = self.draw_point(image, coord, index)
                    self.memory1.clear()

                    
                    # cv2.circle(image, coord, 5, (125, 125), 1)
                    # x,y,w,h = self.boxls[index]
                    # sub_img = image[y:y+h, x:x+w, :]
                    # cv2.imwrite('saved' + str(self.predict[1][1]) + '.jpg', sub_img)

            if self.new_come_side and self.old_come_side:
                self.new_coming_lock = False
                print("new_come_id:{}".format(self.new_come_id))
                print("new_come_side:{}, old_come_side:{}".format(self.new_come_side, self.old_come_side))



    def draw_point(self, image, coord, index, is_milestone=False):
        #cv2.circle(image, coord, 5, (125, 125), 1)
        x,y,w,h = self.boxls[index]
        sub_img = image[y:y+h, x:x+w, :]
        #cv2.circle(sub_img, (coord[0] - x, coord[1] - y) , 5, (125, 125), -1)
        if not is_milestone:
            cv2.imwrite('test_imgs/saved' + str(self.predict[index][1]) + '.jpg', sub_img)
            return (coord[0] - x, coord[1] - y)
        else:
            cv2.imwrite('test_imgs/saved' + str(self.user_input) + '.jpg', sub_img)
            return (coord[0] - x, coord[1] - y)


    def warp(self, img):
        #pts1 = np.float32([[115,124],[520,112],[2,476],[640,480]])
        pts1 = np.float32([[101,160],[531,133],[0,480],[640,480]])
        pts2 = np.float32([[0,0],[640,0],[0,480],[640,480]])
        M = cv2.getPerspectiveTransform(pts1,pts2)
        dst = cv2.warpPerspective(img,M,(640,480))
        return dst
            

    @staticmethod
    def get_k_dis((x1, y1), (x2, y2), (x, y)):
        coord = ((x, y), (x1, y1), (x2, y2))
        return Polygon(coord).area / distant((x1, y1), (x2, y2))
Beispiel #2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument(
        '--batch-size',
        type=int,
        default=64,
        metavar='N',
        help='input batch size for training (default: %(default)s)')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=100,
        metavar='N',
        help='input batch size for testing (default: %(default)s)')
    parser.add_argument(
        '--epochs',
        type=int,
        default=30,
        metavar='N',
        help='number of epochs to train (default: %(default)s)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        metavar='LR',
                        help='learning rate (default: %(default)s)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: %(default)s)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help=
        'how many batches to wait before logging training status (default: %(default)s)'
    )
    parser.add_argument('--dataset',
                        choices=['mnist', 'fashion-mnist'],
                        default='mnist',
                        metavar='D',
                        help='mnist/fashion-mnist (default: %(default)s)')
    parser.add_argument('--nonlin',
                        choices=['softplus', 'sigmoid', 'tanh'],
                        default='softplus',
                        metavar='D',
                        help='softplus/sigmoid/tanh (default: %(default)s)')
    parser.add_argument('--num-layers',
                        choices=['2', '3', '4'],
                        default=2,
                        metavar='N',
                        help='2/3/4 (default: %(default)s)')
    parser.add_argument('--epsilon',
                        type=float,
                        default=1.58,
                        metavar='E',
                        help='ball radius (default: %(default)s)')
    parser.add_argument('--test-epsilon',
                        type=float,
                        default=1.58,
                        metavar='E',
                        help='ball radius (default: %(default)s)')
    parser.add_argument(
        '--step-size',
        type=float,
        default=0.005,
        metavar='L',
        help='step size for finding adversarial example (default: %(default)s)'
    )
    parser.add_argument(
        '--num-steps',
        type=int,
        default=200,
        metavar='L',
        help=
        'number of steps for finding adversarial example (default: %(default)s)'
    )
    parser.add_argument(
        '--beta',
        type=float,
        default=0.005,
        metavar='L',
        help='regularization coefficient for Lipschitz (default: %(default)s)')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    if args.dataset == 'mnist':
        dataset = datasets.MNIST
    elif args.dataset == 'fashion-mnist':
        dataset = datasets.FashionMNIST
    else:
        raise ValueError('Unknown dataset %s', args.dataset)

    torch.manual_seed(args.seed)
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(dataset(
        './' + args.dataset,
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset(
        './' + args.dataset,
        train=False,
        transform=transforms.Compose([transforms.ToTensor()])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    model = Net(int(args.num_layers), args.nonlin).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    model_name = 'saved_models/' + args.dataset + '_' + str(
        args.num_layers) + '_' + args.nonlin + '_L2_' + str(
            args.epsilon) + '_EIGEN_' + str(args.beta)

    print(args)
    print(model_name)

    acc, empirical_acc = test_standard_adv(args, model, device, test_loader)
    certified_acc = test_cert(args, model, device, test_loader)

    best_acc = 0.
    best_empirical_acc = 0.
    best_certified_acc = 0.
    for epoch in range(1, args.epochs + 1):
        train_robust(args, model, device, train_loader, optimizer, epoch)
        acc, empirical_acc = test_standard_adv(args, model, device,
                                               test_loader)
        certified_acc = test_cert(args, model, device, test_loader)

        if acc > best_acc:
            best_acc = acc
            best_empirical_acc = empirical_acc
            best_certified_acc = certified_acc
            torch.save(model.state_dict(), model_name)
        print('Saved model: Accuracy: {:.4f}, Empirical Robust Accuracy: {:.4f}, Certified Robust Accuracy: {:.4f}\n'.\
            format(best_acc, best_empirical_acc, best_certified_acc))
class DQN:
    def __init__(self,
                 memory_size=50000,
                 batch_size=128,
                 gamma=0.99,
                 lr=1e-3,
                 n_step=500000):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.gamma = gamma

        # memory
        self.memory_size = memory_size
        self.Memory = ReplayMemory(self.memory_size)
        self.batch_size = batch_size

        # network
        self.target_net = Net().to(self.device)
        self.eval_net = Net().to(self.device)
        self.target_update()  # initialize same weight
        self.target_net.eval()

        # optim
        self.optimizer = optim.Adam(self.eval_net.parameters(), lr=lr)

    def select_action(self, state, eps):
        prob = random.random()
        if prob > eps:
            return self.eval_net.act(state), False
        else:
            return (torch.tensor(
                [[random.randrange(0, 9)]],
                device=self.device,
                dtype=torch.long,
            ), True)

    def select_dummy_action(self, state):
        state = state.reshape(3, 3, 3)

        open_spots = state[:, :, 0].reshape(-1)

        p = open_spots / open_spots.sum()

        return np.random.choice(np.arange(9), p=p)

    def target_update(self):
        self.target_net.load_state_dict(self.eval_net.state_dict())

    def learn(self):
        if self.Memory.__len__() < self.batch_size:
            return

        # random batch sampling
        transitions = self.Memory.sampling(self.batch_size)
        batch = Transition(*zip(*transitions))

        non_final_mask = torch.tensor(
            tuple(map(lambda s: s is not None, batch.next_state)),
            device=self.device,
            dtype=torch.bool,
        )

        non_final_next_states = torch.cat(
            [s for s in batch.next_state if s is not None]).to(self.device)
        state_batch = torch.cat(batch.state).to(self.device)
        action_batch = torch.cat(batch.action).to(self.device)
        reward_batch = torch.cat(batch.reward).to(self.device)

        # Q(s)
        Q_s = self.eval_net(state_batch).gather(1, action_batch)

        # maxQ(s') no grad for target_net
        Q_s_ = torch.zeros(self.batch_size, device=self.device)
        Q_s_[non_final_mask] = self.target_net(non_final_next_states).max(
            1)[0].detach()

        # Q_target=R+γ*maxQ(s')
        Q_target = reward_batch + (Q_s_ * self.gamma)

        # loss_fnc---(R+γ*maxQ(s'))-Q(s)
        # huber loss with delta=1
        loss = F.smooth_l1_loss(Q_s, Q_target.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.eval_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

    def load_net(self, name):
        self.action_net = torch.load(name).cpu()

    def load_weight(self, name):
        self.eval_net.load_state_dict(torch.load(name))
        self.eval_net = self.eval_net.cpu()

    def act(self, state):
        with torch.no_grad():
            p = F.softmax(self.action_net.forward(state)).cpu().numpy()
            valid_moves = (state.cpu().numpy().reshape(
                3, 3, 3).argmax(axis=2).reshape(-1) == 0)
            p = valid_moves * p
            return p.argmax()