Exemple #1
0
    def __getitem__(self, index):
        sf = self.scale_factor
        rf = self.rot_factor
        if self.is_train:
            a = self.anno[self.train_list[index]]
        else:
            a = self.anno[self.valid_list[index]]

        img_path = os.path.join(self.img_folder, a['img_paths'])
        pts = torch.Tensor(a['joint_self'])
        # pts[:, 0:2] -= 1  # Convert pts to zero based

        # c = torch.Tensor(a['objpos']) - 1
        c = torch.Tensor(a['objpos'])
        s = a['scale_provided']

        # Adjust center/scale slightly to avoid cropping limbs
        if c[0] != -1:
            c[1] = c[1] + 15 * s
            s = s * 1.25

        # For single-person pose estimation with a centered/scaled figure
        nparts = pts.size(0)
        img = load_image(img_path)  # CxHxW

        r = 0
        if self.is_train:
            s = s * torch.randn(1).mul_(sf).add_(1).clamp(1 - sf, 1 + sf)[0]
            r = torch.randn(1).mul_(rf).clamp(
                -2 * rf, 2 * rf)[0] if random.random() <= 0.6 else 0

            # Flip
            if random.random() <= 0.5:
                img = fliplr(img)
                pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices)
                c[0] = img.size(2) - c[0]

            # Color
            img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
            img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
            img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)

        # Prepare image and groundtruth map
        inp = crop(img, c, s, self.inp_res, rot=r)
        inp = color_normalize(inp, self.DATA_INFO.rgb_mean,
                              self.DATA_INFO.rgb_stddev)

        # Generate ground truth
        tpts = pts.clone()
        target = torch.zeros(nparts, *self.out_res)
        target_weight = tpts[:, 2].clone().view(nparts, 1)

        for i in range(nparts):
            # if tpts[i, 2] > 0: # This is evil!!
            if tpts[i, 1] > 0:
                tpts[i, 0:2] = to_torch(
                    transform(tpts[i, 0:2] + 1, c, s, self.out_res, rot=r))
                target[i], vis = draw_labelmap(target[i],
                                               tpts[i] - 1,
                                               self.sigma,
                                               type=self.label_type)
                target_weight[i, 0] *= vis

        # Meta info
        if not isinstance(s, torch.Tensor):
            s = torch.Tensor(s)

        meta = {
            'index': index,
            'center': c,
            'scale': s,
            'pts': pts,
            'tpts': tpts,
            'target_weight': target_weight
        }

        return inp, target, meta
    def __getitem__(self, index):
        sequence_path = self.all_sequence_paths[index]
        df = pd.read_csv(
            sequence_path,
            header=None,
            index_col=False,
            names=['path', 'xmin', 'ymin', 'xmax', 'ymax', 'gazex', 'gazey'])
        show_name = sequence_path.split('/')[-3]
        clip = sequence_path.split('/')[-2]
        seq_len = len(df.index)

        # moving-avg smoothing
        window_size = 11  # should be odd number
        df['xmin'] = myutils.smooth_by_conv(window_size, df, 'xmin')
        df['ymin'] = myutils.smooth_by_conv(window_size, df, 'ymin')
        df['xmax'] = myutils.smooth_by_conv(window_size, df, 'xmax')
        df['ymax'] = myutils.smooth_by_conv(window_size, df, 'ymax')

        if not self.test:
            # cond for data augmentation
            cond_jitter = np.random.random_sample()
            cond_flip = np.random.random_sample()
            cond_color = np.random.random_sample()
            if cond_color < 0.5:
                n1 = np.random.uniform(0.5, 1.5)
                n2 = np.random.uniform(0.5, 1.5)
                n3 = np.random.uniform(0.5, 1.5)
            cond_crop = np.random.random_sample()

            # if longer than seq_len_limit, cut it down to the limit with the init index randomly sampled
            if seq_len > self.seq_len_limit:
                sampled_ind = np.random.randint(0,
                                                seq_len - self.seq_len_limit)
                seq_len = self.seq_len_limit
            else:
                sampled_ind = 0

            if cond_crop < 0.5:
                sliced_x_min = df['xmin'].iloc[sampled_ind:sampled_ind +
                                               seq_len]
                sliced_x_max = df['xmax'].iloc[sampled_ind:sampled_ind +
                                               seq_len]
                sliced_y_min = df['ymin'].iloc[sampled_ind:sampled_ind +
                                               seq_len]
                sliced_y_max = df['ymax'].iloc[sampled_ind:sampled_ind +
                                               seq_len]

                sliced_gaze_x = df['gazex'].iloc[sampled_ind:sampled_ind +
                                                 seq_len]
                sliced_gaze_y = df['gazey'].iloc[sampled_ind:sampled_ind +
                                                 seq_len]

                check_sum = sliced_gaze_x.sum() + sliced_gaze_y.sum()
                all_outside = check_sum == -2 * seq_len

                # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
                if all_outside:
                    crop_x_min = np.min(
                        [sliced_x_min.min(),
                         sliced_x_max.min()])
                    crop_y_min = np.min(
                        [sliced_y_min.min(),
                         sliced_y_max.min()])
                    crop_x_max = np.max(
                        [sliced_x_min.max(),
                         sliced_x_max.max()])
                    crop_y_max = np.max(
                        [sliced_y_min.max(),
                         sliced_y_max.max()])
                else:
                    crop_x_min = np.min([
                        sliced_gaze_x.min(),
                        sliced_x_min.min(),
                        sliced_x_max.min()
                    ])
                    crop_y_min = np.min([
                        sliced_gaze_y.min(),
                        sliced_y_min.min(),
                        sliced_y_max.min()
                    ])
                    crop_x_max = np.max([
                        sliced_gaze_x.max(),
                        sliced_x_min.max(),
                        sliced_x_max.max()
                    ])
                    crop_y_max = np.max([
                        sliced_gaze_y.max(),
                        sliced_y_min.max(),
                        sliced_y_max.max()
                    ])

                # Randomly select a random top left corner
                if crop_x_min >= 0:
                    crop_x_min = np.random.uniform(0, crop_x_min)
                if crop_y_min >= 0:
                    crop_y_min = np.random.uniform(0, crop_y_min)

                # Get image size
                path = os.path.join(self.data_dir, show_name, clip,
                                    df['path'].iloc[0])
                img = Image.open(path)
                img = img.convert('RGB')
                width, height = img.size

                # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
                crop_width_min = crop_x_max - crop_x_min
                crop_height_min = crop_y_max - crop_y_min
                crop_width_max = width - crop_x_min
                crop_height_max = height - crop_y_min
                # Randomly select a width and a height
                crop_width = np.random.uniform(crop_width_min, crop_width_max)
                crop_height = np.random.uniform(crop_height_min,
                                                crop_height_max)
        else:
            sampled_ind = 0


        faces, images, head_channels, heatmaps, paths, gazes, imsizes, gaze_inouts = [], [], [], [], [], [], [], []
        index_tracker = -1
        for i, row in df.iterrows():
            index_tracker = index_tracker + 1
            if not self.test:
                if index_tracker < sampled_ind or index_tracker >= (
                        sampled_ind + self.seq_len_limit):
                    continue

            face_x1 = row['xmin']  # note: Already in image coordinates
            face_y1 = row['ymin']  # note: Already in image coordinates
            face_x2 = row['xmax']  # note: Already in image coordinates
            face_y2 = row['ymax']  # note: Already in image coordinates
            gaze_x = row['gazex']  # note: Already in image coordinates
            gaze_y = row['gazey']  # note: Already in image coordinates

            impath = os.path.join(self.data_dir, show_name, clip, row['path'])
            img = Image.open(impath)
            img = img.convert('RGB')

            width, height = img.size
            imsize = torch.FloatTensor([width, height])
            # imsizes.append(imsize)

            face_x1, face_y1, face_x2, face_y2 = map(
                float, [face_x1, face_y1, face_x2, face_y2])
            gaze_x, gaze_y = map(float, [gaze_x, gaze_y])
            if gaze_x == -1 and gaze_y == -1:
                gaze_inside = False
            else:
                if gaze_x < 0:  # move gaze point that was sliglty outside the image back in
                    gaze_x = 0
                if gaze_y < 0:
                    gaze_y = 0
                gaze_inside = True

            if not self.test:
                ## data augmentation
                # Jitter (expansion-only) bounding box size.
                if cond_jitter < 0.5:
                    k = cond_jitter * 0.1
                    face_x1 -= k * abs(face_x2 - face_x1)
                    face_y1 -= k * abs(face_y2 - face_y1)
                    face_x2 += k * abs(face_x2 - face_x1)
                    face_y2 += k * abs(face_y2 - face_y1)
                    face_x1 = np.clip(face_x1, 0, width)
                    face_x2 = np.clip(face_x2, 0, width)
                    face_y1 = np.clip(face_y1, 0, height)
                    face_y2 = np.clip(face_y2, 0, height)

                # Random Crop
                if cond_crop < 0.5:
                    # Crop it
                    img = TF.crop(img, crop_y_min, crop_x_min, crop_height,
                                  crop_width)

                    # Record the crop's (x, y) offset
                    offset_x, offset_y = crop_x_min, crop_y_min

                    # convert coordinates into the cropped frame
                    face_x1, face_y1, face_x2, face_y2 = face_x1 - offset_x, face_y1 - offset_y, face_x2 - offset_x, face_y2 - offset_y
                    if gaze_inside:
                        gaze_x, gaze_y = (gaze_x- offset_x), \
                                         (gaze_y - offset_y)
                    else:
                        gaze_x = -1
                        gaze_y = -1

                    width, height = crop_width, crop_height

                # Flip?
                if cond_flip < 0.5:
                    img = img.transpose(Image.FLIP_LEFT_RIGHT)
                    x_max_2 = width - face_x1
                    x_min_2 = width - face_x2
                    face_x2 = x_max_2
                    face_x1 = x_min_2
                    if gaze_x != -1 and gaze_y != -1:
                        gaze_x = width - gaze_x

                # Random color change
                if cond_color < 0.5:
                    img = TF.adjust_brightness(img, brightness_factor=n1)
                    img = TF.adjust_contrast(img, contrast_factor=n2)
                    img = TF.adjust_saturation(img, saturation_factor=n3)

            # Face crop
            face = img.copy().crop(
                (int(face_x1), int(face_y1), int(face_x2), int(face_y2)))

            # Head channel image
            head_channel = imutils.get_head_box_channel(
                face_x1,
                face_y1,
                face_x2,
                face_y2,
                width,
                height,
                resolution=self.input_size,
                coordconv=False).unsqueeze(0)
            if self.transform is not None:
                img = self.transform(img)
                face = self.transform(face)

            # Deconv output
            if gaze_inside:
                gaze_x /= float(width)  # fractional gaze
                gaze_y /= float(height)
                gaze_heatmap = torch.zeros(
                    self.output_size,
                    self.output_size)  # set the size of the output
                gaze_map = imutils.draw_labelmap(
                    gaze_heatmap,
                    [gaze_x * self.output_size, gaze_y * self.output_size],
                    3,
                    type='Gaussian')
                gazes.append(torch.FloatTensor([gaze_x, gaze_y]))
            else:
                gaze_map = torch.zeros(self.output_size, self.output_size)
                gazes.append(torch.FloatTensor([-1, -1]))
            faces.append(face)
            images.append(img)
            head_channels.append(head_channel)
            heatmaps.append(gaze_map)
            gaze_inouts.append(torch.FloatTensor([int(gaze_inside)]))

        if self.imshow:
            for i in range(len(faces)):
                fig = plt.figure(111)
                img = 255 - imutils.unnorm(images[i].numpy()) * 255
                img = np.clip(img, 0, 255)
                plt.imshow(np.transpose(img, (1, 2, 0)))
                plt.imshow(imresize(heatmaps[i],
                                    (self.input_size, self.input_size)),
                           cmap='jet',
                           alpha=0.3)
                plt.imshow(imresize(1 - head_channels[i].squeeze(0),
                                    (self.input_size, self.input_size)),
                           alpha=0.2)
                plt.savefig(
                    os.path.join('debug',
                                 'viz_%d_inout=%d.png' % (i, gaze_inouts[i])))
                plt.close('all')

        faces = torch.stack(faces)
        images = torch.stack(images)
        head_channels = torch.stack(head_channels)
        heatmaps = torch.stack(heatmaps)
        gazes = torch.stack(gazes)
        gaze_inouts = torch.stack(gaze_inouts)
        # imsizes = torch.stack(imsizes)
        # print(faces.shape, images.shape, head_channels.shape, heatmaps.shape)

        if self.test:
            return images, faces, head_channels, heatmaps, gazes, gaze_inouts
        else:  # train
            return images, faces, head_channels, heatmaps, gaze_inouts
    def __getitem__(self, index):
        if self.test:
            g = self.X_test.get_group(self.keys[index])
            cont_gaze = []
            for i, row in g.iterrows():
                path = row['path']
                x_min = row['bbox_x_min']
                y_min = row['bbox_y_min']
                x_max = row['bbox_x_max']
                y_max = row['bbox_y_max']
                eye_x = row['eye_x']
                eye_y = row['eye_y']
                gaze_x = row['gaze_x']
                gaze_y = row['gaze_y']
                cont_gaze.append([gaze_x, gaze_y
                                  ])  # all ground truth gaze are stacked up
            for j in range(len(cont_gaze), 20):
                cont_gaze.append(
                    [-1,
                     -1])  # pad dummy gaze to match size for batch processing
            cont_gaze = torch.FloatTensor(cont_gaze)
            gaze_inside = True  # always consider test samples as inside

        else:
            path = self.X_train.iloc[index]
            eye_x, eye_y, gaze_x, gaze_y = self.y_train.iloc[index]
            gaze_inside = True  # bool(inout)

        img = Image.open(os.path.join(self.data_dir, path))
        img = img.convert('RGB')
        width, height = img.size
        # print('gaze coords: ', type(gaze_x), type(gaze_y), gaze_x, gaze_y)
        # print('eye coords: ', type(eye_x), type(eye_y), eye_x, eye_y)
        # expand face bbox a bit
        k = 0.1
        x_min = (eye_x - 0.15) * width
        y_min = (eye_y - 0.15) * height
        x_max = (eye_x + 0.15) * width
        y_max = (eye_y + 0.15) * height
        if x_min < 0:
            x_min = 0
        if y_min < 0:
            y_min = 0
        if x_max < 0:
            x_max = 0
        if y_max < 0:
            y_max = 0
        x_min -= k * abs(x_max - x_min)
        y_min -= k * abs(y_max - y_min)
        x_max += k * abs(x_max - x_min)
        y_max += k * abs(y_max - y_min)

        # x_min = eye_x - 0.15
        # y_min = eye_y - 0.15
        # x_max = eye_x + 0.15
        # y_max = eye_y + 0.15
        # if x_min < 0:
        #     x_min = 0
        # if y_min < 0:
        #     y_min = 0
        # if x_max < 0:
        #     x_max = 0
        # if y_max < 0:
        #     y_max = 0

        # print('bbx',  [x_min, y_min, x_max, y_max])

        x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
        # print(x_min, y_min, x_max, y_max)
        if self.imshow:
            img.save("origin_img.jpg")

        if self.test:
            imsize = torch.IntTensor([width, height])
        else:
            ## data augmentation

            # Jitter (expansion-only) bounding box size
            if np.random.random_sample() <= 0.5:
                k = np.random.random_sample() * 0.2
                x_min -= k * abs(x_max - x_min)
                y_min -= k * abs(y_max - y_min)
                x_max += k * abs(x_max - x_min)
                y_max += k * abs(y_max - y_min)

            # Random Crop
            if np.random.random_sample() <= 0.5:
                # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
                crop_x_min = np.min([gaze_x * width, x_min, x_max])
                crop_y_min = np.min([gaze_y * height, y_min, y_max])
                crop_x_max = np.max([gaze_x * width, x_min, x_max])
                crop_y_max = np.max([gaze_y * height, y_min, y_max])

                # Randomly select a random top left corner
                if crop_x_min >= 0:
                    crop_x_min = np.random.uniform(0, crop_x_min)
                if crop_y_min >= 0:
                    crop_y_min = np.random.uniform(0, crop_y_min)

                # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
                crop_width_min = crop_x_max - crop_x_min
                crop_height_min = crop_y_max - crop_y_min
                crop_width_max = width - crop_x_min
                crop_height_max = height - crop_y_min
                # Randomly select a width and a height
                crop_width = np.random.uniform(crop_width_min, crop_width_max)
                crop_height = np.random.uniform(crop_height_min,
                                                crop_height_max)

                # Crop it
                img = TF.crop(img, crop_y_min, crop_x_min, crop_height,
                              crop_width)

                # Record the crop's (x, y) offset
                offset_x, offset_y = crop_x_min, crop_y_min

                # convert coordinates into the cropped frame
                x_min, y_min, x_max, y_max = x_min - offset_x, y_min - offset_y, x_max - offset_x, y_max - offset_y
                # if gaze_inside:
                gaze_x, gaze_y = (gaze_x * width - offset_x) / float(crop_width), \
                                 (gaze_y * height - offset_y) / float(crop_height)
                # else:
                #     gaze_x = -1; gaze_y = -1

                width, height = crop_width, crop_height

            # Random flip
            if np.random.random_sample() <= 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)
                x_max_2 = width - x_min
                x_min_2 = width - x_max
                x_max = x_max_2
                x_min = x_min_2
                gaze_x = 1 - gaze_x

            # Random color change
            if np.random.random_sample() <= 0.5:
                img = TF.adjust_brightness(img,
                                           brightness_factor=np.random.uniform(
                                               0.5, 1.5))
                img = TF.adjust_contrast(img,
                                         contrast_factor=np.random.uniform(
                                             0.5, 1.5))
                img = TF.adjust_saturation(img,
                                           saturation_factor=np.random.uniform(
                                               0, 1.5))
        # print('bbx2',  [x_min, y_min, x_max, y_max])

        head_channel = imutils.get_head_box_channel(
            x_min,
            y_min,
            x_max,
            y_max,
            width,
            height,
            resolution=self.input_size,
            coordconv=False).unsqueeze(0)

        # Crop the face
        face = img.crop((int(x_min), int(y_min), int(x_max), int(y_max)))

        if self.imshow:
            img.save("img_aug.jpg")
            face.save('face_aug.jpg')

        if self.transform is not None:
            img = self.transform(img)
            face = self.transform(face)
        # print('imsize2', img.size())

        # generate the heat map used for deconv prediction
        gaze_heatmap = torch.zeros(
            self.output_size, self.output_size)  # set the size of the output
        # print([gaze_x * self.output_size, gaze_y * self.output_size])
        # print(self.output_size)
        if self.test:  # aggregated heatmap
            num_valid = 0
            for gaze_x, gaze_y in cont_gaze:
                if gaze_x != -1:
                    num_valid += 1
                    gaze_heatmap = imutils.draw_labelmap(
                        gaze_heatmap,
                        [gaze_x * self.output_size, gaze_y * self.output_size],
                        3,
                        type='Gaussian')
            gaze_heatmap /= num_valid
        else:
            # if gaze_inside:
            gaze_heatmap = imutils.draw_labelmap(
                gaze_heatmap,
                [gaze_x * self.output_size, gaze_y * self.output_size],
                3,
                type='Gaussian')

        if self.imshow:
            fig = plt.figure(111)
            img = 255 - imutils.unnorm(img.numpy()) * 255
            img = np.clip(img, 0, 255)
            plt.imshow(np.transpose(img, (1, 2, 0)))
            plt.imshow(imresize(gaze_heatmap,
                                (self.input_size, self.input_size)),
                       cmap='jet',
                       alpha=0.3)
            plt.imshow(imresize(1 - head_channel.squeeze(0),
                                (self.input_size, self.input_size)),
                       alpha=0.2)
            plt.savefig('viz_aug.png')

        if self.test:
            return img, face, head_channel, gaze_heatmap, cont_gaze, imsize, path
        else:
            return img, face, head_channel, gaze_heatmap, path, gaze_inside
def test():
    transform = _get_transform()

    # Prepare data
    print("Loading Data")
    val_dataset = VideoAttTarget_video(videoattentiontarget_val_data,
                                       videoattentiontarget_val_label,
                                       transform=transform,
                                       test=True,
                                       seq_len_limit=50)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=0,
                                             collate_fn=video_pack_sequences)

    # Define device
    device = torch.device('cuda', args.device)

    # Load model
    num_lstm_layers = 2
    print("Constructing model")
    model = ModelSpatioTemporal(num_lstm_layers=num_lstm_layers)
    model.cuda(device)

    print("Loading weights")
    model_dict = model.state_dict()
    snapshot = torch.load(args.model_weights)
    snapshot = snapshot['model']
    model_dict.update(snapshot)
    model.load_state_dict(model_dict)

    print('Evaluation in progress ...')
    model.train(False)
    AUC = []
    in_vs_out_groundtruth = []
    in_vs_out_pred = []
    distance = []
    chunk_size = 3
    with torch.no_grad():
        for batch_val, (img_val, face_val, head_channel_val, gaze_heatmap_val,
                        cont_gaze, inout_label_val,
                        lengths_val) in enumerate(val_loader):
            print('\tprogress = ', batch_val + 1, '/', len(val_loader))
            X_pad_data_img, X_pad_sizes = pack_padded_sequence(
                img_val, lengths_val, batch_first=True)
            X_pad_data_head, _ = pack_padded_sequence(head_channel_val,
                                                      lengths_val,
                                                      batch_first=True)
            X_pad_data_face, _ = pack_padded_sequence(face_val,
                                                      lengths_val,
                                                      batch_first=True)
            Y_pad_data_cont_gaze, _ = pack_padded_sequence(cont_gaze,
                                                           lengths_val,
                                                           batch_first=True)
            Y_pad_data_heatmap, _ = pack_padded_sequence(gaze_heatmap_val,
                                                         lengths_val,
                                                         batch_first=True)
            Y_pad_data_inout, _ = pack_padded_sequence(inout_label_val,
                                                       lengths_val,
                                                       batch_first=True)

            hx = (torch.zeros(
                (num_lstm_layers, args.batch_size, 512, 7, 7)).cuda(device),
                  torch.zeros((num_lstm_layers, args.batch_size, 512, 7,
                               7)).cuda(device)
                  )  # (num_layers, batch_size, feature dims)
            last_index = 0
            previous_hx_size = args.batch_size

            for i in range(0, lengths_val[0], chunk_size):
                X_pad_sizes_slice = X_pad_sizes[i:i + chunk_size].cuda(device)
                curr_length = np.sum(X_pad_sizes_slice.cpu().detach().numpy())
                # slice padded data
                X_pad_data_slice_img = X_pad_data_img[last_index:last_index +
                                                      curr_length].cuda(device)
                X_pad_data_slice_head = X_pad_data_head[last_index:last_index +
                                                        curr_length].cuda(
                                                            device)
                X_pad_data_slice_face = X_pad_data_face[last_index:last_index +
                                                        curr_length].cuda(
                                                            device)
                Y_pad_data_slice_cont_gaze = Y_pad_data_cont_gaze[
                    last_index:last_index + curr_length].cuda(device)
                Y_pad_data_slice_heatmap = Y_pad_data_heatmap[
                    last_index:last_index + curr_length].cuda(device)
                Y_pad_data_slice_inout = Y_pad_data_inout[
                    last_index:last_index + curr_length].cuda(device)
                last_index += curr_length

                # detach previous hidden states to stop gradient flow
                prev_hx = (hx[0][:, :min(X_pad_sizes_slice[0], previous_hx_size
                                         ), :, :, :].detach(),
                           hx[1][:, :min(X_pad_sizes_slice[0], previous_hx_size
                                         ), :, :, :].detach())

                # forward pass
                deconv, inout_val, hx = model(X_pad_data_slice_img, X_pad_data_slice_head, X_pad_data_slice_face, \
                                                         hidden_scene=prev_hx, batch_sizes=X_pad_sizes_slice)

                for b_i in range(len(Y_pad_data_slice_cont_gaze)):
                    if Y_pad_data_slice_inout[b_i]:  # ONLY for 'inside' cases
                        # AUC: area under curve of ROC
                        multi_hot = torch.zeros(
                            output_resolution,
                            output_resolution)  # set the size of the output
                        gaze_x = Y_pad_data_slice_cont_gaze[b_i, 0]
                        gaze_y = Y_pad_data_slice_cont_gaze[b_i, 1]
                        multi_hot = imutils.draw_labelmap(multi_hot, [
                            gaze_x * output_resolution,
                            gaze_y * output_resolution
                        ],
                                                          3,
                                                          type='Gaussian')
                        multi_hot = (multi_hot > 0).float(
                        ) * 1  # make GT heatmap as binary labels
                        multi_hot = misc.to_numpy(multi_hot)

                        scaled_heatmap = imresize(
                            deconv[b_i].squeeze(),
                            (output_resolution, output_resolution),
                            interp='bilinear')
                        auc_score = evaluation.auc(scaled_heatmap, multi_hot)
                        AUC.append(auc_score)

                        # distance: L2 distance between ground truth and argmax point
                        pred_x, pred_y = evaluation.argmax_pts(
                            deconv[b_i].squeeze())
                        norm_p = [
                            pred_x / output_resolution,
                            pred_y / output_resolution
                        ]
                        dist_score = evaluation.L2_dist(
                            Y_pad_data_slice_cont_gaze[b_i], norm_p).item()
                        distance.append(dist_score)

                # in vs out classification
                in_vs_out_groundtruth.extend(
                    Y_pad_data_slice_inout.cpu().numpy())
                in_vs_out_pred.extend(inout_val.cpu().numpy())

                previous_hx_size = X_pad_sizes_slice[-1]

            try:
                print("\tAUC:{:.4f}"
                      "\tdist:{:.4f}"
                      "\tin vs out AP:{:.4f}".format(
                          torch.mean(torch.tensor(AUC)),
                          torch.mean(torch.tensor(distance)),
                          evaluation.ap(in_vs_out_groundtruth,
                                        in_vs_out_pred)))
            except:
                pass

    print("Summary ")
    print("\tAUC:{:.4f}"
          "\tdist:{:.4f}"
          "\tin vs out AP:{:.4f}".format(
              torch.mean(torch.tensor(AUC)),
              torch.mean(torch.tensor(distance)),
              evaluation.ap(in_vs_out_groundtruth, in_vs_out_pred)))