class multisource_metatrainer(object):
    def __init__(self,
                 args,
                 nnclass,
                 meta_update_lr,
                 meta_update_step,
                 beta,
                 pretrain_mode='meta'):
        self.device = 1
        self.generator_model = None
        self.generator_optim = None
        self.generator_criterion = None
        self.pretrain_mode = pretrain_mode
        self.batch_size = args.batch_size
        self.nnclass = nnclass
        self.init_generator(args)
        self.init_discriminator(args)
        self.init_optimizer(args)
        self.meta_update_lr = meta_update_lr
        self.meta_update_step = meta_update_step
        self.beta = beta

    def init_generator(self, args):

        self.generator_model = DeepLab(num_classes=self.nnclass,
                                       backbone='resnet',
                                       output_stride=16,
                                       sync_bn=None,
                                       freeze_bn=False).cuda()

        self.generator_model = torch.nn.DataParallel(
            self.generator_model).cuda()
        patch_replication_callback(self.generator_model)
        if args.resume:
            print('#--------- load pretrained model --------------#')
            model_dict = self.generator_model.module.state_dict()
            checkpoint = torch.load(args.resume)
            pretrained_dict = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'last_conv' not in k and k in model_dict.keys()
            }
            #pretrained_dict = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()  if 'last_conv' not in k}
            model_dict.update(pretrained_dict)
            self.generator_model.module.load_state_dict(model_dict)
        for param in self.generator_model.parameters():
            param.requires_grad = True

    def init_discriminator(self, args):
        # init D
        self.discriminator_model = FCDiscriminator(num_classes=2).cuda()
        self.interp = nn.Upsample(size=400, mode='bilinear')
        self.disc_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        return

    def init_optimizer(self, args):
        self.generator_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(
                mode='bce')  #torch.nn.BCELoss(reduce ='mean')
        self.generator_params = [{
            'params':
            self.generator_model.module.get_1x_lr_params(),
            'lr':
            args.lr
        }, {
            'params':
            self.generator_model.module.get_10x_lr_params(),
            'lr':
            args.lr * 10
        }]
        self.discriminator_params = [{
            'params':
            self.discriminator_model.parameters(),
            'lr':
            args.lr * 5
        }]
        self.model_optim = torch.optim.Adadelta(self.generator_params +
                                                self.discriminator_params)
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      lr_step=30,
                                      iters_per_epoch=100)

    # for madan the src_image has shape B x source_index x channel x H x W
    def update_weights(self, srca, srca_labels, src_b, srcb_labels, target_img,
                       target_label):
        #self.pretrain_mode = 'meta'
        src_labels = torch.cat([srca_labels.squeeze(),
                                srcb_labels.squeeze()],
                               0).type(torch.LongTensor).cuda()
        src_image = torch.cat([srca.squeeze(), src_b.squeeze()])
        if self.pretrain_mode == 'meta':
            seg_loss = self.meta_mldg(src_image, src_labels, self.batch_size)
        else:
            print('a default training is enabled')
            src_out, source_feature = self.generator_model(src_image)
            seg_loss = self.generator_criterion(src_out, src_labels)
        self.model_optim.zero_grad()
        seg_loss.backward()
        self.model_optim.step()
        target_logit, _ = self.generator_model(target_img.cuda())
        tgt_loss = self.generator_criterion(target_logit, target_label)
        tgt_loss = tgt_loss.detach()
        seg_loss = seg_loss.detach()
        return seg_loss, tgt_loss

    def meta_mldg(self, src_image, src_labels, batch_size):
        batch_size = 4
        num_src = 2
        S = np.random.choice(num_src)
        V = abs(S - 1)
        source_out, _ = self.generator_model(src_image[S * batch_size:(S + 1) *
                                                       batch_size].squeeze())
        losses = self.generator_criterion(
            source_out, src_labels[S * batch_size:(S + 1) * batch_size])
        for k in range(1, self.meta_update_step):
            source_out, _ = self.generator_model(
                src_image[S * batch_size:(S + 1) * batch_size].squeeze())
            loss = self.generator_criterion(
                source_out, src_labels[S * batch_size:(S + 1) * batch_size])
            grad = torch.autograd.grad(loss, self.generator_model.parameters())
            fast_weights = list(
                map(lambda p: p[1] - self.meta_update_lr * p[0],
                    zip(grad, self.generator_model.parameters())))
            # compute the test loss on the fast weights
            Grad_test = self.generator_model(src_image[V * batch_size:(V + 1) *
                                                       batch_size],
                                             fast_weights,
                                             bn_training=True)
            # compute the gradient on generator_model
            losses += self.beta * Grad_test
        return losses
Exemplo n.º 2
0
class adda_trainer(object):
    def __init__(self, args, nnclass):
        self.target_model = None
        self.target_optim = None
        self.target_criterion = None
        self.batch_size = args.batch_size
        self.nnclass = nnclass
        self.init_target(args)
        self.init_discriminator(args)
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      lr_step=40,
                                      iters_per_epoch=100)
        self.disc_params = [{
            'params': self.disc_model.parameters(),
            'lr': args.lr * 5
        }]
        self.dda_optim = torch.optim.Adam(self.train_params)
        self.discriminator_optim = torch.optim.Adam(self.disc_params)
        #self.dda_optim = torch.optim.SGD(self.train_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)
        #self.discriminator_optim = torch.optim.SGD(self.disc_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)
        self.adv_aug = FastGradientSignUntargeted(self.target_model,
                                                  0.0157,
                                                  0.00784,
                                                  min_val=0,
                                                  max_val=1,
                                                  max_iters=2,
                                                  _type='linf')

    def init_target(self, args):

        self.target_model = DeepLab(num_classes=self.nnclass,
                                    backbone='resnet',
                                    output_stride=16,
                                    sync_bn=None,
                                    freeze_bn=False)
        self.train_params = [{
            'params': self.target_model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': self.target_model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]
        self.target_model = torch.nn.DataParallel(self.target_model)
        self.target_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(
                mode='bce')  #torch.nn.BCELoss(reduce ='mean')
        patch_replication_callback(self.target_model)
        model_dict = self.target_model.module.state_dict()
        checkpoint = torch.load(args.resume)
        pretrained_dict = {
            k.replace('module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        }
        #pretrained_dict = {k:v for k,v in checkpoint['state_dict'].items() if 'last_conv' not in k }
        model_dict.update(pretrained_dict)
        self.target_model.module.load_state_dict(model_dict)
        self.target_model = self.target_model.cuda()
        return

    def init_discriminator(self, args):
        # init D
        self.disc_model = FCDiscriminator(num_classes=2).cuda()
        self.interp = nn.Upsample(size=400, mode='bilinear')
        self.disc_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        return

    def update_weights(self, input_, src_labels, target, tgt_labels, lamda_g,
                       trainmodel):

        self.dda_optim.zero_grad()
        self.discriminator_optim.zero_grad()
        if trainmodel == 'train_gen':
            for param in self.target_model.parameters():
                param.requires_grad = True
            for param in self.disc_model.parameters():
                param.requires_grad = False
            self.disc_model.eval()
            self.target_model.train()
        else:
            for param in self.target_model.parameters():
                param.requires_grad = False
            for param in self.disc_model.parameters():
                param.requires_grad = True
            self.disc_model.train()
            self.target_model.eval()
        #tot_input = torch.cat([input_, target])
        #import pdb
        #pdb.set_trace()
        src_out, source_feature = self.target_model(input_)
        seg_loss = self.target_criterion(src_out, src_labels)
        #print(target.shape)
        targ_out, target_feature = self.target_model(target)

        # discriminator
        discriminator_x = torch.cat([source_feature, target_feature]).squeeze()
        discriminator_adv_logit = torch.cat([
            torch.zeros(source_feature.shape),
            torch.ones(target_feature.shape)
        ])
        discriminator_real_logit = torch.cat([
            torch.ones(source_feature.shape),
            torch.zeros(target_feature.shape)
        ])
        disc_out = self.disc_model(discriminator_x)
        #print(source_feature.shape, input_.shape,discriminator_adv_logit.shape, disc_out.shape)
        adv_loss = self.target_criterion(
            disc_out, discriminator_adv_logit[:, 0, :, :].cuda())
        adv_loss += self.target_criterion(
            disc_out, discriminator_adv_logit[:, 1, :, :].cuda())
        disc_loss = self.disc_criterion(
            disc_out, discriminator_real_logit[:, 0, :, :].cuda())
        disc_loss += self.disc_criterion(
            disc_out, discriminator_real_logit[:, 1, :, :].cuda())
        if trainmodel == 'train_gen':
            loss_seg = seg_loss + lamda_g * adv_loss
            loss_seg.backward()
            self.dda_optim.step()
        else:
            disc_loss.backward()
            self.discriminator_optim.step()
        tgt_loss = self.target_criterion(targ_out, tgt_labels)
        return seg_loss.data.cpu().numpy(), tgt_loss.data.cpu().numpy()
Exemplo n.º 3
0
class madan_trainer(object):
    def __init__(self, args, nnclass, ndomains):
        self.device = 1
        self.generator_model = None
        self.generator_optim = None
        self.generator_criterion = None
        self.batch_size = args.batch_size
        self.nnclass = nnclass
        self.num_domains = ndomains
        self.init_wasserstein = Wasserstein()
        self.init_generator(args)
        self.init_discriminator(args)
        self.init_optimizer(args)

    def init_generator(self, args):

        self.generator_model = DeepLab(num_classes=self.nnclass,
                                       backbone='resnet',
                                       output_stride=16,
                                       sync_bn=None,
                                       freeze_bn=False).cuda()

        self.generator_model = torch.nn.DataParallel(
            self.generator_model).cuda()
        patch_replication_callback(self.generator_model)
        if args.resume:
            print('#--------- load pretrained model --------------#')
            model_dict = self.generator_model.module.state_dict()
            checkpoint = torch.load(args.resume)
            pretrained_dict = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'last_conv' not in k and k in model_dict.keys()
            }
            #pretrained_dict = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()  if 'last_conv' not in k}
            model_dict.update(pretrained_dict)
            self.generator_model.module.load_state_dict(model_dict)
        for param in self.generator_model.parameters():
            param.requires_grad = True

    def init_discriminator(self, args):
        # init D
        self.discriminator_model = FCDiscriminator(num_classes=2).cuda()
        self.interp = nn.Upsample(size=400, mode='bilinear')
        self.disc_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        return

    def init_optimizer(self, args):
        self.generator_criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(
                mode='bce')  #torch.nn.BCELoss(reduce ='mean')
        self.generator_params = [{
            'params':
            self.generator_model.module.get_1x_lr_params(),
            'lr':
            args.lr
        }, {
            'params':
            self.generator_model.module.get_10x_lr_params(),
            'lr':
            args.lr * 10
        }]
        self.discriminator_params = [{
            'params':
            self.discriminator_model.parameters(),
            'lr':
            args.lr * 5
        }]
        self.model_optim = torch.optim.Adadelta(self.generator_params +
                                                self.discriminator_params)
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      lr_step=30,
                                      iters_per_epoch=100)

    # for madan the src_image has shape B x source_index x channel x H x W
    def update_weights(self, src_image, src_labels, targ_image, targ_labels,
                       options):
        running_loss = 0.0
        src_labels = torch.cat(
            [src_labels[:, 0].squeeze(), src_labels[:, 1].squeeze()],
            0).type(torch.LongTensor).cuda()
        self.model_optim.zero_grad()
        # src image shape batch_size x domain x 3 channels x height x width
        src_out, source_feature = self.generator_model(
            torch.cat([src_image[:, 0].squeeze(), src_image[:, 1].squeeze()]))
        targ_out, target_feature = self.generator_model(targ_image)
        #  Discriminator
        discriminator_x = torch.cat([source_feature, target_feature]).squeeze()
        disc_clf = self.discriminator_model(discriminator_x)
        # Losses
        losses = torch.stack([
            self.generator_criterion(
                src_out[j * self.batch_size:j + self.batch_size],
                src_labels[j * self.batch_size:j + self.batch_size])
            for j in range(self.num_domains)
        ])
        slabels = torch.ones(self.batch_size,
                             disc_clf.shape[2],
                             disc_clf.shape[3],
                             requires_grad=False).type(
                                 torch.LongTensor).cuda()
        tlabels = torch.zeros(self.batch_size * 2,
                              disc_clf.shape[2],
                              disc_clf.shape[3],
                              requires_grad=False).type(
                                  torch.LongTensor).cuda()
        domain_losses = torch.stack([
            self.generator_criterion(
                disc_clf[j * self.batch_size:j + self.batch_size].squeeze(),
                slabels) for j in range(self.num_domains)
        ])
        domain_losses = torch.cat([
            domain_losses,
            self.generator_criterion(
                disc_clf[2 * self.batch_size:2 * self.batch_size +
                         2 * self.batch_size].squeeze(), tlabels).view(-1)
        ])
        # Different final loss function depending on different training modes.
        if options['mode'] == "maxmin":
            loss = torch.max(losses) + options['mu'] * torch.min(domain_losses)
        elif options['mode'] == "dynamic":
            loss = torch.log(
                torch.sum(
                    torch.exp(options['gamma'] *
                              (losses + options['mu'] * domain_losses)))
            ) / options['gamma']
        else:
            raise ValueError(
                "No support for the training mode on madnNet: {}.".format(
                    options['mode']))
        loss.backward()
        self.model_optim.step()
        running_loss += loss.detach().cpu().numpy()
        # compute target loss
        target_loss = self.generator_criterion(
            targ_out, targ_labels).detach().cpu().numpy()
        return running_loss, target_loss

    def update_wasserstein(self, src_image, src_labels, targ_image,
                           targ_labels, options):
        running_loss = 0.0
        src_labels = torch.cat(
            [src_labels[:, 0].squeeze(), src_labels[:, 1].squeeze()],
            0).type(torch.LongTensor).cuda()
        self.model_optim.zero_grad()
        # src image shape batch_size x domain x 3 channels x height x width
        src_out, source_feature = self.generator_model(
            torch.cat([src_image[:, 0].squeeze(), src_image[:, 1].squeeze()]))
        targ_out, target_feature = self.generator_model(targ_image)
        #  Discriminator
        discriminator_x = torch.cat([source_feature, target_feature]).squeeze()
        disc_clf = self.discriminator_model(discriminator_x)
        # Losses
        losses = torch.stack([
            self.generator_criterion(
                src_out[j * self.batch_size:j * self.batch_size +
                        self.batch_size],
                src_labels[j * self.batch_size:j * self.batch_size +
                           self.batch_size]) for j in range(self.num_domains)
        ])
        wass_loss = [
            self.init_wasserstein.update_wasserstein_dual_source(
                disc_clf[j * self.batch_size:j * self.batch_size +
                         self.batch_size].squeeze(),
                disc_clf[self.num_domains *
                         self.batch_size:self.num_domains * self.batch_size +
                         self.batch_size].squeeze())
            for j in range(self.num_domains)
        ]

        domain_losses = torch.stack(wass_loss)
        # compute gradient penalty
        penalty_cup, penalty_disc = self.init_wasserstein.gradient_regularization_dual_source(
            self.discriminator_model, source_feature.detach(),
            target_feature.detach(), options['batch_size'],
            options['num_domains'])
        # Different final loss function depending on different training modes.
        if options['mode'] == "maxmin":
            loss = torch.max(
                losses) + options['mu'] * torch.min(domain_losses) + options[
                    'gamma'] * penalty_cup + options['gamma'] * penalty_disc
        elif options['mode'] == "dynamic":
            # TODO Wasserstein not implemented yet for this
            loss = torch.log(
                torch.sum(
                    torch.exp(options['gamma'] *
                              (losses + options['mu'] * domain_losses)))
            ) / options['gamma']
        else:
            raise ValueError(
                "No support for the training mode on madnNet: {}.".format(
                    options['mode']))
        loss.backward()
        self.model_optim.step()
        for p in self.discriminator_model.parameters():
            p.data.clamp_(-0.01, 0.01)
        running_loss += loss.detach().cpu().numpy()
        # compute target loss
        target_loss = self.generator_criterion(
            targ_out, targ_labels).detach().cpu().numpy()
        return running_loss, target_loss
Exemplo n.º 4
0
class CustomModel():
    def __init__(self, opt, logger, isTrain=True):
        self.opt = opt
        self.class_numbers = opt.n_class
        self.logger = logger
        self.best_iou = -100
        self.nets = []
        self.nets_DP = []
        self.default_gpu = 0
        self.objective_vectors = torch.zeros([self.class_numbers, 256])
        self.objective_vectors_num = torch.zeros([self.class_numbers])

        if opt.bn == 'sync_bn':
            BatchNorm = SynchronizedBatchNorm2d
        elif opt.bn == 'bn':
            BatchNorm = nn.BatchNorm2d
        else:
            raise NotImplementedError('batch norm choice {} is not implemented'.format(opt.bn))

        if self.opt.no_resume:
            restore_from = None
        else:
            restore_from= opt.resume_path
            self.best_iou = 0
        if self.opt.student_init == 'imagenet':
            self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from)
        elif self.opt.student_init == 'simclr':
            self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, 
                initialization=os.path.join(opt.root, 'Code/ProDA', 'pretrained/simclr/r101_1x_sk0.pth'), bn_clr=opt.bn_clr)
        else:
            self.BaseNet = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from)
            
        logger.info('the backbone is {}'.format(opt.model_name))

        self.nets.extend([self.BaseNet])

        self.optimizers = []
        self.schedulers = []        
        optimizer_cls = torch.optim.SGD
        optimizer_params = {'lr':opt.lr, 'weight_decay':2e-4, 'momentum':0.9}

        if self.opt.stage == 'warm_up':
            self.net_D = FCDiscriminator(inplanes=self.class_numbers)
            self.net_D_DP = self.init_device(self.net_D, gpu_id=self.default_gpu, whether_DP=True)
            self.nets.extend([self.net_D])
            self.nets_DP.append(self.net_D_DP)

            self.optimizer_D = torch.optim.Adam(self.net_D.parameters(), lr=1e-4, betas=(0.9, 0.99))
            self.optimizers.extend([self.optimizer_D])
            self.DSchedule = get_scheduler(self.optimizer_D, opt)
            self.schedulers.extend([self.DSchedule])

        if self.opt.finetune or self.opt.stage == 'warm_up':
            self.BaseOpti = optimizer_cls([{'params':self.BaseNet.get_1x_lr_params(), 'lr':optimizer_params['lr']},
                                           {'params':self.BaseNet.get_10x_lr_params(), 'lr':optimizer_params['lr']*10}], **optimizer_params)
        else:
            self.BaseOpti = optimizer_cls(self.BaseNet.parameters(), **optimizer_params)
        self.optimizers.extend([self.BaseOpti])

        self.BaseSchedule = get_scheduler(self.BaseOpti, opt)
        self.schedulers.extend([self.BaseSchedule])

        if self.opt.ema:
            self.BaseNet_ema = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=restore_from, bn_clr=opt.ema_bn)
            self.BaseNet_ema.load_state_dict(self.BaseNet.state_dict().copy())

        if self.opt.distillation > 0:
            self.teacher = Deeplab(BatchNorm, num_classes=self.class_numbers, freeze_bn=False, restore_from=opt.resume_path, bn_clr=opt.ema_bn)
            self.teacher.eval()
            self.teacher_DP = self.init_device(self.teacher, gpu_id=self.default_gpu, whether_DP=True)


        self.adv_source_label = 0
        self.adv_target_label = 1
        if self.opt.gan == 'Vanilla':
            self.bceloss = nn.BCEWithLogitsLoss(size_average=True)
        elif self.opt.gan == 'LS':
            self.bceloss = torch.nn.MSELoss()
        self.feat_prototype_distance_DP = self.init_device(feat_prototype_distance_module(), gpu_id=self.default_gpu, whether_DP=True)

        self.BaseNet_DP = self.init_device(self.BaseNet, gpu_id=self.default_gpu, whether_DP=True)
        self.nets_DP.append(self.BaseNet_DP)
        if self.opt.ema:
            self.BaseNet_ema_DP = self.init_device(self.BaseNet_ema, gpu_id=self.default_gpu, whether_DP=True)

    def calculate_mean_vector(self, feat_cls, outputs, labels=None, thresh=None):
        outputs_softmax = F.softmax(outputs, dim=1)
        if thresh is None:
            thresh = -1
        conf = outputs_softmax.max(dim=1, keepdim=True)[0]
        mask = conf.ge(thresh)
        outputs_argmax = outputs_softmax.argmax(dim=1, keepdim=True)
        outputs_argmax = self.process_label(outputs_argmax.float())
        if labels is None:
            outputs_pred = outputs_argmax
        else:
            labels_expanded = self.process_label(labels)
            outputs_pred = labels_expanded * outputs_argmax
        scale_factor = F.adaptive_avg_pool2d(outputs_pred * mask, 1)
        vectors = []
        ids = []
        for n in range(feat_cls.size()[0]):
            for t in range(self.class_numbers):
                if scale_factor[n][t].item()==0:
                    continue
                if (outputs_pred[n][t] > 0).sum() < 10:
                    continue
                s = feat_cls[n] * outputs_pred[n][t] * mask[n]
                # scale = torch.sum(outputs_pred[n][t]) / labels.shape[2] / labels.shape[3] * 2
                # s = normalisation_pooling()(s, scale)
                s = F.adaptive_avg_pool2d(s, 1) / scale_factor[n][t]
                vectors.append(s)
                ids.append(t)
        return vectors, ids

    def step_adv(self, source_x, source_label, target_x, source_imageS, source_params):
        for param in self.net_D.parameters():
            param.requires_grad = False
        self.BaseOpti.zero_grad()
        
        if self.opt.S_pseudo_src > 0:
            source_output = self.BaseNet_DP(source_imageS)
            source_label_d4 = F.interpolate(source_label.unsqueeze(1).float(), size=source_output['out'].size()[2:])
            source_labelS = self.label_strong_T(source_label_d4.clone().float(), source_params, padding=250, scale=4).to(torch.int64)
            loss_ = cross_entropy2d(input=source_output['out'], target=source_labelS.squeeze(1))
            loss_GTA = loss_ * self.opt.S_pseudo_src
            source_outputUp = F.interpolate(source_output['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)
        else:
            source_output = self.BaseNet_DP(source_x, ssl=True)
            source_outputUp = F.interpolate(source_output['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)

            loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label, size_average=True, reduction='mean')

        target_output = self.BaseNet_DP(target_x, ssl=True)
        target_outputUp = F.interpolate(target_output['out'], size=target_x.size()[2:], mode='bilinear', align_corners=True)
        target_D_out = self.net_D_DP(F.softmax(target_outputUp, dim=1))
        loss_adv_G = self.bceloss(target_D_out, torch.FloatTensor(target_D_out.data.size()).fill_(self.adv_source_label).to(target_D_out.device)) * self.opt.adv
        loss_G = loss_adv_G + loss_GTA
        loss_G.backward()
        self.BaseOpti.step()

        for param in self.net_D.parameters():
            param.requires_grad = True
        self.optimizer_D.zero_grad()
        source_D_out = self.net_D_DP(F.softmax(source_outputUp.detach(), dim=1))
        target_D_out = self.net_D_DP(F.softmax(target_outputUp.detach(), dim=1))
        loss_D = self.bceloss(source_D_out, torch.FloatTensor(source_D_out.data.size()).fill_(self.adv_source_label).to(source_D_out.device)) + \
                    self.bceloss(target_D_out, torch.FloatTensor(target_D_out.data.size()).fill_(self.adv_target_label).to(target_D_out.device))
        loss_D.backward()
        self.optimizer_D.step()

        return loss_GTA.item(), loss_adv_G.item(), loss_D.item()

    def step(self, source_x, source_label, target_x, target_imageS=None, target_params=None, target_lp=None, 
            target_lpsoft=None, target_image_full=None, target_weak_params=None):

        source_out = self.BaseNet_DP(source_x, ssl=True)
        source_outputUp = F.interpolate(source_out['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)

        loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label)
        loss_GTA.backward()        

        if self.opt.proto_rectify:
            threshold_arg = F.interpolate(target_lpsoft, scale_factor=0.25, mode='bilinear', align_corners=True)
        else:
            threshold_arg = F.interpolate(target_lp.unsqueeze(1).float(), scale_factor=0.25).long()

        if self.opt.ema:
            ema_input = target_image_full
            with torch.no_grad():
                ema_out = self.BaseNet_ema_DP(ema_input)
            ema_out['feat'] = F.interpolate(ema_out['feat'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)
            ema_out['out'] = F.interpolate(ema_out['out'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)

        target_out = self.BaseNet_DP(target_imageS) if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_x)
        target_out['out'] = F.interpolate(target_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
        target_out['feat'] = F.interpolate(target_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)

        loss = torch.Tensor([0]).to(self.default_gpu)
        batch, _, w, h = threshold_arg.shape
        if self.opt.proto_rectify:
            weights = self.get_prototype_weight(ema_out['feat'], target_weak_params=target_weak_params)
            rectified = weights * threshold_arg
            threshold_arg = rectified.max(1, keepdim=True)[1]
            rectified = rectified / rectified.sum(1, keepdim=True)
            argmax = rectified.max(1, keepdim=True)[0]
            threshold_arg[argmax < self.opt.train_thred] = 250
        if self.opt.S_pseudo > 0:
            threshold_argS = self.label_strong_T(threshold_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
            cluster_argS = self.label_strong_T(cluster_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
            threshold_arg = threshold_argS

        loss_CTS = cross_entropy2d(input=target_out['out'], target=threshold_arg.reshape([batch, w, h]))

        if self.opt.rce:
            rce = self.rce(target_out['out'], threshold_arg.reshape([batch, w, h]).clone())
            loss_CTS = self.opt.rce_alpha * loss_CTS + self.opt.rce_beta * rce

        if self.opt.regular_w > 0:
            regular_loss = self.regular_loss(target_out['out'])
            loss_CTS = loss_CTS + regular_loss * self.opt.regular_w

        cluster_argS = None
        loss_consist = torch.Tensor([0]).to(self.default_gpu)
        if self.opt.proto_consistW > 0:
            ema2weak_feat = self.full2weak(ema_out['feat'], target_weak_params)         #N*256*H*W
            ema2weak_feat_proto_distance = self.feat_prototype_distance(ema2weak_feat)  #N*19*H*W
            ema2strong_feat_proto_distance = self.label_strong_T(ema2weak_feat_proto_distance, target_params, padding=250, scale=4)
            mask = (ema2strong_feat_proto_distance != 250).float()
            teacher = F.softmax(-ema2strong_feat_proto_distance * self.opt.proto_temperature, dim=1)

            targetS_out = target_out if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_imageS)
            targetS_out['out'] = F.interpolate(targetS_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
            targetS_out['feat'] = F.interpolate(targetS_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)

            prototype_tmp = self.objective_vectors.expand(4, -1, -1)  #gpu memory limitation
            strong_feat_proto_distance = self.feat_prototype_distance_DP(targetS_out['feat'], prototype_tmp, self.class_numbers)
            student = F.log_softmax(-strong_feat_proto_distance * self.opt.proto_temperature, dim=1)

            loss_consist = F.kl_div(student, teacher, reduction='none')
            loss_consist = (loss_consist * mask).sum() / mask.sum()
            loss = loss + self.opt.proto_consistW * loss_consist

        loss = loss + loss_CTS
        loss.backward()
        self.BaseOpti.step()
        self.BaseOpti.zero_grad()

        if self.opt.moving_prototype: #update prototype
            ema_vectors, ema_ids = self.calculate_mean_vector(ema_out['feat'].detach(), ema_out['out'].detach())
            for t in range(len(ema_ids)):
                self.update_objective_SingleVector(ema_ids[t], ema_vectors[t].detach(), start_mean=False)
        
        if self.opt.ema: #update ema model
            for param_q, param_k in zip(self.BaseNet.parameters(), self.BaseNet_ema.parameters()):
                param_k.data = param_k.data.clone() * 0.999 + param_q.data.clone() * (1. - 0.999)
            for buffer_q, buffer_k in zip(self.BaseNet.buffers(), self.BaseNet_ema.buffers()):
                buffer_k.data = buffer_q.data.clone()

        return loss.item(), loss_CTS.item(), loss_consist.item()

    def regular_loss(self, activation):
        logp = F.log_softmax(activation, dim=1)
        if self.opt.regular_type == 'MRENT':
            p = F.softmax(activation, dim=1)
            loss = (p * logp).sum() / (p.shape[0]*p.shape[2]*p.shape[3])
        elif self.opt.regular_type == 'MRKLD':
            loss = - logp.sum() / (logp.shape[0]*logp.shape[1]*logp.shape[2]*logp.shape[3])
        return loss

    def rce(self, pred, labels):
        pred = F.softmax(pred, dim=1)
        pred = torch.clamp(pred, min=1e-7, max=1.0)
        mask = (labels != 250).float()
        labels[labels==250] = self.class_numbers
        label_one_hot = torch.nn.functional.one_hot(labels, self.class_numbers + 1).float().to(self.default_gpu)
        label_one_hot = torch.clamp(label_one_hot.permute(0,3,1,2)[:,:-1,:,:], min=1e-4, max=1.0)
        rce = -(torch.sum(pred * torch.log(label_one_hot), dim=1) * mask).sum() / (mask.sum() + 1e-6)
        return rce

    def step_distillation(self, source_x, source_label, target_x, target_imageS=None, target_params=None, target_lp=None):

        source_out = self.BaseNet_DP(source_x, ssl=True)
        source_outputUp = F.interpolate(source_out['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)
        loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label)
        loss_GTA.backward()

        threshold_arg = F.interpolate(target_lp.unsqueeze(1).float(), scale_factor=0.25).long()
        if self.opt.S_pseudo > 0:
            threshold_arg = self.label_strong_T(threshold_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
            target_out = self.BaseNet_DP(target_imageS)
        else:
            target_out = self.BaseNet_DP(target_x)
        target_out['out'] = F.interpolate(target_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
        batch, _, w, h = threshold_arg.shape
        loss = cross_entropy2d(input=target_out['out'], target=threshold_arg.reshape([batch, w, h]), size_average=True, reduction='mean')
        if self.opt.rce:
            rce = self.rce(target_out['out'], threshold_arg.reshape([batch, w, h]).clone())
            loss = self.opt.rce_alpha * loss + self.opt.rce_beta * rce

        if self.opt.distillation > 0:
            student = F.softmax(target_out['out'], dim=1)
            with torch.no_grad():
                teacher_out = self.teacher_DP(target_imageS)
                teacher_out['out'] = F.interpolate(teacher_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
                teacher = F.softmax(teacher_out['out'], dim=1)

            loss_kd = F.kl_div(student, teacher, reduction='none')
            mask = (teacher != 250).float()
            loss_kd = (loss_kd * mask).sum() / mask.sum()
            loss = loss + self.opt.distillation * loss_kd

        loss.backward()
        self.BaseOpti.step()
        self.BaseOpti.zero_grad()
        return loss_GTA.item(), loss.item()

    def full2weak(self, feat, target_weak_params):
        tmp = []
        for i in range(feat.shape[0]):
            h, w = target_weak_params['RandomSized'][0][i], target_weak_params['RandomSized'][1][i]
            feat_ = F.interpolate(feat[i:i+1], size=[int(h/4), int(w/4)], mode='bilinear', align_corners=True)
            y1, y2, x1, x2 = target_weak_params['RandomCrop'][0][i], target_weak_params['RandomCrop'][1][i], target_weak_params['RandomCrop'][2][i], target_weak_params['RandomCrop'][3][i]
            y1, th, x1, tw = int(y1/4), int((y2-y1)/4), int(x1/4), int((x2-x1)/4)
            feat_ = feat_[:, :, y1:y1+th, x1:x1+tw]
            if target_weak_params['RandomHorizontallyFlip'][i]:
                inv_idx = torch.arange(feat_.size(3)-1,-1,-1).long().to(feat_.device)
                feat_ = feat_.index_select(3,inv_idx)
            tmp.append(feat_)
        feat = torch.cat(tmp, 0)
        return feat

    def feat_prototype_distance(self, feat):
        N, C, H, W = feat.shape
        feat_proto_distance = -torch.ones((N, self.class_numbers, H, W)).to(feat.device)
        for i in range(self.class_numbers):
            #feat_proto_distance[:, i, :, :] = torch.norm(torch.Tensor(self.objective_vectors[i]).reshape(-1,1,1).expand(-1, H, W).to(feat.device) - feat, 2, dim=1,)
            feat_proto_distance[:, i, :, :] = torch.norm(self.objective_vectors[i].reshape(-1,1,1).expand(-1, H, W) - feat, 2, dim=1,)
        return feat_proto_distance

    def get_prototype_weight(self, feat, label=None, target_weak_params=None):
        feat = self.full2weak(feat, target_weak_params)
        feat_proto_distance = self.feat_prototype_distance(feat)
        feat_nearest_proto_distance, feat_nearest_proto = feat_proto_distance.min(dim=1, keepdim=True)

        feat_proto_distance = feat_proto_distance - feat_nearest_proto_distance
        weight = F.softmax(-feat_proto_distance * self.opt.proto_temperature, dim=1)
        return weight

    def label_strong_T(self, label, params, padding, scale=1):
        label = label + 1
        for i in range(label.shape[0]):
            for (Tform, param) in params.items():
                if Tform == 'Hflip' and param[i].item() == 1:
                    label[i] = label[i].clone().flip(-1)
                elif (Tform == 'ShearX' or Tform == 'ShearY' or Tform == 'TranslateX' or Tform == 'TranslateY' or Tform == 'Rotate') and param[i].item() != 1e4:
                    v = int(param[i].item() // scale) if Tform == 'TranslateX' or Tform == 'TranslateY' else param[i].item()
                    label[i:i+1] = affine_sample(label[i:i+1].clone(), v, Tform)
                elif Tform == 'CutoutAbs' and isinstance(param, list):
                    x0 = int(param[0][i].item() // scale)
                    y0 = int(param[1][i].item() // scale)
                    x1 = int(param[2][i].item() // scale)
                    y1 = int(param[3][i].item() // scale)
                    label[i, :, y0:y1, x0:x1] = 0
        label[label == 0] = padding + 1  # for strong augmentation, constant padding
        label = label - 1
        return label

    def process_label(self, label):
        batch, channel, w, h = label.size()
        pred1 = torch.zeros(batch, self.class_numbers + 1, w, h).to(self.default_gpu)
        id = torch.where(label < self.class_numbers, label, torch.Tensor([self.class_numbers]).to(self.default_gpu))
        pred1 = pred1.scatter_(1, id.long(), 1)
        return pred1

    def freeze_bn_apply(self):
        for net in self.nets:
            net.apply(freeze_bn)
        for net in self.nets_DP:
            net.apply(freeze_bn)

    def scheduler_step(self):
        for scheduler in self.schedulers:
            scheduler.step()
    
    def optimizer_zerograd(self):
        for optimizer in self.optimizers:
            optimizer.zero_grad()
    

    def init_device(self, net, gpu_id=None, whether_DP=False):
        gpu_id = gpu_id or self.default_gpu
        device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else 'cpu')
        net = net.to(device)
        # if torch.cuda.is_available():
        if whether_DP:
            #net = DataParallelWithCallback(net, device_ids=[0])
            net = DataParallelWithCallback(net, device_ids=range(torch.cuda.device_count()))
        return net
    
    def eval(self, net=None, logger=None):
        """Make specific models eval mode during test time"""
        # if issubclass(net, nn.Module) or issubclass(net, BaseModel):
        if net == None:
            for net in self.nets:
                net.eval()
            for net in self.nets_DP:
                net.eval()
            if logger!=None:    
                logger.info("Successfully set the model eval mode") 
        else:
            net.eval()
            if logger!=None:    
                logger("Successfully set {} eval mode".format(net.__class__.__name__))
        return

    def train(self, net=None, logger=None):
        if net==None:
            for net in self.nets:
                net.train()
            for net in self.nets_DP:
                net.train()
        else:
            net.train()
        return

    def update_objective_SingleVector(self, id, vector, name='moving_average', start_mean=True):
        if vector.sum().item() == 0:
            return
        if start_mean and self.objective_vectors_num[id].item() < 100:
            name = 'mean'
        if name == 'moving_average':
            self.objective_vectors[id] = self.objective_vectors[id] * (1 - self.opt.proto_momentum) + self.opt.proto_momentum * vector.squeeze()
            self.objective_vectors_num[id] += 1
            self.objective_vectors_num[id] = min(self.objective_vectors_num[id], 3000)
        elif name == 'mean':
            self.objective_vectors[id] = self.objective_vectors[id] * self.objective_vectors_num[id] + vector.squeeze()
            self.objective_vectors_num[id] += 1
            self.objective_vectors[id] = self.objective_vectors[id] / self.objective_vectors_num[id]
            self.objective_vectors_num[id] = min(self.objective_vectors_num[id], 3000)
            pass
        else:
            raise NotImplementedError('no such updating way of objective vectors {}'.format(name))