コード例 #1
0
ファイル: ssl_cct.py プロジェクト: medical-projects/PixelSSL
    def _validate(self, data_loader, epoch):
        self.meters.reset()

        self.model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._data_err()

            resulter, debugger = self.model.forward(inp, gt, False)
            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            task_loss = tool.dict_value(resulter, 'task_loss', err=True)
            task_loss = task_loss.mean()
            self.meters.update('task_loss', task_loss.data)

            self.task_func.metrics(activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='task')

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, False,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

        # metrics
        metrics_info = {'task': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6}\t'.format(
                            key, self.meters[key])

        logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format(
            metrics_info['task'].replace('_', '-')))
コード例 #2
0
ファイル: proxy.py プロジェクト: lwzbuaa/PixelSSL
    def _build_ssl_algorithm(self):
        """ Build the semi-supervised learning algorithm given in the script.
        """

        for cname in self.args.models.keys():
            self.model_dict[cname] = self.model.__dict__[
                self.args.models[cname]]()
            self.criterion_dict[cname] = self.criterion.__dict__[
                self.args.criterions[cname]]()
            self.lrer_dict[cname] = nnlrer.__dict__[self.args.lrers[cname]](
                self.args)
            self.optimizer_dict[cname] = nnoptimizer.__dict__[
                self.args.optimizers[cname]](self.args)

        logger.log_info('SSL algorithm: \n  {0}\n'.format(
            self.args.ssl_algorithm))
        logger.log_info('Models: ')
        self.ssl_algorithm = pixelssl.ssl_algorithm.__dict__[
            self.args.ssl_algorithm].__dict__[self.args.ssl_algorithm](
                self.args, self.model_dict,
                self.optimizer_dict, self.lrer_dict, self.criterion_dict,
                self.func.task_func()(self.args))

        # check whether the SSL algorithm supports the given task
        if not self.TASK_TYPE in self.ssl_algorithm.SUPPORTED_TASK_TYPES:
            logger.log_err(
                'SSL algorithm - {0} - supports task types {1}\n'
                'However, the given task - {2} - belongs to {3}\n'.format(
                    self.ssl_algorithm.NAME,
                    self.ssl_algorithm.SUPPORTED_TASK_TYPES, self.args.task,
                    self.TASK_TYPE))
コード例 #3
0
def create_model(mclass, mname, **kwargs):
    """ Create a nn.Module and setup it on multiple GPUs.
    """
    model = mclass(**kwargs)
    model = torch.nn.DataParallel(model)
    model = model.cuda()

    logger.log_info('  ' + '=' * 76 +
                    '\n  {0} parameters \n{1}'.format(mname, model_str(model)))
    return model
コード例 #4
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
               task_func):
        self.task_func = task_func

        # create models
        self.task_model = func.create_model(model_funcs[0],
                                            'task_model',
                                            args=self.args).module
        self.rotation_classifier = RotationClassifer(
            self.task_func.ssls4l_rc_in_channels())

        # wrap 'self.task_model' and 'self.rotation_classifier' into a single model
        self.model = WrappedS4LModel(self.args, self.task_model,
                                     self.rotation_classifier)
        self.model = nn.DataParallel(self.model).cuda()

        # call 'patch_replication_callback' to use the `sync_batchnorm` layer
        patch_replication_callback(self.model)
        self.models = {'model': self.model}

        # create optimizers
        self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
        self.optimizers = {'optimizer': self.optimizer}

        # create lrers
        self.lrer = lrer_funcs[0](self.optimizer)
        self.lrers = {'lrer': self.lrer}

        # create criterions
        self.criterion = criterion_funcs[0](self.args)
        self.rotation_criterion = nn.CrossEntropyLoss()
        self.criterions = {
            'criterion': self.criterion,
            'rotation_criterion': self.rotation_criterion
        }

        # the batch size is doubled in S4L since it creates an extra rotated sample for each sample
        self.args.batch_size *= 2
        self.args.labeled_batch_size *= 2
        self.args.unlabeled_batch_size *= 2

        logger.log_info('In SSL_S4L algorithm, batch size are doubled: \n'
                        '  Total labeled batch size: {1}\n'
                        '  Total unlabeled batch size: {2}\n'.format(
                            self.args.lr, self.args.labeled_batch_size,
                            self.args.unlabeled_batch_size))

        self._algorithm_warn()
コード例 #5
0
    def _train(self, data_loader, epoch):
        # disable unlabeled data
        without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0
        if not without_unlabeled_data:
            logger.log_err(
                'SSL_NULL is a supervised-only algorithm\n'
                'Please set ignore_unlabeled = True and unlabeled_batch_size = 0\n'
            )

        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # both 'inp' and 'gt' are tuples
            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            self.optimizer.zero_grad()

            # forward the task model
            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            # calculate the supervised task constraint on the labeled data
            l_pred = func.split_tensor_tuple(pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            task_loss = self.criterion.forward(l_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # backward and update the task model
            loss = task_loss
            loss.backward()
            self.optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
コード例 #6
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._data_err()

            # TODO: support more ramp-up functions
            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.optimizer.zero_grad()

            # -----------------------------------------------------------
            # For Labeled Data
            # -----------------------------------------------------------
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # forward the wrapped CCT model
            resulter, debugger = self.model.forward(l_inp, l_gt, False)
            l_pred = tool.dict_value(resulter, 'pred')
            l_activated_pred = tool.dict_value(resulter, 'activated_pred')

            task_loss = tool.dict_value(resulter, 'task_loss', err=True)
            task_loss = task_loss.mean()
            self.meters.update('task_loss', task_loss.data)

            # -----------------------------------------------------------
            # For Unlabeled Data
            # -----------------------------------------------------------
            if self.args.unlabeled_batch_size > 0:
                ul_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size)
                ul_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)

                # forward the wrapped CCT model
                resulter, debugger = self.model.forward(ul_inp, ul_gt, True)
                ul_pred = tool.dict_value(resulter, 'pred')
                ul_activated_pred = tool.dict_value(resulter, 'activated_pred')
                ul_ad_preds = tool.dict_value(resulter, 'ul_ad_preds')

                cons_loss = tool.dict_value(resulter, 'cons_loss', err=True)
                cons_loss = cons_loss.mean()
                cons_loss = cons_rampup_scale * self.args.cons_scale * cons_loss
                self.meters.update('cons_loss', cons_loss.data)

            else:
                cons_loss = 0
                self.meters.update('cons_loss', cons_loss)

            # backward and update the wrapped CCT model
            loss = task_loss + cons_loss
            loss.backward()
            self.optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                                '  task-{3}\t=>\t'
                                'task-loss: {meters[task_loss]:.6f}\t'
                                'cons-loss: {meters[cons_loss]:.6f}\n'
                                .format(epoch, idx, len(data_loader), self.args.task, meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(epoch, idx, True, 
                                func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()
        
        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
コード例 #7
0
    def _validate(self, data_loader, epoch):
        self.meters.reset()

        self.s_model.eval()
        self.t_model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt, _, _ = self._batch_prehandle(inp, gt, False)
            if len(inp) > 1 and idx == 0:
                self._inp_warn()
            if len(gt) > 1 and idx == 0:
                self._gt_warn()

            s_resulter, s_debugger = self.s_model.forward(inp)
            if not 'pred' in s_resulter.keys(
            ) or not 'activated_pred' in s_resulter.keys():
                self._pred_err()
            s_pred = tool.dict_value(s_resulter, 'pred')
            s_activated_pred = tool.dict_value(s_resulter, 'activated_pred')

            s_task_loss = self.s_criterion.forward(s_pred, gt, inp)
            s_task_loss = torch.mean(s_task_loss)
            self.meters.update('s_task_loss', s_task_loss.data)

            t_resulter, t_debugger = self.t_model.forward(inp)
            if not 'pred' in t_resulter.keys(
            ) or not 'activated_pred' in t_resulter.keys():
                self._pred_err()
            t_pred = tool.dict_value(t_resulter, 'pred')
            t_activated_pred = tool.dict_value(t_resulter, 'activated_pred')

            t_task_loss = self.s_criterion.forward(t_pred, gt, inp)
            t_task_loss = torch.mean(t_task_loss)
            self.meters.update('t_task_loss', t_task_loss.data)

            t_pseudo_gt = []
            for tap in t_activated_pred:
                t_pseudo_gt.append(tap.detach())
            t_pseudo_gt = tuple(t_pseudo_gt)

            cons_loss = 0
            for sap, tpg in zip(s_activated_pred, t_pseudo_gt):
                cons_loss += torch.mean(self.cons_criterion(sap, tpg))
            cons_loss = self.args.cons_scale * torch.mean(cons_loss)
            self.meters.update('cons_loss', cons_loss.data)

            self.task_func.metrics(s_activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='student')
            self.task_func.metrics(t_activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='teacher')

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  student-{3}\t=>\t'
                    's-task-loss: {meters[s_task_loss]:.6f}\t'
                    's-cons-loss: {meters[cons_loss]:.6f}\n'
                    '  teacher-{3}\t=>\t'
                    't-task-loss: {meters[t_task_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, False,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

        # metrics
        metrics_info = {'student': '', 'teacher': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6}\t'.format(
                            key, self.meters[key])

        logger.log_info(
            'Validation metrics:\n  student-metrics\t=>\t{0}\n  teacher-metrics\t=>\t{1}\n'
            .format(metrics_info['student'].replace('_', '-'),
                    metrics_info['teacher'].replace('_', '-')))
コード例 #8
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size
        ubs = self.args.unlabeled_batch_size

        self.s_model.train()
        self.t_model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # 'inp' and 'gt' are tuples
            inp, gt, mix_u_inp, mix_u_mask = self._batch_prehandle(
                inp, gt, True)
            if len(inp) > 1 and idx == 0:
                self._inp_warn()
            if len(gt) > 1 and idx == 0:
                self._gt_warn()

            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.s_optimizer.zero_grad()

            # -------------------------------------------------
            # For Labeled Samples
            # -------------------------------------------------
            l_inp = func.split_tensor_tuple(inp, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)

            # forward the labeled samples by the student model
            l_s_resulter, l_s_debugger = self.s_model.forward(l_inp)
            if not 'pred' in l_s_resulter.keys(
            ) or not 'activated_pred' in l_s_resulter.keys():
                self._pred_err()
            l_s_pred = tool.dict_value(l_s_resulter, 'pred')
            l_s_activated_pred = tool.dict_value(l_s_resulter,
                                                 'activated_pred')

            # calculate the supervised task loss on the labeled samples
            task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # -------------------------------------------------
            # For Unlabeled Samples
            # -------------------------------------------------
            if self.args.unlabeled_batch_size > 0:
                u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)

                # forward the original samples by the teacher model
                with torch.no_grad():
                    u_t_resulter, u_t_debugger = self.t_model.forward(u_inp)
                if not 'pred' in u_t_resulter.keys(
                ) or not 'activated_pred' in u_t_resulter.keys():
                    self._pred_err()
                u_t_activated_pred = tool.dict_value(u_t_resulter,
                                                     'activated_pred')

                # mix the activated pred from the teacher model as the pseudo gt
                u_t_activated_pred_1 = func.split_tensor_tuple(
                    u_t_activated_pred, 0, int(ubs / 2))
                u_t_activated_pred_2 = func.split_tensor_tuple(
                    u_t_activated_pred, int(ubs / 2), ubs)

                mix_u_t_activated_pred = []
                mix_u_t_confidence = []
                for up_1, up_2 in zip(u_t_activated_pred_1,
                                      u_t_activated_pred_2):
                    mp = mix_u_mask * up_1 + (1 - mix_u_mask) * up_2
                    mix_u_t_activated_pred.append(mp.detach())

                    # NOTE: here we just follow the official code of CutMix to calculate the confidence
                    #       but it is odd that all the samples use the same confidence (mean confidence)
                    u_t_confidence = (mp.max(dim=1)[0] >
                                      self.args.cons_threshold).float().mean()
                    mix_u_t_confidence.append(u_t_confidence.detach())

                mix_u_t_activated_pred = tuple(mix_u_t_activated_pred)

                # forward the mixed samples by the student model
                u_s_resulter, u_s_debugger = self.s_model.forward(mix_u_inp)
                if not 'pred' in u_s_resulter.keys(
                ) or not 'activated_pred' in u_s_resulter.keys():
                    self._pred_err()
                mix_u_s_activated_pred = tool.dict_value(
                    u_s_resulter, 'activated_pred')

                # calculate the consistency constraint
                cons_loss = 0
                for msap, mtap, confidence in zip(mix_u_s_activated_pred,
                                                  mix_u_t_activated_pred,
                                                  mix_u_t_confidence):
                    cons_loss += torch.mean(self.cons_criterion(
                        msap, mtap)) * confidence
                cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean(
                    cons_loss)
                self.meters.update('cons_loss', cons_loss.data)
            else:
                cons_loss = 0
                self.meters.update('cons_loss', cons_loss)

            # backward and update the student model
            loss = task_loss + cons_loss
            loss.backward()
            self.s_optimizer.step()

            # update the teacher model by EMA
            self._update_ema_variables(self.s_model, self.t_model,
                                       self.args.ema_decay, cur_step)

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  student-{3}\t=>\t'
                    's-task-loss: {meters[task_loss]:.6f}\t'
                    's-cons-loss: {meters[cons_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(l_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(l_s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(l_gt, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(mix_u_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(mix_u_s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(mix_u_t_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True), mix_u_mask[0])

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.s_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.s_lrer.step()
コード例 #9
0
ファイル: proxy.py プロジェクト: lwzbuaa/PixelSSL
    def _create_dataloader(self):
        """ Create data loaders for experiment.
        """

        # ---------------------------------------------------------------------
        # create dataloder for training
        # ---------------------------------------------------------------------

        # ignore_unlabeled == False & unlabeled_batch_size != 0
        #   means that both labeled and unlabeled data are used
        with_unlabeled_data = not self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0
        # ignore_unlabeled == True & unlabeled_batch_size == 0
        #   means that only the labeled data is used
        without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0

        labeled_train_samples, unlabeled_train_samples = 0, 0
        if not self.args.validation:
            # ignore_unlabeled == True & unlabeled_batch_size != 0 -> error
            if self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0:
                logger.log_err(
                    'Arguments conflict => ignore_unlabeled == True requires unlabeled_batch_size == 0\n'
                )
            # ignore_unlabeled == False & unlabeled_batch_size == 0 -> error
            if not self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0:
                logger.log_err(
                    'Arguments conflict => ignore_unlabeled == False requires unlabeled_batch_size != 0\n'
                )

            # calculate the number of trainsets
            trainset_num = 0
            for key, value in self.args.trainset.items():
                trainset_num += len(value)

            # calculate the number of unlabeledsets
            unlabeledset_num = 0
            for key, value in self.args.unlabeledset.items():
                unlabeledset_num += len(value)

            # if only one labeled training set and without any unlabeled set
            if trainset_num == 1 and unlabeledset_num == 0:
                trainset = self._load_dataset(
                    list(self.args.trainset.keys())[0],
                    list(self.args.trainset.values())[0][0])
                labeled_train_samples = len(trainset.idxs)

                # if the 'sublabeled_path' is given
                sublabeled_prefix = None
                if self.args.sublabeled_path is not None and self.args.sublabeled_path != '':
                    if not os.path.exists(self.args.sublabeled_path):
                        logger.log_err(
                            'Cannot find labeled file: {0}\n'.format(
                                self.args.sublabeled_path))
                    else:
                        with open(self.args.sublabeled_path) as f:
                            sublabeled_prefix = [
                                line.strip() for line in f.read().splitlines()
                            ]
                        sublabeled_prefix = None if len(
                            sublabeled_prefix) == 0 else sublabeled_prefix

                if sublabeled_prefix is not None:
                    # wrap the trainset by 'SplitUnlabeledWrapper'
                    trainset = nndata.SplitUnlabeledWrapper(
                        trainset,
                        sublabeled_prefix,
                        ignore_unlabeled=self.args.ignore_unlabeled)
                    labeled_train_samples = len(trainset.labeled_idxs)
                    unlabeled_train_samples = len(trainset.unlabeled_idxs)

                # if 'sublabeled_prefix' is None but you want to use the unlabeled data for training
                elif with_unlabeled_data:
                    logger.log_err(
                        'Try to use the unlabeled samples without any SSL dataset wrapper\n'
                    )

            # if more than one labeled training sets are given or the unlabeled training sets are given
            elif trainset_num > 1 or unlabeledset_num > 0:
                # 'arg.sublabeled_path' is disabled
                if self.args.sublabeled_path is not None and self.args.sublabeled_path != '':
                    logger.log_err(
                        'Multiple training datasets are given. \n'
                        'Inter-split unlabeled set is not allowed.\n'
                        'Please remove the argument \'sublabeled_path\' in the script\n'
                    )

                # load all training sets
                labeled_sets = []
                for set_name, set_dirs in self.args.trainset.items():
                    for set_dir in set_dirs:
                        labeled_sets.append(
                            self._load_dataset(set_name, set_dir))

                # load all extra unlabeled sets
                unlabeled_sets = []
                # if any extra unlabeled set is given
                if unlabeledset_num > 0:
                    for set_name, set_dirs in self.args.unlabeledset.items():
                        for set_dir in set_dirs:
                            unlabeled_sets.append(
                                self._load_dataset(set_name, set_dir))

                # if unalbeledset_num == 0 but you want to use the unlabeled data for training
                elif with_unlabeled_data:
                    logger.log_err(
                        'Try to use the unlabeled samples without any SSL dataset wrapper\n'
                        'Please add the argument \'unlabeledset\' in the script\n'
                    )

                # wrap both 'labeled_set' and 'unlabeled_set' by 'JointDatasetsWrapper'
                trainset = nndata.JointDatasetsWrapper(
                    labeled_sets,
                    unlabeled_sets,
                    ignore_unlabeled=self.args.ignore_unlabeled)
                labeled_train_samples = len(trainset.labeled_idxs)
                unlabeled_train_samples = len(trainset.unlabeled_idxs)

            # if use labeled data only
            if without_unlabeled_data:
                self.train_loader = torch.utils.data.DataLoader(
                    trainset,
                    batch_size=self.args.batch_size,
                    shuffle=True,
                    num_workers=self.args.num_workers,
                    pin_memory=True,
                    drop_last=True)
            # if use both labeled and unlabeled data
            elif with_unlabeled_data:
                train_sampler = nndata.TwoStreamBatchSampler(
                    trainset.labeled_idxs, trainset.unlabeled_idxs,
                    self.args.labeled_batch_size,
                    self.args.unlabeled_batch_size)
                self.train_loader = torch.utils.data.DataLoader(
                    trainset,
                    batch_sampler=train_sampler,
                    num_workers=self.args.num_workers,
                    pin_memory=True)

        # ---------------------------------------------------------------------
        # create dataloader for validation
        # ---------------------------------------------------------------------

        # calculate the number of valsets
        valset_num = 0
        for key, value in self.args.valset.items():
            valset_num += len(value)

        # if only one validation set is given
        if valset_num == 1:
            valset = self._load_dataset(list(self.args.valset.keys())[0],
                                        list(self.args.valset.values())[0][0],
                                        is_train=False)
            val_samples = len(valset.idxs)

        # if more than one validation sets are given
        elif valset_num > 1:
            valsets = []
            for set_name, set_dirs in self.args.valset.items():
                for set_dir in set_dirs:
                    valsets.append(
                        self._load_dataset(set_name, set_dir, is_train=False))
            valset = nndata.JointDatasetsWrapper(valsets, [],
                                                 ignore_unlabeled=True)
            val_samples = len(valset.labeled_idxs)

        # NOTE: batch size is set to 1 during the validation
        self.val_loader = torch.utils.data.DataLoader(
            valset,
            batch_size=1,
            shuffle=False,
            num_workers=self.args.num_workers,
            pin_memory=True)

        # check the data loaders
        if self.train_loader is None and not self.args.validation:
            logger.log_err(
                'Train data loader is required if validate mode is closed\n')
        elif self.val_loader is None and self.args.validation:
            logger.log_err(
                'Validate data loader is required if validate mode is opened\n'
            )
        elif self.val_loader is None:
            logger.log_warn(
                'No validate data loader, there are no validation during the training\n'
            )

        # set 'iters_per_epoch', which is required by ITER_LRERS
        self.args.iters_per_epoch = len(
            self.train_loader) if self.train_loader is not None else -1

        logger.log_info(
            'Dataset:\n'
            '  Trainset\t=>\tlabeled samples = {0}\t\tunlabeled samples = {1}\n'
            '  Valset\t=>\tsamples = {2}\n'.format(labeled_train_samples,
                                                   unlabeled_train_samples,
                                                   val_samples))
コード例 #10
0
ファイル: proxy.py プロジェクト: lwzbuaa/PixelSSL
    def _run(self):
        """ Main pipeline of experiment.

        Please override this function if you want a special pipeline.
        """

        start_epoch = 0
        if self.args.resume is not None and self.args.resume != '':
            logger.log_info('Load checkpoint from: {0}'.format(
                self.args.resume))
            start_epoch = self.ssl_algorithm.load_checkpoint()

        if self.args.validation:
            if self.val_loader is None:
                logger.log_err('No data loader for validation.\n'
                               'Please set right \'valset\' in the script.\n')

            logger.log_info(
                ['=' * 78, '\nStart to validate model\n', '=' * 78])
            with torch.no_grad():
                self.ssl_algorithm.validate(self.val_loader, start_epoch - 1)
            return

        for epoch in range(start_epoch, self.args.epochs):
            timer = time.time()

            logger.log_info([
                '=' * 78, '\nStart to train epoch-{0}\n'.format(epoch),
                '=' * 78
            ])
            self.ssl_algorithm.train(self.train_loader, epoch)

            if epoch % self.args.val_freq == 0 and self.val_loader is not None:
                logger.log_info([
                    '=' * 78, '\nStart to validate epoch-{0}\n'.format(epoch),
                    '=' * 78
                ])
                with torch.no_grad():
                    self.ssl_algorithm.validate(self.val_loader, epoch)

            if (epoch + 1) % self.args.checkpoint_freq == 0:
                self.ssl_algorithm.save_checkpoint(epoch + 1)
                logger.log_info("Save checkpoint for epoch {0}".format(epoch))

            logger.log_info(
                'Finish epoch in {0} seconds\n'.format(time.time() - timer))

        logger.log_info('Finish experiment {0}\n'.format(self.args.exp_id))
コード例 #11
0
ファイル: ssl_adv.py プロジェクト: medical-projects/PixelSSL
    def _validate(self, data_loader, epoch):
        self.meters.reset()

        self.model.eval()
        self.d_model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            task_loss = self.criterion.forward(pred, gt, inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            d_resulter, d_debugger = self.d_model.forward(activated_pred[0])
            unhandled_fake_confidence_map = tool.dict_value(
                d_resulter, 'confidence')
            fake_confidence_map, fake_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(unhandled_fake_confidence_map, gt[0], False)

            fake_d_loss = self.d_criterion.forward(fake_confidence_map,
                                                   fake_confidence_gt)
            fake_d_loss = self.args.discriminator_scale * torch.mean(
                fake_d_loss)
            self.meters.update('fake_d_loss', fake_d_loss.data)

            real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(gt[0])
            d_resulter, d_debugger = self.d_model.forward(real_gt)
            unhandled_real_confidence_map = tool.dict_value(
                d_resulter, 'confidence')
            real_confidence_map, real_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(unhandled_real_confidence_map, gt[0], True)

            real_d_loss = self.d_criterion.forward(real_confidence_map,
                                                   real_confidence_gt)
            real_d_loss = self.args.discriminator_scale * torch.mean(
                real_d_loss)
            self.meters.update('real_d_loss', real_d_loss.data)

            self.task_func.metrics(activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='task')

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'
                    '  fc-discriminator\t=>\t'
                    'fake-d-loss: {meters[fake_d_loss]:.6f}\t'
                    'real-d-loss: {meters[real_d_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, False,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                    torch.sigmoid(unhandled_fake_confidence_map[0]))

        # metrics
        metrics_info = {'task': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6}\t'.format(
                            key, self.meters[key])

        logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format(
            metrics_info['task'].replace('_', '-')))
コード例 #12
0
ファイル: ssl_adv.py プロジェクト: medical-projects/PixelSSL
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()
        self.d_model.train()

        # both 'inp' and 'gt' are tuples
        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            # -----------------------------------------------------------------------------
            # step-1: train the task model
            # -----------------------------------------------------------------------------
            self.optimizer.zero_grad()

            # forward the task model
            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            # forward the FC discriminator
            # 'confidence_map' is a tensor
            d_resulter, d_debugger = self.d_model.forward(activated_pred[0])
            confidence_map = tool.dict_value(d_resulter, 'confidence')

            # calculate the supervised task constraint on the labeled data
            l_pred = func.split_tensor_tuple(pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            task_loss = self.criterion.forward(l_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # calculate the adversarial constraint
            # calculate the adversarial constraint for the labeled data
            if self.args.adv_for_labeled:
                l_confidence_map = confidence_map[:lbs, ...]

                # preprocess prediction and ground truch for the adversarial constraint
                l_adv_confidence_map, l_adv_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(l_confidence_map, l_gt[0], True)
                l_adv_loss = self.d_criterion(l_adv_confidence_map,
                                              l_adv_confidence_gt)
                labeled_adv_loss = self.args.labeled_adv_scale * torch.mean(
                    l_adv_loss)
                self.meters.update('labeled_adv_loss', labeled_adv_loss.data)
            else:
                labeled_adv_loss = 0
                self.meters.update('labeled_adv_loss', labeled_adv_loss)

            # calculate the adversarial constraint for the unlabeled data
            if self.args.unlabeled_batch_size > 0:
                u_confidence_map = confidence_map[lbs:self.args.batch_size,
                                                  ...]

                # preprocess prediction and ground truch for the adversarial constraint
                u_adv_confidence_map, u_adv_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(u_confidence_map, None, True)
                u_adv_loss = self.d_criterion(u_adv_confidence_map,
                                              u_adv_confidence_gt)
                unlabeled_adv_loss = self.args.unlabeled_adv_scale * torch.mean(
                    u_adv_loss)
                self.meters.update('unlabeled_adv_loss',
                                   unlabeled_adv_loss.data)
            else:
                unlabeled_adv_loss = 0
                self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss)

            adv_loss = labeled_adv_loss + unlabeled_adv_loss

            # backward and update the task model
            loss = task_loss + adv_loss
            loss.backward()
            self.optimizer.step()

            # -----------------------------------------------------------------------------
            # step-2: train the FC discriminator
            # -----------------------------------------------------------------------------
            self.d_optimizer.zero_grad()

            # forward the task prediction (fake)
            if self.args.unlabeled_for_discriminator:
                fake_pred = activated_pred[0].detach()
            else:
                fake_pred = activated_pred[0][:lbs, ...].detach()

            d_resulter, d_debugger = self.d_model.forward(fake_pred)
            fake_confidence_map = tool.dict_value(d_resulter, 'confidence')

            l_fake_confidence_map = fake_confidence_map[:lbs, ...]
            l_fake_confidence_map, l_fake_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(l_fake_confidence_map, l_gt[0], False)

            if self.args.unlabeled_for_discriminator and self.args.unlabeled_batch_size != 0:
                u_fake_confidence_map = fake_confidence_map[
                    lbs:self.args.batch_size, ...]
                u_fake_confidence_map, u_fake_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(u_fake_confidence_map, None, False)

                fake_confidence_map = torch.cat(
                    (l_fake_confidence_map, u_fake_confidence_map), dim=0)
                fake_confidence_gt = torch.cat(
                    (l_fake_confidence_gt, u_fake_confidence_gt), dim=0)

            else:
                fake_confidence_map, fake_confidence_gt = l_fake_confidence_map, l_fake_confidence_gt

            fake_d_loss = self.d_criterion.forward(fake_confidence_map,
                                                   fake_confidence_gt)
            fake_d_loss = self.args.discriminator_scale * torch.mean(
                fake_d_loss)
            self.meters.update('fake_d_loss', fake_d_loss.data)

            # forward the ground truth (real)
            # convert the format of ground truch
            real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(
                l_gt[0])
            d_resulter, d_debugger = self.d_model.forward(real_gt)
            real_confidence_map = tool.dict_value(d_resulter, 'confidence')

            real_confidence_map, real_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(real_confidence_map, l_gt[0], True)

            real_d_loss = self.d_criterion(real_confidence_map,
                                           real_confidence_gt)
            real_d_loss = self.args.discriminator_scale * torch.mean(
                real_d_loss)
            self.meters.update('real_d_loss', real_d_loss.data)

            # backward and update the FC discriminator
            d_loss = (fake_d_loss + real_d_loss) / 2
            d_loss.backward()
            self.d_optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'
                    'labeled-adv-loss: {meters[labeled_adv_loss]:.6f}\t'
                    'unlabeled-adv-loss: {meters[unlabeled_adv_loss]:.6f}\n'
                    '  fc-discriminator\t=>\t'
                    'fake-d-loss: {meters[fake_d_loss]:.6f}\t'
                    'real-d-loss: {meters[real_d_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                u_inp_sample, u_pred_sample, u_cmap_sample = None, None, None
                if self.args.unlabeled_batch_size > 0:
                    u_inp_sample = func.split_tensor_tuple(inp,
                                                           lbs,
                                                           lbs + 1,
                                                           reduce_dim=True)
                    u_pred_sample = func.split_tensor_tuple(activated_pred,
                                                            lbs,
                                                            lbs + 1,
                                                            reduce_dim=True)
                    u_cmap_sample = torch.sigmoid(fake_confidence_map[lbs])

                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                    torch.sigmoid(confidence_map[0]), u_inp_sample,
                    u_pred_sample, u_cmap_sample)

            # the FC discriminator uses polynomiallr [ITER_LRERS]
            self.d_lrer.step()
            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
コード例 #13
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        original_lbs = int(self.args.labeled_batch_size / 2)
        original_bs = int(self.args.batch_size / 2)

        self.model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # the rotated samples are generated in the 'self._batch_prehandle' function
            # both 'inp' and 'gt' are tuples
            # the last element in the tuple 'gt' is the ground truth of the rotation angle
            inp, gt = self._batch_prehandle(inp, gt, True)
            if len(gt) - 1 > 1 and idx == 0:
                self._inp_warn()

            self.optimizer.zero_grad()

            # forward the model
            resulter, debugger = self.model.forward(inp)
            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')
            pred_rotation = tool.dict_value(resulter, 'rotation')

            # calculate the supervised task constraint on the un-rotated labeled data
            l_pred = func.split_tensor_tuple(pred, 0, original_lbs)
            l_gt = func.split_tensor_tuple(gt, 0, original_lbs)
            l_inp = func.split_tensor_tuple(inp, 0, original_lbs)

            unrotated_task_loss = self.criterion.forward(
                l_pred, l_gt[:-1], l_inp)
            unrotated_task_loss = torch.mean(unrotated_task_loss)
            self.meters.update('unrotated_task_loss', unrotated_task_loss.data)

            # calculate the supervised task constraint on the rotated labeled data
            l_rotated_pred = func.split_tensor_tuple(
                pred, original_bs, original_bs + original_lbs)
            l_rotated_gt = func.split_tensor_tuple(gt, original_bs,
                                                   original_bs + original_lbs)
            l_rotated_inp = func.split_tensor_tuple(inp, original_bs,
                                                    original_bs + original_lbs)

            rotated_task_loss = self.criterion.forward(l_rotated_pred,
                                                       l_rotated_gt[:-1],
                                                       l_rotated_inp)
            rotated_task_loss = self.args.rotated_sup_scale * torch.mean(
                rotated_task_loss)
            self.meters.update('rotated_task_loss', rotated_task_loss.data)

            task_loss = unrotated_task_loss + rotated_task_loss

            # calculate the self-supervised rotation constraint
            rotation_loss = self.rotation_criterion.forward(
                pred_rotation, gt[-1])
            rotation_loss = self.args.rotation_scale * torch.mean(
                rotation_loss)
            self.meters.update('rotation_loss', rotation_loss.data)

            # backward and update the model
            loss = task_loss + rotation_loss
            loss.backward()
            self.optimizer.step()

            # calculate the accuracy of the rotation classifier
            _, angle_idx = pred_rotation.topk(1, 1, True, True)
            angle_idx = angle_idx.t()
            rotation_acc = angle_idx.eq(gt[-1].view(1,
                                                    -1).expand_as(angle_idx))
            rotation_acc = rotation_acc.view(-1).float().sum(
                0, keepdim=True).mul_(100.0 / self.args.batch_size)
            self.meters.update('rotation_acc', rotation_acc.data[0])

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'unrotated-task-loss: {meters[unrotated_task_loss]:.6f}\t'
                    'rotated-task-loss: {meters[rotated_task_loss]:.6f}\n'
                    '  rotation-{3}\t=>\t'
                    'rotation-loss: {meters[rotation_loss]:.6f}\t'
                    'rotation-acc: {meters[rotation_acc]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt[:-1], 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()
コード例 #14
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.s_model.train()
        self.t_model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # 's_inp', 't_inp' and 'gt' are tuples
            s_inp, t_inp, gt = self._batch_prehandle(inp, gt, True)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            # calculate the ramp-up coefficient of the consistency constraint
            cur_step = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.cons_rampup_epochs
            cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps)

            self.s_optimizer.zero_grad()

            # forward the student model
            s_resulter, s_debugger = self.s_model.forward(s_inp)
            if not 'pred' in s_resulter.keys(
            ) or not 'activated_pred' in s_resulter.keys():
                self._pred_err()
            s_pred = tool.dict_value(s_resulter, 'pred')
            s_activated_pred = tool.dict_value(s_resulter, 'activated_pred')

            # calculate the supervised task constraint on the labeled data
            l_s_pred = func.split_tensor_tuple(s_pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_s_inp = func.split_tensor_tuple(s_inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            s_task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_s_inp)
            s_task_loss = torch.mean(s_task_loss)
            self.meters.update('s_task_loss', s_task_loss.data)

            # forward the teacher model
            with torch.no_grad():
                t_resulter, t_debugger = self.t_model.forward(t_inp)
                if not 'pred' in t_resulter.keys():
                    self._pred_err()
                t_pred = tool.dict_value(t_resulter, 'pred')
                t_activated_pred = tool.dict_value(t_resulter,
                                                   'activated_pred')

                # calculate 't_task_loss' for recording
                l_t_pred = func.split_tensor_tuple(t_pred, 0, lbs)
                l_t_inp = func.split_tensor_tuple(t_inp, 0, lbs)
                t_task_loss = self.s_criterion.forward(l_t_pred, l_gt, l_t_inp)
                t_task_loss = torch.mean(t_task_loss)
                self.meters.update('t_task_loss', t_task_loss.data)

            # calculate the consistency constraint from the teacher model to the student model
            t_pseudo_gt = Variable(t_pred[0].detach().data,
                                   requires_grad=False)

            if self.args.cons_for_labeled:
                cons_loss = self.cons_criterion(s_pred[0], t_pseudo_gt)
            elif self.args.unlabeled_batch_size > 0:
                cons_loss = self.cons_criterion(s_pred[0][lbs:, ...],
                                                t_pseudo_gt[lbs:, ...])
            else:
                cons_loss = self.zero_tensor
            cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean(
                cons_loss)
            self.meters.update('cons_loss', cons_loss.data)

            # backward and update the student model
            loss = s_task_loss + cons_loss
            loss.backward()
            self.s_optimizer.step()

            # update the teacher model by EMA
            self._update_ema_variables(self.s_model, self.t_model,
                                       self.args.ema_decay, cur_step)

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  student-{3}\t=>\t'
                    's-task-loss: {meters[s_task_loss]:.6f}\t'
                    's-cons-loss: {meters[cons_loss]:.6f}\n'
                    '  teacher-{3}\t=>\t'
                    't-task-loss: {meters[t_task_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(s_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(t_activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.s_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.s_lrer.step()
コード例 #15
0
    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.l_model.train()
        self.r_model.train()
        self.fd_model.train()

        # both 'inp' and 'gt' are tuples
        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt)
            if len(l_gt) == len(r_gt) > 1 and idx == 0:
                self._inp_warn()

            # calculate the ramp-up coefficient of the dynamic consistency constraint
            cur_steps = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.dc_rampup_epochs
            dc_rampup_scale = func.sigmoid_rampup(cur_steps, total_steps)

            # -----------------------------------------------------------------------------
            # step-0: pre-forwarding to save GPU memory
            #   - forward the task models and the flaw detector
            #   - generate pseudo ground truth for the unlabeled data if the dynamic
            #     consistency constraint is enabled
            # -----------------------------------------------------------------------------
            with torch.no_grad():
                l_resulter, l_debugger = self.l_model.forward(l_inp)
                l_activated_pred = tool.dict_value(l_resulter,
                                                   'activated_pred')
                r_resulter, r_debugger = self.r_model.forward(r_inp)
                r_activated_pred = tool.dict_value(r_resulter,
                                                   'activated_pred')

            # 'l_flawmap' and 'r_flawmap' will be used in step-2
            fd_resulter, fd_debugger = self.fd_model.forward(
                l_inp, l_activated_pred[0])
            l_flawmap = tool.dict_value(fd_resulter, 'flawmap')
            fd_resulter, fd_debugger = self.fd_model.forward(
                r_inp, r_activated_pred[0])
            r_flawmap = tool.dict_value(fd_resulter, 'flawmap')

            l_dc_gt, r_dc_gt = None, None
            l_fc_mask, r_fc_mask = None, None

            # generate the pseudo ground truth for the dynamic consistency constraint
            if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
                with torch.no_grad():
                    l_handled_flawmap = self.flawmap_handler.forward(l_flawmap)
                    r_handled_flawmap = self.flawmap_handler.forward(r_flawmap)
                    l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward(
                        l_activated_pred[0].detach(),
                        r_activated_pred[0].detach(), l_handled_flawmap,
                        r_handled_flawmap)

            # -----------------------------------------------------------------------------
            # step-1: train the task models
            # -----------------------------------------------------------------------------
            for param in self.fd_model.parameters():
                param.requires_grad = False

            # train the 'l' task model
            l_loss = self._task_model_iter(epoch, idx, True, 'l', lbs, l_inp,
                                           l_gt, l_dc_gt, l_fc_mask,
                                           dc_rampup_scale)
            self.l_optimizer.zero_grad()
            l_loss.backward()
            self.l_optimizer.step()

            # train the 'r' task model
            r_loss = self._task_model_iter(epoch, idx, True, 'r', lbs, r_inp,
                                           r_gt, r_dc_gt, r_fc_mask,
                                           dc_rampup_scale)
            self.r_optimizer.zero_grad()
            r_loss.backward()
            self.r_optimizer.step()

            # -----------------------------------------------------------------------------
            # step-2: train the flaw detector
            # -----------------------------------------------------------------------------
            for param in self.fd_model.parameters():
                param.requires_grad = True

            # generate the ground truth for the flaw detector (on labeled data only)
            with torch.no_grad():
                l_flawmap_gt = self.fdgt_generator.forward(
                    l_activated_pred[0][:lbs, ...].detach(),
                    self.task_func.sslgct_prepare_task_gt_for_fdgt(
                        l_gt[0][:lbs, ...]))
                r_flawmap_gt = self.fdgt_generator.forward(
                    r_activated_pred[0][:lbs, ...].detach(),
                    self.task_func.sslgct_prepare_task_gt_for_fdgt(
                        r_gt[0][:lbs, ...]))

            l_fd_loss = self.fd_criterion.forward(l_flawmap[:lbs, ...],
                                                  l_flawmap_gt)
            l_fd_loss = self.args.fd_scale * torch.mean(l_fd_loss)
            self.meters.update('l_fd_loss', l_fd_loss.data)

            r_fd_loss = self.fd_criterion.forward(r_flawmap[:lbs, ...],
                                                  r_flawmap_gt)
            r_fd_loss = self.args.fd_scale * torch.mean(r_fd_loss)
            self.meters.update('r_fd_loss', r_fd_loss.data)

            fd_loss = (l_fd_loss + r_fd_loss) / 2

            self.fd_optimizer.zero_grad()
            fd_loss.backward()
            self.fd_optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  l-{3}\t=>\t'
                    'l-task-loss: {meters[l_task_loss]:.6f}\t'
                    'l-dc-loss: {meters[l_dc_loss]:.6f}\t'
                    'l-fc-loss: {meters[l_fc_loss]:.6f}\n'
                    '  r-{3}\t=>\t'
                    'r-task-loss: {meters[r_task_loss]:.6f}\t'
                    'r-dc-loss: {meters[r_dc_loss]:.6f}\t'
                    'r-fc-loss: {meters[r_fc_loss]:.6f}\n'
                    '  fd\t=>\t'
                    'l-fd-loss: {meters[l_fd_loss]:.6f}\t'
                    'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # the flaw detector uses polynomiallr [ITER_LRERS]
            self.fd_lrer.step()
            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.l_lrer.step()
                self.r_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.l_lrer.step()
            self.r_lrer.step()
コード例 #16
0
    def _validate(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.l_model.eval()
        self.r_model.eval()
        self.fd_model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt)
            if len(l_gt) == len(r_gt) > 1 and idx == 0:
                self._inp_warn()

            l_dc_gt, r_dc_gt = None, None
            l_fc_mask, r_fc_mask = None, None
            if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
                l_resulter, l_debugger = self.l_model.forward(l_inp)
                l_activated_pred = tool.dict_value(l_resulter,
                                                   'activated_pred')
                r_resulter, r_debugger = self.r_model.forward(r_inp)
                r_activated_pred = tool.dict_value(r_resulter,
                                                   'activated_pred')

                fd_resulter, fd_debugger = self.fd_model.forward(
                    l_inp, l_activated_pred[0])
                l_flawmap = tool.dict_value(fd_resulter, 'flawmap')
                fd_resulter, fd_debugger = self.fd_model.forward(
                    r_inp, r_activated_pred[0])
                r_flawmap = tool.dict_value(fd_resulter, 'flawmap')

                l_handled_flawmap = self.flawmap_handler.forward(l_flawmap)
                r_handled_flawmap = self.flawmap_handler.forward(r_flawmap)
                l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward(
                    l_activated_pred[0].detach(), r_activated_pred[0].detach(),
                    l_handled_flawmap, r_handled_flawmap)

            l_loss = self._task_model_iter(epoch, idx, False, 'l', lbs, l_inp,
                                           l_gt, l_dc_gt, l_fc_mask, 1)
            r_loss = self._task_model_iter(epoch, idx, False, 'r', lbs, r_inp,
                                           r_gt, r_dc_gt, r_fc_mask, 1)

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  l-{3}\t=>\t'
                    'l-task-loss: {meters[l_task_loss]:.6f}\t'
                    'l-dc-loss: {meters[l_dc_loss]:.6f}\t'
                    'l-fc-loss: {meters[l_fc_loss]:.6f}\n'
                    '  r-{3}\t=>\t'
                    'r-task-loss: {meters[r_task_loss]:.6f}\t'
                    'r-dc-loss: {meters[r_dc_loss]:.6f}\t'
                    'r-fc-loss: {meters[r_fc_loss]:.6f}\n'
                    '  fd\t=>\t'
                    'l-fd-loss: {meters[l_fd_loss]:.6f}\t'
                    'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

        # metrics
        metrics_info = {'l': '', 'r': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6f}\t'.format(
                            key, self.meters[key])

        logger.log_info(
            'Validation metrics:\n  l-metrics\t=>\t{0}\n  r-metrics\t=>\t{1}\n'
            .format(metrics_info['l'].replace('_', '-'),
                    metrics_info['r'].replace('_', '-')))
コード例 #17
0
ファイル: proxy.py プロジェクト: lwzbuaa/PixelSSL
    def _preprocess_arguments(self):
        """ Preprocess the arguments in the script.
        """

        # create the output folder to store the results
        self.args.out_path = "{root}/{exp_id}/{date:%Y-%m-%d_%H:%M:%S}/".format(
            root=self.args.out_path,
            exp_id=self.args.exp_id,
            date=datetime.now())
        if not os.path.exists(self.args.out_path):
            os.makedirs(self.args.out_path)

        # prepare logger
        exp_op = 'val' if self.args.validation else 'train'
        logger.log_mode(self.args.debug)
        logger.log_file(
            os.path.join(self.args.out_path, '{0}.log'.format(exp_op)),
            self.args.debug)

        logger.log_info('Result folder: \n  {0} \n'.format(self.args.out_path))

        # print experimental args
        cmd.print_args()

        # set task name
        self.args.task = self.NAME

        # check the task-specific components dicts required by the SSL algorithm
        if not len(self.args.models) == len(self.args.optimizers) == len(
                self.args.lrers) == len(self.args.criterions):
            logger.log_err(
                'Condition:\n'
                '\tlen(self.args.models) == len(self.args.optimizers) == len(self.args.lrers) == len(self.args.criterions\n'
                'is not satisfied in the script\n')

        for (model, criterion, optimizer, lrer) in \
            zip(self.args.models.values(), self.args.criterions.values(), self.args.optimizers.values(), self.args.lrers.values()):
            if model not in self.model.__dict__:
                logger.log_err(
                    'Unsupport model: {0} for task: {1}\n'
                    'Please add the export function in task\'s \'model.py\'\n'.
                    format(model, self.args.task))
            elif criterion not in self.criterion.__dict__:
                logger.log_err(
                    'Unsupport criterion: {0} for task: {1}\n'
                    'Please add the export function in task\'s \'criterion.py\'\n'
                    .format(criterion, self.args.task))
            elif optimizer not in nnoptimizer.__dict__:
                logger.log_err(
                    'Unsupport optimizer: {0}\n'
                    'Please implement the optimizer wrapper in \'pixelssl/nn/optimizer.py\'\n'
                    .format(optimizer))
            elif lrer not in nnlrer.__dict__:
                logger.log_err(
                    'Unsupport learning rate scheduler: {0}\n'
                    'Please implement lr scheduler wrapper in \'pixelssl/nn/lrer.py\'\n'
                    .format(lrer))

        # check the types of lrers
        for lrer in self.args.lrers.values():
            if lrer in nnlrer.EPOCH_LRERS:
                is_epoch_lrer = True
            elif lrer in nnlrer.ITER_LRERS:
                is_epoch_lrer = False
            else:
                logger.log_err(
                    'Unknown learning rate scheduler ({0}) type\n'
                    'Please add it into either EPOCH_LRERS or ITER_LRERS in \'pixelssl/nn/lrer.py\'\n'
                    'PixelSSL supports: \n'
                    '  EPOCH_LRERS\t=>\t{1}\n  ITER_LRERS\t=>\t{2}\n'.format(
                        lrer, nnlrer.EPOCH_LRERS, nnlrer.ITER_LRERS))

            if self.args.is_epoch_lrer is None:
                self.args.is_epoch_lrer = is_epoch_lrer
            elif self.args.is_epoch_lrer != is_epoch_lrer:
                logger.log_err(
                    'Unmatched lr scheduler types\t=>\t{0}\n'
                    'All lrers of the task models should have the same types (either EPOCH_LRERS or ITER_LRERS)\n'
                    'PixelSSL supports: \n'
                    '  EPOCH_LRERS\t=>\t{1}\n  ITER_LRERS\t=>\t{2}\n'.format(
                        self.args.lrers, nnlrer.EPOCH_LRERS,
                        nnlrer.ITER_LRERS))

        self.args.checkpoint_path = os.path.join(self.args.out_path, 'ckpt')
        if not os.path.exists(self.args.checkpoint_path):
            os.makedirs(self.args.checkpoint_path)

        if self.args.visualize:
            self.args.visual_debug_path = os.path.join(self.args.out_path,
                                                       'visualization/debug')
            self.args.visual_train_path = os.path.join(self.args.out_path,
                                                       'visualization/train')
            self.args.visual_val_path = os.path.join(self.args.out_path,
                                                     'visualization/val')
            for vpath in [
                    self.args.visual_debug_path, self.args.visual_train_path,
                    self.args.visual_val_path
            ]:
                if not os.path.exists(vpath):
                    os.makedirs(vpath)

        # handle argumens for multiply GPUs training
        self.args.gpus = torch.cuda.device_count()
        if self.args.gpus < 1:
            logger.log_err('No GPU be detected\n'
                           'PixelSSL requires at least one Nvidia GPU\n')

        logger.log_info('GPU: \n  Total GPU(s): {0}'.format(self.args.gpus))
        self.args.lr *= self.args.gpus
        self.args.num_workers *= self.args.gpus
        self.args.batch_size *= self.args.gpus
        self.args.unlabeled_batch_size *= self.args.gpus

        # TODO: support unsupervised and self-supervised training
        if self.args.unlabeled_batch_size >= self.args.batch_size:
            logger.log_err(
                'The argument \'unlabeled_batch_size\' ({0}) should be smaller than \'batch_size\' ({1}) '
                'since PixelSSL only supports semi-supervised learning now\n')

        self.args.labeled_batch_size = self.args.batch_size - self.args.unlabeled_batch_size
        logger.log_info(
            '  Total learn rate: {0}\n  Total labeled batch size: {1}\n'
            '  Total unlabeled batch size: {2}\n  Total data workers: {3}\n'.
            format(self.args.lr, self.args.labeled_batch_size,
                   self.args.unlabeled_batch_size, self.args.num_workers))