示例#1
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
        self.task_func = task_func

        # create models
        self.model = func.create_model(model_funcs[0], 'model', args=self.args)
        self.d_model = func.create_model(FCDiscriminator, 'd_model', in_channels=self.task_func.ssladv_fcd_in_channels())
        # call 'patch_replication_callback' to enable the `sync_batchnorm` layer
        patch_replication_callback(self.model)
        patch_replication_callback(self.d_model)
        self.models = {'model': self.model, 'd_model': self.d_model}

        # create optimizers
        self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
        self.d_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.d_model.parameters()), 
                                      lr=self.args.discriminator_lr, betas=(0.9, 0.99))
        self.optimizers = {'optimizer': self.optimizer, 'd_optimizer': self.d_optimizer}

        # create lrers
        self.lrer = lrer_funcs[0](self.optimizer)
        self.d_lrer = PolynomialLR(self.d_optimizer, self.args.epochs, self.args.iters_per_epoch, 
                                   power=self.args.discriminator_power, last_epoch=-1)
        self.lrers = {'lrer': self.lrer, 'd_lrer': self.d_lrer}

        # create criterions
        self.criterion = criterion_funcs[0](self.args)
        self.d_criterion = FCDiscriminatorCriterion()
        self.criterions = {'criterion': self.criterion, 'd_criterion': self.d_criterion}

        self._algorithm_warn()
示例#2
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
        self.task_func = task_func

        # create models
        self.s_model = func.create_model(model_funcs[0], 's_model', args=self.args)
        self.t_model = func.create_model(model_funcs[0], 't_model', args=self.args)
        # call 'patch_replication_callback' to use the `sync_batchnorm` layer
        patch_replication_callback(self.s_model)
        patch_replication_callback(self.t_model)
        # detach the teacher model
        for param in self.t_model.parameters():
            param.detach_()
        self.models = {'s_model': self.s_model, 't_model': self.t_model}

        # create optimizers
        self.s_optimizer = optimizer_funcs[0](self.s_model.module.param_groups)
        self.optimizers = {'s_optimizer': self.s_optimizer}

        # create lrers
        self.s_lrer = lrer_funcs[0](self.s_optimizer)
        self.lrers = {'s_lrer': self.s_lrer}

        # create criterions
        # TODO: support more types of the consistency criterion
        self.cons_criterion = nn.MSELoss()
        self.s_criterion = criterion_funcs[0](self.args)
        self.criterions = {'s_criterion': self.s_criterion, 'cons_criterion': self.cons_criterion}

        # create the gaussian noiser
        self.gaussian_noiser = nn.DataParallel(GaussianNoiseLayer(self.args.gaussian_noise_std)).cuda()

        self._algorithm_warn()
示例#3
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
               task_func):
        self.task_func = task_func

        # create models
        self.s_model = func.create_model(model_funcs[0],
                                         's_model',
                                         args=self.args)
        self.t_model = func.create_model(model_funcs[0],
                                         't_model',
                                         args=self.args)
        # call 'patch_replication_callback' to use the `sync_batchnorm` layer
        patch_replication_callback(self.s_model)
        patch_replication_callback(self.t_model)
        # detach the teacher model
        for param in self.t_model.parameters():
            param.detach_()
        self.models = {'s_model': self.s_model, 't_model': self.t_model}

        # create optimizers
        self.s_optimizer = optimizer_funcs[0](self.s_model.module.param_groups)
        self.optimizers = {'s_optimizer': self.s_optimizer}

        # create lrers
        self.s_lrer = lrer_funcs[0](self.s_optimizer)
        self.lrers = {'s_lrer': self.s_lrer}

        # create criterions
        self.s_criterion = criterion_funcs[0](self.args)
        # TODO: support more types of the consistency criterion
        if self.args.cons_type == 'mse':
            self.cons_criterion = nn.MSELoss()
        self.criterions = {
            's_criterion': self.s_criterion,
            'cons_criterion': self.cons_criterion
        }

        # build the auxiliary modules required by CUTMIX
        # NOTE: this setting follow the original paper of CUTMIX
        self.mask_generator = BoxMaskGenerator(
            prop_range=self.args.mask_prop_range,
            boxes_num=1,
            random_aspect_ratio=True,
            area_prop=True,
            within_bounds=True,
            invert=True)

        self._algorithm_warn()
示例#4
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
               task_func):
        self.task_func = task_func

        # create models
        self.task_model = func.create_model(model_funcs[0],
                                            'task_model',
                                            args=self.args).module
        self.rotation_classifier = RotationClassifer(
            self.task_func.ssls4l_rc_in_channels())

        # wrap 'self.task_model' and 'self.rotation_classifier' into a single model
        self.model = WrappedS4LModel(self.args, self.task_model,
                                     self.rotation_classifier)
        self.model = nn.DataParallel(self.model).cuda()

        # call 'patch_replication_callback' to use the `sync_batchnorm` layer
        patch_replication_callback(self.model)
        self.models = {'model': self.model}

        # create optimizers
        self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
        self.optimizers = {'optimizer': self.optimizer}

        # create lrers
        self.lrer = lrer_funcs[0](self.optimizer)
        self.lrers = {'lrer': self.lrer}

        # create criterions
        self.criterion = criterion_funcs[0](self.args)
        self.rotation_criterion = nn.CrossEntropyLoss()
        self.criterions = {
            'criterion': self.criterion,
            'rotation_criterion': self.rotation_criterion
        }

        # the batch size is doubled in S4L since it creates an extra rotated sample for each sample
        self.args.batch_size *= 2
        self.args.labeled_batch_size *= 2
        self.args.unlabeled_batch_size *= 2

        logger.log_info('In SSL_S4L algorithm, batch size are doubled: \n'
                        '  Total labeled batch size: {1}\n'
                        '  Total unlabeled batch size: {2}\n'.format(
                            self.args.lr, self.args.labeled_batch_size,
                            self.args.unlabeled_batch_size))

        self._algorithm_warn()
示例#5
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
               task_func):
        self.task_func = task_func

        # create models
        self.model = func.create_model(model_funcs[0], 'model', args=self.args)
        # call 'patch_replication_callback' to enable the `sync_batchnorm` layer
        patch_replication_callback(self.model)
        self.models = {'model': self.model}

        # create optimizers
        self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
        self.optimizers = {'optimizer': self.optimizer}

        # create lrers
        self.lrer = lrer_funcs[0](self.optimizer)
        self.lrers = {'lrer': self.lrer}

        # create criterions
        self.criterion = criterion_funcs[0](self.args)
        self.criterions = {'criterion': self.criterion}

        self._algorithm_warn()
示例#6
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
               task_func):
        self.task_func = task_func

        # create models
        # 'l_' denotes the first task model while 'r_' denotes the second task model
        self.l_model = func.create_model(model_funcs[0],
                                         'l_model',
                                         args=self.args)
        self.r_model = func.create_model(model_funcs[1],
                                         'r_model',
                                         args=self.args)
        self.fd_model = func.create_model(
            FlawDetector,
            'fd_model',
            in_channels=self.task_func.sslgct_fd_in_channels())
        # call 'patch_replication_callback' to enable the `sync_batchnorm` layer
        patch_replication_callback(self.l_model)
        patch_replication_callback(self.r_model)
        patch_replication_callback(self.fd_model)
        self.models = {
            'l_model': self.l_model,
            'r_model': self.r_model,
            'fd_model': self.fd_model
        }

        # create optimizers
        self.l_optimizer = optimizer_funcs[0](self.l_model.module.param_groups)
        self.r_optimizer = optimizer_funcs[1](self.r_model.module.param_groups)
        self.fd_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                              self.fd_model.parameters()),
                                       lr=self.args.fd_lr,
                                       betas=(0.9, 0.99))
        self.optimizers = {
            'l_optimizer': self.l_optimizer,
            'r_optimizer': self.r_optimizer,
            'fd_optimizer': self.fd_optimizer
        }

        # create lrers
        self.l_lrer = lrer_funcs[0](self.l_optimizer)
        self.r_lrer = lrer_funcs[1](self.r_optimizer)
        self.fd_lrer = PolynomialLR(self.fd_optimizer,
                                    self.args.epochs,
                                    self.args.iters_per_epoch,
                                    power=0.9,
                                    last_epoch=-1)
        self.lrers = {
            'l_lrer': self.l_lrer,
            'r_lrer': self.r_lrer,
            'fd_lrer': self.fd_lrer
        }

        # create criterions
        self.l_criterion = criterion_funcs[0](self.args)
        self.r_criterion = criterion_funcs[1](self.args)
        self.fd_criterion = FlawDetectorCriterion()
        self.dc_criterion = torch.nn.MSELoss()
        self.criterions = {
            'l_criterion': self.l_criterion,
            'r_criterion': self.r_criterion,
            'fd_criterion': self.fd_criterion,
            'dc_criterion': self.dc_criterion
        }

        # build the extra modules required by GCT
        self.flawmap_handler = nn.DataParallel(FlawmapHandler(
            self.args)).cuda()
        self.dcgt_generator = nn.DataParallel(DCGTGenerator(self.args)).cuda()
        self.fdgt_generator = nn.DataParallel(FDGTGenerator(self.args)).cuda()
示例#7
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
        self.task_func = task_func

        # create criterions
        # TODO: support more types of the consistency criterion
        self.cons_criterion = nn.MSELoss()
        self.criterion = criterion_funcs[0](self.args)
        self.criterions = {'criterion': self.criterion, 'cons_criterion': self.cons_criterion}

        # create the main task model
        self.main_model = func.create_model(model_funcs[0], 'main_model', args=self.args).module
        
        # create the auxiliary decoders
        vat_decoders = [
            VATDecoder(
                self.task_func.sslcct_ad_upsample_scale(), 
                self.task_func.sslcct_ad_in_channels(), 
                self.task_func.sslcct_ad_out_channels(), 
                xi=self.args.vat_dec_xi, 
                eps=self.args.vat_dec_eps
            ) for _ in range(0, self.args.vat_dec_num)
        ]
        drop_decoders = [
            DropOutDecoder(
                self.task_func.sslcct_ad_upsample_scale(), 
                self.task_func.sslcct_ad_in_channels(), 
                self.task_func.sslcct_ad_out_channels(), 
                drop_rate=self.args.drop_dec_rate, 
                spatial_dropout=self.args.drop_dec_spatial
            ) for _ in range(0, self.args.drop_dec_num)
        ]
        cut_decoders = [
            CutOutDecoder(
                self.task_func.sslcct_ad_upsample_scale(), 
                self.task_func.sslcct_ad_in_channels(), 
                self.task_func.sslcct_ad_out_channels(), 
                erase=self.args.cut_dec_erase
            ) for _ in range(0, self.args.cut_dec_num)
        ]
        context_decoders = [
            ContextMaskingDecoder(
                self.task_func.sslcct_ad_upsample_scale(), 
                self.task_func.sslcct_ad_in_channels(), 
                self.task_func.sslcct_ad_out_channels()
            ) for _ in range(0, self.args.context_dec_num)
        ]
        object_decoders = [
            ObjectMaskingDecoder(
                self.task_func.sslcct_ad_upsample_scale(), 
                self.task_func.sslcct_ad_in_channels(), 
                self.task_func.sslcct_ad_out_channels()
            ) for _ in range(0, self.args.object_dec_num)
        ]
        feature_drop_decoders = [
            FeatureDropDecoder(
                self.task_func.sslcct_ad_upsample_scale(), 
                self.task_func.sslcct_ad_in_channels(), 
                self.task_func.sslcct_ad_out_channels()
            ) for _ in range(0, self.args.fd_dec_num)
        ]
        feature_noise_decoders = [
            FeatureNoiseDecoder(
                self.task_func.sslcct_ad_upsample_scale(), 
                self.task_func.sslcct_ad_in_channels(), 
                self.task_func.sslcct_ad_out_channels(), 
                uniform_range=self.args.fn_dec_uniform
            ) for _ in range(0, self.args.fn_dec_num)
        ]

        self.auxiliary_decoders = nn.ModuleList(
            [
                *vat_decoders, 
                *drop_decoders, 
                *cut_decoders, 
                *context_decoders, 
                *object_decoders, 
                *feature_drop_decoders, 
                *feature_noise_decoders,
            ]
        )

        # wrap 'self.main_model' and 'self.auxiliary decoders' into a single model
        # NOTE: all criterions are wrapped into the model to save the memory of the main GPU
        self.model = WrappedCCTModel(self.args, self.main_model, self.auxiliary_decoders, 
                                     self.criterion, self.cons_criterion, self.task_func.sslcct_activate_ad_preds)
        self.model = nn.DataParallel(self.model).cuda()
        # call 'patch_replication_callback' to use the `sync_batchnorm` layer
        patch_replication_callback(self.model)
        self.models = {'model': self.model}

        # create optimizers
        self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
        self.optimizers = {'optimizer': self.optimizer}

        # create lrers
        self.lrer = lrer_funcs[0](self.optimizer)
        self.lrers = {'lrer': self.lrer}

        self._algorithm_warn()