class AutoSKUMerger(object):
    """
      The class for the training phase of Image classification.
    """
    def __init__(self, configer):
        self.configer = configer
        self.cls_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = None
        self.round = 1

        self._relabel()

    def _relabel(self):
        label_id = 0
        label_dict = dict()
        old_label_path = self.configer.get('data', 'label_path')
        new_label_path = '{}_new'.format(self.configer.get('data', 'label_path'))
        self.configer.update('data.label_path', new_label_path)
        fw = open(new_label_path, 'w')
        check_valid_dict = dict()
        with open(old_label_path, 'r') as fr:
            for line in fr.readlines():
                line_items = line.strip().split()
                if not os.path.exists(os.path.join(self.configer.get('data', 'data_dir'), line_items[0])):
                    continue

                if line_items[1] not in label_dict:
                    label_dict[line_items[1]] = label_id
                    label_id += 1

                if line_items[0] in check_valid_dict:
                    Log.error('Duplicate Error: {}'.format(line_items[0]))
                    exit()

                check_valid_dict[line_items[0]] = 1
                fw.write('{} {}\n'.format(line_items[0], label_dict[line_items[1]]))

        fw.close()
        shutil.copy(self.configer.get('data', 'label_path'),
                    os.path.join(self.configer.get('data', 'merge_dir'), 'ori_label.txt'))
        self.configer.update(('data.num_classes'), [label_id])
        Log.info('Num Classes is {}...'.format(self.configer.get('data', 'num_classes')))

    def _init(self):
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.cls_model_manager = ModelManager(self.configer)
        self.cls_data_loader = DataLoader(self.configer)
        self.cls_running_score = RunningScore(self.configer)
        self.runner_state = dict(iters=0, last_iters=0, epoch=0,
                                 last_epoch=0, performance=0,
                                 val_loss=0, max_performance=0, min_val_loss=0)
        self.cls_net = self.cls_model_manager.get_model()
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.solver_dict = self.configer.get(self.configer.get('train', 'solver'))
        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.solver_dict)

        self.cls_net, self.optimizer = RunnerHelper.to_dtype(self, self.cls_net, self.optimizer)
        self.train_loader = self.cls_data_loader.get_trainloader()
        self.val_loader = self.cls_data_loader.get_valloader()
        self.loss = self.cls_model_manager.get_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_2 = []
        params_dict = dict(self.cls_net.named_parameters())
        for key, value in params_dict.items():
            if value.requires_grad:
                if 'backbone' in key:
                    if self.configer.get('network', 'bb_lr_scale') == 0.0:
                        value.requires_grad = False
                    else:
                        lr_1.append(value)
                else:
                    lr_2.append(value)

        params = [{'params': lr_1, 'lr': self.solver_dict['lr']['base_lr']*self.configer.get('network', 'bb_lr_scale')},
                  {'params': lr_2, 'lr': self.solver_dict['lr']['base_lr']}]
        return params

    def _merge_class(self, cmatrix, fscore_list):
        Log.info('Merging class...')
        Log.info('Avg F1-score: {}'.format(fscore_list[-1]))
        threshold = max(self.configer.get('merge', 'min_thre'),
                        self.configer.get('merge', 'max_thre') - self.configer.get('merge', 'round_decay') * self.round)
        h, w = cmatrix.shape[0], cmatrix.shape[1]
        per_class_num = np.sum(cmatrix, 1)
        pairs_list = list()
        pre_dict = dict()
        for i in range(h):
            for j in range(w):
                if i == j:
                    continue

                if cmatrix[i][j] * 1.0 / per_class_num[i] > threshold:
                    pairs_list.append([i, j])
                    pre_dict[i] = i
                    pre_dict[j] = j

        for pair in pairs_list:
            root_node = list()
            for item in pair:
                r = item
                while pre_dict[r] != r:
                    r = pre_dict[r]

                i = item
                while i != r:
                    j = pre_dict[i]
                    pre_dict[i] = r
                    i = j

                root_node.append(r)

            if root_node[0] != root_node[1]:
                pre_dict[root_node[0]] = root_node[1]

        pairs_dict = dict()
        for k in pre_dict.keys():
            v = k
            while pre_dict[v] != v:
                v = pre_dict[v]

            if v != k:
                if v not in pairs_dict:
                    pairs_dict[v] = [k]
                else:
                    pairs_dict[v].append(k)

        mutual_pairs_dict = {}
        for k, v in pairs_dict.items():
            mutual_pairs_dict[k] = v
            if len(v) > 1:  # multi relation
                for p in v:
                    mutual_pairs_dict[p] = [k]
                    for q in v:
                        if p != q:
                            mutual_pairs_dict[p].append(q)

            else:
                mutual_pairs_dict[v[0]] = [k]  # mutual relation

        id_map_list = [-1] * self.configer.get('data', 'num_classes')[0]
        label_cnt = 0
        for i in range(self.configer.get('data', 'num_classes')[0]):
            if id_map_list[i] != -1:
                continue

            power = self.round / self.configer.get('merge', 'max_round')
            if self.configer.get('merge', 'enable_fscore') and \
                    fscore_list[i] / fscore_list[-1] < self.configer.get('merge', 'fscore_ratio') * power:
                continue

            id_map_list[i] = label_cnt
            if i in mutual_pairs_dict:
                for v in mutual_pairs_dict[i]:
                    assert id_map_list[v] == -1
                    id_map_list[v] = label_cnt

            label_cnt += 1

        fw = open('{}_{}'.format(self.configer.get('data', 'label_path'), self.round), 'w')
        with open(self.configer.get('data', 'label_path'), 'r') as fr:
            for line in fr.readlines():
                path, label = line.strip().split()
                if id_map_list[int(label)] == -1:
                    continue

                map_label = id_map_list[int(label)]
                fw.write('{} {}\n'.format(path, map_label))

        fw.close()
        shutil.move('{}_{}'.format(self.configer.get('data', 'label_path'), self.round),
                    self.configer.get('data', 'label_path'))
        shutil.copy(self.configer.get('data', 'label_path'),
                    os.path.join(self.configer.get('data', 'merge_dir'), 'label_{}.txt'.format(self.round)))
        old_label_cnt = self.configer.get('data', 'num_classes')[0]
        self.configer.update('data.num_classes', [label_cnt])
        return old_label_cnt - label_cnt

    def run(self):
        last_acc = 0.0
        while self.round <= self.configer.get('merge', 'max_round'):
            Log.info('Merge Round: {}'.format(self.round))
            Log.info('num classes: {}'.format(self.configer.get('data', 'num_classes')))
            self._init()
            self.train()
            acc, cmatrix, fscore_list = self.val(self.cls_data_loader.get_valloader())
            merge_cnt = self._merge_class(cmatrix, fscore_list)
            if merge_cnt < self.configer.get('merge', 'cnt_thre') \
                    or (acc - last_acc) < self.configer.get('merge', 'acc_thre'):
                break

            last_acc = acc
            self.round += 1

        shutil.copy(self.configer.get('data', 'label_path'),
                    os.path.join(self.configer.get('data', 'merge_dir'), 'merge_label.txt'))
        self._init()
        self.train()
        Log.info('num classes: {}'.format(self.configer.get('data', 'num_classes')))

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.cls_net.train()
        start_time = time.time()
        while self.runner_state['iters'] < self.solver_dict['max_iters']:
            # Adjust the learning rate after every epoch.
            self.runner_state['epoch'] += 1
            for i, data_dict in enumerate(self.train_loader):
                Trainer.update(self, solver_dict=self.solver_dict)
                self.data_time.update(time.time() - start_time)
                # Change the data type.
                # Forward pass.
                out = self.cls_net(data_dict)
                # Compute the loss of the train batch & backward.

                loss_dict = self.loss(out)
                loss = loss_dict['loss']
                self.train_losses.update(loss.item(), data_dict['img'].size(0))
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # Update the vars of the train phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()
                self.runner_state['iters'] += 1

                # Print the log info & reset the states.
                if self.runner_state['iters'] % self.solver_dict['display_iter'] == 0:
                    Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t'
                             'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                             'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                             'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                                 self.runner_state['epoch'], self.runner_state['iters'],
                                 self.solver_dict['display_iter'],
                                 RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time,
                                 data_time=self.data_time, loss=self.train_losses))

                    self.batch_time.reset()
                    self.data_time.reset()
                    self.train_losses.reset()

                if self.solver_dict['lr']['metric'] == 'iters' and self.runner_state['iters'] == self.solver_dict['max_iters']:
                    self.val()
                    break

                # Check to val the current model.
                if self.runner_state['iters'] % self.solver_dict['test_interval'] == 0:
                    self.val()

    def val(self, loader=None):
        """
          Validation function during the train phase.
        """
        self.cls_net.eval()
        start_time = time.time()

        loader = self.val_loader if loader is None else loader
        list_y_true, list_y_pred = [], []
        with torch.no_grad():
            for j, data_dict in enumerate(loader):
                out = self.cls_net(data_dict)
                loss_dict = self.loss(out)
                out_dict, label_dict, _ = RunnerHelper.gather(self, out)
                # Compute the loss of the val batch.
                self.cls_running_score.update(out_dict, label_dict)
                y_true = label_dict['out0'].view(-1).cpu().numpy().tolist()
                y_pred = out_dict['out0'].max(1)[1].view(-1).cpu().numpy().tolist()
                list_y_true.extend(y_true)
                list_y_pred.extend(y_pred)

                self.val_losses.update(loss_dict['loss'].mean().item(), data_dict['img'].size(0))

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            RunnerHelper.save_net(self, self.cls_net, performance=self.cls_running_score.top1_acc.avg['out0'])
            self.runner_state['performance'] = self.cls_running_score.top1_acc.avg['out0']
            # Print the log info & reset the states.
            Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time))
            Log.info('Test Set: {} images'.format(len(list_y_true)))
            Log.info('TestLoss = {loss.avg:.8f}'.format(loss=self.val_losses))
            Log.info('Top1 ACC = {}'.format(self.cls_running_score.top1_acc.avg['out0']))
            # Log.info('Top5 ACC = {}'.format(self.cls_running_score.get_top5_acc()))
            acc= self.cls_running_score.top1_acc.avg['out0']
            cmatrix = confusion_matrix(list_y_true, list_y_pred)
            fscore_str = classification_report(list_y_true, list_y_pred, digits=5)
            fscore_list = [float(line.strip().split()[-2])
                           for line in fscore_str.split('\n')[2:] if len(line.strip().split()) > 0]
            self.batch_time.reset()
            self.val_losses.reset()
            self.cls_running_score.reset()
            self.cls_net.train()
            return acc, cmatrix, fscore_list
示例#2
0
class ImageTester(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.seg_data_loader = DataLoader(configer)
        self.save_dir = self.configer.get('test', 'out_dir')
        self.seg_net = None
        self.test_loader = None
        self.test_size = None
        self.infer_time = 0
        self.infer_cnt = 0
        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        print(f"self.save_dir {self.save_dir}")
        # if 'test' in self.save_dir:
        #     self.test_loader = self.seg_data_loader.get_testloader()
        #     self.test_size = len(self.test_loader) * self.configer.get('test', 'batch_size')
        #     print(f"self.test_size {self.test_size}")
        # else:
        #     self.test_loader = self.seg_data_loader.get_valloader()
        #     self.test_size = len(self.test_loader) * self.configer.get('val', 'batch_size')

        self.seg_net.eval()

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get('data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self, img_path=None, output_dir=None, data_loader=None):
        """
          Validation function during the train phase.
        """
        print("test!!!")
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        colors = get_ade_colors()

        # Reader.
        if img_path is not None:
            input_path = img_path
        else:
            input_path = self.configer.get('input_image')

        input_image = cv2.imread(input_path)

        transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get('normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')), ])

        aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='val')

        pre_vis_img = None
        pre_lines = None
        pre_target_img = None
        
        ori_img = input_image.copy()

        h, w, _ = input_image.shape
        ori_img_size = [w, h]

        # print(img.shape)
        input_image = aug_val_transform(input_image)
        input_image = input_image[0]
            
        h, w, _ = input_image.shape
        border_size = [w, h]

        input_image = transform(input_image)
        # print(img)
        # print(img.shape)

        # inputs = data_dict['img']
        # names = data_dict['name']
        # metas = data_dict['meta']
        
        # print(inputs)

        with torch.no_grad():
            # Forward pass.
            outputs = self.ss_test([input_image])

            if isinstance(outputs, torch.Tensor):
                outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                n = outputs.shape[0]
            else:
                outputs = [output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs]
                n = len(outputs)

            logits = cv2.resize(outputs[0],
                                tuple(ori_img_size), interpolation=cv2.INTER_CUBIC)
            label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8)
            if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'):
                label_img = label_img + 1
                label_img = label_img.astype(np.uint8)
            if self.configer.exists('data', 'label_list'):
                label_img_ = self.__relabel(label_img)
            else:
                label_img_ = label_img
            label_img_ = Image.fromarray(label_img_, 'P')

            input_name = '.'.join(os.path.basename(input_path).split('.')[:-1])
            if output_dir is None:
                label_path = os.path.join(self.save_dir, 'label_{}.png'.format(input_name))
            else:
                label_path = os.path.join(output_dir, 'label_{}.png'.format(input_name))
            FileHelper.make_dirs(label_path, is_file=True)
            # print(f"{label_path}")
            ImageHelper.save(label_img_, label_path)

        self.batch_time.update(time.time() - start_time)

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time))


    def offset_test(self, inputs, offset_h_maps, offset_w_maps, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(inputs, offset_h_maps, offset_w_maps)
            torch.cuda.synchronize()
            end = timeit.default_timer()

            if (self.configer.get('loss', 'loss_type') == "fs_auxce_loss") or (self.configer.get('loss', 'loss_type') == "triple_auxce_loss"):
                outputs = outputs[-1]
            elif self.configer.get('loss', 'loss_type') == "pyramid_auxce_loss":
                outputs = outputs[1] + outputs[2] + outputs[3] + outputs[4]

            outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True)
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def ss_test(self, inputs, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            scaled_inputs = F.interpolate(inputs, size=(int(h*scale), int(w*scale)), mode="bilinear", align_corners=True)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(scaled_inputs)
            torch.cuda.synchronize()
            end = timeit.default_timer()
            outputs = outputs[-1]
            outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True)
            return outputs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            replicas = nn.parallel.replicate(self.seg_net.module, device_ids)
            scaled_inputs, ori_size, outputs = [], [], []
            for i, d in zip(inputs, device_ids):
                h, w = i.size(1), i.size(2)
                ori_size.append((h, w))
                i = F.interpolate(i.unsqueeze(0), size=(int(h*scale), int(w*scale)), mode="bilinear", align_corners=True)
                scaled_inputs.append(i.cuda(d, non_blocking=True))
            scaled_outputs = nn.parallel.parallel_apply(replicas[:len(scaled_inputs)], scaled_inputs)
            for i, output in enumerate(scaled_outputs):
                outputs.append(F.interpolate(output[-1], size=ori_size[i], mode='bilinear', align_corners=True))
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def flip(self, x, dim):
        indices = [slice(None)] * x.dim()
        indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
                                    dtype=torch.long, device=x.device)
        return x[tuple(indices)]


    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        n, c, ori_h, ori_w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
        scaled_inputs = F.interpolate(inputs, size=(int(ori_h*scale), int(ori_w*scale)), mode="bilinear", align_corners=True)
        n, c, h, w = scaled_inputs.size(0), scaled_inputs.size(1), scaled_inputs.size(2), scaled_inputs.size(3)
        full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        count_predictions = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = scaled_inputs[:, :, height:height+crop_size[0], width:width + crop_size[1]]
                prediction = self.ss_test(crop_inputs)
                count_predictions[:, :, height:height+crop_size[0], width:width + crop_size[1]] += 1
                full_probs[:, :, height:height+crop_size[0], width:width + crop_size[1]] += prediction 
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        full_probs /= count_predictions
        full_probs = F.interpolate(full_probs, size=(ori_h, ori_w), mode='bilinear', align_corners=True)
        return full_probs


    def ms_test(self, inputs):
        if isinstance(inputs, torch.Tensor):  
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(self.configer.get('test', 'scale_search'), self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += weight * probs
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += probs
                return full_probs

        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [torch.zeros(1, self.configer.get('data', 'num_classes'), 
                i.size(1), i.size(2)).cuda(device_ids[index], non_blocking=True)
                for index, i in enumerate(inputs)]
            flip_inputs = [self.flip(i, 2) for i in inputs]

            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(self.configer.get('test', 'scale_search'), self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += weight * (probs[i] + self.flip(flip_probs[i], 3))
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += (probs[i] + self.flip(flip_probs[i], 3))
                return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def ms_test_depth(self, inputs, names):
        prob_list = []
        scale_list = []

        if isinstance(inputs, torch.Tensor):  
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                prob_list.append(probs)
                scale_list.append(scale)

            full_probs = self.fuse_with_depth(prob_list, scale_list, names)
            return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def fuse_with_depth(self, probs, scales, names):
        MAX_DEPTH = 63
        POWER_BASE = 0.8
        if 'test' in self.save_dir:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/test/"
        else:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/val/"

        n, c, h, w = probs[0].size(0), probs[0].size(1), probs[0].size(2), probs[0].size(3)
        full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        for index, name in enumerate(names):
            stereo_map = cv2.imread(stereo_path + name + '.png', -1)
            depth_map = stereo_map / 256.0
            depth_map = 0.5 / depth_map
            depth_map = 500 * depth_map

            depth_map = np.clip(depth_map, 0, MAX_DEPTH)
            depth_map = depth_map // (MAX_DEPTH // len(scales))

            for prob, scale in zip(probs, scales):
                scale_index = self._locate_scale_index(scale, scales)
                weight_map = np.abs(depth_map - scale_index)
                weight_map = np.power(POWER_BASE, weight_map)
                weight_map = cv2.resize(weight_map, (w, h))
                full_probs[index, :, :, :] += torch.from_numpy(np.expand_dims(weight_map, axis=0)).type(torch.cuda.FloatTensor) * prob[index, :, :, :]

        return full_probs

    @staticmethod
    def _locate_scale_index(scale, scales):
        for idx, s in enumerate(scales):
            if scale == s:
                return idx
        return 0


    def ms_test_wo_flip(self, inputs):
        if isinstance(inputs, torch.Tensor):  
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                full_probs += probs
            return full_probs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [torch.zeros(1, self.configer.get('data', 'num_classes'), 
                i.size(1), i.size(2)).cuda(device_ids[index], non_blocking=True)
                for index, i, in enumerate(inputs)]
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                for i in range(len(inputs)):
                    full_probs[i] += probs[i]
            return full_probs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))


    def mscrop_test(self, inputs, crop_size):  
        '''
        Currently, mscrop_test does not support diverse_size testing
        '''
        n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(2), inputs.size(3)
        full_probs = torch.cuda.FloatTensor(n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        for scale in self.configer.get('test', 'scale_search'):
            Log.info('Scale {0:.2f} prediction'.format(scale))
            if scale < 1:
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
            else:
                probs = self.sscrop_test(inputs, crop_size, scale)
                flip_probs = self.sscrop_test(self.flip(inputs, 3), crop_size, scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
        return full_probs


    def _decide_intersection(self, total_length, crop_length):
        stride = crop_length
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride*i)
        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length - crop_length)  # must cover the total image
        return cropped_starting


    def dense_crf_process(self, images, outputs):
        '''
        Reference: https://github.com/kazuto1011/deeplab-pytorch/blob/master/libs/utils/crf.py
        '''
        # hyperparameters of the dense crf 
        # baseline = 79.5
        # bi_xy_std = 67, 79.1
        # bi_xy_std = 20, 79.6
        # bi_xy_std = 10, 79.7
        # bi_xy_std = 10, iter_max = 20, v4 79.7
        # bi_xy_std = 10, iter_max = 5, v5 79.7
        # bi_xy_std = 5, v3 79.7
        iter_max = 10
        pos_w = 3
        pos_xy_std = 1
        bi_w = 4
        bi_xy_std = 10
        bi_rgb_std = 3

        b = images.size(0)
        mean_vector = np.expand_dims(np.expand_dims(np.transpose(np.array([102.9801, 115.9465, 122.7717])), axis=1), axis=2)
        outputs = F.softmax(outputs, dim=1)
        for i in range(b):
            unary = outputs[i].data.cpu().numpy()
            C, H, W = unary.shape
            unary = dcrf_utils.unary_from_softmax(unary)
            unary = np.ascontiguousarray(unary)
            
            image = np.ascontiguousarray(images[i]) + mean_vector
            image = image.astype(np.ubyte)
            image = np.ascontiguousarray(image.transpose(1, 2, 0))

            d = dcrf.DenseCRF2D(W, H, C)
            d.setUnaryEnergy(unary)
            d.addPairwiseGaussian(sxy=pos_xy_std, compat=pos_w)
            d.addPairwiseBilateral(sxy=bi_xy_std, srgb=bi_rgb_std, rgbim=image, compat=bi_w)
            out_crf = np.array(d.inference(iter_max))
            outputs[i] = torch.from_numpy(out_crf).cuda().view(C, H, W)

        return outputs
    

    def visualize(self, label_img):
        img = label_img.copy()
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        img_num = 14
        img[img == img_num] = 0
        out_img = img.copy()
        for label in HYUNDAI_POC_CATEGORIES:
            red, green, blue = img[:,:,0], img[:,:,1], img[:,:,2] 
            mask = red == label['id']
            out_img[:,:,:3][mask] = label['color']

        out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)

        return out_img
    
    def get_ratio_all(self, anno_img):
        total_size = 0
        lines = []
        for label in HYUNDAI_POC_CATEGORIES:
            total = self.get_ratio(anno_img.copy(), label)
            lines.append([label['name'], total])
            total_size += total
        for l in lines:
            if total_size:
                l[1] = l[1] / total_size * 100
            else:
                l[1] = 0
        return lines
    
    def get_ratio(self, anno_img, label):
        total = 0
        label_id = label['id']
        if label_id == 14:
            return total
        label_img = (anno_img == label_id).astype(np.uint8)
        # label_img = cv2.cvtColor(label_img, cv2.COLOR_BGR2GRAY)
        total = np.count_nonzero(label_img)
        return total

    def visualize_ratio(self, ratios, video_size, ratio_w):
        ratio_list = ratios.copy()
        ratio_list.insert(0, ['등급','비율'])
        RATIO_IMG_W = ratio_w
        RATIO_IMG_H = int(video_size[0])
        TEXT_MARGIN_H = 20
        TEXT_MARGIN_W = 10
        row_count = 14
        col_count = 2

        ratio_img = np.full((RATIO_IMG_H, RATIO_IMG_W, 3), 255, np.uint8)

        row_h = RATIO_IMG_H / row_count
        col_w = RATIO_IMG_H / row_count

        center_w = RATIO_IMG_W / 2

        for i in range(1, row_count):
            p_y = int(i * row_h)
            p_y_n = int((i+1) * row_h)
            for label in HYUNDAI_POC_CATEGORIES:
                if label['id'] == i:
                    cv2.rectangle(ratio_img, (0, p_y), (int(center_w), p_y_n), label['color'], cv2.FILLED)

        for i in range(1, row_count):
            p_y = int(i * row_h)
            cv2.line(ratio_img, (0, p_y), (RATIO_IMG_W, p_y), (0,0,0))

        cv2.line(ratio_img, (int(center_w), 0), (int(center_w), RATIO_IMG_H), (0,0,0))

        for i in range(row_count):
            p_y = int(i * row_h) + TEXT_MARGIN_H
            p_w = int(center_w) + TEXT_MARGIN_W
            ratio_img = Image.fromarray(ratio_img)
            font = ImageFont.truetype("NanumGothic.ttf", 15)
            draw = ImageDraw.Draw(ratio_img)
            color = (0, 0, 0)
            # print(ratio_list)
            draw.text((0, p_y), ratio_list[i][0], font=font, fill=color)
            if isinstance(ratio_list[i][1], str):
                draw.text((p_w, p_y), ratio_list[i][1],font=font,fill=color)
            else:
                draw.text((p_w, p_y), "{:.02f}".format(ratio_list[i][1]),font=font,fill=color)
            ratio_img = np.array(ratio_img)

        ratio_img = cv2.cvtColor(ratio_img, cv2.COLOR_RGB2BGR)
        return ratio_img
class Tester(object):
    def __init__(self, configer):
        self.crop_size = configer.get('train',
                                      'data_transformer')['input_size']
        val_trans_seq = [
            x for x in configer.get('val_trans', 'trans_seq')
            if 'random' not in x
        ]
        configer.update(('val_trans', 'trans_seq'), val_trans_seq)
        configer.get('val', 'data_transformer')['input_size'] = configer.get(
            'test', 'data_transformer').get('input_size', None)
        configer.update(('train', 'data_transformer'),
                        configer.get('val', 'data_transformer'))
        configer.update(('val', 'batch_size'),
                        int(os.environ.get('batch_size', 16)))
        configer.update(('test', 'batch_size'),
                        int(os.environ.get('batch_size', 16)))

        self.save_dir = configer.get('test', 'out_dir')
        self.dataset_name = configer.get('test', 'eval_set')
        self.sscrop = configer.get('test', 'sscrop')

        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)
        self.seg_net = None
        self.test_loader = None
        self.test_size = None
        self.infer_time = 0
        self.infer_cnt = 0
        self._init_model()

        pprint.pprint(configer.params_root)

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        assert self.dataset_name in ('train', 'val',
                                     'test'), 'Cannot infer dataset name'

        self.size_mode = self.configer.get(self.dataset_name,
                                           'data_transformer')['size_mode']

        if self.dataset_name != 'test':
            self.test_loader = self.seg_data_loader.get_valloader(
                self.dataset_name)
        else:
            self.test_loader = self.seg_data_loader.get_testloader(
                self.dataset_name)
        self.test_size = len(self.test_loader) * self.configer.get(
            'val', 'batch_size')

    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        print('Total batches', len(self.test_loader))
        for j, data_dict in enumerate(self.test_loader):
            inputs = [data_dict['img']]
            names = data_dict['name']
            metas = data_dict['meta']

            dest_dir = self.save_dir

            with torch.no_grad():
                offsets, logits = self.extract_offset(inputs)
                print([x.shape for x in logits])
                for k in range(len(inputs[0])):
                    image_id += 1
                    ori_img_size = metas[k]['ori_img_size']
                    border_size = metas[k]['border_size']
                    offset = offsets[k].squeeze().cpu().numpy()
                    offset = cv2.resize(
                        offset[:border_size[1], :border_size[0]],
                        tuple(ori_img_size),
                        interpolation=cv2.INTER_NEAREST)
                    print(image_id)

                    os.makedirs(dest_dir, exist_ok=True)

                    if names[k].rpartition('.')[0]:
                        dest_name = names[k].rpartition('.')[0] + '.mat'
                    else:
                        dest_name = names[k] + '.mat'
                    dest_name = os.path.join(dest_dir, dest_name)
                    print('Shape:', offset.shape, 'Saving to', dest_name)

                    data_dict = {'mat': offset}

                    scipy.io.savemat(dest_name, data_dict, do_compression=True)
                    try:
                        scipy.io.loadmat(dest_name)
                    except Exception as e:
                        print(e)
                        scipy.io.savemat(dest_name,
                                         data_dict,
                                         do_compression=False)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))

    def extract_offset(self, inputs):
        if self.sscrop:
            outputs = self.sscrop_test(inputs, self.crop_size)
        elif self.configer.get('test', 'mode') == 'ss_test':
            outputs = self.ss_test(inputs)

        offsets = []
        logits = []

        for mask_logits, dir_logits, img in zip(*outputs[:2], inputs[0]):
            h, w = img.shape[1:]

            mask_logits = F.interpolate(mask_logits.unsqueeze(0),
                                        size=(h, w),
                                        mode='bilinear',
                                        align_corners=True)
            dir_logits = F.interpolate(dir_logits.unsqueeze(0),
                                       size=(h, w),
                                       mode='bilinear',
                                       align_corners=True)

            logit = torch.softmax(dir_logits, dim=1)
            zero_mask = mask_logits.argmax(dim=1, keepdim=True) == 0
            logits.append(mask_logits[:, 1])

            offset = self._get_offset(mask_logits, dir_logits)
            offsets.append(offset)
        print([x.shape for x in offsets])
        return offsets, logits

    def _get_offset(self, mask_logits, dir_logits):

        edge_mask = mask_logits[:, 1] > 0.5
        dir_logits = torch.softmax(dir_logits, dim=1)
        n, _, h, w = dir_logits.shape

        keep_mask = edge_mask

        dir_label = torch.argmax(dir_logits, dim=1).float()
        offset = DTOffsetHelper.label_to_vector(dir_label)
        offset = offset.permute(0, 2, 3, 1)
        offset[~keep_mask, :] = 0
        return offset

    def _flip(self, x, dim=-1):
        indices = [slice(None)] * x.dim()
        indices[dim] = torch.arange(x.size(dim) - 1,
                                    -1,
                                    -1,
                                    dtype=torch.long,
                                    device=x.device)
        return x[tuple(indices)]

    def _flip_offset(self, x):
        x = self._flip(x, dim=-1)
        if len(x.shape) == 4:
            return x[:, DTOffsetHelper.flipping_indices()]
        else:
            return x[DTOffsetHelper.flipping_indices()]

    def _flip_inputs(self, inputs):

        if self.size_mode == 'fix_size':
            return [self._flip(x, -1) for x in inputs]
        else:
            return [[self._flip(x, -1) for x in xs] for xs in inputs]

    def _flip_outputs(self, outputs):
        funcs = [self._flip, self._flip_offset]
        if self.size_mode == 'fix_size':
            return [f(x) for f, x in zip(funcs, outputs)]
        else:
            return [[f(x) for x in xs] for f, xs in zip(funcs, outputs)]

    def _tuple_sum(self, tup1, tup2, tup2_weight=1):
        """
        tup1 / tup2: tuple of tensors or tuple of list of tensors
        """

        if tup1 is None:
            if self.size_mode == 'fix_size':
                return [y * tup2_weight for y in tup2]
            else:
                return [[y * tup2_weight for y in ys] for ys in tup2]
        else:
            if self.size_mode == 'fix_size':
                return [x + y * tup2_weight for x, y in zip(tup1, tup2)]
            else:
                return [[x + y * tup2_weight for x, y in zip(xs, ys)]
                        for xs, ys in zip(tup1, tup2)]

    def _scale_ss_inputs(self, inputs, scale):
        n, c, h, w = inputs[0].shape
        size = (int(h * scale), int(w * scale))
        return [
            F.interpolate(inputs[0],
                          size=size,
                          mode="bilinear",
                          align_corners=True),
        ], (h, w)

    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        scaled_inputs = inputs
        img = scaled_inputs[0]
        n, c, h, w = img.size(0), img.size(1), img.size(2), img.size(3)
        ori_h, ori_w = h, w
        full_probs = [
            torch.cuda.FloatTensor(n, dim, h, w).fill_(0) for dim in (2, 8)
        ]
        count_predictions = [
            torch.cuda.FloatTensor(n, dim, h, w).fill_(0) for dim in (2, 8)
        ]

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = [
                    x[..., height:height + crop_size[0],
                      width:width + crop_size[1]] for x in scaled_inputs
                ]
                prediction = self.ss_test(crop_inputs)

                for j in range(2):
                    count_predictions[j][:, :, height:height + crop_size[0],
                                         width:width + crop_size[1]] += 1
                    full_probs[j][:, :, height:height + crop_size[0],
                                  width:width + crop_size[1]] += prediction[j]
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        for j in range(2):
            full_probs[j] /= count_predictions[j]
            full_probs[j] = F.interpolate(full_probs[j],
                                          size=(ori_h, ori_w),
                                          mode='bilinear',
                                          align_corners=True)
        return full_probs

    def _scale_ss_outputs(self, outputs, size):
        return [
            F.interpolate(x, size=size, mode="bilinear", align_corners=True)
            for x in outputs
        ]

    def ss_test(self, inputs, scale=1):
        if self.size_mode == 'fix_size':

            scaled_inputs, orig_size = self._scale_ss_inputs(inputs, scale)
            print([x.shape for x in scaled_inputs])

            start = timeit.default_timer()
            outputs = list(self.seg_net.forward(*scaled_inputs))
            if len(outputs) == 3:
                outputs = (outputs[0], outputs[2])
            else:
                outputs[0] = F.softmax(outputs[0], dim=1)
            torch.cuda.synchronize()
            end = timeit.default_timer()

            return self._scale_ss_outputs(outputs, orig_size)

        else:
            device_ids = self.configer.get('gpu')
            replicas = nn.parallel.replicate(self.seg_net.module, device_ids)
            scaled_inputs, ori_sizes, outputs = [], [], []

            for *i, d in zip(*inputs, device_ids):
                scaled_i, ori_size_i = self._scale_ss_inputs(
                    [x.unsqueeze(0) for x in i], scale)
                scaled_inputs.append(
                    [x.cuda(d, non_blocking=True) for x in scaled_i])
                ori_sizes.append(ori_size_i)

            scaled_outputs = nn.parallel.parallel_apply(
                replicas[:len(scaled_inputs)], scaled_inputs)

            for o, ori_size in zip(scaled_outputs, ori_sizes):
                o = self._scale_ss_outputs(o, ori_size)
                if len(o) == 3:
                    o = (o[0], o[2])
                outputs.append([x.squeeze(0) for x in o])
            outputs = list(map(list, zip(*outputs)))
            return outputs

    def _decide_intersection(self,
                             total_length,
                             crop_length,
                             crop_stride_ratio=1 / 3):
        stride = int(crop_length *
                     crop_stride_ratio)  # set the stride as the paper do
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride * i)

        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length -
                                    crop_length)  # must cover the total image

        return cropped_starting
示例#4
0
class Tester(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.seg_data_loader = DataLoader(configer)
        self.save_dir = self.configer.get('test', 'out_dir')
        self.seg_net = None
        self.test_loader = None
        self.test_size = None
        self.infer_time = 0
        self.infer_cnt = 0
        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        if 'test' in self.save_dir:
            self.test_loader = self.seg_data_loader.get_testloader()
            self.test_size = len(self.test_loader) * self.configer.get(
                'test', 'batch_size')
        else:
            self.test_loader = self.seg_data_loader.get_valloader()
            self.test_size = len(self.test_loader) * self.configer.get(
                'val', 'batch_size')

        self.seg_net.eval()

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get(
                'data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        if self.configer.get('dataset') in ['cityscapes', 'gta5']:
            colors = get_cityscapes_colors()
        elif self.configer.get('dataset') == 'ade20k':
            colors = get_ade_colors()
        elif self.configer.get('dataset') == 'lip':
            colors = get_lip_colors()
        elif self.configer.get('dataset') == 'pascal_context':
            colors = get_pascal_context_colors()
        elif self.configer.get('dataset') == 'pascal_voc':
            colors = get_pascal_voc_colors()
        elif self.configer.get('dataset') == 'coco_stuff':
            colors = get_cocostuff_colors()
        else:
            raise RuntimeError("Unsupport colors")

        save_prob = False
        if self.configer.get('test', 'save_prob'):
            save_prob = self.configer.get('test', 'save_prob')

            def softmax(X, axis=0):
                max_prob = np.max(X, axis=axis, keepdims=True)
                X -= max_prob
                X = np.exp(X)
                sum_prob = np.sum(X, axis=axis, keepdims=True)
                X /= sum_prob
                return X

        for j, data_dict in enumerate(self.test_loader):
            inputs = data_dict['img']
            names = data_dict['name']
            metas = data_dict['meta']

            if 'val' in self.save_dir and os.environ.get('save_gt_label'):
                labels = data_dict['labelmap']

            with torch.no_grad():
                # Forward pass.
                if self.configer.exists('data',
                                        'use_offset') and self.configer.get(
                                            'data', 'use_offset') == 'offline':
                    offset_h_maps = data_dict['offsetmap_h']
                    offset_w_maps = data_dict['offsetmap_w']
                    outputs = self.offset_test(inputs, offset_h_maps,
                                               offset_w_maps)
                elif self.configer.get('test', 'mode') == 'ss_test':
                    outputs = self.ss_test(inputs)
                elif self.configer.get('test', 'mode') == 'ms_test':
                    outputs = self.ms_test(inputs)
                elif self.configer.get('test', 'mode') == 'ms_test_depth':
                    outputs = self.ms_test_depth(inputs, names)
                elif self.configer.get('test', 'mode') == 'sscrop_test':
                    crop_size = self.configer.get('test', 'crop_size')
                    outputs = self.sscrop_test(inputs, crop_size)
                elif self.configer.get('test', 'mode') == 'mscrop_test':
                    crop_size = self.configer.get('test', 'crop_size')
                    outputs = self.mscrop_test(inputs, crop_size)
                elif self.configer.get('test', 'mode') == 'crf_ss_test':
                    outputs = self.ss_test(inputs)
                    outputs = self.dense_crf_process(inputs, outputs)

                if isinstance(outputs, torch.Tensor):
                    outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                    n = outputs.shape[0]
                else:
                    outputs = [
                        output.permute(0, 2, 3, 1).cpu().numpy().squeeze()
                        for output in outputs
                    ]
                    n = len(outputs)

                for k in range(n):
                    image_id += 1
                    ori_img_size = metas[k]['ori_img_size']
                    border_size = metas[k]['border_size']
                    logits = cv2.resize(
                        outputs[k][:border_size[1], :border_size[0]],
                        tuple(ori_img_size),
                        interpolation=cv2.INTER_CUBIC)

                    # save the logits map
                    if self.configer.get('test', 'save_prob'):
                        prob_path = os.path.join(self.save_dir, "prob/",
                                                 '{}.npy'.format(names[k]))
                        FileHelper.make_dirs(prob_path, is_file=True)
                        np.save(prob_path, softmax(logits, axis=-1))

                    label_img = np.asarray(np.argmax(logits, axis=-1),
                                           dtype=np.uint8)
                    if self.configer.exists(
                            'data', 'reduce_zero_label') and self.configer.get(
                                'data', 'reduce_zero_label'):
                        label_img = label_img + 1
                        label_img = label_img.astype(np.uint8)
                    if self.configer.exists('data', 'label_list'):
                        label_img_ = self.__relabel(label_img)
                    else:
                        label_img_ = label_img
                    label_img_ = Image.fromarray(label_img_, 'P')
                    Log.info('{:4d}/{:4d} label map generated'.format(
                        image_id, self.test_size))
                    label_path = os.path.join(self.save_dir, "label/",
                                              '{}.png'.format(names[k]))
                    FileHelper.make_dirs(label_path, is_file=True)
                    ImageHelper.save(label_img_, label_path)

                    # colorize the label-map
                    if os.environ.get('save_gt_label'):
                        if self.configer.exists(
                                'data',
                                'reduce_zero_label') and self.configer.get(
                                    'data', 'reduce_zero_label'):
                            label_img = labels[k] + 1
                            label_img = np.asarray(label_img, dtype=np.uint8)
                        color_img_ = Image.fromarray(label_img)
                        color_img_.putpalette(colors)
                        vis_path = os.path.join(self.save_dir, "gt_vis/",
                                                '{}.png'.format(names[k]))
                        FileHelper.make_dirs(vis_path, is_file=True)
                        ImageHelper.save(color_img_, save_path=vis_path)
                    else:
                        color_img_ = Image.fromarray(label_img)
                        color_img_.putpalette(colors)
                        vis_path = os.path.join(self.save_dir, "vis/",
                                                '{}.png'.format(names[k]))
                        FileHelper.make_dirs(vis_path, is_file=True)
                        ImageHelper.save(color_img_, save_path=vis_path)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))

    def offset_test(self, inputs, offset_h_maps, offset_w_maps, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(inputs, offset_h_maps,
                                           offset_w_maps)
            torch.cuda.synchronize()
            end = timeit.default_timer()

            if (self.configer.get('loss', 'loss_type')
                    == "fs_auxce_loss") or (self.configer.get(
                        'loss', 'loss_type') == "triple_auxce_loss"):
                outputs = outputs[-1]
            elif self.configer.get('loss',
                                   'loss_type') == "pyramid_auxce_loss":
                outputs = outputs[1] + outputs[2] + outputs[3] + outputs[4]

            outputs = F.interpolate(outputs,
                                    size=(h, w),
                                    mode='bilinear',
                                    align_corners=True)
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def ss_test(self, inputs, scale=1):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            scaled_inputs = F.interpolate(inputs,
                                          size=(int(h * scale),
                                                int(w * scale)),
                                          mode="bilinear",
                                          align_corners=True)
            start = timeit.default_timer()
            outputs = self.seg_net.forward(scaled_inputs)
            torch.cuda.synchronize()
            end = timeit.default_timer()
            outputs = outputs[-1]
            outputs = F.interpolate(outputs,
                                    size=(h, w),
                                    mode='bilinear',
                                    align_corners=True)
            return outputs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            replicas = nn.parallel.replicate(self.seg_net.module, device_ids)
            scaled_inputs, ori_size, outputs = [], [], []
            for i, d in zip(inputs, device_ids):
                h, w = i.size(1), i.size(2)
                ori_size.append((h, w))
                i = F.interpolate(i.unsqueeze(0),
                                  size=(int(h * scale), int(w * scale)),
                                  mode="bilinear",
                                  align_corners=True)
                scaled_inputs.append(i.cuda(d, non_blocking=True))
            scaled_outputs = nn.parallel.parallel_apply(
                replicas[:len(scaled_inputs)], scaled_inputs)
            for i, output in enumerate(scaled_outputs):
                outputs.append(
                    F.interpolate(output[-1],
                                  size=ori_size[i],
                                  mode='bilinear',
                                  align_corners=True))
            return outputs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def flip(self, x, dim):
        indices = [slice(None)] * x.dim()
        indices[dim] = torch.arange(x.size(dim) - 1,
                                    -1,
                                    -1,
                                    dtype=torch.long,
                                    device=x.device)
        return x[tuple(indices)]

    def sscrop_test(self, inputs, crop_size, scale=1):
        '''
        Currently, sscrop_test does not support diverse_size testing
        '''
        n, c, ori_h, ori_w = inputs.size(0), inputs.size(1), inputs.size(
            2), inputs.size(3)
        scaled_inputs = F.interpolate(inputs,
                                      size=(int(ori_h * scale),
                                            int(ori_w * scale)),
                                      mode="bilinear",
                                      align_corners=True)
        n, c, h, w = scaled_inputs.size(0), scaled_inputs.size(
            1), scaled_inputs.size(2), scaled_inputs.size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        count_predictions = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        crop_counter = 0

        height_starts = self._decide_intersection(h, crop_size[0])
        width_starts = self._decide_intersection(w, crop_size[1])

        for height in height_starts:
            for width in width_starts:
                crop_inputs = scaled_inputs[:, :, height:height + crop_size[0],
                                            width:width + crop_size[1]]
                prediction = self.ss_test(crop_inputs)
                count_predictions[:, :, height:height + crop_size[0],
                                  width:width + crop_size[1]] += 1
                full_probs[:, :, height:height + crop_size[0],
                           width:width + crop_size[1]] += prediction
                crop_counter += 1
                Log.info('predicting {:d}-th crop'.format(crop_counter))

        full_probs /= count_predictions
        full_probs = F.interpolate(full_probs,
                                   size=(ori_h, ori_w),
                                   mode='bilinear',
                                   align_corners=True)
        return full_probs

    def ms_test(self, inputs):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(
                        self.configer.get('test', 'scale_search'),
                        self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += weight * probs
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                    probs = probs + self.flip(flip_probs, 3)
                    full_probs += probs
                return full_probs

        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [
                torch.zeros(1, self.configer.get('data', 'num_classes'),
                            i.size(1), i.size(2)).cuda(device_ids[index],
                                                       non_blocking=True)
                for index, i in enumerate(inputs)
            ]
            flip_inputs = [self.flip(i, 2) for i in inputs]

            if self.configer.exists('test', 'scale_weights'):
                for scale, weight in zip(
                        self.configer.get('test', 'scale_search'),
                        self.configer.get('test', 'scale_weights')):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += weight * (probs[i] +
                                                   self.flip(flip_probs[i], 3))
                return full_probs
            else:
                for scale in self.configer.get('test', 'scale_search'):
                    probs = self.ss_test(inputs, scale)
                    flip_probs = self.ss_test(flip_inputs, scale)
                    for i in range(len(inputs)):
                        full_probs[i] += (probs[i] +
                                          self.flip(flip_probs[i], 3))
                return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def ms_test_depth(self, inputs, names):
        prob_list = []
        scale_list = []

        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                prob_list.append(probs)
                scale_list.append(scale)

            full_probs = self.fuse_with_depth(prob_list, scale_list, names)
            return full_probs

        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def fuse_with_depth(self, probs, scales, names):
        MAX_DEPTH = 63
        POWER_BASE = 0.8
        if 'test' in self.save_dir:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/test/"
        else:
            stereo_path = "/msravcshare/dataset/cityscapes/stereo/val/"

        n, c, h, w = probs[0].size(0), probs[0].size(1), probs[0].size(
            2), probs[0].size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)

        for index, name in enumerate(names):
            stereo_map = cv2.imread(stereo_path + name + '.png', -1)
            depth_map = stereo_map / 256.0
            depth_map = 0.5 / depth_map
            depth_map = 500 * depth_map

            depth_map = np.clip(depth_map, 0, MAX_DEPTH)
            depth_map = depth_map // (MAX_DEPTH // len(scales))

            for prob, scale in zip(probs, scales):
                scale_index = self._locate_scale_index(scale, scales)
                weight_map = np.abs(depth_map - scale_index)
                weight_map = np.power(POWER_BASE, weight_map)
                weight_map = cv2.resize(weight_map, (w, h))
                full_probs[index, :, :, :] += torch.from_numpy(
                    np.expand_dims(weight_map, axis=0)).type(
                        torch.cuda.FloatTensor) * prob[index, :, :, :]

        return full_probs

    @staticmethod
    def _locate_scale_index(scale, scales):
        for idx, s in enumerate(scales):
            if scale == s:
                return idx
        return 0

    def ms_test_wo_flip(self, inputs):
        if isinstance(inputs, torch.Tensor):
            n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
                2), inputs.size(3)
            full_probs = torch.cuda.FloatTensor(
                n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                full_probs += probs
            return full_probs
        elif isinstance(inputs, collections.Sequence):
            device_ids = self.configer.get('gpu')
            full_probs = [
                torch.zeros(1, self.configer.get('data', 'num_classes'),
                            i.size(1), i.size(2)).cuda(device_ids[index],
                                                       non_blocking=True)
                for index, i, in enumerate(inputs)
            ]
            for scale in self.configer.get('test', 'scale_search'):
                probs = self.ss_test(inputs, scale)
                for i in range(len(inputs)):
                    full_probs[i] += probs[i]
            return full_probs
        else:
            raise RuntimeError("Unsupport data type: {}".format(type(inputs)))

    def mscrop_test(self, inputs, crop_size):
        '''
        Currently, mscrop_test does not support diverse_size testing
        '''
        n, c, h, w = inputs.size(0), inputs.size(1), inputs.size(
            2), inputs.size(3)
        full_probs = torch.cuda.FloatTensor(
            n, self.configer.get('data', 'num_classes'), h, w).fill_(0)
        for scale in self.configer.get('test', 'scale_search'):
            Log.info('Scale {0:.2f} prediction'.format(scale))
            if scale < 1:
                probs = self.ss_test(inputs, scale)
                flip_probs = self.ss_test(self.flip(inputs, 3), scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
            else:
                probs = self.sscrop_test(inputs, crop_size, scale)
                flip_probs = self.sscrop_test(self.flip(inputs, 3), crop_size,
                                              scale)
                probs = probs + self.flip(flip_probs, 3)
                full_probs += probs
        return full_probs

    def _decide_intersection(self, total_length, crop_length):
        stride = crop_length
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride * i)
        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length -
                                    crop_length)  # must cover the total image
        return cropped_starting

    def dense_crf_process(self, images, outputs):
        '''
        Reference: https://github.com/kazuto1011/deeplab-pytorch/blob/master/libs/utils/crf.py
        '''
        # hyperparameters of the dense crf
        # baseline = 79.5
        # bi_xy_std = 67, 79.1
        # bi_xy_std = 20, 79.6
        # bi_xy_std = 10, 79.7
        # bi_xy_std = 10, iter_max = 20, v4 79.7
        # bi_xy_std = 10, iter_max = 5, v5 79.7
        # bi_xy_std = 5, v3 79.7
        iter_max = 10
        pos_w = 3
        pos_xy_std = 1
        bi_w = 4
        bi_xy_std = 10
        bi_rgb_std = 3

        b = images.size(0)
        mean_vector = np.expand_dims(np.expand_dims(np.transpose(
            np.array([102.9801, 115.9465, 122.7717])),
                                                    axis=1),
                                     axis=2)
        outputs = F.softmax(outputs, dim=1)
        for i in range(b):
            unary = outputs[i].data.cpu().numpy()
            C, H, W = unary.shape
            unary = dcrf_utils.unary_from_softmax(unary)
            unary = np.ascontiguousarray(unary)

            image = np.ascontiguousarray(images[i]) + mean_vector
            image = image.astype(np.ubyte)
            image = np.ascontiguousarray(image.transpose(1, 2, 0))

            d = dcrf.DenseCRF2D(W, H, C)
            d.setUnaryEnergy(unary)
            d.addPairwiseGaussian(sxy=pos_xy_std, compat=pos_w)
            d.addPairwiseBilateral(sxy=bi_xy_std,
                                   srgb=bi_rgb_std,
                                   rgbim=image,
                                   compat=bi_w)
            out_crf = np.array(d.inference(iter_max))
            outputs[i] = torch.from_numpy(out_crf).cuda().view(C, H, W)

        return outputs
示例#5
0
class Trainer(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.foward_time = AverageMeter()
        self.backward_time = AverageMeter()
        self.loss_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.data_loader = DataLoader(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.data_helper = DataHelper(configer, self)
        self.evaluator = get_evaluator(configer, self)        

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.running_score = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        Log.info('Params Group Method: {}'.format(self.configer.get('optim', 'group_method')))
        if self.configer.get('optim', 'group_method') == 'decay':
            params_group = self.group_weight(self.seg_net)
        else:
            assert self.configer.get('optim', 'group_method') is None
            params_group = self._get_parameters()

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(params_group)

        self.train_loader = self.data_loader.get_trainloader()
        self.val_loader = self.data_loader.get_valloader()
        self.pixel_loss = self.loss_manager.get_seg_loss()
        if is_distributed():
            self.pixel_loss = self.module_runner.to_device(self.pixel_loss)        

    @staticmethod
    def group_weight(module):
        group_decay = []
        group_no_decay = []
        for m in module.modules():
            if isinstance(m, nn.Linear):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.conv._ConvNd):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            else:
                if hasattr(m, 'weight'):
                    group_no_decay.append(m.weight)
                if hasattr(m, 'bias'):
                    group_no_decay.append(m.bias)

        assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
        groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
        return groups

    def _get_parameters(self):
        bb_lr = []
        nbb_lr = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                nbb_lr.append(value)
            else:
                bb_lr.append(value)

        params = [{'params': bb_lr, 'lr': self.configer.get('lr', 'base_lr')},
                  {'params': nbb_lr, 'lr': self.configer.get('lr', 'base_lr') * self.configer.get('lr', 'nbb_mult')}]
        return params

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        self.pixel_loss.train()
        start_time = time.time()

        if "swa" in self.configer.get('lr', 'lr_policy'):
            normal_max_iters = int(self.configer.get('solver', 'max_iters') * 0.75)
            swa_step_max_iters = (self.configer.get('solver', 'max_iters') - normal_max_iters) // 5 + 1

        if hasattr(self.train_loader.sampler, 'set_epoch'):
            self.train_loader.sampler.set_epoch(self.configer.get('epoch'))

        for i, data_dict in enumerate(self.train_loader):
            if self.configer.get('lr', 'metric') == 'iters':
                self.scheduler.step(self.configer.get('iters'))
            else:
                self.scheduler.step(self.configer.get('epoch'))


            if self.configer.get('lr', 'is_warm'):
                self.module_runner.warm_lr(
                    self.configer.get('iters'),
                    self.scheduler, self.optimizer, backbone_list=[0,]
                )

            (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict)
            self.data_time.update(time.time() - start_time)

            foward_start_time = time.time()
            outputs = self.seg_net(*inputs)
            self.foward_time.update(time.time() - foward_start_time)

            loss_start_time = time.time()
            if is_distributed():
                import torch.distributed as dist
                def reduce_tensor(inp):
                    """
                    Reduce the loss from all processes so that 
                    process with rank 0 has the averaged results.
                    """
                    world_size = get_world_size()
                    if world_size < 2:
                        return inp
                    with torch.no_grad():
                        reduced_inp = inp
                        dist.reduce(reduced_inp, dst=0)
                    return reduced_inp
                loss = self.pixel_loss(outputs, targets)
                backward_loss = loss
                display_loss = reduce_tensor(backward_loss) / get_world_size()
            else:
                backward_loss = display_loss = self.pixel_loss(outputs, targets, gathered=self.configer.get('network', 'gathered'))

            self.train_losses.update(display_loss.item(), batch_size)
            self.loss_time.update(time.time() - loss_start_time)

            backward_start_time = time.time()
            self.optimizer.zero_grad()
            backward_loss.backward()
            self.optimizer.step()
            self.backward_time.update(time.time() - backward_start_time)

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get('solver', 'display_iter') == 0 and \
                (not is_distributed() or get_rank() == 0):
                Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t'
                         'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                         'Forward Time {foward_time.sum:.3f}s / {2}iters, ({foward_time.avg:.3f})\t'
                         'Backward Time {backward_time.sum:.3f}s / {2}iters, ({backward_time.avg:.3f})\t'
                         'Loss Time {loss_time.sum:.3f}s / {2}iters, ({loss_time.avg:.3f})\t'
                         'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                         'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                         self.configer.get('epoch'), self.configer.get('iters'),
                         self.configer.get('solver', 'display_iter'),
                         self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time,
                         foward_time=self.foward_time, backward_time=self.backward_time, loss_time=self.loss_time,
                         data_time=self.data_time, loss=self.train_losses))
                self.batch_time.reset()
                self.foward_time.reset()
                self.backward_time.reset()
                self.loss_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # save checkpoints for swa
            if 'swa' in self.configer.get('lr', 'lr_policy') and \
               self.configer.get('iters') > normal_max_iters and \
               ((self.configer.get('iters') - normal_max_iters) % swa_step_max_iters == 0 or \
                self.configer.get('iters') == self.configer.get('solver', 'max_iters')):
               self.optimizer.update_swa()

            if self.configer.get('iters') == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            # if self.configer.get('epoch') % self.configer.get('solver', 'test_interval') == 0:
            if self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

        self.configer.plus_one('epoch')


    def __val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        self.pixel_loss.eval()
        start_time = time.time()
        replicas = self.evaluator.prepare_validaton()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            if j % 10 == 0:
                Log.info('{} images processed\n'.format(j))

            if self.configer.get('dataset') == 'lip':
                (inputs, targets, inputs_rev, targets_rev), batch_size = self.data_helper.prepare_data(data_dict, want_reverse=True)
            else:
                (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict)

            with torch.no_grad():
                if self.configer.get('dataset') == 'lip':
                    inputs = torch.cat([inputs[0], inputs_rev[0]], dim=0)
                    outputs = self.seg_net(inputs)        
                    outputs_ = self.module_runner.gather(outputs)
                    if isinstance(outputs_, (list, tuple)):
                        outputs_ = outputs_[-1]
                    outputs = outputs_[0:int(outputs_.size(0)/2),:,:,:].clone()
                    outputs_rev = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),:,:,:].clone()
                    if outputs_rev.shape[1] == 20:
                        outputs_rev[:,14,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),15,:,:]
                        outputs_rev[:,15,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),14,:,:]
                        outputs_rev[:,16,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),17,:,:]
                        outputs_rev[:,17,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),16,:,:]
                        outputs_rev[:,18,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),19,:,:]
                        outputs_rev[:,19,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),18,:,:]
                    outputs_rev = torch.flip(outputs_rev, [3])
                    outputs = (outputs + outputs_rev) / 2.
                    self.evaluator.update_score(outputs, data_dict['meta'])

                elif self.data_helper.conditions.diverse_size:
                    outputs = nn.parallel.parallel_apply(replicas[:len(inputs)], inputs)

                    for i in range(len(outputs)):
                        loss = self.pixel_loss(outputs[i], targets[i])
                        self.val_losses.update(loss.item(), 1)
                        outputs_i = outputs[i]
                        if isinstance(outputs_i, torch.Tensor):
                            outputs_i = [outputs_i]
                        self.evaluator.update_score(outputs_i, data_dict['meta'][i:i+1])
                            
                else:
                    outputs = self.seg_net(*inputs)

                    try:
                        loss = self.pixel_loss(
                            outputs, targets, 
                            gathered=self.configer.get('network', 'gathered')
                        )
                    except AssertionError as e:
                        print(len(outputs), len(targets))


                    if not is_distributed():
                        outputs = self.module_runner.gather(outputs)
                    self.val_losses.update(loss.item(), batch_size)
                    self.evaluator.update_score(outputs, data_dict['meta'])

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.evaluator.update_performance()
        
        self.configer.update(['val_loss'], self.val_losses.avg)
        self.module_runner.save_net(self.seg_net, save_mode='performance')
        self.module_runner.save_net(self.seg_net, save_mode='val_loss')
        cudnn.benchmark = True

        # Print the log info & reset the states.
        if not is_distributed() or get_rank() == 0:
            Log.info(
                'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                'Loss {loss.avg:.8f}\n'.format(
                    batch_time=self.batch_time, loss=self.val_losses))
            self.evaluator.print_scores()
            
        self.batch_time.reset()
        self.val_losses.reset()
        self.evaluator.reset()
        self.seg_net.train()
        self.pixel_loss.train()

    def train(self):
        # cudnn.benchmark = True
        # self.__val()
        if self.configer.get('network', 'resume') is not None:
            if self.configer.get('network', 'resume_val'):
                self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))
                return
            elif self.configer.get('network', 'resume_train'):
                self.__val(data_loader=self.data_loader.get_valloader(dataset='train'))
                return
            # return

        if self.configer.get('network', 'resume') is not None and self.configer.get('network', 'resume_val'):
            self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))
            return

        while self.configer.get('iters') < self.configer.get('solver', 'max_iters'):
            self.__train()

        # use swa to average the model
        if 'swa' in self.configer.get('lr', 'lr_policy'):
            self.optimizer.swap_swa_sgd()
            self.optimizer.bn_update(self.train_loader, self.seg_net)

        self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))

    def summary(self):
        from lib.utils.summary import get_model_summary
        import torch.nn.functional as F
        self.seg_net.eval()

        for j, data_dict in enumerate(self.train_loader):
            print(get_model_summary(self.seg_net, data_dict['img'][0:1]))
            return
class MultiTaskClassifier(object):
    """
      The class for the training phase of Image classification.
    """
    def __init__(self, configer):
        self.configer = configer
        self.runner_state = dict(iters=0,
                                 last_iters=0,
                                 epoch=0,
                                 last_epoch=0,
                                 performance=0,
                                 val_loss=0,
                                 max_performance=0,
                                 min_val_loss=0)

        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = DictAverageMeter()
        self.val_losses = DictAverageMeter()
        self.cls_model_manager = ModelManager(configer)
        self.cls_data_loader = DataLoader(configer)
        self.running_score = RunningScore(configer)

        self.cls_net = self.cls_model_manager.get_model()
        self.solver_dict = self.configer.get(
            self.configer.get('train', 'solver'))
        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(),
                                                      self.solver_dict)
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.cls_net, self.optimizer = RunnerHelper.to_dtype(
            self, self.cls_net, self.optimizer)

        self.train_loaders = dict()
        self.val_loaders = dict()
        for source in range(self.configer.get('data', 'num_data_sources')):
            self.train_loaders[source] = self.cls_data_loader.get_trainloader(
                source=source)
            self.val_loaders[source] = self.cls_data_loader.get_valloader(
                source=source)
        if self.configer.get('data', 'mixup'):
            assert (self.configer.get('data', 'num_data_sources') == 2
                    ), "mixup only support src0 and src1 load the same dataset"

        self.loss = self.cls_model_manager.get_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_2 = []
        params_dict = dict(self.cls_net.named_parameters())
        for key, value in params_dict.items():
            if value.requires_grad:
                if 'backbone' in key:
                    if self.configer.get('network', 'bb_lr_scale') == 0.0:
                        value.requires_grad = False
                    else:
                        lr_1.append(value)
                else:
                    lr_2.append(value)

        params = [{
            'params':
            lr_1,
            'lr':
            self.solver_dict['lr']['base_lr'] *
            self.configer.get('network', 'bb_lr_scale')
        }, {
            'params': lr_2,
            'lr': self.solver_dict['lr']['base_lr']
        }]
        return params

    def run(self):
        """
          Train function of every epoch during train phase.
        """
        if self.configer.get('network', 'resume_val'):
            self.val()

        self.cls_net.train()
        train_loaders = dict()
        for source in self.train_loaders:
            train_loaders[source] = iter(self.train_loaders[source])
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        while self.runner_state['iters'] < self.solver_dict['max_iters']:
            data_dict = dict()
            for source in train_loaders:
                try:
                    tmp_data_dict = next(train_loaders[source])
                    # Log.info('iter={}, source={}'.format(self.runner_state['iters'], source))
                except StopIteration:
                    if source == 0 or source == '0':
                        self.runner_state['epoch'] += 1
                    # Log.info('Repeat: iter={}, source={}'.format(self.runner_state['iters'], source))
                    train_loaders[source] = iter(self.train_loaders[source])
                    tmp_data_dict = next(train_loaders[source])
                for k, v in tmp_data_dict.items():
                    data_dict['src{}_{}'.format(source, k)] = v

            if self.configer.get('data', 'multiscale') is not None:
                scale_ratios = self.configer.get('data', 'multiscale')
                scale_ratio = random.uniform(scale_ratios[0], scale_ratios[-1])
                for key in data_dict:
                    if key.endswith('_img'):
                        data_dict[key] = F.interpolate(
                            data_dict[key],
                            scale_factor=[scale_ratio, scale_ratio],
                            mode='bilinear',
                            align_corners=True)
            if self.configer.get('data', 'mixup'):
                src0_resize = F.interpolate(data_dict['src0_img'],
                                            scale_factor=[
                                                random.uniform(0.4, 0.6),
                                                random.uniform(0.4, 0.6)
                                            ],
                                            mode='bilinear',
                                            align_corners=True)
                b, c, h, w = src0_resize.shape
                pos = random.randint(0, 3)
                if pos == 0:  # top-left
                    data_dict['src1_img'][:, :, 0:h, 0:w] = src0_resize
                elif pos == 1:  # top-right
                    data_dict['src1_img'][:, :, 0:h, -w:] = src0_resize
                elif pos == 2:  # bottom-left
                    data_dict['src1_img'][:, :, -h:, 0:w] = src0_resize
                else:  # bottom-right
                    data_dict['src1_img'][:, :, -h:, -w:] = src0_resize

            data_dict = RunnerHelper.to_device(self, data_dict)
            Trainer.update(
                self,
                warm_list=(0, ),
                warm_lr_list=(self.solver_dict['lr']['base_lr'] *
                              self.configer.get('network', 'bb_lr_scale'), ),
                solver_dict=self.solver_dict)
            self.data_time.update(time.time() - start_time)
            # Forward pass.
            out = self.cls_net(data_dict)
            loss_dict, loss_weight_dict = self.loss(out)
            # Compute the loss of the train batch & backward.
            loss = loss_dict['loss']
            self.train_losses.update(
                {key: loss.item()
                 for key, loss in loss_dict.items()},
                data_dict['src0_img'].size(0))
            self.optimizer.zero_grad()
            if self.configer.get('dtype') == 'fp16':
                with amp.scale_loss(loss, self.optimizer) as scaled_losses:
                    scaled_losses.backward()
            else:
                loss.backward()
            if self.configer.get('network', 'clip_grad'):
                RunnerHelper.clip_grad(self.cls_net, 10.)

            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.solver_dict[
                    'display_iter'] == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {5}\tLoss = {3}\nLossWeight = {4}\n'.
                    format(self.runner_state['epoch'],
                           self.runner_state['iters'],
                           self.solver_dict['display_iter'],
                           self.train_losses.info(),
                           loss_weight_dict,
                           RunnerHelper.get_lr(self.optimizer),
                           batch_time=self.batch_time,
                           data_time=self.data_time))

                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.solver_dict['lr'][
                    'metric'] == 'iters' and self.runner_state[
                        'iters'] == self.solver_dict['max_iters']:
                if self.configer.get('local_rank') == 0:
                    RunnerHelper.save_net(self, self.cls_net, postfix='final')
                break

            if self.runner_state['iters'] % self.solver_dict[
                    'save_iters'] == 0 and self.configer.get(
                        'local_rank') == 0:
                RunnerHelper.save_net(self, self.cls_net)

            # Check to val the current model.
            if self.runner_state['iters'] % self.solver_dict[
                    'test_interval'] == 0:
                self.val()
                if self.configer.get('local_rank') == 0:
                    RunnerHelper.save_net(
                        self,
                        self.cls_net,
                        performance=self.runner_state['performance'])

        self.val()

    def val(self):
        """
          Validation function during the train phase.
        """
        self.cls_net.eval()
        start_time = time.time()
        val_loaders = dict()
        val_to_end = dict()
        all_to_end = False
        for source in self.val_loaders:
            val_loaders[source] = iter(self.val_loaders[source])
            val_to_end[source] = False
        with torch.no_grad():
            while not all_to_end:
                data_dict = dict()
                for source in val_loaders:
                    try:
                        tmp_data_dict = next(val_loaders[source])
                    except StopIteration:
                        val_to_end[source] = True
                        val_loaders[source] = iter(self.val_loaders[source])
                        tmp_data_dict = next(val_loaders[source])
                    for k, v in tmp_data_dict.items():
                        data_dict['src{}_{}'.format(source, k)] = v
                # Forward pass.
                data_dict = RunnerHelper.to_device(self, data_dict)
                out = self.cls_net(data_dict)
                loss_dict, loss_weight_dict = self.loss(out)
                out_dict, label_dict, _ = RunnerHelper.gather(self, out)
                # Compute the loss of the val batch.
                self.running_score.update(out_dict, label_dict)
                self.val_losses.update(
                    {
                        key: loss.mean().item()
                        for key, loss in loss_dict.items()
                    }, data_dict['src0_img'].size(0))
                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()
                # check whether scan over all data sources
                all_to_end = True
                for source in val_to_end:
                    if not val_to_end[source]:
                        all_to_end = False

            Log.info('Test Time {batch_time.sum:.3f}s'.format(
                batch_time=self.batch_time))
            Log.info('TestLoss = {}'.format(self.val_losses.info()))
            Log.info('TestLossWeight = {}'.format(loss_weight_dict))
            Log.info('Top1 ACC = {}'.format(self.running_score.get_top1_acc()))
            Log.info('Top3 ACC = {}'.format(self.running_score.get_top3_acc()))
            Log.info('Top5 ACC = {}'.format(self.running_score.get_top5_acc()))
            top1_acc = yaml.load(self.running_score.get_top1_acc())
            for key in top1_acc:
                if 'src0_label0' in key:
                    self.runner_state['performance'] = top1_acc[key]
                    Log.info('Use acc of {} to compare performace'.format(key))
                    break
            self.running_score.reset()
            self.val_losses.reset()
            self.batch_time.reset()
            self.cls_net.train()
示例#7
0
class ImageClassifier(object):
    """
      The class for the training phase of Image classification.
    """
    def __init__(self, configer):
        self.configer = configer
        self.runner_state = dict(iters=0,
                                 last_iters=0,
                                 epoch=0,
                                 last_epoch=0,
                                 performance=0,
                                 val_loss=0,
                                 max_performance=0,
                                 min_val_loss=0)

        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = DictAverageMeter()
        self.val_losses = DictAverageMeter()
        self.cls_model_manager = ModelManager(configer)
        self.cls_data_loader = DataLoader(configer)
        self.running_score = RunningScore(configer)

        self.cls_net = self.cls_model_manager.get_model()
        self.solver_dict = self.configer.get(
            self.configer.get('train', 'solver'))
        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(),
                                                      self.solver_dict)
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.cls_net, self.optimizer = RunnerHelper.to_dtype(
            self, self.cls_net, self.optimizer)

        self.train_loader = self.cls_data_loader.get_trainloader()
        self.val_loader = self.cls_data_loader.get_valloader()
        self.loss = self.cls_model_manager.get_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_2 = []
        params_dict = dict(self.cls_net.named_parameters())
        for key, value in params_dict.items():
            if value.requires_grad:
                if 'backbone' in key:
                    if self.configer.get('network', 'bb_lr_scale') == 0.0:
                        value.requires_grad = False
                    else:
                        lr_1.append(value)
                else:
                    lr_2.append(value)

        params = [{
            'params':
            lr_1,
            'lr':
            self.solver_dict['lr']['base_lr'] *
            self.configer.get('network', 'bb_lr_scale')
        }, {
            'params': lr_2,
            'lr': self.solver_dict['lr']['base_lr']
        }]
        return params

    def run(self):
        if self.configer.get('network', 'resume_val'):
            self.val()

        while self.runner_state['iters'] < self.solver_dict['max_iters']:
            self.train()

        self.val()

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.cls_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        self.runner_state['epoch'] += 1
        for i, data_dict in enumerate(self.train_loader):
            data_dict = {'src0_{}'.format(k): v for k, v in data_dict.items()}
            Trainer.update(
                self,
                warm_list=(0, ),
                warm_lr_list=(self.solver_dict['lr']['base_lr'] *
                              self.configer.get('network', 'bb_lr_scale'), ),
                solver_dict=self.solver_dict)
            self.data_time.update(time.time() - start_time)
            data_dict = RunnerHelper.to_device(self, data_dict)
            # Forward pass.
            out = self.cls_net(data_dict)
            loss_dict, _ = self.loss(out)
            # Compute the loss of the train batch & backward.

            loss = loss_dict['loss']
            self.train_losses.update(
                {key: loss.item()
                 for key, loss in loss_dict.items()},
                data_dict['src0_img'].size(0))
            self.optimizer.zero_grad()
            loss.backward()
            if self.configer.get('network', 'clip_grad'):
                RunnerHelper.clip_grad(self.cls_net, 10.)

            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.solver_dict[
                    'display_iter'] == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {4}\tLoss = {3}\n'.format(
                        self.runner_state['epoch'],
                        self.runner_state['iters'],
                        self.solver_dict['display_iter'],
                        self.train_losses.info(),
                        RunnerHelper.get_lr(self.optimizer),
                        batch_time=self.batch_time,
                        data_time=self.data_time))

                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.solver_dict['lr'][
                    'metric'] == 'iters' and self.runner_state[
                        'iters'] == self.solver_dict['max_iters']:
                break

            if self.runner_state['iters'] % self.solver_dict[
                    'save_iters'] == 0 and self.configer.get(
                        'local_rank') == 0:
                RunnerHelper.save_net(self, self.cls_net)

            # Check to val the current model.
            # if self.runner_state['iters'] % self.solver_dict['test_interval'] == 0 \
            #         and not self.configer.get('distributed'):
            #     self.val()
            if self.runner_state['iters'] % self.solver_dict[
                    'test_interval'] == 0:
                self.val()

    def val(self):
        """
          Validation function during the train phase.
        """
        self.cls_net.eval()
        start_time = time.time()
        with torch.no_grad():
            for j, data_dict in enumerate(self.val_loader):
                data_dict = {
                    'src0_{}'.format(k): v
                    for k, v in data_dict.items()
                }
                # Forward pass.
                data_dict = RunnerHelper.to_device(self, data_dict)
                out = self.cls_net(data_dict)
                loss_dict = self.loss(out)
                out_dict, label_dict, _ = RunnerHelper.gather(self, out)
                self.running_score.update(out_dict, label_dict)
                self.val_losses.update(
                    {key: loss.item()
                     for key, loss in loss_dict.items()},
                    data_dict['src0_img'].size(0))

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            # RunnerHelper.save_net(self, self.cls_net) # only local_rank=0 can save net
            # Print the log info & reset the states.
            Log.info('Test Time {batch_time.sum:.3f}s'.format(
                batch_time=self.batch_time))
            Log.info('TestLoss = {}'.format(self.val_losses.info()))
            Log.info('Top1 ACC = {}'.format(self.running_score.get_top1_acc()))
            Log.info('Top3 ACC = {}'.format(self.running_score.get_top3_acc()))
            Log.info('Top5 ACC = {}'.format(self.running_score.get_top5_acc()))
            self.batch_time.reset()
            self.batch_time.reset()
            self.val_losses.reset()
            self.running_score.reset()
            self.cls_net.train()
示例#8
0
class Trainer(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.running_score = RunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.data_loader = DataLoader(configer)
        self.optim_scheduler = OptimScheduler(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        Log.info('Params Group Method: {}'.format(
            self.configer.get('optim', 'group_method')))
        if self.configer.get('optim', 'group_method') == 'decay':
            params_group = self.group_weight(self.seg_net)
        else:
            assert self.configer.get('optim', 'group_method') is None
            params_group = self._get_parameters()

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(
            params_group)

        self.train_loader = self.data_loader.get_trainloader()
        self.val_loader = self.data_loader.get_valloader()

        self.pixel_loss = self.loss_manager.get_seg_loss()

    @staticmethod
    def group_weight(module):
        group_decay = []
        group_no_decay = []
        for m in module.modules():
            if isinstance(m, nn.Linear):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.conv._ConvNd):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            else:
                if hasattr(m, 'weight'):
                    group_no_decay.append(m.weight)
                if hasattr(m, 'bias'):
                    group_no_decay.append(m.bias)

        assert len(list(
            module.parameters())) == len(group_decay) + len(group_no_decay)
        groups = [
            dict(params=group_decay),
            dict(params=group_no_decay, weight_decay=.0)
        ]
        return groups

    def _get_parameters(self):
        bb_lr = []
        nbb_lr = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                nbb_lr.append(value)
            else:
                bb_lr.append(value)

        params = [{
            'params': bb_lr,
            'lr': self.configer.get('lr', 'base_lr')
        }, {
            'params':
            nbb_lr,
            'lr':
            self.configer.get('lr', 'base_lr') *
            self.configer.get('lr', 'nbb_mult')
        }]
        return params

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        start_time = time.time()

        for i, data_dict in enumerate(self.train_loader):
            if self.configer.get('lr', 'metric') == 'iters':
                self.scheduler.step(self.configer.get('iters'))
            else:
                self.scheduler.step(self.configer.get('epoch'))

            if self.configer.get('lr', 'is_warm'):
                self.module_runner.warm_lr(self.configer.get('iters'),
                                           self.scheduler,
                                           self.optimizer,
                                           backbone_list=[
                                               0,
                                           ])
            inputs = data_dict['img']
            targets = data_dict['labelmap']
            self.data_time.update(time.time() - start_time)
            # Change the data type.
            # inputs, targets = self.module_runner.to_device(inputs, targets)

            # Forward pass.
            outputs = self.seg_net(inputs)
            # outputs = self.module_utilizer.gather(outputs)
            # Compute the loss of the train batch & backward.
            loss = self.pixel_loss(outputs,
                                   targets,
                                   gathered=self.configer.get(
                                       'network', 'gathered'))
            if self.configer.exists('train', 'loader') and self.configer.get(
                    'train', 'loader') == 'ade20k':
                batch_size = self.configer.get(
                    'train', 'batch_size') * self.configer.get(
                        'train', 'batch_per_gpu')
                self.train_losses.update(loss.item(), batch_size)
            else:
                self.train_losses.update(loss.item(), inputs.size(0))

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.configer.get('epoch'),
                            self.configer.get('iters'),
                            self.configer.get('solver', 'display_iter'),
                            self.module_runner.get_lr(self.optimizer),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.configer.get('iters') == self.configer.get(
                    'solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'test_interval') == 0:
                self.__val()

        self.configer.plus_one('epoch')

    def __val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            inputs = data_dict['img']
            targets = data_dict['labelmap']

            with torch.no_grad():
                # Change the data type.
                inputs, targets = self.module_runner.to_device(inputs, targets)
                # Forward pass.
                outputs = self.seg_net(inputs)
                # Compute the loss of the val batch.
                loss = self.pixel_loss(outputs,
                                       targets,
                                       gathered=self.configer.get(
                                           'network', 'gathered'))
                outputs = self.module_runner.gather(outputs)

            self.val_losses.update(loss.item(), inputs.size(0))
            self._update_running_score(outputs[-1], data_dict['meta'])
            # self.seg_running_score.update(pred.max(1)[1].cpu().numpy(), targets.cpu().numpy())

            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.configer.update(['performance'],
                             self.running_score.get_mean_iou())
        self.configer.update(['val_loss'], self.val_losses.avg)
        self.module_runner.save_net(self.seg_net, save_mode='performance')
        self.module_runner.save_net(self.seg_net, save_mode='val_loss')

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                                loss=self.val_losses))
        Log.info('Mean IOU: {}\n'.format(self.running_score.get_mean_iou()))
        Log.info('Pixel ACC: {}\n'.format(self.running_score.get_pixel_acc()))
        self.batch_time.reset()
        self.val_losses.reset()
        self.running_score.reset()
        self.seg_net.train()

    def _update_running_score(self, pred, metas):
        pred = pred.permute(0, 2, 3, 1)
        for i in range(pred.size(0)):
            ori_img_size = metas[i]['ori_img_size']
            border_size = metas[i]['border_size']
            ori_target = metas[i]['ori_target']
            total_logits = cv2.resize(
                pred[i, :border_size[1], :border_size[0]].cpu().numpy(),
                tuple(ori_img_size),
                interpolation=cv2.INTER_CUBIC)
            labelmap = np.argmax(total_logits, axis=-1)
            self.running_score.update(labelmap[None], ori_target[None])

    def train(self):
        # cudnn.benchmark = True
        if self.configer.get('network',
                             'resume') is not None and self.configer.get(
                                 'network', 'resume_val'):
            self.__val()

        while self.configer.get('iters') < self.configer.get(
                'solver', 'max_iters'):
            self.__train()

        self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))
        self.__val(data_loader=self.data_loader.get_valloader(dataset='train'))