class FCNSegmentor(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = DictAverageMeter()
        self.val_losses = DictAverageMeter()
        self.seg_running_score = SegRunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = dict()

        self._init_model()

    def _init_model(self):
        self.seg_net = self.seg_model_manager.get_seg_model()
        self.seg_net = RunnerHelper.load_net(self, self.seg_net)

        self.optimizer, self.scheduler = Trainer.init(
            self._get_parameters(), self.configer.get('solver'))

        self.train_loader = self.seg_data_loader.get_trainloader()
        self.val_loader = self.seg_data_loader.get_valloader()

        self.loss = self.seg_model_manager.get_seg_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_10 = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                lr_10.append(value)
            else:
                lr_1.append(value)

        params = [{
            'params': lr_1,
            'lr': self.configer.get('solver', 'lr')['base_lr']
        }, {
            'params': lr_10,
            'lr': self.configer.get('solver', 'lr')['base_lr'] * 1.0
        }]
        return params

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.

        for i, data_dict in enumerate(self.train_loader):
            Trainer.update(self,
                           warm_list=(0, ),
                           solver_dict=self.configer.get('solver'))
            self.data_time.update(time.time() - start_time)

            # Forward pass.
            data_dict = RunnerHelper.to_device(self, data_dict)
            out = self.seg_net(data_dict)
            # Compute the loss of the train batch & backward.
            loss_dict = self.loss(out)
            loss = loss_dict['loss']
            self.train_losses.update(
                {key: loss.item()
                 for key, loss in loss_dict.items()}, data_dict['img'].size(0))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {4}\tLoss = {3}\n'.format(
                        self.runner_state['epoch'],
                        self.runner_state['iters'],
                        self.configer.get('solver', 'display_iter'),
                        self.train_losses.info(),
                        RunnerHelper.get_lr(self.optimizer),
                        batch_time=self.batch_time,
                        data_time=self.data_time))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.runner_state['iters'] % self.configer.get('solver.save_iters') == 0 \
                    and self.configer.get('local_rank') == 0:
                RunnerHelper.save_net(self, self.seg_net)

            if self.configer.get('solver', 'lr')['metric'] == 'iters' \
                    and self.runner_state['iters'] == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.runner_state['iters'] % self.configer.get('solver', 'test_interval') == 0 \
                    and not self.configer.get('network.distributed'):
                self.val()

        self.runner_state['epoch'] += 1

    def val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            data_dict = RunnerHelper.to_device(self, data_dict)
            with torch.no_grad():
                # Forward pass.
                out = self.seg_net(data_dict)
                loss_dict = self.loss(out)
                # Compute the loss of the val batch.
                out_dict, _ = RunnerHelper.gather(self, out)

            self.val_losses.update(
                {key: loss.item()
                 for key, loss in loss_dict.items()}, data_dict['img'].size(0))
            self._update_running_score(out_dict['out'],
                                       DCHelper.tolist(data_dict['meta']))

            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.runner_state['performance'] = self.seg_running_score.get_mean_iou(
        )
        self.runner_state['val_loss'] = self.val_losses.avg['loss']
        RunnerHelper.save_net(
            self,
            self.seg_net,
            performance=self.seg_running_score.get_mean_iou(),
            val_loss=self.val_losses.avg['loss'])

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss = {0}\n'.format(self.val_losses.info(),
                                       batch_time=self.batch_time))
        Log.info('Mean IOU: {}\n'.format(
            self.seg_running_score.get_mean_iou()))
        Log.info('Pixel ACC: {}\n'.format(
            self.seg_running_score.get_pixel_acc()))
        self.batch_time.reset()
        self.val_losses.reset()
        self.seg_running_score.reset()
        self.seg_net.train()

    def _update_running_score(self, pred, metas):
        pred = pred.permute(0, 2, 3, 1)
        for i in range(pred.size(0)):
            border_size = metas[i]['border_wh']
            ori_target = metas[i]['ori_target']
            total_logits = cv2.resize(
                pred[i, :border_size[1], :border_size[0]].cpu().numpy(),
                tuple(metas[i]['ori_img_wh']),
                interpolation=cv2.INTER_CUBIC)
            labelmap = np.argmax(total_logits, axis=-1)
            self.seg_running_score.update(labelmap[None], ori_target[None])
Exemple #2
0
class ImageTranslator(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)

        self.gan_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = dict()

        self._init_model()

    def _init_model(self):
        self.gan_net = self.model_manager.gan_model()
        self.gan_net = RunnerHelper.load_net(self, self.gan_net)

        self.optimizer_G, self.scheduler_G = Trainer.init(
            self._get_parameters()[0], self.configer.get('solver'))
        self.optimizer_D, self.scheduler_D = Trainer.init(
            self._get_parameters()[1], self.configer.get('solver'))

        self.train_loader = self.seg_data_loader.get_trainloader()
        self.val_loader = self.seg_data_loader.get_valloader()

    def _get_parameters(self):
        params_G = []
        params_D = []
        params_dict = dict(self.gan_net.named_parameters())
        for key, value in params_dict.items():
            if 'G' not in key:
                params_D.append(value)
            else:
                params_G.append(value)

        return params_G, params_D

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.gan_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        self.scheduler_G.step(self.runner_state['epoch'])
        self.scheduler_D.step(self.runner_state['epoch'])
        for i, data_dict in enumerate(self.train_loader):
            self.data_time.update(time.time() - start_time)

            # Forward pass.
            out_dict = self.gan_net(data_dict)
            # outputs = self.module_utilizer.gather(outputs)
            self.optimizer_G.zero_grad()
            loss_G = out_dict['loss_G'].mean()
            loss_G.backward()
            self.optimizer_G.step()

            self.optimizer_D.zero_grad()
            loss_D = out_dict['loss_D'].mean()
            loss_D.backward()
            self.optimizer_D.step()
            loss = loss_G + loss_D
            self.train_losses.update(loss.item(),
                                     len(DCHelper.tolist(data_dict['meta'])))

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.runner_state['epoch'],
                            self.runner_state['iters'],
                            self.configer.get('solver', 'display_iter'), [
                                RunnerHelper.get_lr(self.optimizer_G),
                                RunnerHelper.get_lr(self.optimizer_D)
                            ],
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.configer.get('solver', 'lr')['metric'] == 'iters' \
                    and self.runner_state['iters'] == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'test_interval') == 0:
                self.val()

        self.runner_state['epoch'] += 1

    def val(self):
        """
          Validation function during the train phase.
        """
        self.gan_net.eval()
        start_time = time.time()

        for j, data_dict in enumerate(self.val_loader):
            with torch.no_grad():
                # Forward pass.
                out_dict = self.gan_net(data_dict)
                # Compute the loss of the val batch.

            self.val_losses.update(
                out_dict['loss_G'].mean().item() +
                out_dict['loss_D'].mean().item(),
                len(DCHelper.tolist(data_dict['meta'])))
            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        RunnerHelper.save_net(self, self.gan_net, val_loss=self.val_losses.avg)

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                                loss=self.val_losses))
        self.batch_time.reset()
        self.val_losses.reset()
        self.gan_net.train()
Exemple #3
0
class PoseEstimator(object):
    """
      The class for Pose Estimation. Include train, val, test & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = DictAverageMeter()
        self.val_losses = DictAverageMeter()
        self.pose_visualizer = PoseVisualizer(configer)
        self.pose_model_manager = ModelManager(configer)
        self.pose_data_loader = DataLoader(configer)

        self.pose_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = dict()

        self._init_model()

    def _init_model(self):
        self.pose_net = self.pose_model_manager.get_pose_model()
        self.pose_net = RunnerHelper.load_net(self, self.pose_net)

        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.configer.get('solver'))

        self.train_loader = self.pose_data_loader.get_trainloader()
        self.val_loader = self.pose_data_loader.get_valloader()

        self.pose_loss = self.pose_model_manager.get_pose_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_2 = []
        params_dict = dict(self.pose_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                lr_2.append(value)
            else:
                lr_1.append(value)

        params = [{'params': lr_1, 'lr': self.configer.get('solver', 'lr')['base_lr'], 'weight_decay': 0.0},
                  {'params': lr_2, 'lr': self.configer.get('solver', 'lr')['base_lr'], 'weight_decay': 0.0},]

        return params

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.pose_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        self.runner_state['epoch'] += 1
        for i, data_dict in enumerate(self.train_loader):
            Trainer.update(self, warm_list=(0,), solver_dict=self.configer.get('solver'))
            self.data_time.update(time.time() - start_time)
            # Forward pass.
            out = self.pose_net(data_dict)

            # Compute the loss of the train batch & backward.
            loss_dict = self.pose_loss(out)

            loss = loss_dict['loss']
            self.train_losses.update({key: loss.item() for key, loss in loss_dict.items()}, data_dict['img'].size(0))

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.configer.get('solver', 'display_iter') == 0:
                Log.info('Train Epoch: {0}\tTrain Iteration: {1}\t'
                         'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                         'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                         'Learning rate = {4}\tLoss = {3}\n'.format(
                    self.runner_state['epoch'], self.runner_state['iters'],
                    self.configer.get('solver', 'display_iter'), self.train_losses.info(),
                    RunnerHelper.get_lr(self.optimizer), batch_time=self.batch_time, data_time=self.data_time))

                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.configer.get('solver', 'lr')['metric'] == 'iters' \
                    and self.runner_state['iters'] == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.runner_state['iters'] % self.configer.get('solver', 'test_interval') == 0:
                self.val()

    def val(self):
        """
          Validation function during the train phase.
        """
        self.pose_net.eval()
        start_time = time.time()

        with torch.no_grad():
            for i, data_dict in enumerate(self.val_loader):
                # Forward pass.
                out = self.pose_net(data_dict)
                # Compute the loss of the val batch.
                loss_dict = self.pose_loss(out)

                self.val_losses.update({key: loss.item() for key, loss in loss_dict.items()}, data_dict['img'].size(0))

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            self.runner_state['val_loss'] = self.val_losses.avg['loss']
            RunnerHelper.save_net(self, self.pose_net, val_loss=self.val_losses.avg['loss'])
            # Print the log info & reset the states.
            Log.info(
                'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                'Loss {0}\n'.format(self.val_losses.info(), batch_time=self.batch_time))
            self.batch_time.reset()
            self.val_losses.reset()
            self.pose_net.train()
Exemple #4
0
class YOLOv3(object):
    """
      The class for YOLO v3. Include train, val, test & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.det_visualizer = DetVisualizer(configer)
        self.det_model_manager = ModelManager(configer)
        self.det_data_loader = DataLoader(configer)
        self.det_running_score = DetRunningScore(configer)

        self.det_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = dict()

        self._init_model()

    def _init_model(self):
        self.det_net = self.det_model_manager.object_detector()
        self.det_net = RunnerHelper.load_net(self, self.det_net)

        self.optimizer, self.scheduler = Trainer.init(
            self._get_parameters(), self.configer.get('solver'))

        self.train_loader = self.det_data_loader.get_trainloader()
        self.val_loader = self.det_data_loader.get_valloader()

    def _get_parameters(self):
        lr_1 = []
        lr_10 = []
        params_dict = dict(self.det_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                lr_10.append(value)
            else:
                lr_1.append(value)

        params = [{
            'params': lr_1,
            'lr': self.configer.get('solver', 'lr')['base_lr']
        }, {
            'params': lr_10,
            'lr': self.configer.get('solver', 'lr')['base_lr'] * 10.
        }]

        return params

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.det_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        self.runner_state['epoch'] += 1

        # data_tuple: (inputs, heatmap, maskmap, vecmap)
        for i, data_dict in enumerate(self.train_loader):
            Trainer.update(self,
                           warm_list=(0, ),
                           warm_lr_list=(self.configer.get('solver',
                                                           'lr')['base_lr'], ),
                           solver_dict=self.configer.get('solver'))

            self.data_time.update(time.time() - start_time)
            # Forward pass.
            out_dict = self.det_net(data_dict)
            # Compute the loss of the train batch & backward.
            loss = out_dict['loss'].mean()
            self.train_losses.update(loss.item(),
                                     len(DCHelper.tolist(data_dict['meta'])))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.runner_state['epoch'],
                            self.runner_state['iters'],
                            self.configer.get('solver', 'display_iter'),
                            RunnerHelper.get_lr(self.optimizer),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.configer.get('solver', 'lr')['metric'] == 'iters' \
                    and self.runner_state['iters'] == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'test_interval') == 0:
                self.val()

    def val(self):
        """
          Validation function during the train phase.
        """
        self.det_net.eval()
        start_time = time.time()
        with torch.no_grad():
            for i, data_dict in enumerate(self.val_loader):
                # Forward pass.
                out_dict = self.det_net(data_dict)

                # Compute the loss of the val batch.
                loss = out_dict['loss'].mean()
                self.val_losses.update(loss.item(),
                                       len(DCHelper.tolist(data_dict['meta'])))

                batch_detections = YOLOv3Test.decode(
                    out_dict['dets'], self.configer,
                    DCHelper.tolist(data_dict['meta']))
                batch_pred_bboxes = self.__get_object_list(batch_detections)

                self.det_running_score.update(batch_pred_bboxes, [
                    item['ori_bboxes']
                    for item in DCHelper.tolist(data_dict['meta'])
                ], [
                    item['ori_labels']
                    for item in DCHelper.tolist(data_dict['meta'])
                ])

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            RunnerHelper.save_net(self,
                                  self.det_net,
                                  iters=self.runner_state['iters'])
            # Print the log info & reset the states.
            Log.info(
                'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                               loss=self.val_losses))
            Log.info('Val mAP: {}'.format(self.det_running_score.get_mAP()))
            self.det_running_score.reset()
            self.batch_time.reset()
            self.val_losses.reset()
            self.det_net.train()

    def __get_object_list(self, batch_detections):
        batch_pred_bboxes = list()
        for idx, detections in enumerate(batch_detections):
            object_list = list()
            if detections is not None:
                for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
                    cf = float('%.2f' % conf.item())
                    cls_pred = int(cls_pred.item())
                    object_list.append([
                        x1.item(),
                        y1.item(),
                        x2.item(),
                        y2.item(), cls_pred, cf
                    ])

            batch_pred_bboxes.append(object_list)

        return batch_pred_bboxes
Exemple #5
0
class FaceGAN(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.model_manager = ModelManager(configer)
        self.seg_data_loader = DataLoader(configer)

        self.gan_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = dict()

        self._init_model()

    def _init_model(self):
        self.gan_net = self.model_manager.gan_model()
        self.gan_net = RunnerHelper.load_net(self, self.gan_net)

        self.optimizer, self.scheduler = Trainer.init(
            self._get_parameters(), self.configer.get('solver'))

        self.train_loader = self.seg_data_loader.get_trainloader()
        self.val_loader = self.seg_data_loader.get_valloader()

    def _get_parameters(self):

        return self.gan_net.parameters()

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.gan_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        for i, data_dict in enumerate(self.train_loader):
            Trainer.update(self, solver_dict=self.configer.get('solver'))
            self.data_time.update(time.time() - start_time)

            # Forward pass.
            out_dict = self.gan_net(data_dict)
            # outputs = self.module_utilizer.gather(outputs)
            loss = out_dict['loss'].mean()
            self.train_losses.update(loss.item(),
                                     len(DCHelper.tolist(data_dict['meta'])))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.runner_state['epoch'],
                            self.runner_state['iters'],
                            self.configer.get('solver', 'display_iter'),
                            RunnerHelper.get_lr(self.optimizer),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.configer.get('solver', 'lr')['metric'] == 'iters' \
                    and self.runner_state['iters'] == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'test_interval') == 0:
                self.val()

        self.runner_state['epoch'] += 1

    def val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.gan_net.eval()
        start_time = time.time()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):

            with torch.no_grad():
                # Forward pass.
                out_dict = self.gan_net(data_dict)
                # Compute the loss of the val batch.

            self.val_losses.update(out_dict['loss'].mean().item(),
                                   len(DCHelper.tolist(data_dict['meta'])))
            meta_list = DCHelper.tolist(data_dict['meta'])
            probe_features = []
            gallery_features = []
            probe_labels = []
            gallery_labels = []
            for idx in range(len(meta_list)):
                gallery_features.append(out_dict['featB'][idx].cpu().numpy())
                gallery_labels.append(meta_list[idx]['labelB'])
                probe_features.append(out_dict['featA'][idx].cpu().numpy())
                probe_labels.append(meta_list[idx]['labelA'])

            rank_1, vr_far_001 = FaceGANTest.decode(probe_features,
                                                    gallery_features,
                                                    probe_labels,
                                                    gallery_labels)
            Log.info('Rank1 accuracy is {}'.format(rank_1))
            Log.info('VR@FAR=0.1% accuracy is {}'.format(vr_far_001))
            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        RunnerHelper.save_net(self, self.gan_net, val_loss=self.val_losses.avg)

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                                loss=self.val_losses))
        self.batch_time.reset()
        self.val_losses.reset()
        self.gan_net.train()
Exemple #6
0
class ImageClassifier(object):
    """
      The class for the training phase of Image classification.
    """
    def __init__(self, configer):
        self.configer = configer
        self.runner_state = dict()

        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = DictAverageMeter()
        self.val_losses = DictAverageMeter()
        self.cls_model_manager = ModelManager(configer)
        self.cls_data_loader = DataLoader(configer)
        self.running_score = ClsRunningScore(configer)

        self.cls_net = self.cls_model_manager.get_cls_model()
        self.solver_dict = self.configer.get('solver')
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(),
                                                      self.solver_dict)
        self.train_loader = self.cls_data_loader.get_trainloader()
        self.val_loader = self.cls_data_loader.get_valloader()
        self.loss = self.cls_model_manager.get_cls_loss()

    def _init_model(self):
        self.cls_net = self.cls_model_manager.get_cls_model()
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.optimizer, self.scheduler = Trainer.init(
            self._get_parameters(), self.configer.get('solver'))

        self.train_loader = self.cls_data_loader.get_trainloader()
        self.val_loader = self.cls_data_loader.get_valloader()

        self.ce_loss = self.cls_model_manager.get_cls_loss()

    def _get_parameters(self):
        if self.solver_dict.get('optim.wdall', default=True):
            lr_1 = []
            lr_2 = []
            params_dict = dict(self.cls_net.named_parameters())
            for key, value in params_dict.items():
                if value.requires_grad:
                    if 'backbone' in key:
                        if self.configer.get('solver.lr.bb_lr_scale') == 0.0:
                            value.requires_grad = False
                        else:
                            lr_1.append(value)
                    else:
                        lr_2.append(value)

            params = [{
                'params':
                lr_1,
                'lr':
                self.solver_dict['lr']['base_lr'] *
                self.configer.get('solver.lr.bb_lr_scale')
            }, {
                'params': lr_2,
                'lr': self.solver_dict['lr']['base_lr']
            }]
        else:
            no_decay_list = []
            decay_list = []
            no_decay_name = []
            decay_name = []
            for m in self.cls_net.modules():
                if (hasattr(m, 'groups') and m.groups > 1) or isinstance(m, torch.nn.BatchNorm2d) \
                        or m.__class__.__name__ == 'GL':
                    no_decay_list += m.parameters(recurse=False)
                    for name, p in m.named_parameters(recurse=False):
                        no_decay_name.append(m.__class__.__name__ + name)
                else:
                    for name, p in m.named_parameters(recurse=False):
                        if 'bias' in name:
                            no_decay_list.append(p)
                            no_decay_name.append(m.__class__.__name__ + name)
                        else:
                            decay_list.append(p)
                            decay_name.append(m.__class__.__name__ + name)
            Log.info('no decay list = {}'.format(no_decay_name))
            Log.info('decay list = {}'.format(decay_name))
            params = [{
                'params': no_decay_list,
                'weight_decay': 0
            }, {
                'params': decay_list
            }]

        return params

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.cls_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        self.runner_state['epoch'] += 1
        for i, data_dict in enumerate(self.train_loader):
            Trainer.update(
                self,
                warm_list=(0, 1),
                warm_lr_list=(self.solver_dict['lr']['base_lr'] *
                              self.configer.get('solver.lr.bb_lr_scale'),
                              self.solver_dict['lr']['base_lr']),
                solver_dict=self.solver_dict)
            self.data_time.update(time.time() - start_time)
            data_dict = RunnerHelper.to_device(self, data_dict)
            # Forward pass.
            out = self.cls_net(data_dict)
            loss_dict = self.loss(out)
            # Compute the loss of the train batch & backward.

            loss = loss_dict['loss']
            self.train_losses.update(
                {key: loss.item()
                 for key, loss in loss_dict.items()}, data_dict['img'].size(0))
            self.optimizer.zero_grad()
            loss.backward()
            if self.configer.get('network', 'clip_grad', default=False):
                RunnerHelper.clip_grad(self.cls_net, 10.)

            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.runner_state['iters'] % self.solver_dict[
                    'display_iter'] == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {4}\tLoss = {3}\n'.format(
                        self.runner_state['epoch'],
                        self.runner_state['iters'],
                        self.solver_dict['display_iter'],
                        self.train_losses.info(),
                        RunnerHelper.get_lr(self.optimizer),
                        batch_time=self.batch_time,
                        data_time=self.data_time))

                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.solver_dict['lr'][
                    'metric'] == 'iters' and self.runner_state[
                        'iters'] == self.solver_dict['max_iters']:
                break

            if self.runner_state['iters'] % self.solver_dict[
                    'save_iters'] == 0 and self.configer.get(
                        'local_rank') == 0:
                RunnerHelper.save_net(self, self.cls_net)

            # Check to val the current model.
            if self.runner_state['iters'] % self.solver_dict[
                    'test_interval'] == 0:
                self.val()

    def val(self):
        """
          Validation function during the train phase.
        """
        self.cls_net.eval()
        start_time = time.time()
        with torch.no_grad():
            for j, data_dict in enumerate(self.val_loader):
                # Forward pass.
                data_dict = RunnerHelper.to_device(self, data_dict)
                out = self.cls_net(data_dict)
                loss_dict = self.loss(out)
                out_dict, label_dict, _ = RunnerHelper.gather(self, out)
                self.running_score.update(out_dict, label_dict)
                self.val_losses.update(
                    {key: loss.item()
                     for key, loss in loss_dict.items()},
                    data_dict['img'].size(0))

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            RunnerHelper.save_net(self, self.cls_net)
            # Print the log info & reset the states.
            Log.info('Test Time {batch_time.sum:.3f}s'.format(
                batch_time=self.batch_time))
            Log.info('TestLoss = {}'.format(self.val_losses.info()))
            Log.info('Top1 ACC = {}'.format(
                RunnerHelper.dist_avg(self,
                                      self.running_score.get_top1_acc())))
            Log.info('Top3 ACC = {}'.format(
                RunnerHelper.dist_avg(self,
                                      self.running_score.get_top3_acc())))
            Log.info('Top5 ACC = {}'.format(
                RunnerHelper.dist_avg(self,
                                      self.running_score.get_top5_acc())))
            self.batch_time.reset()
            self.batch_time.reset()
            self.val_losses.reset()
            self.running_score.reset()
            self.cls_net.train()