Exemple #1
0
    def _validate(self, data_loader, epoch):
        self.meters.reset()

        self.model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._data_err()

            resulter, debugger = self.model.forward(inp, gt, False)
            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            task_loss = tool.dict_value(resulter, 'task_loss', err=True)
            task_loss = task_loss.mean()
            self.meters.update('task_loss', task_loss.data)

            self.task_func.metrics(activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='task')

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, False,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

        # metrics
        metrics_info = {'task': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6}\t'.format(
                            key, self.meters[key])

        logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format(
            metrics_info['task'].replace('_', '-')))
Exemple #2
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()
        self.d_model.train()

        # both 'inp' and 'gt' are tuples
        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            # -----------------------------------------------------------------------------
            # step-1: train the task model
            # -----------------------------------------------------------------------------
            self.optimizer.zero_grad()

            # forward the task model
            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            # forward the FC discriminator
            # 'confidence_map' is a tensor
            d_resulter, d_debugger = self.d_model.forward(activated_pred[0])
            confidence_map = tool.dict_value(d_resulter, 'confidence')

            # calculate the supervised task constraint on the labeled data
            l_pred = func.split_tensor_tuple(pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            task_loss = self.criterion.forward(l_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # calculate the adversarial constraint
            # calculate the adversarial constraint for the labeled data
            if self.args.adv_for_labeled:
                l_confidence_map = confidence_map[:lbs, ...]

                # preprocess prediction and ground truch for the adversarial constraint
                l_adv_confidence_map, l_adv_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(l_confidence_map, l_gt[0], True)
                l_adv_loss = self.d_criterion(l_adv_confidence_map,
                                              l_adv_confidence_gt)
                labeled_adv_loss = self.args.labeled_adv_scale * torch.mean(
                    l_adv_loss)
                self.meters.update('labeled_adv_loss', labeled_adv_loss.data)
            else:
                labeled_adv_loss = 0
                self.meters.update('labeled_adv_loss', labeled_adv_loss)

            # calculate the adversarial constraint for the unlabeled data
            if self.args.unlabeled_batch_size > 0:
                u_confidence_map = confidence_map[lbs:self.args.batch_size,
                                                  ...]

                # preprocess prediction and ground truch for the adversarial constraint
                u_adv_confidence_map, u_adv_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(u_confidence_map, None, True)
                u_adv_loss = self.d_criterion(u_adv_confidence_map,
                                              u_adv_confidence_gt)
                unlabeled_adv_loss = self.args.unlabeled_adv_scale * torch.mean(
                    u_adv_loss)
                self.meters.update('unlabeled_adv_loss',
                                   unlabeled_adv_loss.data)
            else:
                unlabeled_adv_loss = 0
                self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss)

            adv_loss = labeled_adv_loss + unlabeled_adv_loss

            # backward and update the task model
            loss = task_loss + adv_loss
            loss.backward()
            self.optimizer.step()

            # -----------------------------------------------------------------------------
            # step-2: train the FC discriminator
            # -----------------------------------------------------------------------------
            self.d_optimizer.zero_grad()

            # forward the task prediction (fake)
            if self.args.unlabeled_for_discriminator:
                fake_pred = activated_pred[0].detach()
            else:
                fake_pred = activated_pred[0][:lbs, ...].detach()

            d_resulter, d_debugger = self.d_model.forward(fake_pred)
            fake_confidence_map = tool.dict_value(d_resulter, 'confidence')

            l_fake_confidence_map = fake_confidence_map[:lbs, ...]
            l_fake_confidence_map, l_fake_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(l_fake_confidence_map, l_gt[0], False)

            if self.args.unlabeled_for_discriminator and self.args.unlabeled_batch_size != 0:
                u_fake_confidence_map = fake_confidence_map[
                    lbs:self.args.batch_size, ...]
                u_fake_confidence_map, u_fake_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(u_fake_confidence_map, None, False)

                fake_confidence_map = torch.cat(
                    (l_fake_confidence_map, u_fake_confidence_map), dim=0)
                fake_confidence_gt = torch.cat(
                    (l_fake_confidence_gt, u_fake_confidence_gt), dim=0)

            else:
                fake_confidence_map, fake_confidence_gt = l_fake_confidence_map, l_fake_confidence_gt

            fake_d_loss = self.d_criterion.forward(fake_confidence_map,
                                                   fake_confidence_gt)
            fake_d_loss = self.args.discriminator_scale * torch.mean(
                fake_d_loss)
            self.meters.update('fake_d_loss', fake_d_loss.data)

            # forward the ground truth (real)
            # convert the format of ground truch
            real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(
                l_gt[0])
            d_resulter, d_debugger = self.d_model.forward(real_gt)
            real_confidence_map = tool.dict_value(d_resulter, 'confidence')

            real_confidence_map, real_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(real_confidence_map, l_gt[0], True)

            real_d_loss = self.d_criterion(real_confidence_map,
                                           real_confidence_gt)
            real_d_loss = self.args.discriminator_scale * torch.mean(
                real_d_loss)
            self.meters.update('real_d_loss', real_d_loss.data)

            # backward and update the FC discriminator
            d_loss = (fake_d_loss + real_d_loss) / 2
            d_loss.backward()
            self.d_optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'
                    'labeled-adv-loss: {meters[labeled_adv_loss]:.6f}\t'
                    'unlabeled-adv-loss: {meters[unlabeled_adv_loss]:.6f}\n'
                    '  fc-discriminator\t=>\t'
                    'fake-d-loss: {meters[fake_d_loss]:.6f}\t'
                    'real-d-loss: {meters[real_d_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                u_inp_sample, u_pred_sample, u_cmap_sample = None, None, None
                if self.args.unlabeled_batch_size > 0:
                    u_inp_sample = func.split_tensor_tuple(inp,
                                                           lbs,
                                                           lbs + 1,
                                                           reduce_dim=True)
                    u_pred_sample = func.split_tensor_tuple(activated_pred,
                                                            lbs,
                                                            lbs + 1,
                                                            reduce_dim=True)
                    u_cmap_sample = torch.sigmoid(fake_confidence_map[lbs])

                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                    torch.sigmoid(confidence_map[0]), u_inp_sample,
                    u_pred_sample, u_cmap_sample)

            # the FC discriminator uses polynomiallr [ITER_LRERS]
            self.d_lrer.step()
            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
Exemple #3
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        original_lbs = int(self.args.labeled_batch_size / 2)
        original_bs = int(self.args.batch_size / 2)

        self.model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # the rotated samples are generated in the 'self._batch_prehandle' function
            # both 'inp' and 'gt' are tuples
            # the last element in the tuple 'gt' is the ground truth of the rotation angle
            inp, gt = self._batch_prehandle(inp, gt, True)
            if len(gt) - 1 > 1 and idx == 0:
                self._inp_warn()

            self.optimizer.zero_grad()

            # forward the model
            resulter, debugger = self.model.forward(inp)
            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')
            pred_rotation = tool.dict_value(resulter, 'rotation')

            # calculate the supervised task constraint on the un-rotated labeled data
            l_pred = func.split_tensor_tuple(pred, 0, original_lbs)
            l_gt = func.split_tensor_tuple(gt, 0, original_lbs)
            l_inp = func.split_tensor_tuple(inp, 0, original_lbs)

            unrotated_task_loss = self.criterion.forward(
                l_pred, l_gt[:-1], l_inp)
            unrotated_task_loss = torch.mean(unrotated_task_loss)
            self.meters.update('unrotated_task_loss', unrotated_task_loss.data)

            # calculate the supervised task constraint on the rotated labeled data
            l_rotated_pred = func.split_tensor_tuple(
                pred, original_bs, original_bs + original_lbs)
            l_rotated_gt = func.split_tensor_tuple(gt, original_bs,
                                                   original_bs + original_lbs)
            l_rotated_inp = func.split_tensor_tuple(inp, original_bs,
                                                    original_bs + original_lbs)

            rotated_task_loss = self.criterion.forward(l_rotated_pred,
                                                       l_rotated_gt[:-1],
                                                       l_rotated_inp)
            rotated_task_loss = self.args.rotated_sup_scale * torch.mean(
                rotated_task_loss)
            self.meters.update('rotated_task_loss', rotated_task_loss.data)

            task_loss = unrotated_task_loss + rotated_task_loss

            # calculate the self-supervised rotation constraint
            rotation_loss = self.rotation_criterion.forward(
                pred_rotation, gt[-1])
            rotation_loss = self.args.rotation_scale * torch.mean(
                rotation_loss)
            self.meters.update('rotation_loss', rotation_loss.data)

            # backward and update the model
            loss = task_loss + rotation_loss
            loss.backward()
            self.optimizer.step()

            # calculate the accuracy of the rotation classifier
            _, angle_idx = pred_rotation.topk(1, 1, True, True)
            angle_idx = angle_idx.t()
            rotation_acc = angle_idx.eq(gt[-1].view(1,
                                                    -1).expand_as(angle_idx))
            rotation_acc = rotation_acc.view(-1).float().sum(
                0, keepdim=True).mul_(100.0 / self.args.batch_size)
            self.meters.update('rotation_acc', rotation_acc.data[0])

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'unrotated-task-loss: {meters[unrotated_task_loss]:.6f}\t'
                    'rotated-task-loss: {meters[rotated_task_loss]:.6f}\n'
                    '  rotation-{3}\t=>\t'
                    'rotation-loss: {meters[rotation_loss]:.6f}\t'
                    'rotation-acc: {meters[rotation_acc]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt[:-1], 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
Exemple #4
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.s_model.train()
        self.t_model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # 's_inp', 't_inp' and 'gt' are tuples
            s_inp, t_inp, gt = self._batch_prehandle(inp, gt, True)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.s_optimizer.zero_grad()

            # forward the student model
            s_resulter, s_debugger = self.s_model.forward(s_inp)
            if not 'pred' in s_resulter.keys(
            ) or not 'activated_pred' in s_resulter.keys():
                self._pred_err()
            s_pred = tool.dict_value(s_resulter, 'pred')
            s_activated_pred = tool.dict_value(s_resulter, 'activated_pred')

            # calculate the supervised task constraint on the labeled data
            l_s_pred = func.split_tensor_tuple(s_pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_s_inp = func.split_tensor_tuple(s_inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            s_task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_s_inp)
            s_task_loss = torch.mean(s_task_loss)
            self.meters.update('s_task_loss', s_task_loss.data)

            # forward the teacher model
            with torch.no_grad():
                t_resulter, t_debugger = self.t_model.forward(t_inp)
                if not 'pred' in t_resulter.keys():
                    self._pred_err()
                t_pred = tool.dict_value(t_resulter, 'pred')
                t_activated_pred = tool.dict_value(t_resulter,
                                                   'activated_pred')

                # calculate 't_task_loss' for recording
                l_t_pred = func.split_tensor_tuple(t_pred, 0, lbs)
                l_t_inp = func.split_tensor_tuple(t_inp, 0, lbs)
                t_task_loss = self.s_criterion.forward(l_t_pred, l_gt, l_t_inp)
                t_task_loss = torch.mean(t_task_loss)
                self.meters.update('t_task_loss', t_task_loss.data)

            # calculate the consistency constraint from the teacher model to the student model
            t_pseudo_gt = Variable(t_pred[0].detach().data,
                                   requires_grad=False)

            if self.args.cons_for_labeled:
                cons_loss = self.cons_criterion(s_pred[0], t_pseudo_gt)
            elif self.args.unlabeled_batch_size > 0:
                cons_loss = self.cons_criterion(s_pred[0][lbs:, ...],
                                                t_pseudo_gt[lbs:, ...])
            else:
                cons_loss = self.zero_tensor
            cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean(
                cons_loss)
            self.meters.update('cons_loss', cons_loss.data)

            # backward and update the student model
            loss = s_task_loss + cons_loss
            loss.backward()
            self.s_optimizer.step()

            # update the teacher model by EMA
            self._update_ema_variables(self.s_model, self.t_model,
                                       self.args.ema_decay, cur_step)

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  student-{3}\t=>\t'
                    's-task-loss: {meters[s_task_loss]:.6f}\t'
                    's-cons-loss: {meters[cons_loss]:.6f}\n'
                    '  teacher-{3}\t=>\t'
                    't-task-loss: {meters[t_task_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(t_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.s_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.s_lrer.step()