Ejemplo n.º 1
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.l_model.train()
        self.r_model.train()
        self.fd_model.train()

        # both 'inp' and 'gt' are tuples
        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt)
            if len(l_gt) == len(r_gt) > 1 and idx == 0:
                self._inp_warn()

            # calculate the ramp-up coefficient of the dynamic consistency constraint
            cur_steps = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.dc_rampup_epochs
            dc_rampup_scale = func.sigmoid_rampup(cur_steps, total_steps)

            # -----------------------------------------------------------------------------
            # step-0: pre-forwarding to save GPU memory
            #   - forward the task models and the flaw detector
            #   - generate pseudo ground truth for the unlabeled data if the dynamic
            #     consistency constraint is enabled
            # -----------------------------------------------------------------------------
            with torch.no_grad():
                l_resulter, l_debugger = self.l_model.forward(l_inp)
                l_activated_pred = tool.dict_value(l_resulter,
                                                   'activated_pred')
                r_resulter, r_debugger = self.r_model.forward(r_inp)
                r_activated_pred = tool.dict_value(r_resulter,
                                                   'activated_pred')

            # 'l_flawmap' and 'r_flawmap' will be used in step-2
            fd_resulter, fd_debugger = self.fd_model.forward(
                l_inp, l_activated_pred[0])
            l_flawmap = tool.dict_value(fd_resulter, 'flawmap')
            fd_resulter, fd_debugger = self.fd_model.forward(
                r_inp, r_activated_pred[0])
            r_flawmap = tool.dict_value(fd_resulter, 'flawmap')

            l_dc_gt, r_dc_gt = None, None
            l_fc_mask, r_fc_mask = None, None

            # generate the pseudo ground truth for the dynamic consistency constraint
            if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
                with torch.no_grad():
                    l_handled_flawmap = self.flawmap_handler.forward(l_flawmap)
                    r_handled_flawmap = self.flawmap_handler.forward(r_flawmap)
                    l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward(
                        l_activated_pred[0].detach(),
                        r_activated_pred[0].detach(), l_handled_flawmap,
                        r_handled_flawmap)

            # -----------------------------------------------------------------------------
            # step-1: train the task models
            # -----------------------------------------------------------------------------
            for param in self.fd_model.parameters():
                param.requires_grad = False

            # train the 'l' task model
            l_loss = self._task_model_iter(epoch, idx, True, 'l', lbs, l_inp,
                                           l_gt, l_dc_gt, l_fc_mask,
                                           dc_rampup_scale)
            self.l_optimizer.zero_grad()
            l_loss.backward()
            self.l_optimizer.step()

            # train the 'r' task model
            r_loss = self._task_model_iter(epoch, idx, True, 'r', lbs, r_inp,
                                           r_gt, r_dc_gt, r_fc_mask,
                                           dc_rampup_scale)
            self.r_optimizer.zero_grad()
            r_loss.backward()
            self.r_optimizer.step()

            # -----------------------------------------------------------------------------
            # step-2: train the flaw detector
            # -----------------------------------------------------------------------------
            for param in self.fd_model.parameters():
                param.requires_grad = True

            # generate the ground truth for the flaw detector (on labeled data only)
            with torch.no_grad():
                l_flawmap_gt = self.fdgt_generator.forward(
                    l_activated_pred[0][:lbs, ...].detach(),
                    self.task_func.sslgct_prepare_task_gt_for_fdgt(
                        l_gt[0][:lbs, ...]))
                r_flawmap_gt = self.fdgt_generator.forward(
                    r_activated_pred[0][:lbs, ...].detach(),
                    self.task_func.sslgct_prepare_task_gt_for_fdgt(
                        r_gt[0][:lbs, ...]))

            l_fd_loss = self.fd_criterion.forward(l_flawmap[:lbs, ...],
                                                  l_flawmap_gt)
            l_fd_loss = self.args.fd_scale * torch.mean(l_fd_loss)
            self.meters.update('l_fd_loss', l_fd_loss.data)

            r_fd_loss = self.fd_criterion.forward(r_flawmap[:lbs, ...],
                                                  r_flawmap_gt)
            r_fd_loss = self.args.fd_scale * torch.mean(r_fd_loss)
            self.meters.update('r_fd_loss', r_fd_loss.data)

            fd_loss = (l_fd_loss + r_fd_loss) / 2

            self.fd_optimizer.zero_grad()
            fd_loss.backward()
            self.fd_optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  l-{3}\t=>\t'
                    'l-task-loss: {meters[l_task_loss]:.6f}\t'
                    'l-dc-loss: {meters[l_dc_loss]:.6f}\t'
                    'l-fc-loss: {meters[l_fc_loss]:.6f}\n'
                    '  r-{3}\t=>\t'
                    'r-task-loss: {meters[r_task_loss]:.6f}\t'
                    'r-dc-loss: {meters[r_dc_loss]:.6f}\t'
                    'r-fc-loss: {meters[r_fc_loss]:.6f}\n'
                    '  fd\t=>\t'
                    'l-fd-loss: {meters[l_fd_loss]:.6f}\t'
                    'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # the flaw detector uses polynomiallr [ITER_LRERS]
            self.fd_lrer.step()
            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.l_lrer.step()
                self.r_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.l_lrer.step()
            self.r_lrer.step()
Ejemplo n.º 2
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size
        ubs = self.args.unlabeled_batch_size

        self.s_model.train()
        self.t_model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # 'inp' and 'gt' are tuples
            inp, gt, mix_u_inp, mix_u_mask = self._batch_prehandle(
                inp, gt, True)
            if len(inp) > 1 and idx == 0:
                self._inp_warn()
            if len(gt) > 1 and idx == 0:
                self._gt_warn()

            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.s_optimizer.zero_grad()

            # -------------------------------------------------
            # For Labeled Samples
            # -------------------------------------------------
            l_inp = func.split_tensor_tuple(inp, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)

            # forward the labeled samples by the student model
            l_s_resulter, l_s_debugger = self.s_model.forward(l_inp)
            if not 'pred' in l_s_resulter.keys(
            ) or not 'activated_pred' in l_s_resulter.keys():
                self._pred_err()
            l_s_pred = tool.dict_value(l_s_resulter, 'pred')
            l_s_activated_pred = tool.dict_value(l_s_resulter,
                                                 'activated_pred')

            # calculate the supervised task loss on the labeled samples
            task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # -------------------------------------------------
            # For Unlabeled Samples
            # -------------------------------------------------
            if self.args.unlabeled_batch_size > 0:
                u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)

                # forward the original samples by the teacher model
                with torch.no_grad():
                    u_t_resulter, u_t_debugger = self.t_model.forward(u_inp)
                if not 'pred' in u_t_resulter.keys(
                ) or not 'activated_pred' in u_t_resulter.keys():
                    self._pred_err()
                u_t_activated_pred = tool.dict_value(u_t_resulter,
                                                     'activated_pred')

                # mix the activated pred from the teacher model as the pseudo gt
                u_t_activated_pred_1 = func.split_tensor_tuple(
                    u_t_activated_pred, 0, int(ubs / 2))
                u_t_activated_pred_2 = func.split_tensor_tuple(
                    u_t_activated_pred, int(ubs / 2), ubs)

                mix_u_t_activated_pred = []
                mix_u_t_confidence = []
                for up_1, up_2 in zip(u_t_activated_pred_1,
                                      u_t_activated_pred_2):
                    mp = mix_u_mask * up_1 + (1 - mix_u_mask) * up_2
                    mix_u_t_activated_pred.append(mp.detach())

                    # NOTE: here we just follow the official code of CutMix to calculate the confidence
                    #       but it is odd that all the samples use the same confidence (mean confidence)
                    u_t_confidence = (mp.max(dim=1)[0] >
                                      self.args.cons_threshold).float().mean()
                    mix_u_t_confidence.append(u_t_confidence.detach())

                mix_u_t_activated_pred = tuple(mix_u_t_activated_pred)

                # forward the mixed samples by the student model
                u_s_resulter, u_s_debugger = self.s_model.forward(mix_u_inp)
                if not 'pred' in u_s_resulter.keys(
                ) or not 'activated_pred' in u_s_resulter.keys():
                    self._pred_err()
                mix_u_s_activated_pred = tool.dict_value(
                    u_s_resulter, 'activated_pred')

                # calculate the consistency constraint
                cons_loss = 0
                for msap, mtap, confidence in zip(mix_u_s_activated_pred,
                                                  mix_u_t_activated_pred,
                                                  mix_u_t_confidence):
                    cons_loss += torch.mean(self.cons_criterion(
                        msap, mtap)) * confidence
                cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean(
                    cons_loss)
                self.meters.update('cons_loss', cons_loss.data)
            else:
                cons_loss = 0
                self.meters.update('cons_loss', cons_loss)

            # backward and update the student model
            loss = task_loss + cons_loss
            loss.backward()
            self.s_optimizer.step()

            # update the teacher model by EMA
            self._update_ema_variables(self.s_model, self.t_model,
                                       self.args.ema_decay, cur_step)

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  student-{3}\t=>\t'
                    's-task-loss: {meters[task_loss]:.6f}\t'
                    's-cons-loss: {meters[cons_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(l_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(l_s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(l_gt, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(mix_u_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(mix_u_s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(mix_u_t_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True), mix_u_mask[0])

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.s_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.s_lrer.step()
Ejemplo n.º 3
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._data_err()

            # TODO: support more ramp-up functions
            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.optimizer.zero_grad()

            # -----------------------------------------------------------
            # For Labeled Data
            # -----------------------------------------------------------
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # forward the wrapped CCT model
            resulter, debugger = self.model.forward(l_inp, l_gt, False)
            l_pred = tool.dict_value(resulter, 'pred')
            l_activated_pred = tool.dict_value(resulter, 'activated_pred')

            task_loss = tool.dict_value(resulter, 'task_loss', err=True)
            task_loss = task_loss.mean()
            self.meters.update('task_loss', task_loss.data)

            # -----------------------------------------------------------
            # For Unlabeled Data
            # -----------------------------------------------------------
            if self.args.unlabeled_batch_size > 0:
                ul_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size)
                ul_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)

                # forward the wrapped CCT model
                resulter, debugger = self.model.forward(ul_inp, ul_gt, True)
                ul_pred = tool.dict_value(resulter, 'pred')
                ul_activated_pred = tool.dict_value(resulter, 'activated_pred')
                ul_ad_preds = tool.dict_value(resulter, 'ul_ad_preds')

                cons_loss = tool.dict_value(resulter, 'cons_loss', err=True)
                cons_loss = cons_loss.mean()
                cons_loss = cons_rampup_scale * self.args.cons_scale * cons_loss
                self.meters.update('cons_loss', cons_loss.data)

            else:
                cons_loss = 0
                self.meters.update('cons_loss', cons_loss)

            # backward and update the wrapped CCT model
            loss = task_loss + cons_loss
            loss.backward()
            self.optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                                '  task-{3}\t=>\t'
                                'task-loss: {meters[task_loss]:.6f}\t'
                                'cons-loss: {meters[cons_loss]:.6f}\n'
                                .format(epoch, idx, len(data_loader), self.args.task, meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(epoch, idx, True, 
                                func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()
        
        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
Ejemplo n.º 4
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.s_model.train()
        self.t_model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # 's_inp', 't_inp' and 'gt' are tuples
            s_inp, t_inp, gt = self._batch_prehandle(inp, gt, True)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.s_optimizer.zero_grad()

            # forward the student model
            s_resulter, s_debugger = self.s_model.forward(s_inp)
            if not 'pred' in s_resulter.keys(
            ) or not 'activated_pred' in s_resulter.keys():
                self._pred_err()
            s_pred = tool.dict_value(s_resulter, 'pred')
            s_activated_pred = tool.dict_value(s_resulter, 'activated_pred')

            # calculate the supervised task constraint on the labeled data
            l_s_pred = func.split_tensor_tuple(s_pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_s_inp = func.split_tensor_tuple(s_inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            s_task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_s_inp)
            s_task_loss = torch.mean(s_task_loss)
            self.meters.update('s_task_loss', s_task_loss.data)

            # forward the teacher model
            with torch.no_grad():
                t_resulter, t_debugger = self.t_model.forward(t_inp)
                if not 'pred' in t_resulter.keys():
                    self._pred_err()
                t_pred = tool.dict_value(t_resulter, 'pred')
                t_activated_pred = tool.dict_value(t_resulter,
                                                   'activated_pred')

                # calculate 't_task_loss' for recording
                l_t_pred = func.split_tensor_tuple(t_pred, 0, lbs)
                l_t_inp = func.split_tensor_tuple(t_inp, 0, lbs)
                t_task_loss = self.s_criterion.forward(l_t_pred, l_gt, l_t_inp)
                t_task_loss = torch.mean(t_task_loss)
                self.meters.update('t_task_loss', t_task_loss.data)

            # calculate the consistency constraint from the teacher model to the student model
            t_pseudo_gt = Variable(t_pred[0].detach().data,
                                   requires_grad=False)

            if self.args.cons_for_labeled:
                cons_loss = self.cons_criterion(s_pred[0], t_pseudo_gt)
            elif self.args.unlabeled_batch_size > 0:
                cons_loss = self.cons_criterion(s_pred[0][lbs:, ...],
                                                t_pseudo_gt[lbs:, ...])
            else:
                cons_loss = self.zero_tensor
            cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean(
                cons_loss)
            self.meters.update('cons_loss', cons_loss.data)

            # backward and update the student model
            loss = s_task_loss + cons_loss
            loss.backward()
            self.s_optimizer.step()

            # update the teacher model by EMA
            self._update_ema_variables(self.s_model, self.t_model,
                                       self.args.ema_decay, cur_step)

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  student-{3}\t=>\t'
                    's-task-loss: {meters[s_task_loss]:.6f}\t'
                    's-cons-loss: {meters[cons_loss]:.6f}\n'
                    '  teacher-{3}\t=>\t'
                    't-task-loss: {meters[t_task_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(t_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.s_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.s_lrer.step()