Ejemplo n.º 1
0
 def _pred_err(self):
     logger.log_err(
         'In SSL_MT, the \'resulter\' dict returned by the task model should contain the following keys:\n'
         '   (1) \'pred\'\t=>\tunactivated task predictions\n'
         '   (2) \'activated_pred\'\t=>\tactivated task predictions\n'
         'We need both of them since some losses include the activation functions,\n'
         'e.g., the CrossEntropyLoss has contained SoftMax\n')
Ejemplo n.º 2
0
    def __init__(self, args):
        super(SSLMT, self).__init__(args)

        # define the student model and the teacher model
        self.s_model, self.t_model = None, None
        self.s_optimizer = None
        self.s_lrer = None
        self.s_criterion = None

        self.cons_criterion = None

        self.gaussian_noiser = None
        self.zero_tensor = torch.zeros(1)

        # check SSL arguments
        if self.args.cons_for_labeled or self.args.unlabeled_batch_size > 0:
            if self.args.cons_scale < 0:
                logger.log_err(
                    'The argument - cons_scale - is not set (or invalid)\n'
                    'You set argument - cons_for_labeled - to True\n'
                    'or\n'
                    'You set argument - unlabeled_batch_size - larger than 0\n'
                    'Please set - cons_scale >= 0 - for training\n')
            if self.args.cons_rampup_epochs < 0:
                logger.log_err(
                    'The argument - cons_rampup_epochs - is not set (or invalid)\n'
                    'You set argument - cons_for_labeled - to True\n'
                    'or\n'
                    'You set argument - unlabeled_batch_size - larger than 0\n'
                    'Please set - cons_rampup_epochs >= 0 - for training\n')
Ejemplo n.º 3
0
    def forward(self, inp):
        resulter, debugger = {}, {}

        t_resulter, t_debugger = self.task_model.forward(inp)

        if not 'pred' in t_resulter.keys(
        ) or not 'activated_pred' in t_resulter.keys():
            logger.log_err(
                'In SSL_S4L, the \'resulter\' dict returned by the task model should contain the following keys:\n'
                '   (1) \'pred\'\t=>\tunactivated task predictions\n'
                '   (2) \'activated_pred\'\t=>\tactivated task predictions\n'
                'We need both of them since some losses include the activation functions,\n'
                'e.g., the CrossEntropyLoss has contained SoftMax\n')

        if not 'ssls4l_rc_inp' in t_resulter.keys():
            logger.log_err(
                'In SSL_S4L, the \'resulter\' dict returned by the task model should contain the key:\n'
                '    \'ssls4l_rc_inp\'\t=>\tinputs of the rotation classifier (a 4-dim tensor)\n'
                'It can be the feature map encoded by the task model or the output of the task model\n'
                'Please add the key \'ssls4l_rc_inp\' in your task model\'s resulter\n'
            )

        rc_inp = tool.dict_value(t_resulter, 'ssls4l_rc_inp')
        pred_rotation = self.rotation_classifier.forward(rc_inp)

        resulter['pred'] = tool.dict_value(t_resulter, 'pred')
        resulter['activated_pred'] = tool.dict_value(t_resulter,
                                                     'activated_pred')
        resulter['rotation'] = pred_rotation

        return resulter, debugger
Ejemplo n.º 4
0
    def __init__(self, args):
        super(SSLADV, self).__init__(args)

        # define the task model and the FC discriminator
        self.model, self.d_model = None, None
        self.optimizer, self.d_optimizer = None, None
        self.lrer, self.d_lrer = None, None
        self.criterion, self.d_criterion = None, None

        # prepare the arguments for multiple GPUs
        self.args.discriminator_lr *= self.args.gpus

        # check SSL arguments
        if self.args.adv_for_labeled:
            if self.args.labeled_adv_scale < 0:
                logger.log_err('The argument - labeled_adv_scale - is not set (or invalid)\n'
                                'You set argument - adv_for_labeled - to True\n'
                                'Please set - labeled_adv_scale >= 0 - for calculating the'
                                'adversarial loss on the labeled data\n')
        if self.args.unlabeled_batch_size > 0:
            if self.args.unlabeled_adv_scale < 0:
                logger.log_err('The argument - unlabeled_adv_scale - is not set (or invalid)\n'
                                'You set argument - unlabeled_batch_size - larger than 0\n'
                                'Please set - unlabeled_adv_scale >= 0 - for calculating the' 
                                'adversarial loss on the unlabeled data\n')
Ejemplo n.º 5
0
    def _build_ssl_algorithm(self):
        """ Build the semi-supervised learning algorithm given in the script.
        """

        for cname in self.args.models.keys():
            self.model_dict[cname] = self.model.__dict__[
                self.args.models[cname]]()
            self.criterion_dict[cname] = self.criterion.__dict__[
                self.args.criterions[cname]]()
            self.lrer_dict[cname] = nnlrer.__dict__[self.args.lrers[cname]](
                self.args)
            self.optimizer_dict[cname] = nnoptimizer.__dict__[
                self.args.optimizers[cname]](self.args)

        logger.log_info('SSL algorithm: \n  {0}\n'.format(
            self.args.ssl_algorithm))
        logger.log_info('Models: ')
        self.ssl_algorithm = pixelssl.ssl_algorithm.__dict__[
            self.args.ssl_algorithm].__dict__[self.args.ssl_algorithm](
                self.args, self.model_dict,
                self.optimizer_dict, self.lrer_dict, self.criterion_dict,
                self.func.task_func()(self.args))

        # check whether the SSL algorithm supports the given task
        if not self.TASK_TYPE in self.ssl_algorithm.SUPPORTED_TASK_TYPES:
            logger.log_err(
                'SSL algorithm - {0} - supports task types {1}\n'
                'However, the given task - {2} - belongs to {3}\n'.format(
                    self.ssl_algorithm.NAME,
                    self.ssl_algorithm.SUPPORTED_TASK_TYPES, self.args.task,
                    self.TASK_TYPE))
Ejemplo n.º 6
0
def pytorch_support(required_version='1.0.0', info_str=''):
    if torch.__version__ < required_version:
        logger.log_err('{0} required PyTorch >= {1}\n'
                       'However, current PyTorch == {2}\n'.format(
                           info_str, required_version, torch.__version__))
    else:
        return True
Ejemplo n.º 7
0
    def _batch_prehandle(self, inp, gt, is_train):
        # add extra data augmentation process here if necessary

        inp_var = []
        for i in inp:
            inp_var.append(Variable(i).cuda())
        inp = tuple(inp_var)

        gt_var = []
        for g in gt:
            gt_var.append(Variable(g).cuda())
        gt = tuple(gt_var)

        mix_u_inp = None
        mix_u_mask = None

        # -------------------------------------------------
        # Operations for CUTMIX
        # -------------------------------------------------
        if is_train:
            lbs = self.args.labeled_batch_size
            ubs = self.args.unlabeled_batch_size

            # check the shape of input and gt
            # NOTE: this implementation of CUTMIX supports multiple input and gt
            #       but all input and gt should have the same image size

            sample_shape = (inp[0].shape[2], inp[0].shape[3])
            for i in inp:
                if not tuple(i.shape[2:]) == sample_shape:
                    logger.log_err(
                        'This SSL_CUTMIX algorithm requires all inputs have the same shape \n'
                    )
            for g in gt:
                if not tuple(g.shape[2:]) == sample_shape:
                    logger.log_err(
                        'This SSL_CUTMIX algorithm requires all ground truths have the same shape \n'
                    )

            # generate the mask for mixing the unlabeled samples
            mix_u_mask = self.mask_generator.produce(int(ubs / 2),
                                                     sample_shape)
            mix_u_mask = torch.tensor(mix_u_mask).cuda()

            # mix the unlabeled samples
            u_inp_1 = func.split_tensor_tuple(inp, lbs, int(lbs + ubs / 2))
            u_inp_2 = func.split_tensor_tuple(inp, int(lbs + ubs / 2),
                                              self.args.batch_size)

            mix_u_inp = []
            for ui_1, ui_2 in zip(u_inp_1, u_inp_2):
                mi = mix_u_mask * ui_1 + (1 - mix_u_mask) * ui_2
                mix_u_inp.append(mi)
            mix_u_inp = tuple(mix_u_inp)

        return inp, gt, mix_u_inp, mix_u_mask
Ejemplo n.º 8
0
    def _run(self):
        """ Main pipeline of experiment.

        Please override this function if you want a special pipeline.
        """

        start_epoch = 0
        if self.args.resume is not None and self.args.resume != '':
            logger.log_info('Load checkpoint from: {0}'.format(
                self.args.resume))
            start_epoch = self.ssl_algorithm.load_checkpoint()

        if self.args.validation:
            if self.val_loader is None:
                logger.log_err('No data loader for validation.\n'
                               'Please set right \'valset\' in the script.\n')

            logger.log_info(
                ['=' * 78, '\nStart to validate model\n', '=' * 78])
            with torch.no_grad():
                self.ssl_algorithm.validate(self.val_loader, start_epoch)
            return

        # NOTE: the first epoch index for 'train' and 'validatie' is 0
        for epoch in range(start_epoch, self.args.epochs):
            timer = time.time()

            logger.log_info([
                '=' * 78, '\nStart to train epoch-{0}\n'.format(epoch + 1),
                '=' * 78
            ])
            self.ssl_algorithm.train(self.train_loader, epoch)

            if (epoch + 1
                ) % self.args.val_freq == 0 and self.val_loader is not None:
                logger.log_info([
                    '=' * 78,
                    '\nStart to validate epoch-{0}\n'.format(epoch + 1),
                    '=' * 78
                ])
                with torch.no_grad():
                    self.ssl_algorithm.validate(self.val_loader, epoch)

            if (epoch + 1) % self.args.checkpoint_freq == 0:
                self.ssl_algorithm.save_checkpoint(epoch + 1)
                logger.log_info("Save checkpoint for epoch {0}".format(epoch +
                                                                       1))

            logger.log_info(
                'Finish epoch in {0} seconds\n'.format(time.time() - timer))

        logger.log_info('Finish experiment {0}\n'.format(self.args.exp_id))
Ejemplo n.º 9
0
    def __init__(self, args):
        super(SSLCUTMIX, self).__init__(args)

        self.s_model, self.t_model = None, None
        self.s_optimizer = None
        self.s_lrer = None
        self.s_criterion = None

        # define the auxiliary modules required by CUTMIX
        self.mask_generator = None

        # check SSL arguments
        if self.args.unlabeled_batch_size > 0:
            if not self.args.unlabeled_batch_size > 2 or not self.args.unlabeled_batch_size % 2 == 0:
                logger.log_err(
                    'This implementation of SSL_CUTMIX requires the unlabeled batch size: \n'
                    '    1. larger than 2 \n'
                    '    2. is divisible by 2 \n')
            if self.args.cons_scale < 0:
                logger.log_err(
                    'The argument - cons_scale - is not set (or invalid)\n'
                    'Please set - cons_scale >= 0 - for training\n')
            if self.args.cons_rampup_epochs < 0:
                logger.log_err(
                    'The argument - cons_rampup_epochs - is not set (or invalid)\n'
                    'Please set - cons_rampup_epochs >= 0 - for training\n')
            if self.args.cons_threshold < 0 or self.args.cons_threshold > 1:
                logger.log_err(
                    'The argument - cons_threshold - is not set (or invalid)\n'
                    'Please set - 0 <= cons_threshold < 1 - for training\n')
Ejemplo n.º 10
0
    def _load_checkpoint(self):
        checkpoint = torch.load(self.args.resume)

        checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown')
        if checkpoint_algorithm != self.NAME:
            logger.log_err('Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n'
                           .format(self.NAME, checkpoint_algorithm))

        self.s_model.load_state_dict(checkpoint['s_model'])
        self.t_model.load_state_dict(checkpoint['t_model'])
        self.s_optimizer.load_state_dict(checkpoint['s_optimizer'])
        self.s_lrer.load_state_dict(checkpoint['s_lrer'])

        return checkpoint['epoch']
Ejemplo n.º 11
0
def create_parser(algorithm):
    parser = argparse.ArgumentParser(
        description='PixelSSL Static Script Parser')

    if not algorithm in ssl_algorithm.SSL_ALGORITHMS:
        logger.log_err('Unknown semi-supervised learning algorithm: {0}\n'
                       'The support algorithms are: {1}\n'.format(
                           algorithm, ssl_algorithm.SSL_ALGORITHMS))

    optimizer.add_parser_arguments(parser)
    lrer.add_parser_arguments(parser)
    ssl_algorithm.__dict__[algorithm].add_parser_arguments(parser)

    return parser
Ejemplo n.º 12
0
def ssl_cct(args, model_dict, optimizer_dict, lrer_dict, criterion_dict, task_func):
    if not len(model_dict) == len(optimizer_dict) == len(lrer_dict) == len(criterion_dict) == 1:
        logger.log_err('The len(element_dict) of SSL_CCT should be 1\n')
    elif list(model_dict.keys())[0] != 'model':
        logger.log_err('In SSL_CCT, the key of element_dict should be \'model\',\n'
                'but \'{0}\' is given\n'.format(model_dict.keys()))

    model_funcs = [model_dict['model']]
    optimizer_funcs = [optimizer_dict['model']]
    lrer_funcs = [lrer_dict['model']]
    criterion_funcs = [criterion_dict['model']]

    algorithm = SSLCCT(args)
    algorithm.build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func)
    return algorithm
Ejemplo n.º 13
0
    def _load_dataset(self, dataset_name, dataset_dir, is_train=True):
        """ Load one dataset.
        """

        if not dataset_name in self.data.__dict__.keys():
            logger.log_err('Unknown dataset type: {0}\n'.format(dataset_name))
        elif not os.path.exists(dataset_dir):
            logger.log_err('Cannot find the path of dataset: {0}\n'.format(dataset_dir))
        else:
            dataset_args = copy.deepcopy(self.args)
            if is_train:
                dataset_args.trainset = {dataset_name: dataset_dir}
            else:
                dataset_args.valset = {dataset_name: dataset_dir}
            return self.data.__dict__[dataset_name]()(dataset_args, is_train)
Ejemplo n.º 14
0
    def _load_checkpoint(self):
        checkpoint = torch.load(self.args.resume)

        checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown')

        if checkpoint_algorithm != self.NAME:
            logger.log_err('Unmatched SSL algorithm format in checkpoint => required: {0} - given: {1}\n'
                           .format(self.NAME, checkpoint_algorithm))

        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.lrer.load_state_dict(checkpoint['lrer'])

        self.main_model = self.model.module.main_model
        self.auxiliary_decoders = self.model.module.auxiliary_decoders

        return checkpoint['epoch']
Ejemplo n.º 15
0
    def _load_checkpoint(self):
        checkpoint = torch.load(self.args.resume)

        checkpoint_algorithm = tool.dict_value(checkpoint,
                                               'algorithm',
                                               default='unknown')
        if checkpoint_algorithm != self.NAME:
            logger.log_err(
                'Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n'
                .format(self.NAME, checkpoint_algorithm))

        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.lrer.load_state_dict(checkpoint['lrer'])

        self.task_model = self.model.module.task_model
        self.rotation_classifier = self.model.module.rotation_classifier

        return checkpoint['epoch']
Ejemplo n.º 16
0
    def forward(self, x):
        """
        Arguments:
            x (torch.Tensor): input 4D tensor

        Returns:
            torch.Tensor: Blurred version of the input 
        """

        if not len(list(x.shape)) == 4:
            logger.log_err(
                '\'GaussianBlurLayer\' requires a 4D tensor as input\n')
        elif not x.shape[1] == self.channels:
            logger.log_err(
                'In \'GaussianBlurLayer\', the required channel ({0}) is'
                'not the same as input ({1})\n'.format(self.channels,
                                                       x.shape[1]))

        return self.op(x)
Ejemplo n.º 17
0
    def __init__(self, args):
        super(SSLS4L, self).__init__(args)

        self.task_model = None
        self.rotation_classifier = None

        self.model = None
        self.optimizer = None
        self.lrer = None
        self.criterion = None

        # check SSL arguments
        if self.args.rotation_scale < 0:
            logger.log_err(
                'The argument - rotation_scale - is not set (or invalid)\n'
                'Please set - rotation_scale >= 0 - for training\n')
        if self.args.rotated_sup_scale < 0:
            logger.log_err(
                'The argument - rotated_sup_scale - is not set (or invalid)\n'
                'Please set - rotated_sup_scale >= 0 - for training\n')
Ejemplo n.º 18
0
    def __init__(self, args):
        super(SSLCCT, self).__init__(args)

        self.main_model = None
        self.auxiliary_decoders = None

        self.model = None
        self.optimizer = None
        self.lrer = None
        self.criterion = None

        self.cons_criterion = None

        # check SSL arguments
        if self.args.unlabeled_batch_size > 0:
            if self.args.cons_scale < 0:
                logger.log_err('The argument - cons_scale - is not set (or invalid)\n'
                               'You set argument - unlabeled_batch_size - larger than 0\n'
                               'Please set - cons_scale >= 0 - for training\n')
            elif self.args.cons_rampup_epochs < 0:
                logger.log_err('The argument - cons_rampup_epochs - is not set (or invalid)\n'
                               'You set argument - unlabeled_batch_size - larger than 0\n'
                               'Please set - cons_rampup_epochs >= 0 - for training\n')
            if self.args.ad_lr_scale < 0:
                logger.log_err('The argument - ad_lr_scale - is not set (or invalid)\n'
                               'You set argument - unlabeled_batch_size - larger than 0\n'
                               'Please set - ad_lr_scale >= 0 - for training\n')
        else:
            self.args.ad_lr_scale = 0
Ejemplo n.º 19
0
    def forward(self, inp, gt, is_unlabeled):
        resulter, debugger = {}, {}

        # forward the task model
        m_resulter, m_debugger = self.main_model.forward(inp)

        if not 'pred' in m_resulter.keys() or not 'activated_pred' in m_resulter.keys():
            logger.log_err('In SSL_CCT, the \'resulter\' dict returned by the task model should contain the following keys:\n'
                           '   (1) \'pred\'\t=>\tunactivated task predictions\n'
                           '   (2) \'activated_pred\'\t=>\tactivated task predictions\n'
                           'We need both of them since some losses include the activation functions,\n'
                           'e.g., the CrossEntropyLoss has contained SoftMax\n')

        resulter['pred'] = tool.dict_value(m_resulter, 'pred')
        resulter['activated_pred'] = tool.dict_value(m_resulter, 'activated_pred')

        if not len(resulter['pred']) == len(resulter['activated_pred']) == 1:
            logger.log_err('This implementation of SSL_CCT only support the task model with only one prediction (output). \n'
                           'However, there are {0} predictions.\n'.format(len(resulter['pred'])))

        # calculate the task loss
        resulter['task_loss'] = None if is_unlabeled else torch.mean(self.task_criterion.forward(resulter['pred'], gt, inp))

        # for the unlabeled data
        if is_unlabeled and self.args.unlabeled_batch_size > 0:
            if not 'sslcct_ad_inp' in m_resulter.keys():
                logger.log_err('In SSL_CCT, the \'resulter\' dict returned by the task model should contain the key:\n'
                               '    \'sslcct_ad_inp\'\t=>\tinputs of the auxiliary decoders (a 4-dim tensor)\n'
                               'It is the feature map encoded by the task model\n'
                               'Please add the key \'sslcct_ad_inp\' in your task model\'s resulter\n'
                               'Note that for different task models, the shape of \'sslcct_ad_inp\' may be different\n')

            ul_ad_inp = tool.dict_value(m_resulter, 'sslcct_ad_inp')
            ul_main_pred = resulter['pred'][0].detach()

            # forward the auxiliary decoders
            ul_ad_preds = []
            for ad in self.auxiliary_decoders:
                ul_ad_preds.append(ad.forward(ul_ad_inp, pred_of_main_decoder=ul_main_pred))

            resulter['ul_ad_preds'] = ul_ad_preds

            # calculate the consistency loss
            ul_ad_gt = resulter['activated_pred'][0].detach()
            ul_ad_preds = [F.interpolate(ul_ad_pred, size=(ul_ad_gt.shape[2], ul_ad_gt.shape[3]), mode='bilinear') for ul_ad_pred in ul_ad_preds]
            ul_activated_ad_preds = self.ad_activation_func(ul_ad_preds)
            cons_loss = sum([self.cons_criterion.forward(ul_activated_ad_pred, ul_ad_gt) for ul_activated_ad_pred in ul_activated_ad_preds])
            cons_loss = torch.mean(cons_loss) / len(ul_activated_ad_preds)
            resulter['cons_loss'] = cons_loss
        else:
            resulter['ul_ad_preds'] = None
            resulter['cons_loss'] = None

        return resulter, debugger
Ejemplo n.º 20
0
def ssl_gct(args, model_dict, optimizer_dict, lrer_dict, criterion_dict,
            task_func):
    if not len(model_dict) == len(optimizer_dict) == len(lrer_dict) == len(
            criterion_dict):
        logger.log_err('The len(element_dict) of SSL_GCT should be the same\n')

    if len(model_dict) == 1:
        if list(model_dict.keys())[0] != 'model':
            logger.log_err(
                'In SSL_GCT, the key of 1-value element_dict should be \'model\',\n'
                'but \'{0}\' is given\n'.format(model_dict.keys()))

        model_funcs = [model_dict['model'], model_dict['model']]
        optimizer_funcs = [optimizer_dict['model'], optimizer_dict['model']]
        lrer_funcs = [lrer_dict['model'], lrer_dict['model']]
        criterion_funcs = [criterion_dict['model'], criterion_dict['model']]

    elif len(model_dict) == 2:
        if 'lmodel' not in list(model_dict.keys()) or 'rmodel' not in list(
                model_dict.keys()):
            logger.log_err(
                'In SSL_GCT, the key of 2-value element_dict should be \'(lmodel, rmodel)\', '
                'but \'{0}\' is given\n'.format(model_dict.keys()))

        model_funcs = [model_dict['lmodel'], model_dict['rmodel']]
        optimizer_funcs = [optimizer_dict['lmodel'], optimizer_dict['rmodel']]
        lrer_funcs = [lrer_dict['lmodel'], lrer_dict['rmodel']]
        criterion_funcs = [criterion_dict['lmodel'], criterion_dict['rmodel']]

    else:
        logger.log_err(
            'The SSL_GCT algorithm supports element_dict with 1 or 2 elements, '
            'but given {0} elements\n'.format(len(model_dict)))

    algorithm = SSLGCT(args)
    algorithm.build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
                    task_func)
    return algorithm
Ejemplo n.º 21
0
    def _preprocess_arguments(self):
        """ Preprocess the arguments in the script.
        """

        # create the output folder to store the results
        self.args.out_path = "{root}/{exp_id}/{date:%Y-%m-%d_%H:%M:%S}/".format(
            root=self.args.out_path,
            exp_id=self.args.exp_id,
            date=datetime.now())
        if not os.path.exists(self.args.out_path):
            os.makedirs(self.args.out_path)

        # prepare logger
        exp_op = 'val' if self.args.validation else 'train'
        logger.log_mode(self.args.debug)
        logger.log_file(
            os.path.join(self.args.out_path, '{0}.log'.format(exp_op)),
            self.args.debug)

        logger.log_info('Result folder: \n  {0} \n'.format(self.args.out_path))

        # print experimental args
        cmd.print_args()

        # set task name
        self.args.task = self.NAME

        # check the task-specific components dicts required by the SSL algorithm
        if not len(self.args.models) == len(self.args.optimizers) == len(
                self.args.lrers) == len(self.args.criterions):
            logger.log_err(
                'Condition:\n'
                '\tlen(self.args.models) == len(self.args.optimizers) == len(self.args.lrers) == len(self.args.criterions\n'
                'is not satisfied in the script\n')

        for (model, criterion, optimizer, lrer) in \
            zip(self.args.models.values(), self.args.criterions.values(), self.args.optimizers.values(), self.args.lrers.values()):
            if model not in self.model.__dict__:
                logger.log_err(
                    'Unsupport model: {0} for task: {1}\n'
                    'Please add the export function in task\'s \'model.py\'\n'.
                    format(model, self.args.task))
            elif criterion not in self.criterion.__dict__:
                logger.log_err(
                    'Unsupport criterion: {0} for task: {1}\n'
                    'Please add the export function in task\'s \'criterion.py\'\n'
                    .format(criterion, self.args.task))
            elif optimizer not in nnoptimizer.__dict__:
                logger.log_err(
                    'Unsupport optimizer: {0}\n'
                    'Please implement the optimizer wrapper in \'pixelssl/nn/optimizer.py\'\n'
                    .format(optimizer))
            elif lrer not in nnlrer.__dict__:
                logger.log_err(
                    'Unsupport learning rate scheduler: {0}\n'
                    'Please implement lr scheduler wrapper in \'pixelssl/nn/lrer.py\'\n'
                    .format(lrer))

        # check the types of lrers
        for lrer in self.args.lrers.values():
            if lrer in nnlrer.EPOCH_LRERS:
                is_epoch_lrer = True
            elif lrer in nnlrer.ITER_LRERS:
                is_epoch_lrer = False
            else:
                logger.log_err(
                    'Unknown learning rate scheduler ({0}) type\n'
                    'Please add it into either EPOCH_LRERS or ITER_LRERS in \'pixelssl/nn/lrer.py\'\n'
                    'PixelSSL supports: \n'
                    '  EPOCH_LRERS\t=>\t{1}\n  ITER_LRERS\t=>\t{2}\n'.format(
                        lrer, nnlrer.EPOCH_LRERS, nnlrer.ITER_LRERS))

            if self.args.is_epoch_lrer is None:
                self.args.is_epoch_lrer = is_epoch_lrer
            elif self.args.is_epoch_lrer != is_epoch_lrer:
                logger.log_err(
                    'Unmatched lr scheduler types\t=>\t{0}\n'
                    'All lrers of the task models should have the same types (either EPOCH_LRERS or ITER_LRERS)\n'
                    'PixelSSL supports: \n'
                    '  EPOCH_LRERS\t=>\t{1}\n  ITER_LRERS\t=>\t{2}\n'.format(
                        self.args.lrers, nnlrer.EPOCH_LRERS,
                        nnlrer.ITER_LRERS))

        self.args.checkpoint_path = os.path.join(self.args.out_path, 'ckpt')
        if not os.path.exists(self.args.checkpoint_path):
            os.makedirs(self.args.checkpoint_path)

        if self.args.visualize:
            self.args.visual_debug_path = os.path.join(self.args.out_path,
                                                       'visualization/debug')
            self.args.visual_train_path = os.path.join(self.args.out_path,
                                                       'visualization/train')
            self.args.visual_val_path = os.path.join(self.args.out_path,
                                                     'visualization/val')
            for vpath in [
                    self.args.visual_debug_path, self.args.visual_train_path,
                    self.args.visual_val_path
            ]:
                if not os.path.exists(vpath):
                    os.makedirs(vpath)

        # handle argumens for multiply GPUs training
        self.args.gpus = torch.cuda.device_count()
        if self.args.gpus < 1:
            logger.log_err('No GPU be detected\n'
                           'PixelSSL requires at least one Nvidia GPU\n')

        logger.log_info('GPU: \n  Total GPU(s): {0}'.format(self.args.gpus))
        self.args.lr *= self.args.gpus
        self.args.num_workers *= self.args.gpus
        self.args.batch_size *= self.args.gpus
        self.args.unlabeled_batch_size *= self.args.gpus

        # TODO: support unsupervised and self-supervised training
        if self.args.unlabeled_batch_size >= self.args.batch_size:
            logger.log_err(
                'The argument \'unlabeled_batch_size\' ({0}) should be smaller than \'batch_size\' ({1}) '
                'since PixelSSL only supports semi-supervised learning now\n')

        self.args.labeled_batch_size = self.args.batch_size - self.args.unlabeled_batch_size
        logger.log_info(
            '  Total learn rate: {0}\n  Total labeled batch size: {1}\n'
            '  Total unlabeled batch size: {2}\n  Total data workers: {3}\n'.
            format(self.args.lr, self.args.labeled_batch_size,
                   self.args.unlabeled_batch_size, self.args.num_workers))
Ejemplo n.º 22
0
    def _task_model_iter(self, epoch, idx, is_train, mid, lbs, inp, gt, dc_gt,
                         fc_mask, dc_rampup_scale):
        if mid == 'l':
            model, criterion = self.l_model, self.l_criterion
        elif mid == 'r':
            model, criterion = self.r_model, self.r_criterion
        else:
            model, criterion = None, None

        # forward the task model
        resulter, debugger = model.forward(inp)
        if not 'pred' in resulter.keys(
        ) or not 'activated_pred' in resulter.keys():
            self._pred_err()

        pred = tool.dict_value(resulter, 'pred')
        activated_pred = tool.dict_value(resulter, 'activated_pred')

        fd_resulter, fd_debugger = self.fd_model.forward(
            inp, activated_pred[0])
        flawmap = tool.dict_value(fd_resulter, 'flawmap')

        # calculate the supervised task constraint on the labeled data
        labeled_pred = func.split_tensor_tuple(pred, 0, lbs)
        labeled_gt = func.split_tensor_tuple(gt, 0, lbs)
        labeled_inp = func.split_tensor_tuple(inp, 0, lbs)
        task_loss = torch.mean(
            criterion.forward(labeled_pred, labeled_gt, labeled_inp))
        self.meters.update('{0}_task_loss'.format(mid), task_loss.data)

        # calculate the flaw correction constraint
        if self.args.ssl_mode in [MODE_GCT, MODE_FC]:
            if flawmap.shape == self.zero_df_gt.shape:
                fc_ssl_loss = self.fd_criterion.forward(flawmap,
                                                        self.zero_df_gt,
                                                        is_ssl=True,
                                                        reduction=False)
            else:
                fc_ssl_loss = self.fd_criterion.forward(
                    flawmap,
                    torch.zeros(flawmap.shape).cuda(),
                    is_ssl=True,
                    reduction=False)

            if self.args.ssl_mode == MODE_GCT:
                fc_ssl_loss = fc_mask * fc_ssl_loss

            fc_ssl_loss = self.args.fc_ssl_scale * torch.mean(fc_ssl_loss)
            self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss.data)
        else:
            fc_ssl_loss = 0
            self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss)

        # calculate the dynamic consistency constraint
        if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
            if dc_gt is None:
                logger.log_err(
                    'The dynamic consistency constraint is enabled, '
                    'but no pseudo ground truth is given.')

            dc_ssl_loss = self.dc_criterion.forward(activated_pred[0], dc_gt)
            dc_ssl_loss = dc_rampup_scale * self.args.dc_ssl_scale * torch.mean(
                dc_ssl_loss)
            self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss.data)
        else:
            dc_ssl_loss = 0
            self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss)

        with torch.no_grad():
            flawmap_gt = self.fdgt_generator.forward(
                activated_pred[0],
                self.task_func.sslgct_prepare_task_gt_for_fdgt(gt[0]))

        # for validation
        if not is_train:
            fd_loss = self.args.fd_scale * self.fd_criterion.forward(
                flawmap, flawmap_gt)
            self.meters.update('{0}_fd_loss'.format(mid),
                               torch.mean(fd_loss).data)

            self.task_func.metrics(activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str=mid)

        # visualization
        if self.args.visualize and idx % self.args.visual_freq == 0:
            with torch.no_grad():
                handled_flawmap = self.flawmap_handler(flawmap)[0]

            self._visualize(
                epoch, idx, is_train, mid,
                func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True),
                func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                handled_flawmap, flawmap_gt[0], dc_gt[0])

        loss = task_loss + fc_ssl_loss + dc_ssl_loss
        return loss
Ejemplo n.º 23
0
    def _create_dataloader(self):
        """ Create data loaders for experiment.
        """

        # ---------------------------------------------------------------------
        # create dataloder for training
        # ---------------------------------------------------------------------

        # ignore_unlabeled == False & unlabeled_batch_size != 0
        #   means that both labeled and unlabeled data are used
        with_unlabeled_data = not self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0
        # ignore_unlabeled == True & unlabeled_batch_size == 0
        #   means that only the labeled data is used
        without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0

        labeled_train_samples, unlabeled_train_samples = 0, 0
        if not self.args.validation:
            # ignore_unlabeled == True & unlabeled_batch_size != 0 -> error
            if self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0:
                logger.log_err(
                    'Arguments conflict => ignore_unlabeled == True requires unlabeled_batch_size == 0\n'
                )
            # ignore_unlabeled == False & unlabeled_batch_size == 0 -> error
            if not self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0:
                logger.log_err(
                    'Arguments conflict => ignore_unlabeled == False requires unlabeled_batch_size != 0\n'
                )

            # calculate the number of trainsets
            trainset_num = 0
            for key, value in self.args.trainset.items():
                trainset_num += len(value)

            # calculate the number of unlabeledsets
            unlabeledset_num = 0
            for key, value in self.args.unlabeledset.items():
                unlabeledset_num += len(value)

            # if only one labeled training set and without any unlabeled set
            if trainset_num == 1 and unlabeledset_num == 0:
                trainset = self._load_dataset(
                    list(self.args.trainset.keys())[0],
                    list(self.args.trainset.values())[0][0])
                labeled_train_samples = len(trainset.idxs)

                # if the 'sublabeled_path' is given
                sublabeled_prefix = None
                if self.args.sublabeled_path is not None and self.args.sublabeled_path != '':
                    if not os.path.exists(self.args.sublabeled_path):
                        logger.log_err(
                            'Cannot find labeled file: {0}\n'.format(
                                self.args.sublabeled_path))
                    else:
                        with open(self.args.sublabeled_path) as f:
                            sublabeled_prefix = [
                                line.strip() for line in f.read().splitlines()
                            ]
                        sublabeled_prefix = None if len(
                            sublabeled_prefix) == 0 else sublabeled_prefix

                if sublabeled_prefix is not None:
                    # wrap the trainset by 'SplitUnlabeledWrapper'
                    trainset = nndata.SplitUnlabeledWrapper(
                        trainset,
                        sublabeled_prefix,
                        ignore_unlabeled=self.args.ignore_unlabeled)
                    labeled_train_samples = len(trainset.labeled_idxs)
                    unlabeled_train_samples = len(trainset.unlabeled_idxs)

                # if 'sublabeled_prefix' is None but you want to use the unlabeled data for training
                elif with_unlabeled_data:
                    logger.log_err(
                        'Try to use the unlabeled samples without any SSL dataset wrapper\n'
                    )

            # if more than one labeled training sets are given or the unlabeled training sets are given
            elif trainset_num > 1 or unlabeledset_num > 0:
                # 'arg.sublabeled_path' is disabled
                if self.args.sublabeled_path is not None and self.args.sublabeled_path != '':
                    logger.log_err(
                        'Multiple training datasets are given. \n'
                        'Inter-split unlabeled set is not allowed.\n'
                        'Please remove the argument \'sublabeled_path\' in the script\n'
                    )

                # load all training sets
                labeled_sets = []
                for set_name, set_dirs in self.args.trainset.items():
                    for set_dir in set_dirs:
                        labeled_sets.append(
                            self._load_dataset(set_name, set_dir))

                # load all extra unlabeled sets
                unlabeled_sets = []
                # if any extra unlabeled set is given
                if unlabeledset_num > 0:
                    for set_name, set_dirs in self.args.unlabeledset.items():
                        for set_dir in set_dirs:
                            unlabeled_sets.append(
                                self._load_dataset(set_name, set_dir))

                # if unalbeledset_num == 0 but you want to use the unlabeled data for training
                elif with_unlabeled_data:
                    logger.log_err(
                        'Try to use the unlabeled samples without any SSL dataset wrapper\n'
                        'Please add the argument \'unlabeledset\' in the script\n'
                    )

                # wrap both 'labeled_set' and 'unlabeled_set' by 'JointDatasetsWrapper'
                trainset = nndata.JointDatasetsWrapper(
                    labeled_sets,
                    unlabeled_sets,
                    ignore_unlabeled=self.args.ignore_unlabeled)
                labeled_train_samples = len(trainset.labeled_idxs)
                unlabeled_train_samples = len(trainset.unlabeled_idxs)

            # if use labeled data only
            if without_unlabeled_data:
                self.train_loader = torch.utils.data.DataLoader(
                    trainset,
                    batch_size=self.args.batch_size,
                    shuffle=True,
                    num_workers=self.args.num_workers,
                    pin_memory=True,
                    drop_last=True)
            # if use both labeled and unlabeled data
            elif with_unlabeled_data:
                train_sampler = nndata.TwoStreamBatchSampler(
                    trainset.labeled_idxs, trainset.unlabeled_idxs,
                    self.args.labeled_batch_size,
                    self.args.unlabeled_batch_size)
                self.train_loader = torch.utils.data.DataLoader(
                    trainset,
                    batch_sampler=train_sampler,
                    num_workers=self.args.num_workers,
                    pin_memory=True)

        # ---------------------------------------------------------------------
        # create dataloader for validation
        # ---------------------------------------------------------------------

        # calculate the number of valsets
        valset_num = 0
        for key, value in self.args.valset.items():
            valset_num += len(value)

        # if only one validation set is given
        if valset_num == 1:
            valset = self._load_dataset(list(self.args.valset.keys())[0],
                                        list(self.args.valset.values())[0][0],
                                        is_train=False)
            val_samples = len(valset.idxs)

        # if more than one validation sets are given
        elif valset_num > 1:
            valsets = []
            for set_name, set_dirs in self.args.valset.items():
                for set_dir in set_dirs:
                    valsets.append(
                        self._load_dataset(set_name, set_dir, is_train=False))
            valset = nndata.JointDatasetsWrapper(valsets, [],
                                                 ignore_unlabeled=True)
            val_samples = len(valset.labeled_idxs)

        # NOTE: batch size is set to 1 during the validation
        self.val_loader = torch.utils.data.DataLoader(
            valset,
            batch_size=1,
            shuffle=False,
            num_workers=self.args.num_workers,
            pin_memory=True)

        # check the data loaders
        if self.train_loader is None and not self.args.validation:
            logger.log_err(
                'Train data loader is required if validate mode is closed\n')
        elif self.val_loader is None and self.args.validation:
            logger.log_err(
                'Validate data loader is required if validate mode is opened\n'
            )
        elif self.val_loader is None:
            logger.log_warn(
                'No validate data loader, there are no validation during the training\n'
            )

        # set 'iters_per_epoch', which is required by ITER_LRERS
        self.args.iters_per_epoch = len(
            self.train_loader) if self.train_loader is not None else -1

        logger.log_info(
            'Dataset:\n'
            '  Trainset\t=>\tlabeled samples = {0}\t\tunlabeled samples = {1}\n'
            '  Valset\t=>\tsamples = {2}\n'.format(labeled_train_samples,
                                                   unlabeled_train_samples,
                                                   val_samples))
Ejemplo n.º 24
0
    def __init__(self, args):
        super(SSLGCT, self).__init__(args)

        # define the task model and the flaw detector
        self.l_model, self.r_model, self.fd_model = None, None, None
        self.l_optimizer, self.r_optimizer, self.fd_optimizer = None, None, None
        self.l_lrer, self.r_lrer, self.fd_lrer = None, None, None
        self.l_criterion, self.r_criterion, self.fd_criterion = None, None, None

        # define the extra modules required by GCT
        self.flawmap_handler = None
        self.dcgt_generator = None
        self.fdgt_generator = None
        self.zero_df_gt = torch.zeros(
            [self.args.batch_size, 1, self.args.im_size,
             self.args.im_size]).cuda()

        # prepare the arguments for multiple GPUs
        self.args.fd_lr *= self.args.gpus

        # check SSL arguments
        if self.args.unlabeled_batch_size > 0:
            if self.args.ssl_mode in [MODE_GCT, MODE_FC]:
                if self.args.fc_ssl_scale < 0:
                    logger.log_err(
                        'The argument - fc_ssl_scale - is not set (or invalid)\n'
                        'You enable the flaw correction constraint\n'
                        'Please set - fc_ssl_scale >= 0 - for training\n')
            if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
                if self.args.dc_rampup_epochs < 0:
                    logger.log_err(
                        'The argument - dc_rampup_epochs - is not set (or invalid)\n'
                        'You enable the dynamic consistency constraint\n'
                        'Please set - dc_rampup_epochs >= 0 - for training\n')
                elif self.args.dc_ssl_scale < 0:
                    logger.log_err(
                        'The argument - dc_ssl_scale - is not set (or invalid)\n'
                        'You enable the dynamic consistency constraint\n'
                        'Please set - dc_ssl_scale >= 0 - for training\n')
                elif self.args.dc_threshold < 0:
                    logger.log_err(
                        'The argument - dc_threshold - is not set (or invalid)\n'
                        'You enable the dynamic consistency constraint\n'
                        'Please set - dc_threshold >= 0 - for training\n')
                elif self.args.mu < 0:
                    logger.log_err(
                        'The argument - mu - is not set (or invalid)\n'
                        'Please set - 0 < mu <= 1 - for training\n')
                elif self.args.nu < 0:
                    logger.log_err(
                        'The argument - nu - is not set (or invalid)\n'
                        'Please set - nu > 0 - for training\n')
Ejemplo n.º 25
0
    def _train(self, data_loader, epoch):
        # disable unlabeled data
        without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0
        if not without_unlabeled_data:
            logger.log_err(
                'SSL_NULL is a supervised-only algorithm\n'
                'Please set ignore_unlabeled = True and unlabeled_batch_size = 0\n'
            )

        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            # both 'inp' and 'gt' are tuples
            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            self.optimizer.zero_grad()

            # forward the task model
            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            # calculate the supervised task constraint on the labeled data
            l_pred = func.split_tensor_tuple(pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            task_loss = self.criterion.forward(l_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # backward and update the task model
            loss = task_loss
            loss.backward()
            self.optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()