def get_ims(slice_idx, vol, spacing, slice_intv, mask_list):
    num_slice = cfg.INPUT.NUM_SLICES * cfg.INPUT.NUM_IMAGES_3DCE
    im_np, im_scale, crop, mask_list = load_prep_img(vol, slice_idx, spacing, slice_intv, mask_list,
                                          cfg.INPUT.IMG_DO_CLIP, num_slice=num_slice)
    im = im_np - cfg.INPUT.PIXEL_MEAN
    im = torch.from_numpy(im.transpose((2, 0, 1))).to(dtype=torch.float)
    ims = im.split(cfg.INPUT.NUM_IMAGES_3DCE)
    return ims, im_np[:, :, int(num_slice/2)+1], im_scale, crop, mask_list
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target, info).
        """
        image_fn = self.image_fn_list[index]
        lesion_idx_grouped = self.lesion_idx_grouped[index]
        boxes0 = self.boxes[lesion_idx_grouped]
        # slice_no = self.slice_idx[lesion_idx_grouped][0]
        slice_intv = self.slice_intv[lesion_idx_grouped][0]
        spacing = self.spacing[lesion_idx_grouped][0]
        recists = self.d_coordinate[lesion_idx_grouped]
        diameters = self.diameter[lesion_idx_grouped]
        window = self.DICOM_window[lesion_idx_grouped][0]
        gender = float(self.gender[lesion_idx_grouped][0] == 'M')
        age = self.age[lesion_idx_grouped][0]/100
        if np.isnan(age) or age == 0:
            age = .5
        z_coord = self.norm_location[lesion_idx_grouped[0], 2]

        num_slice = cfg.INPUT.NUM_SLICES * cfg.INPUT.NUM_IMAGES_3DCE
        is_train = self.split=='train'
        if is_train and cfg.INPUT.DATA_AUG_3D is not False:
            slice_radius = diameters.min() / 2 * spacing / slice_intv * abs(cfg.INPUT.DATA_AUG_3D)  # lesion should not be too small
            slice_radius = int(slice_radius)
            if slice_radius > 0:
                if cfg.INPUT.DATA_AUG_3D > 0:
                    delta = np.random.randint(0, slice_radius+1)
                else:  # give central slice higher prob
                    ar = np.arange(slice_radius+1)
                    p = slice_radius-ar.astype(float)
                    delta = np.random.choice(ar, p=p/p.sum())
                if np.random.rand(1) > .5:
                    delta = -delta

                dirname, slicename = image_fn.split(os.sep)
                slice_idx = int(slicename[:-4])
                image_fn1 = '%s%s%03d.png' % (dirname, os.sep, slice_idx + delta)
                if os.path.exists(os.path.join(self.data_path, image_fn1)):
                    image_fn = image_fn1
        im, im_scale, crop = load_prep_img(self.data_path, image_fn, spacing, slice_intv,
                                           cfg.INPUT.IMG_DO_CLIP, num_slice=num_slice, is_train=is_train)

        im -= cfg.INPUT.PIXEL_MEAN
        im = torch.from_numpy(im.transpose((2, 0, 1))).to(dtype=torch.float)

        boxes_new = boxes0.copy()
        if cfg.INPUT.IMG_DO_CLIP:
            offset = [crop[2], crop[0]]
            boxes_new -= offset*2
        boxes_new *= im_scale
        boxes = torch.as_tensor(boxes_new).reshape(-1, 4)  # guard against no boxes
        target = BoxList(boxes, (im.shape[2], im.shape[1]), mode="xyxy")

        num_boxes = boxes.shape[0]
        classes = torch.ones(num_boxes, dtype=torch.int)  # lesion/nonlesion
        target.add_field("labels", classes)
        if cfg.MODEL.TAG_ON:
            tags = torch.zeros(num_boxes, self.num_tags, dtype=torch.int)
            reliable_neg_tags = torch.zeros(num_boxes, self.num_tags, dtype=torch.int)
            for p in range(num_boxes):
                if lesion_idx_grouped[p] in self.lesion_tags.keys():
                    pos_tags = self.lesion_tags[lesion_idx_grouped[p]]
                    tags[p, pos_tags] = 1
                    ex_tags = [e for l in pos_tags for e in self.exclusive_list[l] if e not in pos_tags]
                    reliable_neg_tags[p, ex_tags] = 1
                else:
                    tags[p] = -1  # no tag exist for this lesion, the loss weights for this lesion should be zero
                    reliable_neg_tags[p] = -1  # no tag exist for this lesion, the loss weights for this lesion should be zero
            target.add_field("tags", tags)
            target.add_field("reliable_neg_tags", reliable_neg_tags)

            if self.split == 'test':
                tags = torch.zeros(num_boxes, self.num_tags, dtype=torch.int)
                for p in range(num_boxes):
                    if lesion_idx_grouped[p] in self.manual_annot_test_tags.keys():
                        tags[p, self.manual_annot_test_tags[lesion_idx_grouped[p]]] = 1
                    else:
                        tags[p] = -1  # no tag exist for this lesion, the loss weights for this lesion should be zero
                target.add_field("manual_annot_test_tags", tags)

        if cfg.INPUT.IMG_DO_CLIP:
            recists -= offset * 4
        recists *= im_scale
        if cfg.MODEL.MASK_ON:
            masks = []
            for recist in recists:
                mask = gen_mask_polygon_from_recist(recist)
                masks.append([mask])
            masks = SegmentationMask(masks, (im.shape[-1], im.shape[-2]))
            target.add_field("masks", masks)

        target = target.clip_to_image(remove_empty=False)

        if self.transforms is not None:
            im, target = self.transforms(im, target)

        infos = {'im_index': index, 'lesion_idxs': lesion_idx_grouped, 'image_fn': image_fn, 'diameters': diameters*spacing,
                 'crop': crop, 'recists': recists, 'window': window, 'spacing': spacing, 'im_scale': im_scale,
                 'gender': gender, 'age': age, 'z_coord': z_coord}
        return im, target, infos