Exemple #1
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(self.img_list[index],
                                     tool=self.configer.get('data', 'image_tool'),
                                     mode=self.configer.get('data', 'input_mode'))

        img_size = ImageHelper.get_size(img)
        bboxes, labels = self.__read_json_file(self.json_list[index])

        if self.aug_transform is not None:
            img, bboxes, labels = self.aug_transform(img, bboxes=bboxes, labels=labels)

        img_scale = ImageHelper.get_size(img)[0] / img_size[0]

        labels = torch.from_numpy(labels).long()
        bboxes = torch.from_numpy(bboxes).float()

        meta = dict(
            ori_img_size=img_size,
            border_size=ImageHelper.get_size(img),
            img_scale=img_scale,
        )
        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(
            img=DataContainer(img, stack=True),
            bboxes=DataContainer(bboxes, stack=False),
            labels=DataContainer(labels, stack=False),
            meta=DataContainer(meta, stack=False, cpu_only=True)
        )
Exemple #2
0
def stack(batch, data_key=None, trans_dict=None):
    if isinstance(batch[0][data_key], DataContainer):
        if batch[0][data_key].stack:
            assert isinstance(batch[0][data_key].data, torch.Tensor) or \
                   isinstance(batch[0], int_classes) or isinstance(batch[0], float) or \
                   isinstance(batch[0], string_classes) or isinstance(batch[0], collections.Mapping) or\
                   isinstance(batch[0], collections.Sequence)
            stacked = []
            if batch[0][data_key].samples_per_gpu:
                assert len(batch) % trans_dict['samples_per_gpu'] == 0
                for i in range(0, len(batch), trans_dict['samples_per_gpu']):
                    stacked.append(
                        default_collate([
                            sample[data_key].data
                            for sample in batch[i:i +
                                                trans_dict['samples_per_gpu']]
                        ]))
            else:
                stacked = default_collate(
                    [sample[data_key].data for sample in batch])

            if batch[0][data_key].return_dc:
                return DataContainer(stacked,
                                     batch[0][data_key].stack,
                                     batch[0][data_key].padding_value,
                                     cpu_only=batch[0][data_key].cpu_only)
            else:
                return stacked
        else:
            stacked = []
            if batch[0][data_key].samples_per_gpu:
                assert len(batch) % trans_dict['samples_per_gpu'] == 0
                for i in range(0, len(batch), trans_dict['samples_per_gpu']):
                    stacked.append([
                        sample[data_key].data
                        for sample in batch[i:i +
                                            trans_dict['samples_per_gpu']]
                    ])
            else:
                stacked = [sample[data_key].data for sample in batch]

            if batch[0][data_key].return_dc:
                return DataContainer(stacked,
                                     batch[0][data_key].stack,
                                     batch[0][data_key].padding_value,
                                     cpu_only=batch[0][data_key].cpu_only)
            else:
                return stacked
    else:
        return default_collate([sample[data_key] for sample in batch])
    def __getitem__(self, index):
        img = ImageHelper.read_image(self.img_list[index],
                                     tool=self.configer.get('data', 'image_tool'),
                                     mode=self.configer.get('data', 'input_mode'))
        label = self.label_list[index]

        if self.aug_transform is not None:
            img = self.aug_transform(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(
            img=DataContainer(img, stack=True),
            label=DataContainer(label, stack=True),
        )
Exemple #4
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        if os.path.exists(self.mask_list[index]):
            maskmap = ImageHelper.read_image(self.mask_list[index],
                                             tool=self.configer.get(
                                                 'data', 'image_tool'),
                                             mode='P')
        else:
            maskmap = np.ones((img.size[1], img.size[0]), dtype=np.uint8)
            if self.configer.get('data', 'image_tool') == 'pil':
                maskmap = ImageHelper.np2img(maskmap)

        kpts, bboxes = self.__read_json_file(self.json_list[index])

        if self.aug_transform is not None and len(bboxes) > 0:
            img, maskmap, kpts, bboxes = self.aug_transform(img,
                                                            maskmap=maskmap,
                                                            kpts=kpts,
                                                            bboxes=bboxes)

        elif self.aug_transform is not None:
            img, maskmap, kpts = self.aug_transform(img,
                                                    maskmap=maskmap,
                                                    kpts=kpts)

        width, height = ImageHelper.get_size(maskmap)
        maskmap = ImageHelper.resize(
            maskmap, (width // self.configer.get('network', 'stride'),
                      height // self.configer.get('network', 'stride')),
            interpolation='nearest')

        maskmap = torch.from_numpy(np.array(maskmap, dtype=np.float32))
        maskmap = maskmap.unsqueeze(0)
        kpts = torch.from_numpy(kpts).float()

        heatmap = self.heatmap_generator(kpts, [width, height], maskmap)
        vecmap = self.paf_generator(kpts, [width, height], maskmap)
        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(img=DataContainer(img, stack=True),
                    heatmap=DataContainer(heatmap, stack=True),
                    maskmap=DataContainer(maskmap, stack=True),
                    vecmap=DataContainer(vecmap, stack=True))
Exemple #5
0
    def todc(data_list, gpu_list, cpu_only=False):
        assert len(data_list) % len(gpu_list) == 0
        samples_per_gpu = len(data_list) // len(gpu_list)
        stacked = []
        for i in range(0, len(data_list), samples_per_gpu):
            stacked.append(data_list[i:i + samples_per_gpu])

        return DataContainer(stacked, cpu_only=cpu_only)
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        labels, bboxes, polygons = self.__read_json_file(self.json_list[index])

        if self.aug_transform is not None:
            img, bboxes, labels, polygons = self.aug_transform(
                img, bboxes=bboxes, labels=labels, polygons=polygons)

        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(img=DataContainer(img, stack=True),
                    bboxes=DataContainer(bboxes, stack=False),
                    labels=DataContainer(labels, stack=False),
                    polygons=DataContainer(polygons,
                                           stack=False,
                                           cpu_only=True))
    def __getitem__(self, index):
        img = ImageHelper.read_image(self.img_list[index],
                                     tool=self.configer.get('data', 'image_tool'),
                                     mode=self.configer.get('data', 'input_mode'))

        bboxes, labels = self.__read_json_file(self.json_list[index])

        if self.aug_transform is not None:
            img, bboxes, labels = self.aug_transform(img, bboxes=bboxes, labels=labels)

        labels = torch.from_numpy(labels).long()
        bboxes = torch.from_numpy(bboxes).float()

        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(
            img=DataContainer(img, stack=True),
            bboxes=DataContainer(bboxes, stack=False),
            labels=DataContainer(labels, stack=False),
        )
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        kpts, bboxes = self.__read_json_file(self.json_list[index])

        if self.aug_transform is not None:
            img, kpts, bboxes = self.aug_transform(img,
                                                   kpts=kpts,
                                                   bboxes=bboxes)

        kpts = torch.from_numpy(kpts).float()
        heatmap = self.heatmap_generator(kpts, ImageHelper.get_size(img))
        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(
            img=DataContainer(img, stack=True),
            heatmap=DataContainer(heatmap, stack=True),
        )
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        img_size = ImageHelper.get_size(img)
        labelmap = ImageHelper.read_image(self.label_list[index],
                                          tool=self.configer.get(
                                              'data', 'image_tool'),
                                          mode='P')
        if self.configer.exists('data', 'label_list'):
            labelmap = self._encode_label(labelmap)

        if self.configer.exists('data', 'reduce_zero_label'):
            labelmap = self._reduce_zero_label(labelmap)

        ori_target = ImageHelper.tonp(labelmap)
        ori_target[ori_target == 255] = -1

        if self.aug_transform is not None:
            img, labelmap = self.aug_transform(img, labelmap=labelmap)

        border_size = ImageHelper.get_size(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.label_transform is not None:
            labelmap = self.label_transform(labelmap)

        meta = dict(ori_img_size=img_size,
                    border_size=border_size,
                    ori_target=ori_target)
        return dict(
            img=DataContainer(img, stack=True),
            labelmap=DataContainer(labelmap, stack=True),
            meta=DataContainer(meta, stack=False, cpu_only=True),
        )
Exemple #10
0
def stack(batch, data_key=None, return_dc=False):
    if isinstance(batch[0][data_key], DataContainer):
        if batch[0][data_key].stack:
            assert isinstance(batch[0][data_key].data, torch.Tensor)
            samples = [sample[data_key].data for sample in batch]
            return default_collate(samples)

        elif not return_dc:
            return [sample[data_key].data for sample in batch]

        else:
            return DataContainer([sample[data_key].data for sample in batch])

    else:
        return default_collate([sample[data_key] for sample in batch])
    def __getitem__(self, index):
        img_out, label_out = self._get_batch_per_gpu(index)
        img_list = []
        labelmap_list = []
        for img, labelmap in zip(img_out, label_out):
            if self.configer.exists('data', 'label_list'):
                labelmap = self._encode_label(labelmap)

            if self.configer.exists('data', 'reduce_zero_label'):
                labelmap = self._reduce_zero_label(labelmap)

            if self.aug_transform is not None:
                img, labelmap = self.aug_transform(img, labelmap=labelmap)

            if self.img_transform is not None:
                img = self.img_transform(img)

            if self.label_transform is not None:
                labelmap = self.label_transform(labelmap)

            img_list.append(img)
            labelmap_list.append(labelmap)

        border_width = [sample.size(2) for sample in img_list]
        border_height = [sample.size(1) for sample in img_list]
        target_width, target_height = max(border_width), max(border_height)
        if 'fit_stride' in self.configer.get('train', 'data_transformer'):
            stride = self.configer.get('train',
                                       'data_transformer')['fit_stride']
            pad_w = 0 if (target_width % stride
                          == 0) else stride - (target_width % stride)  # right
            pad_h = 0 if (target_height % stride
                          == 0) else stride - (target_height % stride)  # down
            target_width = target_width + pad_w
            target_height = target_height + pad_h

        batch_images = torch.zeros(self.configer.get('train', 'batch_per_gpu'),
                                   3, target_height, target_width)
        batch_labels = torch.ones(self.configer.get('train', 'batch_per_gpu'),
                                  target_height, target_width)
        batch_labels = (batch_labels * -1).long()
        for i, (img, labelmap) in enumerate(zip(img_list, labelmap_list)):
            pad_width = target_width - img.size(2)
            pad_height = target_height - img.size(1)
            if self.configer.get('train',
                                 'data_transformer')['pad_mode'] == 'random':
                left_pad = random.randint(0, pad_width)  # pad_left
                up_pad = random.randint(0, pad_height)  # pad_up
            else:
                left_pad = 0
                up_pad = 0

            batch_images[i, :, up_pad:up_pad + img.size(1),
                         left_pad:left_pad + img.size(2)] = img
            batch_labels[i, up_pad:up_pad + labelmap.size(0),
                         left_pad:left_pad + labelmap.size(1)] = labelmap

        return dict(
            img=DataContainer(batch_images, stack=False),
            labelmap=DataContainer(batch_labels, stack=False),
        )
Exemple #12
0
    def __test_img(self, image_path, json_path, raw_path, vis_path):
        Log.info('Image Path: {}'.format(image_path))
        image = ImageHelper.read_image(
            image_path,
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        ori_img_bgr = ImageHelper.get_cv2_bgr(image,
                                              mode=self.configer.get(
                                                  'data', 'input_mode'))
        width, height = ImageHelper.get_size(image)
        scale1 = self.configer.get('test', 'resize_bound')[0] / min(
            width, height)
        scale2 = self.configer.get('test', 'resize_bound')[1] / max(
            width, height)
        scale = min(scale1, scale2)
        inputs = self.blob_helper.make_input(image, scale=scale)
        b, c, h, w = inputs.size()
        border_wh = [w, h]
        if self.configer.exists('test', 'fit_stride'):
            stride = self.configer.get('test', 'fit_stride')

            pad_w = 0 if (w % stride == 0) else stride - (w % stride)  # right
            pad_h = 0 if (h % stride == 0) else stride - (h % stride)  # down

            expand_image = torch.zeros(
                (b, c, h + pad_h, w + pad_w)).to(inputs.device)
            expand_image[:, :, 0:h, 0:w] = inputs
            inputs = expand_image

        data_dict = dict(
            img=inputs,
            meta=DataContainer([[
                dict(ori_img_size=ImageHelper.get_size(ori_img_bgr),
                     aug_img_size=border_wh,
                     img_scale=scale,
                     input_size=[inputs.size(3),
                                 inputs.size(2)])
            ]],
                               cpu_only=True))

        with torch.no_grad():
            # Forward pass.
            test_group = self.det_net(data_dict)

            test_indices_and_rois, test_roi_locs, test_roi_scores, test_rois_num = test_group

        batch_detections = self.decode(test_roi_locs, test_roi_scores,
                                       test_indices_and_rois, test_rois_num,
                                       self.configer,
                                       DCHelper.tolist(data_dict['meta']))
        json_dict = self.__get_info_tree(batch_detections[0],
                                         ori_img_bgr,
                                         scale=scale)

        image_canvas = self.det_parser.draw_bboxes(
            ori_img_bgr.copy(),
            json_dict,
            conf_threshold=self.configer.get('res', 'vis_conf_thre'))
        cv2.imwrite(vis_path, image_canvas)
        cv2.imwrite(raw_path, ori_img_bgr)

        Log.info('Json Path: {}'.format(json_path))
        JsonHelper.save_file(json_dict, json_path)
        return json_dict
Exemple #13
0
def collate(batch, trans_dict):
    data_keys = batch[0].keys()
    if trans_dict['size_mode'] == 'fix_size':
        target_width, target_height = trans_dict['input_size']

    elif trans_dict['size_mode'] == 'multi_size':
        ms_input_size = trans_dict['ms_input_size']
        target_width, target_height = ms_input_size[random.randint(
            0,
            len(ms_input_size) - 1)]

    elif trans_dict['size_mode'] == 'max_size':
        border_width = [sample['img'].size(2) for sample in batch]
        border_height = [sample['img'].size(1) for sample in batch]
        target_width, target_height = max(border_width), max(border_height)

    else:
        raise NotImplementedError('Size Mode {} is invalid!'.format(
            trans_dict['size_mode']))

    if 'fit_stride' in trans_dict:
        stride = trans_dict['fit_stride']
        pad_w = 0 if (target_width % stride
                      == 0) else stride - (target_width % stride)  # right
        pad_h = 0 if (target_height % stride
                      == 0) else stride - (target_height % stride)  # down
        target_width = target_width + pad_w
        target_height = target_height + pad_h

    for i in range(len(batch)):
        if 'meta' in data_keys:
            batch[i]['meta'].data['input_size'] = [target_width, target_height]

        channels, height, width = batch[i]['img'].size()
        if height == target_height and width == target_width:
            continue

        scaled_size = [width, height]

        if trans_dict['align_method'] in ['only_scale', 'scale_and_pad']:
            w_scale_ratio = target_width / width
            h_scale_ratio = target_height / height
            if trans_dict['align_method'] == 'scale_and_pad':
                w_scale_ratio = min(w_scale_ratio, h_scale_ratio)
                h_scale_ratio = w_scale_ratio

            if 'kpts' in data_keys and batch[i]['kpts'].numel() > 0:
                batch[i]['kpts'].data[:, :, 0] *= w_scale_ratio
                batch[i]['kpts'].data[:, :, 1] *= h_scale_ratio

            if 'bboxes' in data_keys and batch[i]['bboxes'].numel() > 0:
                batch[i]['bboxes'].data[:, 0::2] *= w_scale_ratio
                batch[i]['bboxes'].data[:, 1::2] *= h_scale_ratio

            if 'polygons' in data_keys:
                for object_id in range(len(batch[i]['polygons'])):
                    for polygon_id in range(
                            len(batch[i]['polygons'][object_id])):
                        batch[i]['polygons'].data[object_id][polygon_id][
                            0::2] *= w_scale_ratio
                        batch[i]['polygons'].data[object_id][polygon_id][
                            1::2] *= h_scale_ratio

            scaled_size = (int(round(width * w_scale_ratio)),
                           int(round(height * h_scale_ratio)))
            if 'meta' in data_keys and 'border_size' in batch[i]['meta'].data:
                batch[i]['meta'].data['border_size'] = scaled_size

            scaled_size_hw = (scaled_size[1], scaled_size[0])

            batch[i]['img'] = DataContainer(TensorHelper.resize(
                batch[i]['img'].data,
                scaled_size_hw,
                mode='bilinear',
                align_corners=True),
                                            stack=batch[i]['img'].stack)
            if 'labelmap' in data_keys:
                batch[i]['labelmap'] = DataContainer(
                    TensorHelper.resize(batch[i]['labelmap'].data,
                                        scaled_size_hw,
                                        mode='nearest'),
                    stack=batch[i]['labelmap'].stack)

            if 'maskmap' in data_keys:
                batch[i]['maskmap'] = DataContainer(
                    TensorHelper.resize(batch[i]['maskmap'].data,
                                        scaled_size_hw,
                                        mode='nearest'),
                    stack=batch[i]['maskmap'].stack)

        pad_width = target_width - scaled_size[0]
        pad_height = target_height - scaled_size[1]
        assert pad_height >= 0 and pad_width >= 0
        if pad_width > 0 or pad_height > 0:
            assert trans_dict['align_method'] in ['only_pad', 'scale_and_pad']
            left_pad, up_pad = None, None
            if 'pad_mode' not in trans_dict or trans_dict[
                    'pad_mode'] == 'random':
                left_pad = random.randint(0, pad_width)  # pad_left
                up_pad = random.randint(0, pad_height)  # pad_up

            elif trans_dict['pad_mode'] == 'pad_border':
                direction = random.randint(0, 1)
                left_pad = pad_width if direction == 0 else 0
                up_pad = pad_height if direction == 0 else 0

            elif trans_dict['pad_mode'] == 'pad_left_up':
                left_pad = pad_width
                up_pad = pad_height

            elif trans_dict['pad_mode'] == 'pad_right_down':
                left_pad = 0
                up_pad = 0

            elif trans_dict['pad_mode'] == 'pad_center':
                left_pad = pad_width // 2
                up_pad = pad_height // 2

            else:
                Log.error('Invalid pad mode: {}'.format(
                    trans_dict['pad_mode']))
                exit(1)

            pad = (left_pad, pad_width - left_pad, up_pad, pad_height - up_pad)

            batch[i]['img'] = DataContainer(F.pad(batch[i]['img'].data,
                                                  pad=pad,
                                                  value=0),
                                            stack=batch[i]['img'].stack)

            if 'labelmap' in data_keys:
                batch[i]['labelmap'] = DataContainer(
                    F.pad(batch[i]['labelmap'].data, pad=pad, value=-1),
                    stack=batch[i]['labelmap'].stack)

            if 'maskmap' in data_keys:
                batch[i]['maskmap'] = DataContainer(
                    F.pad(batch[i]['maskmap'].data, pad=pad, value=1),
                    stack=batch[i]['maskmap'].stack)

            if 'polygons' in data_keys:
                for object_id in range(len(batch[i]['polygons'])):
                    for polygon_id in range(
                            len(batch[i]['polygons'][object_id])):
                        batch[i]['polygons'].data[object_id][polygon_id][
                            0::2] += left_pad
                        batch[i]['polygons'].data[object_id][polygon_id][
                            1::2] += up_pad

            if 'kpts' in data_keys and batch[i]['kpts'].numel() > 0:
                batch[i]['kpts'].data[:, :, 0] += left_pad
                batch[i]['kpts'].data[:, :, 1] += up_pad

            if 'bboxes' in data_keys and batch[i]['bboxes'].numel() > 0:
                batch[i]['bboxes'].data[:, 0::2] += left_pad
                batch[i]['bboxes'].data[:, 1::2] += up_pad

    return dict({
        key: stack(batch, data_key=key, trans_dict=trans_dict)
        for key in data_keys
    })
Exemple #14
0
def collate(batch, trans_dict):
    data_keys = batch[0].keys()
    if trans_dict['size_mode'] == 'ade20k':
        return dict({key: stack(batch, data_key=key, return_dc=True) for key in data_keys})

    elif trans_dict['size_mode'] == 'random_size':
        target_width, target_height = batch[0]['img'].size(2), batch[0]['img'].size(1)

    elif trans_dict['size_mode'] == 'fix_size':
        target_width, target_height = trans_dict['input_size']

    elif trans_dict['size_mode'] == 'multi_size':
        ms_input_size = trans_dict['ms_input_size']
        target_width, target_height = ms_input_size[random.randint(0, len(ms_input_size) - 1)]

    elif trans_dict['size_mode'] == 'max_size':
        border_width = [sample['img'].size(2) for sample in batch]
        border_height = [sample['img'].size(1) for sample in batch]
        target_width, target_height = max(border_width), max(border_height)

    else:
        raise NotImplementedError('Size Mode {} is invalid!'.format(trans_dict['size_mode']))

    if 'fit_stride' in trans_dict:
        stride = trans_dict['fit_stride']
        pad_w = 0 if (target_width % stride == 0) else stride - (target_width % stride)  # right
        pad_h = 0 if (target_height % stride == 0) else stride - (target_height % stride)  # down
        target_width = target_width + pad_w
        target_height = target_height + pad_h

    for i in range(len(batch)):
        if 'meta' in data_keys:
            batch[i]['meta'].data['input_size'] = [target_width, target_height]

        channels, height, width = batch[i]['img'].size()
        if height == target_height and width == target_width:
            continue

        scaled_size = [width, height]

        if trans_dict['align_method'] in ['only_scale', 'scale_and_pad']:
            w_scale_ratio = target_width / width
            h_scale_ratio = target_height / height
            if trans_dict['align_method'] == 'scale_and_pad':
                w_scale_ratio = min(w_scale_ratio, h_scale_ratio)
                h_scale_ratio = w_scale_ratio

            scaled_size = (int(round(width * w_scale_ratio)), int(round(height * h_scale_ratio)))
            if 'meta' in data_keys and 'border_size' in batch[i]['meta'].data:
                batch[i]['meta'].data['border_size'] = scaled_size

            scaled_size_hw = (scaled_size[1], scaled_size[0])
            batch[i]['img'] = DataContainer(F.interpolate(batch[i]['img'].data.unsqueeze(0),
                                            scaled_size_hw, mode='bilinear', align_corners=True).squeeze(0), stack=True)
            if 'labelmap' in data_keys:
                labelmap = batch[i]['labelmap'].data.unsqueeze(0).unsqueeze(0).float()
                labelmap = F.interpolate(labelmap, scaled_size_hw, mode='nearest').long().squeeze(0).squeeze(0)
                batch[i]['labelmap'] = DataContainer(labelmap, stack=True)

            if 'maskmap' in data_keys:
                maskmap = batch[i]['maskmap'].data.unsqueeze(0).unsqueeze(0).float()
                maskmap = F.interpolate(maskmap, scaled_size_hw, mode='nearest').long().squeeze(0).squeeze(0)
                batch[i]['maskmap'].data = DataContainer(maskmap, stack=True)

        pad_width = target_width - scaled_size[0]
        pad_height = target_height - scaled_size[1]
        assert pad_height >= 0 and pad_width >= 0
        if pad_width > 0 or pad_height > 0:
            assert trans_dict['align_method'] in ['only_pad', 'scale_and_pad']
            left_pad = 0
            up_pad = 0
            if 'pad_mode' not in trans_dict or trans_dict['pad_mode'] == 'random':
                left_pad = random.randint(0, pad_width)  # pad_left
                up_pad = random.randint(0, pad_height)  # pad_up

            elif trans_dict['pad_mode'] == 'pad_left_up':
                left_pad = pad_width
                up_pad = pad_height

            elif trans_dict['pad_mode'] == 'pad_right_down':
                left_pad = 0
                up_pad = 0

            elif trans_dict['pad_mode'] == 'pad_center':
                left_pad = pad_width // 2
                up_pad = pad_height // 2

            elif trans_dict['pad_mode'] == 'pad_border':
                if random.randint(0, 1) == 0:
                    left_pad = pad_width
                    up_pad = pad_height
                else:
                    left_pad = 0
                    up_pad = 0
            else:
                Log.error('Invalid pad mode: {}'.format(trans_dict['pad_mode']))
                exit(1)

            pad = (left_pad, pad_width-left_pad, up_pad, pad_height-up_pad)

            batch[i]['img'] = DataContainer(F.pad(batch[i]['img'].data, pad=pad, value=0), stack=True)

            if 'labelmap' in data_keys:
                batch[i]['labelmap'] = DataContainer(F.pad(batch[i]['labelmap'].data, pad=pad, value=-1), stack=True)

            if 'maskmap' in data_keys:
                batch[i]['maskmap'] = DataContainer(F.pad(batch[i]['maskmap'].data, pad=pad, value=1), stack=True)

    return dict({key: stack(batch, data_key=key) for key in data_keys})