Example #1
0
class FeatureExtractor(object):
    '''
    Standard model class.
    '''
    def __init__(self, args, device='cuda'):
        if torch.cuda.device_count() == 0:
            device = 'cpu'

        self.device = torch.device(device)
        Log.info('Resuming from {}...'.format(args.model_path))
        checkpoint_dict = torch.load(args.model_path)
        self.configer = Configer(config_dict=checkpoint_dict['config_dict'],
                                 args_parser=args,
                                 valid_flag="deploy")
        self.net = ModelManager(self.configer).get_deploy_model()
        RunnerHelper.load_state_dict(self.net, checkpoint_dict['state_dict'],
                                     False)
        if device == 'cuda':
            self.net = DataParallelModel(self.net, gather_=True)

        self.net = self.net.to(self.device).eval()
        self.test_loader = DataLoader(self.configer)

    def run(self, root_dir, batch_size):
        '''
          Apply the model.
        '''
        for i, data_dict in enumerate(
                self.test_loader.get_testloader(test_dir=root_dir,
                                                batch_size=batch_size)):
            with torch.no_grad():
                feat = self.net(data_dict['img'])

            norm_feat_arr = feat.cpu().numpy()
            for i in range(len(data_dict['meta'])):
                save_name = '{}.feat'.format(
                    os.path.splitext(data_dict['meta'][i]['filename'])[0])
                save_path = os.path.join(
                    '{}_feat'.format(root_dir.rstrip('/')), save_name)
                FileHelper.make_dirs(save_path, is_file=True)
                ffeat = open(save_path, 'w')
                ffeat.write(
                    "%s" %
                    (" ".join([str(x)
                               for x in list(norm_feat_arr[i])]) + '\n'))
                ffeat.close()
class Tester(object):
    def __init__(self, configer):
        self.crop_size = configer.get('train',
                                      'data_transformer')['input_size']
        val_trans_seq = [
            x for x in configer.get('val_trans', 'trans_seq')
            if 'random' not in x
        ]
        configer.update(('val_trans', 'trans_seq'), val_trans_seq)
        configer.get('val', 'data_transformer')['input_size'] = configer.get(
            'test', 'data_transformer').get('input_size', None)
        configer.update(('train', 'data_transformer'),
                        configer.get('val', 'data_transformer'))
        configer.update(('val', 'batch_size'),
                        int(os.environ.get('batch_size', 16)))
        configer.update(('test', 'batch_size'),
                        int(os.environ.get('batch_size', 16)))

        self.save_dir = configer.get('test', 'out_dir')
        self.dataset_name = configer.get('test', 'eval_set')
        self.sscrop = configer.get('test', 'sscrop')

        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)
        self.seg_net = None
        self.test_loader = None
        self.test_size = None
        self.infer_time = 0
        self.infer_cnt = 0
        self._init_model()

        pprint.pprint(configer.params_root)

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        assert self.dataset_name in ('train', 'val',
                                     'test'), 'Cannot infer dataset name'

        self.size_mode = self.configer.get(self.dataset_name,
                                           'data_transformer')['size_mode']

        if self.dataset_name != 'test':
            self.test_loader = self.seg_data_loader.get_valloader(
                self.dataset_name)
        else:
            self.test_loader = self.seg_data_loader.get_testloader(
                self.dataset_name)
        self.test_size = len(self.test_loader) * self.configer.get(
            'val', 'batch_size')

    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        print('Total batches', len(self.test_loader))
        for j, data_dict in enumerate(self.test_loader):
            inputs = [data_dict['img']]
            names = data_dict['name']
            metas = data_dict['meta']

            dest_dir = self.save_dir

            with torch.no_grad():
                offsets, logits = self.extract_offset(inputs)
                print([x.shape for x in logits])
                for k in range(len(inputs[0])):
                    image_id += 1
                    ori_img_size = metas[k]['ori_img_size']
                    border_size = metas[k]['border_size']
                    offset = offsets[k].squeeze().cpu().numpy()
                    offset = cv2.resize(
                        offset[:border_size[1], :border_size[0]],
                        tuple(ori_img_size),
                        interpolation=cv2.INTER_NEAREST)
                    print(image_id)

                    os.makedirs(dest_dir, exist_ok=True)

                    if names[k].rpartition('.')[0]:
                        dest_name = names[k].rpartition('.')[0] + '.mat'
                    else:
                        dest_name = names[k] + '.mat'
                    dest_name = os.path.join(dest_dir, dest_name)
                    print('Shape:', offset.shape, 'Saving to', dest_name)

                    data_dict = {'mat': offset}

                    scipy.io.savemat(dest_name, data_dict, do_compression=True)
                    try:
                        scipy.io.loadmat(dest_name)
                    except Exception as e:
                        print(e)
                        scipy.io.savemat(dest_name,
                                         data_dict,
                                         do_compression=False)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))

    def extract_offset(self, inputs):
        if self.sscrop:
            outputs = self.sscrop_test(inputs, self.crop_size)
        elif self.configer.get('test', 'mode') == 'ss_test':
            outputs = self.ss_test(inputs)

        offsets = []
        logits = []

        for mask_logits, dir_logits, img in zip(*outputs[:2], inputs[0]):
            h, w = img.shape[1:]

            mask_logits = F.interpolate(mask_logits.unsqueeze(0),
                                        size=(h, w),
                                        mode='bilinear',
                                        align_corners=True)
            dir_logits = F.interpolate(dir_logits.unsqueeze(0),
                                       size=(h, w),
                                       mode='bilinear',
                                       align_corners=True)

            logit = torch.softmax(dir_logits, dim=1)
            zero_mask = mask_logits.argmax(dim=1, keepdim=True) == 0
            logits.append(mask_logits[:, 1])

            offset = self._get_offset(mask_logits, dir_logits)
            offsets.append(offset)
        print([x.shape for x in offsets])
        return offsets, logits

    def _get_offset(self, mask_logits, dir_logits):

        edge_mask = mask_logits[:, 1] > 0.5
        dir_logits = torch.softmax(dir_logits, dim=1)
        n, _, h, w = dir_logits.shape

        keep_mask = edge_mask

        dir_label = torch.argmax(dir_logits, dim=1).float()
        offset = DTOffsetHelper.label_to_vector(dir_label)
        offset = offset.permute(0, 2, 3, 1)
        offset[~keep_mask, :] = 0
        return offset

    def _flip(self, x, dim=-1):
        indices = [slice(None)] * x.dim()
        indices[dim] = torch.arange(x.size(dim) - 1,
                                    -1,
                                    -1,
                                    dtype=torch.long,
                                    device=x.device)
        return x[tuple(indices)]

    def _flip_offset(self, x):
        x = self._flip(x, dim=-1)
        if len(x.shape) == 4:
            return x[:, DTOffsetHelper.flipping_indices()]
        else:
            return x[DTOffsetHelper.flipping_indices()]

    def _flip_inputs(self, inputs):

        if self.size_mode == 'fix_size':
            return [self._flip(x, -1) for x in inputs]
        else:
            return [[self._flip(x, -1) for x in xs] for xs in inputs]

    def _flip_outputs(self, outputs):
        funcs = [self._flip, self._flip_offset]
        if self.size_mode == 'fix_size':
            return [f(x) for f, x in zip(funcs, outputs)]
        else:
            return [[f(x) for x in xs] for f, xs in zip(funcs, outputs)]

    def _tuple_sum(self, tup1, tup2, tup2_weight=1):
        """
        tup1 / tup2: tuple of tensors or tuple of list of tensors
        """

        if tup1 is None:
            if self.size_mode == 'fix_size':
                return [y * tup2_weight for y in tup2]
            else:
                return [[y * tup2_weight for y in ys] for ys in tup2]
        else:
            if self.size_mode == 'fix_size':
                return [x + y * tup2_weight for x, y in zip(tup1, tup2)]
            else:
                return [[x + y * tup2_weight for x, y in zip(xs, ys)]
                        for xs, ys in zip(tup1, tup2)]

    def _scale_ss_inputs(self, inputs, scale):
        n, c, h, w = inputs[0].shape
        size = (int(h * scale), int(w * scale))
        return [
            F.interpolate(inputs[0],
                          size=size,
                          mode="bilinear",
                          align_corners=True),
        ], (h, w)

    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        scaled_inputs = inputs
        img = scaled_inputs[0]
        n, c, h, w = img.size(0), img.size(1), img.size(2), img.size(3)
        ori_h, ori_w = h, w
        full_probs = [
            torch.cuda.FloatTensor(n, dim, h, w).fill_(0) for dim in (2, 8)
        ]
        count_predictions = [
            torch.cuda.FloatTensor(n, dim, h, w).fill_(0) for dim in (2, 8)
        ]

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = [
                    x[..., height:height + crop_size[0],
                      width:width + crop_size[1]] for x in scaled_inputs
                ]
                prediction = self.ss_test(crop_inputs)

                for j in range(2):
                    count_predictions[j][:, :, height:height + crop_size[0],
                                         width:width + crop_size[1]] += 1
                    full_probs[j][:, :, height:height + crop_size[0],
                                  width:width + crop_size[1]] += prediction[j]
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        for j in range(2):
            full_probs[j] /= count_predictions[j]
            full_probs[j] = F.interpolate(full_probs[j],
                                          size=(ori_h, ori_w),
                                          mode='bilinear',
                                          align_corners=True)
        return full_probs

    def _scale_ss_outputs(self, outputs, size):
        return [
            F.interpolate(x, size=size, mode="bilinear", align_corners=True)
            for x in outputs
        ]

    def ss_test(self, inputs, scale=1):
        if self.size_mode == 'fix_size':

            scaled_inputs, orig_size = self._scale_ss_inputs(inputs, scale)
            print([x.shape for x in scaled_inputs])

            start = timeit.default_timer()
            outputs = list(self.seg_net.forward(*scaled_inputs))
            if len(outputs) == 3:
                outputs = (outputs[0], outputs[2])
            else:
                outputs[0] = F.softmax(outputs[0], dim=1)
            torch.cuda.synchronize()
            end = timeit.default_timer()

            return self._scale_ss_outputs(outputs, orig_size)

        else:
            device_ids = self.configer.get('gpu')
            replicas = nn.parallel.replicate(self.seg_net.module, device_ids)
            scaled_inputs, ori_sizes, outputs = [], [], []

            for *i, d in zip(*inputs, device_ids):
                scaled_i, ori_size_i = self._scale_ss_inputs(
                    [x.unsqueeze(0) for x in i], scale)
                scaled_inputs.append(
                    [x.cuda(d, non_blocking=True) for x in scaled_i])
                ori_sizes.append(ori_size_i)

            scaled_outputs = nn.parallel.parallel_apply(
                replicas[:len(scaled_inputs)], scaled_inputs)

            for o, ori_size in zip(scaled_outputs, ori_sizes):
                o = self._scale_ss_outputs(o, ori_size)
                if len(o) == 3:
                    o = (o[0], o[2])
                outputs.append([x.squeeze(0) for x in o])
            outputs = list(map(list, zip(*outputs)))
            return outputs

    def _decide_intersection(self,
                             total_length,
                             crop_length,
                             crop_stride_ratio=1 / 3):
        stride = int(crop_length *
                     crop_stride_ratio)  # set the stride as the paper do
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride * i)

        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length -
                                    crop_length)  # must cover the total image

        return cropped_starting
Example #3
0
class VideoTester(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.seg_data_loader = DataLoader(configer)
        self.save_dir = self.configer.get('test', 'out_dir')
        self.seg_net = None
        self.test_loader = None
        self.test_size = None
        self.infer_time = 0
        self.infer_cnt = 0
        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        print(f"self.save_dir {self.save_dir}")
        if 'test' in self.save_dir:
            self.test_loader = self.seg_data_loader.get_testloader()
            self.test_size = len(self.test_loader) * self.configer.get(
                'test', 'batch_size')
            print(f"self.test_size {self.test_size}")
        else:
            self.test_loader = self.seg_data_loader.get_valloader()
            self.test_size = len(self.test_loader) * self.configer.get(
                'val', 'batch_size')

        self.seg_net.eval()

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get(
                'data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        print("test!!!")
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        colors = get_ade_colors()

        # Reader.
        input_path = self.configer.get('input_video')
        cap = cv2.VideoCapture(self.configer.get('input_video'))

        total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
        v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
        total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        fps = cap.get(cv2.CAP_PROP_FPS)

        # Writer.
        output_name = '.'.join(os.path.basename(input_path).split('.')[:-1])
        output_name = output_name + '_out.avi'
        RATIO_IMG_W = 200
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        out = cv2.VideoWriter(output_name, fourcc, fps,
                              (int(v_w + RATIO_IMG_W), int(v_h)))

        transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get(
                'normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')),
        ])

        aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer,
                                                             split='val')

        pre_vis_img = None
        pre_lines = None
        pre_target_img = None
        for i in tqdm(range(int(total_frames))):
            ret, img = cap.read()

            if not ret:
                break

            ori_img = img.copy()

            h, w, _ = img.shape
            ori_img_size = [w, h]

            # print(img.shape)
            img = aug_val_transform(img)
            img = img[0]

            h, w, _ = img.shape
            border_size = [w, h]

            img = transform(img)
            # print(img)
            # print(img.shape)

            # inputs = data_dict['img']
            # names = data_dict['name']
            # metas = data_dict['meta']

            # print(inputs)

            with torch.no_grad():
                # Forward pass.
                outputs = self.ss_test([img])

                if isinstance(outputs, torch.Tensor):
                    outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                    n = outputs.shape[0]
                else:
                    outputs = [
                        output.permute(0, 2, 3, 1).cpu().numpy().squeeze()
                        for output in outputs
                    ]
                    n = len(outputs)

                image_id += 1
                logits = cv2.resize(outputs[0],
                                    tuple(ori_img_size),
                                    interpolation=cv2.INTER_CUBIC)
                label_img = np.asarray(np.argmax(logits, axis=-1),
                                       dtype=np.uint8)
                if self.configer.exists(
                        'data', 'reduce_zero_label') and self.configer.get(
                            'data', 'reduce_zero_label'):
                    label_img = label_img + 1
                    label_img = label_img.astype(np.uint8)
                if self.configer.exists('data', 'label_list'):
                    label_img_ = self.__relabel(label_img)
                else:
                    label_img_ = label_img
                # print(f"label_img_1 {label_img_}")
                lines = self.get_ratio_all(label_img_)
                # print(f"lines {lines}")
                vis_img = self.visualize(label_img_)
                # print(f"vis_img {vis_img.shape}")

                pre_vis_img = vis_img
                # # if pre_vis_img is None:
                #     pre_vis_img = vis_img

                # if i % fps == 0:
                #     pre_vis_img = vis_img

                alpha = 0.5
                cv2.addWeighted(pre_vis_img, alpha, ori_img, 1 - alpha, 0,
                                ori_img)

                pre_lines = lines
                # if pre_lines is None:
                #     pre_lines = lines

                # if i % fps == 0:
                #     pre_lines = lines
                ratio_img = self.visualize_ratio(pre_lines, (v_h, v_w),
                                                 RATIO_IMG_W)

                target_img = cv2.hconcat([ori_img, ratio_img])

                target_img = cv2.cvtColor(target_img, cv2.COLOR_RGB2BGR)

                # if pre_target_img is None:
                #     pre_target_img = target_img

                # if i % fps == 0:
                #     pre_target_img = target_img

                out.write(target_img)

                # label_img_ = Image.fromarray(label_img_, 'P')
                # Log.info('{:4d}/{:4d} label map generated'.format(image_id, self.test_size))
                # label_path = os.path.join(self.save_dir, "label/", '{:05d}.png'.format(i))
                # FileHelper.make_dirs(label_path, is_file=True)
                # ImageHelper.save(label_img_, label_path)
                # # colorize the label-map
                # color_img_ = Image.fromarray(target_img)
                # vis_path = os.path.join(self.save_dir, "vis/", '{:05d}.png'.format(i))
                # FileHelper.make_dirs(vis_path, is_file=True)
                # ImageHelper.save(color_img_, save_path=vis_path)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))

    def offset_test(self, inputs, offset_h_maps, offset_w_maps, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(inputs, offset_h_maps,
                                           offset_w_maps)
            torch.cuda.synchronize()
            end = timeit.default_timer()

            if (self.configer.get('loss', 'loss_type')
                    == "fs_auxce_loss") or (self.configer.get(
                        'loss', 'loss_type') == "triple_auxce_loss"):
                outputs = outputs[-1]
            elif self.configer.get('loss',
                                   'loss_type') == "pyramid_auxce_loss":
                outputs = outputs[1] + outputs[2] + outputs[3] + outputs[4]

            outputs = F.interpolate(outputs,
                                    size=(h, w),
                                    mode='bilinear',
                                    align_corners=True)
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def ss_test(self, inputs, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            scaled_inputs = F.interpolate(inputs,
                                          size=(int(h * scale),
                                                int(w * scale)),
                                          mode="bilinear",
                                          align_corners=True)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(scaled_inputs)
            torch.cuda.synchronize()
            end = timeit.default_timer()
            outputs = outputs[-1]
            outputs = F.interpolate(outputs,
                                    size=(h, w),
                                    mode='bilinear',
                                    align_corners=True)
            return outputs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            replicas = nn.parallel.replicate(self.seg_net.module, device_ids)
            scaled_inputs, ori_size, outputs = [], [], []
            for i, d in zip(inputs, device_ids):
                h, w = i.size(1), i.size(2)
                ori_size.append((h, w))
                i = F.interpolate(i.unsqueeze(0),
                                  size=(int(h * scale), int(w * scale)),
                                  mode="bilinear",
                                  align_corners=True)
                scaled_inputs.append(i.cuda(d, non_blocking=True))
            scaled_outputs = nn.parallel.parallel_apply(
                replicas[:len(scaled_inputs)], scaled_inputs)
            for i, output in enumerate(scaled_outputs):
                outputs.append(
                    F.interpolate(output[-1],
                                  size=ori_size[i],
                                  mode='bilinear',
                                  align_corners=True))
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def flip(self, x, dim):
        indices = [slice(None)] * x.dim()
        indices[dim] = torch.arange(x.size(dim) - 1,
                                    -1,
                                    -1,
                                    dtype=torch.long,
                                    device=x.device)
        return x[tuple(indices)]

    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        n, c, ori_h, ori_w = inputs.size(0), inputs.size(1), inputs.size(
            2), inputs.size(3)
        scaled_inputs = F.interpolate(inputs,
                                      size=(int(ori_h * scale),
                                            int(ori_w * scale)),
                                      mode="bilinear",
                                      align_corners=True)
        n, c, h, w = scaled_inputs.size(0), scaled_inputs.size(
            1), scaled_inputs.size(2), scaled_inputs.size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        count_predictions = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = scaled_inputs[:, :, height:height + crop_size[0],
                                            width:width + crop_size[1]]
                prediction = self.ss_test(crop_inputs)
                count_predictions[:, :, height:height + crop_size[0],
                                  width:width + crop_size[1]] += 1
                full_probs[:, :, height:height + crop_size[0],
                           width:width + crop_size[1]] += prediction
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        full_probs /= count_predictions
        full_probs = F.interpolate(full_probs,
                                   size=(ori_h, ori_w),
                                   mode='bilinear',
                                   align_corners=True)
        return full_probs

    def ms_test(self, inputs):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(
                        self.configer.get('test', 'scale_search'),
                        self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += weight * probs
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += probs
                return full_probs

        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [
                torch.zeros(1, self.configer.get('data', 'num_classes'),
                            i.size(1), i.size(2)).cuda(device_ids[index],
                                                       non_blocking=True)
                for index, i in enumerate(inputs)
            ]
            flip_inputs = [self.flip(i, 2) for i in inputs]

            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(
                        self.configer.get('test', 'scale_search'),
                        self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += weight * (probs[i] +
                                                   self.flip(flip_probs[i], 3))
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += (probs[i] +
                                          self.flip(flip_probs[i], 3))
                return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def ms_test_depth(self, inputs, names):
        prob_list = []
        scale_list = []

        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                prob_list.append(probs)
                scale_list.append(scale)

            full_probs = self.fuse_with_depth(prob_list, scale_list, names)
            return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def fuse_with_depth(self, probs, scales, names):
        MAX_DEPTH = 63
        POWER_BASE = 0.8
        if 'test' in self.save_dir:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/test/"
        else:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/val/"

        n, c, h, w = probs[0].size(0), probs[0].size(1), probs[0].size(
            2), probs[0].size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        for index, name in enumerate(names):
            stereo_map = cv2.imread(stereo_path + name + '.png', -1)
            depth_map = stereo_map / 256.0
            depth_map = 0.5 / depth_map
            depth_map = 500 * depth_map

            depth_map = np.clip(depth_map, 0, MAX_DEPTH)
            depth_map = depth_map // (MAX_DEPTH // len(scales))

            for prob, scale in zip(probs, scales):
                scale_index = self._locate_scale_index(scale, scales)
                weight_map = np.abs(depth_map - scale_index)
                weight_map = np.power(POWER_BASE, weight_map)
                weight_map = cv2.resize(weight_map, (w, h))
                full_probs[index, :, :, :] += torch.from_numpy(
                    np.expand_dims(weight_map, axis=0)).type(
                        torch.cuda.FloatTensor) * prob[index, :, :, :]

        return full_probs

    @staticmethod
    def _locate_scale_index(scale, scales):
        for idx, s in enumerate(scales):
            if scale == s:
                return idx
        return 0

    def ms_test_wo_flip(self, inputs):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                full_probs += probs
            return full_probs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [
                torch.zeros(1, self.configer.get('data', 'num_classes'),
                            i.size(1), i.size(2)).cuda(device_ids[index],
                                                       non_blocking=True)
                for index, i, in enumerate(inputs)
            ]
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                for i in range(len(inputs)):
                    full_probs[i] += probs[i]
            return full_probs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def mscrop_test(self, inputs, crop_size):
        '''
        Currently, mscrop_test does not support diverse_size testing
        '''
        n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
            2), inputs.size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        for scale in self.configer.get('test', 'scale_search'):
            Log.info('Scale {0:.2f} prediction'.format(scale))
            if scale < 1:
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
            else:
                probs = self.sscrop_test(inputs, crop_size, scale)
                flip_probs = self.sscrop_test(self.flip(inputs, 3), crop_size,
                                              scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
        return full_probs

    def _decide_intersection(self, total_length, crop_length):
        stride = crop_length
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride * i)
        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length -
                                    crop_length)  # must cover the total image
        return cropped_starting

    def dense_crf_process(self, images, outputs):
        '''
        Reference: https://github.com/kazuto1011/deeplab-pytorch/blob/master/libs/utils/crf.py
        '''
        # hyperparameters of the dense crf
        # baseline = 79.5
        # bi_xy_std = 67, 79.1
        # bi_xy_std = 20, 79.6
        # bi_xy_std = 10, 79.7
        # bi_xy_std = 10, iter_max = 20, v4 79.7
        # bi_xy_std = 10, iter_max = 5, v5 79.7
        # bi_xy_std = 5, v3 79.7
        iter_max = 10
        pos_w = 3
        pos_xy_std = 1
        bi_w = 4
        bi_xy_std = 10
        bi_rgb_std = 3

        b = images.size(0)
        mean_vector = np.expand_dims(np.expand_dims(np.transpose(
            np.array([102.9801, 115.9465, 122.7717])),
                                                    axis=1),
                                     axis=2)
        outputs = F.softmax(outputs, dim=1)
        for i in range(b):
            unary = outputs[i].data.cpu().numpy()
            C, H, W = unary.shape
            unary = dcrf_utils.unary_from_softmax(unary)
            unary = np.ascontiguousarray(unary)

            image = np.ascontiguousarray(images[i]) + mean_vector
            image = image.astype(np.ubyte)
            image = np.ascontiguousarray(image.transpose(1, 2, 0))

            d = dcrf.DenseCRF2D(W, H, C)
            d.setUnaryEnergy(unary)
            d.addPairwiseGaussian(sxy=pos_xy_std, compat=pos_w)
            d.addPairwiseBilateral(sxy=bi_xy_std,
                                   srgb=bi_rgb_std,
                                   rgbim=image,
                                   compat=bi_w)
            out_crf = np.array(d.inference(iter_max))
            outputs[i] = torch.from_numpy(out_crf).cuda().view(C, H, W)

        return outputs

    def visualize(self, label_img):
        img = label_img.copy()
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        img_num = 14
        img[img == img_num] = 0
        out_img = img.copy()
        for label in HYUNDAI_POC_CATEGORIES:
            red, green, blue = img[:, :, 0], img[:, :, 1], img[:, :, 2]
            mask = red == label['id']
            out_img[:, :, :3][mask] = label['color']

        out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)

        return out_img

    def get_ratio_all(self, anno_img):
        total_size = 0
        lines = []
        for label in HYUNDAI_POC_CATEGORIES:
            total = self.get_ratio(anno_img.copy(), label)
            lines.append([label['name'], total])
            total_size += total
        for l in lines:
            if total_size:
                l[1] = l[1] / total_size * 100
            else:
                l[1] = 0
        return lines

    def get_ratio(self, anno_img, label):
        total = 0
        label_id = label['id']
        if label_id == 14:
            return total
        label_img = (anno_img == label_id).astype(np.uint8)
        # label_img = cv2.cvtColor(label_img, cv2.COLOR_BGR2GRAY)
        total = np.count_nonzero(label_img)
        return total

    def visualize_ratio(self, ratios, video_size, ratio_w):
        ratio_list = ratios.copy()
        ratio_list.insert(0, ['등급', '비율'])
        RATIO_IMG_W = ratio_w
        RATIO_IMG_H = int(video_size[0])
        TEXT_MARGIN_H = 20
        TEXT_MARGIN_W = 10
        row_count = 14
        col_count = 2

        ratio_img = np.full((RATIO_IMG_H, RATIO_IMG_W, 3), 255, np.uint8)

        row_h = RATIO_IMG_H / row_count
        col_w = RATIO_IMG_H / row_count

        center_w = RATIO_IMG_W / 2

        for i in range(1, row_count):
            p_y = int(i * row_h)
            p_y_n = int((i + 1) * row_h)
            for label in HYUNDAI_POC_CATEGORIES:
                if label['id'] == i:
                    cv2.rectangle(ratio_img, (0, p_y), (int(center_w), p_y_n),
                                  label['color'], cv2.FILLED)

        for i in range(1, row_count):
            p_y = int(i * row_h)
            cv2.line(ratio_img, (0, p_y), (RATIO_IMG_W, p_y), (0, 0, 0))

        cv2.line(ratio_img, (int(center_w), 0), (int(center_w), RATIO_IMG_H),
                 (0, 0, 0))

        for i in range(row_count):
            p_y = int(i * row_h) + TEXT_MARGIN_H
            p_w = int(center_w) + TEXT_MARGIN_W
            ratio_img = Image.fromarray(ratio_img)
            font = ImageFont.truetype("NanumGothic.ttf", 15)
            draw = ImageDraw.Draw(ratio_img)
            color = (0, 0, 0)
            # print(ratio_list)
            draw.text((0, p_y), ratio_list[i][0], font=font, fill=color)
            if isinstance(ratio_list[i][1], str):
                draw.text((p_w, p_y), ratio_list[i][1], font=font, fill=color)
            else:
                draw.text((p_w, p_y),
                          "{:.02f}".format(ratio_list[i][1]),
                          font=font,
                          fill=color)
            ratio_img = np.array(ratio_img)

        ratio_img = cv2.cvtColor(ratio_img, cv2.COLOR_RGB2BGR)
        return ratio_img
Example #4
0
class Tester(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.seg_data_loader = DataLoader(configer)
        self.save_dir = self.configer.get('test', 'out_dir')
        self.seg_net = None
        self.test_loader = None
        self.test_size = None
        self.infer_time = 0
        self.infer_cnt = 0
        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        if 'test' in self.save_dir:
            self.test_loader = self.seg_data_loader.get_testloader()
            self.test_size = len(self.test_loader) * self.configer.get(
                'test', 'batch_size')
        else:
            self.test_loader = self.seg_data_loader.get_valloader()
            self.test_size = len(self.test_loader) * self.configer.get(
                'val', 'batch_size')

        self.seg_net.eval()

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get(
                'data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        if self.configer.get('dataset') in ['cityscapes', 'gta5']:
            colors = get_cityscapes_colors()
        elif self.configer.get('dataset') == 'ade20k':
            colors = get_ade_colors()
        elif self.configer.get('dataset') == 'lip':
            colors = get_lip_colors()
        elif self.configer.get('dataset') == 'pascal_context':
            colors = get_pascal_context_colors()
        elif self.configer.get('dataset') == 'pascal_voc':
            colors = get_pascal_voc_colors()
        elif self.configer.get('dataset') == 'coco_stuff':
            colors = get_cocostuff_colors()
        else:
            raise RuntimeError("Unsupport colors")

        save_prob = False
        if self.configer.get('test', 'save_prob'):
            save_prob = self.configer.get('test', 'save_prob')

            def softmax(X, axis=0):
                max_prob = np.max(X, axis=axis, keepdims=True)
                X -= max_prob
                X = np.exp(X)
                sum_prob = np.sum(X, axis=axis, keepdims=True)
                X /= sum_prob
                return X

        for j, data_dict in enumerate(self.test_loader):
            inputs = data_dict['img']
            names = data_dict['name']
            metas = data_dict['meta']

            if 'val' in self.save_dir and os.environ.get('save_gt_label'):
                labels = data_dict['labelmap']

            with torch.no_grad():
                # Forward pass.
                if self.configer.exists('data',
                                        'use_offset') and self.configer.get(
                                            'data', 'use_offset') == 'offline':
                    offset_h_maps = data_dict['offsetmap_h']
                    offset_w_maps = data_dict['offsetmap_w']
                    outputs = self.offset_test(inputs, offset_h_maps,
                                               offset_w_maps)
                elif self.configer.get('test', 'mode') == 'ss_test':
                    outputs = self.ss_test(inputs)
                elif self.configer.get('test', 'mode') == 'ms_test':
                    outputs = self.ms_test(inputs)
                elif self.configer.get('test', 'mode') == 'ms_test_depth':
                    outputs = self.ms_test_depth(inputs, names)
                elif self.configer.get('test', 'mode') == 'sscrop_test':
                    crop_size = self.configer.get('test', 'crop_size')
                    outputs = self.sscrop_test(inputs, crop_size)
                elif self.configer.get('test', 'mode') == 'mscrop_test':
                    crop_size = self.configer.get('test', 'crop_size')
                    outputs = self.mscrop_test(inputs, crop_size)
                elif self.configer.get('test', 'mode') == 'crf_ss_test':
                    outputs = self.ss_test(inputs)
                    outputs = self.dense_crf_process(inputs, outputs)

                if isinstance(outputs, torch.Tensor):
                    outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                    n = outputs.shape[0]
                else:
                    outputs = [
                        output.permute(0, 2, 3, 1).cpu().numpy().squeeze()
                        for output in outputs
                    ]
                    n = len(outputs)

                for k in range(n):
                    image_id += 1
                    ori_img_size = metas[k]['ori_img_size']
                    border_size = metas[k]['border_size']
                    logits = cv2.resize(
                        outputs[k][:border_size[1], :border_size[0]],
                        tuple(ori_img_size),
                        interpolation=cv2.INTER_CUBIC)

                    # save the logits map
                    if self.configer.get('test', 'save_prob'):
                        prob_path = os.path.join(self.save_dir, "prob/",
                                                 '{}.npy'.format(names[k]))
                        FileHelper.make_dirs(prob_path, is_file=True)
                        np.save(prob_path, softmax(logits, axis=-1))

                    label_img = np.asarray(np.argmax(logits, axis=-1),
                                           dtype=np.uint8)
                    if self.configer.exists(
                            'data', 'reduce_zero_label') and self.configer.get(
                                'data', 'reduce_zero_label'):
                        label_img = label_img + 1
                        label_img = label_img.astype(np.uint8)
                    if self.configer.exists('data', 'label_list'):
                        label_img_ = self.__relabel(label_img)
                    else:
                        label_img_ = label_img
                    label_img_ = Image.fromarray(label_img_, 'P')
                    Log.info('{:4d}/{:4d} label map generated'.format(
                        image_id, self.test_size))
                    label_path = os.path.join(self.save_dir, "label/",
                                              '{}.png'.format(names[k]))
                    FileHelper.make_dirs(label_path, is_file=True)
                    ImageHelper.save(label_img_, label_path)

                    # colorize the label-map
                    if os.environ.get('save_gt_label'):
                        if self.configer.exists(
                                'data',
                                'reduce_zero_label') and self.configer.get(
                                    'data', 'reduce_zero_label'):
                            label_img = labels[k] + 1
                            label_img = np.asarray(label_img, dtype=np.uint8)
                        color_img_ = Image.fromarray(label_img)
                        color_img_.putpalette(colors)
                        vis_path = os.path.join(self.save_dir, "gt_vis/",
                                                '{}.png'.format(names[k]))
                        FileHelper.make_dirs(vis_path, is_file=True)
                        ImageHelper.save(color_img_, save_path=vis_path)
                    else:
                        color_img_ = Image.fromarray(label_img)
                        color_img_.putpalette(colors)
                        vis_path = os.path.join(self.save_dir, "vis/",
                                                '{}.png'.format(names[k]))
                        FileHelper.make_dirs(vis_path, is_file=True)
                        ImageHelper.save(color_img_, save_path=vis_path)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))

    def offset_test(self, inputs, offset_h_maps, offset_w_maps, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(inputs, offset_h_maps,
                                           offset_w_maps)
            torch.cuda.synchronize()
            end = timeit.default_timer()

            if (self.configer.get('loss', 'loss_type')
                    == "fs_auxce_loss") or (self.configer.get(
                        'loss', 'loss_type') == "triple_auxce_loss"):
                outputs = outputs[-1]
            elif self.configer.get('loss',
                                   'loss_type') == "pyramid_auxce_loss":
                outputs = outputs[1] + outputs[2] + outputs[3] + outputs[4]

            outputs = F.interpolate(outputs,
                                    size=(h, w),
                                    mode='bilinear',
                                    align_corners=True)
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def ss_test(self, inputs, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            scaled_inputs = F.interpolate(inputs,
                                          size=(int(h * scale),
                                                int(w * scale)),
                                          mode="bilinear",
                                          align_corners=True)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(scaled_inputs)
            torch.cuda.synchronize()
            end = timeit.default_timer()
            outputs = outputs[-1]
            outputs = F.interpolate(outputs,
                                    size=(h, w),
                                    mode='bilinear',
                                    align_corners=True)
            return outputs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            replicas = nn.parallel.replicate(self.seg_net.module, device_ids)
            scaled_inputs, ori_size, outputs = [], [], []
            for i, d in zip(inputs, device_ids):
                h, w = i.size(1), i.size(2)
                ori_size.append((h, w))
                i = F.interpolate(i.unsqueeze(0),
                                  size=(int(h * scale), int(w * scale)),
                                  mode="bilinear",
                                  align_corners=True)
                scaled_inputs.append(i.cuda(d, non_blocking=True))
            scaled_outputs = nn.parallel.parallel_apply(
                replicas[:len(scaled_inputs)], scaled_inputs)
            for i, output in enumerate(scaled_outputs):
                outputs.append(
                    F.interpolate(output[-1],
                                  size=ori_size[i],
                                  mode='bilinear',
                                  align_corners=True))
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def flip(self, x, dim):
        indices = [slice(None)] * x.dim()
        indices[dim] = torch.arange(x.size(dim) - 1,
                                    -1,
                                    -1,
                                    dtype=torch.long,
                                    device=x.device)
        return x[tuple(indices)]

    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        n, c, ori_h, ori_w = inputs.size(0), inputs.size(1), inputs.size(
            2), inputs.size(3)
        scaled_inputs = F.interpolate(inputs,
                                      size=(int(ori_h * scale),
                                            int(ori_w * scale)),
                                      mode="bilinear",
                                      align_corners=True)
        n, c, h, w = scaled_inputs.size(0), scaled_inputs.size(
            1), scaled_inputs.size(2), scaled_inputs.size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        count_predictions = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = scaled_inputs[:, :, height:height + crop_size[0],
                                            width:width + crop_size[1]]
                prediction = self.ss_test(crop_inputs)
                count_predictions[:, :, height:height + crop_size[0],
                                  width:width + crop_size[1]] += 1
                full_probs[:, :, height:height + crop_size[0],
                           width:width + crop_size[1]] += prediction
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        full_probs /= count_predictions
        full_probs = F.interpolate(full_probs,
                                   size=(ori_h, ori_w),
                                   mode='bilinear',
                                   align_corners=True)
        return full_probs

    def ms_test(self, inputs):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(
                        self.configer.get('test', 'scale_search'),
                        self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += weight * probs
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += probs
                return full_probs

        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [
                torch.zeros(1, self.configer.get('data', 'num_classes'),
                            i.size(1), i.size(2)).cuda(device_ids[index],
                                                       non_blocking=True)
                for index, i in enumerate(inputs)
            ]
            flip_inputs = [self.flip(i, 2) for i in inputs]

            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(
                        self.configer.get('test', 'scale_search'),
                        self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += weight * (probs[i] +
                                                   self.flip(flip_probs[i], 3))
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += (probs[i] +
                                          self.flip(flip_probs[i], 3))
                return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def ms_test_depth(self, inputs, names):
        prob_list = []
        scale_list = []

        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                prob_list.append(probs)
                scale_list.append(scale)

            full_probs = self.fuse_with_depth(prob_list, scale_list, names)
            return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def fuse_with_depth(self, probs, scales, names):
        MAX_DEPTH = 63
        POWER_BASE = 0.8
        if 'test' in self.save_dir:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/test/"
        else:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/val/"

        n, c, h, w = probs[0].size(0), probs[0].size(1), probs[0].size(
            2), probs[0].size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        for index, name in enumerate(names):
            stereo_map = cv2.imread(stereo_path + name + '.png', -1)
            depth_map = stereo_map / 256.0
            depth_map = 0.5 / depth_map
            depth_map = 500 * depth_map

            depth_map = np.clip(depth_map, 0, MAX_DEPTH)
            depth_map = depth_map // (MAX_DEPTH // len(scales))

            for prob, scale in zip(probs, scales):
                scale_index = self._locate_scale_index(scale, scales)
                weight_map = np.abs(depth_map - scale_index)
                weight_map = np.power(POWER_BASE, weight_map)
                weight_map = cv2.resize(weight_map, (w, h))
                full_probs[index, :, :, :] += torch.from_numpy(
                    np.expand_dims(weight_map, axis=0)).type(
                        torch.cuda.FloatTensor) * prob[index, :, :, :]

        return full_probs

    @staticmethod
    def _locate_scale_index(scale, scales):
        for idx, s in enumerate(scales):
            if scale == s:
                return idx
        return 0

    def ms_test_wo_flip(self, inputs):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                full_probs += probs
            return full_probs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [
                torch.zeros(1, self.configer.get('data', 'num_classes'),
                            i.size(1), i.size(2)).cuda(device_ids[index],
                                                       non_blocking=True)
                for index, i, in enumerate(inputs)
            ]
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                for i in range(len(inputs)):
                    full_probs[i] += probs[i]
            return full_probs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def mscrop_test(self, inputs, crop_size):
        '''
        Currently, mscrop_test does not support diverse_size testing
        '''
        n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
            2), inputs.size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        for scale in self.configer.get('test', 'scale_search'):
            Log.info('Scale {0:.2f} prediction'.format(scale))
            if scale < 1:
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
            else:
                probs = self.sscrop_test(inputs, crop_size, scale)
                flip_probs = self.sscrop_test(self.flip(inputs, 3), crop_size,
                                              scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
        return full_probs

    def _decide_intersection(self, total_length, crop_length):
        stride = crop_length
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride * i)
        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length -
                                    crop_length)  # must cover the total image
        return cropped_starting

    def dense_crf_process(self, images, outputs):
        '''
        Reference: https://github.com/kazuto1011/deeplab-pytorch/blob/master/libs/utils/crf.py
        '''
        # hyperparameters of the dense crf
        # baseline = 79.5
        # bi_xy_std = 67, 79.1
        # bi_xy_std = 20, 79.6
        # bi_xy_std = 10, 79.7
        # bi_xy_std = 10, iter_max = 20, v4 79.7
        # bi_xy_std = 10, iter_max = 5, v5 79.7
        # bi_xy_std = 5, v3 79.7
        iter_max = 10
        pos_w = 3
        pos_xy_std = 1
        bi_w = 4
        bi_xy_std = 10
        bi_rgb_std = 3

        b = images.size(0)
        mean_vector = np.expand_dims(np.expand_dims(np.transpose(
            np.array([102.9801, 115.9465, 122.7717])),
                                                    axis=1),
                                     axis=2)
        outputs = F.softmax(outputs, dim=1)
        for i in range(b):
            unary = outputs[i].data.cpu().numpy()
            C, H, W = unary.shape
            unary = dcrf_utils.unary_from_softmax(unary)
            unary = np.ascontiguousarray(unary)

            image = np.ascontiguousarray(images[i]) + mean_vector
            image = image.astype(np.ubyte)
            image = np.ascontiguousarray(image.transpose(1, 2, 0))

            d = dcrf.DenseCRF2D(W, H, C)
            d.setUnaryEnergy(unary)
            d.addPairwiseGaussian(sxy=pos_xy_std, compat=pos_w)
            d.addPairwiseBilateral(sxy=bi_xy_std,
                                   srgb=bi_rgb_std,
                                   rgbim=image,
                                   compat=bi_w)
            out_crf = np.array(d.inference(iter_max))
            outputs[i] = torch.from_numpy(out_crf).cuda().view(C, H, W)

        return outputs