Ejemplo n.º 1
0
def load_files(tasks_nb, datafile_name):
    with open(datafile_name, 'rb') as input:
        data_saved_data = pickle.load(input)
    train_datasets, test_datasets = get_datasets(task_number=tasks_nb,
                                                batch_size_train=128,
                                                batch_size_test=4096,
                                                saved_data=data_saved_data)

    kfacs = []
    all_models = {}

    for i in range(tasks_nb):
        model_name = '{:d}-0'.format(i)
        model = Net().cuda()
        model.load_state_dict(torch.load('models/{:s}.pt'.format(model_name)))
        all_models[model_name] = model

        with open('kfacs/{:d}_weights.pkl'.format(i), 'rb') as input:
            weights = pickle.load(input)
        with open('kfacs/{:d}_maa.pkl'.format(i), 'rb') as input:
            m_aa = pickle.load(input)
        with open('kfacs/{:d}_mgg.pkl'.format(i), 'rb') as input:
            m_gg = pickle.load(input)
        
        kfac = KFAC(model, train_datasets[i], False)
        kfac.weights = weights
        kfac.m_aa = m_aa
        kfac.m_gg = m_gg
        kfacs.append([kfac])

    return train_datasets, test_datasets, kfacs, all_models
Ejemplo n.º 2
0
class Process(object):
    def __init__(self, args):
        self.file_path = os.path.dirname(
            os.path.abspath(__file__))  # current file path
        self.use_cuda = args.use_cuda
        self.scale = args.scale
        self.dir = args.dir
        self.grasp_angle = args.grasp_angle
        self.voxel_size = args.voxel_size
        self.color_topic = args.color_topic
        self.depth_topic = args.depth_topic
        self.episode = args.episode
        self.run = args.run
        self.num_objs = args.num_objs
        self.save_root = self.file_path + "/exp_2/{}/ep{}_run{}".format(
            self.dir, self.episode, self.run)
        self._create_directories()
        self.suck_weight = 1.0
        self.grasp_weight = 0.25
        self.count = 0
        self.last_iter_fail = None
        self.last_fail_primitive = None
        self.gripper_angle_list = [0, -45, -90, 45]  # 0 1 2 3
        self.bridge = CvBridge()
        self.background_color = self.file_path + "/" + args.color_bg
        self.background_depth = self.file_path + "/" + args.depth_bg
        self.action_wrapper = ActionWrapper()
        # Service
        self.service = rospy.Service(
            "~start", Empty,
            self.callback)  # Start process, until workspace is empty
        self.save_background = rospy.Service(
            "~save_bg", Empty,
            self.save_cb)  # Save bin background color and depth image
        self.reset_service = rospy.Service(
            "~reset", Empty, self.reset_cb
        )  # Reset `self.episode`, `self.run` and create new save root
        # Service client
        self.record_bag_client = rospy.ServiceProxy(
            "/autonomous_recording_node/start_recording", recorder)
        self.stop_record_client = rospy.ServiceProxy(
            "/autonomous_recording_node/stop_recording", Empty)
        try:
            self.camera_info = rospy.wait_for_message(self.color_topic.replace(
                "image_raw", "camera_info"),
                                                      CameraInfo,
                                                      timeout=5.0)
        except rospy.ROSException:
            rospy.logerr(
                "Can't get camera intrinsic after 5 seconds, terminate node..."
            )
            rospy.signal_shutdown("No intrinsic")
            sys.exit(0)
        load_ts = time.time()
        rospy.loginfo("Loading model...")
        self.suck_net = Net(args.n_classes)
        self.grasp_net = Net(args.n_classes)
        self.suck_net.load_state_dict(
            torch.load(self.file_path + "/" + args.suck_model))
        self.grasp_net.load_state_dict(
            torch.load(self.file_path + "/" + args.grasp_model))
        if self.use_cuda:
            self.suck_net = self.suck_net.cuda()
            self.grasp_net = self.grasp_net.cuda()
        rospy.loginfo("Load complete, time elasped: {}".format(time.time() -
                                                               load_ts))
        rospy.loginfo("current episode: \t{}".format(self.episode))
        rospy.loginfo("current run: \t{}".format(self.run))
        rospy.loginfo("current code: \t{}".format(
            self.encode_index(self.episode, self.run)))
        rospy.loginfo("Service ready")

    # encode recording index
    def encode_index(self, episode, run):
        res = 30000 + episode * 10 + run
        return res

    # Create directories for saving data
    def _create_directories(self):
        self.save_paths = [
            self.save_root + "/color/",  # color image from camera
            self.save_root + "/depth/",  # depth image from camera
            self.save_root + "/color_heightmap/",  # converted color heightmap
            self.save_root + "/depth_heightmap/",  # converted depth heightmap
            self.save_root +
            "/mixed_img/",  # color image and prediction heatmap
            self.save_root + "/pc/",  # pointcloud pcd files
            self.save_root + "/viz/"
        ]  # corresponding action and symbols
        for path in self.save_paths:
            if not os.path.exists(path):
                os.makedirs(path)

    # Save bin background color and depth image
    def save_cb(self, req):
        color_topic = rospy.wait_for_message(self.color_topic, Image)
        depth_topic = rospy.wait_for_message(self.depth_topic, Image)
        color_img = self.bridge.imgmsg_to_cv2(color_topic,
                                              desired_encoding="passthrough")
        color_img = cv2.cvtColor(color_img, cv2.COLOR_RGB2BGR)
        depth_img = self.bridge.imgmsg_to_cv2(depth_topic,
                                              desired_encoding="passthrough")
        cv2.imwrite(self.file_path + "/color_background.jpg", color_img)
        cv2.imwrite(self.file_path + "/depth_background.png", depth_img)
        rospy.loginfo("Background images saved")
        return EmptyResponse()

    # Forward pass and get suck prediction
    def _suck(self, color, depth):
        fg_mask = utils.background_subtraction(color, depth,
                                               self.background_color,
                                               self.background_depth)
        color_tensor, depth_tensor = utils.preprocessing(color, depth)
        if self.use_cuda:
            color_tensor = color_tensor.cuda()
            depth_tensor = depth_tensor.cuda()
        predict = self.suck_net.forward(color_tensor, depth_tensor)
        suctionable = predict.detach().cpu().numpy()[0, 1]
        suctionable = cv2.resize(suctionable,
                                 dsize=(suctionable.shape[1] * self.scale,
                                        suctionable.shape[0] * self.scale))
        suctionable[fg_mask == 0] = 0.0  # Background
        return suctionable

    # Forward pass and get grasp prediction (with number `self.grasp_angle`)
    def _grasp(self, color, depth):
        color_heightmap, depth_heightmap = utils.generate_heightmap(
            color, depth, self.camera_info, self.background_color,
            self.background_depth, self.voxel_size)
        graspable = np.zeros((self.grasp_angle, depth_heightmap.shape[1],
                              depth_heightmap.shape[0]))
        for i in range(self.grasp_angle):
            angle = -np.degrees(np.pi / self.grasp_angle * i)
            rotated_color_heightmap, rotated_depth_heightmap = utils.rotate_heightmap(
                color_heightmap, depth_heightmap, angle)
            color_tensor, depth_tensor = utils.preprocessing(
                rotated_color_heightmap, rotated_depth_heightmap)
            if self.use_cuda:
                color_tensor = color_tensor.cuda()
                depth_tensor = depth_tensor.cuda()
            predict = self.grasp_net.forward(color_tensor, depth_tensor)
            grasp = predict.detach().cpu().numpy()[0, 1]
            affordance = cv2.resize(grasp,
                                    dsize=(grasp.shape[1] * self.scale,
                                           grasp.shape[0] * self.scale))
            affordance[rotated_depth_heightmap == 0] = 0.0  # Background
            # affordance[depth_heightmap==0] = 0.0 # Background
            graspable[i, :, :] = affordance
        return color_heightmap, depth_heightmap, graspable

    # Start process, until workspace is empty
    def callback(self, req):
        rospy.loginfo("Receive start command")
        self.action_wrapper.reset()
        empty = False
        iter_count = 0
        valid_count = 0
        grasped_count = 0
        primitive = []  # `best_action_idx` `pixel y` `pixel x`
        position = []  # execution position in `base_link` frame
        suck_fail = 0
        grasp_fail = 0
        suck_weight = self.suck_weight
        grasp_weight = self.grasp_weight
        # Start recording
        self.record_bag_client(
            recorderRequest(self.encode_index(self.episode, self.run)))
        while empty is not True and iter_count < 2 * self.num_objs:
            rospy.loginfo("Baseline method, iter: {}".format(iter_count))
            # Get color and depth images
            color_topic = rospy.wait_for_message(self.color_topic, Image)
            depth_topic = rospy.wait_for_message(self.depth_topic, Image)
            color_img = self.bridge.imgmsg_to_cv2(
                color_topic, desired_encoding="passthrough")
            color_img = cv2.cvtColor(color_img, cv2.COLOR_RGB2BGR)
            depth_img = self.bridge.imgmsg_to_cv2(
                depth_topic, desired_encoding="passthrough")
            suckable = self._suck(color_img, depth_img)
            color_heightmap, depth_heightmap, graspable = self._grasp(
                color_img, depth_img)
            cv2.imwrite(self.save_paths[0] + "{:06}.png".format(iter_count),
                        color_img)  # color
            cv2.imwrite(self.save_paths[1] + "{:06}.png".format(iter_count),
                        depth_img)  # depth
            cv2.imwrite(self.save_paths[2] + "{:06}.png".format(iter_count),
                        color_heightmap)  # color heightmap
            cv2.imwrite(self.save_paths[3] + "{:06}.png".format(iter_count),
                        depth_heightmap)  # depth heightmap
            # Last-fail punishment (From Appendix A. `Avoid repeating unsuccessful attempts`)
            if self.last_iter_fail:
                if self.last_fail_primitive[0] == 0:  # suction
                    y = self.last_fail_primitive[1]
                    x = self.last_fail_primitive[2]
                    z_cam = self.last_fail_primitive[3]
                    punish_mask = utils.create_mask(
                        suckable.shape,
                        y,
                        x,
                        0.02,
                        img_type="image",
                        camera_info=self.camera_info,
                        z=z_cam)
                    suckable = np.multiply(suckable, punish_mask)
                else:  # gripper
                    y = self.last_fail_primitive[1]
                    x = self.last_fail_primitive[2]
                    punish_mask = utils.create_mask(graspable[0].shape,
                                                    y,
                                                    x,
                                                    0.02,
                                                    img_type="heightmap",
                                                    voxel_size=self.voxel_size)
                    graspable[self.last_fail_primitive[0] - 1] = np.multiply(
                        graspable[self.last_fail_primitive[0] - 1],
                        punish_mask)
            # Prevent unsuccessful grasping with same primitive (From Appendix A. `Encouraging exploration upon repeat failure`)
            if suck_fail == 2: suck_weight = 0.5 * self.suck_weight
            elif suck_fail >= 3: suck_weight = 0.25 * self.suck_weight
            if grasp_fail == 2: grasp_weight = 0.5 * self.grasp_weight
            elif grasp_fail >= 3: grasp_weight = 0.25 * self.grasp_weight
            rospy.loginfo("suck weight: {}\tgrasp weight: {}".format(
                suck_weight, grasp_weight))
            # Multiply by primitive weight (From Appendix A. `Suction first, grasp later.`)
            suckable *= suck_weight
            graspable *= grasp_weight
            suck_hm = utils.viz_affordance(suckable)
            suck_combined = cv2.addWeighted(color_img, 0.8, suck_hm, 0.8, 0)
            cv2.imwrite(self.save_paths[4] +
                        "suck_{:06}.png".format(iter_count),
                        suck_combined)  # mixed
            grasps_combined = []
            for i in range(self.grasp_angle):
                angle = -np.degrees(np.pi / self.grasp_angle * i)
                rotated_color_heightmap, _ = utils.rotate_heightmap(
                    color_heightmap, depth_heightmap, angle)
                grasp_hm = utils.viz_affordance(graspable[i])
                grasp_combined = cv2.addWeighted(rotated_color_heightmap, 0.8,
                                                 grasp_hm, 0.8, 0)
                cv2.imwrite(self.save_paths[4] +
                            "/grasp_{:06}_{}.png".format(iter_count, i),
                            grasp_combined)  # mixed
                grasps_combined.append(
                    grasp_combined)  # Stored rotated combined image
                '''angle = np.degrees(np.pi/self.grasp_angle*i)
				rotate_predict = utils.rotate_img(graspable[i], angle)
				grasp_hm = utils.viz_affordance(rotate_predict)
				grasp_combined = cv2.addWeighted(color_heightmap, 0.8, grasp_hm, 0.8, 0)
				cv2.imwrite(self.save_paths[4]+"/grasp_{:06}_{}.png".format(iter_count, i), grasp_combined) # mixed
				grasp_combined = utils.rotate_img(grasp_combined, -angle)
				grasps_combined.append(grasp_combined)'''
            # Select action and get position
            best_action = [
                np.max(suckable),
                np.max(graspable[0]),
                np.max(graspable[1]),
                np.max(graspable[2]),
                np.max(graspable[3])
            ]
            print "Action value: ", best_action
            best_action_idx = np.where(
                best_action == np.max(best_action))[0][0]
            gripper_angle = 0
            targetPt = np.zeros((3, 1))
            static = np.array([[0.0, -1.0, 0.0], [-1.0, 0.0, 0.0],
                               [0.0, 0.0, -1.0]])
            tool_id = 3
            will_collide = False
            if best_action_idx == 0:  # suck
                rospy.loginfo("Use \033[1;31msuction\033[0m")
                suck_pixel = np.where(suckable == np.max(suckable))
                suck_pixel = [suck_pixel[1][0], suck_pixel[0][0]]  # x, y
                primitive.append(
                    [best_action_idx, suck_pixel[1], suck_pixel[0]])
                cam_z = (depth_img[suck_pixel[1], suck_pixel[0]]).astype(
                    np.float32) / 1000
                cam_x = (suck_pixel[0] -
                         self.camera_info.K[2]) * cam_z / self.camera_info.K[0]
                cam_y = (suck_pixel[1] -
                         self.camera_info.K[5]) * cam_z / self.camera_info.K[4]
                camPt = np.array([[cam_x], [cam_y], [cam_z], [1.0]])
                camera_pose = np.loadtxt(self.file_path + "/camera_pose.txt")
                cam2arm = np.matmul(static.T, camera_pose[:3])
                targetPt = np.matmul(cam2arm, camPt).reshape(3)
                action = utils.draw_action(suck_combined, suck_pixel)
                cv2.imwrite(self.save_paths[6] +
                            "action_{:06}.png".format(iter_count),
                            action)  # viz
            else:  # gripper
                tool_id = 1
                best_angle_idx = best_action_idx - 1
                angle = np.degrees(np.pi / self.grasp_angle * best_angle_idx)
                gripper_angle = self.gripper_angle_list[best_angle_idx]
                rospy.loginfo(
                    "Use \033[1;31mparallel-jaw gripper with angle: {}\033[0m".
                    format(gripper_angle))
                binMiddleBottom = np.loadtxt(self.file_path + "/bin_pose.txt")
                rotate_predict = utils.rotate_img(graspable[best_angle_idx],
                                                  angle)
                grasp_pixel = np.where(
                    rotate_predict == np.max(rotate_predict))
                grasp_pixel = [grasp_pixel[1][0], grasp_pixel[0][0]]  # x, y
                primitive.append(
                    [best_action_idx, grasp_pixel[1], grasp_pixel[0]])
                u = grasp_pixel[0] - graspable[best_angle_idx].shape[0] / 2
                v = grasp_pixel[1] - graspable[best_angle_idx].shape[1] / 2
                tempPt = np.zeros((3, 1))
                # Position in fake link
                tempPt[0] = binMiddleBottom[0] + u * self.voxel_size  # X
                tempPt[1] = binMiddleBottom[1] + v * self.voxel_size  # Y
                tempPt[2] = depth_heightmap[grasp_pixel[1],
                                            grasp_pixel[0]].astype(
                                                np.float32) / 1000  # Z
                targetPt = np.matmul(static, tempPt).reshape(3)
                targetPt[2] = -targetPt[2] - binMiddleBottom[2]  # Have no idea
                gripper_center = np.where(graspable[best_angle_idx] == np.max(
                    graspable[best_angle_idx]))
                gripper_center = [gripper_center[1][0], gripper_center[0][0]]
                action = utils.draw_action(grasps_combined[best_angle_idx],
                                           gripper_center, "grasp")
                if angle != 0:  # Then rotate back
                    action_rotate = utils.rotate_img(action, angle)
                else:
                    action_rotate = action
                action_rotate[np.where(
                    color_heightmap == np.array([0, 0, 0]))] = 0
                cv2.imwrite(
                    self.save_paths[6] + "action_{:06}.png".format(iter_count),
                    action_rotate)
                self.action_wrapper.get_pc(
                    self.save_paths[5] + "{:06}_before.png".format(iter_count))
                will_collide = self.action_wrapper.check_if_collide(
                    targetPt, np.radians(gripper_angle))
            print "Target position: [{}, {}, {}]".format(
                targetPt[0], targetPt[1], targetPt[2])
            position.append(targetPt)
            # Check if out of range
            in_range = False
            bin_range = np.loadtxt(
                self.file_path +
                "/bin_range.txt")  # Range in `base_link` frame
            if bin_range[0][0] < targetPt[0] < bin_range[0][1] and \
               bin_range[1][0] < targetPt[1] < bin_range[1][1] and \
               bin_range[2][0] < targetPt[2] < bin_range[2][1]:
                in_range = True
            if will_collide or not in_range:
                self.last_iter_fail = True
                if best_action_idx == 0:
                    self.last_fail_primitive = [
                        best_action_idx, suck_pixel[1], suck_pixel[0], camPt[2]
                    ]
                else:
                    self.last_fail_primitive = [
                        best_action_idx, grasp_pixel[1], grasp_pixel[0]
                    ]
                if will_collide: rospy.logwarn("Will collide, abort action")
                if not in_range: rospy.logwarn("Out of range, abort action")
                iter_count += 1
                if best_action_idx == 0: suck_fail += 1
                if best_action_idx != 0: grasp_fail += 1
                self.action_wrapper.publish_data(iter_count, -1, False)
                continue
            self.action_wrapper.take_action(
                tool_id, targetPt, np.radians(gripper_angle))  # execute action
            valid_count += 1
            action_success = self.action_wrapper.check_if_success(
                tool_id,
                self.save_paths[5] + "{:06}_check.pcd".format(iter_count))
            self.action_wrapper.publish_data(iter_count, best_action_idx,
                                             action_success)
            if not action_success:  # fail
                self.last_iter_fail = True
                if best_action_idx == 0:  # suction fail
                    self.last_fail_primitive = [
                        best_action_idx, suck_pixel[1], suck_pixel[0], camPt[2]
                    ]  # suction, y, x (pixel), z (meter in camera coordinate)
                    suck_fail += 1
                else:  # gripper fail
                    self.last_fail_primitive = [
                        best_action_idx, gripper_center[1], gripper_center[0]
                    ]  # gripper, y, x (pixel)
                    grasp_fail += 1
                rospy.loginfo("Action fail")
                self.action_wrapper.reset()
                rospy.sleep(1.0)
            else:  # success
                grasped_count += 1
                suck_fail = grasp_fail = 0
                suck_weight = self.suck_weight
                grasp_weight = self.grasp_weight
                self.last_iter_fail = False
                self.last_fail_primitive = []
                rospy.loginfo("Action success")
                self.action_wrapper.place()
            self.action_wrapper.get_pc(self.save_paths[5] +
                                       "{:06}_after.pcd".format(iter_count))
            empty = self.action_wrapper.check_if_empty(
                self.save_paths[5] + "{:06}_after.pcd".format(iter_count))
            iter_count += 1
            rospy.sleep(1.0)
        self.stop_record_client()
        self.last_iter_fail = None
        self.last_fail_primitive = []
        rospy.loginfo("Complete")
        print("================================================")
        print("Number of iterations: {}".format(iter_count))
        print("Valid iterations: {}".format(valid_count))
        print("Grasped objects: {}".format(grasped_count))
        print("Pass test: {}".format(empty))
        print("================================================")
        f = open(
            self.save_root +
            "/{}.txt".format(self.encode_index(self.episode, self.run)), 'w')
        f.write("Number of iterations: {}\n".format(iter_count))
        f.write("Valid iterations: {}\n".format(valid_count))
        f.write("Grasped objects: {}\n".format(grasped_count))
        f.write("Pass test: {}\n".format(empty))
        f.close()
        np.savetxt(self.save_root + "/position.csv", position, delimiter=",")
        np.savetxt(self.save_root + "/action_target.csv",
                   primitive,
                   delimiter=",")

        return EmptyResponse()

    # Reset `self.episode`, `self.run` and create new save root
    def reset_cb(self, req):
        rospy.loginfo("current episode: \t{}".format(self.episode))
        rospy.loginfo("current run: \t{}".format(self.run))
        new_episode = int(raw_input("Input new episode: "))
        new_run = int(raw_input("Input new run: "))
        rospy.loginfo("episode set from {} to {}".format(
            self.episode, new_episode))
        rospy.loginfo("run set from {} to {}".format(self.run, new_run))
        self.episode = new_episode
        self.run = new_run
        rospy.loginfo("current code: \t{}".format(
            self.encode_index(self.episode, self.run)))
        self.save_root = self.file_path + "/exp_2/{}/ep{}_run{}".format(
            self.dir, self.episode, self.run)
        self._create_directories()
        self.last_iter_fail = None
        self.last_fail_primitive = []

        return EmptyResponse()
Ejemplo n.º 3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=100,
        metavar='N',
        help='input batch size for testing (default: %(default)s)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: %(default)s)')
    parser.add_argument('--dataset',
                        choices=['mnist', 'fashion-mnist'],
                        default='mnist',
                        metavar='D',
                        help='mnist/fashion-mnist (default: %(default)s)')
    parser.add_argument('--nonlin',
                        choices=['softplus', 'sigmoid', 'tanh'],
                        default='softplus',
                        metavar='D',
                        help='softplus/sigmoid/tanh (default: %(default)s)')
    parser.add_argument('--num-layers',
                        choices=['2', '3', '4'],
                        default=2,
                        metavar='N',
                        help='2/3/4 (default: %(default)s)')
    parser.add_argument('--epsilon',
                        type=float,
                        default=1.58,
                        metavar='E',
                        help='ball radius (default: %(default)s)')
    parser.add_argument('--test-epsilon',
                        type=float,
                        default=1.58,
                        metavar='E',
                        help='ball radius (default: %(default)s)')
    parser.add_argument(
        '--step-size',
        type=float,
        default=0.005,
        metavar='L',
        help='step size for finding adversarial example (default: %(default)s)'
    )
    parser.add_argument(
        '--num-steps',
        type=int,
        default=200,
        metavar='L',
        help=
        'number of steps for finding adversarial example (default: %(default)s)'
    )
    parser.add_argument(
        '--beta',
        type=float,
        default=0.005,
        metavar='L',
        help='regularization coefficient for Lipschitz (default: %(default)s)')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    if args.dataset == 'mnist':
        dataset = datasets.MNIST
    elif args.dataset == 'fashion-mnist':
        dataset = datasets.FashionMNIST
    else:
        raise ValueError('Unknown dataset %s', args.dataset)

    torch.manual_seed(args.seed)
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    test_loader = torch.utils.data.DataLoader(dataset(
        './' + args.dataset,
        train=False,
        transform=transforms.Compose([transforms.ToTensor()])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    model = Net(int(args.num_layers), args.nonlin).to(device)
    model_name = 'saved_models/' + args.dataset + '_' + str(
        args.num_layers) + '_' + args.nonlin + '_L2_' + str(
            args.epsilon) + '_EIGEN_' + str(args.beta)
    model.load_state_dict(torch.load(model_name))

    print(args)
    print(model_name)

    acc, empirical_acc = test_standard_adv(args, model, device, test_loader)
    certified_acc = test_cert(args, model, device, test_loader)

    print('Accuracy: {:.4f}, Empirical Robust Accuracy: {:.4f}, Certified Robust Accuracy: {:.4f}\n'.\
           format(acc, empirical_acc, certified_acc))
class temp_tracking():
    global gesture_id
    def __init__(self):
        self.cap = cv2.VideoCapture(0)
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        self.out = cv2.VideoWriter('G3_01123.avi',fourcc, 20.0, (640,480))
        self.hand_mask = []
        self.trigger = False
        self.after_trigger = False
        if torch.cuda.is_available():
            self.net = Net().cuda()
        else:
            self.net = Net()
        self.net.load_state_dict(torch.load(f='/home/intuitivecompting/catkin_ws/src/ur5/ur5_with_gripper/icl_phri_robotiq_control/src/model'))
        self.last_select = None
        self.tip_deque = deque(maxlen=20)
        self.tip_deque1 = deque(maxlen=20)
        self.tip_deque2 = deque(maxlen=20)
        self.mode = None
        self.center = None
        self.onehand_center = None
        self.two_hand_mode = None
        self.pick_center = None
        self.gesture_mode = None
        self.pick_tip = None

    def test(self, box ,draw_img):
        global gesture_id
        net = self.net
        frame = self.image.copy()
        preprocess = transforms.Compose([transforms.Resize((50, 50)),
                                                    transforms.ToTensor()])
        #preprocess = transforms.Compose([transforms.Pad(30),
         #                                             transforms.ToTensor()])
        x,y,w,h = box
        temp = frame[y:y+h, x:x+w, :]
        #temp = cv2.cvtColor(temp,cv2.COLOR_BGR2RGB)
        temp = cv2.blur(temp,(5,5))
        hsv = cv2.cvtColor(temp,cv2.COLOR_BGR2HSV)
        temp = cv2.inRange(hsv, Hand_low, Hand_high)
        image = Image.fromarray(temp)
        img_tensor = preprocess(image)
        img_tensor.unsqueeze_(0)
        img_variable = Variable(img_tensor).cuda()
        if torch.cuda.is_available():
            img_variable = Variable(img_tensor).cuda()
            out = np.argmax(net(img_variable).cpu().data.numpy()[0])
            #print(np.max(net(img_variable).cpu().data.numpy()[0]))
        else:
            img_variable = Variable(img_tensor)
            out = np.argmax(net(img_variable).data.numpy()[0])
            # if np.max(net(img_variable).cpu().data.numpy()[0]) > 0.3:
            #     out = np.argmax(net(img_variable).cpu().data.numpy()[0])
            # else:
            #     out = -1
        #cv2.rectangle(draw_img,(x,y),(x+w,y+h),(255, 0, 0),2)
        #cv2.putText(draw_img,str(out + 1),(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
        gesture_id = int(out +1)
        return int(out) + 1

    def get_current_frame(self):
        self.cap.release()
        self.cap = cv2.VideoCapture(0)
        OK, origin = self.cap.read()
        if OK:
            rect = camrectify(origin)
            warp = warp_img(rect)
            return warp.copy()
    

    def update(self):
        '''
        gesture flag for distinguish different scenario
        '''
        global color_flag
        OK, origin = self.cap.read()

        x = None
        if OK:
            #print(self.mode)
            rect = camrectify(origin)
            # self.out.write(rect)
            # rect = cv2.flip(rect,0)
            # rect = cv2.flip(rect,1)
            warp = warp_img(rect)
            thresh = get_objectmask(warp)
            cv2.imshow('thresh', thresh)
            self.image = warp.copy()
            draw_img1 = warp.copy()
            self.get_bound(draw_img1, thresh, visualization=True)
            cx, cy = None, None
            lx, rx = None, None

            # self.handls = []
            # hsv = cv2.cvtColor(warp.copy(),cv2.COLOR_BGR2HSV)
            # hand_mask = cv2.inRange(hsv, Hand_low, Hand_high)
            # hand_mask = cv2.dilate(hand_mask, kernel = np.ones((7,7),np.uint8))
            # (_,hand_contours, hand_hierarchy)=cv2.findContours(hand_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
            # for i , contour in enumerate(hand_contours):
            #     area = cv2.contourArea(contour)
            #     if area>600 and area < 100000 and hand_hierarchy[0, i, 3] == -1:					
            #         x,y,w,h = cv2.boundingRect(contour)
            #         self.handls.append([x, y, w, h])
            
            result = hand_tracking(warp_img(rect), cache(10), cache(10)).get_result()
            num_hand_view = len(result)
            # if num_hand_view == 0:
            #     self.tip_deque.clear()
            #     self.tip_deque1.clear()
            #     self.tip_deque2.clear()
            if num_hand_view == 0:
                if len(self.hand_mask) > 0 and self.after_trigger:
                    if color_flag is not None:
                        object_mask = get_objectmask(deepcopy(self.image))
                        if color_flag == "yellow":
                            color_mask = get_yellow_objectmask(deepcopy(self.image))
                        elif color_flag == "blue":
                            color_mask = get_blue_objectmask(deepcopy(self.image))
                        # elif color_flag == "green":
                        #     color_mask = get_green_objectmask(deepcopy(self.image))
                        mask = self.hand_mask[0]
                        for i in range(1, len(self.hand_mask), 1):
                            mask = cv2.bitwise_or(self.hand_mask[i],mask)                     
                        mask = cv2.bitwise_and(mask,color_mask)
                        temp_result = []
                        for cx, cy in self.surfacels:
                            if mask[cy, cx] == 255:
                                temp_result.append((cx, cy))
                    else:
                        object_mask = get_objectmask(deepcopy(self.image))
                        mask = self.hand_mask[0]
                        for i in range(1, len(self.hand_mask), 1):
                            mask = cv2.bitwise_or(self.hand_mask[i],mask)                     
                        mask = cv2.bitwise_and(mask,object_mask)
                        temp_result = []
                        for cx, cy in self.surfacels:
                            if mask[cy, cx] == 255:
                                temp_result.append((cx, cy))
                    '''
                    multihand
                    '''
                    self.draw = draw_img1
                    print("getting bitwise and when there is one finger after palm")
                    #print([temp_result, tips[0], center,3])
                    #self.hand_mask = []
                    #self.after_trigger = False
                    #netsend(temp_result, flag=1, need_unpack=True)
                    self.last_select = temp_result
                    self.mode = 3
                   # self.center = center
                    #return [temp_result, tips[0], center,3]
                else:
                    netsend([777,888], need_unpack=False, flag=-19)
            '''
            one hand in the view
            '''
            if num_hand_view == 1:
                center = result[0][0]
                tips = result[0][1]
                radius = result[0][2]
                box = result[0][3]
                fake_tip, fake_center = result[0][4]
                app = result[0][5]
                cv2.drawContours(draw_img1, [app],-1,(255, 0, 0),1)
                for k in range(len(tips)):
                    cv2.circle(draw_img1,tips[k],10,(255, 0, 0),2)
                    cv2.line(draw_img1,tips[k],center,(255, 0, 0),2)
                num_tips = len(tips)
                label = self.test(box, draw_img1)
                self.onehand_center = center
                #print(box)
                #label = -1
                #label = classifier(draw_img1,self.image, box)
                #self.tip_deque.appendleft(tips)
            # '''
            # one hand and one finger, flag == 1
            # '''
                
                    #rospy.loginfo("mask, trigger:{},{}".format(len(self.hand_mask), self.after_trigger))
                #if num_tips == 1 and len(self.boxls) > 0 and label == 1:
                if len(self.hand_mask) > 0 and self.after_trigger:
                    if color_flag is not None:
                        object_mask = get_objectmask(deepcopy(self.image))
                        if color_flag == "yellow":
                            color_mask = get_yellow_objectmask(deepcopy(self.image))
                            netsend([777,888], need_unpack=False, flag=-200)
                        elif color_flag == "blue":
                            color_mask = get_blue_objectmask(deepcopy(self.image))
                            netsend([777,888], need_unpack=False, flag=-100)
                        # elif color_flag == "green":
                        #     color_mask = get_green_objectmask(deepcopy(self.image))
                        mask = self.hand_mask[0]
                        for i in range(1, len(self.hand_mask), 1):
                            mask = cv2.bitwise_or(self.hand_mask[i],mask)                     
                        mask = cv2.bitwise_and(mask,color_mask)
                        temp_result = []
                        for cx, cy in self.surfacels:
                            if mask[cy, cx] == 255:
                                temp_result.append((cx, cy))
                    else:
                        object_mask = get_objectmask(deepcopy(self.image))
                        mask = self.hand_mask[0]
                        for i in range(1, len(self.hand_mask), 1):
                            mask = cv2.bitwise_or(self.hand_mask[i],mask)   
                        #print(mask.dtype, object_mask.dtype)                  
                        mask = cv2.bitwise_and(mask,object_mask)
                        temp_result = []
                        for cx, cy in self.surfacels:
                            if mask[cy, cx] == 255:
                                temp_result.append((cx, cy))
                    '''
                    multihand
                    '''
                    self.draw = draw_img1
                    print("getting bitwise and when there is one finger after palm")
                    if len(tips) == 0:
                        rospy.logwarn("no finger tips")
                    else:
                        #print([temp_result, tips[0], center,3])
                        #self.hand_mask = []
                        #self.after_trigger = False
                        self.last_select = temp_result
                        self.mode = 3
                        #self.center = center
                        return [temp_result, tips[0], center,3]

                if len(self.boxls) > 0 and num_tips == 1 and label != 4:        
                    if len(self.hand_mask) == 0 or not self.after_trigger:
                        #rospy.loginfo("single pointing")
                        #point = max(tips, key=lambda x: np.sqrt((x[0]- center[0])**2 + (x[1] - center[1])**2))
                        point = tips[0]
                        self.tip_deque.appendleft(point)
                        #
                        length_ls = []
                        for x, y, w, h in self.boxls:
                            length_ls.append((get_k_dis((point[0], point[1]), (center[0], center[1]), (x+w/2, y+h/2)), (x+w/2, y+h/2)))
                        length_ls = filter(lambda x: (point[1] - x[1][1]) * (point[1] - center[1]) <= 0, length_ls)
                        length_ls = filter(lambda x: x[1][1] - point[1] < 0, length_ls)
                        length_ls = filter(lambda x: x[0] < 15, length_ls)
                        if len(length_ls) > 0:
                            x,y = min(length_ls, key=lambda x: distant((x[1][0], x[1][1]), (point[0], point[1])))[1]
                            ind = test_insdie((x, y), self.boxls)
                            x, y, w, h = self.boxls[ind]
                            cx, cy = self.surfacels[ind]
                            cv2.rectangle(draw_img1,(x,y),(x+w,y+h),(0,0,255),2)
                            #cv2.circle(draw_img1, (cx, cy), 5, (0, 0, 255), -1)
                            #cv2.putText(draw_img1,"pointed",(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
                            
                            '''
                            flag is 1
                            '''
                            if self.trigger:
                                self.pick_tip = tuple([point[0],point[1]])
                            self.draw = draw_img1
                            self.last_select = [(cx, cy)]
                            netsend([cx, cy], need_unpack=False)
                            self.mode = 1
                            self.pick_center = center
                            return [[point[0],point[1]],(cx, cy), center,1]
                        else:
                            self.draw = draw_img1
                            self.mode = 1
                            self.pick_center = center
                            return [[point[0],point[1]], center,1]
            #  '''
            # one hand and two finger, flag == 2
            # '''
                elif num_tips == 2 and len(self.boxls) > 0 and label != 4:
                    boxls = deepcopy(self.boxls)
                    length_lsr = []
                    length_lsl = []
                    rpoint, lpoint = tips
                    for x, y, w, h in self.boxls:
                        length_lsr.append((get_k_dis((rpoint[0], rpoint[1]), (center[0], center[1]), (x+w/2, y+h/2)), (x+w/2, y+h/2)))
                    length_lsr = filter(lambda x: (rpoint[1] - x[1][1]) * (rpoint[1] - center[1]) <= 0, length_lsr)
                    length_lsr = filter(lambda x: x[0] < 20, length_lsr)
                    if len(length_lsr) > 0:
                        rx,ry = min(length_lsr, key=lambda x: distant((x[1][0], x[1][1]), (rpoint[0], rpoint[1])))[1]
                        rind = test_insdie((rx, ry), self.boxls)
                        rx, ry = self.surfacels[rind]
                        x, y, w, h = self.boxls[rind]
                        #rx, ry = int(x+w/2), int(y+h/2)
                        del boxls[rind]
                        cv2.rectangle(draw_img1,(x,y),(x+w,y+h),(0,0,255),2)
                        #cv2.putText(draw_img1,"pointed_right",(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
                        if len(boxls) > 0:
                            for x, y, w, h in boxls:
                                length_lsl.append((get_k_dis((lpoint[0], lpoint[1]), (center[0], center[1]), (x+w/2, y+h/2)), (x+w/2, y+h/2)))
                            length_lsl = filter(lambda x: (lpoint[1] - x[1][1]) * (lpoint[1] - center[1]) <= 0, length_lsl)
                            length_lsl = filter(lambda x: x[0] < 20, length_lsl)
                            if len(length_lsl) > 0:
                                lx,ly = min(length_lsl, key=lambda x: distant((x[1][0], x[1][1]), (lpoint[0], lpoint[1])))[1]
                                lind = test_insdie((lx, ly), boxls)
                                lx, ly = self.surfacels[lind]
                                x, y, w, h = boxls[lind]
                                #lx, ly = int(x+w/2), int(y+h/2)
                                cv2.rectangle(draw_img1,(x,y),(x+w,y+h),(0,0,255),2)
                                #cv2.putText(draw_img1,"pointed_left",(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
                                '''
                                flag is 2
                                '''
                                self.draw = draw_img1
                                self.last_select = [[rx, ry], [lx, ly]]
                                netsend([[rx, ry], [lx, ly]])
                                self.mode = 2
                                #self.center = center
                                self.pick_center = center
                                return [[tips[0][0], tips[0][1]], [tips[1][0], tips[1][1]], [rx, ry], [lx, ly], center,2]

                # '''
                # one hand and multi finger, flag == 3
                # '''
                elif num_tips > 0 and label == 3:
                    temp_center = (center[0], center[1] - 30)
                    if not self.trigger:
                        netsend(list(temp_center), need_unpack=False, flag=-18)
                    elif self.trigger:
                        # surface = np.ones(self.image.shape)
                        # cv2.circle(surface, center, 120, (255, 255, 255), -1)
                        # grayscaled = cv2.cvtColor(surface,cv2.COLOR_BGR2GRAY)
                        # retval, threshold = cv2.threshold(grayscaled, 10, 255, cv2.THRESH_BINARY)
                        # self.hand_mask.append(threshold)
                        self.hand_mask = []
                        self.hand_mask.append(get_handmask(deepcopy(self.image), center))
                        rospy.loginfo("get brushed")
                        self.draw = draw_img1
                        self.trigger = False
                        self.mode = 3
                        rospy.loginfo("send center information :{}".format(list(temp_center)))
                        netsend(list(temp_center), need_unpack=False, flag=-8)
                        self.pick_center = center
                        #self.center = center
                        return [temp_center,3]

                elif label == 4 and len(self.boxls) > 0 and len(tips) > 0 and len(tips) < 4:
                    #point = max(tips, key=lambda x: np.sqrt((x[0]- center[0])**2 + (x[1] - center[1])**2))
                    point = fake_tip
                    center = fake_center
                    length_ls = []
                    for x, y, w, h in self.boxls:
                        length_ls.append((get_k_dis((point[0], point[1]), (center[0], center[1]), (x+w/2, y+h/2)), (x+w/2, y+h/2)))
                    #length_ls = filter(lambda x: (point[1] - x[1][1]) * (point[1] - center[1]) <= 0, length_ls)
                    #length_ls = filter(lambda x: (point[0] - x[1][0]) * (center[0] - x[1][0]) > 0, length_ls)
                    length_ls = filter(lambda x: x[1][1] - point[1] < 0, length_ls)
                    #print("haha", len(length_ls))
                    length_ls = filter(lambda x: x[0] < 50, length_ls)
                    #print("ddd", len(length_ls))
                    sub_result = []
                    if color_flag is not None:
                        object_mask = get_objectmask(deepcopy(self.image))
                        if color_flag == "yellow":
                            color_mask = get_yellow_objectmask(deepcopy(self.image))
                        elif color_flag == "blue":
                            color_mask = get_blue_objectmask(deepcopy(self.image))
                        if len(length_ls) > 0:
                            for i in range(len(length_ls)):
                                # x,y = min(length_ls, key=lambda x: distant((x[1][0], x[1][1]), (point[0], point[1])))[1]
                                # ind = test_insdie((x, y), self.boxls)
                                x,y = length_ls[i][1]
                                ind = test_insdie((x, y), self.boxls)
                                x, y, w, h = self.boxls[ind]
                                cx, cy = self.surfacels[ind]
                                if color_mask[cy, cx] == 255:
                                    sub_result.append((cx, cy))
                                cv2.rectangle(draw_img1,(x,y),(x+w,y+h),(0,0,255),2)
                                #cv2.circle(draw_img1, (cx, cy), 5, (0, 0, 255), -1)
                                #cv2.putText(draw_img1,"general",(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
                            
                            '''
                            flag is 1
                            '''
                            self.draw = draw_img1
                            self.last_select = sub_result
                            self.mode = 6
                            self.pick_center = center
                            #self.center = center
                            return [sub_result, center ,6]
                        else:
                            self.draw = draw_img1
                            return None
                    
                    else:
                        if len(length_ls) > 0:
                            for i in range(len(length_ls)):
                            # x,y = min(length_ls, key=lambda x: distant((x[1][0], x[1][1]), (point[0], point[1])))[1]
                            # ind = test_insdie((x, y), self.boxls)
                                x,y = length_ls[i][1]
                                ind = test_insdie((x, y), self.boxls)
                                x, y, w, h = self.boxls[ind]
                                cx, cy = self.surfacels[ind]
                                sub_result.append((cx, cy))
                                cv2.rectangle(draw_img1,(x,y),(x+w,y+h),(0,0,255),2)
                                #cv2.circle(draw_img1, (cx, cy), 5, (0, 0, 255), -1)
                                #cv2.putText(draw_img1,"general",(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
                            netsend(sub_result, need_unpack=True)
                            self.draw = draw_img1
                            self.last_select = sub_result
                            self.mode = 6
                            self.pick_center = center
                            #self.center = center
                            return [sub_result, center ,6]
                        else:
                            self.draw = draw_img1
                            return None
                    







            '''
            two hand in the view
            '''
            if num_hand_view == 2:
                lcenter = result[0][0]
                ltips = result[0][1]
                lnum_tips = len(ltips)
                lradius = result[0][2]
                lbox = result[0][3]
                llabel = self.test(lbox, draw_img1) 
                app = result[0][5]
                cv2.drawContours(draw_img1, [app],-1,(255, 0, 0),1)
                for k in range(len(ltips)):
                    cv2.circle(draw_img1,ltips[k],10,(255, 0, 0),2)
                    cv2.line(draw_img1,ltips[k],lcenter,(255, 0, 0),2)

                rcenter = result[1][0]
                rtips = result[1][1]
                rnum_tips = len(rtips)
                rradius = result[1][2]
                rbox = result[1][3]
                rlabel = self.test(rbox, draw_img1)
                lapp = result[1][5]
                cv2.drawContours(draw_img1, [lapp],-1,(255, 0, 0),1)
                for k in range(len(rtips)):
                    cv2.circle(draw_img1,rtips[k],10,(255, 0, 0),2)
                    cv2.line(draw_img1,rtips[k],rcenter,(255, 0, 0),2)
                # '''
                # two hand is both one finger pointing, ONLY PLACE
                # '''
                if set([lnum_tips, rnum_tips]) == set([1,1]) and len(self.boxls) > 0 and set([llabel, rlabel]) == set([1,1]):
                    self.draw = draw_img1
                    
                    '''
                    flag is 4
                    '''
                    self.mode = 4
                    self.two_hand_mode =4
                    self.tip_deque1.appendleft((ltips[0][0], ltips[0][1]))
                    self.tip_deque2.appendleft((rtips[0][0], rtips[0][1]))
                    self.center = [list(lcenter), list(rcenter)]
                    return [[rtips[0][0], rtips[0][1]], [ltips[0][0], ltips[0][1]], [list(rcenter), list(lcenter)], 4]

                elif max(set([lnum_tips, rnum_tips])) >= 2 and min(set([lnum_tips, rnum_tips])) == 1 and max(set([llabel, rlabel])) < 4 and self.onehand_center:
                    #sub_result = filter(lambda x: len(x[1]) == 1 , [[rcenter, rtips], [lcenter, ltips]])
                    sub_result = max([[rcenter, rtips], [lcenter, ltips]], key=lambda x: distant(x[0], self.onehand_center))
                    center = sub_result[0]
                    tips = sub_result[1]
                    # center = sub_result[0][0]
                    # tips = sub_result[0][1]
                    self.tip_deque.appendleft((tips[0][0], tips[0][1]))
                    self.draw = draw_img1
                    
                    if max(set([lnum_tips, rnum_tips])) == 2 and set([lnum_tips, rnum_tips]) == set([1,2]):
                        self.mode = 1
                        self.two_hand_mode = 1
                        return [[tips[0][0], tips[0][1]], 1]
                    else:
                        self.mode = 5
                        self.two_hand_mode = 5
                        return [[tips[0][0], tips[0][1]], 5]
                
                elif min(set([lnum_tips, rnum_tips])) == 1 and max(set([llabel, rlabel])) == 4 and self.onehand_center:
                    #sub_result = filter(lambda x: len(x[1]) == 1 , [[rcenter, rtips], [lcenter, ltips]])
                    sub_result = max([[rcenter, rtips], [lcenter, ltips]], key=lambda x: distant(x[0], self.onehand_center))
                    center = sub_result[0]
                    tips = sub_result[1]
                    # center = sub_result[0][0]
                    # tips = sub_result[0][1]
                    self.tip_deque.appendleft((tips[0][0], tips[0][1]))
                    self.draw = draw_img1
                    self.mode = 1
                    self.two_hand_mode = 1
                    #rospy.loginfo("jdjdjdjjs")
                    return [[tips[0][0], tips[0][1]], 1]
        self.draw = draw_img1       

    def get_bound(self, image, object_mask, visualization=True):
        self.surfacels = []
        self.boxls = []
        (_,object_contours, object_hierarchy)=cv2.findContours(object_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
        if len(object_contours) > 0:
            for i , contour in enumerate(object_contours):
                area = cv2.contourArea(contour)
                if area>250 and area < 800 and object_hierarchy[0, i, 3] == -1:					
                    M = cv2.moments(contour)
                    cx = int(M['m10']/M['m00'])
                    cy = int(M['m01']/M['m00'])
                    x,y,w,h = cv2.boundingRect(contour)
                    self.surfacels.append((int(x+w/2), int(y+h/2)))
                    self.boxls.append((x, y, w, h))
        if len(self.boxls) > 0:
            boxls_arr = np.array(self.boxls)
            self.boxls = boxls_arr[boxls_arr[:, 0].argsort()].tolist()
            sur_array = boxls_arr = np.array(self.surfacels)
            self.surfacels = sur_array[boxls_arr[:, 0].argsort()].tolist()
            #print(self.surfacels)

        # for x, y, w, h in self.boxls:
        #     sub = image[y:y+h, x:x+w, :]
        #     hsv = cv2.cvtColor(sub,cv2.COLOR_BGR2HSV)
        #     top_mask = cv2.inRange(hsv, Top_low, Top_high)
        #     (_,top_contours, object_hierarchy)=cv2.findContours(top_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
        #     max_area = 0
        #     for i , contour in enumerate(top_contours):
        #         area = cv2.contourArea(contour)
        #         if area>max_area and object_hierarchy[0, i, 3] == -1:					
        #             M = cv2.moments(contour)
        #             cx = int(M['m10']/M['m00'])
        #             cy = int(M['m01']/M['m00'])
        #             max_area = area
        #     self.surfacels.append((cx+x, cy+y))

    def reset(self): 
        self.tip_deque.clear()
        self.tip_deque1.clear()
        self.tip_deque2.clear()
        self.two_hand_mode = None
        self.mode = 0
        self.last_select = None
        self.center = None
        self.pick_tip = None
        self.pick_center = None
        self.hand_mask = []
        self.onehand_center = None

    def __del__(self):
        self.cap.release()
        self.out.release()
Ejemplo n.º 5
0
class object_detector(): 
    def __init__(self, start): 	
        self.cap = cv2.VideoCapture(0)	
        self.start_time = start

        self.stored_flag = False
        self.trained_flag = False
        self.milstone_flag = False
        self.incremental_train_flag = False
        self.tracking_flag = False

        self.boxls = None
        self.count = 1
        self.new_count = 1
        self.path = "/home/intuitivecompting/Desktop/color/Smart-Projector/script/datasets/"
        if MODE == 'all':
            self.file = open(self.path + "read.txt", "w")
            self.milestone_file = open(self.path + "mileston_read.txt", "w")
        self.user_input = 0
        self.predict = None
        self.memory = cache(10)
        self.memory1 = cache(10)
        self.hand_memory = cache(10)

        self.node_sequence = []
        #-----------------------create GUI-----------------------#
        self.gui_img = np.zeros((130,640,3), np.uint8)
        cv2.circle(self.gui_img,(160,50),30,(255,0,0),-1)
        cv2.putText(self.gui_img,"start",(130,110),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(255,0,0))
        cv2.circle(self.gui_img,(320,50),30,(0,255,0),-1)
        cv2.putText(self.gui_img,"stop",(290,110),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,255,0))
        cv2.circle(self.gui_img,(480,50),30,(0,0,255),-1)
        cv2.putText(self.gui_img,"quit",(450,110),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
        cv2.namedWindow('gui_img')
        cv2.namedWindow('gui_img1')
        cv2.setMouseCallback('gui_img',self.gui_callback)
        cv2.setMouseCallback('gui_img1',self.gui_callback)
        #-----------------------Training sign--------------#
        self.training_surface = np.ones((610,640,3), np.uint8) * 255
        cv2.putText(self.training_surface,'Training...',(120,300),cv2.FONT_HERSHEY_SIMPLEX, 3.0,(255,192,203), 5)
        #----------------------new coming item id------------------#
        self.new_come_id = None
        self.old_come_id = None
        self.new_come_side = None
        self.old_come_side = None
        self.new_coming_lock = True
        self.once_lock = True
        #---------------------set some flag-------------------#
        self.storing = None
        self.quit = None
        self.once = True
        #---------------------set gui image----------------------#
        self.temp_surface = None
        #----------------------for easlier developing-----------------#
        if MODE == 'test':
            if not GPU:
                self.net = Net()
            else:
                self.net = Net().cuda()
            self.net.load_state_dict(torch.load(f=self.path + 'model'))
            self.user_input = 5
            self.stored_flag = True


    def update(self, save=True, train=False):
        
        self.boxls = []
        OK, origin = self.cap.read()
        if OK:
            rect = self.camrectify(origin)

            #-------------warp the image---------------------#
            warp = self.warp(rect)

            #-------------segment the object----------------#
            hsv = cv2.cvtColor(warp,cv2.COLOR_BGR2HSV)
            green_mask = cv2.inRange(hsv, Green_low, Green_high)
            # green_mask = cv2.inRange(hsv, np.array([45,90,29]), np.array([85,255,255]))
            hand_mask = cv2.inRange(hsv, Hand_low, Hand_high)
            hand_mask = cv2.dilate(hand_mask, kernel = np.ones((7,7),np.uint8))

            skin_mask = cv2.inRange(hsv, Skin_low, Skin_high)
            skin_mask = cv2.dilate(skin_mask, kernel = np.ones((7,7),np.uint8))

            
            
            thresh = 255 - green_mask
            thresh = cv2.subtract(thresh, hand_mask)
            thresh = cv2.subtract(thresh, skin_mask)
            thresh[477:, 50:610] = 0
            #thresh = cv2.dilate(thresh, kernel = np.ones((11,11),np.uint8))
            cv2.imshow('afg', thresh)
            draw_img1 = warp.copy()
            draw_img2 = warp.copy()
            draw_img3 = warp.copy()
            self.train_img = warp.copy()
            #-------------get the bounding box--------------
            self.get_bound(draw_img1, thresh, hand_mask, only=False, visualization=True)
            #--------------get bags of words and training-------#
            if MODE == 'all':
                #----------------------------storing image for each item---------#
                if not self.stored_flag:
                    self.temp_surface = np.vstack((draw_img1, self.gui_img))                    
                    self.stored_flag = self.store()
                    cv2.imshow('gui_img', self.temp_surface)
                #--------------------------training, just once------------------#
                if self.stored_flag and not self.trained_flag:  
                    cv2.destroyWindow('gui_img')
                    #cv2.imshow('training', self.training_surface)
                    self.trained_flag = self.train()
                #------------------------assembling and saving milstone---------#
                if self.trained_flag and not self.milstone_flag: 
                    self.test(draw_img2)
                    self.temp_surface = np.vstack((draw_img2, self.gui_img))
                    cv2.imshow('gui_img1', self.temp_surface)
                #-----------------------training saved milstone image---------#
                if self.milstone_flag and not self.incremental_train_flag:
                    cv2.destroyWindow('gui_img1')
                    self.incremental_train_flag = self.train(is_incremental=True)
                #-----------------------finalized tracking------------------#
                if self.incremental_train_flag and not self.tracking_flag:
                    self.test(draw_img3, is_tracking=True)
                    cv2.imshow('tracking', draw_img3)
            elif MODE == 'test':
                self.test(draw_img2)
                self.temp_surface = np.vstack((draw_img2, self.gui_img))
                cv2.imshow('gui_img', self.temp_surface)
                #cv2.imshow('track', draw_img2)
                #-----------------------training saved milstone image---------#
                if self.milstone_flag and not self.incremental_train_flag:
                    cv2.destroyWindow('gui_img')
                    self.incremental_train_flag = self.train(is_incremental=True)
                #-----------------------finalized tracking------------------#
                if self.incremental_train_flag and not self.tracking_flag:
                    self.test(draw_img3, is_tracking=True)
                    cv2.imshow('gui_img1', draw_img3)
            elif MODE == 'train':
                if not self.trained_flag:  
                    #cv2.destroyWindow('gui_img')
                    #cv2.imshow('training', self.training_surface)
                    self.trained_flag = self.train()
                #------------------------assembling and saving milstone---------#
                if self.trained_flag and not self.milstone_flag: 
                    self.test(draw_img2)
                    self.temp_surface = np.vstack((draw_img2, self.gui_img))
                    cv2.imshow('gui_img1', self.temp_surface)
                #-----------------------training saved milstone image---------#
                if self.milstone_flag and not self.incremental_train_flag:
                    cv2.destroyWindow('gui_img1')
                    self.incremental_train_flag = self.train(is_incremental=True)
                #-----------------------finalized tracking------------------#
                if self.incremental_train_flag and not self.tracking_flag:
                    self.test(draw_img3, is_tracking=True)
                    cv2.imshow('tracking', draw_img3)
        
    def get_bound(self, img, object_mask, hand_mask, only=True, visualization=True):
        (_,object_contours, object_hierarchy)=cv2.findContours(object_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
        (_,hand_contours, hand_hierarchy)=cv2.findContours(hand_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
        hand_m_ls = []
        object_m_ls = []
        if len(hand_contours) > 0:
            for i , contour in enumerate(hand_contours):
                area = cv2.contourArea(contour)
                if area>600 and area < 100000 and hand_hierarchy[0, i, 3] == -1:					
                    M = cv2.moments(contour)
                    cx = int(M['m10']/M['m00'])
                    cy = int(M['m01']/M['m00'])
                    hand_m_ls.append((cx, cy))
        if len(object_contours) > 0:
            for i , contour in enumerate(object_contours):
                area = cv2.contourArea(contour)
                if area>100 and area < 100000 and object_hierarchy[0, i, 3] == -1:					
                    M = cv2.moments(contour)
                    cx = int(M['m10']/M['m00'])
                    cy = int(M['m01']/M['m00'])
                    object_m_ls.append((cx, cy))
                    x,y,w,h = cv2.boundingRect(contour)
                    self.boxls.append([x, y, w, h])
        temp_i = []
        temp_j = []
        for (x3, y3) in hand_m_ls:
            for i in range(len(object_m_ls)):
                for j in range(i + 1, len(object_m_ls)):
                    x1, y1 = object_m_ls[i]
                    x2, y2 = object_m_ls[j]
                    d12 = distant((x1, y1), (x2, y2))
                    d13 = distant((x1, y1), (x3, y3))
                    d23 = distant((x2, y2), (x3, y3))
                    # dis = d13 * d23 / d12
                    # if dis < 60 and d12 < 140 and d13 < 100 and d23 < 100:
                    #     temp_i.append(i)
                    #     temp_j.append(j)
                    dis = self.get_k_dis((x1, y1), (x2, y2), (x3, y3))
                    if dis < 60 and d12 < 140 and d13 < 100 and d23 < 100:
                        temp_i.append(i)
                        temp_j.append(j)
                        # print(dis, d12, d13, d23)

        if len(temp_i) > 0 and len(temp_j) > 0 and len(self.boxls) >= 1:
            for (i, j) in zip(temp_i, temp_j):
                if self.boxls[i] != 0 and self.boxls[j] != 0:
                    x, y = np.min([self.boxls[i][0], self.boxls[j][0]]), np.min([self.boxls[i][1], self.boxls[j][1]])
                    x_max, y_max = np.max([self.boxls[i][0] + self.boxls[i][2], self.boxls[j][0] + self.boxls[j][2]]), np.max([self.boxls[i][1] + self.boxls[i][3], self.boxls[j][1] + self.boxls[j][3]])         
                    w, h = x_max - x, y_max - y
                    self.boxls[i] = 0
                    self.boxls[j] = [x, y, w, h]
            
            self.boxls = filter(lambda a: a != 0, self.boxls)   

            #---------------sorting the list according to the x coordinate of each item
        if len(self.boxls) > 0:
            boxls_arr = np.array(self.boxls)
            self.boxls = boxls_arr[boxls_arr[:, 0].argsort()].tolist()
        for i in range(len(self.boxls)): 
            if visualization: 
                ind = max(range(len(self.boxls)), key=lambda i:self.boxls[i][2]*self.boxls[i][3])
                x,y,w,h = self.boxls[ind]
                cv2.rectangle(img,(x,y),(x+w,y+h),(0,0,255),2)
                cv2.putText(img,str(self.user_input),(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))

     
    def gui_callback(self, event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDBLCLK and (self.temp_surface[y, x] == np.array([255, 0, 0])).all() and not self.storing:
            self.count = 1
            self.user_input += 1
            self.storing = True
            if self.user_input > 5:
                if self.once:
                    temp_node = node((self.new_come_id, self.old_come_id), (self.new_come_side, self.old_come_side),self.user_input)
                    self.once = False
                else:
                    temp_node = node((self.new_come_id, self.user_input - 1), (self.new_come_side, self.old_come_side), self.user_input)
                self.node_sequence.append(temp_node)
            print("start")
        if event == cv2.EVENT_LBUTTONDBLCLK and (self.temp_surface[y, x] == np.array([0, 255, 0])).all() and self.storing:
            self.storing = False
            self.new_coming_lock = True
            self.new_come_id = None
            self.old_come_id = None
            self.new_come_side = None
            self.old_come_side = None

            print("stop")
        if event == cv2.EVENT_LBUTTONDBLCLK and (self.temp_surface[y, x] == np.array([0, 0, 255])).all():
            self.storing = None
            self.quit = True
            print("quit")
            if self.stored_flag:
                x,y,w,h = self.boxls[0]
                sub_img = self.train_img[y:y+h, x:x+w, :]
                cv2.imwrite('test_imgs/saved' + str(self.user_input) + '.jpg', sub_img)
        # if event == cv2.EVENT_LBUTTONDBLCLK and (self.temp_surface[y, x] == np.array([255, 0, 255])).all():
        #     self.saving_milstone = True
        #     self.user_input += 1

    def store(self, is_milestone=False):
        # if is_milestone:
        #     file = self.milestone_file
        #     img_dir = os.path.join(self.path + "milestone_image", str(self.count) + ".jpg")
        #     self.createFolder(self.path + "milestone_image")
        # else:
        if is_milestone:
            self.file = open(self.path + "read.txt", "a")
            img_dir = os.path.join(self.path + "image", "milstone" + str(self.new_count) + ".jpg")
        else:
            img_dir = os.path.join(self.path + "image", str(self.new_count) + ".jpg")
        file = self.file
        self.createFolder(self.path + "image")
        if self.quit:
                file.close()
                print('finish output')               
                return True
        if len(self.boxls) > 0:
            if self.storing:
                cv2.putText(self.temp_surface,"recording",(450,50),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255), 2)
                frame = self.train_img
                ind = max(range(len(self.boxls)), key=lambda i:self.boxls[i][2]*self.boxls[i][3])
            #-------------capturing img for each of item--------------#
                x,y,w,h = self.boxls[ind]
                temp = frame[y:y+h, x:x+w, :]
                
                cv2.imwrite(img_dir, temp)         
                file.write(img_dir + " " + str(self.user_input) + "\n")
                if self.count % 100 == 0:
                    print('output imgs ' + str(self.count) + 'img' )
                self.count += 1
                self.new_count += 1 
                return False
            #-----------------get to the next item-----------    
        else:
            return False
        

    
        
    def train(self, is_incremental=False):
        if is_incremental:
            pickle.dump(node.pair_list ,open("node.p", "wb"))
        start_time = time.time()
        if not is_incremental:
            reader_train = self.reader(self.path, "read.txt")
            if not GPU:
                self.net = Net()
            else:
                self.net = Net().cuda()
        else:
            if not GPU:
                self.net = Net()
            else:
                self.net = Net().cuda()
            reader_train = self.reader(self.path, "read.txt")
            #self.net.load_state_dict(torch.load(f=self.path + 'model'))
        optimizer = optim.SGD(self.net.parameters(), lr=LR, momentum=MOMENTUM, nesterov=True)
        #optimizer = optim.Adam(self.net.parameters(), lr=LR, weight_decay=0.01)
        schedule = optim.lr_scheduler.StepLR(optimizer, step_size=STEP, gamma=GAMMA)
        trainset = CovnetDataset(reader=reader_train, transforms=transforms.Compose([transforms.Resize((200, 100)),
                                                                                            transforms.ToTensor()
                                                                                    ]))
        #trainset = CovnetDataset(reader=reader_train, transforms=transforms.Compose([transforms.Pad(30),
         #                                                                                     transforms.ToTensor()
          #                                                                            ]))
        trainloader = DataLoader(dataset=trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
#-----------------------------------training----------------------------------------------------------------        
        if True:
            loss_ls = []
            count = 0
            count_ls = []
            t = tqdm.trange(EPOTH, desc='Training')
            temp = 0
            for _ in t:  # loop over the dataset multiple times
                schedule.step()
                running_loss = 0.0
                i = 0
                for data in trainloader:

                    # get the inputs
                    inputs, labels = data
                    if GPU:
                        inputs, labels = inputs.cuda(), labels.cuda()
                    inputs, labels = Variable(inputs), Variable(labels.long())
                    # zero the parameter gradients
                    optimizer.zero_grad()
                    # forward + backward + optimize
                    outputs = self.net(inputs)
                    # print(outputs)
                    # print(labels.view(1, -1)[0])
                    loss = F.cross_entropy(outputs, labels.view(1, -1)[0])
                    loss.backward()
                    optimizer.step()
                    t.set_description('loss=%g' %(temp))

                    loss_ls.append(loss.item())
                    count += 1
                    count_ls.append(count)
                    
                    running_loss += loss.item()                    
                    if i % 10 == 9:   
                        temp = running_loss/10
                        running_loss = 0.0
                    i += 1
            plt.plot(count_ls, loss_ls)
            plt.show(block=False)
            print('Finished Training, using {} second'.format(int(time.time() - start_time)))
            
            self.quit = None
            
            if not is_incremental:
                self.user_input = 5
                torch.save(self.net.state_dict(), f=self.path + 'model')
            else:
                torch.save(self.net.state_dict(), f=self.path + 'milestone_model')
                # try:
                #     node_file = open(self.path + "node.txt", "w")
                #     for pair in node.pair_list: 
                #         node_file.write(str(pair[0][0]) + "" + str(pair[0][1]) + "" +str(pair[1][0]) + "" + str(pair[1][1]) + "\n")
                # except:
                #     print("fail to save")
            return True
#---------------------------------testing-----------------------------------------------
        
    def test(self, draw_img, is_tracking=False):
        self.predict = []
        net = self.net
        num_object = len(self.boxls)
        frame = self.train_img
        preprocess = transforms.Compose([transforms.Resize((200, 100)),
                                                    transforms.ToTensor()])
        #preprocess = transforms.Compose([transforms.Pad(30),
         #                                             transforms.ToTensor()])
        for i in range(num_object):
            x,y,w,h = self.boxls[i]
            temp = frame[y:y+h, x:x+w, :]
            temp = cv2.cvtColor(temp,cv2.COLOR_BGR2RGB)
            image = Image.fromarray(temp)
            img_tensor = preprocess(image)
            img_tensor.unsqueeze_(0)
            img_variable = Variable(img_tensor).cuda()
            if GPU:
                img_variable = Variable(img_tensor).cuda()
                out = np.argmax(net(img_variable).cpu().data.numpy()[0])
            else:
                img_variable = Variable(img_tensor)
                out = np.argmax(net(img_variable).data.numpy()[0])
            # if np.max(net(img_variable).cpu().data.numpy()[0]) > 0.9:
            #     out = np.argmax(net(img_variable).cpu().data.numpy()[0])
            # else:
            #     out = -1
            cv2.rectangle(draw_img,(x,y),(x+w,y+h),(0,0,255),2)
            cv2.putText(draw_img,str(out),(x,y),cv2.FONT_HERSHEY_SIMPLEX, 1.0,(0,0,255))
            self.predict.append(((x, y, w, h), out))
        if not is_tracking:
            if self.old_come_side is not None and self.new_come_side is None:
                cv2.putText(draw_img,"Point to next!",(220,50),cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0,0,255), 2)
            if self.new_come_side is not None and self.old_come_side is not None:
                cv2.putText(draw_img,"Start Assemble! Click Start when finish",(180,50),cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0,0,255), 2)
            lab, color, ind, coord = self.store_side(frame)
            if lab:
                self.get_pair(frame.copy(), num_object, lab, color, ind, coord)
            self.milstone_flag = self.store(is_milestone=True)
            
        # self.memory.append(self.predict)
        #print(len(self.memory.list))

    
    def store_side(self, frame):
        img = frame.copy()
        point, center = hand_tracking(img, self.hand_memory).get_result()
        if point and len(self.boxls) > 0:
            red_center = side_finder(img, color='red')
            blue_center = side_finder(img, color='blue')
            tape = red_center + blue_center
            length_ls = []
            for (x, y) in tape:
                length_ls.append((self.get_k_dis((point[0], point[1]), (center[0], center[1]), (x, y)), (x, y)))
            x,y = min(length_ls, key=lambda x: x[0])[1]
            cv2.circle(img, (x,y), 10, [255, 255, 0], -1)
            ind = test_insdie((x, y), self.boxls)

            # x,y,w,h = self.boxls[ind]
            # line_canvas = np.zeros((h, w))
            # cx, cy = center
            # x1, y1 = point
            # k = (y1-cy)/float(x1-cx)
            # cv2.line(line_canvas, point, (x1-50, y1-50*k), (255,0,0), 5)
            
            # frame_copy = frame.copy()
            # sub_img = frame_copy[y:y+h, x:x+w, :]
            # hsv = cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
            # object_mask = cv2.subtract(cv2.inRange(hsv, Green_low, Green_high),cv2.inRange(hsv, Hand_low, Hand_high))
            # object_mask = 255 - object_mask
            # (_,object_contours, object_hierarchy)=cv2.findContours(object_mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
            # max_area = 0
            # cnt = None
            # for i , contour in enumerate(object_contours):
            #     area = cv2.contourArea(contour)
            #     if object_hierarchy[0, i, 3] == -1 and area > max_area:	
            #         max_area = area
            #         cnt = contour	
            # cnt_canvas = np.zeros((h, w))



            
            
            # cv2.imshow("point", img)
            # print(ind, self.predict)
            if ind is not None:
                color = None
                if (x, y) in red_center:
                    color = 'red'
                elif (x, y) in blue_center:
                    color = 'blue'
                return self.predict[ind][1], color, ind, (x, y)
            else:
                return None, None, None, None
        else:
            return None, None, None, None
        # 
            


    def get_pair(self,image, num_object, label, color, index, coord):
        '''
        pointing from left to right
        '''
        if self.once and self.once_lock and num_object == 2:
            if index == 0:
                self.memory.append(self.predict[0][1])
                if self.memory.full and self.new_come_id is None:
                    self.old_come_id = max(set(self.memory.list), key=self.memory.list.count)
                    # if self.new_come_id == label:
                    self.old_come_side = self.draw_point(image, coord, index)
        
                    # cv2.circle(image, coord, 5, (125, 125), 1)
                    # x,y,w,h = self.boxls[index]
                    # sub_img = image[y:y+h, x:x+w, :]
                    # cv2.imwrite('saved' + str(self.predict[0][1]) + '.jpg', sub_img)
                    # else:
                    #     self.memory.clear()                                
                
            if self.memory.full and index == 1:    
                self.memory1.append(self.predict[-1][1])
                if self.memory1.full:
                    self.new_come_id = max(set(self.memory1.list), key=self.memory1.list.count)
                    # if self.old_come_id == label:
                    self.new_come_side = self.draw_point(image, coord, index)
                    # cv2.circle(image, coord, 5, (125, 125), 1)
                    # x,y,w,h = self.boxls[index]
                    # sub_img = image[y:y+h, x:x+w, :]
                    # cv2.imwrite('saved' + str(self.predict[-1][1]) + '.jpg', sub_img)
                    # else:
                    #     self.memory1.clear()
                    
            if self.memory.full and self.memory1.full:
                self.once_lock = False
                self.memory.clear()
                self.memory1.clear()
                print("new_come_id:{}, old_come_id:{}".format(self.new_come_id, self.old_come_id))
                print("new_come_side:{}, old_come_side:{}".format(self.new_come_side, self.old_come_side))
        print(self.new_come_side, self.old_come_side)
        
        '''
        pointing from left to right
        '''
        if not self.once and num_object == 2 and self.new_coming_lock:
            if index == 0:
                self.memory.append(0)
                if self.memory.full:
                    self.old_come_side = self.draw_point(image, coord, index, is_milestone=True)
                    self.memory.clear()
                # cv2.circle(image, coord, 5, (125, 125), 1)
                # x,y,w,h = self.boxls[index]
                # sub_img = image[y:y+h, x:x+w, :]
                # cv2.imwrite('saved' + str(self.user_input + 1) + '.jpg', sub_img)
            elif index == 1 and self.new_come_id is None:               
                self.memory1.append(self.predict[-1][1])
                if self.memory1.full:
                    self.new_come_id = max(set(self.memory1.list), key=self.memory1.list.count)                    
                    self.new_come_side = self.draw_point(image, coord, index)
                    self.memory1.clear()

                    
                    # cv2.circle(image, coord, 5, (125, 125), 1)
                    # x,y,w,h = self.boxls[index]
                    # sub_img = image[y:y+h, x:x+w, :]
                    # cv2.imwrite('saved' + str(self.predict[1][1]) + '.jpg', sub_img)

            if self.new_come_side and self.old_come_side:
                self.new_coming_lock = False
                print("new_come_id:{}".format(self.new_come_id))
                print("new_come_side:{}, old_come_side:{}".format(self.new_come_side, self.old_come_side))



    def draw_point(self, image, coord, index, is_milestone=False):
        #cv2.circle(image, coord, 5, (125, 125), 1)
        x,y,w,h = self.boxls[index]
        sub_img = image[y:y+h, x:x+w, :]
        #cv2.circle(sub_img, (coord[0] - x, coord[1] - y) , 5, (125, 125), -1)
        if not is_milestone:
            cv2.imwrite('test_imgs/saved' + str(self.predict[index][1]) + '.jpg', sub_img)
            return (coord[0] - x, coord[1] - y)
        else:
            cv2.imwrite('test_imgs/saved' + str(self.user_input) + '.jpg', sub_img)
            return (coord[0] - x, coord[1] - y)


    def warp(self, img):
        #pts1 = np.float32([[115,124],[520,112],[2,476],[640,480]])
        pts1 = np.float32([[101,160],[531,133],[0,480],[640,480]])
        pts2 = np.float32([[0,0],[640,0],[0,480],[640,480]])
        M = cv2.getPerspectiveTransform(pts1,pts2)
        dst = cv2.warpPerspective(img,M,(640,480))
        return dst
            

    @staticmethod
    def get_k_dis((x1, y1), (x2, y2), (x, y)):
        coord = ((x, y), (x1, y1), (x2, y2))
        return Polygon(coord).area / distant((x1, y1), (x2, y2))
class DQN:
    def __init__(self,
                 memory_size=50000,
                 batch_size=128,
                 gamma=0.99,
                 lr=1e-3,
                 n_step=500000):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.gamma = gamma

        # memory
        self.memory_size = memory_size
        self.Memory = ReplayMemory(self.memory_size)
        self.batch_size = batch_size

        # network
        self.target_net = Net().to(self.device)
        self.eval_net = Net().to(self.device)
        self.target_update()  # initialize same weight
        self.target_net.eval()

        # optim
        self.optimizer = optim.Adam(self.eval_net.parameters(), lr=lr)

    def select_action(self, state, eps):
        prob = random.random()
        if prob > eps:
            return self.eval_net.act(state), False
        else:
            return (torch.tensor(
                [[random.randrange(0, 9)]],
                device=self.device,
                dtype=torch.long,
            ), True)

    def select_dummy_action(self, state):
        state = state.reshape(3, 3, 3)

        open_spots = state[:, :, 0].reshape(-1)

        p = open_spots / open_spots.sum()

        return np.random.choice(np.arange(9), p=p)

    def target_update(self):
        self.target_net.load_state_dict(self.eval_net.state_dict())

    def learn(self):
        if self.Memory.__len__() < self.batch_size:
            return

        # random batch sampling
        transitions = self.Memory.sampling(self.batch_size)
        batch = Transition(*zip(*transitions))

        non_final_mask = torch.tensor(
            tuple(map(lambda s: s is not None, batch.next_state)),
            device=self.device,
            dtype=torch.bool,
        )

        non_final_next_states = torch.cat(
            [s for s in batch.next_state if s is not None]).to(self.device)
        state_batch = torch.cat(batch.state).to(self.device)
        action_batch = torch.cat(batch.action).to(self.device)
        reward_batch = torch.cat(batch.reward).to(self.device)

        # Q(s)
        Q_s = self.eval_net(state_batch).gather(1, action_batch)

        # maxQ(s') no grad for target_net
        Q_s_ = torch.zeros(self.batch_size, device=self.device)
        Q_s_[non_final_mask] = self.target_net(non_final_next_states).max(
            1)[0].detach()

        # Q_target=R+γ*maxQ(s')
        Q_target = reward_batch + (Q_s_ * self.gamma)

        # loss_fnc---(R+γ*maxQ(s'))-Q(s)
        # huber loss with delta=1
        loss = F.smooth_l1_loss(Q_s, Q_target.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.eval_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

    def load_net(self, name):
        self.action_net = torch.load(name).cpu()

    def load_weight(self, name):
        self.eval_net.load_state_dict(torch.load(name))
        self.eval_net = self.eval_net.cpu()

    def act(self, state):
        with torch.no_grad():
            p = F.softmax(self.action_net.forward(state)).cpu().numpy()
            valid_moves = (state.cpu().numpy().reshape(
                3, 3, 3).argmax(axis=2).reshape(-1) == 0)
            p = valid_moves * p
            return p.argmax()