Exemple #1
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
        self.task_func = task_func

        # create models
        self.model = func.create_model(model_funcs[0], 'model', args=self.args)
        self.d_model = func.create_model(FCDiscriminator, 'd_model', in_channels=self.task_func.ssladv_fcd_in_channels())
        # call 'patch_replication_callback' to enable the `sync_batchnorm` layer
        patch_replication_callback(self.model)
        patch_replication_callback(self.d_model)
        self.models = {'model': self.model, 'd_model': self.d_model}

        # create optimizers
        self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
        self.d_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.d_model.parameters()), 
                                      lr=self.args.discriminator_lr, betas=(0.9, 0.99))
        self.optimizers = {'optimizer': self.optimizer, 'd_optimizer': self.d_optimizer}

        # create lrers
        self.lrer = lrer_funcs[0](self.optimizer)
        self.d_lrer = PolynomialLR(self.d_optimizer, self.args.epochs, self.args.iters_per_epoch, 
                                   power=self.args.discriminator_power, last_epoch=-1)
        self.lrers = {'lrer': self.lrer, 'd_lrer': self.d_lrer}

        # create criterions
        self.criterion = criterion_funcs[0](self.args)
        self.d_criterion = FCDiscriminatorCriterion()
        self.criterions = {'criterion': self.criterion, 'd_criterion': self.d_criterion}

        self._algorithm_warn()
Exemple #2
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
               task_func):
        self.task_func = task_func

        # create models
        # 'l_' denotes the first task model while 'r_' denotes the second task model
        self.l_model = func.create_model(model_funcs[0],
                                         'l_model',
                                         args=self.args)
        self.r_model = func.create_model(model_funcs[1],
                                         'r_model',
                                         args=self.args)
        self.fd_model = func.create_model(
            FlawDetector,
            'fd_model',
            in_channels=self.task_func.sslgct_fd_in_channels())
        # call 'patch_replication_callback' to enable the `sync_batchnorm` layer
        patch_replication_callback(self.l_model)
        patch_replication_callback(self.r_model)
        patch_replication_callback(self.fd_model)
        self.models = {
            'l_model': self.l_model,
            'r_model': self.r_model,
            'fd_model': self.fd_model
        }

        # create optimizers
        self.l_optimizer = optimizer_funcs[0](self.l_model.module.param_groups)
        self.r_optimizer = optimizer_funcs[1](self.r_model.module.param_groups)
        self.fd_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                              self.fd_model.parameters()),
                                       lr=self.args.fd_lr,
                                       betas=(0.9, 0.99))
        self.optimizers = {
            'l_optimizer': self.l_optimizer,
            'r_optimizer': self.r_optimizer,
            'fd_optimizer': self.fd_optimizer
        }

        # create lrers
        self.l_lrer = lrer_funcs[0](self.l_optimizer)
        self.r_lrer = lrer_funcs[1](self.r_optimizer)
        self.fd_lrer = PolynomialLR(self.fd_optimizer,
                                    self.args.epochs,
                                    self.args.iters_per_epoch,
                                    power=0.9,
                                    last_epoch=-1)
        self.lrers = {
            'l_lrer': self.l_lrer,
            'r_lrer': self.r_lrer,
            'fd_lrer': self.fd_lrer
        }

        # create criterions
        self.l_criterion = criterion_funcs[0](self.args)
        self.r_criterion = criterion_funcs[1](self.args)
        self.fd_criterion = FlawDetectorCriterion()
        self.dc_criterion = torch.nn.MSELoss()
        self.criterions = {
            'l_criterion': self.l_criterion,
            'r_criterion': self.r_criterion,
            'fd_criterion': self.fd_criterion,
            'dc_criterion': self.dc_criterion
        }

        # build the extra modules required by GCT
        self.flawmap_handler = nn.DataParallel(FlawmapHandler(
            self.args)).cuda()
        self.dcgt_generator = nn.DataParallel(DCGTGenerator(self.args)).cuda()
        self.fdgt_generator = nn.DataParallel(FDGTGenerator(self.args)).cuda()