def __getitem__(self, idx):
        image_path = self.paths[idx][0][0]
        image_path = os.path.join(self.root_dir, image_path)

        box = self.bboxes[0, idx][0]
        eye = self.eyes[0, idx][0]
        # todo: process gaze differently for training or testing
        gaze = self.gazes[0, idx].mean(axis=0)
        headPose = self.headPose[0, idx][0]
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)

        if random.random() > 0.5 and self.training == 'train':
            eye = [1.0 - eye[0], eye[1]]
            gaze = [1.0 - gaze[0], gaze[1]]
            image = cv2.flip(image, 1)
        h, w = image.shape[:2]
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = data_transforms[self.training](image)

        # generate gaze field
        gaze_field = self.generate_data_field(eye_point=eye)
        # generate heatmap
        heatmap = get_paste_kernel((224 // 4, 224 // 4), gaze, kernel_map,
                                   (224 // 4, 224 // 4))
        '''
        direction = gaze - eye
        norm = (direction[0] ** 2.0 + direction[1] ** 2.0) ** 0.5
        if norm <= 0.0:
            norm = 1.0

        direction = direction / norm
        '''
        sample = {
            'image': image,
            'eye_position': torch.FloatTensor(eye),
            'gaze_field': torch.from_numpy(gaze_field),
            'head_pose': torch.FloatTensor(headPose),
            'gt_position': torch.FloatTensor(gaze),
            'gt_heatmap': torch.FloatTensor(heatmap).unsqueeze(0)
        }

        return sample
예제 #2
0
    def __getitem__(self, idx):
        image_path = self.paths[idx][0][0]
        image_path = os.path.join(self.root_dir, image_path)
        box = self.bboxes[0, idx][0]
        eye = self.eyes[0, idx][0]
        # todo: process gaze differently for training or testing
        gaze = self.gazes[0, idx].mean(axis=0)

        image = cv2.imread(image_path, cv2.IMREAD_COLOR)

        if random.random() > 0.5 and self.training == 'train':
            eye = [1.0 - eye[0], eye[1]]
            gaze = [1.0 - gaze[0], gaze[1]]
            image = cv2.flip(image, 1)

        # crop face
        x_c, y_c = eye
        x_0 = x_c - 0.15
        y_0 = y_c - 0.15
        x_1 = x_c + 0.15
        y_1 = y_c + 0.15
        if x_0 < 0:
            x_0 = 0
        if y_0 < 0:
            y_0 = 0
        if x_1 > 1:
            x_1 = 1
        if y_1 > 1:
            y_1 = 1
        h, w = image.shape[:2]
        face_image = image[int(y_0 * h):int(y_1 * h),
                           int(x_0 * w):int(x_1 * w), :]
        # process face_image for face net
        face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
        face_image = Image.fromarray(face_image)
        face_image = data_transforms[self.training](face_image)
        # process image for saliency net
        # image = image_preprocess(image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = data_transforms[self.training](image)

        # generate gaze field
        gaze_field = self.generate_data_field(eye_point=eye)
        # generate heatmap
        heatmap = get_paste_kernel((224 // 4, 224 // 4), gaze, kernel_map,
                                   (224 // 4, 224 // 4))
        '''
        direction = gaze - eye
        norm = (direction[0] ** 2.0 + direction[1] ** 2.0) ** 0.5
        if norm <= 0.0:
            norm = 1.0

        direction = direction / norm
        '''
        sample = {
            'image': image,
            'face_image': face_image,
            'eye_position': torch.FloatTensor(eye),
            'gaze_field': torch.from_numpy(gaze_field),
            'gt_position': torch.FloatTensor(gaze),
            'gt_heatmap': torch.FloatTensor(heatmap).unsqueeze(0)
        }

        return sample
예제 #3
0
    def __getitem__(self, idx):
        # print(self.data)
        get_data = self.data[idx]
        image_path = get_data['filename']
        image_path = os.path.join(self.root_dir, image_path)

        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        self.height, self.width, _ = image.shape

        #### NORMALIZE VARIABLES
        x_0 = get_data['ann']['bboxes'][-1, 0] / self.width  # xmin_headbb
        y_0 = get_data['ann']['bboxes'][-1, 1] / self.height  # ymin_headbb
        x_1 = get_data['ann']['bboxes'][-1, 2] / self.width  # xmax_headbb
        y_1 = get_data['ann']['bboxes'][-1, 3] / self.height  # ymax_headbb

        x_c = get_data['hx'] / self.width  # use head coordinates for eyes
        y_c = get_data['hy'] / self.height
        eye = np.array([x_c, y_c])

        gaze_x = get_data['gaze_cx'] / self.width
        gaze_y = get_data['gaze_cy'] / self.height
        gaze = np.array([gaze_x, gaze_y])

        if random.random() > 0.5 and self.training == 'train':
            eye = [1.0 - eye[0], eye[1]]
            gaze = [1.0 - gaze[0], gaze[1]]
            image = cv2.flip(image, 1)

            # box[0] = w-box[0]
            # box[2] = w-box[2]
        # # crop face
        # x_c, y_c = eye
        # x_0 = x_c - 0.15
        # y_0 = y_c - 0.15
        # x_1 = x_c + 0.15
        # y_1 = y_c + 0.15
        # if x_0 < 0:
        #     x_0 = 0
        # if y_0 < 0:
        #     y_0 = 0
        # if x_1 > 1:
        #     x_1 = 1
        # if y_1 > 1:
        #     y_1 = 1
        h, w = image.shape[:2]
        # print('h,w:',h,w, eye, x_c, y_c)
        face_image = image[int(y_0 * h):int(y_1 * h),
                           int(x_0 * w):int(x_1 * w), :]
        # face_image = image[int(box[1]):int(box[3]), int(box[0]):int(box[2])]
        # print('\n\n\n\n',face_image.shape, box)
        # process face_image for face net
        face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
        face_image = Image.fromarray(face_image)
        face_image = data_transforms[self.training](face_image)
        # process image for saliency net
        #image = image_preprocess(image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = data_transforms[self.training](image)

        # generate gaze field
        gaze_field = self.generate_data_field(eye_point=eye)
        # generate heatmap
        heatmap = get_paste_kernel((224 // 4, 224 // 4), gaze, kernel_map,
                                   (224 // 4, 224 // 4))
        '''
        direction = gaze - eye
        norm = (direction[0] ** 2.0 + direction[1] ** 2.0) ** 0.5
        if norm <= 0.0:
            norm = 1.0

        direction = direction / norm
        '''
        sample = {
            'image': image,
            'face_image': face_image,
            'eye_position': torch.FloatTensor(eye),
            'gaze_field': torch.from_numpy(gaze_field),
            'gt_position': torch.FloatTensor(gaze),
            'gt_heatmap': torch.FloatTensor(heatmap).unsqueeze(0)
        }

        return sample
예제 #4
0
    def __getitem__(self, idx):

        data = self.data[idx]
        # print("keys",data.keys())
        image_path = data['filename']
        image_path = os.path.join(self.root_dir, image_path)
        # print(image_path)

        eye = [float(data['hx']) / 640, float(data['hy']) / 480]
        gaze = [float(data['gaze_cx']) / 640, float(data['gaze_cy']) / 480]
        # print('eye coords: ', eye)
        # print('gaze coords: ', gaze)

        image_path = image_path.replace('\\', '/')
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)

        flip_flag = False
        if random.random() > 0.5 and self.training == 'train':
            eye = [1.0 - eye[0], eye[1]]
            gaze = [1.0 - gaze[0], gaze[1]]
            image = cv2.flip(image, 1)
            flip_flag = True
            # print("FLIPPED!")

        # crop face
        x_c, y_c = eye
        x_0 = x_c - 0.15
        y_0 = y_c - 0.15
        x_1 = x_c + 0.15
        y_1 = y_c + 0.15
        if x_0 < 0:
            x_0 = 0
        if y_0 < 0:
            y_0 = 0
        if x_1 > 1:
            x_1 = 1
        if y_1 > 1:
            y_1 = 1
        h, w = image.shape[:2]
        face_image = image[int(y_0 * h):int(y_1 * h),
                           int(x_0 * w):int(x_1 * w), :]
        # process face_image for face net
        face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
        face_image = Image.fromarray(face_image)
        face_image = data_transforms[self.training](face_image)
        # process image for saliency net
        # image = image_preprocess(image)
        original_image = np.copy(image)
        original_image = cv2.resize(original_image, (224, 224))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = data_transforms[self.training](image)

        # generate gaze field
        gaze_field = self.generate_data_field(eye_point=eye)
        # generate heatmap
        heatmap = get_paste_kernel((224 // 4, 224 // 4), gaze, kernel_map,
                                   (224 // 4, 224 // 4))

        # Get additional gaze data
        gaze_idx = np.copy(data['gazeIdx'])
        gaze_class = np.copy(data['gaze_item']).astype(np.int64)

        gaze_bbox = np.copy(data['ann']['bboxes'][gaze_idx])

        # normalize
        gaze_bbox[0] = gaze_bbox[0] / 640.0
        gaze_bbox[1] = gaze_bbox[1] / 480.0
        gaze_bbox[2] = gaze_bbox[2] / 640.0
        gaze_bbox[3] = gaze_bbox[3] / 480.0

        head_bbox = np.copy(
            data['ann']['bboxes'][-1])  # last in the list is the head bbox

        # normalize
        head_bbox[0] = head_bbox[0] / 640.0
        head_bbox[1] = head_bbox[1] / 480.0
        head_bbox[2] = head_bbox[2] / 640.0
        head_bbox[3] = head_bbox[3] / 480.0

        # flip bbox locations
        if flip_flag:
            # flip gaze bbox
            xmin = gaze_bbox[0]
            xmax = gaze_bbox[2]
            gaze_bbox[0] = 1 - xmax
            gaze_bbox[2] = 1 - xmin

            # flip head_bbox
            xmin = head_bbox[0]
            xmax = head_bbox[2]
            head_bbox[0] = 1 - xmax
            head_bbox[2] = 1 - xmin

        # generate directed gaze_field using head loc and item loc
        direction = np.array(gaze) - np.array(eye)
        norm = (direction[0]**2.0 + direction[1]**2.0)**0.5
        if norm <= 0.0:
            norm = 1.0
        direction = direction / norm
        direction = torch.from_numpy(direction)
        direction = torch.unsqueeze(direction, 0)
        # print('d',direction)
        # print('g',gaze)
        # print('e',eye)

        channel, height, width = gaze_field.shape
        gaze_field_directed = torch.from_numpy(np.copy(gaze_field))
        gaze_field_directed = torch.unsqueeze(gaze_field_directed, 0)
        gaze_field_directed = gaze_field_directed.permute([0, 2, 3, 1])
        gaze_field_directed = gaze_field_directed.view([1, -1, 2])

        gaze_field_directed = torch.matmul(gaze_field_directed,
                                           direction.view([1, 2, 1]).float())
        gaze_field_directed = gaze_field_directed.view([1, height, width, 1])
        gaze_field_directed = gaze_field_directed.permute([0, 3, 1, 2])
        # gaze_field_directed = torch.squeeze(gaze_field_directed, 0)
        gaze_field_directed[gaze_field_directed < 0] = 0

        # print(torch.max(gaze_field_directed))
        # print(torch.min(gaze_field_directed))

        # different alpha
        gaze_field_directed_2 = torch.pow(gaze_field_directed, 2)
        gaze_field_directed_3 = torch.pow(gaze_field_directed, 5)
        gaze_field_directed_4 = torch.pow(gaze_field_directed, 20)
        gaze_field_directed_5 = torch.pow(gaze_field_directed, 100)
        gaze_field_directed_6 = torch.pow(gaze_field_directed, 500)

        directed_gaze_field = torch.cat([
            gaze_field_directed, gaze_field_directed_2, gaze_field_directed_3,
            gaze_field_directed_4, gaze_field_directed_5, gaze_field_directed_6
        ])
        # gaze_field_directed = gaze_field_directed.numpy()

        # gt_bboxes = np.copy(data['ann']['bboxes']).tolist()
        # # gt_bboxes = np.expand_dims(gt_bboxes, axis=0)
        # gt_labels = np.copy(data['ann']['labels']).tolist()
        # # gt_labels = np.expand_dims(gt_labels, axis=0)
        # # print("boxes", gt_bboxes.shape)
        # # print("labels",gt_labels.shape)
        # print("boxes", len(gt_bboxes))
        # print("labels",len(gt_labels))

        # create bbox masks

        c, h, w = image.shape
        gaze_bbox_mask = np.zeros((h, w))
        gaze_bbox_mask[int(gaze_bbox[1] * h):int(gaze_bbox[3] * h),
                       int(gaze_bbox[0] * w):int(gaze_bbox[2] * w)] = 1
        gaze_bbox_mask = cv2.resize(gaze_bbox_mask, (224, 224))
        gaze_bbox = gaze_bbox.astype(np.float32)

        head_bbox_mask = np.zeros((h, w))
        head_bbox_mask[int(head_bbox[1] * h):int(head_bbox[3] * h),
                       int(head_bbox[0] * w):int(head_bbox[2] * w)] = 1
        head_bbox_mask = cv2.resize(head_bbox_mask, (224, 224))
        head_bbox = head_bbox.astype(np.float32)

        # Create GazeMask heatmap
        seg = data['seg']
        if seg is not None:
            seg_mask = create_mask(np.array(seg).astype(np.int64))
            seg_mask = cv2.resize(seg_mask, (224 // 4, 224 // 4))

            if flip_flag:
                seg_mask = cv2.flip(seg_mask, 1)  # horizontal flip

            seg_mask = seg_mask.astype(np.float64) / 255.0
            if self.alpha < 1.0:
                gaze_mask = self.alpha * seg_mask + (1 - self.alpha) * heatmap
            elif self.alpha == 1.0:
                gaze_mask = np.fmax(seg_mask, heatmap)

        else:
            gaze_mask = np.zeros((224 // 4, 224 // 4))

        sample = {
            'image': image,
            'face_image': face_image,
            'eye_position': torch.FloatTensor(eye),
            'gaze_field': torch.from_numpy(
                gaze_field),  # this gaze field does not have direction yet
            # direction is computed during training
            # "ground truth" directed gaze field is not used
            'gt_position': torch.FloatTensor(gaze),
            'gt_heatmap': torch.FloatTensor(heatmap).unsqueeze(0),
            'image_path': image_path,
            'original_image': original_image,
            'gaze_idx': gaze_idx,
            'gaze_class': gaze_class,
            'gaze_bbox': gaze_bbox,
            'head_bbox': head_bbox,
            'head_loc': eye,
            'directed_gaze_field': directed_gaze_field,
            # 'gt_bboxes': gt_bboxes,
            # 'gt_labels': gt_labels,
            'gaze_bbox_mask': gaze_bbox_mask,
            'head_bbox_mask': head_bbox_mask,
            'gaze_mask_heatmap':
            torch.FloatTensor(gaze_mask).unsqueeze(0),  # GazeMask
            'gaze_heatmap':
            torch.FloatTensor(heatmap).unsqueeze(0),  # original gaze heatmap
        }

        if self.use_gazemask:
            sample['gt_heatmap'] = torch.FloatTensor(gaze_mask).unsqueeze(0)

        return sample
예제 #5
0
    def __getitem__(self, idx):
        image_path = self.paths[idx][0][0]
        image_path = os.path.join(self.root_dir, image_path)

        box = np.copy(self.bboxes[0, idx][0])
        # Note: original box annotations are [xmin, ymin, width, height]
        box[2] = box[0] + box[2]  # xmax = xmin + width
        box[3] = box[1] + box[3]  # ymax = xmax + height
        eye = self.eyes[0, idx][0]
        # todo: process gaze differently for training or testing
        gaze = self.gazes[0, idx].mean(axis=0)

        image_path = image_path.replace('\\', '/')
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)

        if random.random() > 0.5 and self.training == 'train':
            eye = [1.0 - eye[0], eye[1]]
            gaze = [1.0 - gaze[0], gaze[1]]
            image = cv2.flip(image, 1)

            xmin = box[0]
            xmax = box[2]
            box[0] = 1 - xmax
            box[2] = 1 - xmin

        # crop face
        x_c, y_c = eye
        x_0 = x_c - 0.15
        y_0 = y_c - 0.15
        x_1 = x_c + 0.15
        y_1 = y_c + 0.15
        if x_0 < 0:
            x_0 = 0
        if y_0 < 0:
            y_0 = 0
        if x_1 > 1:
            x_1 = 1
        if y_1 > 1:
            y_1 = 1
        h, w = image.shape[:2]
        face_image = image[int(y_0 * h):int(y_1 * h),
                           int(x_0 * w):int(x_1 * w), :]
        # process face_image for face net
        face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
        face_image = Image.fromarray(face_image)
        face_image = data_transforms[self.training](face_image)

        head_bbox = np.array([x_0, y_0, x_1, y_1])
        # process image for saliency net
        # image = image_preprocess(image)
        original_image = np.copy(image)
        original_image = cv2.resize(original_image, (224, 224))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = data_transforms[self.training](image)

        # generate gaze field
        gaze_field = self.generate_data_field(eye_point=eye)
        # generate heatmap
        heatmap = get_paste_kernel((224 // 4, 224 // 4), gaze, kernel_map,
                                   (224 // 4, 224 // 4))
        '''
        direction = gaze - eye
        norm = (direction[0] ** 2.0 + direction[1] ** 2.0) ** 0.5
        if norm <= 0.0:
            norm = 1.0
        direction = direction / norm
        '''

        # generate directed gaze_field using head loc and item loc
        direction = np.array(gaze) - np.array(eye)
        norm = (direction[0]**2.0 + direction[1]**2.0)**0.5
        if norm <= 0.0:
            norm = 1.0
        direction = direction / norm
        direction = torch.from_numpy(direction)
        direction = torch.unsqueeze(direction, 0)
        # print('d',direction)
        # print('g',gaze)
        # print('e',eye)

        channel, height, width = gaze_field.shape
        gaze_field_directed = torch.from_numpy(np.copy(gaze_field))
        gaze_field_directed = torch.unsqueeze(gaze_field_directed, 0)
        gaze_field_directed = gaze_field_directed.permute([0, 2, 3, 1])
        gaze_field_directed = gaze_field_directed.view([1, -1, 2])

        gaze_field_directed = torch.matmul(gaze_field_directed,
                                           direction.view([1, 2, 1]).float())
        gaze_field_directed = gaze_field_directed.view([1, height, width, 1])
        gaze_field_directed = gaze_field_directed.permute([0, 3, 1, 2])
        # gaze_field_directed = torch.squeeze(gaze_field_directed, 0)
        gaze_field_directed[gaze_field_directed < 0] = 0

        # print(torch.max(gaze_field_directed))
        # print(torch.min(gaze_field_directed))

        # different alpha
        gaze_field_directed_2 = torch.pow(gaze_field_directed, 2)
        gaze_field_directed_3 = torch.pow(gaze_field_directed, 5)
        gaze_field_directed_4 = torch.pow(gaze_field_directed, 20)
        gaze_field_directed_5 = torch.pow(gaze_field_directed, 100)
        gaze_field_directed_6 = torch.pow(gaze_field_directed, 500)

        directed_gaze_field = torch.cat([
            gaze_field_directed, gaze_field_directed_2, gaze_field_directed_3,
            gaze_field_directed_4, gaze_field_directed_5, gaze_field_directed_6
        ])

        # create bbox masks

        head_bbox_mask = np.zeros((h, w))
        head_bbox_mask[int(box[1] * h):int(box[3] * h),
                       int(box[0] * w):int(box[2] * w)] = 1
        head_bbox_mask = cv2.resize(head_bbox_mask, (224, 224))
        box = box.astype(np.float32)

        sample = {
            'image': image,
            # 'face_image': face_image,
            'eye_position': torch.FloatTensor(eye),
            'gaze_field': torch.from_numpy(
                gaze_field),  # this gaze field does not have direction yet
            # direction is computed during training
            # "ground truth" directed gaze field is not used
            'gt_position': torch.FloatTensor(gaze),
            'gt_heatmap': torch.FloatTensor(heatmap).unsqueeze(0),
            'image_path': image_path,
            'original_image': original_image,
            # 'gaze_idx': gaze_idx,       # not available in this dataset
            'gaze_class': -1,  # not available in this dataset
            # 'gaze_bbox': gaze_bbox,     # not available in this dataset
            'head_bbox': torch.FloatTensor(head_bbox),
            'head_loc': eye,
            'directed_gaze_field': directed_gaze_field,
            # 'gt_bboxes': gt_bboxes,             # not available in this dataset
            # 'gt_labels': gt_labels,             # not available in this dataset
            # 'gaze_bbox_mask': gaze_bbox_mask,   # not available in this dataset
            'head_bbox_mask': head_bbox_mask,
        }

        return sample
예제 #6
0
    def __getitem__(self, idx):

        data = self.data[idx]
        image_path = data['filename']
        image_path = os.path.join(self.root_dir, image_path)
        #print(image_path)

        eye = [float(data['hx']) / 640, float(data['hy']) / 480]
        gaze = [float(data['gaze_cx']) / 640, float(data['gaze_cy']) / 480]
        #print('eye coords: ', eye)
        #print('gaze coords: ', gaze)

        image_path = image_path.replace('\\', '/')
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)

        flip_flag = False
        if random.random() > 0.5 and self.training == 'train':
            eye = [1.0 - eye[0], eye[1]]
            gaze = [1.0 - gaze[0], gaze[1]]
            image = cv2.flip(image, 1)
            flip_flag = True

        # crop face
        x_c, y_c = eye
        x_0 = x_c - 0.15
        y_0 = y_c - 0.15
        x_1 = x_c + 0.15
        y_1 = y_c + 0.15
        if x_0 < 0:
            x_0 = 0
        if y_0 < 0:
            y_0 = 0
        if x_1 > 1:
            x_1 = 1
        if y_1 > 1:
            y_1 = 1
        h, w = image.shape[:2]
        face_image = image[int(y_0 * h):int(y_1 * h),
                           int(x_0 * w):int(x_1 * w), :]
        # process face_image for face net
        face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
        face_image = Image.fromarray(face_image)
        face_image = data_transforms[self.training](face_image)
        # process image for saliency net
        #image = image_preprocess(image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = data_transforms[self.training](image)

        # generate gaze field
        gaze_field = self.generate_data_field(eye_point=eye)

        #Get bounding boxes and class labels as well as gt index for gazed object
        gt_bboxes, gt_labels = np.zeros(1), np.zeros(1)
        gt_labels = np.expand_dims(gt_labels, axis=0)
        gaze_idx = np.copy(data['gazeIdx']).astype(np.int64)
        gaze_class = np.copy(data['gaze_item']).astype(np.int64)
        if self.use_bboxes:
            gt_bboxes = np.copy(data['ann']['bboxes']) / [640, 480, 640, 480]
            gt_labels = np.copy(data['ann']['labels'])

        # generate Default heatmap
        heatmap = get_paste_kernel((224 // 4, 224 // 4), gaze, kernel_map,
                                   (224 // 4, 224 // 4))

        # Create GazeMask heatmap
        seg = data['seg']
        if seg is not None:
            seg_mask = self.create_mask(np.array(seg).astype(np.int64))
            seg_mask = cv2.resize(seg_mask, (224 // 4, 224 // 4))

            if flip_flag:
                seg_mask = cv2.flip(seg_mask, 1)  # horizontal flip

            seg_mask = seg_mask.astype(np.float64) / 255.0
            if self.alpha < 1.0:
                gaze_mask = self.alpha * seg_mask + (1 - self.alpha) * heatmap
            elif self.alpha == 1.0:
                gaze_mask = np.fmax(seg_mask, heatmap)

        else:
            gaze_mask = np.zeros((224 // 4, 224 // 4))

        if self.use_gazemask:
            heatmap = gaze_mask
        '''
        direction = gaze - eye
        norm = (direction[0] ** 2.0 + direction[1] ** 2.0) ** 0.5
        if norm <= 0.0:
            norm = 1.0

        direction = direction / norm
        '''

        sample = {
            'image': image,
            'face_image': face_image,
            'eye_position': torch.FloatTensor(eye),
            'gaze_field': torch.from_numpy(gaze_field),
            'gt_position': torch.FloatTensor(gaze),
            'gt_heatmap': torch.FloatTensor(heatmap).unsqueeze(0),
            'gt_bboxes': gt_bboxes,
            'gt_labels': gt_labels,
            'gaze_idx': gaze_idx,
            'gaze_class': gaze_class,
            'image_path': image_path
        }

        return sample