class AutoSKUMerger(object):
    """
      The class for the training phase of Image classification.
    """
    def __init__(self, configer):
        self.configer = configer
        self.cls_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = None
        self.round = 1

        self._relabel()

    def _relabel(self):
        label_id = 0
        label_dict = dict()
        old_label_path = self.configer.get('data', 'label_path')
        new_label_path = '{}_new'.format(self.configer.get('data', 'label_path'))
        self.configer.update('data.label_path', new_label_path)
        fw = open(new_label_path, 'w')
        check_valid_dict = dict()
        with open(old_label_path, 'r') as fr:
            for line in fr.readlines():
                line_items = line.strip().split()
                if not os.path.exists(os.path.join(self.configer.get('data', 'data_dir'), line_items[0])):
                    continue

                if line_items[1] not in label_dict:
                    label_dict[line_items[1]] = label_id
                    label_id += 1

                if line_items[0] in check_valid_dict:
                    Log.error('Duplicate Error: {}'.format(line_items[0]))
                    exit()

                check_valid_dict[line_items[0]] = 1
                fw.write('{} {}\n'.format(line_items[0], label_dict[line_items[1]]))

        fw.close()
        shutil.copy(self.configer.get('data', 'label_path'),
                    os.path.join(self.configer.get('data', 'merge_dir'), 'ori_label.txt'))
        self.configer.update(('data.num_classes'), [label_id])
        Log.info('Num Classes is {}...'.format(self.configer.get('data', 'num_classes')))

    def _init(self):
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.cls_model_manager = ModelManager(self.configer)
        self.cls_data_loader = DataLoader(self.configer)
        self.cls_running_score = RunningScore(self.configer)
        self.runner_state = dict(iters=0, last_iters=0, epoch=0,
                                 last_epoch=0, performance=0,
                                 val_loss=0, max_performance=0, min_val_loss=0)
        self.cls_net = self.cls_model_manager.get_model()
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.solver_dict = self.configer.get(self.configer.get('train', 'solver'))
        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.solver_dict)

        self.cls_net, self.optimizer = RunnerHelper.to_dtype(self, self.cls_net, self.optimizer)
        self.train_loader = self.cls_data_loader.get_trainloader()
        self.val_loader = self.cls_data_loader.get_valloader()
        self.loss = self.cls_model_manager.get_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_2 = []
        params_dict = dict(self.cls_net.named_parameters())
        for key, value in params_dict.items():
            if value.requires_grad:
                if 'backbone' in key:
                    if self.configer.get('network', 'bb_lr_scale') == 0.0:
                        value.requires_grad = False
                    else:
                        lr_1.append(value)
                else:
                    lr_2.append(value)

        params = [{'params': lr_1, 'lr': self.solver_dict['lr']['base_lr']*self.configer.get('network', 'bb_lr_scale')},
                  {'params': lr_2, 'lr': self.solver_dict['lr']['base_lr']}]
        return params

    def _merge_class(self, cmatrix, fscore_list):
        Log.info('Merging class...')
        Log.info('Avg F1-score: {}'.format(fscore_list[-1]))
        threshold = max(self.configer.get('merge', 'min_thre'),
                        self.configer.get('merge', 'max_thre') - self.configer.get('merge', 'round_decay') * self.round)
        h, w = cmatrix.shape[0], cmatrix.shape[1]
        per_class_num = np.sum(cmatrix, 1)
        pairs_list = list()
        pre_dict = dict()
        for i in range(h):
            for j in range(w):
                if i == j:
                    continue

                if cmatrix[i][j] * 1.0 / per_class_num[i] > threshold:
                    pairs_list.append([i, j])
                    pre_dict[i] = i
                    pre_dict[j] = j

        for pair in pairs_list:
            root_node = list()
            for item in pair:
                r = item
                while pre_dict[r] != r:
                    r = pre_dict[r]

                i = item
                while i != r:
                    j = pre_dict[i]
                    pre_dict[i] = r
                    i = j

                root_node.append(r)

            if root_node[0] != root_node[1]:
                pre_dict[root_node[0]] = root_node[1]

        pairs_dict = dict()
        for k in pre_dict.keys():
            v = k
            while pre_dict[v] != v:
                v = pre_dict[v]

            if v != k:
                if v not in pairs_dict:
                    pairs_dict[v] = [k]
                else:
                    pairs_dict[v].append(k)

        mutual_pairs_dict = {}
        for k, v in pairs_dict.items():
            mutual_pairs_dict[k] = v
            if len(v) > 1:  # multi relation
                for p in v:
                    mutual_pairs_dict[p] = [k]
                    for q in v:
                        if p != q:
                            mutual_pairs_dict[p].append(q)

            else:
                mutual_pairs_dict[v[0]] = [k]  # mutual relation

        id_map_list = [-1] * self.configer.get('data', 'num_classes')[0]
        label_cnt = 0
        for i in range(self.configer.get('data', 'num_classes')[0]):
            if id_map_list[i] != -1:
                continue

            power = self.round / self.configer.get('merge', 'max_round')
            if self.configer.get('merge', 'enable_fscore') and \
                    fscore_list[i] / fscore_list[-1] < self.configer.get('merge', 'fscore_ratio') * power:
                continue

            id_map_list[i] = label_cnt
            if i in mutual_pairs_dict:
                for v in mutual_pairs_dict[i]:
                    assert id_map_list[v] == -1
                    id_map_list[v] = label_cnt

            label_cnt += 1

        fw = open('{}_{}'.format(self.configer.get('data', 'label_path'), self.round), 'w')
        with open(self.configer.get('data', 'label_path'), 'r') as fr:
            for line in fr.readlines():
                path, label = line.strip().split()
                if id_map_list[int(label)] == -1:
                    continue

                map_label = id_map_list[int(label)]
                fw.write('{} {}\n'.format(path, map_label))

        fw.close()
        shutil.move('{}_{}'.format(self.configer.get('data', 'label_path'), self.round),
                    self.configer.get('data', 'label_path'))
        shutil.copy(self.configer.get('data', 'label_path'),
                    os.path.join(self.configer.get('data', 'merge_dir'), 'label_{}.txt'.format(self.round)))
        old_label_cnt = self.configer.get('data', 'num_classes')[0]
        self.configer.update('data.num_classes', [label_cnt])
        return old_label_cnt - label_cnt

    def run(self):
        last_acc = 0.0
        while self.round <= self.configer.get('merge', 'max_round'):
            Log.info('Merge Round: {}'.format(self.round))
            Log.info('num classes: {}'.format(self.configer.get('data', 'num_classes')))
            self._init()
            self.train()
            acc, cmatrix, fscore_list = self.val(self.cls_data_loader.get_valloader())
            merge_cnt = self._merge_class(cmatrix, fscore_list)
            if merge_cnt < self.configer.get('merge', 'cnt_thre') \
                    or (acc - last_acc) < self.configer.get('merge', 'acc_thre'):
                break

            last_acc = acc
            self.round += 1

        shutil.copy(self.configer.get('data', 'label_path'),
                    os.path.join(self.configer.get('data', 'merge_dir'), 'merge_label.txt'))
        self._init()
        self.train()
        Log.info('num classes: {}'.format(self.configer.get('data', 'num_classes')))

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.cls_net.train()
        start_time = time.time()
        while self.runner_state['iters'] < self.solver_dict['max_iters']:
            # Adjust the learning rate after every epoch.
            self.runner_state['epoch'] += 1
            for i, data_dict in enumerate(self.train_loader):
                Trainer.update(self, solver_dict=self.solver_dict)
                self.data_time.update(time.time() - start_time)
                # Change the data type.
                # Forward pass.
                out = self.cls_net(data_dict)
                # Compute the loss of the train batch & backward.

                loss_dict = self.loss(out)
                loss = loss_dict['loss']
                self.train_losses.update(loss.item(), data_dict['img'].size(0))
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # Update the vars of the train phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()
                self.runner_state['iters'] += 1

                # Print the log info & reset the states.
                if self.runner_state['iters'] % self.solver_dict['display_iter'] == 0:
                    Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t'
                             'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                             'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                             'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                                 self.runner_state['epoch'], self.runner_state['iters'],
                                 self.solver_dict['display_iter'],
                                 RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time,
                                 data_time=self.data_time, loss=self.train_losses))

                    self.batch_time.reset()
                    self.data_time.reset()
                    self.train_losses.reset()

                if self.solver_dict['lr']['metric'] == 'iters' and self.runner_state['iters'] == self.solver_dict['max_iters']:
                    self.val()
                    break

                # Check to val the current model.
                if self.runner_state['iters'] % self.solver_dict['test_interval'] == 0:
                    self.val()

    def val(self, loader=None):
        """
          Validation function during the train phase.
        """
        self.cls_net.eval()
        start_time = time.time()

        loader = self.val_loader if loader is None else loader
        list_y_true, list_y_pred = [], []
        with torch.no_grad():
            for j, data_dict in enumerate(loader):
                out = self.cls_net(data_dict)
                loss_dict = self.loss(out)
                out_dict, label_dict, _ = RunnerHelper.gather(self, out)
                # Compute the loss of the val batch.
                self.cls_running_score.update(out_dict, label_dict)
                y_true = label_dict['out0'].view(-1).cpu().numpy().tolist()
                y_pred = out_dict['out0'].max(1)[1].view(-1).cpu().numpy().tolist()
                list_y_true.extend(y_true)
                list_y_pred.extend(y_pred)

                self.val_losses.update(loss_dict['loss'].mean().item(), data_dict['img'].size(0))

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            RunnerHelper.save_net(self, self.cls_net, performance=self.cls_running_score.top1_acc.avg['out0'])
            self.runner_state['performance'] = self.cls_running_score.top1_acc.avg['out0']
            # Print the log info & reset the states.
            Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time))
            Log.info('Test Set: {} images'.format(len(list_y_true)))
            Log.info('TestLoss = {loss.avg:.8f}'.format(loss=self.val_losses))
            Log.info('Top1 ACC = {}'.format(self.cls_running_score.top1_acc.avg['out0']))
            # Log.info('Top5 ACC = {}'.format(self.cls_running_score.get_top5_acc()))
            acc= self.cls_running_score.top1_acc.avg['out0']
            cmatrix = confusion_matrix(list_y_true, list_y_pred)
            fscore_str = classification_report(list_y_true, list_y_pred, digits=5)
            fscore_list = [float(line.strip().split()[-2])
                           for line in fscore_str.split('\n')[2:] if len(line.strip().split()) > 0]
            self.batch_time.reset()
            self.val_losses.reset()
            self.cls_running_score.reset()
            self.cls_net.train()
            return acc, cmatrix, fscore_list
Пример #2
0
class Trainer(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.foward_time = AverageMeter()
        self.backward_time = AverageMeter()
        self.loss_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.data_loader = DataLoader(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.data_helper = DataHelper(configer, self)
        self.evaluator = get_evaluator(configer, self)        

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.running_score = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        Log.info('Params Group Method: {}'.format(self.configer.get('optim', 'group_method')))
        if self.configer.get('optim', 'group_method') == 'decay':
            params_group = self.group_weight(self.seg_net)
        else:
            assert self.configer.get('optim', 'group_method') is None
            params_group = self._get_parameters()

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(params_group)

        self.train_loader = self.data_loader.get_trainloader()
        self.val_loader = self.data_loader.get_valloader()
        self.pixel_loss = self.loss_manager.get_seg_loss()
        if is_distributed():
            self.pixel_loss = self.module_runner.to_device(self.pixel_loss)        

    @staticmethod
    def group_weight(module):
        group_decay = []
        group_no_decay = []
        for m in module.modules():
            if isinstance(m, nn.Linear):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.conv._ConvNd):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            else:
                if hasattr(m, 'weight'):
                    group_no_decay.append(m.weight)
                if hasattr(m, 'bias'):
                    group_no_decay.append(m.bias)

        assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
        groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
        return groups

    def _get_parameters(self):
        bb_lr = []
        nbb_lr = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                nbb_lr.append(value)
            else:
                bb_lr.append(value)

        params = [{'params': bb_lr, 'lr': self.configer.get('lr', 'base_lr')},
                  {'params': nbb_lr, 'lr': self.configer.get('lr', 'base_lr') * self.configer.get('lr', 'nbb_mult')}]
        return params

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        self.pixel_loss.train()
        start_time = time.time()

        if "swa" in self.configer.get('lr', 'lr_policy'):
            normal_max_iters = int(self.configer.get('solver', 'max_iters') * 0.75)
            swa_step_max_iters = (self.configer.get('solver', 'max_iters') - normal_max_iters) // 5 + 1

        if hasattr(self.train_loader.sampler, 'set_epoch'):
            self.train_loader.sampler.set_epoch(self.configer.get('epoch'))

        for i, data_dict in enumerate(self.train_loader):
            if self.configer.get('lr', 'metric') == 'iters':
                self.scheduler.step(self.configer.get('iters'))
            else:
                self.scheduler.step(self.configer.get('epoch'))


            if self.configer.get('lr', 'is_warm'):
                self.module_runner.warm_lr(
                    self.configer.get('iters'),
                    self.scheduler, self.optimizer, backbone_list=[0,]
                )

            (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict)
            self.data_time.update(time.time() - start_time)

            foward_start_time = time.time()
            outputs = self.seg_net(*inputs)
            self.foward_time.update(time.time() - foward_start_time)

            loss_start_time = time.time()
            if is_distributed():
                import torch.distributed as dist
                def reduce_tensor(inp):
                    """
                    Reduce the loss from all processes so that 
                    process with rank 0 has the averaged results.
                    """
                    world_size = get_world_size()
                    if world_size < 2:
                        return inp
                    with torch.no_grad():
                        reduced_inp = inp
                        dist.reduce(reduced_inp, dst=0)
                    return reduced_inp
                loss = self.pixel_loss(outputs, targets)
                backward_loss = loss
                display_loss = reduce_tensor(backward_loss) / get_world_size()
            else:
                backward_loss = display_loss = self.pixel_loss(outputs, targets, gathered=self.configer.get('network', 'gathered'))

            self.train_losses.update(display_loss.item(), batch_size)
            self.loss_time.update(time.time() - loss_start_time)

            backward_start_time = time.time()
            self.optimizer.zero_grad()
            backward_loss.backward()
            self.optimizer.step()
            self.backward_time.update(time.time() - backward_start_time)

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get('solver', 'display_iter') == 0 and \
                (not is_distributed() or get_rank() == 0):
                Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t'
                         'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                         'Forward Time {foward_time.sum:.3f}s / {2}iters, ({foward_time.avg:.3f})\t'
                         'Backward Time {backward_time.sum:.3f}s / {2}iters, ({backward_time.avg:.3f})\t'
                         'Loss Time {loss_time.sum:.3f}s / {2}iters, ({loss_time.avg:.3f})\t'
                         'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                         'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                         self.configer.get('epoch'), self.configer.get('iters'),
                         self.configer.get('solver', 'display_iter'),
                         self.module_runner.get_lr(self.optimizer), batch_time=self.batch_time,
                         foward_time=self.foward_time, backward_time=self.backward_time, loss_time=self.loss_time,
                         data_time=self.data_time, loss=self.train_losses))
                self.batch_time.reset()
                self.foward_time.reset()
                self.backward_time.reset()
                self.loss_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # save checkpoints for swa
            if 'swa' in self.configer.get('lr', 'lr_policy') and \
               self.configer.get('iters') > normal_max_iters and \
               ((self.configer.get('iters') - normal_max_iters) % swa_step_max_iters == 0 or \
                self.configer.get('iters') == self.configer.get('solver', 'max_iters')):
               self.optimizer.update_swa()

            if self.configer.get('iters') == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            # if self.configer.get('epoch') % self.configer.get('solver', 'test_interval') == 0:
            if self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

        self.configer.plus_one('epoch')


    def __val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        self.pixel_loss.eval()
        start_time = time.time()
        replicas = self.evaluator.prepare_validaton()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            if j % 10 == 0:
                Log.info('{} images processed\n'.format(j))

            if self.configer.get('dataset') == 'lip':
                (inputs, targets, inputs_rev, targets_rev), batch_size = self.data_helper.prepare_data(data_dict, want_reverse=True)
            else:
                (inputs, targets), batch_size = self.data_helper.prepare_data(data_dict)

            with torch.no_grad():
                if self.configer.get('dataset') == 'lip':
                    inputs = torch.cat([inputs[0], inputs_rev[0]], dim=0)
                    outputs = self.seg_net(inputs)        
                    outputs_ = self.module_runner.gather(outputs)
                    if isinstance(outputs_, (list, tuple)):
                        outputs_ = outputs_[-1]
                    outputs = outputs_[0:int(outputs_.size(0)/2),:,:,:].clone()
                    outputs_rev = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),:,:,:].clone()
                    if outputs_rev.shape[1] == 20:
                        outputs_rev[:,14,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),15,:,:]
                        outputs_rev[:,15,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),14,:,:]
                        outputs_rev[:,16,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),17,:,:]
                        outputs_rev[:,17,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),16,:,:]
                        outputs_rev[:,18,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),19,:,:]
                        outputs_rev[:,19,:,:] = outputs_[int(outputs_.size(0)/2):int(outputs_.size(0)),18,:,:]
                    outputs_rev = torch.flip(outputs_rev, [3])
                    outputs = (outputs + outputs_rev) / 2.
                    self.evaluator.update_score(outputs, data_dict['meta'])

                elif self.data_helper.conditions.diverse_size:
                    outputs = nn.parallel.parallel_apply(replicas[:len(inputs)], inputs)

                    for i in range(len(outputs)):
                        loss = self.pixel_loss(outputs[i], targets[i])
                        self.val_losses.update(loss.item(), 1)
                        outputs_i = outputs[i]
                        if isinstance(outputs_i, torch.Tensor):
                            outputs_i = [outputs_i]
                        self.evaluator.update_score(outputs_i, data_dict['meta'][i:i+1])
                            
                else:
                    outputs = self.seg_net(*inputs)

                    try:
                        loss = self.pixel_loss(
                            outputs, targets, 
                            gathered=self.configer.get('network', 'gathered')
                        )
                    except AssertionError as e:
                        print(len(outputs), len(targets))


                    if not is_distributed():
                        outputs = self.module_runner.gather(outputs)
                    self.val_losses.update(loss.item(), batch_size)
                    self.evaluator.update_score(outputs, data_dict['meta'])

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.evaluator.update_performance()
        
        self.configer.update(['val_loss'], self.val_losses.avg)
        self.module_runner.save_net(self.seg_net, save_mode='performance')
        self.module_runner.save_net(self.seg_net, save_mode='val_loss')
        cudnn.benchmark = True

        # Print the log info & reset the states.
        if not is_distributed() or get_rank() == 0:
            Log.info(
                'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                'Loss {loss.avg:.8f}\n'.format(
                    batch_time=self.batch_time, loss=self.val_losses))
            self.evaluator.print_scores()
            
        self.batch_time.reset()
        self.val_losses.reset()
        self.evaluator.reset()
        self.seg_net.train()
        self.pixel_loss.train()

    def train(self):
        # cudnn.benchmark = True
        # self.__val()
        if self.configer.get('network', 'resume') is not None:
            if self.configer.get('network', 'resume_val'):
                self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))
                return
            elif self.configer.get('network', 'resume_train'):
                self.__val(data_loader=self.data_loader.get_valloader(dataset='train'))
                return
            # return

        if self.configer.get('network', 'resume') is not None and self.configer.get('network', 'resume_val'):
            self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))
            return

        while self.configer.get('iters') < self.configer.get('solver', 'max_iters'):
            self.__train()

        # use swa to average the model
        if 'swa' in self.configer.get('lr', 'lr_policy'):
            self.optimizer.swap_swa_sgd()
            self.optimizer.bn_update(self.train_loader, self.seg_net)

        self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))

    def summary(self):
        from lib.utils.summary import get_model_summary
        import torch.nn.functional as F
        self.seg_net.eval()

        for j, data_dict in enumerate(self.train_loader):
            print(get_model_summary(self.seg_net, data_dict['img'][0:1]))
            return
class MultiTaskClassifier(object):
    """
      The class for the training phase of Image classification.
    """
    def __init__(self, configer):
        self.configer = configer
        self.runner_state = dict(iters=0,
                                 last_iters=0,
                                 epoch=0,
                                 last_epoch=0,
                                 performance=0,
                                 val_loss=0,
                                 max_performance=0,
                                 min_val_loss=0)

        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = DictAverageMeter()
        self.val_losses = DictAverageMeter()
        self.cls_model_manager = ModelManager(configer)
        self.cls_data_loader = DataLoader(configer)
        self.running_score = RunningScore(configer)

        self.cls_net = self.cls_model_manager.get_model()
        self.solver_dict = self.configer.get(
            self.configer.get('train', 'solver'))
        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(),
                                                      self.solver_dict)
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.cls_net, self.optimizer = RunnerHelper.to_dtype(
            self, self.cls_net, self.optimizer)

        self.train_loaders = dict()
        self.val_loaders = dict()
        for source in range(self.configer.get('data', 'num_data_sources')):
            self.train_loaders[source] = self.cls_data_loader.get_trainloader(
                source=source)
            self.val_loaders[source] = self.cls_data_loader.get_valloader(
                source=source)
        if self.configer.get('data', 'mixup'):
            assert (self.configer.get('data', 'num_data_sources') == 2
                    ), "mixup only support src0 and src1 load the same dataset"

        self.loss = self.cls_model_manager.get_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_2 = []
        params_dict = dict(self.cls_net.named_parameters())
        for key, value in params_dict.items():
            if value.requires_grad:
                if 'backbone' in key:
                    if self.configer.get('network', 'bb_lr_scale') == 0.0:
                        value.requires_grad = False
                    else:
                        lr_1.append(value)
                else:
                    lr_2.append(value)

        params = [{
            'params':
            lr_1,
            'lr':
            self.solver_dict['lr']['base_lr'] *
            self.configer.get('network', 'bb_lr_scale')
        }, {
            'params': lr_2,
            'lr': self.solver_dict['lr']['base_lr']
        }]
        return params

    def run(self):
        """
          Train function of every epoch during train phase.
        """
        if self.configer.get('network', 'resume_val'):
            self.val()

        self.cls_net.train()
        train_loaders = dict()
        for source in self.train_loaders:
            train_loaders[source] = iter(self.train_loaders[source])
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        while self.runner_state['iters'] < self.solver_dict['max_iters']:
            data_dict = dict()
            for source in train_loaders:
                try:
                    tmp_data_dict = next(train_loaders[source])
                    # Log.info('iter={}, source={}'.format(self.runner_state['iters'], source))
                except StopIteration:
                    if source == 0 or source == '0':
                        self.runner_state['epoch'] += 1
                    # Log.info('Repeat: iter={}, source={}'.format(self.runner_state['iters'], source))
                    train_loaders[source] = iter(self.train_loaders[source])
                    tmp_data_dict = next(train_loaders[source])
                for k, v in tmp_data_dict.items():
                    data_dict['src{}_{}'.format(source, k)] = v

            if self.configer.get('data', 'multiscale') is not None:
                scale_ratios = self.configer.get('data', 'multiscale')
                scale_ratio = random.uniform(scale_ratios[0], scale_ratios[-1])
                for key in data_dict:
                    if key.endswith('_img'):
                        data_dict[key] = F.interpolate(
                            data_dict[key],
                            scale_factor=[scale_ratio, scale_ratio],
                            mode='bilinear',
                            align_corners=True)
            if self.configer.get('data', 'mixup'):
                src0_resize = F.interpolate(data_dict['src0_img'],
                                            scale_factor=[
                                                random.uniform(0.4, 0.6),
                                                random.uniform(0.4, 0.6)
                                            ],
                                            mode='bilinear',
                                            align_corners=True)
                b, c, h, w = src0_resize.shape
                pos = random.randint(0, 3)
                if pos == 0:  # top-left
                    data_dict['src1_img'][:, :, 0:h, 0:w] = src0_resize
                elif pos == 1:  # top-right
                    data_dict['src1_img'][:, :, 0:h, -w:] = src0_resize
                elif pos == 2:  # bottom-left
                    data_dict['src1_img'][:, :, -h:, 0:w] = src0_resize
                else:  # bottom-right
                    data_dict['src1_img'][:, :, -h:, -w:] = src0_resize

            data_dict = RunnerHelper.to_device(self, data_dict)
            Trainer.update(
                self,
                warm_list=(0, ),
                warm_lr_list=(self.solver_dict['lr']['base_lr'] *
                              self.configer.get('network', 'bb_lr_scale'), ),
                solver_dict=self.solver_dict)
            self.data_time.update(time.time() - start_time)
            # Forward pass.
            out = self.cls_net(data_dict)
            loss_dict, loss_weight_dict = self.loss(out)
            # Compute the loss of the train batch & backward.
            loss = loss_dict['loss']
            self.train_losses.update(
                {key: loss.item()
                 for key, loss in loss_dict.items()},
                data_dict['src0_img'].size(0))
            self.optimizer.zero_grad()
            if self.configer.get('dtype') == 'fp16':
                with amp.scale_loss(loss, self.optimizer) as scaled_losses:
                    scaled_losses.backward()
            else:
                loss.backward()
            if self.configer.get('network', 'clip_grad'):
                RunnerHelper.clip_grad(self.cls_net, 10.)

            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.solver_dict[
                    'display_iter'] == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {5}\tLoss = {3}\nLossWeight = {4}\n'.
                    format(self.runner_state['epoch'],
                           self.runner_state['iters'],
                           self.solver_dict['display_iter'],
                           self.train_losses.info(),
                           loss_weight_dict,
                           RunnerHelper.get_lr(self.optimizer),
                           batch_time=self.batch_time,
                           data_time=self.data_time))

                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.solver_dict['lr'][
                    'metric'] == 'iters' and self.runner_state[
                        'iters'] == self.solver_dict['max_iters']:
                if self.configer.get('local_rank') == 0:
                    RunnerHelper.save_net(self, self.cls_net, postfix='final')
                break

            if self.runner_state['iters'] % self.solver_dict[
                    'save_iters'] == 0 and self.configer.get(
                        'local_rank') == 0:
                RunnerHelper.save_net(self, self.cls_net)

            # Check to val the current model.
            if self.runner_state['iters'] % self.solver_dict[
                    'test_interval'] == 0:
                self.val()
                if self.configer.get('local_rank') == 0:
                    RunnerHelper.save_net(
                        self,
                        self.cls_net,
                        performance=self.runner_state['performance'])

        self.val()

    def val(self):
        """
          Validation function during the train phase.
        """
        self.cls_net.eval()
        start_time = time.time()
        val_loaders = dict()
        val_to_end = dict()
        all_to_end = False
        for source in self.val_loaders:
            val_loaders[source] = iter(self.val_loaders[source])
            val_to_end[source] = False
        with torch.no_grad():
            while not all_to_end:
                data_dict = dict()
                for source in val_loaders:
                    try:
                        tmp_data_dict = next(val_loaders[source])
                    except StopIteration:
                        val_to_end[source] = True
                        val_loaders[source] = iter(self.val_loaders[source])
                        tmp_data_dict = next(val_loaders[source])
                    for k, v in tmp_data_dict.items():
                        data_dict['src{}_{}'.format(source, k)] = v
                # Forward pass.
                data_dict = RunnerHelper.to_device(self, data_dict)
                out = self.cls_net(data_dict)
                loss_dict, loss_weight_dict = self.loss(out)
                out_dict, label_dict, _ = RunnerHelper.gather(self, out)
                # Compute the loss of the val batch.
                self.running_score.update(out_dict, label_dict)
                self.val_losses.update(
                    {
                        key: loss.mean().item()
                        for key, loss in loss_dict.items()
                    }, data_dict['src0_img'].size(0))
                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()
                # check whether scan over all data sources
                all_to_end = True
                for source in val_to_end:
                    if not val_to_end[source]:
                        all_to_end = False

            Log.info('Test Time {batch_time.sum:.3f}s'.format(
                batch_time=self.batch_time))
            Log.info('TestLoss = {}'.format(self.val_losses.info()))
            Log.info('TestLossWeight = {}'.format(loss_weight_dict))
            Log.info('Top1 ACC = {}'.format(self.running_score.get_top1_acc()))
            Log.info('Top3 ACC = {}'.format(self.running_score.get_top3_acc()))
            Log.info('Top5 ACC = {}'.format(self.running_score.get_top5_acc()))
            top1_acc = yaml.load(self.running_score.get_top1_acc())
            for key in top1_acc:
                if 'src0_label0' in key:
                    self.runner_state['performance'] = top1_acc[key]
                    Log.info('Use acc of {} to compare performace'.format(key))
                    break
            self.running_score.reset()
            self.val_losses.reset()
            self.batch_time.reset()
            self.cls_net.train()
Пример #4
0
class ImageClassifier(object):
    """
      The class for the training phase of Image classification.
    """
    def __init__(self, configer):
        self.configer = configer
        self.runner_state = dict(iters=0,
                                 last_iters=0,
                                 epoch=0,
                                 last_epoch=0,
                                 performance=0,
                                 val_loss=0,
                                 max_performance=0,
                                 min_val_loss=0)

        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = DictAverageMeter()
        self.val_losses = DictAverageMeter()
        self.cls_model_manager = ModelManager(configer)
        self.cls_data_loader = DataLoader(configer)
        self.running_score = RunningScore(configer)

        self.cls_net = self.cls_model_manager.get_model()
        self.solver_dict = self.configer.get(
            self.configer.get('train', 'solver'))
        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(),
                                                      self.solver_dict)
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.cls_net, self.optimizer = RunnerHelper.to_dtype(
            self, self.cls_net, self.optimizer)

        self.train_loader = self.cls_data_loader.get_trainloader()
        self.val_loader = self.cls_data_loader.get_valloader()
        self.loss = self.cls_model_manager.get_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_2 = []
        params_dict = dict(self.cls_net.named_parameters())
        for key, value in params_dict.items():
            if value.requires_grad:
                if 'backbone' in key:
                    if self.configer.get('network', 'bb_lr_scale') == 0.0:
                        value.requires_grad = False
                    else:
                        lr_1.append(value)
                else:
                    lr_2.append(value)

        params = [{
            'params':
            lr_1,
            'lr':
            self.solver_dict['lr']['base_lr'] *
            self.configer.get('network', 'bb_lr_scale')
        }, {
            'params': lr_2,
            'lr': self.solver_dict['lr']['base_lr']
        }]
        return params

    def run(self):
        if self.configer.get('network', 'resume_val'):
            self.val()

        while self.runner_state['iters'] < self.solver_dict['max_iters']:
            self.train()

        self.val()

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.cls_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        self.runner_state['epoch'] += 1
        for i, data_dict in enumerate(self.train_loader):
            data_dict = {'src0_{}'.format(k): v for k, v in data_dict.items()}
            Trainer.update(
                self,
                warm_list=(0, ),
                warm_lr_list=(self.solver_dict['lr']['base_lr'] *
                              self.configer.get('network', 'bb_lr_scale'), ),
                solver_dict=self.solver_dict)
            self.data_time.update(time.time() - start_time)
            data_dict = RunnerHelper.to_device(self, data_dict)
            # Forward pass.
            out = self.cls_net(data_dict)
            loss_dict, _ = self.loss(out)
            # Compute the loss of the train batch & backward.

            loss = loss_dict['loss']
            self.train_losses.update(
                {key: loss.item()
                 for key, loss in loss_dict.items()},
                data_dict['src0_img'].size(0))
            self.optimizer.zero_grad()
            loss.backward()
            if self.configer.get('network', 'clip_grad'):
                RunnerHelper.clip_grad(self.cls_net, 10.)

            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.solver_dict[
                    'display_iter'] == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {4}\tLoss = {3}\n'.format(
                        self.runner_state['epoch'],
                        self.runner_state['iters'],
                        self.solver_dict['display_iter'],
                        self.train_losses.info(),
                        RunnerHelper.get_lr(self.optimizer),
                        batch_time=self.batch_time,
                        data_time=self.data_time))

                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.solver_dict['lr'][
                    'metric'] == 'iters' and self.runner_state[
                        'iters'] == self.solver_dict['max_iters']:
                break

            if self.runner_state['iters'] % self.solver_dict[
                    'save_iters'] == 0 and self.configer.get(
                        'local_rank') == 0:
                RunnerHelper.save_net(self, self.cls_net)

            # Check to val the current model.
            # if self.runner_state['iters'] % self.solver_dict['test_interval'] == 0 \
            #         and not self.configer.get('distributed'):
            #     self.val()
            if self.runner_state['iters'] % self.solver_dict[
                    'test_interval'] == 0:
                self.val()

    def val(self):
        """
          Validation function during the train phase.
        """
        self.cls_net.eval()
        start_time = time.time()
        with torch.no_grad():
            for j, data_dict in enumerate(self.val_loader):
                data_dict = {
                    'src0_{}'.format(k): v
                    for k, v in data_dict.items()
                }
                # Forward pass.
                data_dict = RunnerHelper.to_device(self, data_dict)
                out = self.cls_net(data_dict)
                loss_dict = self.loss(out)
                out_dict, label_dict, _ = RunnerHelper.gather(self, out)
                self.running_score.update(out_dict, label_dict)
                self.val_losses.update(
                    {key: loss.item()
                     for key, loss in loss_dict.items()},
                    data_dict['src0_img'].size(0))

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            # RunnerHelper.save_net(self, self.cls_net) # only local_rank=0 can save net
            # Print the log info & reset the states.
            Log.info('Test Time {batch_time.sum:.3f}s'.format(
                batch_time=self.batch_time))
            Log.info('TestLoss = {}'.format(self.val_losses.info()))
            Log.info('Top1 ACC = {}'.format(self.running_score.get_top1_acc()))
            Log.info('Top3 ACC = {}'.format(self.running_score.get_top3_acc()))
            Log.info('Top5 ACC = {}'.format(self.running_score.get_top5_acc()))
            self.batch_time.reset()
            self.batch_time.reset()
            self.val_losses.reset()
            self.running_score.reset()
            self.cls_net.train()
Пример #5
0
class Trainer(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.running_score = RunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.loss_manager = LossManager(configer)
        self.module_runner = ModuleRunner(configer)
        self.model_manager = ModelManager(configer)
        self.data_loader = DataLoader(configer)
        self.optim_scheduler = OptimScheduler(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)

        Log.info('Params Group Method: {}'.format(
            self.configer.get('optim', 'group_method')))
        if self.configer.get('optim', 'group_method') == 'decay':
            params_group = self.group_weight(self.seg_net)
        else:
            assert self.configer.get('optim', 'group_method') is None
            params_group = self._get_parameters()

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(
            params_group)

        self.train_loader = self.data_loader.get_trainloader()
        self.val_loader = self.data_loader.get_valloader()

        self.pixel_loss = self.loss_manager.get_seg_loss()

    @staticmethod
    def group_weight(module):
        group_decay = []
        group_no_decay = []
        for m in module.modules():
            if isinstance(m, nn.Linear):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.conv._ConvNd):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            else:
                if hasattr(m, 'weight'):
                    group_no_decay.append(m.weight)
                if hasattr(m, 'bias'):
                    group_no_decay.append(m.bias)

        assert len(list(
            module.parameters())) == len(group_decay) + len(group_no_decay)
        groups = [
            dict(params=group_decay),
            dict(params=group_no_decay, weight_decay=.0)
        ]
        return groups

    def _get_parameters(self):
        bb_lr = []
        nbb_lr = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                nbb_lr.append(value)
            else:
                bb_lr.append(value)

        params = [{
            'params': bb_lr,
            'lr': self.configer.get('lr', 'base_lr')
        }, {
            'params':
            nbb_lr,
            'lr':
            self.configer.get('lr', 'base_lr') *
            self.configer.get('lr', 'nbb_mult')
        }]
        return params

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        start_time = time.time()

        for i, data_dict in enumerate(self.train_loader):
            if self.configer.get('lr', 'metric') == 'iters':
                self.scheduler.step(self.configer.get('iters'))
            else:
                self.scheduler.step(self.configer.get('epoch'))

            if self.configer.get('lr', 'is_warm'):
                self.module_runner.warm_lr(self.configer.get('iters'),
                                           self.scheduler,
                                           self.optimizer,
                                           backbone_list=[
                                               0,
                                           ])
            inputs = data_dict['img']
            targets = data_dict['labelmap']
            self.data_time.update(time.time() - start_time)
            # Change the data type.
            # inputs, targets = self.module_runner.to_device(inputs, targets)

            # Forward pass.
            outputs = self.seg_net(inputs)
            # outputs = self.module_utilizer.gather(outputs)
            # Compute the loss of the train batch & backward.
            loss = self.pixel_loss(outputs,
                                   targets,
                                   gathered=self.configer.get(
                                       'network', 'gathered'))
            if self.configer.exists('train', 'loader') and self.configer.get(
                    'train', 'loader') == 'ade20k':
                batch_size = self.configer.get(
                    'train', 'batch_size') * self.configer.get(
                        'train', 'batch_per_gpu')
                self.train_losses.update(loss.item(), batch_size)
            else:
                self.train_losses.update(loss.item(), inputs.size(0))

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.configer.get('epoch'),
                            self.configer.get('iters'),
                            self.configer.get('solver', 'display_iter'),
                            self.module_runner.get_lr(self.optimizer),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.configer.get('iters') == self.configer.get(
                    'solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'test_interval') == 0:
                self.__val()

        self.configer.plus_one('epoch')

    def __val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            inputs = data_dict['img']
            targets = data_dict['labelmap']

            with torch.no_grad():
                # Change the data type.
                inputs, targets = self.module_runner.to_device(inputs, targets)
                # Forward pass.
                outputs = self.seg_net(inputs)
                # Compute the loss of the val batch.
                loss = self.pixel_loss(outputs,
                                       targets,
                                       gathered=self.configer.get(
                                           'network', 'gathered'))
                outputs = self.module_runner.gather(outputs)

            self.val_losses.update(loss.item(), inputs.size(0))
            self._update_running_score(outputs[-1], data_dict['meta'])
            # self.seg_running_score.update(pred.max(1)[1].cpu().numpy(), targets.cpu().numpy())

            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.configer.update(['performance'],
                             self.running_score.get_mean_iou())
        self.configer.update(['val_loss'], self.val_losses.avg)
        self.module_runner.save_net(self.seg_net, save_mode='performance')
        self.module_runner.save_net(self.seg_net, save_mode='val_loss')

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                                loss=self.val_losses))
        Log.info('Mean IOU: {}\n'.format(self.running_score.get_mean_iou()))
        Log.info('Pixel ACC: {}\n'.format(self.running_score.get_pixel_acc()))
        self.batch_time.reset()
        self.val_losses.reset()
        self.running_score.reset()
        self.seg_net.train()

    def _update_running_score(self, pred, metas):
        pred = pred.permute(0, 2, 3, 1)
        for i in range(pred.size(0)):
            ori_img_size = metas[i]['ori_img_size']
            border_size = metas[i]['border_size']
            ori_target = metas[i]['ori_target']
            total_logits = cv2.resize(
                pred[i, :border_size[1], :border_size[0]].cpu().numpy(),
                tuple(ori_img_size),
                interpolation=cv2.INTER_CUBIC)
            labelmap = np.argmax(total_logits, axis=-1)
            self.running_score.update(labelmap[None], ori_target[None])

    def train(self):
        # cudnn.benchmark = True
        if self.configer.get('network',
                             'resume') is not None and self.configer.get(
                                 'network', 'resume_val'):
            self.__val()

        while self.configer.get('iters') < self.configer.get(
                'solver', 'max_iters'):
            self.__train()

        self.__val(data_loader=self.data_loader.get_valloader(dataset='val'))
        self.__val(data_loader=self.data_loader.get_valloader(dataset='train'))
Пример #6
0
class Tester(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_parser = SegParser(configer)
        self.seg_model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)
        self.module_runner = ModuleRunner(configer)
        self.device = torch.device('cpu' if self.configer.get('gpu') is None else 'cuda')
        self.seg_net = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.seg_model_manager.semantic_segmentor()
        self.seg_net = self.module_runner.load_net(self.seg_net)
        self.seg_net.eval()

    def _get_blob(self, ori_image, scale=None):
        assert scale is not None
        image = None
        if self.configer.exists('test', 'input_size'):
            image = self.blob_helper.make_input(image=ori_image,
                                                input_size=self.configer.get('test', 'input_size'),
                                                scale=scale)

        elif self.configer.exists('test', 'min_side_length') and not self.configer.exists('test', 'max_side_length'):
            image = self.blob_helper.make_input(image=ori_image,
                                                min_side_length=self.configer.get('test', 'min_side_length'),
                                                scale=scale)

        elif not self.configer.exists('test', 'min_side_length') and self.configer.exists('test', 'max_side_length'):
            image = self.blob_helper.make_input(image=ori_image,
                                                max_side_length=self.configer.get('test', 'max_side_length'),
                                                scale=scale)

        elif self.configer.exists('test', 'min_side_length') and self.configer.exists('test', 'max_side_length'):
            image = self.blob_helper.make_input(image=ori_image,
                                                min_side_length=self.configer.get('test', 'min_side_length'),
                                                max_side_length=self.configer.get('test', 'max_side_length'),
                                                scale=scale)

        else:
            Log.error('Test setting error')
            exit(1)

        b, c, h, w = image.size()
        border_hw = [h, w]
        if self.configer.exists('test', 'fit_stride'):
            stride = self.configer.get('test', 'fit_stride')

            pad_w = 0 if (w % stride == 0) else stride - (w % stride)  # right
            pad_h = 0 if (h % stride == 0) else stride - (h % stride)  # down

            expand_image = torch.zeros((b, c, h + pad_h, w + pad_w)).to(image.device)
            expand_image[:, :, 0:h, 0:w] = image
            image = expand_image

        return image, border_hw

    def __test_img(self, image_path, label_path, vis_path, raw_path):
        Log.info('Image Path: {}'.format(image_path))
        ori_image = ImageHelper.read_image(image_path,
                                           tool=self.configer.get('data', 'image_tool'),
                                           mode=self.configer.get('data', 'input_mode'))
        total_logits = None
        if self.configer.get('test', 'mode') == 'ss_test':
            total_logits = self.ss_test(ori_image)

        elif self.configer.get('test', 'mode') == 'sscrop_test':
            total_logits = self.sscrop_test(ori_image)

        elif self.configer.get('test', 'mode') == 'ms_test':
            total_logits = self.ms_test(ori_image)

        elif self.configer.get('test', 'mode') == 'mscrop_test':
            total_logits = self.mscrop_test(ori_image)

        else:
            Log.error('Invalid test mode:{}'.format(self.configer.get('test', 'mode')))
            exit(1)

        label_map = np.argmax(total_logits, axis=-1)
        label_img = np.array(label_map, dtype=np.uint8)
        ori_img_bgr = ImageHelper.get_cv2_bgr(ori_image, mode=self.configer.get('data', 'input_mode'))
        image_canvas = self.seg_parser.colorize(label_img, image_canvas=ori_img_bgr)
        ImageHelper.save(image_canvas, save_path=vis_path)
        ImageHelper.save(ori_image, save_path=raw_path)

        if self.configer.exists('data', 'label_list'):
            label_img = self.__relabel(label_img)

        if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'):
            label_img = label_img + 1
            label_img = label_img.astype(np.uint8)

        label_img = Image.fromarray(label_img, 'P')
        Log.info('Label Path: {}'.format(label_path))
        ImageHelper.save(label_img, label_path)

    def ss_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        image, border_hw = self._get_blob(ori_image, scale=1.0)
        results = self._predict(image)
        results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                             (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
        total_logits += results
        return total_logits

    def sscrop_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        image, border_hw = self._get_blob(ori_image, scale=1.0)
        crop_size = self.configer.get('test', 'crop_size')
        if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
            results = self._crop_predict(image, crop_size)
        else:
            results = self._predict(image)

        results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                             (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
        total_logits += results
        return total_logits

    def mscrop_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        crop_size = self.configer.get('test', 'crop_size')
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        for scale in self.configer.get('test', 'scale_search'):
            image, border_hw = self._get_blob(ori_image, scale=scale)
            if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
                results = self._crop_predict(image, crop_size)
            else:
                results = self._predict(image)

            results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                                 (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

            if self.configer.get('data', 'image_tool') == 'cv2':
                mirror_image = cv2.flip(ori_image, 1)
            else:
                mirror_image = ori_image.transpose(Image.FLIP_LEFT_RIGHT)

            image, border_hw = self._get_blob(mirror_image, scale=1.0)
            if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
                results = self._crop_predict(image, crop_size)
            else:
                results = self._predict(image)

            results = results[:border_hw[0], :border_hw[1]]
            results = cv2.resize(results[:, ::-1], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

        return total_logits

    def ms_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        for scale in self.configer.get('test', 'scale_search'):
            image, border_hw = self._get_blob(ori_image, scale=scale)
            results = self._predict(image)
            results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                                 (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

            if self.configer.get('data', 'image_tool') == 'cv2':
                mirror_image = cv2.flip(ori_image, 1)
            else:
                mirror_image = ori_image.transpose(Image.FLIP_LEFT_RIGHT)

            image, border_hw = self._get_blob(mirror_image, scale=scale)
            results = self._predict(image)
            results = results[:border_hw[0], :border_hw[1]]
            results = cv2.resize(results[:, ::-1], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

        return total_logits

    def _crop_predict(self, image, crop_size):
        height, width = image.size()[2:]
        np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
        height_starts = self._decide_intersection(height, crop_size[1])
        width_starts = self._decide_intersection(width, crop_size[0])
        split_crops = []
        for height in height_starts:
            for width in width_starts:
                image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]]
                split_crops.append(image_crop[np.newaxis, :])

        split_crops = np.concatenate(split_crops, axis=0)  # (n, crop_image_size, crop_image_size, 3)
        inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device)
        with torch.no_grad():
            results = self.seg_net.forward(inputs)
            results = results[-1].permute(0, 2, 3, 1).cpu().numpy()

        reassemble = np.zeros((np_image.shape[0], np_image.shape[1], results.shape[-1]), np.float32)
        index = 0
        for height in height_starts:
            for width in width_starts:
                reassemble[height:height+crop_size[1], width:width+crop_size[0]] += results[index]
                index += 1

        return reassemble

    def _decide_intersection(self, total_length, crop_length):
        stride = int(crop_length * self.configer.get('test', 'crop_stride_ratio'))            # set the stride as the paper do
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride*i)

        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length - crop_length)  # must cover the total image

        return cropped_starting

    def _predict(self, inputs):
        with torch.no_grad():
            results = self.seg_net.forward(inputs)
            results = results[-1].squeeze(0).permute(1, 2, 0).cpu().numpy()

        return results

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get('data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self):
        base_dir = os.path.join(self.configer.get('project_dir'), 'results', self.configer.get('dataset'))

        test_img = self.configer.get('test', 'test_img')
        test_dir = self.configer.get('test', 'test_dir')
        if test_img is None and test_dir is None:
            Log.error('test_img & test_dir not exists.')
            exit(1)

        if test_img is not None and test_dir is not None:
            Log.error('Either test_img or test_dir.')
            exit(1)

        if test_img is not None:
            base_dir = os.path.join(base_dir, 'test_img')
            filename = test_img.rstrip().split('/')[-1]
            label_path = os.path.join(base_dir, 'label', '{}.png'.format('.'.join(filename.split('.')[:-1])))
            raw_path = os.path.join(base_dir, 'raw', filename)
            vis_path = os.path.join(base_dir, 'vis', '{}_vis.png'.format('.'.join(filename.split('.')[:-1])))
            FileHelper.make_dirs(label_path, is_file=True)
            FileHelper.make_dirs(raw_path, is_file=True)
            FileHelper.make_dirs(vis_path, is_file=True)

            self.__test_img(test_img, label_path, vis_path, raw_path)

        else:
            base_dir = os.path.join(base_dir, 'test_dir',
                                    self.configer.get('checkpoints', 'checkpoints_name'),
                                    self.configer.get('test', 'out_dir'))
            FileHelper.make_dirs(base_dir)

            for filename in FileHelper.list_dir(test_dir):
                image_path = os.path.join(test_dir, filename)
                label_path = os.path.join(base_dir, 'label', '{}.png'.format('.'.join(filename.split('.')[:-1])))
                raw_path = os.path.join(base_dir, 'raw', filename)
                vis_path = os.path.join(base_dir, 'vis', '{}_vis.png'.format('.'.join(filename.split('.')[:-1])))
                FileHelper.make_dirs(label_path, is_file=True)
                FileHelper.make_dirs(raw_path, is_file=True)
                FileHelper.make_dirs(vis_path, is_file=True)

                self.__test_img(image_path, label_path, vis_path, raw_path)

    def debug(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'vis/results', self.configer.get('dataset'), 'debug')

        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        count = 0
        for i, data_dict in enumerate(self.seg_data_loader.get_trainloader()):
            inputs = data_dict['img']
            targets = data_dict['labelmap']
            for j in range(inputs.size(0)):
                count = count + 1
                if count > 20:
                    exit(1)

                image_bgr = self.blob_helper.tensor2bgr(inputs[j])
                label_map = targets[j].numpy()
                image_canvas = self.seg_parser.colorize(label_map, image_canvas=image_bgr)
                cv2.imwrite(os.path.join(base_dir, '{}_{}_vis.png'.format(i, j)), image_canvas)
                cv2.imshow('main', image_canvas)
                cv2.waitKey()