Ejemplo n.º 1
0
    def forward(self, inp):
        resulter, debugger = {}, {}

        t_resulter, t_debugger = self.task_model.forward(inp)

        if not 'pred' in t_resulter.keys(
        ) or not 'activated_pred' in t_resulter.keys():
            logger.log_err(
                'In SSL_S4L, the \'resulter\' dict returned by the task model should contain the following keys:\n'
                '   (1) \'pred\'\t=>\tunactivated task predictions\n'
                '   (2) \'activated_pred\'\t=>\tactivated task predictions\n'
                'We need both of them since some losses include the activation functions,\n'
                'e.g., the CrossEntropyLoss has contained SoftMax\n')

        if not 'ssls4l_rc_inp' in t_resulter.keys():
            logger.log_err(
                'In SSL_S4L, the \'resulter\' dict returned by the task model should contain the key:\n'
                '    \'ssls4l_rc_inp\'\t=>\tinputs of the rotation classifier (a 4-dim tensor)\n'
                'It can be the feature map encoded by the task model or the output of the task model\n'
                'Please add the key \'ssls4l_rc_inp\' in your task model\'s resulter\n'
            )

        rc_inp = tool.dict_value(t_resulter, 'ssls4l_rc_inp')
        pred_rotation = self.rotation_classifier.forward(rc_inp)

        resulter['pred'] = tool.dict_value(t_resulter, 'pred')
        resulter['activated_pred'] = tool.dict_value(t_resulter,
                                                     'activated_pred')
        resulter['rotation'] = pred_rotation

        return resulter, debugger
Ejemplo n.º 2
0
    def _validate(self, data_loader, epoch):
        self.meters.reset()

        self.model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._data_err()

            resulter, debugger = self.model.forward(inp, gt, False)
            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            task_loss = tool.dict_value(resulter, 'task_loss', err=True)
            task_loss = task_loss.mean()
            self.meters.update('task_loss', task_loss.data)

            self.task_func.metrics(activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='task')

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, False,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

        # metrics
        metrics_info = {'task': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6}\t'.format(
                            key, self.meters[key])

        logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format(
            metrics_info['task'].replace('_', '-')))
Ejemplo n.º 3
0
    def forward(self, inp, gt, is_unlabeled):
        resulter, debugger = {}, {}

        # forward the task model
        m_resulter, m_debugger = self.main_model.forward(inp)

        if not 'pred' in m_resulter.keys() or not 'activated_pred' in m_resulter.keys():
            logger.log_err('In SSL_CCT, the \'resulter\' dict returned by the task model should contain the following keys:\n'
                           '   (1) \'pred\'\t=>\tunactivated task predictions\n'
                           '   (2) \'activated_pred\'\t=>\tactivated task predictions\n'
                           'We need both of them since some losses include the activation functions,\n'
                           'e.g., the CrossEntropyLoss has contained SoftMax\n')

        resulter['pred'] = tool.dict_value(m_resulter, 'pred')
        resulter['activated_pred'] = tool.dict_value(m_resulter, 'activated_pred')

        if not len(resulter['pred']) == len(resulter['activated_pred']) == 1:
            logger.log_err('This implementation of SSL_CCT only support the task model with only one prediction (output). \n'
                           'However, there are {0} predictions.\n'.format(len(resulter['pred'])))

        # calculate the task loss
        resulter['task_loss'] = None if is_unlabeled else torch.mean(self.task_criterion.forward(resulter['pred'], gt, inp))

        # for the unlabeled data
        if is_unlabeled and self.args.unlabeled_batch_size > 0:
            if not 'sslcct_ad_inp' in m_resulter.keys():
                logger.log_err('In SSL_CCT, the \'resulter\' dict returned by the task model should contain the key:\n'
                               '    \'sslcct_ad_inp\'\t=>\tinputs of the auxiliary decoders (a 4-dim tensor)\n'
                               'It is the feature map encoded by the task model\n'
                               'Please add the key \'sslcct_ad_inp\' in your task model\'s resulter\n'
                               'Note that for different task models, the shape of \'sslcct_ad_inp\' may be different\n')

            ul_ad_inp = tool.dict_value(m_resulter, 'sslcct_ad_inp')
            ul_main_pred = resulter['pred'][0].detach()

            # forward the auxiliary decoders
            ul_ad_preds = []
            for ad in self.auxiliary_decoders:
                ul_ad_preds.append(ad.forward(ul_ad_inp, pred_of_main_decoder=ul_main_pred))

            resulter['ul_ad_preds'] = ul_ad_preds

            # calculate the consistency loss
            ul_ad_gt = resulter['activated_pred'][0].detach()
            ul_ad_preds = [F.interpolate(ul_ad_pred, size=(ul_ad_gt.shape[2], ul_ad_gt.shape[3]), mode='bilinear') for ul_ad_pred in ul_ad_preds]
            ul_activated_ad_preds = self.ad_activation_func(ul_ad_preds)
            cons_loss = sum([self.cons_criterion.forward(ul_activated_ad_pred, ul_ad_gt) for ul_activated_ad_pred in ul_activated_ad_preds])
            cons_loss = torch.mean(cons_loss) / len(ul_activated_ad_preds)
            resulter['cons_loss'] = cons_loss
        else:
            resulter['ul_ad_preds'] = None
            resulter['cons_loss'] = None

        return resulter, debugger
Ejemplo n.º 4
0
    def _load_checkpoint(self):
        checkpoint = torch.load(self.args.resume)

        checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown')
        if checkpoint_algorithm != self.NAME:
            logger.log_err('Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n'
                           .format(self.NAME, checkpoint_algorithm))

        self.s_model.load_state_dict(checkpoint['s_model'])
        self.t_model.load_state_dict(checkpoint['t_model'])
        self.s_optimizer.load_state_dict(checkpoint['s_optimizer'])
        self.s_lrer.load_state_dict(checkpoint['s_lrer'])

        return checkpoint['epoch']
Ejemplo n.º 5
0
    def _load_checkpoint(self):
        checkpoint = torch.load(self.args.resume)

        checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown')

        if checkpoint_algorithm != self.NAME:
            logger.log_err('Unmatched SSL algorithm format in checkpoint => required: {0} - given: {1}\n'
                           .format(self.NAME, checkpoint_algorithm))

        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.lrer.load_state_dict(checkpoint['lrer'])

        self.main_model = self.model.module.main_model
        self.auxiliary_decoders = self.model.module.auxiliary_decoders

        return checkpoint['epoch']
Ejemplo n.º 6
0
    def _load_checkpoint(self):
        checkpoint = torch.load(self.args.resume)

        checkpoint_algorithm = tool.dict_value(checkpoint,
                                               'algorithm',
                                               default='unknown')
        if checkpoint_algorithm != self.NAME:
            logger.log_err(
                'Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n'
                .format(self.NAME, checkpoint_algorithm))

        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.lrer.load_state_dict(checkpoint['lrer'])

        self.task_model = self.model.module.task_model
        self.rotation_classifier = self.model.module.rotation_classifier

        return checkpoint['epoch']
Ejemplo n.º 7
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.l_model.train()
        self.r_model.train()
        self.fd_model.train()

        # both 'inp' and 'gt' are tuples
        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt)
            if len(l_gt) == len(r_gt) > 1 and idx == 0:
                self._inp_warn()

            # calculate the ramp-up coefficient of the dynamic consistency constraint
            cur_steps = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.dc_rampup_epochs
            dc_rampup_scale = func.sigmoid_rampup(cur_steps, total_steps)

            # -----------------------------------------------------------------------------
            # step-0: pre-forwarding to save GPU memory
            #   - forward the task models and the flaw detector
            #   - generate pseudo ground truth for the unlabeled data if the dynamic
            #     consistency constraint is enabled
            # -----------------------------------------------------------------------------
            with torch.no_grad():
                l_resulter, l_debugger = self.l_model.forward(l_inp)
                l_activated_pred = tool.dict_value(l_resulter,
                                                   'activated_pred')
                r_resulter, r_debugger = self.r_model.forward(r_inp)
                r_activated_pred = tool.dict_value(r_resulter,
                                                   'activated_pred')

            # 'l_flawmap' and 'r_flawmap' will be used in step-2
            fd_resulter, fd_debugger = self.fd_model.forward(
                l_inp, l_activated_pred[0])
            l_flawmap = tool.dict_value(fd_resulter, 'flawmap')
            fd_resulter, fd_debugger = self.fd_model.forward(
                r_inp, r_activated_pred[0])
            r_flawmap = tool.dict_value(fd_resulter, 'flawmap')

            l_dc_gt, r_dc_gt = None, None
            l_fc_mask, r_fc_mask = None, None

            # generate the pseudo ground truth for the dynamic consistency constraint
            if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
                with torch.no_grad():
                    l_handled_flawmap = self.flawmap_handler.forward(l_flawmap)
                    r_handled_flawmap = self.flawmap_handler.forward(r_flawmap)
                    l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward(
                        l_activated_pred[0].detach(),
                        r_activated_pred[0].detach(), l_handled_flawmap,
                        r_handled_flawmap)

            # -----------------------------------------------------------------------------
            # step-1: train the task models
            # -----------------------------------------------------------------------------
            for param in self.fd_model.parameters():
                param.requires_grad = False

            # train the 'l' task model
            l_loss = self._task_model_iter(epoch, idx, True, 'l', lbs, l_inp,
                                           l_gt, l_dc_gt, l_fc_mask,
                                           dc_rampup_scale)
            self.l_optimizer.zero_grad()
            l_loss.backward()
            self.l_optimizer.step()

            # train the 'r' task model
            r_loss = self._task_model_iter(epoch, idx, True, 'r', lbs, r_inp,
                                           r_gt, r_dc_gt, r_fc_mask,
                                           dc_rampup_scale)
            self.r_optimizer.zero_grad()
            r_loss.backward()
            self.r_optimizer.step()

            # -----------------------------------------------------------------------------
            # step-2: train the flaw detector
            # -----------------------------------------------------------------------------
            for param in self.fd_model.parameters():
                param.requires_grad = True

            # generate the ground truth for the flaw detector (on labeled data only)
            with torch.no_grad():
                l_flawmap_gt = self.fdgt_generator.forward(
                    l_activated_pred[0][:lbs, ...].detach(),
                    self.task_func.sslgct_prepare_task_gt_for_fdgt(
                        l_gt[0][:lbs, ...]))
                r_flawmap_gt = self.fdgt_generator.forward(
                    r_activated_pred[0][:lbs, ...].detach(),
                    self.task_func.sslgct_prepare_task_gt_for_fdgt(
                        r_gt[0][:lbs, ...]))

            l_fd_loss = self.fd_criterion.forward(l_flawmap[:lbs, ...],
                                                  l_flawmap_gt)
            l_fd_loss = self.args.fd_scale * torch.mean(l_fd_loss)
            self.meters.update('l_fd_loss', l_fd_loss.data)

            r_fd_loss = self.fd_criterion.forward(r_flawmap[:lbs, ...],
                                                  r_flawmap_gt)
            r_fd_loss = self.args.fd_scale * torch.mean(r_fd_loss)
            self.meters.update('r_fd_loss', r_fd_loss.data)

            fd_loss = (l_fd_loss + r_fd_loss) / 2

            self.fd_optimizer.zero_grad()
            fd_loss.backward()
            self.fd_optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  l-{3}\t=>\t'
                    'l-task-loss: {meters[l_task_loss]:.6f}\t'
                    'l-dc-loss: {meters[l_dc_loss]:.6f}\t'
                    'l-fc-loss: {meters[l_fc_loss]:.6f}\n'
                    '  r-{3}\t=>\t'
                    'r-task-loss: {meters[r_task_loss]:.6f}\t'
                    'r-dc-loss: {meters[r_dc_loss]:.6f}\t'
                    'r-fc-loss: {meters[r_fc_loss]:.6f}\n'
                    '  fd\t=>\t'
                    'l-fd-loss: {meters[l_fd_loss]:.6f}\t'
                    'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # the flaw detector uses polynomiallr [ITER_LRERS]
            self.fd_lrer.step()
            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.l_lrer.step()
                self.r_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.l_lrer.step()
            self.r_lrer.step()
Ejemplo n.º 8
0
    def _train(self, data_loader, epoch):
        # disable unlabeled data
        without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0
        if not without_unlabeled_data:
            logger.log_err(
                'SSL_NULL is a supervised-only algorithm\n'
                'Please set ignore_unlabeled = True and unlabeled_batch_size = 0\n'
            )

        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # both 'inp' and 'gt' are tuples
            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            self.optimizer.zero_grad()

            # forward the task model
            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            # calculate the supervised task constraint on the labeled data
            l_pred = func.split_tensor_tuple(pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            task_loss = self.criterion.forward(l_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # backward and update the task model
            loss = task_loss
            loss.backward()
            self.optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
Ejemplo n.º 9
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._data_err()

            # TODO: support more ramp-up functions
            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.optimizer.zero_grad()

            # -----------------------------------------------------------
            # For Labeled Data
            # -----------------------------------------------------------
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # forward the wrapped CCT model
            resulter, debugger = self.model.forward(l_inp, l_gt, False)
            l_pred = tool.dict_value(resulter, 'pred')
            l_activated_pred = tool.dict_value(resulter, 'activated_pred')

            task_loss = tool.dict_value(resulter, 'task_loss', err=True)
            task_loss = task_loss.mean()
            self.meters.update('task_loss', task_loss.data)

            # -----------------------------------------------------------
            # For Unlabeled Data
            # -----------------------------------------------------------
            if self.args.unlabeled_batch_size > 0:
                ul_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size)
                ul_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)

                # forward the wrapped CCT model
                resulter, debugger = self.model.forward(ul_inp, ul_gt, True)
                ul_pred = tool.dict_value(resulter, 'pred')
                ul_activated_pred = tool.dict_value(resulter, 'activated_pred')
                ul_ad_preds = tool.dict_value(resulter, 'ul_ad_preds')

                cons_loss = tool.dict_value(resulter, 'cons_loss', err=True)
                cons_loss = cons_loss.mean()
                cons_loss = cons_rampup_scale * self.args.cons_scale * cons_loss
                self.meters.update('cons_loss', cons_loss.data)

            else:
                cons_loss = 0
                self.meters.update('cons_loss', cons_loss)

            # backward and update the wrapped CCT model
            loss = task_loss + cons_loss
            loss.backward()
            self.optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                                '  task-{3}\t=>\t'
                                'task-loss: {meters[task_loss]:.6f}\t'
                                'cons-loss: {meters[cons_loss]:.6f}\n'
                                .format(epoch, idx, len(data_loader), self.args.task, meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(epoch, idx, True, 
                                func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()
        
        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
Ejemplo n.º 10
0
    def _validate(self, data_loader, epoch):
        self.meters.reset()

        self.s_model.eval()
        self.t_model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt, _, _ = self._batch_prehandle(inp, gt, False)
            if len(inp) > 1 and idx == 0:
                self._inp_warn()
            if len(gt) > 1 and idx == 0:
                self._gt_warn()

            s_resulter, s_debugger = self.s_model.forward(inp)
            if not 'pred' in s_resulter.keys(
            ) or not 'activated_pred' in s_resulter.keys():
                self._pred_err()
            s_pred = tool.dict_value(s_resulter, 'pred')
            s_activated_pred = tool.dict_value(s_resulter, 'activated_pred')

            s_task_loss = self.s_criterion.forward(s_pred, gt, inp)
            s_task_loss = torch.mean(s_task_loss)
            self.meters.update('s_task_loss', s_task_loss.data)

            t_resulter, t_debugger = self.t_model.forward(inp)
            if not 'pred' in t_resulter.keys(
            ) or not 'activated_pred' in t_resulter.keys():
                self._pred_err()
            t_pred = tool.dict_value(t_resulter, 'pred')
            t_activated_pred = tool.dict_value(t_resulter, 'activated_pred')

            t_task_loss = self.s_criterion.forward(t_pred, gt, inp)
            t_task_loss = torch.mean(t_task_loss)
            self.meters.update('t_task_loss', t_task_loss.data)

            t_pseudo_gt = []
            for tap in t_activated_pred:
                t_pseudo_gt.append(tap.detach())
            t_pseudo_gt = tuple(t_pseudo_gt)

            cons_loss = 0
            for sap, tpg in zip(s_activated_pred, t_pseudo_gt):
                cons_loss += torch.mean(self.cons_criterion(sap, tpg))
            cons_loss = self.args.cons_scale * torch.mean(cons_loss)
            self.meters.update('cons_loss', cons_loss.data)

            self.task_func.metrics(s_activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='student')
            self.task_func.metrics(t_activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='teacher')

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  student-{3}\t=>\t'
                    's-task-loss: {meters[s_task_loss]:.6f}\t'
                    's-cons-loss: {meters[cons_loss]:.6f}\n'
                    '  teacher-{3}\t=>\t'
                    't-task-loss: {meters[t_task_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, False,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

        # metrics
        metrics_info = {'student': '', 'teacher': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6}\t'.format(
                            key, self.meters[key])

        logger.log_info(
            'Validation metrics:\n  student-metrics\t=>\t{0}\n  teacher-metrics\t=>\t{1}\n'
            .format(metrics_info['student'].replace('_', '-'),
                    metrics_info['teacher'].replace('_', '-')))
Ejemplo n.º 11
0
    def _validate(self, data_loader, epoch):
        self.meters.reset()

        self.model.eval()
        self.d_model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            task_loss = self.criterion.forward(pred, gt, inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            d_resulter, d_debugger = self.d_model.forward(activated_pred[0])
            unhandled_fake_confidence_map = tool.dict_value(
                d_resulter, 'confidence')
            fake_confidence_map, fake_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(unhandled_fake_confidence_map, gt[0], False)

            fake_d_loss = self.d_criterion.forward(fake_confidence_map,
                                                   fake_confidence_gt)
            fake_d_loss = self.args.discriminator_scale * torch.mean(
                fake_d_loss)
            self.meters.update('fake_d_loss', fake_d_loss.data)

            real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(gt[0])
            d_resulter, d_debugger = self.d_model.forward(real_gt)
            unhandled_real_confidence_map = tool.dict_value(
                d_resulter, 'confidence')
            real_confidence_map, real_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(unhandled_real_confidence_map, gt[0], True)

            real_d_loss = self.d_criterion.forward(real_confidence_map,
                                                   real_confidence_gt)
            real_d_loss = self.args.discriminator_scale * torch.mean(
                real_d_loss)
            self.meters.update('real_d_loss', real_d_loss.data)

            self.task_func.metrics(activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='task')

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'
                    '  fc-discriminator\t=>\t'
                    'fake-d-loss: {meters[fake_d_loss]:.6f}\t'
                    'real-d-loss: {meters[real_d_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, False,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                    torch.sigmoid(unhandled_fake_confidence_map[0]))

        # metrics
        metrics_info = {'task': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6}\t'.format(
                            key, self.meters[key])

        logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format(
            metrics_info['task'].replace('_', '-')))
Ejemplo n.º 12
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()
        self.d_model.train()

        # both 'inp' and 'gt' are tuples
        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            # -----------------------------------------------------------------------------
            # step-1: train the task model
            # -----------------------------------------------------------------------------
            self.optimizer.zero_grad()

            # forward the task model
            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            # forward the FC discriminator
            # 'confidence_map' is a tensor
            d_resulter, d_debugger = self.d_model.forward(activated_pred[0])
            confidence_map = tool.dict_value(d_resulter, 'confidence')

            # calculate the supervised task constraint on the labeled data
            l_pred = func.split_tensor_tuple(pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            task_loss = self.criterion.forward(l_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # calculate the adversarial constraint
            # calculate the adversarial constraint for the labeled data
            if self.args.adv_for_labeled:
                l_confidence_map = confidence_map[:lbs, ...]

                # preprocess prediction and ground truch for the adversarial constraint
                l_adv_confidence_map, l_adv_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(l_confidence_map, l_gt[0], True)
                l_adv_loss = self.d_criterion(l_adv_confidence_map,
                                              l_adv_confidence_gt)
                labeled_adv_loss = self.args.labeled_adv_scale * torch.mean(
                    l_adv_loss)
                self.meters.update('labeled_adv_loss', labeled_adv_loss.data)
            else:
                labeled_adv_loss = 0
                self.meters.update('labeled_adv_loss', labeled_adv_loss)

            # calculate the adversarial constraint for the unlabeled data
            if self.args.unlabeled_batch_size > 0:
                u_confidence_map = confidence_map[lbs:self.args.batch_size,
                                                  ...]

                # preprocess prediction and ground truch for the adversarial constraint
                u_adv_confidence_map, u_adv_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(u_confidence_map, None, True)
                u_adv_loss = self.d_criterion(u_adv_confidence_map,
                                              u_adv_confidence_gt)
                unlabeled_adv_loss = self.args.unlabeled_adv_scale * torch.mean(
                    u_adv_loss)
                self.meters.update('unlabeled_adv_loss',
                                   unlabeled_adv_loss.data)
            else:
                unlabeled_adv_loss = 0
                self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss)

            adv_loss = labeled_adv_loss + unlabeled_adv_loss

            # backward and update the task model
            loss = task_loss + adv_loss
            loss.backward()
            self.optimizer.step()

            # -----------------------------------------------------------------------------
            # step-2: train the FC discriminator
            # -----------------------------------------------------------------------------
            self.d_optimizer.zero_grad()

            # forward the task prediction (fake)
            if self.args.unlabeled_for_discriminator:
                fake_pred = activated_pred[0].detach()
            else:
                fake_pred = activated_pred[0][:lbs, ...].detach()

            d_resulter, d_debugger = self.d_model.forward(fake_pred)
            fake_confidence_map = tool.dict_value(d_resulter, 'confidence')

            l_fake_confidence_map = fake_confidence_map[:lbs, ...]
            l_fake_confidence_map, l_fake_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(l_fake_confidence_map, l_gt[0], False)

            if self.args.unlabeled_for_discriminator and self.args.unlabeled_batch_size != 0:
                u_fake_confidence_map = fake_confidence_map[
                    lbs:self.args.batch_size, ...]
                u_fake_confidence_map, u_fake_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(u_fake_confidence_map, None, False)

                fake_confidence_map = torch.cat(
                    (l_fake_confidence_map, u_fake_confidence_map), dim=0)
                fake_confidence_gt = torch.cat(
                    (l_fake_confidence_gt, u_fake_confidence_gt), dim=0)

            else:
                fake_confidence_map, fake_confidence_gt = l_fake_confidence_map, l_fake_confidence_gt

            fake_d_loss = self.d_criterion.forward(fake_confidence_map,
                                                   fake_confidence_gt)
            fake_d_loss = self.args.discriminator_scale * torch.mean(
                fake_d_loss)
            self.meters.update('fake_d_loss', fake_d_loss.data)

            # forward the ground truth (real)
            # convert the format of ground truch
            real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(
                l_gt[0])
            d_resulter, d_debugger = self.d_model.forward(real_gt)
            real_confidence_map = tool.dict_value(d_resulter, 'confidence')

            real_confidence_map, real_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(real_confidence_map, l_gt[0], True)

            real_d_loss = self.d_criterion(real_confidence_map,
                                           real_confidence_gt)
            real_d_loss = self.args.discriminator_scale * torch.mean(
                real_d_loss)
            self.meters.update('real_d_loss', real_d_loss.data)

            # backward and update the FC discriminator
            d_loss = (fake_d_loss + real_d_loss) / 2
            d_loss.backward()
            self.d_optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'
                    'labeled-adv-loss: {meters[labeled_adv_loss]:.6f}\t'
                    'unlabeled-adv-loss: {meters[unlabeled_adv_loss]:.6f}\n'
                    '  fc-discriminator\t=>\t'
                    'fake-d-loss: {meters[fake_d_loss]:.6f}\t'
                    'real-d-loss: {meters[real_d_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                u_inp_sample, u_pred_sample, u_cmap_sample = None, None, None
                if self.args.unlabeled_batch_size > 0:
                    u_inp_sample = func.split_tensor_tuple(inp,
                                                           lbs,
                                                           lbs + 1,
                                                           reduce_dim=True)
                    u_pred_sample = func.split_tensor_tuple(activated_pred,
                                                            lbs,
                                                            lbs + 1,
                                                            reduce_dim=True)
                    u_cmap_sample = torch.sigmoid(fake_confidence_map[lbs])

                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                    torch.sigmoid(confidence_map[0]), u_inp_sample,
                    u_pred_sample, u_cmap_sample)

            # the FC discriminator uses polynomiallr [ITER_LRERS]
            self.d_lrer.step()
            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
Ejemplo n.º 13
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        original_lbs = int(self.args.labeled_batch_size / 2)
        original_bs = int(self.args.batch_size / 2)

        self.model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # the rotated samples are generated in the 'self._batch_prehandle' function
            # both 'inp' and 'gt' are tuples
            # the last element in the tuple 'gt' is the ground truth of the rotation angle
            inp, gt = self._batch_prehandle(inp, gt, True)
            if len(gt) - 1 > 1 and idx == 0:
                self._inp_warn()

            self.optimizer.zero_grad()

            # forward the model
            resulter, debugger = self.model.forward(inp)
            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')
            pred_rotation = tool.dict_value(resulter, 'rotation')

            # calculate the supervised task constraint on the un-rotated labeled data
            l_pred = func.split_tensor_tuple(pred, 0, original_lbs)
            l_gt = func.split_tensor_tuple(gt, 0, original_lbs)
            l_inp = func.split_tensor_tuple(inp, 0, original_lbs)

            unrotated_task_loss = self.criterion.forward(
                l_pred, l_gt[:-1], l_inp)
            unrotated_task_loss = torch.mean(unrotated_task_loss)
            self.meters.update('unrotated_task_loss', unrotated_task_loss.data)

            # calculate the supervised task constraint on the rotated labeled data
            l_rotated_pred = func.split_tensor_tuple(
                pred, original_bs, original_bs + original_lbs)
            l_rotated_gt = func.split_tensor_tuple(gt, original_bs,
                                                   original_bs + original_lbs)
            l_rotated_inp = func.split_tensor_tuple(inp, original_bs,
                                                    original_bs + original_lbs)

            rotated_task_loss = self.criterion.forward(l_rotated_pred,
                                                       l_rotated_gt[:-1],
                                                       l_rotated_inp)
            rotated_task_loss = self.args.rotated_sup_scale * torch.mean(
                rotated_task_loss)
            self.meters.update('rotated_task_loss', rotated_task_loss.data)

            task_loss = unrotated_task_loss + rotated_task_loss

            # calculate the self-supervised rotation constraint
            rotation_loss = self.rotation_criterion.forward(
                pred_rotation, gt[-1])
            rotation_loss = self.args.rotation_scale * torch.mean(
                rotation_loss)
            self.meters.update('rotation_loss', rotation_loss.data)

            # backward and update the model
            loss = task_loss + rotation_loss
            loss.backward()
            self.optimizer.step()

            # calculate the accuracy of the rotation classifier
            _, angle_idx = pred_rotation.topk(1, 1, True, True)
            angle_idx = angle_idx.t()
            rotation_acc = angle_idx.eq(gt[-1].view(1,
                                                    -1).expand_as(angle_idx))
            rotation_acc = rotation_acc.view(-1).float().sum(
                0, keepdim=True).mul_(100.0 / self.args.batch_size)
            self.meters.update('rotation_acc', rotation_acc.data[0])

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'unrotated-task-loss: {meters[unrotated_task_loss]:.6f}\t'
                    'rotated-task-loss: {meters[rotated_task_loss]:.6f}\n'
                    '  rotation-{3}\t=>\t'
                    'rotation-loss: {meters[rotation_loss]:.6f}\t'
                    'rotation-acc: {meters[rotation_acc]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt[:-1], 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
Ejemplo n.º 14
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.s_model.train()
        self.t_model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # 's_inp', 't_inp' and 'gt' are tuples
            s_inp, t_inp, gt = self._batch_prehandle(inp, gt, True)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.s_optimizer.zero_grad()

            # forward the student model
            s_resulter, s_debugger = self.s_model.forward(s_inp)
            if not 'pred' in s_resulter.keys(
            ) or not 'activated_pred' in s_resulter.keys():
                self._pred_err()
            s_pred = tool.dict_value(s_resulter, 'pred')
            s_activated_pred = tool.dict_value(s_resulter, 'activated_pred')

            # calculate the supervised task constraint on the labeled data
            l_s_pred = func.split_tensor_tuple(s_pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_s_inp = func.split_tensor_tuple(s_inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            s_task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_s_inp)
            s_task_loss = torch.mean(s_task_loss)
            self.meters.update('s_task_loss', s_task_loss.data)

            # forward the teacher model
            with torch.no_grad():
                t_resulter, t_debugger = self.t_model.forward(t_inp)
                if not 'pred' in t_resulter.keys():
                    self._pred_err()
                t_pred = tool.dict_value(t_resulter, 'pred')
                t_activated_pred = tool.dict_value(t_resulter,
                                                   'activated_pred')

                # calculate 't_task_loss' for recording
                l_t_pred = func.split_tensor_tuple(t_pred, 0, lbs)
                l_t_inp = func.split_tensor_tuple(t_inp, 0, lbs)
                t_task_loss = self.s_criterion.forward(l_t_pred, l_gt, l_t_inp)
                t_task_loss = torch.mean(t_task_loss)
                self.meters.update('t_task_loss', t_task_loss.data)

            # calculate the consistency constraint from the teacher model to the student model
            t_pseudo_gt = Variable(t_pred[0].detach().data,
                                   requires_grad=False)

            if self.args.cons_for_labeled:
                cons_loss = self.cons_criterion(s_pred[0], t_pseudo_gt)
            elif self.args.unlabeled_batch_size > 0:
                cons_loss = self.cons_criterion(s_pred[0][lbs:, ...],
                                                t_pseudo_gt[lbs:, ...])
            else:
                cons_loss = self.zero_tensor
            cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean(
                cons_loss)
            self.meters.update('cons_loss', cons_loss.data)

            # backward and update the student model
            loss = s_task_loss + cons_loss
            loss.backward()
            self.s_optimizer.step()

            # update the teacher model by EMA
            self._update_ema_variables(self.s_model, self.t_model,
                                       self.args.ema_decay, cur_step)

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  student-{3}\t=>\t'
                    's-task-loss: {meters[s_task_loss]:.6f}\t'
                    's-cons-loss: {meters[cons_loss]:.6f}\n'
                    '  teacher-{3}\t=>\t'
                    't-task-loss: {meters[t_task_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(t_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.s_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.s_lrer.step()
Ejemplo n.º 15
0
    def _validate(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.l_model.eval()
        self.r_model.eval()
        self.fd_model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt)
            if len(l_gt) == len(r_gt) > 1 and idx == 0:
                self._inp_warn()

            l_dc_gt, r_dc_gt = None, None
            l_fc_mask, r_fc_mask = None, None
            if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
                l_resulter, l_debugger = self.l_model.forward(l_inp)
                l_activated_pred = tool.dict_value(l_resulter,
                                                   'activated_pred')
                r_resulter, r_debugger = self.r_model.forward(r_inp)
                r_activated_pred = tool.dict_value(r_resulter,
                                                   'activated_pred')

                fd_resulter, fd_debugger = self.fd_model.forward(
                    l_inp, l_activated_pred[0])
                l_flawmap = tool.dict_value(fd_resulter, 'flawmap')
                fd_resulter, fd_debugger = self.fd_model.forward(
                    r_inp, r_activated_pred[0])
                r_flawmap = tool.dict_value(fd_resulter, 'flawmap')

                l_handled_flawmap = self.flawmap_handler.forward(l_flawmap)
                r_handled_flawmap = self.flawmap_handler.forward(r_flawmap)
                l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward(
                    l_activated_pred[0].detach(), r_activated_pred[0].detach(),
                    l_handled_flawmap, r_handled_flawmap)

            l_loss = self._task_model_iter(epoch, idx, False, 'l', lbs, l_inp,
                                           l_gt, l_dc_gt, l_fc_mask, 1)
            r_loss = self._task_model_iter(epoch, idx, False, 'r', lbs, r_inp,
                                           r_gt, r_dc_gt, r_fc_mask, 1)

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  l-{3}\t=>\t'
                    'l-task-loss: {meters[l_task_loss]:.6f}\t'
                    'l-dc-loss: {meters[l_dc_loss]:.6f}\t'
                    'l-fc-loss: {meters[l_fc_loss]:.6f}\n'
                    '  r-{3}\t=>\t'
                    'r-task-loss: {meters[r_task_loss]:.6f}\t'
                    'r-dc-loss: {meters[r_dc_loss]:.6f}\t'
                    'r-fc-loss: {meters[r_fc_loss]:.6f}\n'
                    '  fd\t=>\t'
                    'l-fd-loss: {meters[l_fd_loss]:.6f}\t'
                    'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

        # metrics
        metrics_info = {'l': '', 'r': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6f}\t'.format(
                            key, self.meters[key])

        logger.log_info(
            'Validation metrics:\n  l-metrics\t=>\t{0}\n  r-metrics\t=>\t{1}\n'
            .format(metrics_info['l'].replace('_', '-'),
                    metrics_info['r'].replace('_', '-')))
Ejemplo n.º 16
0
    def _task_model_iter(self, epoch, idx, is_train, mid, lbs, inp, gt, dc_gt,
                         fc_mask, dc_rampup_scale):
        if mid == 'l':
            model, criterion = self.l_model, self.l_criterion
        elif mid == 'r':
            model, criterion = self.r_model, self.r_criterion
        else:
            model, criterion = None, None

        # forward the task model
        resulter, debugger = model.forward(inp)
        if not 'pred' in resulter.keys(
        ) or not 'activated_pred' in resulter.keys():
            self._pred_err()

        pred = tool.dict_value(resulter, 'pred')
        activated_pred = tool.dict_value(resulter, 'activated_pred')

        fd_resulter, fd_debugger = self.fd_model.forward(
            inp, activated_pred[0])
        flawmap = tool.dict_value(fd_resulter, 'flawmap')

        # calculate the supervised task constraint on the labeled data
        labeled_pred = func.split_tensor_tuple(pred, 0, lbs)
        labeled_gt = func.split_tensor_tuple(gt, 0, lbs)
        labeled_inp = func.split_tensor_tuple(inp, 0, lbs)
        task_loss = torch.mean(
            criterion.forward(labeled_pred, labeled_gt, labeled_inp))
        self.meters.update('{0}_task_loss'.format(mid), task_loss.data)

        # calculate the flaw correction constraint
        if self.args.ssl_mode in [MODE_GCT, MODE_FC]:
            if flawmap.shape == self.zero_df_gt.shape:
                fc_ssl_loss = self.fd_criterion.forward(flawmap,
                                                        self.zero_df_gt,
                                                        is_ssl=True,
                                                        reduction=False)
            else:
                fc_ssl_loss = self.fd_criterion.forward(
                    flawmap,
                    torch.zeros(flawmap.shape).cuda(),
                    is_ssl=True,
                    reduction=False)

            if self.args.ssl_mode == MODE_GCT:
                fc_ssl_loss = fc_mask * fc_ssl_loss

            fc_ssl_loss = self.args.fc_ssl_scale * torch.mean(fc_ssl_loss)
            self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss.data)
        else:
            fc_ssl_loss = 0
            self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss)

        # calculate the dynamic consistency constraint
        if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
            if dc_gt is None:
                logger.log_err(
                    'The dynamic consistency constraint is enabled, '
                    'but no pseudo ground truth is given.')

            dc_ssl_loss = self.dc_criterion.forward(activated_pred[0], dc_gt)
            dc_ssl_loss = dc_rampup_scale * self.args.dc_ssl_scale * torch.mean(
                dc_ssl_loss)
            self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss.data)
        else:
            dc_ssl_loss = 0
            self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss)

        with torch.no_grad():
            flawmap_gt = self.fdgt_generator.forward(
                activated_pred[0],
                self.task_func.sslgct_prepare_task_gt_for_fdgt(gt[0]))

        # for validation
        if not is_train:
            fd_loss = self.args.fd_scale * self.fd_criterion.forward(
                flawmap, flawmap_gt)
            self.meters.update('{0}_fd_loss'.format(mid),
                               torch.mean(fd_loss).data)

            self.task_func.metrics(activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str=mid)

        # visualization
        if self.args.visualize and idx % self.args.visual_freq == 0:
            with torch.no_grad():
                handled_flawmap = self.flawmap_handler(flawmap)[0]

            self._visualize(
                epoch, idx, is_train, mid,
                func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True),
                func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                handled_flawmap, flawmap_gt[0], dc_gt[0])

        loss = task_loss + fc_ssl_loss + dc_ssl_loss
        return loss
Ejemplo n.º 17
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size
        ubs = self.args.unlabeled_batch_size

        self.s_model.train()
        self.t_model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # 'inp' and 'gt' are tuples
            inp, gt, mix_u_inp, mix_u_mask = self._batch_prehandle(
                inp, gt, True)
            if len(inp) > 1 and idx == 0:
                self._inp_warn()
            if len(gt) > 1 and idx == 0:
                self._gt_warn()

            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.s_optimizer.zero_grad()

            # -------------------------------------------------
            # For Labeled Samples
            # -------------------------------------------------
            l_inp = func.split_tensor_tuple(inp, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)

            # forward the labeled samples by the student model
            l_s_resulter, l_s_debugger = self.s_model.forward(l_inp)
            if not 'pred' in l_s_resulter.keys(
            ) or not 'activated_pred' in l_s_resulter.keys():
                self._pred_err()
            l_s_pred = tool.dict_value(l_s_resulter, 'pred')
            l_s_activated_pred = tool.dict_value(l_s_resulter,
                                                 'activated_pred')

            # calculate the supervised task loss on the labeled samples
            task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # -------------------------------------------------
            # For Unlabeled Samples
            # -------------------------------------------------
            if self.args.unlabeled_batch_size > 0:
                u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)

                # forward the original samples by the teacher model
                with torch.no_grad():
                    u_t_resulter, u_t_debugger = self.t_model.forward(u_inp)
                if not 'pred' in u_t_resulter.keys(
                ) or not 'activated_pred' in u_t_resulter.keys():
                    self._pred_err()
                u_t_activated_pred = tool.dict_value(u_t_resulter,
                                                     'activated_pred')

                # mix the activated pred from the teacher model as the pseudo gt
                u_t_activated_pred_1 = func.split_tensor_tuple(
                    u_t_activated_pred, 0, int(ubs / 2))
                u_t_activated_pred_2 = func.split_tensor_tuple(
                    u_t_activated_pred, int(ubs / 2), ubs)

                mix_u_t_activated_pred = []
                mix_u_t_confidence = []
                for up_1, up_2 in zip(u_t_activated_pred_1,
                                      u_t_activated_pred_2):
                    mp = mix_u_mask * up_1 + (1 - mix_u_mask) * up_2
                    mix_u_t_activated_pred.append(mp.detach())

                    # NOTE: here we just follow the official code of CutMix to calculate the confidence
                    #       but it is odd that all the samples use the same confidence (mean confidence)
                    u_t_confidence = (mp.max(dim=1)[0] >
                                      self.args.cons_threshold).float().mean()
                    mix_u_t_confidence.append(u_t_confidence.detach())

                mix_u_t_activated_pred = tuple(mix_u_t_activated_pred)

                # forward the mixed samples by the student model
                u_s_resulter, u_s_debugger = self.s_model.forward(mix_u_inp)
                if not 'pred' in u_s_resulter.keys(
                ) or not 'activated_pred' in u_s_resulter.keys():
                    self._pred_err()
                mix_u_s_activated_pred = tool.dict_value(
                    u_s_resulter, 'activated_pred')

                # calculate the consistency constraint
                cons_loss = 0
                for msap, mtap, confidence in zip(mix_u_s_activated_pred,
                                                  mix_u_t_activated_pred,
                                                  mix_u_t_confidence):
                    cons_loss += torch.mean(self.cons_criterion(
                        msap, mtap)) * confidence
                cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean(
                    cons_loss)
                self.meters.update('cons_loss', cons_loss.data)
            else:
                cons_loss = 0
                self.meters.update('cons_loss', cons_loss)

            # backward and update the student model
            loss = task_loss + cons_loss
            loss.backward()
            self.s_optimizer.step()

            # update the teacher model by EMA
            self._update_ema_variables(self.s_model, self.t_model,
                                       self.args.ema_decay, cur_step)

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  student-{3}\t=>\t'
                    's-task-loss: {meters[task_loss]:.6f}\t'
                    's-cons-loss: {meters[cons_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(l_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(l_s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(l_gt, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(mix_u_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(mix_u_s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(mix_u_t_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True), mix_u_mask[0])

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.s_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.s_lrer.step()