예제 #1
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
        self.task_func = task_func

        # create models
        self.model = func.create_model(model_funcs[0], 'model', args=self.args)
        self.d_model = func.create_model(FCDiscriminator, 'd_model', in_channels=self.task_func.ssladv_fcd_in_channels())
        # call 'patch_replication_callback' to enable the `sync_batchnorm` layer
        patch_replication_callback(self.model)
        patch_replication_callback(self.d_model)
        self.models = {'model': self.model, 'd_model': self.d_model}

        # create optimizers
        self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
        self.d_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.d_model.parameters()), 
                                      lr=self.args.discriminator_lr, betas=(0.9, 0.99))
        self.optimizers = {'optimizer': self.optimizer, 'd_optimizer': self.d_optimizer}

        # create lrers
        self.lrer = lrer_funcs[0](self.optimizer)
        self.d_lrer = PolynomialLR(self.d_optimizer, self.args.epochs, self.args.iters_per_epoch, 
                                   power=self.args.discriminator_power, last_epoch=-1)
        self.lrers = {'lrer': self.lrer, 'd_lrer': self.d_lrer}

        # create criterions
        self.criterion = criterion_funcs[0](self.args)
        self.d_criterion = FCDiscriminatorCriterion()
        self.criterions = {'criterion': self.criterion, 'd_criterion': self.d_criterion}

        self._algorithm_warn()
예제 #2
0
    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
               task_func):
        self.task_func = task_func

        # create models
        # 'l_' denotes the first task model while 'r_' denotes the second task model
        self.l_model = func.create_model(model_funcs[0],
                                         'l_model',
                                         args=self.args)
        self.r_model = func.create_model(model_funcs[1],
                                         'r_model',
                                         args=self.args)
        self.fd_model = func.create_model(
            FlawDetector,
            'fd_model',
            in_channels=self.task_func.sslgct_fd_in_channels())
        # call 'patch_replication_callback' to enable the `sync_batchnorm` layer
        patch_replication_callback(self.l_model)
        patch_replication_callback(self.r_model)
        patch_replication_callback(self.fd_model)
        self.models = {
            'l_model': self.l_model,
            'r_model': self.r_model,
            'fd_model': self.fd_model
        }

        # create optimizers
        self.l_optimizer = optimizer_funcs[0](self.l_model.module.param_groups)
        self.r_optimizer = optimizer_funcs[1](self.r_model.module.param_groups)
        self.fd_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                              self.fd_model.parameters()),
                                       lr=self.args.fd_lr,
                                       betas=(0.9, 0.99))
        self.optimizers = {
            'l_optimizer': self.l_optimizer,
            'r_optimizer': self.r_optimizer,
            'fd_optimizer': self.fd_optimizer
        }

        # create lrers
        self.l_lrer = lrer_funcs[0](self.l_optimizer)
        self.r_lrer = lrer_funcs[1](self.r_optimizer)
        self.fd_lrer = PolynomialLR(self.fd_optimizer,
                                    self.args.epochs,
                                    self.args.iters_per_epoch,
                                    power=0.9,
                                    last_epoch=-1)
        self.lrers = {
            'l_lrer': self.l_lrer,
            'r_lrer': self.r_lrer,
            'fd_lrer': self.fd_lrer
        }

        # create criterions
        self.l_criterion = criterion_funcs[0](self.args)
        self.r_criterion = criterion_funcs[1](self.args)
        self.fd_criterion = FlawDetectorCriterion()
        self.dc_criterion = torch.nn.MSELoss()
        self.criterions = {
            'l_criterion': self.l_criterion,
            'r_criterion': self.r_criterion,
            'fd_criterion': self.fd_criterion,
            'dc_criterion': self.dc_criterion
        }

        # build the extra modules required by GCT
        self.flawmap_handler = nn.DataParallel(FlawmapHandler(
            self.args)).cuda()
        self.dcgt_generator = nn.DataParallel(DCGTGenerator(self.args)).cuda()
        self.fdgt_generator = nn.DataParallel(FDGTGenerator(self.args)).cuda()
예제 #3
0
class SSLGCT(ssl_base._SSLBase):
    NAME = 'ssl_gct'
    SUPPORTED_TASK_TYPES = [REGRESSION, CLASSIFICATION]

    def __init__(self, args):
        super(SSLGCT, self).__init__(args)

        # define the task model and the flaw detector
        self.l_model, self.r_model, self.fd_model = None, None, None
        self.l_optimizer, self.r_optimizer, self.fd_optimizer = None, None, None
        self.l_lrer, self.r_lrer, self.fd_lrer = None, None, None
        self.l_criterion, self.r_criterion, self.fd_criterion = None, None, None

        # define the extra modules required by GCT
        self.flawmap_handler = None
        self.dcgt_generator = None
        self.fdgt_generator = None
        self.zero_df_gt = torch.zeros(
            [self.args.batch_size, 1, self.args.im_size,
             self.args.im_size]).cuda()

        # prepare the arguments for multiple GPUs
        self.args.fd_lr *= self.args.gpus

        # check SSL arguments
        if self.args.unlabeled_batch_size > 0:
            if self.args.ssl_mode in [MODE_GCT, MODE_FC]:
                if self.args.fc_ssl_scale < 0:
                    logger.log_err(
                        'The argument - fc_ssl_scale - is not set (or invalid)\n'
                        'You enable the flaw correction constraint\n'
                        'Please set - fc_ssl_scale >= 0 - for training\n')
            if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
                if self.args.dc_rampup_epochs < 0:
                    logger.log_err(
                        'The argument - dc_rampup_epochs - is not set (or invalid)\n'
                        'You enable the dynamic consistency constraint\n'
                        'Please set - dc_rampup_epochs >= 0 - for training\n')
                elif self.args.dc_ssl_scale < 0:
                    logger.log_err(
                        'The argument - dc_ssl_scale - is not set (or invalid)\n'
                        'You enable the dynamic consistency constraint\n'
                        'Please set - dc_ssl_scale >= 0 - for training\n')
                elif self.args.dc_threshold < 0:
                    logger.log_err(
                        'The argument - dc_threshold - is not set (or invalid)\n'
                        'You enable the dynamic consistency constraint\n'
                        'Please set - dc_threshold >= 0 - for training\n')
                elif self.args.mu < 0:
                    logger.log_err(
                        'The argument - mu - is not set (or invalid)\n'
                        'Please set - 0 < mu <= 1 - for training\n')
                elif self.args.nu < 0:
                    logger.log_err(
                        'The argument - nu - is not set (or invalid)\n'
                        'Please set - nu > 0 - for training\n')

    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
               task_func):
        self.task_func = task_func

        # create models
        # 'l_' denotes the first task model while 'r_' denotes the second task model
        self.l_model = func.create_model(model_funcs[0],
                                         'l_model',
                                         args=self.args)
        self.r_model = func.create_model(model_funcs[1],
                                         'r_model',
                                         args=self.args)
        self.fd_model = func.create_model(
            FlawDetector,
            'fd_model',
            in_channels=self.task_func.sslgct_fd_in_channels())
        # call 'patch_replication_callback' to enable the `sync_batchnorm` layer
        patch_replication_callback(self.l_model)
        patch_replication_callback(self.r_model)
        patch_replication_callback(self.fd_model)
        self.models = {
            'l_model': self.l_model,
            'r_model': self.r_model,
            'fd_model': self.fd_model
        }

        # create optimizers
        self.l_optimizer = optimizer_funcs[0](self.l_model.module.param_groups)
        self.r_optimizer = optimizer_funcs[1](self.r_model.module.param_groups)
        self.fd_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                              self.fd_model.parameters()),
                                       lr=self.args.fd_lr,
                                       betas=(0.9, 0.99))
        self.optimizers = {
            'l_optimizer': self.l_optimizer,
            'r_optimizer': self.r_optimizer,
            'fd_optimizer': self.fd_optimizer
        }

        # create lrers
        self.l_lrer = lrer_funcs[0](self.l_optimizer)
        self.r_lrer = lrer_funcs[1](self.r_optimizer)
        self.fd_lrer = PolynomialLR(self.fd_optimizer,
                                    self.args.epochs,
                                    self.args.iters_per_epoch,
                                    power=0.9,
                                    last_epoch=-1)
        self.lrers = {
            'l_lrer': self.l_lrer,
            'r_lrer': self.r_lrer,
            'fd_lrer': self.fd_lrer
        }

        # create criterions
        self.l_criterion = criterion_funcs[0](self.args)
        self.r_criterion = criterion_funcs[1](self.args)
        self.fd_criterion = FlawDetectorCriterion()
        self.dc_criterion = torch.nn.MSELoss()
        self.criterions = {
            'l_criterion': self.l_criterion,
            'r_criterion': self.r_criterion,
            'fd_criterion': self.fd_criterion,
            'dc_criterion': self.dc_criterion
        }

        # build the extra modules required by GCT
        self.flawmap_handler = nn.DataParallel(FlawmapHandler(
            self.args)).cuda()
        self.dcgt_generator = nn.DataParallel(DCGTGenerator(self.args)).cuda()
        self.fdgt_generator = nn.DataParallel(FDGTGenerator(self.args)).cuda()

    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.l_model.train()
        self.r_model.train()
        self.fd_model.train()

        # both 'inp' and 'gt' are tuples
        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt)
            if len(l_gt) == len(r_gt) > 1 and idx == 0:
                self._inp_warn()

            # calculate the ramp-up coefficient of the dynamic consistency constraint
            cur_steps = len(data_loader) * epoch + idx
            total_steps = len(data_loader) * self.args.dc_rampup_epochs
            dc_rampup_scale = func.sigmoid_rampup(cur_steps, total_steps)

            # -----------------------------------------------------------------------------
            # step-0: pre-forwarding to save GPU memory
            #   - forward the task models and the flaw detector
            #   - generate pseudo ground truth for the unlabeled data if the dynamic
            #     consistency constraint is enabled
            # -----------------------------------------------------------------------------
            with torch.no_grad():
                l_resulter, l_debugger = self.l_model.forward(l_inp)
                l_activated_pred = tool.dict_value(l_resulter,
                                                   'activated_pred')
                r_resulter, r_debugger = self.r_model.forward(r_inp)
                r_activated_pred = tool.dict_value(r_resulter,
                                                   'activated_pred')

            # 'l_flawmap' and 'r_flawmap' will be used in step-2
            fd_resulter, fd_debugger = self.fd_model.forward(
                l_inp, l_activated_pred[0])
            l_flawmap = tool.dict_value(fd_resulter, 'flawmap')
            fd_resulter, fd_debugger = self.fd_model.forward(
                r_inp, r_activated_pred[0])
            r_flawmap = tool.dict_value(fd_resulter, 'flawmap')

            l_dc_gt, r_dc_gt = None, None
            l_fc_mask, r_fc_mask = None, None

            # generate the pseudo ground truth for the dynamic consistency constraint
            if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
                with torch.no_grad():
                    l_handled_flawmap = self.flawmap_handler.forward(l_flawmap)
                    r_handled_flawmap = self.flawmap_handler.forward(r_flawmap)
                    l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward(
                        l_activated_pred[0].detach(),
                        r_activated_pred[0].detach(), l_handled_flawmap,
                        r_handled_flawmap)

            # -----------------------------------------------------------------------------
            # step-1: train the task models
            # -----------------------------------------------------------------------------
            for param in self.fd_model.parameters():
                param.requires_grad = False

            # train the 'l' task model
            l_loss = self._task_model_iter(epoch, idx, True, 'l', lbs, l_inp,
                                           l_gt, l_dc_gt, l_fc_mask,
                                           dc_rampup_scale)
            self.l_optimizer.zero_grad()
            l_loss.backward()
            self.l_optimizer.step()

            # train the 'r' task model
            r_loss = self._task_model_iter(epoch, idx, True, 'r', lbs, r_inp,
                                           r_gt, r_dc_gt, r_fc_mask,
                                           dc_rampup_scale)
            self.r_optimizer.zero_grad()
            r_loss.backward()
            self.r_optimizer.step()

            # -----------------------------------------------------------------------------
            # step-2: train the flaw detector
            # -----------------------------------------------------------------------------
            for param in self.fd_model.parameters():
                param.requires_grad = True

            # generate the ground truth for the flaw detector (on labeled data only)
            with torch.no_grad():
                l_flawmap_gt = self.fdgt_generator.forward(
                    l_activated_pred[0][:lbs, ...].detach(),
                    self.task_func.sslgct_prepare_task_gt_for_fdgt(
                        l_gt[0][:lbs, ...]))
                r_flawmap_gt = self.fdgt_generator.forward(
                    r_activated_pred[0][:lbs, ...].detach(),
                    self.task_func.sslgct_prepare_task_gt_for_fdgt(
                        r_gt[0][:lbs, ...]))

            l_fd_loss = self.fd_criterion.forward(l_flawmap[:lbs, ...],
                                                  l_flawmap_gt)
            l_fd_loss = self.args.fd_scale * torch.mean(l_fd_loss)
            self.meters.update('l_fd_loss', l_fd_loss.data)

            r_fd_loss = self.fd_criterion.forward(r_flawmap[:lbs, ...],
                                                  r_flawmap_gt)
            r_fd_loss = self.args.fd_scale * torch.mean(r_fd_loss)
            self.meters.update('r_fd_loss', r_fd_loss.data)

            fd_loss = (l_fd_loss + r_fd_loss) / 2

            self.fd_optimizer.zero_grad()
            fd_loss.backward()
            self.fd_optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  l-{3}\t=>\t'
                    'l-task-loss: {meters[l_task_loss]:.6f}\t'
                    'l-dc-loss: {meters[l_dc_loss]:.6f}\t'
                    'l-fc-loss: {meters[l_fc_loss]:.6f}\n'
                    '  r-{3}\t=>\t'
                    'r-task-loss: {meters[r_task_loss]:.6f}\t'
                    'r-dc-loss: {meters[r_dc_loss]:.6f}\t'
                    'r-fc-loss: {meters[r_fc_loss]:.6f}\n'
                    '  fd\t=>\t'
                    'l-fd-loss: {meters[l_fd_loss]:.6f}\t'
                    'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # the flaw detector uses polynomiallr [ITER_LRERS]
            self.fd_lrer.step()
            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.l_lrer.step()
                self.r_lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.l_lrer.step()
            self.r_lrer.step()

    def _validate(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.l_model.eval()
        self.r_model.eval()
        self.fd_model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt)
            if len(l_gt) == len(r_gt) > 1 and idx == 0:
                self._inp_warn()

            l_dc_gt, r_dc_gt = None, None
            l_fc_mask, r_fc_mask = None, None
            if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
                l_resulter, l_debugger = self.l_model.forward(l_inp)
                l_activated_pred = tool.dict_value(l_resulter,
                                                   'activated_pred')
                r_resulter, r_debugger = self.r_model.forward(r_inp)
                r_activated_pred = tool.dict_value(r_resulter,
                                                   'activated_pred')

                fd_resulter, fd_debugger = self.fd_model.forward(
                    l_inp, l_activated_pred[0])
                l_flawmap = tool.dict_value(fd_resulter, 'flawmap')
                fd_resulter, fd_debugger = self.fd_model.forward(
                    r_inp, r_activated_pred[0])
                r_flawmap = tool.dict_value(fd_resulter, 'flawmap')

                l_handled_flawmap = self.flawmap_handler.forward(l_flawmap)
                r_handled_flawmap = self.flawmap_handler.forward(r_flawmap)
                l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward(
                    l_activated_pred[0].detach(), r_activated_pred[0].detach(),
                    l_handled_flawmap, r_handled_flawmap)

            l_loss = self._task_model_iter(epoch, idx, False, 'l', lbs, l_inp,
                                           l_gt, l_dc_gt, l_fc_mask, 1)
            r_loss = self._task_model_iter(epoch, idx, False, 'r', lbs, r_inp,
                                           r_gt, r_dc_gt, r_fc_mask, 1)

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  l-{3}\t=>\t'
                    'l-task-loss: {meters[l_task_loss]:.6f}\t'
                    'l-dc-loss: {meters[l_dc_loss]:.6f}\t'
                    'l-fc-loss: {meters[l_fc_loss]:.6f}\n'
                    '  r-{3}\t=>\t'
                    'r-task-loss: {meters[r_task_loss]:.6f}\t'
                    'r-dc-loss: {meters[r_dc_loss]:.6f}\t'
                    'r-fc-loss: {meters[r_fc_loss]:.6f}\n'
                    '  fd\t=>\t'
                    'l-fd-loss: {meters[l_fd_loss]:.6f}\t'
                    'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format(
                        epoch,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

        # metrics
        metrics_info = {'l': '', 'r': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6f}\t'.format(
                            key, self.meters[key])

        logger.log_info(
            'Validation metrics:\n  l-metrics\t=>\t{0}\n  r-metrics\t=>\t{1}\n'
            .format(metrics_info['l'].replace('_', '-'),
                    metrics_info['r'].replace('_', '-')))

    def _save_checkpoint(self, epoch):
        state = {
            'algorithm': self.NAME,
            'epoch': epoch + 1,
            'l_model': self.l_model.state_dict(),
            'r_model': self.r_model.state_dict(),
            'fd_model': self.fd_model.state_dict(),
            'l_optimizer': self.l_optimizer.state_dict(),
            'r_optimizer': self.r_optimizer.state_dict(),
            'fd_optimizer': self.fd_optimizer.state_dict(),
            'l_lrer': self.l_lrer.state_dict(),
            'r_lrer': self.r_lrer.state_dict(),
            'fd_lrer': self.fd_lrer.state_dict()
        }

        checkpoint = os.path.join(self.args.checkpoint_path,
                                  'checkpoint_{0}.ckpt'.format(epoch))
        torch.save(state, checkpoint)

    def _load_checkpoint(self):
        checkpoint = torch.load(self.args.resume)

        checkpoint_algorithm = tool.dict_value(checkpoint,
                                               'algorithm',
                                               default='unknown')
        if checkpoint_algorithm != self.NAME:
            logger.log_err(
                'Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n'
                .format(self.NAME, checkpoint_algorithm))

        self.l_model.load_state_dict(checkpoint['l_model'])
        self.r_model.load_state_dict(checkpoint['r_model'])
        self.fd_model.load_state_dict(checkpoint['fd_model'])
        self.l_optimizer.load_state_dict(checkpoint['l_optimizer'])
        self.r_optimizer.load_state_dict(checkpoint['r_optimizer'])
        self.fd_optimizer.load_state_dict(checkpoint['fd_optimizer'])
        self.l_lrer.load_state_dict(checkpoint['l_lrer'])
        self.r_lrer.load_state_dict(checkpoint['r_lrer'])
        self.fd_lrer.load_state_dict(checkpoint['fd_lrer'])

        return checkpoint['epoch']

    def _task_model_iter(self, epoch, idx, is_train, mid, lbs, inp, gt, dc_gt,
                         fc_mask, dc_rampup_scale):
        if mid == 'l':
            model, criterion = self.l_model, self.l_criterion
        elif mid == 'r':
            model, criterion = self.r_model, self.r_criterion
        else:
            model, criterion = None, None

        # forward the task model
        resulter, debugger = model.forward(inp)
        if not 'pred' in resulter.keys(
        ) or not 'activated_pred' in resulter.keys():
            self._pred_err()

        pred = tool.dict_value(resulter, 'pred')
        activated_pred = tool.dict_value(resulter, 'activated_pred')

        fd_resulter, fd_debugger = self.fd_model.forward(
            inp, activated_pred[0])
        flawmap = tool.dict_value(fd_resulter, 'flawmap')

        # calculate the supervised task constraint on the labeled data
        labeled_pred = func.split_tensor_tuple(pred, 0, lbs)
        labeled_gt = func.split_tensor_tuple(gt, 0, lbs)
        labeled_inp = func.split_tensor_tuple(inp, 0, lbs)
        task_loss = torch.mean(
            criterion.forward(labeled_pred, labeled_gt, labeled_inp))
        self.meters.update('{0}_task_loss'.format(mid), task_loss.data)

        # calculate the flaw correction constraint
        if self.args.ssl_mode in [MODE_GCT, MODE_FC]:
            if flawmap.shape == self.zero_df_gt.shape:
                fc_ssl_loss = self.fd_criterion.forward(flawmap,
                                                        self.zero_df_gt,
                                                        is_ssl=True,
                                                        reduction=False)
            else:
                fc_ssl_loss = self.fd_criterion.forward(
                    flawmap,
                    torch.zeros(flawmap.shape).cuda(),
                    is_ssl=True,
                    reduction=False)

            if self.args.ssl_mode == MODE_GCT:
                fc_ssl_loss = fc_mask * fc_ssl_loss

            fc_ssl_loss = self.args.fc_ssl_scale * torch.mean(fc_ssl_loss)
            self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss.data)
        else:
            fc_ssl_loss = 0
            self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss)

        # calculate the dynamic consistency constraint
        if self.args.ssl_mode in [MODE_GCT, MODE_DC]:
            if dc_gt is None:
                logger.log_err(
                    'The dynamic consistency constraint is enabled, '
                    'but no pseudo ground truth is given.')

            dc_ssl_loss = self.dc_criterion.forward(activated_pred[0], dc_gt)
            dc_ssl_loss = dc_rampup_scale * self.args.dc_ssl_scale * torch.mean(
                dc_ssl_loss)
            self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss.data)
        else:
            dc_ssl_loss = 0
            self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss)

        with torch.no_grad():
            flawmap_gt = self.fdgt_generator.forward(
                activated_pred[0],
                self.task_func.sslgct_prepare_task_gt_for_fdgt(gt[0]))

        # for validation
        if not is_train:
            fd_loss = self.args.fd_scale * self.fd_criterion.forward(
                flawmap, flawmap_gt)
            self.meters.update('{0}_fd_loss'.format(mid),
                               torch.mean(fd_loss).data)

            self.task_func.metrics(activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str=mid)

        # visualization
        if self.args.visualize and idx % self.args.visual_freq == 0:
            with torch.no_grad():
                handled_flawmap = self.flawmap_handler(flawmap)[0]

            self._visualize(
                epoch, idx, is_train, mid,
                func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True),
                func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                handled_flawmap, flawmap_gt[0], dc_gt[0])

        loss = task_loss + fc_ssl_loss + dc_ssl_loss
        return loss

    def _visualize(self,
                   epoch,
                   idx,
                   is_train,
                   id_str,
                   inp,
                   pred,
                   gt,
                   flawmap,
                   flawmap_gt,
                   dc_gt=None):
        visualize_path = self.args.visual_train_path if is_train else self.args.visual_val_path
        out_path = os.path.join(visualize_path, '{0}_{1}'.format(epoch, idx))

        self.task_func.visualize(out_path,
                                 id_str=id_str,
                                 inp=inp,
                                 pred=pred,
                                 gt=gt)

        flawmap = flawmap[0].data.cpu().numpy()
        Image.fromarray(
            (flawmap * 255).astype('uint8'),
            mode='L').save(out_path + '_{0}-fmap.png'.format(id_str))
        flawmap_gt = flawmap_gt[0].data.cpu().numpy()
        Image.fromarray(
            (flawmap_gt * 255).astype('uint8'),
            mode='L').save(out_path + '_{0}-fmap-gt.png'.format(id_str))

        if dc_gt is not None:
            self.task_func.visualize(out_path,
                                     id_str=id_str + '_dc',
                                     inp=None,
                                     pred=[dc_gt],
                                     gt=None)

    def _batch_prehandle(self, inp, gt):
        # add extra data augmentation process here if necessary

        inp_var = []
        for i in inp:
            inp_var.append(Variable(i).cuda())
        inp = tuple(inp_var)

        gt_var = []
        for g in gt:
            gt_var.append(Variable(g).cuda())
        gt = tuple(gt_var)

        # augment 'inp' for different task models here if necessary
        l_inp, r_inp = inp, inp
        l_gt, r_gt = gt, gt

        return (l_inp, l_gt), (r_inp, r_gt)

    def _inp_warn(self):
        logger.log_warn(
            'More than one ground truth of the task model is given in SSL_GCT\n'
            'You try to train the task model with more than one (pred & gt) pairs\n'
            'Please make sure that:\n'
            '  (1) The prediction tuple has the same size as the ground truth tuple\n'
            '  (2) The elements with the same index in the two tuples are corresponding\n'
            '  (3) The first element of (pred & gt) will be used to train the flaw detector\n'
            '      and calculate the SSL constraints\n'
            'Please implement a new SSL algorithm if you want a variant of SSL_GCT with\n'
            'multiple flaw detectors (for multiple predictions)\n')

    def _pred_err(self):
        logger.log_err(
            'In SSL_GCT, the \'resulter\' dict returned by the task model should contain the following keys:\n'
            '   (1) \'pred\'\t=>\tunactivated task predictions\n'
            '   (2) \'activated_pred\'\t=>\tactivated task predictions\n'
            'We need both of them since some losses include the activation functions,\n'
            'e.g., the CrossEntropyLoss has contained SoftMax\n')
예제 #4
0
class SSLADV(ssl_base._SSLBase):
    NAME = 'ssl_adv'
    SUPPORTED_TASK_TYPES = [REGRESSION, CLASSIFICATION]

    def __init__(self, args):
        super(SSLADV, self).__init__(args)

        # define the task model and the FC discriminator
        self.model, self.d_model = None, None
        self.optimizer, self.d_optimizer = None, None
        self.lrer, self.d_lrer = None, None
        self.criterion, self.d_criterion = None, None

        # prepare the arguments for multiple GPUs
        self.args.discriminator_lr *= self.args.gpus

        # check SSL arguments
        if self.args.adv_for_labeled:
            if self.args.labeled_adv_scale < 0:
                logger.log_err(
                    'The argument - labeled_adv_scale - is not set (or invalid)\n'
                    'You set argument - adv_for_labeled - to True\n'
                    'Please set - labeled_adv_scale >= 0 - for calculating the'
                    'adversarial loss on the labeled data\n')
        if self.args.unlabeled_batch_size > 0:
            if self.args.unlabeled_adv_scale < 0:
                logger.log_err(
                    'The argument - unlabeled_adv_scale - is not set (or invalid)\n'
                    'You set argument - unlabeled_batch_size - larger than 0\n'
                    'Please set - unlabeled_adv_scale >= 0 - for calculating the'
                    'adversarial loss on the unlabeled data\n')

    def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs,
               task_func):
        self.task_func = task_func

        # create models
        self.model = func.create_model(model_funcs[0], 'model', args=self.args)
        self.d_model = func.create_model(
            FCDiscriminator,
            'd_model',
            in_channels=self.task_func.ssladv_fcd_in_channels())
        # call 'patch_replication_callback' to enable the `sync_batchnorm` layer
        patch_replication_callback(self.model)
        patch_replication_callback(self.d_model)
        self.models = {'model': self.model, 'd_model': self.d_model}

        # create optimizers
        self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
        self.d_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                             self.d_model.parameters()),
                                      lr=self.args.discriminator_lr,
                                      betas=(0.9, 0.99))
        self.optimizers = {
            'optimizer': self.optimizer,
            'd_optimizer': self.d_optimizer
        }

        # create lrers
        self.lrer = lrer_funcs[0](self.optimizer)
        self.d_lrer = PolynomialLR(self.d_optimizer,
                                   self.args.epochs,
                                   self.args.iters_per_epoch,
                                   power=self.args.discriminator_power,
                                   last_epoch=-1)
        self.lrers = {'lrer': self.lrer, 'd_lrer': self.d_lrer}

        # create criterions
        self.criterion = criterion_funcs[0](self.args)
        self.d_criterion = FCDiscriminatorCriterion()
        self.criterions = {
            'criterion': self.criterion,
            'd_criterion': self.d_criterion
        }

        self._algorithm_warn()

    def _train(self, data_loader, epoch):
        self.meters.reset()
        lbs = self.args.labeled_batch_size

        self.model.train()
        self.d_model.train()

        # both 'inp' and 'gt' are tuples
        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            # -----------------------------------------------------------------------------
            # step-1: train the task model
            # -----------------------------------------------------------------------------
            self.optimizer.zero_grad()

            # forward the task model
            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            # forward the FC discriminator
            # 'confidence_map' is a tensor
            d_resulter, d_debugger = self.d_model.forward(activated_pred[0])
            confidence_map = tool.dict_value(d_resulter, 'confidence')

            # calculate the supervised task constraint on the labeled data
            l_pred = func.split_tensor_tuple(pred, 0, lbs)
            l_gt = func.split_tensor_tuple(gt, 0, lbs)
            l_inp = func.split_tensor_tuple(inp, 0, lbs)

            # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size
            task_loss = self.criterion.forward(l_pred, l_gt, l_inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            # calculate the adversarial constraint
            # calculate the adversarial constraint for the labeled data
            if self.args.adv_for_labeled:
                l_confidence_map = confidence_map[:lbs, ...]

                # preprocess prediction and ground truch for the adversarial constraint
                l_adv_confidence_map, l_adv_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(l_confidence_map, l_gt[0], True)
                l_adv_loss = self.d_criterion(l_adv_confidence_map,
                                              l_adv_confidence_gt)
                labeled_adv_loss = self.args.labeled_adv_scale * torch.mean(
                    l_adv_loss)
                self.meters.update('labeled_adv_loss', labeled_adv_loss.data)
            else:
                labeled_adv_loss = 0
                self.meters.update('labeled_adv_loss', labeled_adv_loss)

            # calculate the adversarial constraint for the unlabeled data
            if self.args.unlabeled_batch_size > 0:
                u_confidence_map = confidence_map[lbs:self.args.batch_size,
                                                  ...]

                # preprocess prediction and ground truch for the adversarial constraint
                u_adv_confidence_map, u_adv_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(u_confidence_map, None, True)
                u_adv_loss = self.d_criterion(u_adv_confidence_map,
                                              u_adv_confidence_gt)
                unlabeled_adv_loss = self.args.unlabeled_adv_scale * torch.mean(
                    u_adv_loss)
                self.meters.update('unlabeled_adv_loss',
                                   unlabeled_adv_loss.data)
            else:
                unlabeled_adv_loss = 0
                self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss)

            adv_loss = labeled_adv_loss + unlabeled_adv_loss

            # backward and update the task model
            loss = task_loss + adv_loss
            loss.backward()
            self.optimizer.step()

            # -----------------------------------------------------------------------------
            # step-2: train the FC discriminator
            # -----------------------------------------------------------------------------
            self.d_optimizer.zero_grad()

            # forward the task prediction (fake)
            if self.args.unlabeled_for_discriminator:
                fake_pred = activated_pred[0].detach()
            else:
                fake_pred = activated_pred[0][:lbs, ...].detach()

            d_resulter, d_debugger = self.d_model.forward(fake_pred)
            fake_confidence_map = tool.dict_value(d_resulter, 'confidence')

            l_fake_confidence_map = fake_confidence_map[:lbs, ...]
            l_fake_confidence_map, l_fake_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(l_fake_confidence_map, l_gt[0], False)

            if self.args.unlabeled_for_discriminator and self.args.unlabeled_batch_size != 0:
                u_fake_confidence_map = fake_confidence_map[
                    lbs:self.args.batch_size, ...]
                u_fake_confidence_map, u_fake_confidence_gt = \
                    self.task_func.ssladv_preprocess_fcd_criterion(u_fake_confidence_map, None, False)

                fake_confidence_map = torch.cat(
                    (l_fake_confidence_map, u_fake_confidence_map), dim=0)
                fake_confidence_gt = torch.cat(
                    (l_fake_confidence_gt, u_fake_confidence_gt), dim=0)

            else:
                fake_confidence_map, fake_confidence_gt = l_fake_confidence_map, l_fake_confidence_gt

            fake_d_loss = self.d_criterion.forward(fake_confidence_map,
                                                   fake_confidence_gt)
            fake_d_loss = self.args.discriminator_scale * torch.mean(
                fake_d_loss)
            self.meters.update('fake_d_loss', fake_d_loss.data)

            # forward the ground truth (real)
            # convert the format of ground truch
            real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(
                l_gt[0])
            d_resulter, d_debugger = self.d_model.forward(real_gt)
            real_confidence_map = tool.dict_value(d_resulter, 'confidence')

            real_confidence_map, real_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(real_confidence_map, l_gt[0], True)

            real_d_loss = self.d_criterion(real_confidence_map,
                                           real_confidence_gt)
            real_d_loss = self.args.discriminator_scale * torch.mean(
                real_d_loss)
            self.meters.update('real_d_loss', real_d_loss.data)

            # backward and update the FC discriminator
            d_loss = (fake_d_loss + real_d_loss) / 2
            d_loss.backward()
            self.d_optimizer.step()

            # logging
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'
                    'labeled-adv-loss: {meters[labeled_adv_loss]:.6f}\t'
                    'unlabeled-adv-loss: {meters[unlabeled_adv_loss]:.6f}\n'
                    '  fc-discriminator\t=>\t'
                    'fake-d-loss: {meters[fake_d_loss]:.6f}\t'
                    'real-d-loss: {meters[real_d_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            # visualization
            if self.args.visualize and idx % self.args.visual_freq == 0:
                u_inp_sample, u_pred_sample, u_cmap_sample = None, None, None
                if self.args.unlabeled_batch_size > 0:
                    u_inp_sample = func.split_tensor_tuple(inp,
                                                           lbs,
                                                           lbs + 1,
                                                           reduce_dim=True)
                    u_pred_sample = func.split_tensor_tuple(activated_pred,
                                                            lbs,
                                                            lbs + 1,
                                                            reduce_dim=True)
                    u_cmap_sample = torch.sigmoid(fake_confidence_map[lbs])

                self._visualize(
                    epoch, idx, True,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                    torch.sigmoid(confidence_map[0]), u_inp_sample,
                    u_pred_sample, u_cmap_sample)

            # the FC discriminator uses polynomiallr [ITER_LRERS]
            self.d_lrer.step()
            # update iteration-based lrers
            if not self.args.is_epoch_lrer:
                self.lrer.step()

        # update epoch-based lrers
        if self.args.is_epoch_lrer:
            self.lrer.step()

    def _validate(self, data_loader, epoch):
        self.meters.reset()

        self.model.eval()
        self.d_model.eval()

        for idx, (inp, gt) in enumerate(data_loader):
            timer = time.time()

            inp, gt = self._batch_prehandle(inp, gt)
            if len(gt) > 1 and idx == 0:
                self._inp_warn()

            resulter, debugger = self.model.forward(inp)
            if not 'pred' in resulter.keys(
            ) or not 'activated_pred' in resulter.keys():
                self._pred_err()

            pred = tool.dict_value(resulter, 'pred')
            activated_pred = tool.dict_value(resulter, 'activated_pred')

            task_loss = self.criterion.forward(pred, gt, inp)
            task_loss = torch.mean(task_loss)
            self.meters.update('task_loss', task_loss.data)

            d_resulter, d_debugger = self.d_model.forward(activated_pred[0])
            unhandled_fake_confidence_map = tool.dict_value(
                d_resulter, 'confidence')
            fake_confidence_map, fake_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(unhandled_fake_confidence_map, gt[0], False)

            fake_d_loss = self.d_criterion.forward(fake_confidence_map,
                                                   fake_confidence_gt)
            fake_d_loss = self.args.discriminator_scale * torch.mean(
                fake_d_loss)
            self.meters.update('fake_d_loss', fake_d_loss.data)

            real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(gt[0])
            d_resulter, d_debugger = self.d_model.forward(real_gt)
            unhandled_real_confidence_map = tool.dict_value(
                d_resulter, 'confidence')
            real_confidence_map, real_confidence_gt = \
                self.task_func.ssladv_preprocess_fcd_criterion(unhandled_real_confidence_map, gt[0], True)

            real_d_loss = self.d_criterion.forward(real_confidence_map,
                                                   real_confidence_gt)
            real_d_loss = self.args.discriminator_scale * torch.mean(
                real_d_loss)
            self.meters.update('real_d_loss', real_d_loss.data)

            self.task_func.metrics(activated_pred,
                                   gt,
                                   inp,
                                   self.meters,
                                   id_str='task')

            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info(
                    'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                    '  task-{3}\t=>\t'
                    'task-loss: {meters[task_loss]:.6f}\t'
                    '  fc-discriminator\t=>\t'
                    'fake-d-loss: {meters[fake_d_loss]:.6f}\t'
                    'real-d-loss: {meters[real_d_loss]:.6f}\n'.format(
                        epoch + 1,
                        idx,
                        len(data_loader),
                        self.args.task,
                        meters=self.meters))

            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(
                    epoch, idx, False,
                    func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
                    func.split_tensor_tuple(activated_pred,
                                            0,
                                            1,
                                            reduce_dim=True),
                    func.split_tensor_tuple(gt, 0, 1, reduce_dim=True),
                    torch.sigmoid(unhandled_fake_confidence_map[0]))

        # metrics
        metrics_info = {'task': ''}
        for key in sorted(list(self.meters.keys())):
            if self.task_func.METRIC_STR in key:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6}\t'.format(
                            key, self.meters[key])

        logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format(
            metrics_info['task'].replace('_', '-')))

    def _save_checkpoint(self, epoch):
        state = {
            'algorithm': self.NAME,
            'epoch': epoch,
            'model': self.model.state_dict(),
            'd_model': self.d_model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'd_optimizer': self.d_optimizer.state_dict(),
            'lrer': self.lrer.state_dict(),
            'd_lrer': self.d_lrer.state_dict()
        }

        checkpoint = os.path.join(self.args.checkpoint_path,
                                  'checkpoint_{0}.ckpt'.format(epoch))
        torch.save(state, checkpoint)

    def _load_checkpoint(self):
        checkpoint = torch.load(self.args.resume)

        checkpoint_algorithm = tool.dict_value(checkpoint,
                                               'algorithm',
                                               default='unknown')
        if checkpoint_algorithm != self.NAME:
            logger.log_err(
                'Unmatched SSL algorithm format in checkpoint => required: {0} - given: {1}\n'
                .format(self.NAME, checkpoint_algorithm))

        self.model.load_state_dict(checkpoint['model'])
        self.d_model.load_state_dict(checkpoint['d_model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
        self.lrer.load_state_dict(checkpoint['lrer'])
        self.d_lrer.load_state_dict(checkpoint['d_lrer'])

        return checkpoint['epoch']

    # -------------------------------------------------------------------------------------------
    # Tool Functions for SSL_ADV
    # -------------------------------------------------------------------------------------------

    def _visualize(self,
                   epoch,
                   idx,
                   is_train,
                   l_inp,
                   l_pred,
                   l_gt,
                   l_cmap,
                   u_inp=None,
                   u_pred=None,
                   u_cmap=None):
        # 'cmap' is the output of the FC discriminator

        visualize_path = self.args.visual_train_path if is_train else self.args.visual_val_path
        out_path = os.path.join(visualize_path, '{0}_{1}'.format(epoch, idx))

        self.task_func.visualize(out_path,
                                 id_str='labeled',
                                 inp=l_inp,
                                 pred=l_pred,
                                 gt=l_gt)
        l_cmap = l_cmap[0].data.cpu().numpy()
        Image.fromarray((l_cmap * 255).astype('uint8'),
                        mode='L').save(out_path + '_labeled-cmap.png')

        if u_inp is not None and u_pred and not None and u_cmap is not None:
            self.task_func.visualize(out_path,
                                     id_str='unlabeled',
                                     inp=u_inp,
                                     pred=u_pred,
                                     gt=None)
            u_cmap = u_cmap[0].data.cpu().numpy()
            Image.fromarray((u_cmap * 255).astype('uint8'),
                            mode='L').save(out_path + '_unlabeled-cmap.png')

    def _batch_prehandle(self, inp, gt):
        # add extra data augmentation process here if necessary

        inp_var = []
        for i in inp:
            inp_var.append(Variable(i).cuda())
        inp = tuple(inp_var)

        gt_var = []
        for g in gt:
            gt_var.append(Variable(g).cuda())
        gt = tuple(gt_var)

        return inp, gt

    def _algorithm_warn(self):
        logger.log_warn(
            'This SSL_ADV algorithm reproduces the SSL algorithm from the paper:\n'
            '  \'Adversarial Learning for Semi-supervised Semantic Segmentation\'\n'
            'The main differences between this implementation and the original paper are:\n'
            '  (1) This implementation does not support the constraint named \'L_semi\' in the\n'
            '      original paper since it can only be used for pixel-wise classification\n'
            '\nThe semi-supervised constraint in this implementation refer to the constraint\n'
            'named \'L_adv\' in the original paper\n'
            '\nSame as the original paper, the FC discriminator is trained by the Adam optimizer\n'
            'with the PolynomialLR scheduler\n')

    def _inp_warn(self):
        logger.log_warn(
            'More than one ground truth of the task model is given in SSL_ADV\n'
            'You try to train the task model with more than one (pred & gt) pairs\n'
            'Please make sure that:\n'
            '  (1) The prediction tuple has the same size as the ground truth tuple\n'
            '  (2) The elements with the same index in the two tuples are corresponding\n'
            '  (3) The first element of (pred & gt) will be used to train the FC discriminator\n'
            '      and calculate the SSL constraints\n'
            'Please implement a new SSL algorithm if you want a variant of SSL_ADV with\n'
            'multiple FC discriminators (for multiple predictions)\n')

    def _pred_err(self):
        logger.log_err(
            'In SSL_ADV, the \'resulter\' dict returned by the task model should contain the following keys:\n'
            '   (1) \'pred\'\t=>\tunactivated task predictions\n'
            '   (2) \'activated_pred\'\t=>\tactivated task predictions\n'
            'We need both of them since some losses include the activation functions,\n'
            'e.g., the CrossEntropyLoss has contained SoftMax\n')