Exemplo n.º 1
0
class Tester(object):
    def __init__(self, configer):
        self.crop_size = configer.get('train',
                                      'data_transformer')['input_size']
        val_trans_seq = [
            x for x in configer.get('val_trans', 'trans_seq')
            if 'random' not in x
        ]
        configer.update(('val_trans', 'trans_seq'), val_trans_seq)
        configer.get('val', 'data_transformer')['input_size'] = configer.get(
            'test', 'data_transformer').get('input_size', None)
        configer.update(('train', 'data_transformer'),
                        configer.get('val', 'data_transformer'))
        configer.update(('val', 'batch_size'),
                        int(os.environ.get('batch_size', 16)))
        configer.update(('test', 'batch_size'),
                        int(os.environ.get('batch_size', 16)))

        self.save_dir = configer.get('test', 'out_dir')
        self.dataset_name = configer.get('test', 'eval_set')
        self.sscrop = configer.get('test', 'sscrop')

        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)
        self.seg_net = None
        self.test_loader = None
        self.test_size = None
        self.infer_time = 0
        self.infer_cnt = 0
        self._init_model()

        pprint.pprint(configer.params_root)

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        assert self.dataset_name in ('train', 'val',
                                     'test'), 'Cannot infer dataset name'

        self.size_mode = self.configer.get(self.dataset_name,
                                           'data_transformer')['size_mode']

        if self.dataset_name != 'test':
            self.test_loader = self.seg_data_loader.get_valloader(
                self.dataset_name)
        else:
            self.test_loader = self.seg_data_loader.get_testloader(
                self.dataset_name)
        self.test_size = len(self.test_loader) * self.configer.get(
            'val', 'batch_size')

    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        print('Total batches', len(self.test_loader))
        for j, data_dict in enumerate(self.test_loader):
            inputs = [data_dict['img']]
            names = data_dict['name']
            metas = data_dict['meta']

            dest_dir = self.save_dir

            with torch.no_grad():
                offsets, logits = self.extract_offset(inputs)
                print([x.shape for x in logits])
                for k in range(len(inputs[0])):
                    image_id += 1
                    ori_img_size = metas[k]['ori_img_size']
                    border_size = metas[k]['border_size']
                    offset = offsets[k].squeeze().cpu().numpy()
                    offset = cv2.resize(
                        offset[:border_size[1], :border_size[0]],
                        tuple(ori_img_size),
                        interpolation=cv2.INTER_NEAREST)
                    print(image_id)

                    os.makedirs(dest_dir, exist_ok=True)

                    if names[k].rpartition('.')[0]:
                        dest_name = names[k].rpartition('.')[0] + '.mat'
                    else:
                        dest_name = names[k] + '.mat'
                    dest_name = os.path.join(dest_dir, dest_name)
                    print('Shape:', offset.shape, 'Saving to', dest_name)

                    data_dict = {'mat': offset}

                    scipy.io.savemat(dest_name, data_dict, do_compression=True)
                    try:
                        scipy.io.loadmat(dest_name)
                    except Exception as e:
                        print(e)
                        scipy.io.savemat(dest_name,
                                         data_dict,
                                         do_compression=False)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))

    def extract_offset(self, inputs):
        if self.sscrop:
            outputs = self.sscrop_test(inputs, self.crop_size)
        elif self.configer.get('test', 'mode') == 'ss_test':
            outputs = self.ss_test(inputs)

        offsets = []
        logits = []

        for mask_logits, dir_logits, img in zip(*outputs[:2], inputs[0]):
            h, w = img.shape[1:]

            mask_logits = F.interpolate(mask_logits.unsqueeze(0),
                                        size=(h, w),
                                        mode='bilinear',
                                        align_corners=True)
            dir_logits = F.interpolate(dir_logits.unsqueeze(0),
                                       size=(h, w),
                                       mode='bilinear',
                                       align_corners=True)

            logit = torch.softmax(dir_logits, dim=1)
            zero_mask = mask_logits.argmax(dim=1, keepdim=True) == 0
            logits.append(mask_logits[:, 1])

            offset = self._get_offset(mask_logits, dir_logits)
            offsets.append(offset)
        print([x.shape for x in offsets])
        return offsets, logits

    def _get_offset(self, mask_logits, dir_logits):

        edge_mask = mask_logits[:, 1] > 0.5
        dir_logits = torch.softmax(dir_logits, dim=1)
        n, _, h, w = dir_logits.shape

        keep_mask = edge_mask

        dir_label = torch.argmax(dir_logits, dim=1).float()
        offset = DTOffsetHelper.label_to_vector(dir_label)
        offset = offset.permute(0, 2, 3, 1)
        offset[~keep_mask, :] = 0
        return offset

    def _flip(self, x, dim=-1):
        indices = [slice(None)] * x.dim()
        indices[dim] = torch.arange(x.size(dim) - 1,
                                    -1,
                                    -1,
                                    dtype=torch.long,
                                    device=x.device)
        return x[tuple(indices)]

    def _flip_offset(self, x):
        x = self._flip(x, dim=-1)
        if len(x.shape) == 4:
            return x[:, DTOffsetHelper.flipping_indices()]
        else:
            return x[DTOffsetHelper.flipping_indices()]

    def _flip_inputs(self, inputs):

        if self.size_mode == 'fix_size':
            return [self._flip(x, -1) for x in inputs]
        else:
            return [[self._flip(x, -1) for x in xs] for xs in inputs]

    def _flip_outputs(self, outputs):
        funcs = [self._flip, self._flip_offset]
        if self.size_mode == 'fix_size':
            return [f(x) for f, x in zip(funcs, outputs)]
        else:
            return [[f(x) for x in xs] for f, xs in zip(funcs, outputs)]

    def _tuple_sum(self, tup1, tup2, tup2_weight=1):
        """
        tup1 / tup2: tuple of tensors or tuple of list of tensors
        """

        if tup1 is None:
            if self.size_mode == 'fix_size':
                return [y * tup2_weight for y in tup2]
            else:
                return [[y * tup2_weight for y in ys] for ys in tup2]
        else:
            if self.size_mode == 'fix_size':
                return [x + y * tup2_weight for x, y in zip(tup1, tup2)]
            else:
                return [[x + y * tup2_weight for x, y in zip(xs, ys)]
                        for xs, ys in zip(tup1, tup2)]

    def _scale_ss_inputs(self, inputs, scale):
        n, c, h, w = inputs[0].shape
        size = (int(h * scale), int(w * scale))
        return [
            F.interpolate(inputs[0],
                          size=size,
                          mode="bilinear",
                          align_corners=True),
        ], (h, w)

    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        scaled_inputs = inputs
        img = scaled_inputs[0]
        n, c, h, w = img.size(0), img.size(1), img.size(2), img.size(3)
        ori_h, ori_w = h, w
        full_probs = [
            torch.cuda.FloatTensor(n, dim, h, w).fill_(0) for dim in (2, 8)
        ]
        count_predictions = [
            torch.cuda.FloatTensor(n, dim, h, w).fill_(0) for dim in (2, 8)
        ]

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = [
                    x[..., height:height + crop_size[0],
                      width:width + crop_size[1]] for x in scaled_inputs
                ]
                prediction = self.ss_test(crop_inputs)

                for j in range(2):
                    count_predictions[j][:, :, height:height + crop_size[0],
                                         width:width + crop_size[1]] += 1
                    full_probs[j][:, :, height:height + crop_size[0],
                                  width:width + crop_size[1]] += prediction[j]
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        for j in range(2):
            full_probs[j] /= count_predictions[j]
            full_probs[j] = F.interpolate(full_probs[j],
                                          size=(ori_h, ori_w),
                                          mode='bilinear',
                                          align_corners=True)
        return full_probs

    def _scale_ss_outputs(self, outputs, size):
        return [
            F.interpolate(x, size=size, mode="bilinear", align_corners=True)
            for x in outputs
        ]

    def ss_test(self, inputs, scale=1):
        if self.size_mode == 'fix_size':

            scaled_inputs, orig_size = self._scale_ss_inputs(inputs, scale)
            print([x.shape for x in scaled_inputs])

            start = timeit.default_timer()
            outputs = list(self.seg_net.forward(*scaled_inputs))
            if len(outputs) == 3:
                outputs = (outputs[0], outputs[2])
            else:
                outputs[0] = F.softmax(outputs[0], dim=1)
            torch.cuda.synchronize()
            end = timeit.default_timer()

            return self._scale_ss_outputs(outputs, orig_size)

        else:
            device_ids = self.configer.get('gpu')
            replicas = nn.parallel.replicate(self.seg_net.module, device_ids)
            scaled_inputs, ori_sizes, outputs = [], [], []

            for *i, d in zip(*inputs, device_ids):
                scaled_i, ori_size_i = self._scale_ss_inputs(
                    [x.unsqueeze(0) for x in i], scale)
                scaled_inputs.append(
                    [x.cuda(d, non_blocking=True) for x in scaled_i])
                ori_sizes.append(ori_size_i)

            scaled_outputs = nn.parallel.parallel_apply(
                replicas[:len(scaled_inputs)], scaled_inputs)

            for o, ori_size in zip(scaled_outputs, ori_sizes):
                o = self._scale_ss_outputs(o, ori_size)
                if len(o) == 3:
                    o = (o[0], o[2])
                outputs.append([x.squeeze(0) for x in o])
            outputs = list(map(list, zip(*outputs)))
            return outputs

    def _decide_intersection(self,
                             total_length,
                             crop_length,
                             crop_stride_ratio=1 / 3):
        stride = int(crop_length *
                     crop_stride_ratio)  # set the stride as the paper do
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride * i)

        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length -
                                    crop_length)  # must cover the total image

        return cropped_starting
Exemplo n.º 2
0
class ImageTester(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.seg_data_loader = DataLoader(configer)
        self.save_dir = self.configer.get('test', 'out_dir')
        self.seg_net = None
        self.test_loader = None
        self.test_size = None
        self.infer_time = 0
        self.infer_cnt = 0
        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        print(f"self.save_dir {self.save_dir}")
        # if 'test' in self.save_dir:
        #     self.test_loader = self.seg_data_loader.get_testloader()
        #     self.test_size = len(self.test_loader) * self.configer.get('test', 'batch_size')
        #     print(f"self.test_size {self.test_size}")
        # else:
        #     self.test_loader = self.seg_data_loader.get_valloader()
        #     self.test_size = len(self.test_loader) * self.configer.get('val', 'batch_size')

        self.seg_net.eval()

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get('data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self, img_path=None, output_dir=None, data_loader=None):
        """
          Validation function during the train phase.
        """
        print("test!!!")
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        colors = get_ade_colors()

        # Reader.
        if img_path is not None:
            input_path = img_path
        else:
            input_path = self.configer.get('input_image')

        input_image = cv2.imread(input_path)

        transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get('normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')), ])

        aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='val')

        pre_vis_img = None
        pre_lines = None
        pre_target_img = None
        
        ori_img = input_image.copy()

        h, w, _ = input_image.shape
        ori_img_size = [w, h]

        # print(img.shape)
        input_image = aug_val_transform(input_image)
        input_image = input_image[0]
            
        h, w, _ = input_image.shape
        border_size = [w, h]

        input_image = transform(input_image)
        # print(img)
        # print(img.shape)

        # inputs = data_dict['img']
        # names = data_dict['name']
        # metas = data_dict['meta']
        
        # print(inputs)

        with torch.no_grad():
            # Forward pass.
            outputs = self.ss_test([input_image])

            if isinstance(outputs, torch.Tensor):
                outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                n = outputs.shape[0]
            else:
                outputs = [output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs]
                n = len(outputs)

            logits = cv2.resize(outputs[0],
                                tuple(ori_img_size), interpolation=cv2.INTER_CUBIC)
            label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8)
            if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'):
                label_img = label_img + 1
                label_img = label_img.astype(np.uint8)
            if self.configer.exists('data', 'label_list'):
                label_img_ = self.__relabel(label_img)
            else:
                label_img_ = label_img
            label_img_ = Image.fromarray(label_img_, 'P')

            input_name = '.'.join(os.path.basename(input_path).split('.')[:-1])
            if output_dir is None:
                label_path = os.path.join(self.save_dir, 'label_{}.png'.format(input_name))
            else:
                label_path = os.path.join(output_dir, 'label_{}.png'.format(input_name))
            FileHelper.make_dirs(label_path, is_file=True)
            # print(f"{label_path}")
            ImageHelper.save(label_img_, label_path)

        self.batch_time.update(time.time() - start_time)

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time))


    def offset_test(self, inputs, offset_h_maps, offset_w_maps, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(inputs, offset_h_maps, offset_w_maps)
            torch.cuda.synchronize()
            end = timeit.default_timer()

            if (self.configer.get('loss', 'loss_type') == "fs_auxce_loss") or (self.configer.get('loss', 'loss_type') == "triple_auxce_loss"):
                outputs = outputs[-1]
            elif self.configer.get('loss', 'loss_type') == "pyramid_auxce_loss":
                outputs = outputs[1] + outputs[2] + outputs[3] + outputs[4]

            outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True)
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def ss_test(self, inputs, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            scaled_inputs = F.interpolate(inputs, size=(int(h*scale), int(w*scale)), mode="bilinear", align_corners=True)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(scaled_inputs)
            torch.cuda.synchronize()
            end = timeit.default_timer()
            outputs = outputs[-1]
            outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True)
            return outputs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            replicas = nn.parallel.replicate(self.seg_net.module, device_ids)
            scaled_inputs, ori_size, outputs = [], [], []
            for i, d in zip(inputs, device_ids):
                h, w = i.size(1), i.size(2)
                ori_size.append((h, w))
                i = F.interpolate(i.unsqueeze(0), size=(int(h*scale), int(w*scale)), mode="bilinear", align_corners=True)
                scaled_inputs.append(i.cuda(d, non_blocking=True))
            scaled_outputs = nn.parallel.parallel_apply(replicas[:len(scaled_inputs)], scaled_inputs)
            for i, output in enumerate(scaled_outputs):
                outputs.append(F.interpolate(output[-1], size=ori_size[i], mode='bilinear', align_corners=True))
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def flip(self, x, dim):
        indices = [slice(None)] * x.dim()
        indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
                                    dtype=torch.long, device=x.device)
        return x[tuple(indices)]


    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        n, c, ori_h, ori_w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
        scaled_inputs = F.interpolate(inputs, size=(int(ori_h*scale), int(ori_w*scale)), mode="bilinear", align_corners=True)
        n, c, h, w = scaled_inputs.size(0), scaled_inputs.size(1), scaled_inputs.size(2), scaled_inputs.size(3)
        full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        count_predictions = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = scaled_inputs[:, :, height:height+crop_size[0], width:width + crop_size[1]]
                prediction = self.ss_test(crop_inputs)
                count_predictions[:, :, height:height+crop_size[0], width:width + crop_size[1]] += 1
                full_probs[:, :, height:height+crop_size[0], width:width + crop_size[1]] += prediction 
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        full_probs /= count_predictions
        full_probs = F.interpolate(full_probs, size=(ori_h, ori_w), mode='bilinear', align_corners=True)
        return full_probs


    def ms_test(self, inputs):
        if isinstance(inputs, torch.Tensor):  
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(self.configer.get('test', 'scale_search'), self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += weight * probs
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += probs
                return full_probs

        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [torch.zeros(1, self.configer.get('data', 'num_classes'), 
                i.size(1), i.size(2)).cuda(device_ids[index], non_blocking=True)
                for index, i in enumerate(inputs)]
            flip_inputs = [self.flip(i, 2) for i in inputs]

            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(self.configer.get('test', 'scale_search'), self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += weight * (probs[i] + self.flip(flip_probs[i], 3))
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += (probs[i] + self.flip(flip_probs[i], 3))
                return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def ms_test_depth(self, inputs, names):
        prob_list = []
        scale_list = []

        if isinstance(inputs, torch.Tensor):  
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                prob_list.append(probs)
                scale_list.append(scale)

            full_probs = self.fuse_with_depth(prob_list, scale_list, names)
            return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def fuse_with_depth(self, probs, scales, names):
        MAX_DEPTH = 63
        POWER_BASE = 0.8
        if 'test' in self.save_dir:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/test/"
        else:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/val/"

        n, c, h, w = probs[0].size(0), probs[0].size(1), probs[0].size(2), probs[0].size(3)
        full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        for index, name in enumerate(names):
            stereo_map = cv2.imread(stereo_path + name + '.png', -1)
            depth_map = stereo_map / 256.0
            depth_map = 0.5 / depth_map
            depth_map = 500 * depth_map

            depth_map = np.clip(depth_map, 0, MAX_DEPTH)
            depth_map = depth_map // (MAX_DEPTH // len(scales))

            for prob, scale in zip(probs, scales):
                scale_index = self._locate_scale_index(scale, scales)
                weight_map = np.abs(depth_map - scale_index)
                weight_map = np.power(POWER_BASE, weight_map)
                weight_map = cv2.resize(weight_map, (w, h))
                full_probs[index, :, :, :] += torch.from_numpy(np.expand_dims(weight_map, axis=0)).type(torch.cuda.FloatTensor) * prob[index, :, :, :]

        return full_probs

    @staticmethod
    def _locate_scale_index(scale, scales):
        for idx, s in enumerate(scales):
            if scale == s:
                return idx
        return 0


    def ms_test_wo_flip(self, inputs):
        if isinstance(inputs, torch.Tensor):  
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                full_probs += probs
            return full_probs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [torch.zeros(1, self.configer.get('data', 'num_classes'), 
                i.size(1), i.size(2)).cuda(device_ids[index], non_blocking=True)
                for index, i, in enumerate(inputs)]
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                for i in range(len(inputs)):
                    full_probs[i] += probs[i]
            return full_probs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def mscrop_test(self, inputs, crop_size):  
        '''
        Currently, mscrop_test does not support diverse_size testing
        '''
        n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
        full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        for scale in self.configer.get('test', 'scale_search'):
            Log.info('Scale {0:.2f} prediction'.format(scale))
            if scale < 1:
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
            else:
                probs = self.sscrop_test(inputs, crop_size, scale)
                flip_probs = self.sscrop_test(self.flip(inputs, 3), crop_size, scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
        return full_probs


    def _decide_intersection(self, total_length, crop_length):
        stride = crop_length
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride*i)
        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length - crop_length)  # must cover the total image
        return cropped_starting


    def dense_crf_process(self, images, outputs):
        '''
        Reference: https://github.com/kazuto1011/deeplab-pytorch/blob/master/libs/utils/crf.py
        '''
        # hyperparameters of the dense crf 
        # baseline = 79.5
        # bi_xy_std = 67, 79.1
        # bi_xy_std = 20, 79.6
        # bi_xy_std = 10, 79.7
        # bi_xy_std = 10, iter_max = 20, v4 79.7
        # bi_xy_std = 10, iter_max = 5, v5 79.7
        # bi_xy_std = 5, v3 79.7
        iter_max = 10
        pos_w = 3
        pos_xy_std = 1
        bi_w = 4
        bi_xy_std = 10
        bi_rgb_std = 3

        b = images.size(0)
        mean_vector = np.expand_dims(np.expand_dims(np.transpose(np.array([102.9801, 115.9465, 122.7717])), axis=1), axis=2)
        outputs = F.softmax(outputs, dim=1)
        for i in range(b):
            unary = outputs[i].data.cpu().numpy()
            C, H, W = unary.shape
            unary = dcrf_utils.unary_from_softmax(unary)
            unary = np.ascontiguousarray(unary)
            
            image = np.ascontiguousarray(images[i]) + mean_vector
            image = image.astype(np.ubyte)
            image = np.ascontiguousarray(image.transpose(1, 2, 0))

            d = dcrf.DenseCRF2D(W, H, C)
            d.setUnaryEnergy(unary)
            d.addPairwiseGaussian(sxy=pos_xy_std, compat=pos_w)
            d.addPairwiseBilateral(sxy=bi_xy_std, srgb=bi_rgb_std, rgbim=image, compat=bi_w)
            out_crf = np.array(d.inference(iter_max))
            outputs[i] = torch.from_numpy(out_crf).cuda().view(C, H, W)

        return outputs
    

    def visualize(self, label_img):
        img = label_img.copy()
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        img_num = 14
        img[img == img_num] = 0
        out_img = img.copy()
        for label in HYUNDAI_POC_CATEGORIES:
            red, green, blue = img[:,:,0], img[:,:,1], img[:,:,2] 
            mask = red == label['id']
            out_img[:,:,:3][mask] = label['color']

        out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)

        return out_img
    
    def get_ratio_all(self, anno_img):
        total_size = 0
        lines = []
        for label in HYUNDAI_POC_CATEGORIES:
            total = self.get_ratio(anno_img.copy(), label)
            lines.append([label['name'], total])
            total_size += total
        for l in lines:
            if total_size:
                l[1] = l[1] / total_size * 100
            else:
                l[1] = 0
        return lines
    
    def get_ratio(self, anno_img, label):
        total = 0
        label_id = label['id']
        if label_id == 14:
            return total
        label_img = (anno_img == label_id).astype(np.uint8)
        # label_img = cv2.cvtColor(label_img, cv2.COLOR_BGR2GRAY)
        total = np.count_nonzero(label_img)
        return total

    def visualize_ratio(self, ratios, video_size, ratio_w):
        ratio_list = ratios.copy()
        ratio_list.insert(0, ['등급','비율'])
        RATIO_IMG_W = ratio_w
        RATIO_IMG_H = int(video_size[0])
        TEXT_MARGIN_H = 20
        TEXT_MARGIN_W = 10
        row_count = 14
        col_count = 2

        ratio_img = np.full((RATIO_IMG_H, RATIO_IMG_W, 3), 255, np.uint8)

        row_h = RATIO_IMG_H / row_count
        col_w = RATIO_IMG_H / row_count

        center_w = RATIO_IMG_W / 2

        for i in range(1, row_count):
            p_y = int(i * row_h)
            p_y_n = int((i+1) * row_h)
            for label in HYUNDAI_POC_CATEGORIES:
                if label['id'] == i:
                    cv2.rectangle(ratio_img, (0, p_y), (int(center_w), p_y_n), label['color'], cv2.FILLED)

        for i in range(1, row_count):
            p_y = int(i * row_h)
            cv2.line(ratio_img, (0, p_y), (RATIO_IMG_W, p_y), (0,0,0))

        cv2.line(ratio_img, (int(center_w), 0), (int(center_w), RATIO_IMG_H), (0,0,0))

        for i in range(row_count):
            p_y = int(i * row_h) + TEXT_MARGIN_H
            p_w = int(center_w) + TEXT_MARGIN_W
            ratio_img = Image.fromarray(ratio_img)
            font = ImageFont.truetype("NanumGothic.ttf", 15)
            draw = ImageDraw.Draw(ratio_img)
            color = (0, 0, 0)
            # print(ratio_list)
            draw.text((0, p_y), ratio_list[i][0], font=font, fill=color)
            if isinstance(ratio_list[i][1], str):
                draw.text((p_w, p_y), ratio_list[i][1],font=font,fill=color)
            else:
                draw.text((p_w, p_y), "{:.02f}".format(ratio_list[i][1]),font=font,fill=color)
            ratio_img = np.array(ratio_img)

        ratio_img = cv2.cvtColor(ratio_img, cv2.COLOR_RGB2BGR)
        return ratio_img
Exemplo n.º 3
0
class Tester(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.seg_data_loader = DataLoader(configer)
        self.save_dir = self.configer.get('test', 'out_dir')
        self.seg_net = None
        self.test_loader = None
        self.test_size = None
        self.infer_time = 0
        self.infer_cnt = 0
        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        if 'test' in self.save_dir:
            self.test_loader = self.seg_data_loader.get_testloader()
            self.test_size = len(self.test_loader) * self.configer.get(
                'test', 'batch_size')
        else:
            self.test_loader = self.seg_data_loader.get_valloader()
            self.test_size = len(self.test_loader) * self.configer.get(
                'val', 'batch_size')

        self.seg_net.eval()

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get(
                'data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        if self.configer.get('dataset') in ['cityscapes', 'gta5']:
            colors = get_cityscapes_colors()
        elif self.configer.get('dataset') == 'ade20k':
            colors = get_ade_colors()
        elif self.configer.get('dataset') == 'lip':
            colors = get_lip_colors()
        elif self.configer.get('dataset') == 'pascal_context':
            colors = get_pascal_context_colors()
        elif self.configer.get('dataset') == 'pascal_voc':
            colors = get_pascal_voc_colors()
        elif self.configer.get('dataset') == 'coco_stuff':
            colors = get_cocostuff_colors()
        else:
            raise RuntimeError("Unsupport colors")

        save_prob = False
        if self.configer.get('test', 'save_prob'):
            save_prob = self.configer.get('test', 'save_prob')

            def softmax(X, axis=0):
                max_prob = np.max(X, axis=axis, keepdims=True)
                X -= max_prob
                X = np.exp(X)
                sum_prob = np.sum(X, axis=axis, keepdims=True)
                X /= sum_prob
                return X

        for j, data_dict in enumerate(self.test_loader):
            inputs = data_dict['img']
            names = data_dict['name']
            metas = data_dict['meta']

            if 'val' in self.save_dir and os.environ.get('save_gt_label'):
                labels = data_dict['labelmap']

            with torch.no_grad():
                # Forward pass.
                if self.configer.exists('data',
                                        'use_offset') and self.configer.get(
                                            'data', 'use_offset') == 'offline':
                    offset_h_maps = data_dict['offsetmap_h']
                    offset_w_maps = data_dict['offsetmap_w']
                    outputs = self.offset_test(inputs, offset_h_maps,
                                               offset_w_maps)
                elif self.configer.get('test', 'mode') == 'ss_test':
                    outputs = self.ss_test(inputs)
                elif self.configer.get('test', 'mode') == 'ms_test':
                    outputs = self.ms_test(inputs)
                elif self.configer.get('test', 'mode') == 'ms_test_depth':
                    outputs = self.ms_test_depth(inputs, names)
                elif self.configer.get('test', 'mode') == 'sscrop_test':
                    crop_size = self.configer.get('test', 'crop_size')
                    outputs = self.sscrop_test(inputs, crop_size)
                elif self.configer.get('test', 'mode') == 'mscrop_test':
                    crop_size = self.configer.get('test', 'crop_size')
                    outputs = self.mscrop_test(inputs, crop_size)
                elif self.configer.get('test', 'mode') == 'crf_ss_test':
                    outputs = self.ss_test(inputs)
                    outputs = self.dense_crf_process(inputs, outputs)

                if isinstance(outputs, torch.Tensor):
                    outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                    n = outputs.shape[0]
                else:
                    outputs = [
                        output.permute(0, 2, 3, 1).cpu().numpy().squeeze()
                        for output in outputs
                    ]
                    n = len(outputs)

                for k in range(n):
                    image_id += 1
                    ori_img_size = metas[k]['ori_img_size']
                    border_size = metas[k]['border_size']
                    logits = cv2.resize(
                        outputs[k][:border_size[1], :border_size[0]],
                        tuple(ori_img_size),
                        interpolation=cv2.INTER_CUBIC)

                    # save the logits map
                    if self.configer.get('test', 'save_prob'):
                        prob_path = os.path.join(self.save_dir, "prob/",
                                                 '{}.npy'.format(names[k]))
                        FileHelper.make_dirs(prob_path, is_file=True)
                        np.save(prob_path, softmax(logits, axis=-1))

                    label_img = np.asarray(np.argmax(logits, axis=-1),
                                           dtype=np.uint8)
                    if self.configer.exists(
                            'data', 'reduce_zero_label') and self.configer.get(
                                'data', 'reduce_zero_label'):
                        label_img = label_img + 1
                        label_img = label_img.astype(np.uint8)
                    if self.configer.exists('data', 'label_list'):
                        label_img_ = self.__relabel(label_img)
                    else:
                        label_img_ = label_img
                    label_img_ = Image.fromarray(label_img_, 'P')
                    Log.info('{:4d}/{:4d} label map generated'.format(
                        image_id, self.test_size))
                    label_path = os.path.join(self.save_dir, "label/",
                                              '{}.png'.format(names[k]))
                    FileHelper.make_dirs(label_path, is_file=True)
                    ImageHelper.save(label_img_, label_path)

                    # colorize the label-map
                    if os.environ.get('save_gt_label'):
                        if self.configer.exists(
                                'data',
                                'reduce_zero_label') and self.configer.get(
                                    'data', 'reduce_zero_label'):
                            label_img = labels[k] + 1
                            label_img = np.asarray(label_img, dtype=np.uint8)
                        color_img_ = Image.fromarray(label_img)
                        color_img_.putpalette(colors)
                        vis_path = os.path.join(self.save_dir, "gt_vis/",
                                                '{}.png'.format(names[k]))
                        FileHelper.make_dirs(vis_path, is_file=True)
                        ImageHelper.save(color_img_, save_path=vis_path)
                    else:
                        color_img_ = Image.fromarray(label_img)
                        color_img_.putpalette(colors)
                        vis_path = os.path.join(self.save_dir, "vis/",
                                                '{}.png'.format(names[k]))
                        FileHelper.make_dirs(vis_path, is_file=True)
                        ImageHelper.save(color_img_, save_path=vis_path)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))

    def offset_test(self, inputs, offset_h_maps, offset_w_maps, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(inputs, offset_h_maps,
                                           offset_w_maps)
            torch.cuda.synchronize()
            end = timeit.default_timer()

            if (self.configer.get('loss', 'loss_type')
                    == "fs_auxce_loss") or (self.configer.get(
                        'loss', 'loss_type') == "triple_auxce_loss"):
                outputs = outputs[-1]
            elif self.configer.get('loss',
                                   'loss_type') == "pyramid_auxce_loss":
                outputs = outputs[1] + outputs[2] + outputs[3] + outputs[4]

            outputs = F.interpolate(outputs,
                                    size=(h, w),
                                    mode='bilinear',
                                    align_corners=True)
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def ss_test(self, inputs, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            scaled_inputs = F.interpolate(inputs,
                                          size=(int(h * scale),
                                                int(w * scale)),
                                          mode="bilinear",
                                          align_corners=True)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(scaled_inputs)
            torch.cuda.synchronize()
            end = timeit.default_timer()
            outputs = outputs[-1]
            outputs = F.interpolate(outputs,
                                    size=(h, w),
                                    mode='bilinear',
                                    align_corners=True)
            return outputs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            replicas = nn.parallel.replicate(self.seg_net.module, device_ids)
            scaled_inputs, ori_size, outputs = [], [], []
            for i, d in zip(inputs, device_ids):
                h, w = i.size(1), i.size(2)
                ori_size.append((h, w))
                i = F.interpolate(i.unsqueeze(0),
                                  size=(int(h * scale), int(w * scale)),
                                  mode="bilinear",
                                  align_corners=True)
                scaled_inputs.append(i.cuda(d, non_blocking=True))
            scaled_outputs = nn.parallel.parallel_apply(
                replicas[:len(scaled_inputs)], scaled_inputs)
            for i, output in enumerate(scaled_outputs):
                outputs.append(
                    F.interpolate(output[-1],
                                  size=ori_size[i],
                                  mode='bilinear',
                                  align_corners=True))
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def flip(self, x, dim):
        indices = [slice(None)] * x.dim()
        indices[dim] = torch.arange(x.size(dim) - 1,
                                    -1,
                                    -1,
                                    dtype=torch.long,
                                    device=x.device)
        return x[tuple(indices)]

    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        n, c, ori_h, ori_w = inputs.size(0), inputs.size(1), inputs.size(
            2), inputs.size(3)
        scaled_inputs = F.interpolate(inputs,
                                      size=(int(ori_h * scale),
                                            int(ori_w * scale)),
                                      mode="bilinear",
                                      align_corners=True)
        n, c, h, w = scaled_inputs.size(0), scaled_inputs.size(
            1), scaled_inputs.size(2), scaled_inputs.size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        count_predictions = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = scaled_inputs[:, :, height:height + crop_size[0],
                                            width:width + crop_size[1]]
                prediction = self.ss_test(crop_inputs)
                count_predictions[:, :, height:height + crop_size[0],
                                  width:width + crop_size[1]] += 1
                full_probs[:, :, height:height + crop_size[0],
                           width:width + crop_size[1]] += prediction
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        full_probs /= count_predictions
        full_probs = F.interpolate(full_probs,
                                   size=(ori_h, ori_w),
                                   mode='bilinear',
                                   align_corners=True)
        return full_probs

    def ms_test(self, inputs):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(
                        self.configer.get('test', 'scale_search'),
                        self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += weight * probs
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += probs
                return full_probs

        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [
                torch.zeros(1, self.configer.get('data', 'num_classes'),
                            i.size(1), i.size(2)).cuda(device_ids[index],
                                                       non_blocking=True)
                for index, i in enumerate(inputs)
            ]
            flip_inputs = [self.flip(i, 2) for i in inputs]

            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(
                        self.configer.get('test', 'scale_search'),
                        self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += weight * (probs[i] +
                                                   self.flip(flip_probs[i], 3))
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += (probs[i] +
                                          self.flip(flip_probs[i], 3))
                return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def ms_test_depth(self, inputs, names):
        prob_list = []
        scale_list = []

        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                prob_list.append(probs)
                scale_list.append(scale)

            full_probs = self.fuse_with_depth(prob_list, scale_list, names)
            return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def fuse_with_depth(self, probs, scales, names):
        MAX_DEPTH = 63
        POWER_BASE = 0.8
        if 'test' in self.save_dir:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/test/"
        else:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/val/"

        n, c, h, w = probs[0].size(0), probs[0].size(1), probs[0].size(
            2), probs[0].size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        for index, name in enumerate(names):
            stereo_map = cv2.imread(stereo_path + name + '.png', -1)
            depth_map = stereo_map / 256.0
            depth_map = 0.5 / depth_map
            depth_map = 500 * depth_map

            depth_map = np.clip(depth_map, 0, MAX_DEPTH)
            depth_map = depth_map // (MAX_DEPTH // len(scales))

            for prob, scale in zip(probs, scales):
                scale_index = self._locate_scale_index(scale, scales)
                weight_map = np.abs(depth_map - scale_index)
                weight_map = np.power(POWER_BASE, weight_map)
                weight_map = cv2.resize(weight_map, (w, h))
                full_probs[index, :, :, :] += torch.from_numpy(
                    np.expand_dims(weight_map, axis=0)).type(
                        torch.cuda.FloatTensor) * prob[index, :, :, :]

        return full_probs

    @staticmethod
    def _locate_scale_index(scale, scales):
        for idx, s in enumerate(scales):
            if scale == s:
                return idx
        return 0

    def ms_test_wo_flip(self, inputs):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                full_probs += probs
            return full_probs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [
                torch.zeros(1, self.configer.get('data', 'num_classes'),
                            i.size(1), i.size(2)).cuda(device_ids[index],
                                                       non_blocking=True)
                for index, i, in enumerate(inputs)
            ]
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                for i in range(len(inputs)):
                    full_probs[i] += probs[i]
            return full_probs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def mscrop_test(self, inputs, crop_size):
        '''
        Currently, mscrop_test does not support diverse_size testing
        '''
        n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
            2), inputs.size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        for scale in self.configer.get('test', 'scale_search'):
            Log.info('Scale {0:.2f} prediction'.format(scale))
            if scale < 1:
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
            else:
                probs = self.sscrop_test(inputs, crop_size, scale)
                flip_probs = self.sscrop_test(self.flip(inputs, 3), crop_size,
                                              scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
        return full_probs

    def _decide_intersection(self, total_length, crop_length):
        stride = crop_length
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride * i)
        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length -
                                    crop_length)  # must cover the total image
        return cropped_starting

    def dense_crf_process(self, images, outputs):
        '''
        Reference: https://github.com/kazuto1011/deeplab-pytorch/blob/master/libs/utils/crf.py
        '''
        # hyperparameters of the dense crf
        # baseline = 79.5
        # bi_xy_std = 67, 79.1
        # bi_xy_std = 20, 79.6
        # bi_xy_std = 10, 79.7
        # bi_xy_std = 10, iter_max = 20, v4 79.7
        # bi_xy_std = 10, iter_max = 5, v5 79.7
        # bi_xy_std = 5, v3 79.7
        iter_max = 10
        pos_w = 3
        pos_xy_std = 1
        bi_w = 4
        bi_xy_std = 10
        bi_rgb_std = 3

        b = images.size(0)
        mean_vector = np.expand_dims(np.expand_dims(np.transpose(
            np.array([102.9801, 115.9465, 122.7717])),
                                                    axis=1),
                                     axis=2)
        outputs = F.softmax(outputs, dim=1)
        for i in range(b):
            unary = outputs[i].data.cpu().numpy()
            C, H, W = unary.shape
            unary = dcrf_utils.unary_from_softmax(unary)
            unary = np.ascontiguousarray(unary)

            image = np.ascontiguousarray(images[i]) + mean_vector
            image = image.astype(np.ubyte)
            image = np.ascontiguousarray(image.transpose(1, 2, 0))

            d = dcrf.DenseCRF2D(W, H, C)
            d.setUnaryEnergy(unary)
            d.addPairwiseGaussian(sxy=pos_xy_std, compat=pos_w)
            d.addPairwiseBilateral(sxy=bi_xy_std,
                                   srgb=bi_rgb_std,
                                   rgbim=image,
                                   compat=bi_w)
            out_crf = np.array(d.inference(iter_max))
            outputs[i] = torch.from_numpy(out_crf).cuda().view(C, H, W)

        return outputs
Exemplo n.º 4
0
class Trainer(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.foward_time = AverageMeter()
        self.backward_time = AverageMeter()
        self.loss_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.data_loader = DataLoader(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.data_helper = DataHelper(configer, self)
        self.evaluator = get_evaluator(configer, self)        

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.running_score = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        Log.info('Params Group Method: {}'.format(self.configer.get('optim', 'group_method')))
        if self.configer.get('optim', 'group_method') == 'decay':
            params_group = self.group_weight(self.seg_net)
        else:
            assert self.configer.get('optim', 'group_method') is None
            params_group = self._get_parameters()

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(params_group)

        self.train_loader = self.data_loader.get_trainloader()
        self.val_loader = self.data_loader.get_valloader()
        self.pixel_loss = self.loss_manager.get_seg_loss()
        if is_distributed():
            self.pixel_loss = self.module_runner.to_device(self.pixel_loss)        

    @staticmethod
    def group_weight(module):
        group_decay = []
        group_no_decay = []
        for m in module.modules():
            if isinstance(m, nn.Linear):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.conv._ConvNd):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            else:
                if hasattr(m, 'weight'):
                    group_no_decay.append(m.weight)
                if hasattr(m, 'bias'):
                    group_no_decay.append(m.bias)

        assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
        groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
        return groups

    def _get_parameters(self):
        bb_lr = []
        nbb_lr = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                nbb_lr.append(value)
            else:
                bb_lr.append(value)

        params = [{'params': bb_lr, 'lr': self.configer.get('lr', 'base_lr')},
                  {'params': nbb_lr, 'lr': self.configer.get('lr', 'base_lr') * self.configer.get('lr', 'nbb_mult')}]
        return params

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        self.pixel_loss.train()
        start_time = time.time()

        if "swa" in self.configer.get('lr', 'lr_policy'):
            normal_max_iters = int(self.configer.get('solver', 'max_iters') * 0.75)
            swa_step_max_iters = (self.configer.get('solver', 'max_iters') - normal_max_iters) // 5 + 1

        if hasattr(self.train_loader.sampler, 'set_epoch'):
            self.train_loader.sampler.set_epoch(self.configer.get('epoch'))

        for i, data_dict in enumerate(self.train_loader):
            if self.configer.get('lr', 'metric') == 'iters':
                self.scheduler.step(self.configer.get('iters'))
            else:
                self.scheduler.step(self.configer.get('epoch'))


            if self.configer.get('lr', 'is_warm'):
                self.module_runner.warm_lr(
                    self.configer.get('iters'),
                    self.scheduler, self.optimizer, backbone_list=[0,]
                )

            (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict)
            self.data_time.update(time.time() - start_time)

            foward_start_time = time.time()
            outputs = self.seg_net(*inputs)
            self.foward_time.update(time.time() - foward_start_time)

            loss_start_time = time.time()
            if is_distributed():
                import torch.distributed as dist
                def reduce_tensor(inp):
                    """
                    Reduce the loss from all processes so that 
                    process with rank 0 has the averaged results.
                    """
                    world_size = get_world_size()
                    if world_size < 2:
                        return inp
                    with torch.no_grad():
                        reduced_inp = inp
                        dist.reduce(reduced_inp, dst=0)
                    return reduced_inp
                loss = self.pixel_loss(outputs, targets)
                backward_loss = loss
                display_loss = reduce_tensor(backward_loss) / get_world_size()
            else:
                backward_loss = display_loss = self.pixel_loss(outputs, targets, gathered=self.configer.get('network', 'gathered'))

            self.train_losses.update(display_loss.item(), batch_size)
            self.loss_time.update(time.time() - loss_start_time)

            backward_start_time = time.time()
            self.optimizer.zero_grad()
            backward_loss.backward()
            self.optimizer.step()
            self.backward_time.update(time.time() - backward_start_time)

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get('solver', 'display_iter') == 0 and \
                (not is_distributed() or get_rank() == 0):
                Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t'
                         'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                         'Forward Time {foward_time.sum:.3f}s / {2}iters, ({foward_time.avg:.3f})\t'
                         'Backward Time {backward_time.sum:.3f}s / {2}iters, ({backward_time.avg:.3f})\t'
                         'Loss Time {loss_time.sum:.3f}s / {2}iters, ({loss_time.avg:.3f})\t'
                         'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                         'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                         self.configer.get('epoch'), self.configer.get('iters'),
                         self.configer.get('solver', 'display_iter'),
                         self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time,
                         foward_time=self.foward_time, backward_time=self.backward_time, loss_time=self.loss_time,
                         data_time=self.data_time, loss=self.train_losses))
                self.batch_time.reset()
                self.foward_time.reset()
                self.backward_time.reset()
                self.loss_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # save checkpoints for swa
            if 'swa' in self.configer.get('lr', 'lr_policy') and \
               self.configer.get('iters') > normal_max_iters and \
               ((self.configer.get('iters') - normal_max_iters) % swa_step_max_iters == 0 or \
                self.configer.get('iters') == self.configer.get('solver', 'max_iters')):
               self.optimizer.update_swa()

            if self.configer.get('iters') == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            # if self.configer.get('epoch') % self.configer.get('solver', 'test_interval') == 0:
            if self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

        self.configer.plus_one('epoch')


    def __val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        self.pixel_loss.eval()
        start_time = time.time()
        replicas = self.evaluator.prepare_validaton()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            if j % 10 == 0:
                Log.info('{} images processed\n'.format(j))

            if self.configer.get('dataset') == 'lip':
                (inputs, targets, inputs_rev, targets_rev), batch_size = self.data_helper.prepare_data(data_dict, want_reverse=True)
            else:
                (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict)

            with torch.no_grad():
                if self.configer.get('dataset') == 'lip':
                    inputs = torch.cat([inputs[0], inputs_rev[0]], dim=0)
                    outputs = self.seg_net(inputs)        
                    outputs_ = self.module_runner.gather(outputs)
                    if isinstance(outputs_, (list, tuple)):
                        outputs_ = outputs_[-1]
                    outputs = outputs_[0:int(outputs_.size(0)/2),:,:,:].clone()
                    outputs_rev = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),:,:,:].clone()
                    if outputs_rev.shape[1] == 20:
                        outputs_rev[:,14,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),15,:,:]
                        outputs_rev[:,15,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),14,:,:]
                        outputs_rev[:,16,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),17,:,:]
                        outputs_rev[:,17,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),16,:,:]
                        outputs_rev[:,18,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),19,:,:]
                        outputs_rev[:,19,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),18,:,:]
                    outputs_rev = torch.flip(outputs_rev, [3])
                    outputs = (outputs + outputs_rev) / 2.
                    self.evaluator.update_score(outputs, data_dict['meta'])

                elif self.data_helper.conditions.diverse_size:
                    outputs = nn.parallel.parallel_apply(replicas[:len(inputs)], inputs)

                    for i in range(len(outputs)):
                        loss = self.pixel_loss(outputs[i], targets[i])
                        self.val_losses.update(loss.item(), 1)
                        outputs_i = outputs[i]
                        if isinstance(outputs_i, torch.Tensor):
                            outputs_i = [outputs_i]
                        self.evaluator.update_score(outputs_i, data_dict['meta'][i:i+1])
                            
                else:
                    outputs = self.seg_net(*inputs)

                    try:
                        loss = self.pixel_loss(
                            outputs, targets, 
                            gathered=self.configer.get('network', 'gathered')
                        )
                    except AssertionError as e:
                        print(len(outputs), len(targets))


                    if not is_distributed():
                        outputs = self.module_runner.gather(outputs)
                    self.val_losses.update(loss.item(), batch_size)
                    self.evaluator.update_score(outputs, data_dict['meta'])

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.evaluator.update_performance()
        
        self.configer.update(['val_loss'], self.val_losses.avg)
        self.module_runner.save_net(self.seg_net, save_mode='performance')
        self.module_runner.save_net(self.seg_net, save_mode='val_loss')
        cudnn.benchmark = True

        # Print the log info & reset the states.
        if not is_distributed() or get_rank() == 0:
            Log.info(
                'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                'Loss {loss.avg:.8f}\n'.format(
                    batch_time=self.batch_time, loss=self.val_losses))
            self.evaluator.print_scores()
            
        self.batch_time.reset()
        self.val_losses.reset()
        self.evaluator.reset()
        self.seg_net.train()
        self.pixel_loss.train()

    def train(self):
        # cudnn.benchmark = True
        # self.__val()
        if self.configer.get('network', 'resume') is not None:
            if self.configer.get('network', 'resume_val'):
                self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))
                return
            elif self.configer.get('network', 'resume_train'):
                self.__val(data_loader=self.data_loader.get_valloader(dataset='train'))
                return
            # return

        if self.configer.get('network', 'resume') is not None and self.configer.get('network', 'resume_val'):
            self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))
            return

        while self.configer.get('iters') < self.configer.get('solver', 'max_iters'):
            self.__train()

        # use swa to average the model
        if 'swa' in self.configer.get('lr', 'lr_policy'):
            self.optimizer.swap_swa_sgd()
            self.optimizer.bn_update(self.train_loader, self.seg_net)

        self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))

    def summary(self):
        from lib.utils.summary import get_model_summary
        import torch.nn.functional as F
        self.seg_net.eval()

        for j, data_dict in enumerate(self.train_loader):
            print(get_model_summary(self.seg_net, data_dict['img'][0:1]))
            return
Exemplo n.º 5
0
class Trainer(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.running_score = RunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.data_loader = DataLoader(configer)
        self.optim_scheduler = OptimScheduler(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        Log.info('Params Group Method: {}'.format(
            self.configer.get('optim', 'group_method')))
        if self.configer.get('optim', 'group_method') == 'decay':
            params_group = self.group_weight(self.seg_net)
        else:
            assert self.configer.get('optim', 'group_method') is None
            params_group = self._get_parameters()

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(
            params_group)

        self.train_loader = self.data_loader.get_trainloader()
        self.val_loader = self.data_loader.get_valloader()

        self.pixel_loss = self.loss_manager.get_seg_loss()

    @staticmethod
    def group_weight(module):
        group_decay = []
        group_no_decay = []
        for m in module.modules():
            if isinstance(m, nn.Linear):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.conv._ConvNd):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            else:
                if hasattr(m, 'weight'):
                    group_no_decay.append(m.weight)
                if hasattr(m, 'bias'):
                    group_no_decay.append(m.bias)

        assert len(list(
            module.parameters())) == len(group_decay) + len(group_no_decay)
        groups = [
            dict(params=group_decay),
            dict(params=group_no_decay, weight_decay=.0)
        ]
        return groups

    def _get_parameters(self):
        bb_lr = []
        nbb_lr = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                nbb_lr.append(value)
            else:
                bb_lr.append(value)

        params = [{
            'params': bb_lr,
            'lr': self.configer.get('lr', 'base_lr')
        }, {
            'params':
            nbb_lr,
            'lr':
            self.configer.get('lr', 'base_lr') *
            self.configer.get('lr', 'nbb_mult')
        }]
        return params

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        start_time = time.time()

        for i, data_dict in enumerate(self.train_loader):
            if self.configer.get('lr', 'metric') == 'iters':
                self.scheduler.step(self.configer.get('iters'))
            else:
                self.scheduler.step(self.configer.get('epoch'))

            if self.configer.get('lr', 'is_warm'):
                self.module_runner.warm_lr(self.configer.get('iters'),
                                           self.scheduler,
                                           self.optimizer,
                                           backbone_list=[
                                               0,
                                           ])
            inputs = data_dict['img']
            targets = data_dict['labelmap']
            self.data_time.update(time.time() - start_time)
            # Change the data type.
            # inputs, targets = self.module_runner.to_device(inputs, targets)

            # Forward pass.
            outputs = self.seg_net(inputs)
            # outputs = self.module_utilizer.gather(outputs)
            # Compute the loss of the train batch & backward.
            loss = self.pixel_loss(outputs,
                                   targets,
                                   gathered=self.configer.get(
                                       'network', 'gathered'))
            if self.configer.exists('train', 'loader') and self.configer.get(
                    'train', 'loader') == 'ade20k':
                batch_size = self.configer.get(
                    'train', 'batch_size') * self.configer.get(
                        'train', 'batch_per_gpu')
                self.train_losses.update(loss.item(), batch_size)
            else:
                self.train_losses.update(loss.item(), inputs.size(0))

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.configer.get('epoch'),
                            self.configer.get('iters'),
                            self.configer.get('solver', 'display_iter'),
                            self.module_runner.get_lr(self.optimizer),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.configer.get('iters') == self.configer.get(
                    'solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'test_interval') == 0:
                self.__val()

        self.configer.plus_one('epoch')

    def __val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            inputs = data_dict['img']
            targets = data_dict['labelmap']

            with torch.no_grad():
                # Change the data type.
                inputs, targets = self.module_runner.to_device(inputs, targets)
                # Forward pass.
                outputs = self.seg_net(inputs)
                # Compute the loss of the val batch.
                loss = self.pixel_loss(outputs,
                                       targets,
                                       gathered=self.configer.get(
                                           'network', 'gathered'))
                outputs = self.module_runner.gather(outputs)

            self.val_losses.update(loss.item(), inputs.size(0))
            self._update_running_score(outputs[-1], data_dict['meta'])
            # self.seg_running_score.update(pred.max(1)[1].cpu().numpy(), targets.cpu().numpy())

            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.configer.update(['performance'],
                             self.running_score.get_mean_iou())
        self.configer.update(['val_loss'], self.val_losses.avg)
        self.module_runner.save_net(self.seg_net, save_mode='performance')
        self.module_runner.save_net(self.seg_net, save_mode='val_loss')

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                                loss=self.val_losses))
        Log.info('Mean IOU: {}\n'.format(self.running_score.get_mean_iou()))
        Log.info('Pixel ACC: {}\n'.format(self.running_score.get_pixel_acc()))
        self.batch_time.reset()
        self.val_losses.reset()
        self.running_score.reset()
        self.seg_net.train()

    def _update_running_score(self, pred, metas):
        pred = pred.permute(0, 2, 3, 1)
        for i in range(pred.size(0)):
            ori_img_size = metas[i]['ori_img_size']
            border_size = metas[i]['border_size']
            ori_target = metas[i]['ori_target']
            total_logits = cv2.resize(
                pred[i, :border_size[1], :border_size[0]].cpu().numpy(),
                tuple(ori_img_size),
                interpolation=cv2.INTER_CUBIC)
            labelmap = np.argmax(total_logits, axis=-1)
            self.running_score.update(labelmap[None], ori_target[None])

    def train(self):
        # cudnn.benchmark = True
        if self.configer.get('network',
                             'resume') is not None and self.configer.get(
                                 'network', 'resume_val'):
            self.__val()

        while self.configer.get('iters') < self.configer.get(
                'solver', 'max_iters'):
            self.__train()

        self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))
        self.__val(data_loader=self.data_loader.get_valloader(dataset='train'))
Exemplo n.º 6
0
class Tester(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_parser = SegParser(configer)
        self.seg_model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)
        self.module_runner = ModuleRunner(configer)
        self.device = torch.device('cpu' if self.configer.get('gpu') is None else 'cuda')
        self.seg_net = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.seg_model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)
        self.seg_net.eval()

    def _get_blob(self, ori_image, scale=None):
        assert scale is not None
        image = None
        if self.configer.exists('test', 'input_size'):
            image = self.blob_helper.make_input(image=ori_image,
                                                input_size=self.configer.get('test', 'input_size'),
                                                scale=scale)

        elif self.configer.exists('test', 'min_side_length') and not self.configer.exists('test', 'max_side_length'):
            image = self.blob_helper.make_input(image=ori_image,
                                                min_side_length=self.configer.get('test', 'min_side_length'),
                                                scale=scale)

        elif not self.configer.exists('test', 'min_side_length') and self.configer.exists('test', 'max_side_length'):
            image = self.blob_helper.make_input(image=ori_image,
                                                max_side_length=self.configer.get('test', 'max_side_length'),
                                                scale=scale)

        elif self.configer.exists('test', 'min_side_length') and self.configer.exists('test', 'max_side_length'):
            image = self.blob_helper.make_input(image=ori_image,
                                                min_side_length=self.configer.get('test', 'min_side_length'),
                                                max_side_length=self.configer.get('test', 'max_side_length'),
                                                scale=scale)

        else:
            Log.error('Test setting error')
            exit(1)

        b, c, h, w = image.size()
        border_hw = [h, w]
        if self.configer.exists('test', 'fit_stride'):
            stride = self.configer.get('test', 'fit_stride')

            pad_w = 0 if (w % stride == 0) else stride - (w % stride)  # right
            pad_h = 0 if (h % stride == 0) else stride - (h % stride)  # down

            expand_image = torch.zeros((b, c, h + pad_h, w + pad_w)).to(image.device)
            expand_image[:, :, 0:h, 0:w] = image
            image = expand_image

        return image, border_hw

    def __test_img(self, image_path, label_path, vis_path, raw_path):
        Log.info('Image Path: {}'.format(image_path))
        ori_image = ImageHelper.read_image(image_path,
                                           tool=self.configer.get('data', 'image_tool'),
                                           mode=self.configer.get('data', 'input_mode'))
        total_logits = None
        if self.configer.get('test', 'mode') == 'ss_test':
            total_logits = self.ss_test(ori_image)

        elif self.configer.get('test', 'mode') == 'sscrop_test':
            total_logits = self.sscrop_test(ori_image)

        elif self.configer.get('test', 'mode') == 'ms_test':
            total_logits = self.ms_test(ori_image)

        elif self.configer.get('test', 'mode') == 'mscrop_test':
            total_logits = self.mscrop_test(ori_image)

        else:
            Log.error('Invalid test mode:{}'.format(self.configer.get('test', 'mode')))
            exit(1)

        label_map = np.argmax(total_logits, axis=-1)
        label_img = np.array(label_map, dtype=np.uint8)
        ori_img_bgr = ImageHelper.get_cv2_bgr(ori_image, mode=self.configer.get('data', 'input_mode'))
        image_canvas = self.seg_parser.colorize(label_img, image_canvas=ori_img_bgr)
        ImageHelper.save(image_canvas, save_path=vis_path)
        ImageHelper.save(ori_image, save_path=raw_path)

        if self.configer.exists('data', 'label_list'):
            label_img = self.__relabel(label_img)

        if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'):
            label_img = label_img + 1
            label_img = label_img.astype(np.uint8)

        label_img = Image.fromarray(label_img, 'P')
        Log.info('Label Path: {}'.format(label_path))
        ImageHelper.save(label_img, label_path)

    def ss_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        image, border_hw = self._get_blob(ori_image, scale=1.0)
        results = self._predict(image)
        results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                             (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
        total_logits += results
        return total_logits

    def sscrop_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        image, border_hw = self._get_blob(ori_image, scale=1.0)
        crop_size = self.configer.get('test', 'crop_size')
        if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
            results = self._crop_predict(image, crop_size)
        else:
            results = self._predict(image)

        results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                             (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
        total_logits += results
        return total_logits

    def mscrop_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        crop_size = self.configer.get('test', 'crop_size')
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        for scale in self.configer.get('test', 'scale_search'):
            image, border_hw = self._get_blob(ori_image, scale=scale)
            if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
                results = self._crop_predict(image, crop_size)
            else:
                results = self._predict(image)

            results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                                 (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

            if self.configer.get('data', 'image_tool') == 'cv2':
                mirror_image = cv2.flip(ori_image, 1)
            else:
                mirror_image = ori_image.transpose(Image.FLIP_LEFT_RIGHT)

            image, border_hw = self._get_blob(mirror_image, scale=1.0)
            if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
                results = self._crop_predict(image, crop_size)
            else:
                results = self._predict(image)

            results = results[:border_hw[0], :border_hw[1]]
            results = cv2.resize(results[:, ::-1], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

        return total_logits

    def ms_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        for scale in self.configer.get('test', 'scale_search'):
            image, border_hw = self._get_blob(ori_image, scale=scale)
            results = self._predict(image)
            results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                                 (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

            if self.configer.get('data', 'image_tool') == 'cv2':
                mirror_image = cv2.flip(ori_image, 1)
            else:
                mirror_image = ori_image.transpose(Image.FLIP_LEFT_RIGHT)

            image, border_hw = self._get_blob(mirror_image, scale=scale)
            results = self._predict(image)
            results = results[:border_hw[0], :border_hw[1]]
            results = cv2.resize(results[:, ::-1], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

        return total_logits

    def _crop_predict(self, image, crop_size):
        height, width = image.size()[2:]
        np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
        height_starts = self._decide_intersection(height, crop_size[1])
        width_starts = self._decide_intersection(width, crop_size[0])
        split_crops = []
        for height in height_starts:
            for width in width_starts:
                image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]]
                split_crops.append(image_crop[np.newaxis, :])

        split_crops = np.concatenate(split_crops, axis=0)  # (n, crop_image_size, crop_image_size, 3)
        inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device)
        with torch.no_grad():
            results = self.seg_net.forward(inputs)
            results = results[-1].permute(0, 2, 3, 1).cpu().numpy()

        reassemble = np.zeros((np_image.shape[0], np_image.shape[1], results.shape[-1]), np.float32)
        index = 0
        for height in height_starts:
            for width in width_starts:
                reassemble[height:height+crop_size[1], width:width+crop_size[0]] += results[index]
                index += 1

        return reassemble

    def _decide_intersection(self, total_length, crop_length):
        stride = int(crop_length * self.configer.get('test', 'crop_stride_ratio'))            # set the stride as the paper do
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride*i)

        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length - crop_length)  # must cover the total image

        return cropped_starting

    def _predict(self, inputs):
        with torch.no_grad():
            results = self.seg_net.forward(inputs)
            results = results[-1].squeeze(0).permute(1, 2, 0).cpu().numpy()

        return results

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get('data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self):
        base_dir = os.path.join(self.configer.get('project_dir'), 'results', self.configer.get('dataset'))

        test_img = self.configer.get('test', 'test_img')
        test_dir = self.configer.get('test', 'test_dir')
        if test_img is None and test_dir is None:
            Log.error('test_img & test_dir not exists.')
            exit(1)

        if test_img is not None and test_dir is not None:
            Log.error('Either test_img or test_dir.')
            exit(1)

        if test_img is not None:
            base_dir = os.path.join(base_dir, 'test_img')
            filename = test_img.rstrip().split('/')[-1]
            label_path = os.path.join(base_dir, 'label', '{}.png'.format('.'.join(filename.split('.')[:-1])))
            raw_path = os.path.join(base_dir, 'raw', filename)
            vis_path = os.path.join(base_dir, 'vis', '{}_vis.png'.format('.'.join(filename.split('.')[:-1])))
            FileHelper.make_dirs(label_path, is_file=True)
            FileHelper.make_dirs(raw_path, is_file=True)
            FileHelper.make_dirs(vis_path, is_file=True)

            self.__test_img(test_img, label_path, vis_path, raw_path)

        else:
            base_dir = os.path.join(base_dir, 'test_dir',
                                    self.configer.get('checkpoints', 'checkpoints_name'),
                                    self.configer.get('test', 'out_dir'))
            FileHelper.make_dirs(base_dir)

            for filename in FileHelper.list_dir(test_dir):
                image_path = os.path.join(test_dir, filename)
                label_path = os.path.join(base_dir, 'label', '{}.png'.format('.'.join(filename.split('.')[:-1])))
                raw_path = os.path.join(base_dir, 'raw', filename)
                vis_path = os.path.join(base_dir, 'vis', '{}_vis.png'.format('.'.join(filename.split('.')[:-1])))
                FileHelper.make_dirs(label_path, is_file=True)
                FileHelper.make_dirs(raw_path, is_file=True)
                FileHelper.make_dirs(vis_path, is_file=True)

                self.__test_img(image_path, label_path, vis_path, raw_path)

    def debug(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'vis/results', self.configer.get('dataset'), 'debug')

        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        count = 0
        for i, data_dict in enumerate(self.seg_data_loader.get_trainloader()):
            inputs = data_dict['img']
            targets = data_dict['labelmap']
            for j in range(inputs.size(0)):
                count = count + 1
                if count > 20:
                    exit(1)

                image_bgr = self.blob_helper.tensor2bgr(inputs[j])
                label_map = targets[j].numpy()
                image_canvas = self.seg_parser.colorize(label_map, image_canvas=image_bgr)
                cv2.imwrite(os.path.join(base_dir, '{}_{}_vis.png'.format(i, j)), image_canvas)
                cv2.imshow('main', image_canvas)
                cv2.waitKey()