コード例 #1
0
def cuda2np(image):
    # image: d,h,w
    if isinstance(image, torch.Tensor):
        if image.is_cuda:
            image = image.cpu().detach().numpy()
        else:
            image = image.numpy()
    elif not isinstance(image, np.ndarray):
        logger.error('image should be torch.Tensor or numpy.ndarray')
    return image
コード例 #2
0
ファイル: cc_augment.py プロジェクト: yblsn12138/u2net_torch
def cc_augment(config_task,
               data,
               seg,
               patch_type,
               patch_size,
               patch_center_dist_from_border=30,
               do_elastic_deform=True,
               alpha=(0., 1000.),
               sigma=(10., 13.),
               do_rotation=True,
               angle_x=(0, 2 * np.pi),
               angle_y=(0, 2 * np.pi),
               angle_z=(0, 2 * np.pi),
               do_scale=True,
               scale=(0.75, 1.25),
               border_mode_data='constant',
               border_cval_data=0,
               order_data=3,
               border_mode_seg='constant',
               border_cval_seg=0,
               order_seg=0,
               random_crop=True,
               p_el_per_sample=1,
               p_scale_per_sample=1,
               p_rot_per_sample=1,
               tag=''):
    # patch_center_dist_from_border should be no more than 1/2 patch size. otherwise code not available.

    # data: [n,c,d,h,w]
    # seg: [n,c,d,h,w]
    dim = len(patch_size)

    seg_result = None
    if seg is not None:
        seg_result = np.zeros([seg.shape[0], seg.shape[1]] + patch_size,
                              dtype=np.float32)

    data_result = np.zeros([data.shape[0], data.shape[1]] + patch_size,
                           dtype=np.float32)

    if not isinstance(patch_center_dist_from_border,
                      (list, tuple, np.ndarray)):
        patch_center_dist_from_border = dim * [patch_center_dist_from_border]

    ## for-loop for dim[0]
    augs = list()
    for sample_id in range(data.shape[0]):
        coords = create_zero_centered_coordinate_mesh(patch_size)

        # now find a nice center location and extract patch
        if seg is None:
            patch_type = 'any'

        handler = 0
        n = 0
        while handler == 0:

            # augmentation
            modified_coords = False
            if np.random.uniform() < p_el_per_sample and do_elastic_deform:
                a = np.random.uniform(alpha[0], alpha[1])
                s = np.random.uniform(sigma[0], sigma[1])
                coords = elastic_deform_coordinates(coords, a, s)
                modified_coords = True

                augs.append('elastic')

            if np.random.uniform() < p_rot_per_sample and do_rotation:
                if angle_x[0] == angle_x[1]:
                    a_x = angle_x[0]
                else:
                    a_x = np.random.uniform(angle_x[0], angle_x[1])
                if dim == 3:
                    if angle_y[0] == angle_y[1]:
                        a_y = angle_y[0]
                    else:
                        a_y = np.random.uniform(angle_y[0], angle_y[1])
                    if angle_z[0] == angle_z[1]:
                        a_z = angle_z[0]
                    else:
                        a_z = np.random.uniform(angle_z[0], angle_z[1])
                    coords = rotate_coords_3d(coords, a_x, a_y, a_z)
                else:
                    coords = rotate_coords_2d(coords, a_x)
                modified_coords = True

                augs.append('rotation')

            if np.random.uniform() < p_scale_per_sample and do_scale:
                if np.random.random() < 0.5 and scale[0] < 1:
                    sc = np.random.uniform(scale[0], 1)
                else:
                    sc = np.random.uniform(max(scale[0], 1), scale[1])
                coords = scale_coords(coords, sc)
                modified_coords = True

                augs.append('scale')

            # find candidate area for center, the area is cand_point_coord +/- patch_size
            if patch_type in ['fore', 'small'] and seg is not None:
                if seg.shape[1] > 1:
                    logger.error('TBD for seg with multiple channels')
                if patch_type == 'fore':
                    lab_coords = np.where(
                        seg[sample_id, 0, ...] > 0)  # lab_coords: tuple
                elif patch_type == 'small':
                    if config_task.task == 'Task05_Prostate':
                        lab_coords = np.where(seg[sample_id, 0, ...] == 1)
                    else:
                        lab_coords = np.where(
                            seg[sample_id, 0,
                                ...] == config_task.num_class - 1)
                if len(lab_coords[0]) > 0:  # 0 means no such label exists
                    idx = np.random.choice(len(lab_coords[0]))
                    cand_point_coord = [
                        coords[idx] for coords in lab_coords
                    ]  # coords for one random point from 'fore' ground
                else:
                    cand_point_coord = None

            if patch_type in ['fore', 'small'] and cand_point_coord is None:
                ctr_list = None
                handler = 1
                data_result = None
                seg_result = None
                augs = None
            else:
                ctr_list = list()  # coords of the patch center
                for d in range(dim):
                    if random_crop:
                        if patch_type in ['fore', 'small'] and seg is not None:
                            low = max(
                                patch_center_dist_from_border[d] - 1,
                                cand_point_coord[d] - (patch_size[d] / 2 - 1))
                            low = int(low)
                            upper = min(
                                cand_point_coord[d] + (patch_size[d] / 2 - 1),
                                data.shape[d + 2] -
                                (patch_center_dist_from_border[d] - 1)
                            )  # +/- patch_size[d] is better but computation costly
                            upper = int(upper)

                            if low == upper:
                                ctr = int(low)
                            elif low < upper:
                                ctr = int(np.random.randint(low, upper))
                                # if n > 1:
                                #     logger.info('n:{}; [low,upper]:{}, ctr:{}'.format(n, str([low, upper]), ctr))
                            else:
                                logger.error(
                                    '(low:{} should be <= upper:{}). patch_type:{}, patch_center_dist_from_border:{}, cand_point_coord:{}, cand point seg value:{}, data.shape:{}, ctr_list:{}'
                                    .format(
                                        low, upper, str(patch_type),
                                        str(patch_center_dist_from_border),
                                        str(cand_point_coord),
                                        seg[sample_id, 0] + cand_point_coord,
                                        str(data.shape), str(ctr_list)))
                        elif patch_type == 'any':
                            if patch_center_dist_from_border[d] == data.shape[
                                    d + 2] - patch_center_dist_from_border[d]:
                                ctr = int(patch_center_dist_from_border[d])
                            elif patch_center_dist_from_border[d] < data.shape[
                                    d + 2] - patch_center_dist_from_border[d]:
                                ctr = int(
                                    np.random.randint(
                                        patch_center_dist_from_border[d],
                                        data.shape[d + 2] -
                                        patch_center_dist_from_border[d]))
                            else:
                                logger.error(
                                    'low should be <= upper. patch_type:{}, patch_center_dist_from_border:{}, data.shape:{}, ctr_list:{}'
                                    .format(str(patch_type),
                                            str(patch_center_dist_from_border),
                                            str(data.shape), str(ctr_list)))
                    else:  # center crop
                        ctr = int(np.round(data.shape[d + 2] / 2.))
                    ctr_list.append(ctr)

                # extracting patch
                if n < 10 and modified_coords:
                    for d in range(dim):
                        coords[d] += ctr_list[d]
                    for channel_id in range(data.shape[1]):
                        data_result[sample_id, channel_id] = interpolate_img(
                            data[sample_id, channel_id],
                            coords,
                            order_data,
                            border_mode_data,
                            cval=border_cval_data)
                    if seg is not None:
                        for channel_id in range(seg.shape[1]):
                            seg_result[sample_id,
                                       channel_id] = interpolate_img(
                                           seg[sample_id, channel_id],
                                           coords,
                                           order_seg,
                                           border_mode_seg,
                                           cval=border_cval_seg,
                                           is_seg=True)
                else:
                    augs = list()
                    if seg is None:
                        s = None
                    else:
                        s = seg[sample_id:sample_id + 1]
                    if random_crop:
                        # margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
                        # d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin)
                        d_tmps = list()
                        for channel_id in range(data.shape[1]):
                            d_tmp = utils.extract_roi_from_volume(
                                data[sample_id, channel_id],
                                ctr_list,
                                patch_size,
                                fill="zero")
                            d_tmps.append(d_tmp)
                        d = np.asarray(d_tmps)
                        if seg is not None:
                            s_tmps = list()
                            for channel_id in range(seg.shape[1]):
                                s_tmp = utils.extract_roi_from_volume(
                                    seg[sample_id, channel_id],
                                    ctr_list,
                                    patch_size,
                                    fill="zero")
                                s_tmps.append(s_tmp)
                            s = np.asarray(s_tmps)
                    else:
                        d, s = center_crop_aug(data[sample_id:sample_id + 1],
                                               patch_size, s)
                    # data_result[sample_id] = d[0]
                    data_result[sample_id] = d
                    if seg is not None:
                        # seg_result[sample_id] = s[0]
                        seg_result[sample_id] = s

                ## check patch
                if patch_type in [
                        'fore'
                ]:  # cancer could be very very small. so use opproximate method (i.e. use 'fore').
                    if np.any(seg_result > 0) and np.any(data_result != 0):
                        handler = 1
                    else:
                        handler = 0
                elif patch_type in ['small']:
                    if config_task.task == 'Task05_Prostate':
                        if np.any(seg_result == 1) and np.any(
                                data_result != 0):
                            handler = 1
                        else:
                            handler = 0
                    else:
                        if np.any(seg_result == config_task.num_class -
                                  1) and np.any(data_result != 0):
                            handler = 1
                        else:
                            handler = 0
                else:
                    if np.any(data_result != 0):
                        handler = 1
                    else:
                        handler = 0
                n += 1

                if n > 5:
                    logger.info(
                        'tag:{}, patch_type: {}; handler: {}; times: {}; cand point:{}; cand point seg value:{}; ctr_list:{}; data.shape:{}; np.unique(seg_result):{}; np.sum(data_result):{}'
                        .format(
                            tag, patch_type, handler, n, str(cand_point_coord),
                            seg[sample_id, 0, cand_point_coord[0],
                                cand_point_coord[1], cand_point_coord[2]],
                            str(ctr_list), str(data.shape),
                            np.unique(seg_result, return_counts=True),
                            np.sum(data_result)))

    return data_result, seg_result, augs
コード例 #3
0
ファイル: train.py プロジェクト: zsj0577/u2net_torch
    def gen_batch(self, batch_size, patch_size):
        batchImg = np.zeros([
            batch_size, self.config_task.num_modality, patch_size[0],
            patch_size[1], patch_size[2]
        ])  # n,mod,d,h,w
        batchLabel = np.zeros(
            [batch_size, patch_size[0], patch_size[1],
             patch_size[2]])  # n,d,h,w
        batchWeight = np.zeros(
            [batch_size, patch_size[0], patch_size[1],
             patch_size[2]])  # n,d,h,w
        batchAugs = list()

        # import ipdb; ipdb.set_trace()
        for i in range(batch_size):
            temp_prob = np.random.uniform()
            st_time = time.time()

            handler = 0
            while handler == 0:

                t_wait = 0
                if self.trainQueue.qsize() == 0:
                    logger.info(
                        '{} self.trainQueue size = {}, filling....(start time:{})'
                        .format(self.task, self.trainQueue.qsize(),
                                tinies.datestr()))
                while self.trainQueue.qsize() == 0:
                    time.sleep(1)
                    t_wait += 1
                if t_wait > 0:
                    logger.info('{} time to fill self.trainQueue: {}'.format(
                        self.task, t_wait))

                patches = self.trainQueue.get()
                # logger.info('{} trainQueue size:{}'.format(self.task, str(self.trainQueue.qsize())))
                if i <= math.ceil(
                        batch_size / 3
                ):  # nn_unet3d: at least 1/3 samples in a batch contain at least one forground class
                    if temp_prob < self.config_task.small_prob and patches[
                            'small'] is not None:
                        patch = patches['small']
                        handler = 1
                    elif patches['fore'] is not None:
                        patch = patches['fore']
                        handler = 1
                    else:
                        handler = 0
                        logger.warn('handler={}'.format(handler))
                # else for i > math.ceil(batch_size/3)
                else:
                    if temp_prob < self.config_task.small_prob and patches[
                            'small'] is not None:
                        patch = patches['small']
                        handler = 1
                    elif 1 - temp_prob < self.config_task.fore_prob and patches[
                            'fore'] is not None:
                        patch = patches['fore']
                        handler = 1
                    else:
                        patch = patches['any']
                        handler = 1
                if handler == 0:
                    logger.info('handler is 0, going back')
            if handler == 0:
                logger.error('handler is 0')

            # fill in a batch
            batchImg[i, ...] = patch['image']
            batchLabel[i, ...] = patch['label']
            batchWeight[i, ...] = patch['weight']
            batchAugs.append(patch['augs'])

        return (batchImg, batchLabel, batchWeight, batchAugs)
コード例 #4
0
for task in args.tasks:
    config.config_tasks[task] = config.set_config_task(args.trainMode, task, config.base_dir)

if args.out_tag:
    args.out_tag = '_'+args.out_tag

#### Prepare datasets
with open(os.path.join(os.path.dirname(os.getcwd()), 'fold_splits.json'), mode='r') as f:
    tasks_archive = json.load(f) # dict: {'Task02_Heart'/...}{'fold index'}{'train'/'val'}

# seed
np.random.seed(1993)

#### prep train
if args.trainMode == "independent":
    logger.error('trainMode should be one of parallel_adapter, shared_adapter')
        
elif args.trainMode != "independent":
    ### model settings
    config.patch_size = [128,128,128]
    config.patch_weights = tinies.calPatchWeights(config.patch_size)
    
    config.out_dir = os.path.join(config.out_dir, 'res_{}_{}{}'.format(args.model, args.trainMode, args.out_tag), '_'.join(args.tasks))
    tinies.sureDir(config.out_dir)
    config.eval_out_dir = os.path.join(config.out_dir, "eval_out")
    tinies.newdir(config.eval_out_dir)

    config.log_dir = os.path.join(config.out_dir, 'train_log')
    config.writer = MySummaryWriter(log_dir=config.log_dir) # this will create log_dir
    logger.set_logger_dir(os.path.join(config.log_dir, 'logger'), action="b") # 'b' reuse log_dir and backup log.log
    logger.info('--------------------------------Training for {}: {}--------------------------------'.format(args.trainMode, '_'.join(args.tasks)))
コード例 #5
0
def batch_segmentation(config_task, temp_imgs, model):
    # temp_imgs: mod_num, D, H, W?
    model_patch_size = config.patch_size  # model patch size. if args.trainMode='independent', equal to config_task.patch_size; else, not equal.
    batch_size = config.batch_size
    num_class = config_task.num_class
    patch_weights = torch.from_numpy(config.patch_weights).float().cuda()

    data_channel, original_D, original_H, original_W = temp_imgs.shape  # data_channel = 4

    # for some cases, e.g. Task04_Hippocampus. temp_imgs[0] shape is smaller than patch_size.. pad to patch_size. remember to apply the same process to get_train_dataflow()
    # import ipdb; ipdb.set_trace()
    temp_imgs, pad_size = tinies.pad2gePatch(temp_imgs, config_task.patch_size,
                                             data_channel)

    data_channel, D, H, W = temp_imgs.shape
    # temp_prob1 = np.zeros([D, H, W, num_class])

    ### before input to model, scale the image with factor of model_patch_size/task_specific_patch_size,so as to unify the patch size to the size required by the universal pipeline model.
    st_time = time.time()

    oldShape = [D, H, W]
    if config.unifyPatch == 'resize':
        # resize all tasks images to same size for shared/universal model
        if config.trainMode in ["shared", "universal"]:

            # tb visualization
            # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00']
            tb_image = temp_imgs[0, ...]
            slice_indices = [8 * i for i in range(int(tb_image.shape[0] / 8))]
            img_fig = config.writer.tensor2figure(tb_image,
                                                  slice_indices,
                                                  colorslist=config.colorslist,
                                                  is_label=False,
                                                  fig_title='image')
            # config.writer.add_figure('figure/{}_batch_seg_temp_imgs_before_resize2modelpatch'.format(config_task.task), [img_fig], config.step)

            scale_factors = [
                model_patch_size[i] / config_task.patch_size[i]
                for i in range(len(model_patch_size))
            ]
            newShape = [
                int(oldShape[i] * scale_factors[i])
                for i in range(len(scale_factors))
            ]
            imgs_list = []
            for i in range(temp_imgs.shape[0]):
                imgs_list.append(
                    skimage.transform.resize(temp_imgs[i],
                                             output_shape=tuple(newShape),
                                             order=3,
                                             mode='constant'))  # bi-cubic.
            temp_imgs = np.asarray(imgs_list)

            # tb visualization
            # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00']
            tb_image = temp_imgs[0, ...]
            slice_indices = [8 * i for i in range(int(tb_image.shape[0] / 8))]
            img_fig = config.writer.tensor2figure(tb_image,
                                                  slice_indices,
                                                  colorslist=config.colorslist,
                                                  is_label=False,
                                                  fig_title='image')
            # config.writer.add_figure('figure/{}_batch_seg_temp_imgs_after_resize2modelpatch'.format(config_task.task), [img_fig], config.step)

    else:
        raise ValueError('{}: not yet implemented!!'.format(config.unifyPatch))

    # logger.info('resize2modelpatch time elapsed:{}'.format(tinies.timer(st_time, time.time())))

    data_channel, D, H, W = temp_imgs.shape
    temp_prob1 = np.zeros([D, H, W, num_class])

    data_mini_batch = []
    centers = []

    st_time = time.time()

    for patch_center_W in range(int(model_patch_size[2] / 2),
                                W + int(model_patch_size[2] / 2),
                                int(model_patch_size[2] / 2)):
        patch_center_W = min(patch_center_W, W - int(model_patch_size[2] / 2))
        for patch_center_H in range(int(model_patch_size[1] / 2),
                                    H + int(model_patch_size[1] / 2),
                                    int(model_patch_size[1] / 2)):
            patch_center_H = min(patch_center_H,
                                 H - int(model_patch_size[1] / 2))
            for patch_center_D in range(int(model_patch_size[0] / 2),
                                        D + int(model_patch_size[0] / 2),
                                        int(model_patch_size[0] / 2)):
                patch_center_D = min(patch_center_D,
                                     D - int(model_patch_size[0] / 2))
                temp_input_center = [
                    patch_center_D, patch_center_H, patch_center_W
                ]
                # logger.info("temp_input_center:{}".format(temp_input_center))
                # ipdb.set_trace()
                centers.append(temp_input_center)

                patch = []
                for chn in range(data_channel):
                    sub_patch = extract_roi_from_volume(temp_imgs[chn],
                                                        temp_input_center,
                                                        model_patch_size,
                                                        fill="zero")
                    patch.append(sub_patch)
                patch = np.asanyarray(patch, np.float32)  #[mod,d,h,w]
                # collect to batch
                data_mini_batch.append(patch)  # [14,4,d,h,w] # 4, modalities;

                if len(data_mini_batch) == batch_size:
                    data_mini_batch = np.asarray(data_mini_batch, np.float32)
                    # data_mini_batch = np.transpose(data_mini_batch, [0, 2, 3, 4, 1]) # batch_size, d, h, w, modality
                    # ipdb.set_trace()
                    data_mini_batch = torch.from_numpy(data_mini_batch).float(
                    ).cuda()  # numpy to torch to GPU
                    if config.trainMode == "universal":
                        prob_mini_batch1, share_map, para_map = model(
                            data_mini_batch)
                    else:
                        prob_mini_batch1 = model(data_mini_batch)

                    # if config.test_flip:
                    #     prob_mini_batch1 += model(torch.flip(data_mini_batch, [4]))

                    prob_mini_batch1 = prob_mini_batch1.detach()

                    # prob_mini_batch1 = np.transpose(prob_mini_batch1, [0,2,3,4,1]) # n,d,h,w,c
                    prob_mini_batch1 = prob_mini_batch1.permute(
                        [0, 2, 3, 4, 1])  # n,d,h,w,c

                    data_mini_batch = []
                    for batch_idx in range(prob_mini_batch1.shape[0]):
                        sub_prob = prob_mini_batch1[batch_idx]

                        for i in range(num_class):
                            # sub_prob[...,i] = np.multiply(sub_prob[...,i], config.patch_weights)
                            sub_prob[...,
                                     i] = torch.mul(sub_prob[..., i],
                                                    patch_weights)

                        sub_prob = sub_prob.cpu().numpy()

                        temp_input_center = centers[batch_idx]
                        for c in range(num_class):
                            temp_prob1[..., c] = set_roi_to_volume(
                                temp_prob1[..., c], temp_input_center,
                                sub_prob[..., c])
                    centers = []

    remainder_batch_size = len(data_mini_batch)
    if remainder_batch_size > 0 and remainder_batch_size < batch_size:
        # treat the remainder as an idependent batch as it's smaller than batch_size
        for idx in range(batch_size - len(data_mini_batch)):
            data_mini_batch.append(np.zeros(
                [data_channel] +
                model_patch_size))  # fill to full batch_size with zeros array
        data_mini_batch = np.asarray(data_mini_batch, np.float32)
        # data_mini_batch = np.transpose(data_mini_batch, [0, 2, 3, 4, 1]) # batch_size, d, h, w, modality

        data_mini_batch = torch.from_numpy(
            data_mini_batch).float().cuda()  # numpy to torch to GPU
        if config.trainMode == "universal":
            prob_mini_batch1, share_map, para_map = model(data_mini_batch)
        else:
            prob_mini_batch1 = model(data_mini_batch)
        # if config.test_flip: # flip on w axis?
        #     prob_mini_batch1 += model(torch.flip(data_mini_batch, [4]))
        prob_mini_batch1 = prob_mini_batch1.detach()
        # prob_mini_batch1 = np.transpose(prob_mini_batch1, [0,2,3,4,1])
        prob_mini_batch1 = prob_mini_batch1.permute([0, 2, 3, 4,
                                                     1])  # n,d,h,w,c
        # logger.info('prob_mini_batch1 shape:{}'.format(prob_mini_batch1.shape))

        data_mini_batch = []
        for batch_idx in range(remainder_batch_size):
            sub_prob = prob_mini_batch1[batch_idx]
            # sub_prob = np.reshape(prob_mini_batch1[batch_idx], model_patch_size + [num_class])

            for i in range(num_class):
                # sub_prob[...,i] = np.multiply(sub_prob[...,i], config.patch_weights)
                sub_prob[..., i] = torch.mul(sub_prob[..., i], patch_weights)

            sub_prob = sub_prob.cpu().numpy()

            temp_input_center = centers[batch_idx]
            for c in range(num_class):
                temp_prob1[..., c] = set_roi_to_volume(temp_prob1[..., c],
                                                       temp_input_center,
                                                       sub_prob[..., c])
    elif remainder_batch_size >= batch_size:
        logger.error(
            'the remainder data_mini_batch size is {} and batch_size = {}, code is wrong'
            .format(len(data_mini_batch), batch_size))

    logger.info('patch eval for-loop time elapsed:{}'.format(
        tinies.timer(st_time, time.time())))

    # argmax
    temp_pred1 = np.argmax(temp_prob1, axis=-1)
    # temp_pred1 = np.asarray(temp_pred1, dtype=np.uint8)

    if config.unifyPatch == 'resize':
        # resize all tasks images to same size for universal model
        if config.trainMode in ["shared", "universal"]:

            # tb visualization
            # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00']
            tb_image = temp_imgs[0, ...]
            tb_pred = temp_pred1
            slice_indices = config.writer.chooseSlices(tb_pred)
            img_fig = config.writer.tensor2figure(tb_image,
                                                  slice_indices,
                                                  colorslist=config.colorslist,
                                                  is_label=False,
                                                  fig_title='image')
            pred_fig = config.writer.tensor2figure(
                tb_pred,
                slice_indices,
                colorslist=config.colorslist,
                is_label=True,
                fig_title='pred')
            # config.writer.add_figure('figure/{}_batch_seg_temp_pred1_before_resize2originalScale'.format(config_task.task), [img_fig, pred_fig], config.step)

            # reisze
            temp_pred1 = temp_pred1.astype(
                np.float32
            )  # it will result in nothing if input an array of np.uint8 to resize(order=0)
            temp_pred1 = skimage.transform.resize(temp_pred1,
                                                  output_shape=tuple(oldShape),
                                                  order=0,
                                                  mode='constant')  # nearest.

            # tb visualization
            # colorslist=['#000000','#00FF00','#0000FF','#FF0000', '#FFFF00']
            tb_image = temp_imgs[0, ...]
            tb_pred = temp_pred1
            slice_indices = config.writer.chooseSlices(tb_pred)
            img_fig = config.writer.tensor2figure(tb_image,
                                                  slice_indices,
                                                  colorslist=config.colorslist,
                                                  is_label=False,
                                                  fig_title='image')
            pred_fig = config.writer.tensor2figure(
                tb_pred,
                slice_indices,
                colorslist=config.colorslist,
                is_label=True,
                fig_title='pred')
            # config.writer.add_figure('figure/{}_batch_seg_temp_pred1_after_resize2originalScale'.format(config_task.task), [img_fig, pred_fig], config.step)

    else:
        raise ValueError('{}: not yet implemented!!'.format(config.unifyPatch))

    temp_pred1 = np.asarray(temp_pred1, dtype=np.uint8)

    # for some cases, e.g. Task04_Hippocampus. temp_imgs[0] shape is smaller than model_patch_size.. here use crop to recover to original shape.

    if np.any(pad_size):
        temp_pred1 = temp_pred1[pad_size[0]:(original_D + pad_size[0]),
                                pad_size[1]:(original_H + pad_size[1]),
                                pad_size[2]:(original_W + pad_size[2])]

    return temp_pred1